Samplers

BatchSamplers

class sentence_transformers.training_args.BatchSamplers(value)[source]

Stores the acceptable string identifiers for batch samplers.

The batch sampler is responsible for determining how samples are grouped into batches during training. Valid options are:

If you want to use a custom batch sampler, then you can subclass DefaultBatchSampler and pass the class (not an instance) to the batch_sampler argument in SentenceTransformerTrainingArguments (or CrossEncoderTrainingArguments, etc.). Alternatively, you can pass a function that accepts dataset, batch_size, drop_last, valid_label_columns, generator, and seed and returns a DefaultBatchSampler instance.

Usage:
from sentence_transformers import SentenceTransformer, SentenceTransformerTrainer, SentenceTransformerTrainingArguments
from sentence_transformers.training_args import BatchSamplers
from sentence_transformers.losses import MultipleNegativesRankingLoss
from datasets import Dataset

model = SentenceTransformer("microsoft/mpnet-base")
train_dataset = Dataset.from_dict({
    "anchor": ["It's nice weather outside today.", "He drove to work."],
    "positive": ["It's so sunny.", "He took the car to the office."],
})
loss = MultipleNegativesRankingLoss(model)
args = SentenceTransformerTrainingArguments(
    output_dir="checkpoints",
    batch_sampler=BatchSamplers.NO_DUPLICATES,
)
trainer = SentenceTransformerTrainer(
    model=model,
    args=args,
    train_dataset=train_dataset,
    loss=loss,
)
trainer.train()
class sentence_transformers.sampler.DefaultBatchSampler(dataset: Dataset, batch_size: int, drop_last: bool, valid_label_columns: list[str] | None = None, generator: Generator | None = None, seed: int = 0)[source]

This sampler is the default batch sampler used in the SentenceTransformer library. It is equivalent to the PyTorch BatchSampler.

Parameters:
  • sampler (Sampler or Iterable) – The sampler used for sampling elements from the dataset, such as SubsetRandomSampler.

  • batch_size (int) – Number of samples per batch.

  • drop_last (bool) – If True, drop the last incomplete batch if the dataset size is not divisible by the batch size.

  • valid_label_columns (List[str], optional) – List of column names to check for labels. The first column name from valid_label_columns found in the dataset will be used as the label column.

  • generator (torch.Generator, optional) – Optional random number generator for shuffling the indices.

  • seed (int) – Seed for the random number generator to ensure reproducibility. Defaults to 0.

class sentence_transformers.sampler.NoDuplicatesBatchSampler(dataset: Dataset, batch_size: int, drop_last: bool, valid_label_columns: list[str] | None = None, generator: Generator | None = None, seed: int = 0, precompute_hashes: bool = False, precompute_num_proc: int | None = None, precompute_batch_size: int = 1000)[source]

This sampler creates batches such that each batch contains samples where the values are unique, even across columns. This is useful when losses consider other samples in a batch to be in-batch negatives, and you want to ensure that the negatives are not duplicates of the anchor/positive sample.

Recommended for:
Parameters:
  • dataset (Dataset) – The dataset to sample from.

  • batch_size (int) – Number of samples per batch.

  • drop_last (bool) – If True, drop the last incomplete batch if the dataset size is not divisible by the batch size.

  • valid_label_columns (List[str], optional) – List of column names to check for labels. The first column name from valid_label_columns found in the dataset will be used as the label column.

  • generator (torch.Generator, optional) – Optional random number generator for shuffling the indices.

  • seed (int) – Seed for the random number generator to ensure reproducibility. Defaults to 0.

  • precompute_hashes (bool, optional) – If True, precompute xxhash 64-bit values for dataset fields using datasets.map to speed up duplicate checks. Requires xxhash to be installed and uses additional memory: in theory roughly len(dataset) * num_columns * 8 bytes for the dense int64 hash matrix, although actual memory usage may therefore differ in practice. Defaults to False.

  • precompute_num_proc (int, optional) – Number of processes for hashing with datasets.map. If set to None, defaults to min(8, cpu_count - 1) when precompute_hashes is True.

  • precompute_batch_size (int, optional) – Batch size for datasets.map hashing. Defaults to 1000.

class sentence_transformers.sampler.GroupByLabelBatchSampler(dataset: Dataset, batch_size: int, drop_last: bool, valid_label_columns: list[str] | None = None, generator: Generator | None = None, seed: int = 0)[source]

Batch sampler that groups samples by label for in-batch triplet mining.

Samples are shuffled within each label, then interleaved in round-robin fashion to produce a stream where labels are well-mixed. This stream is chunked into batches of exactly batch_size. Every batch is guaranteed to contain multiple distinct labels, each with at least 2 samples.

Labels take turns emitting 2 samples each. The stream stops when fewer than 2 labels remain, so the dominant label’s tail ends up in the remainder. Produces excellent per-batch balance.

Recommended for:
Parameters:
  • dataset (Dataset) – The dataset to sample from.

  • batch_size (int) – Number of samples per batch. Must be an even number >= 4.

  • drop_last (bool) – If True, drop the last incomplete batch if the dataset size is not divisible by the batch size.

  • valid_label_columns (List[str], optional) – List of column names to check for labels. The first column name from valid_label_columns found in the dataset will be used as the label column.

  • generator (torch.Generator, optional) – Optional random number generator for shuffling the indices.

  • seed (int) – Seed for the random number generator to ensure reproducibility. Defaults to 0.

MultiDatasetBatchSamplers

class sentence_transformers.training_args.MultiDatasetBatchSamplers(value)[source]

Stores the acceptable string identifiers for multi-dataset batch samplers.

The multi-dataset batch sampler is responsible for determining in what order batches are sampled from multiple datasets during training. Valid options are:

  • MultiDatasetBatchSamplers.ROUND_ROBIN: Uses RoundRobinBatchSampler, which uses round-robin sampling from each dataset until one is exhausted. With this strategy, it’s likely that not all samples from each dataset are used, but each dataset is sampled from equally.

  • MultiDatasetBatchSamplers.PROPORTIONAL: [default] Uses ProportionalBatchSampler, which samples from each dataset in proportion to its size. With this strategy, all samples from each dataset are used and larger datasets are sampled from more frequently.

If you want to use a custom multi-dataset batch sampler, then you can subclass MultiDatasetDefaultBatchSampler and pass the class (not an instance) to the multi_dataset_batch_sampler argument in SentenceTransformerTrainingArguments. (or CrossEncoderTrainingArguments, etc.). Alternatively, you can pass a function that accepts dataset (a ConcatDataset), batch_samplers (i.e. a list of batch sampler for each of the datasets in the ConcatDataset), generator, and seed and returns a MultiDatasetDefaultBatchSampler instance.

Usage:
from sentence_transformers import SentenceTransformer, SentenceTransformerTrainer, SentenceTransformerTrainingArguments
from sentence_transformers.training_args import MultiDatasetBatchSamplers
from sentence_transformers.losses import CoSENTLoss
from datasets import Dataset, DatasetDict

model = SentenceTransformer("microsoft/mpnet-base")
train_general = Dataset.from_dict({
    "sentence_A": ["It's nice weather outside today.", "He drove to work."],
    "sentence_B": ["It's so sunny.", "He took the car to the bank."],
    "score": [0.9, 0.4],
})
train_medical = Dataset.from_dict({
    "sentence_A": ["The patient has a fever.", "The doctor prescribed medication.", "The patient is sweating."],
    "sentence_B": ["The patient feels hot.", "The medication was given to the patient.", "The patient is perspiring."],
    "score": [0.8, 0.6, 0.7],
})
train_legal = Dataset.from_dict({
    "sentence_A": ["This contract is legally binding.", "The parties agree to the terms and conditions."],
    "sentence_B": ["Both parties acknowledge their obligations.", "By signing this agreement, the parties enter into a legal relationship."],
    "score": [0.7, 0.8],
})
train_dataset = DatasetDict({
    "general": train_general,
    "medical": train_medical,
    "legal": train_legal,
})

loss = CoSENTLoss(model)
args = SentenceTransformerTrainingArguments(
    output_dir="checkpoints",
    multi_dataset_batch_sampler=MultiDatasetBatchSamplers.PROPORTIONAL,
)
trainer = SentenceTransformerTrainer(
    model=model,
    args=args,
    train_dataset=train_dataset,
    loss=loss,
)
trainer.train()
class sentence_transformers.sampler.MultiDatasetDefaultBatchSampler(dataset: ConcatDataset, batch_samplers: list[BatchSampler], generator: Generator | None = None, seed: int = 0)[source]

Abstract base batch sampler that yields batches from multiple batch samplers. This class must be subclassed to implement specific sampling strategies, and cannot be used directly.

Parameters:
  • dataset (ConcatDataset) – A concatenation of multiple datasets.

  • batch_samplers (List[BatchSampler]) – A list of batch samplers, one for each dataset in the ConcatDataset.

  • generator (torch.Generator, optional) – A generator for reproducible sampling. Defaults to None.

  • seed (int) – Seed for the random number generator to ensure reproducibility. Defaults to 0.

class sentence_transformers.sampler.RoundRobinBatchSampler(dataset: ConcatDataset, batch_samplers: list[BatchSampler], generator: Generator | None = None, seed: int = 0)[source]

Batch sampler that yields batches in a round-robin fashion from multiple batch samplers, until one is exhausted. With this sampler, it’s unlikely that all samples from each dataset are used, but we do ensure that each dataset is sampled from equally.

Parameters:
  • dataset (ConcatDataset) – A concatenation of multiple datasets.

  • batch_samplers (List[BatchSampler]) – A list of batch samplers, one for each dataset in the ConcatDataset.

  • generator (torch.Generator, optional) – A generator for reproducible sampling. Defaults to None.

  • seed (int) – Seed for the random number generator to ensure reproducibility. Defaults to 0.

class sentence_transformers.sampler.ProportionalBatchSampler(dataset: ConcatDataset, batch_samplers: list[BatchSampler], generator: Generator | None = None, seed: int = 0)[source]

Batch sampler that samples from each dataset in proportion to its size, until all are exhausted simultaneously. With this sampler, all samples from each dataset are used and larger datasets are sampled from more frequently.

Parameters:
  • dataset (ConcatDataset) – A concatenation of multiple datasets.

  • batch_samplers (List[BatchSampler]) – A list of batch samplers, one for each dataset in the ConcatDataset.

  • generator (torch.Generator, optional) – A generator for reproducible sampling. Defaults to None.

  • seed (int) – Seed for the random number generator to ensure reproducibility. Defaults to 0.