Samplers¶
BatchSamplers¶
-
class
sentence_transformers.training_args.
BatchSamplers
(value)[source]¶ Stores the acceptable string identifiers for batch samplers.
The batch sampler is responsible for determining how samples are grouped into batches during training. Valid options are:
BatchSamplers.BATCH_SAMPLER
: [default] UsesDefaultBatchSampler
, the default PyTorch batch sampler.BatchSamplers.NO_DUPLICATES
: UsesNoDuplicatesBatchSampler
, ensuring no duplicate samples in a batch. Recommended for losses that use in-batch negatives, such as:BatchSamplers.GROUP_BY_LABEL
: UsesGroupByLabelBatchSampler
, ensuring that each batch has 2+ samples from the same label. Recommended for losses that require multiple samples from the same label, such as:
If you want to use a custom batch sampler, you can create a new Trainer class that inherits from
SentenceTransformerTrainer
and overrides theget_batch_sampler()
method. The method must return a class instance that supports__iter__
and__len__
methods. The former should yield a list of indices for each batch, and the latter should return the number of batches.- Usage:
from sentence_transformers import SentenceTransformer, SentenceTransformerTrainer, SentenceTransformerTrainingArguments from sentence_transformers.training_args import BatchSamplers from sentence_transformers.losses import MultipleNegativesRankingLoss from datasets import Dataset model = SentenceTransformer("microsoft/mpnet-base") train_dataset = Dataset.from_dict({ "anchor": ["It's nice weather outside today.", "He drove to work."], "positive": ["It's so sunny.", "He took the car to the office."], }) loss = MultipleNegativesRankingLoss(model) args = SentenceTransformerTrainingArguments( output_dir="checkpoints", batch_sampler=BatchSamplers.NO_DUPLICATES, ) trainer = SentenceTransformerTrainer( model=model, args=args, train_dataset=train_dataset, loss=loss, ) trainer.train()
-
class
sentence_transformers.sampler.
DefaultBatchSampler
(*args, **kwargs)[source]¶ This sampler is the default batch sampler used in the SentenceTransformer library. It is equivalent to the PyTorch BatchSampler.
- Parameters
sampler (Sampler or Iterable) – The sampler used for sampling elements from the dataset, such as SubsetRandomSampler.
batch_size (int) – Number of samples per batch.
drop_last (bool) – If True, drop the last incomplete batch if the dataset size is not divisible by the batch size.
-
class
sentence_transformers.sampler.
NoDuplicatesBatchSampler
(dataset: Dataset, batch_size: int, drop_last: bool, valid_label_columns: list[str] = [], generator: torch.Generator = None, seed: int = 0)[source]¶ This sampler creates batches such that each batch contains samples where the values are unique, even across columns. This is useful when losses consider other samples in a batch to be in-batch negatives, and you want to ensure that the negatives are not duplicates of the anchor/positive sample.
- Recommended for:
- Parameters
dataset (Dataset) – The dataset to sample from.
batch_size (int) – Number of samples per batch.
drop_last (bool) – If True, drop the last incomplete batch if the dataset size is not divisible by the batch size.
valid_label_columns (List[str]) – List of column names to check for labels. The first column name from
valid_label_columns
found in the dataset will be used as the label column.generator (torch.Generator, optional) – Optional random number generator for shuffling the indices.
seed (int, optional) – Seed for the random number generator to ensure reproducibility.
-
class
sentence_transformers.sampler.
GroupByLabelBatchSampler
(dataset: Dataset, batch_size: int, drop_last: bool, valid_label_columns: list[str] = None, generator: torch.Generator = None, seed: int = 0)[source]¶ This sampler groups samples by their labels and aims to create batches such that each batch contains samples where the labels are as homogeneous as possible. This sampler is meant to be used alongside the
Batch...TripletLoss
classes, which require that each batch contains at least 2 examples per label class.- Recommended for:
- Parameters
dataset (Dataset) – The dataset to sample from.
batch_size (int) – Number of samples per batch. Must be divisible by 2.
drop_last (bool) – If True, drop the last incomplete batch if the dataset size is not divisible by the batch size.
valid_label_columns (List[str]) – List of column names to check for labels. The first column name from
valid_label_columns
found in the dataset will be used as the label column.generator (torch.Generator, optional) – Optional random number generator for shuffling the indices.
seed (int, optional) – Seed for the random number generator to ensure reproducibility.
MultiDatasetBatchSamplers¶
-
class
sentence_transformers.training_args.
MultiDatasetBatchSamplers
(value)[source]¶ Stores the acceptable string identifiers for multi-dataset batch samplers.
The multi-dataset batch sampler is responsible for determining in what order batches are sampled from multiple datasets during training. Valid options are:
MultiDatasetBatchSamplers.ROUND_ROBIN
: UsesRoundRobinBatchSampler
, which uses round-robin sampling from each dataset until one is exhausted. With this strategy, it’s likely that not all samples from each dataset are used, but each dataset is sampled from equally.MultiDatasetBatchSamplers.PROPORTIONAL
: [default] UsesProportionalBatchSampler
, which samples from each dataset in proportion to its size. With this strategy, all samples from each dataset are used and larger datasets are sampled from more frequently.
- Usage:
from sentence_transformers import SentenceTransformer, SentenceTransformerTrainer, SentenceTransformerTrainingArguments from sentence_transformers.training_args import MultiDatasetBatchSamplers from sentence_transformers.losses import CoSENTLoss from datasets import Dataset, DatasetDict model = SentenceTransformer("microsoft/mpnet-base") train_general = Dataset.from_dict({ "sentence_A": ["It's nice weather outside today.", "He drove to work."], "sentence_B": ["It's so sunny.", "He took the car to the bank."], "score": [0.9, 0.4], }) train_medical = Dataset.from_dict({ "sentence_A": ["The patient has a fever.", "The doctor prescribed medication.", "The patient is sweating."], "sentence_B": ["The patient feels hot.", "The medication was given to the patient.", "The patient is perspiring."], "score": [0.8, 0.6, 0.7], }) train_legal = Dataset.from_dict({ "sentence_A": ["This contract is legally binding.", "The parties agree to the terms and conditions."], "sentence_B": ["Both parties acknowledge their obligations.", "By signing this agreement, the parties enter into a legal relationship."], "score": [0.7, 0.8], }) train_dataset = DatasetDict({ "general": train_general, "medical": train_medical, "legal": train_legal, }) loss = CoSENTLoss(model) args = SentenceTransformerTrainingArguments( output_dir="checkpoints", multi_dataset_batch_sampler=MultiDatasetBatchSamplers.PROPORTIONAL, ) trainer = SentenceTransformerTrainer( model=model, args=args, train_dataset=train_dataset, loss=loss, ) trainer.train()
-
class
sentence_transformers.sampler.
RoundRobinBatchSampler
(dataset: ConcatDataset, batch_samplers: list[BatchSampler], generator: torch.Generator = None, seed: int = None)[source]¶ Batch sampler that yields batches in a round-robin fashion from multiple batch samplers, until one is exhausted. With this sampler, it’s unlikely that all samples from each dataset are used, but we do ensure that each dataset is sampled from equally.
- Parameters
dataset (ConcatDataset) – A concatenation of multiple datasets.
batch_samplers (List[BatchSampler]) – A list of batch samplers, one for each dataset in the ConcatDataset.
generator (torch.Generator, optional) – A generator for reproducible sampling. Defaults to None.
seed (int, optional) – A seed for the generator. Defaults to None.
-
class
sentence_transformers.sampler.
ProportionalBatchSampler
(dataset: ConcatDataset, batch_samplers: list[BatchSampler], generator: torch.Generator, seed: int)[source]¶ Batch sampler that samples from each dataset in proportion to its size, until all are exhausted simultaneously. With this sampler, all samples from each dataset are used and larger datasets are sampled from more frequently.
- Parameters
dataset (ConcatDataset) – A concatenation of multiple datasets.
batch_samplers (List[BatchSampler]) – A list of batch samplers, one for each dataset in the ConcatDataset.
generator (torch.Generator, optional) – A generator for reproducible sampling. Defaults to None.
seed (int, optional) – A seed for the generator. Defaults to None.