Cross-entropy loss is a widely used loss function for training classification models, particularly for tasks where the model predicts probabilities for multiple classes, such as language modeling. In the context of transformers, cross-entropy loss measures how well the model’s predicted probability distribution over possible next tokens matches the actual target (ground truth) token.
What Cross-entropy Loss Does
Cross-entropy loss quantifies the difference between:
- The true distribution (the target or correct answer), which in the case of language modeling is typically a one-hot encoded vector with the probability of 1 for the correct token and 0 for all others.
- The model’s predicted distribution (its output probabilities), which is the probability the model assigns to each possible next token in the vocabulary.
How Cross-entropy Loss Works
- Calculate Probabilities: The model outputs a probability distribution over all tokens in the vocabulary, representing the likelihood of each token being the correct next token.
- Assign Loss Based on Correct Token: Cross-entropy loss penalizes the model based on how much probability it assigned to the correct token. The higher the probability for the correct token, the lower the loss.
Mathematically, the cross-entropy loss for a single token prediction is:
where is the probability assigned by the model to the correct token.
Why Cross-entropy Loss is Effective for Language Models
- Encourages Confidence in Correct Prediction: By minimizing cross-entropy loss, the model learns to assign higher probabilities to the correct tokens.
- Smooths Gradient Signals: Cross-entropy provides a smooth gradient signal that helps the model learn effectively by adjusting probabilities for all tokens in the vocabulary, not just the correct one.
Example
If the target token is “cat” and the model outputs probabilities over a vocabulary where it assigns 0.7 to “cat” and 0.3 to other words, the cross-entropy loss would be:
If the model assigned a higher probability to “cat” (say 0.9), the loss would be lower, meaning the model is closer to the correct answer.