Training Overview

Why Finetune?

Finetuning Sparse Encoder models often heavily improves the performance of the model on your use case, because each task requires a different notion of similarity. For example, given news articles:

  • “Apple launches the new iPad”

  • “NVIDIA is gearing up for the next GPU generation”

Then the following use cases, we may have different notions of similarity:

  • a model for classification of news articles as Economy, Sports, Technology, Politics, etc., should produce similar embeddings for these texts.

  • a model for semantic textual similarity should produce dissimilar embeddings for these texts, as they have different meanings.

  • a model for semantic search would not need a notion for similarity between two documents, as it should only compare queries and documents.

Also see Training Examples for numerous training scripts for common real-world applications that you can adopt.

Training Components

Training Sparse Encoder models involves between 4 to 6 components:

Model

Sparse Encoder models consist of a sequence of Modules, Sparse Encoder specific Modules or Custom Modules, allowing for a lot of flexibility. If you want to further finetune a SparseEncoder model (e.g. it has a modules.json file), then you don’t have to worry about which modules are used:

from sentence_transformers import SparseEncoder

model = SparseEncoder("naver/splade-cocondenser-ensembledistil")

But if instead you want to train from another checkpoint, or from scratch, then these are the most common architectures you can use:

Splade models use the MLMTransformer followed by a SpladePooling modules. The former loads a pretrained Masked Language Modeling transformer model (e.g. BERT, RoBERTa, DistilBERT, ModernBERT, etc.) and the latter pools the output of the MLMHead to produce a single sparse embedding of the size of the vocabulary.

from sentence_transformers import models, SparseEncoder
from sentence_transformers.sparse_encoder.models import MLMTransformer, SpladePooling

# Initialize MLM Transformer (use a fill-mask model)
mlm_transformer = MLMTransformer("google-bert/bert-base-uncased")

# Initialize SpladePooling module
splade_pooling = SpladePooling(pooling_strategy="max")

# Create the Splade model
model = SparseEncoder(modules=[mlm_transformer, splade_pooling])

This architecture is the default if you provide a fill-mask model architecture to SparseEncoder, so it’s easier to use the shortcut:

from sentence_transformers import SparseEncoder

model = SparseEncoder("google-bert/bert-base-uncased")
# SparseEncoder(
#   (0): MLMTransformer({'max_seq_length': 512, 'do_lower_case': False, 'architecture': 'BertForMaskedLM'})
#   (1): SpladePooling({'pooling_strategy': 'max', 'activation_function': 'relu', 'word_embedding_dimension': None})
# )

Inference-free Splade uses a Router module with different modules for queries and documents. Usually for this type of architecture, the documents part is a traditional Splade architecture (a MLMTransformer followed by a SpladePooling module) and the query part is an SparseStaticEmbedding module, which just returns a pre-computed score for every token in the query.

from sentence_transformers import SparseEncoder
from sentence_transformers.models import Router
from sentence_transformers.sparse_encoder.models import MLMTransformer, SparseStaticEmbedding, SpladePooling

# Initialize MLM Transformer for document encoding
doc_encoder = MLMTransformer("google-bert/bert-base-uncased")

# Create a router model with different paths for queries and documents
router = Router.for_query_document(
    query_modules=[SparseStaticEmbedding(tokenizer=doc_encoder.tokenizer, frozen=False)],
    # Document path: full MLM transformer + pooling
    document_modules=[doc_encoder, SpladePooling("max")],
)

# Create the inference-free model
model = SparseEncoder(modules=[router], similarity_fn_name="dot")
# SparseEncoder(
#   (0): Router(
#     (query_0_SparseStaticEmbedding): SparseStaticEmbedding({'frozen': False}, dim:30522, tokenizer: BertTokenizerFast)
#     (document_0_MLMTransformer): MLMTransformer({'max_seq_length': 512, 'do_lower_case': False, 'architecture': 'BertForMaskedLM'})
#     (document_1_SpladePooling): SpladePooling({'pooling_strategy': 'max', 'activation_function': 'relu', 'word_embedding_dimension': None})
#   )
# )

This architecture allows for fast query-time processing using the lightweight SparseStaticEmbedding approach, that can be trained and seen as a linear weights, while documents are processed with the full MLM transformer and SpladePooling.

Tip

Inference-free Splade is particularly useful for search applications where query latency is critical, as it shifts the computational complexity to the document indexing phase which can be done offline.

Note

When training models with the Router module, you must use the router_mapping argument in the SparseEncoderTrainingArguments to map the training dataset columns to the correct route (“query” or “document”). For example, if your dataset(s) have ["question", "answer"] columns, then you can use the following mapping:

args = SparseEncoderTrainingArguments(
    ...,
    router_mapping={
        "question": "query",
        "answer": "document",
    }
)

Additionally, it is recommended to use a much higher learning rate for the SparseStaticEmbedding module than for the rest of the model. For this, you should use the learning_rate_mapping argument in the SparseEncoderTrainingArguments to map parameter patterns to their learning rates. For example, if you want to use a learning rate of 1e-3 for the SparseStaticEmbedding module and 2e-5 for the rest of the model, you can do this:

args = SparseEncoderTrainingArguments(
    ...,
    learning_rate=2e-5,
    learning_rate_mapping={
        r"SparseStaticEmbedding\.*": 1e-3,
    }
)

Contrastive Sparse Representation (CSR) models apply a SparseAutoEncoder module on top of a dense Sentence Transformer model, which usually consist of a Transformer followed by a Pooling module. You can initialize one from scratch like so:

from sentence_transformers import models, SparseEncoder
from sentence_transformers.sparse_encoder.models import SparseAutoEncoder

# Initialize transformer (can be any dense encoder model)
transformer = models.Transformer("google-bert/bert-base-uncased")

# Initialize pooling
pooling = models.Pooling(transformer.get_word_embedding_dimension(), pooling_mode="mean")

# Initialize SparseAutoEncoder module
sae = SparseAutoEncoder(
    input_dim=transformer.get_word_embedding_dimension(),
    hidden_dim=4 * transformer.get_word_embedding_dimension(),
    k=256,  # Number of top values to keep
    k_aux=512,  # Number of top values for auxiliary loss
)
# Create the CSR model
model = SparseEncoder(modules=[transformer, pooling, sae])

Or if your base model is 1) a dense Sentence Transformer model or 2) a non-MLM Transformer model (those are loaded as Splade models by default), then this shortcut will automatically initialize the CSR model for you:

from sentence_transformers import SparseEncoder

model = SparseEncoder("mixedbread-ai/mxbai-embed-large-v1")
# SparseEncoder(
#   (0): Transformer({'max_seq_length': 512, 'do_lower_case': False, 'architecture': 'BertModel'})
#   (1): Pooling({'word_embedding_dimension': 1024, 'pooling_mode_cls_token': True, 'pooling_mode_mean_tokens': False, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False, 'pooling_mode_weightedmean_tokens': False, 'pooling_mode_lasttoken': False, 'include_prompt': True})
#   (2): SparseAutoEncoder({'input_dim': 1024, 'hidden_dim': 4096, 'k': 256, 'k_aux': 512, 'normalize': False, 'dead_threshold': 30})
# )

Warning

Unlike (Inference-free) Splade models, sparse embeddings by CSR models don’t have the same size as the vocabulary of the base model. This means you can’t directly interpret which words are activated in your embedding like you can with Splade models, where each dimension corresponds to a specific token in the vocabulary.

Beyond that, CSR models are most effective on dense encoder models that use high-dimensional representations (e.g. 1024-4096 dimensions).

Dataset

The SparseEncoderTrainer trains and evaluates using datasets.Dataset (one dataset) or datasets.DatasetDict instances (multiple datasets, see also Multi-dataset training).

If you want to load data from the Hugging Face Datasets, then you should use datasets.load_dataset():

from datasets import load_dataset

train_dataset = load_dataset("sentence-transformers/all-nli", "triplet", split="train")
eval_dataset = load_dataset("sentence-transformers/all-nli", "triplet", split="dev")

print(train_dataset)
"""
Dataset({
    features: ['anchor', 'positive', 'negative'],
    num_rows: 557850
})
"""

Some datasets (including sentence-transformers/all-nli) require you to provide a “subset” alongside the dataset name. sentence-transformers/all-nli has 4 subsets, each with different data formats: pair, pair-class, pair-score, triplet.

Note

Many Hugging Face datasets that work out of the box with Sentence Transformers have been tagged with sentence-transformers, allowing you to easily find them by browsing to https://huggingface.co/datasets?other=sentence-transformers. We strongly recommend that you browse these datasets to find training datasets that might be useful for your tasks.

If you have local data in common file-formats, then you can load this data easily using datasets.load_dataset():

from datasets import load_dataset

dataset = load_dataset("csv", data_files="my_file.csv")

or:

from datasets import load_dataset

dataset = load_dataset("json", data_files="my_file.json")

If you have local data that requires some extra pre-processing, my recommendation is to initialize your dataset using datasets.Dataset.from_dict() and a dictionary of lists, like so:

from datasets import Dataset

anchors = []
positives = []
# Open a file, do preprocessing, filtering, cleaning, etc.
# and append to the lists

dataset = Dataset.from_dict({
    "anchor": anchors,
    "positive": positives,
})

Each key from the dictionary will become a column in the resulting dataset.

Dataset Format

It is important that your dataset format matches your loss function (or that you choose a loss function that matches your dataset format). Verifying whether a dataset format works with a loss function involves two steps:

  1. If your loss function requires a Label according to the Loss Overview table, then your dataset must have a column named “label” or “score”. This column is automatically taken as the label.

  2. All columns not named “label” or “score” are considered Inputs according to the Loss Overview table. The number of remaining columns must match the number of valid inputs for your chosen loss. The names of these columns are irrelevant, only the order matters.

For example, given a dataset with columns ["text1", "text2", "label"] where the “label” column has float similarity score between 0 and 1, we can use it with SparseCoSENTLoss, SparseAnglELoss, and SparseCosineSimilarityLoss because it:

  1. has a “label” column as is required for these loss functions.

  2. has 2 non-label columns, exactly the amount required by these loss functions.

Be sure to re-order your dataset columns with Dataset.select_columns if your columns are not ordered correctly. For example, if your dataset has ["good_answer", "bad_answer", "question"] as columns, then this dataset can technically be used with a loss that requires (anchor, positive, negative) triplets, but the good_answer column will be taken as the anchor, bad_answer as the positive, and question as the negative.

Additionally, if your dataset has extraneous columns (e.g. sample_id, metadata, source, type), you should remove these with Dataset.remove_columns as they will be used as inputs otherwise. You can also use Dataset.select_columns to keep only the desired columns.

Loss Function

Loss functions quantify how well a model performs for a given batch of data, allowing an optimizer to update the model weights to produce more favourable (i.e., lower) loss values. This is the core of the training process.

Sadly, there is no single loss function that works best for all use-cases. Instead, which loss function to use greatly depends on your available data and on your target task. See Dataset Format to learn what datasets are valid for which loss functions. Additionally, the Loss Overview will be your best friend to learn about the options.

Warning

To train a SparseEncoder, you need either SpladeLoss or CSRLoss, depending on the architecture. These are wrapper losses that add sparsity regularization on top of a main loss function, which must be provided as a parameter. The only loss that can be used independently is SparseMSELoss, as it performs embedding-level distillation, ensuring sparsity by directly copying the teacher’s sparse embedding.

Most loss functions can be initialized with just the SparseEncoder that you’re training, alongside some optional parameters, e.g.:

from datasets import load_dataset
from sentence_transformers import SparseEncoder
from sentence_transformers.sparse_encoder.losses import SpladeLoss, SparseMultipleNegativesRankingLoss

# Load a model to train/finetune
model = SparseEncoder("distilbert/distilbert-base-uncased")

# Initialize the SpladeLoss with a SparseMultipleNegativesRankingLoss
# This loss requires pairs of related texts or triplets
loss = SpladeLoss(
    model=model,
    loss=SparseMultipleNegativesRankingLoss(model=model),
    query_regularizer_weight=5e-5,  # Weight for query loss
    document_regularizer_weight=3e-5,
)

# Load an example training dataset that works with our loss function:
train_dataset = load_dataset("sentence-transformers/natural-questions", split="train")
print(train_dataset)
"""
Dataset({
    features: ['query', 'answer'],
    num_rows: 100231
})
"""

Training Arguments

The SparseEncoderTrainingArguments class can be used to specify parameters for influencing training performance as well as defining the tracking/debugging parameters. Although it is optional, it is heavily recommended to experiment with the various useful arguments.



Here is an example of how SparseEncoderTrainingArguments can be initialized:

args = SparseEncoderTrainingArguments(
    # Required parameter:
    output_dir="models/splade-distilbert-base-uncased-nq",
    # Optional training parameters:
    num_train_epochs=1,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    learning_rate=2e-5,
    warmup_ratio=0.1,
    fp16=True,  # Set to False if you get an error that your GPU can't run on FP16
    bf16=False,  # Set to True if you have a GPU that supports BF16
    batch_sampler=BatchSamplers.NO_DUPLICATES,  # losses that use "in-batch negatives" benefit from no duplicates
    # Optional tracking/debugging parameters:
    eval_strategy="steps",
    eval_steps=100,
    save_strategy="steps",
    save_steps=100,
    save_total_limit=2,
    logging_steps=100,
    run_name="splade-distilbert-base-uncased-nq",  # Will be used in W&B if `wandb` is installed
)

Evaluator

You can provide the SparseEncoderTrainer with an eval_dataset to get the evaluation loss during training, but it may be useful to get more concrete metrics during training, too. For this, you can use evaluators to assess the model’s performance with useful metrics before, during, or after training. You can use both an eval_dataset and an evaluator, one or the other, or neither. They evaluate based on the eval_strategy and eval_steps Training Arguments.

Here are the implemented Evaluators that come with Sentence Transformers for Sparse Encoder models:

Evaluator

Required Data

SparseBinaryClassificationEvaluator

Pairs with class labels.

SparseEmbeddingSimilarityEvaluator

Pairs with similarity scores.

SparseInformationRetrievalEvaluator

Queries (qid => question), Corpus (cid => document), and relevant documents (qid => set[cid]).

SparseNanoBEIREvaluator

No data required.

SparseMSEEvaluator

Source sentences to embed with a teacher model and target sentences to embed with the student model. Can be the same texts.

SparseRerankingEvaluator

List of {'query': '...', 'positive': [...], 'negative': [...]} dictionaries.

SparseTranslationEvaluator

Pairs of sentences in two separate languages.

SparseTripletEvaluator

(anchor, positive, negative) pairs.

Additionally, SequentialEvaluator should be used to combine multiple evaluators into one Evaluator that can be passed to the SparseEncoderTrainer.

Sometimes you don’t have the required evaluation data to prepare one of these evaluators on your own, but you still want to track how well the model performs on some common benchmarks. In that case, you can use these evaluators with data from Hugging Face.

from sentence_transformers.sparse_encoder.evaluation import SparseNanoBEIREvaluator

# Initialize the evaluator. Unlike most other evaluators, this one loads the relevant datasets
# directly from Hugging Face, so there's no mandatory arguments
dev_evaluator = SparseNanoBEIREvaluator()
# You can run evaluation like so:
# results = dev_evaluator(model)
from datasets import load_dataset
from sentence_transformers.evaluation import SimilarityFunction
from sentence_transformers.sparse_encoder.evaluation import SparseEmbeddingSimilarityEvaluator

# Load the STSB dataset (https://huggingface.co/datasets/sentence-transformers/stsb)
eval_dataset = load_dataset("sentence-transformers/stsb", split="validation")

# Initialize the evaluator
dev_evaluator = SparseEmbeddingSimilarityEvaluator(
    sentences1=eval_dataset["sentence1"],
    sentences2=eval_dataset["sentence2"],
    scores=eval_dataset["score"],
    main_similarity=SimilarityFunction.COSINE,
    name="sts-dev",
)
# You can run evaluation like so:
# results = dev_evaluator(model)
from datasets import load_dataset
from sentence_transformers.evaluation import SimilarityFunction
from sentence_transformers.sparse_encoder.evaluation import SparseTripletEvaluator

# Load triplets from the AllNLI dataset (https://huggingface.co/datasets/sentence-transformers/all-nli)
max_samples = 1000
eval_dataset = load_dataset("sentence-transformers/all-nli", "triplet", split=f"dev[:{max_samples}]")

# Initialize the evaluator
dev_evaluator = SparseTripletEvaluator(
    anchors=eval_dataset["anchor"],
    positives=eval_dataset["positive"],
    negatives=eval_dataset["negative"],
    main_distance_function=SimilarityFunction.DOT,
    name="all-nli-dev",
)
# You can run evaluation like so:
# results = dev_evaluator(model)

Tip

When evaluating frequently during training with a small eval_steps, consider using a tiny eval_dataset to minimize evaluation overhead. If you’re concerned about the evaluation set size, a 90-1-9 train-eval-test split can provide a balance, reserving a reasonably sized test set for final evaluations. After training, you can assess your model’s performance using trainer.evaluate(test_dataset) for test loss or initialize a testing evaluator with test_evaluator(model) for detailed test metrics.

If you evaluate after training, but before saving the model, your automatically generated model card will still include the test results.

Warning

When using Distributed Training, the evaluator only runs on the first device, unlike the training and evaluation datasets, which are shared across all devices.

Trainer

The SparseEncoderTrainer is where all previous components come together. We only have to specify the trainer with the model, training arguments (optional), training dataset, evaluation dataset (optional), loss function, evaluator (optional) and we can start training. Let’s have a look at a script where all of these components come together:

import logging

from datasets import load_dataset

from sentence_transformers import (
    SparseEncoder,
    SparseEncoderModelCardData,
    SparseEncoderTrainer,
    SparseEncoderTrainingArguments,
)
from sentence_transformers.sparse_encoder.evaluation import SparseNanoBEIREvaluator
from sentence_transformers.sparse_encoder.losses import SparseMultipleNegativesRankingLoss, SpladeLoss
from sentence_transformers.training_args import BatchSamplers

logging.basicConfig(format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO)

# 1. Load a model to finetune with 2. (Optional) model card data
model = SparseEncoder(
    "distilbert/distilbert-base-uncased",
    model_card_data=SparseEncoderModelCardData(
        language="en",
        license="apache-2.0",
        model_name="DistilBERT base trained on Natural-Questions tuples",
    )
)

# 3. Load a dataset to finetune on
full_dataset = load_dataset("sentence-transformers/natural-questions", split="train").select(range(100_000))
dataset_dict = full_dataset.train_test_split(test_size=1_000, seed=12)
train_dataset = dataset_dict["train"]
eval_dataset = dataset_dict["test"]

# 4. Define a loss function
loss = SpladeLoss(
    model=model,
    loss=SparseMultipleNegativesRankingLoss(model=model),
    query_regularizer_weight=5e-5,
    document_regularizer_weight=3e-5,
)

# 5. (Optional) Specify training arguments
run_name = "splade-distilbert-base-uncased-nq"
args = SparseEncoderTrainingArguments(
    # Required parameter:
    output_dir=f"models/{run_name}",
    # Optional training parameters:
    num_train_epochs=1,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    learning_rate=2e-5,
    warmup_ratio=0.1,
    fp16=True,  # Set to False if you get an error that your GPU can't run on FP16
    bf16=False,  # Set to True if you have a GPU that supports BF16
    batch_sampler=BatchSamplers.NO_DUPLICATES,  # MultipleNegativesRankingLoss benefits from no duplicate samples in a batch
    # Optional tracking/debugging parameters:
    eval_strategy="steps",
    eval_steps=1000,
    save_strategy="steps",
    save_steps=1000,
    save_total_limit=2,
    logging_steps=200,
    run_name=run_name,  # Will be used in W&B if `wandb` is installed
)

# 6. (Optional) Create an evaluator & evaluate the base model
dev_evaluator = SparseNanoBEIREvaluator(dataset_names=["msmarco", "nfcorpus", "nq"], batch_size=16)

# 7. Create a trainer & train
trainer = SparseEncoderTrainer(
    model=model,
    args=args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    loss=loss,
    evaluator=dev_evaluator,
)
trainer.train()

# 8. Evaluate the model performance again after training
dev_evaluator(model)

# 9. Save the trained model
model.save_pretrained(f"models/{run_name}/final")

# 10. (Optional) Push it to the Hugging Face Hub
model.push_to_hub(run_name)
import logging

from datasets import load_dataset

from sentence_transformers import (
    SparseEncoder,
    SparseEncoderModelCardData,
    SparseEncoderTrainer,
    SparseEncoderTrainingArguments,
)
from sentence_transformers.models import Router
from sentence_transformers.sparse_encoder.evaluation import SparseNanoBEIREvaluator
from sentence_transformers.sparse_encoder.losses import SparseMultipleNegativesRankingLoss, SpladeLoss
from sentence_transformers.sparse_encoder.models import MLMTransformer, SparseStaticEmbedding, SpladePooling
from sentence_transformers.training_args import BatchSamplers

logging.basicConfig(format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO)

# 1. Load a model to finetune with 2. (Optional) model card data
mlm_transformer = MLMTransformer("distilbert/distilbert-base-uncased", tokenizer_args={"model_max_length": 512})
splade_pooling = SpladePooling(
    pooling_strategy="max", word_embedding_dimension=mlm_transformer.get_sentence_embedding_dimension()
)
router = Router.for_query_document(
    query_modules=[SparseStaticEmbedding(tokenizer=mlm_transformer.tokenizer, frozen=False)],
    document_modules=[mlm_transformer, splade_pooling],
)

model = SparseEncoder(
    modules=[router],
    model_card_data=SparseEncoderModelCardData(
        language="en",
        license="apache-2.0",
        model_name="Inference-free SPLADE distilbert-base-uncased trained on Natural-Questions tuples",
    ),
)

# 3. Load a dataset to finetune on
full_dataset = load_dataset("sentence-transformers/natural-questions", split="train").select(range(100_000))
dataset_dict = full_dataset.train_test_split(test_size=1_000, seed=12)
train_dataset = dataset_dict["train"]
eval_dataset = dataset_dict["test"]
print(train_dataset)
print(train_dataset[0])

# 4. Define a loss function
loss = SpladeLoss(
    model=model,
    loss=SparseMultipleNegativesRankingLoss(model=model),
    query_regularizer_weight=0,
    document_regularizer_weight=3e-4,
)

# 5. (Optional) Specify training arguments
run_name = "inference-free-splade-distilbert-base-uncased-nq"
args = SparseEncoderTrainingArguments(
    # Required parameter:
    output_dir=f"models/{run_name}",
    # Optional training parameters:
    num_train_epochs=1,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    learning_rate=2e-5,
    learning_rate_mapping={r"SparseStaticEmbedding\.weight": 1e-3},  # Set a higher learning rate for the SparseStaticEmbedding module
    warmup_ratio=0.1,
    fp16=True,  # Set to False if you get an error that your GPU can't run on FP16
    bf16=False,  # Set to True if you have a GPU that supports BF16
    batch_sampler=BatchSamplers.NO_DUPLICATES,  # MultipleNegativesRankingLoss benefits from no duplicate samples in a batch
    router_mapping={"query": "query", "answer": "document"},  # Map the column names to the routes
    # Optional tracking/debugging parameters:
    eval_strategy="steps",
    eval_steps=1000,
    save_strategy="steps",
    save_steps=1000,
    save_total_limit=2,
    logging_steps=200,
    run_name=run_name,  # Will be used in W&B if `wandb` is installed
)

# 6. (Optional) Create an evaluator & evaluate the base model
dev_evaluator = SparseNanoBEIREvaluator(dataset_names=["msmarco", "nfcorpus", "nq"], batch_size=16)

# 7. Create a trainer & train
trainer = SparseEncoderTrainer(
    model=model,
    args=args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    loss=loss,
    evaluator=dev_evaluator,
)
trainer.train()

# 8. Evaluate the model performance again after training
dev_evaluator(model)

# 9. Save the trained model
model.save_pretrained(f"models/{run_name}/final")

# 10. (Optional) Push it to the Hugging Face Hub
model.push_to_hub(run_name)

Callbacks

This Sparse Encoder trainer integrates support for various transformers.TrainerCallback subclasses, such as:

See the Transformers Callbacks documentation for more information on the integrated callbacks and how to write your own callbacks.

Multi-Dataset Training

The top performing models are trained using many datasets at once. Normally, this is rather tricky, as each dataset has a different format. However, SparseEncoderTrainer can train with multiple datasets without having to convert each dataset to the same format. It can even apply different loss functions to each of the datasets. The steps to train with multiple datasets are:

  • Use a dictionary of Dataset instances (or a DatasetDict) as the train_dataset (and optionally also eval_dataset).

  • (Optional) Use a dictionary of loss functions mapping dataset names to losses. Only required if you wish to use different loss function for different datasets.

Each training/evaluation batch will only contain samples from one of the datasets. The order in which batches are samples from the multiple datasets is defined by the MultiDatasetBatchSamplers enum, which can be passed to the SparseEncoderTrainingArguments via multi_dataset_batch_sampler. Valid options are:

  • MultiDatasetBatchSamplers.ROUND_ROBIN: Round-robin sampling from each dataset until one is exhausted. With this strategy, it’s likely that not all samples from each dataset are used, but each dataset is sampled from equally.

  • MultiDatasetBatchSamplers.PROPORTIONAL (default): Sample from each dataset in proportion to its size. With this strategy, all samples from each dataset are used and larger datasets are sampled from more frequently.

Training Tips

Sparse Encoder models have a few quirks that you should be aware of when training them:

  1. Sparse Encoder models should not be evaluated solely using the evaluation scores, but also with the sparsity of the embeddings. After all, a low sparsity means that the model embeddings are expensive to store and slow to retrieve. This also means that the parameters that determine sparsity (e.g. query_regularizer_weight, document_regularizer_weight in SpladeLoss and beta and gamma in the CSRLoss) should be tuned to achieve a good balance between performance and sparsity. Each Evaluator outputs the active_dims and sparsity_ratio metrics that can be used to assess the sparsity of the embeddings.

  2. It is not recommended to use an Evaluator on an untrained model prior to training, as the sparsity will be very low, and so the memory usage might be unexpectedly high.

  3. The stronger Sparse Encoder models are trained almost exclusively with distillation from a stronger teacher model (e.g. a CrossEncoder model), instead of training directly from text pairs or triplets. See for example the SPLADE-v3 paper, which uses SparseDistillKLDivLoss and SparseMarginMSELoss for distillation.

  4. Whereas the majority of dense embedding models are trained to be used with cosine similarity, SparseEncoder models are commonly trained to be used with dot product to compute similarity. Some losses require you to provide a similarity function, and you might be better off using dot product there. Note that you can often provide the loss with model.similarity or model.similarity_pairwise.