BERT for Classification: Beyond the Next Sentence Prediction Task

Divyanshu Dimri
14 min readOct 17, 2023

--

Abstract

In this article, we delve into the world of BERT (Bidirectional Encoder Representations from Transformers), one of the most renowned transformer-based Deep Learning models. While BERT has been a game-changer in the NLP community since its debut in 2018, our focus here is to comprehend the BERT model through the lens of fine-tuning for classification tasks. We will explore how to harness BERT’s unique capabilities, particularly its Next Sentence Prediction (NSP), to enhance our classification endeavors.

BERT Architectures

BERT is available in two versions: BERT-BASE and BERT-LARGE, each equipped with a different number of encoder layers, which are also referred to as transformer blocks. These encoder layers play a crucial role in the model’s architecture. BERT-BASE has 12 encoder layers, while BERT-LARGE boasts 24 encoder layers.

What sets these two versions apart is not only the number of encoder layers but also the size of their feed-forward networks and the quantity of attention heads. In BERT-BASE, the feed-forward network has 768 hidden units, and there are 12 attention heads. On the other hand, BERT-LARGE features a more extensive configuration with 1024 hidden units in the feed-forward network and an impressive 16 attention heads.

This design represents a significant enhancement over the default settings in the original transformer model as described in the foundational paper. The default transformer configuration featured 6 encoder layers, 512 hidden units, and 8 attention heads. BERT’s scaling up to larger models with more encoder layers, hidden units, and attention heads has been instrumental in its success and its ability to handle complex NLP tasks effectively.

Here we will try to do the classification using two method one with pre-trained weights and second finetune the BERT itself.

Approach

In this we will try to investigate two different methods for classification using BERT. Our focus is to understand the practical differences in how they are put into action and how these distinctions influence the model’s performance. By conducting this comparative analysis, we aim to provide technical insights into the most effective strategies when working with BERT for classification tasks.

BERT with pre-trained weights

Lets try to understand the model with BERT-BASE representation. Each position of the input, output’s a vector of size 768 which is a special token(which we will delve deep a bit later) containing feature representation of out input sentence. We then use this feature representation for training our classifier.

  • We first try to leverage the pre-trained BERT model, by extracting the features from the last hidden state of size 786, which will be the feature representation of our input text.

For this we will be using DistilBERT, for extracting pre-trained weights as well as is used for the tokenization. DistilBERT is a smaller version of BERT developed and open sourced by the team at HuggingFace. It’s a lighter and faster version of BERT that roughly matches its performance.

  • We than use these feature representation to do the standard classification either using random forest, or logistic regression as per your choice.

So what is this special token, BERT uses 2 special token one is [CLS] and other is [SEP]. For our classification we leverage the token embedding of [CLS] token, a 786 hidden unit which contains the feature representation of our input text.

Why use [CLS] for sentence classification?

The reason been the [CLS] token in BERT helps in containing the encoded information during the NSP(Next Sentence Prediction) by being the global representation of the input sequence.

Lets try to understand how [CLS] token is trained via NSP and why it contain the feature representation of the input sentence/ text.

The following example shows how the [CLS] token is used in NSP:

Premise: I am Divyanshu.
Hypothesis: I work for a service based company as a Data-Scientist.

# In order to prepare this input for the NSP task, you would format it as follows:
Input Sequence: [CLS] I am Divyanshu. [SEP] I work for a service based company as a Data-Scientist. [SEP]
  • [CLS] (Classification Token): This token is added at the beginning of the input and represents the start of the sequence. It's typically used to capture aggregate information about the entire input a stated above.
  • [SEP] (Separator Token): This token is used to separate different segments of text within the input. In the context of the NSP task, it indicates the boundary between the two sentences.

During the NSP training BERT would encode this input sequence and use the output of the [CLS] token to compute the probability distribution over the two possible labels: IsNextSentenceand NotNextSentence . It’s a binary classification task, where BERT is trying to predict the labels for the the given pair of sentences. This help BERT learn the contextual relationships between the sentences, making it capable of understanding text coherence and context in various NLP applications, which in our case is the classification task where we will try to leverage this already trained [CLS] token.

# For DistilBERT:
model_class, tokenizer_class, pretrained_weights = (
transformers.DistilBertModel,
transformers.DistilBertTokenizer,
'distilbert-base-uncased'
)

# Load pretrained model/tokenizer
tokenizer = tokenizer_class.from_pretrained(pretrained_weights)
model = model_class.from_pretrained(pretrained_weights)

Here I am working on the text into two category one is agriand other is non-agri. ( because of the confidentiality I will not be sharing the dataset, just for the viewers understanding its a dataset which contains agri and non-agri input.)

Why use bert-base-uncasednot bert-base-cased ?

bert-base-uncased is used if the text data is mostly in lowercase/ uppercase and we do not need the model to differentiate between uppercase and lowercase words. It is commonly used when case sensitivity is not crucial, especially in English text. Since for our case we are dealing with sentence classification where word casing is not any issue we go with the bert-base-uncased version and not with the bert-base-cased as it is used when text data contains important information in uppercase letters and we want the model to capture the difference between uppercase and lowercase words. It is useful when maintaining the original casing of words is relevant, such as in NER(Named-Entity-Recognition) tasks.

DistilBertModel(
(embeddings): Embeddings(
(word_embeddings): Embedding(30522, 768, padding_idx=0)
(position_embeddings): Embedding(512, 768)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
(transformer): Transformer(
(layer): ModuleList(
(0-5): 6 x TransformerBlock(
(attention): MultiHeadSelfAttention(
(dropout): Dropout(p=0.1, inplace=False)
(q_lin): Linear(in_features=768, out_features=768, bias=True)
(k_lin): Linear(in_features=768, out_features=768, bias=True)
(v_lin): Linear(in_features=768, out_features=768, bias=True)
(out_lin): Linear(in_features=768, out_features=768, bias=True)
)
(sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(ffn): FFN(
(dropout): Dropout(p=0.1, inplace=False)
(lin1): Linear(in_features=768, out_features=3072, bias=True)
(lin2): Linear(in_features=3072, out_features=768, bias=True)
(activation): GELUActivation()
)
(output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
)
)
)
)

Before we hand over the sentence to the BERT, we need to pre-process the text in the format it requires following 3 step process:

  • Tokenization

This break the sentences into the words ad sub-words in the format that BERT is comfortable with, see the below diagram( diagram is just for the explanation purposes and has nothing to do with the dataset).

tokenized = productData["description"].apply((lambda x: tokenizer.encode(x, add_special_tokens=True)))
BERT Tokenization by Jalammar

BERT uses 101 as [CLS] classifier token index and 102 as [SEP] separator token index.

  • Padding
max_len = 0
for i in tokenized.values:
if len(i) > max_len:
max_len = len(i)

padded_tokens = np.array([i + [0]*(max_len-len(i)) for i in tokenized.values])

BERT is a model with absolute position embeddings, means each token in the input sequence is assigned a unique position embedding based on its position in the sequence.

So it is usually advised to pad the inputs on the right (end of the sequence) rather than the left (beginning of the sequence) This is because the position embedding for a token is based on its position in the sequence, and if the inputs are not padded on the right, then the position embeddings for the padded tokens will be incorrect.

Our dataset is now padded with below dimensions

padded_tokens.shape
(1517, 155)

Going with the example above, the padding will looks something like this:

[ 101 1045 2572 4487 7054 14235 1012 1045 2147 2005 1037 2326 2241 2194 2004 1037 2951 1011 7155 1012 102 0 0 0 0 0 0 0]
  • Masking

When feeding padded tokens into BERT, it’s important to ensure that the model doesn’t get confused. To achieve this, we introduce an additional variable that instructs BERT to disregard (or ‘mask’) the padding we’ve included. Padding tokens don’t carry meaningful information, and that’s precisely why we include them — to ensure consistent input lengths. By effectively masking the padding, we ensure that BERT focuses only on the relevant content, enhancing its processing efficiency and accuracy.”

attention_mask = np.where(padded_tokens != 0, 1, 0)

We than feed the padded & masked output to our model to get all the features, out of which [CLS] is the main goal of focus and we do it as:

input_ids = torch.tensor(padded_tokens)
attention_mask = torch.tensor(attention_mask)

with torch.no_grad():
last_hidden_states = model(input_ids, attention_mask=attention_mask)

features = last_hidden_states[0][:,0,:].numpy()

[CLS] for each sentence is the first token of the BERT model which contains the sentence embedding for each sentence and can be extracted as last_hidden_states[0][:,0,:] where size(last_hidden_states) in our case is (1517 x 155 x 768), We’ll then save those in the features variable, as they’ll serve as the features to our logistics regression or random forest model whichever will be our choice of preference.

Before moving forward let me also define the pre-processing module which will be come for both the model Architecture.

Pre-Processing Module

class RemovePunctuation:
"""
class to remove the corresponding punctuation from the list of punctuations
"""

def __init__(self):
"""
:param empty: None
"""
self.punctuation = string.punctuation

def __call__(self, punctuations):
"""
Apply the transformations above.
:param punctuation: take the single punctuation(in my case '?')
:return: transformed punctuation list, excluding the '?'
"""
if type(punctuations) == str:
punctuations = list(punctuations)
for punctuation in punctuations:
self.punctuation = self.punctuation.translate(str.maketrans('', '', punctuation))
return self.punctuation


# Accessing the remove_punctuation object
remove_punctuation = RemovePunctuation()


def get_wordnet_pos(tag):
if tag.startswith('J'):
return wordnet.ADJ
elif tag.startswith('V'):
return wordnet.VERB
elif tag.startswith('N'):
return wordnet.NOUN
elif tag.startswith('R'):
return wordnet.ADV
else:
return wordnet.NOUN # Default to Noun if the part of speech is not recognized


class ProcessText(object):

@staticmethod
def remove_punctuation_text(text):
"""custom function to remove the punctuation"""
res = (re.findall(r'\w+|[^\s\w]+', text))
name = []
for word in res:
clean_word = word.translate(str.maketrans('', '', remove_punctuation("")))
if clean_word != "":
name.append(clean_word)

return " ".join(name)

@staticmethod
def remove_stopwords(text):
stop_words = set(stopwords.words('english'))
words = word_tokenize(text)
filtered_words = [word for word in words if word.lower() not in stop_words]
return ' '.join(filtered_words)

@staticmethod
def lower_casing(text):
text_lower = text.lower()

return text_lower


@staticmethod
def lemmatize_text(text):
lemmatizer = WordNetLemmatizer()
words = word_tokenize(text)
tagged_words = nltk.pos_tag(words)
lemmatized_words = [lemmatizer.lemmatize(word, pos=get_wordnet_pos(tag)) for word, tag in tagged_words]
return ' '.join(lemmatized_words)

@staticmethod
def remove_duplicates_and_sort(text):
# Split the text into individual words
words = text.split()

# Create a set to store unique words (which automatically removes duplicates)
unique_words = set(words)

# Sort the unique words based on their original order in the text
sorted_unique_words = sorted(unique_words, key=lambda x: words.index(x))

# Join the sorted unique words back into a string with space as separator
sorted_text = ' '.join(sorted_unique_words)

return sorted_text

def remove_numbers(self, text):
# Use regex to replace all numbers with an empty string
cleaned_text = re.sub(r'\d+', '', text)
return cleaned_text

def include_words_with_len_greater_than_2(self, text):
# Split the text into words
words = text.split()

# Filter out words with length greater than 2
filtered_words = [word for word in words if len(word) > 2]

# Join the filtered words back into a text
cleaned_text = ' '.join(filtered_words)

return cleaned_text

def __call__(self, text):
# remove any punctuation
text = self.remove_punctuation_text(text)

# Covert text into lower case
text = self.lower_casing(text)

# Stopwords such as "is", "the", etc that coney no meaning are removed
text = self.remove_stopwords(text)

# Lemmatization is done for converting words to their base or root form, considering their context and part of speech.
text = self.lemmatize_text(text)

# Since words are independent to one another in our problem scenario we can sort the text by word and remove any kind of duplicacy
text = self.remove_duplicates_and_sort(text)

cleaned_text = self.include_words_with_len_greater_than_2(self.remove_numbers(text))

return cleaned_text

Why use POS for Lemmatization?

Because the same words can have different lemmatized form based on its pos(e.g., noun, verb, adjective, adverb). Example: Consider a word flying , sitting, etc. can have same lemmatized form, This is because these words are treated as a noun in the given sentence rather than a verb; which should be fly, sit, etc. respectively hence introduce pos tagging with lemmatization and also because lemmatization is a process that depends on the context and grammatical role of a word within a sentence.

Random Forest Classification

Lets say we use Random Forest for our classification task.

from sklearn.ensemble import RandomForestClassifier

# Create a Random Forest Classifier
rf_clf = RandomForestClassifier(n_estimators=100, random_state=42)

# Split the dataset into training and testing sets (80% train, 10% test)
X_train, X_test, y_train, y_test = train_test_split(features, labels, test_size=0.1, random_state=42)

# Create a Random Forest Classifier
rf_clf = RandomForestClassifier(n_estimators=100, random_state=42)

# Train the model on the training data
rf_clf.fit(X_train, y_train)

# Make predictions on the test set
y_pred = rf_clf.predict(X_test)

# Calculate accuracy
accuracy = accuracy_score(y_test, y_pred)
print("Accuracy:", accuracy)

# Generate a classification report
print("Classification Report:")
print(classification_report(y_test, y_pred))

# Generate a confusion matrix
confusion_matrix = confusion_matrix(y_test, y_pred)
print("Confusion Matrix:")
print(confusion_matrix)

Below are the sample results from the classification of agri and non-agri data on testing dataset, pretty good for small sample size.

Classification Metrices

When experimented with the logistic regression the accuracy got to 96.23% . and we will see further improvement in accuracy when we do the classification with fine-tuned BERT model.

Inference Pipeline

class Predict(object):
def __call__(self, text):
try:
textPre_processing = ProcessText()
processed_description = textPre_processing(text)

if type(processed_description) == str:
processed_description = str(processed_description)
predToken = tokenizer.encode(processed_description, add_special_tokens=True)

max_len = 155
padded_predToken = np.array([predToken + [0]*(max_len-len(predToken))])
predAttention_mask = np.where(padded_predToken != 0, 1, 0)

input_idsPred = torch.tensor(padded_predToken)
attention_maskPred = torch.tensor(predAttention_mask)
with torch.no_grad():
last_hidden_statesPred = model(input_idsPred, attention_mask=attention_maskPred)

featuresPred = last_hidden_statesPred[0][:,0,:].numpy()
predicted_label = rf_model.predict(featuresPred)[0]

return predicted_label

except Exception as error:
print("{}".format(str(error)))
return -1

BERT with Fine-Tuning

The problem is still the same but instead of using the last layer feature representation we will fine-tune the complete BERT model for our use-case.

In this what we did is we leverage the DistilBertTokenizer property to save the attention_mask , padding tokens and labels that we have calculated above when tokenizing the sentences using DistilBertTokenizer. We save these in a bertInput.pickle file.

def save_bertInputs(ids, attention_masks, labels):
# Store features and labels in a dictionary or list
data = {
'input_ids': ids,
'attention_masks': attention_masks,
'labels': labels
}

# Save the data to a pickle file
with open('bertInput.pickle', 'wb') as f:
pickle.dump(data, f)


save_bertInputs(ids, attention_masks, label)

We then loaded the embedding and called our data-loader for test and validation.

train_ratio = 0.8

batch_size = 16

# Indices of the train and validation splits stratified by labels
indices = torch.randperm(len(labels)).tolist()

train_no = int(len(indices) * 0.8)
test_no = int(len(indices) - train_no)

train_idx = np.array(indices[:train_no])
val_idx = np.array(indices[-test_no:])

# Train and validation sets
train_set = TensorDataset(
attention_masks[train_idx],
input_ids[train_idx],
labels[train_idx]
)

val_set = TensorDataset(
attention_masks[val_idx],
input_ids[val_idx],
labels[val_idx]
)

# Prepare DataLoader
train_dataloader = DataLoader(
train_set,
batch_sampler = BatchSampler(
RandomSampler(train_set),
batch_size=batch_size,
drop_last=False
)
)

validation_dataloader = DataLoader(
val_set,
batch_sampler = BatchSampler(
RandomSampler(val_set),
batch_size=batch_size,
drop_last=False
)
)

drop_last arguments in the BatchSampler allow us to handle cases where number of samples in the dataset are not evenly distributed examples consider a dataset of sample_size 10 and batch_size 3, 10 % 3 = 1, so the last batch will have a sample of size 1 left with it which we can handle using drop_last.

If drop_last=True it will drop that un-even sample of sample_size=1 while drop_last=False will keep that un-even sample

Both batch_size & drop_last help us achieve the automatic batching via interactive-style dataset with multi-processing approach.

Model Training

In this case we have used BertForSequenceClassification for both modeling and inference as NSP is all in all a classification problem as explained before.

# Load the pre-trained BERT model and tokenizer
model_name = 'bert-base-uncased'
model = BertForSequenceClassification.from_pretrained(
pretrained_model_name_or_path=model_name,
output_hidden_states = False,
output_attentions = False,
num_labels = 2,
)

optimizer = torch.optim.AdamW(
model.parameters(),
eps = 1e-08,
lr = 5e-5,
)

model.cuda()
BertForSequenceClassification(
(bert): BertModel(
(embeddings): BertEmbeddings(
(word_embeddings): Embedding(30522, 768, padding_idx=0)
(position_embeddings): Embedding(512, 768)
(token_type_embeddings): Embedding(2, 768)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
(encoder): BertEncoder(
(layer): ModuleList(
(0-11): 12 x BertLayer(
(attention): BertAttention(
(self): BertSelfAttention(
(query): Linear(in_features=768, out_features=768, bias=True)
(key): Linear(in_features=768, out_features=768, bias=True)
(value): Linear(in_features=768, out_features=768, bias=True)
(dropout): Dropout(p=0.1, inplace=False)
)
(output): BertSelfOutput(
(dense): Linear(in_features=768, out_features=768, bias=True)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
(intermediate): BertIntermediate(
(dense): Linear(in_features=768, out_features=3072, bias=True)
(intermediate_act_fn): GELUActivation()
)
(output): BertOutput(
(dense): Linear(in_features=3072, out_features=768, bias=True)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
)
)
(pooler): BertPooler(
(dense): Linear(in_features=768, out_features=768, bias=True)
(activation): Tanh()
)
)
(dropout): Dropout(p=0.1, inplace=False)
(classifier): Linear(in_features=768, out_features=2, bias=True)
)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

epochs = 10
for _ in trange(epochs, desc = 'Epoch'):

# ========== Training ==========

# Set model to training mode
model.train()

# Tracking variables
tr_loss = 0
nb_tr_examples, nb_tr_steps = 0, 0

for step, batch in enumerate(train_dataloader):
batch = tuple(t.to(device) for t in batch)
b_input_ids, b_input_mask, b_labels = batch
optimizer.zero_grad()

# Forward pass
train_output = model(
b_input_ids,
token_type_ids = None,
attention_mask = b_input_mask,
labels = b_labels
)

# Backward pass
train_output.loss.backward()
optimizer.step()

# Update tracking variables
tr_loss += train_output.loss.item()
nb_tr_examples += b_input_ids.size(0)
nb_tr_steps += 1

# ========== Validation ==========

# Set model to evaluation mode
model.eval()

# Tracking variables
val_accuracy = []
val_precision = []
val_recall = []
val_specificity = []

for batch in validation_dataloader:
batch = tuple(t.to(device) for t in batch)
b_input_ids, b_input_mask, b_labels = batch
with torch.no_grad():
# Forward pass
eval_output = model(b_input_ids,
token_type_ids = None,
attention_mask = b_input_mask)

logits = eval_output.logits.detach().cpu().numpy()
label_ids = b_labels.to('cpu').numpy()

torch.save(model.state_dict(), 'fine_tuned_bert.pth')

We then load the saved model and do the inferencing as follows:

Inference

# Load the model architecture
model = BertForSequenceClassification.from_pretrained('bert-base-uncased')

# Load the saved state dictionary
model.load_state_dict(torch.load('fine_tuned_bert.pth'))

model.cuda()

class Predict(object):
def __call__(self, text):
try:
textPre_processing = ProcessText()
processed_description = textPre_processing(text)

if type(processed_description) == str:
processed_description = str(processed_description)
predToken = tokenizer.encode(processed_description, add_special_tokens=True)

max_len = 155

padded_predToken = np.array([predToken + [0]*(max_len-len(predToken))])
predAttention_mask = np.where(padded_predToken != 0, 1, 0)

input_idsPred = torch.tensor(padded_predToken)
attention_maskPred = torch.tensor(predAttention_mask)

with torch.no_grad():
output = model(input_idsPred.to(device), token_type_ids = None, attention_mask = attention_maskPred.to(device))

prediction = 1 if np.argmax(output.logits.cpu().numpy()).flatten().item() == 1 else 0

return prediction

except Exception as error:
print("{}".format(str(error)))
return -1

Below are the sample results from the classification of agri and non-agri data on testing dataset, using fine-tuned model.

Classification Metrices

Conclusion

In this article, we try to understand BERT by its different way of achieving same( classification) task and also saw how fine-tuning the BERT model for classification tasks, surpassing the conventional pre-trained BERT + logistic regression/ random forest method.

While the accuracy difference may not be substantial due to our limited data points, the fine-tuned BERT model still shines, showcasing its remarkable ability to understand intricate relationships within the data, resulting in higher accuracy, improved precision, and enhanced performance across our classification tasks.

This only shows that with an increase in data points, we can anticipate even greater improvements in model accuracy. Expanding our dataset is poised to further elevate the performance of fine-tuned BERT.

This blog is designed to provide viewers with a deeper understanding of BERT implementation and its various approaches on classification task. Now since the model is ready to use how do we deploy our Model, Tokenizer and all the Text Pre-processing steps aggregated as one to be used directly just by calling the model. For this we will be using hugging face, deploy our custom pipeline in our next blog post as a Part II of this.

Your comments and feedback are greatly appreciated for open communication and further insights. Feel free to reach out or connect with me via my LinkedIn page. Cheers!

Sign up to discover human stories that deepen your understanding of the world.

Free

Distraction-free reading. No ads.

Organize your knowledge with lists and highlights.

Tell your story. Find your audience.

Membership

Read member-only stories

Support writers you read most

Earn money for your writing

Listen to audio narrations

Read offline with the Medium app

--

--

Responses (1)

Write a response