Loss Overview
Loss Table
Loss functions play a critical role in the performance of your fine-tuned model. Sadly, there is no “one size fits all” loss function. Ideally, this table should help narrow down your choice of loss function(s) by matching them to your data formats.
Note
You can often convert one training data format into another, allowing more loss functions to be viable for your scenario. For example, (sentence_A, sentence_B) pairs with class labels can be converted into (anchor, positive, negative) triplets by sampling sentences with the same or different classes.
Loss modifiers
These loss functions can be seen as loss modifiers: they work on top of standard loss functions, but apply those loss functions in different ways to try and instil useful properties into the trained embedding model.
For example, models trained with MatryoshkaLoss produce embeddings whose size can be truncated without notable losses in performance, and models trained with AdaptiveLayerLoss still perform well when you remove model layers for faster inference.
| Texts | Labels | Appropriate Loss Functions |
|---|---|---|
any |
any |
MatryoshkaLossAdaptiveLayerLossMatryoshka2dLoss |
Distillation
These loss functions are specifically designed to be used when distilling the knowledge from one model into another. For example, when finetuning a small model to behave more like a larger & stronger one, or when finetuning a model to become multi-lingual.
| Texts | Labels | Appropriate Loss Functions |
|---|---|---|
sentence |
model sentence embeddings |
MSELoss |
(sentence_1, sentence_2, ..., sentence_N) |
model sentence embeddings |
MSELoss |
(query, passage_one, passage_two) |
gold_sim(query, passage_one) - gold_sim(query, passage_two) |
MarginMSELoss |
(query, positive, negative_1, ..., negative_n) |
[gold_sim(query, positive) - gold_sim(query, negative_i) for i in 1..n] |
MarginMSELoss |
(query, positive, negative) |
[gold_sim(query, positive), gold_sim(query, negative)] |
DistillKLDivLossMarginMSELoss |
(query, positive, negative_1, ..., negative_n) |
[gold_sim(query, positive), gold_sim(query, negative_i)...] |
DistillKLDivLossMarginMSELoss |
Commonly used Loss Functions
In practice, not all loss functions get used equally often. The most common scenarios are:
(anchor, positive) pairswithout any labels:MultipleNegativesRankingLoss(a.k.a. InfoNCE or in-batch negatives loss) is commonly used to train the top performing embedding models. This data is often relatively cheap to obtain, and the models are generally very performant.CachedMultipleNegativesRankingLossis often used to increase the batch size, resulting in superior performance.(sentence_A, sentence_B) pairswith afloat similarity score:CosineSimilarityLossis traditionally used a lot, though more recentlyCoSENTLossandAnglELossare used as drop-in replacements with superior performance.
Custom Loss Functions
Advanced users can create and train with their own loss functions. Custom loss functions only have a few requirements:
They must be a subclass of
torch.nn.Module.They must have
modelas the first argument in the constructor.They must implement a
forwardmethod that acceptssentence_featuresandlabels. The former is a list of tokenized batches, one element for each column. These tokenized batches can be fed directly to themodelbeing trained to produce embeddings. The latter is an optional tensor of labels. The method must return a single loss value or a dictionary of loss components (component names to loss values) that will be summed to produce the final loss value. When returning a dictionary, the individual components will be logged separately in addition to the summed loss, allowing you to monitor the individual components of the loss.
To get full support with the automatic model card generation, you may also wish to implement:
a
get_config_dictmethod that returns a dictionary of loss parameters.a
citationproperty so your work gets cited in all models that train with the loss.
Consider inspecting existing loss functions to get a feel for how loss functions are commonly implemented.