Multimodal Training
Sentence Transformer models can handle multimodal inputs (text, images, audio, and video), enabling cross-modal retrieval tasks such as text-to-image search or audio-to-text matching. The key enabler is the Transformer module’s automatic modality detection: it inspects the underlying model’s processor to determine which modalities are supported, then handles preprocessing for each modality transparently.
This means multimodal training uses the exact same pipeline as text-only training: the same losses, the same trainer, and the same evaluation tools. The data collator handles multimodal preprocessing automatically.
Supported Input Types
You can check which modalities a model supports using the modalities property and the supports() method:
from sentence_transformers import SentenceTransformer
model = SentenceTransformer("Qwen/Qwen3-VL-Embedding-2B", revision="refs/pr/23")
print(model.modalities)
# => ['text', 'image', 'message']
print(model.supports("image"))
# => True
print(model.supports("audio"))
# => False
The encode() method accepts a variety of input types depending on the model’s supported modalities:
Text: strings.
Image: PIL images, file paths, URLs, or numpy/torch arrays.
Audio: file paths, numpy/torch arrays, dicts with
"array"and"sampling_rate"keys, ortorchcodec.AudioDecoderinstances.Video: file paths, numpy/torch arrays, dicts with
"array"and"video_metadata"keys, ortorchcodec.VideoDecoderinstances.Multimodal dicts: a dict mapping modality names to values, e.g.
{"text": ..., "audio": ...}. The keys must be"text","image","audio", or"video".Chat messages: a list of dicts with
"role"and"content"keys for multimodal models that use an uncommon chat template to combine text and non-text inputs.
Training
Training a multimodal model follows the same steps as training a text-only model. You can use any compatible loss function, and the trainer and data collator handle multimodal inputs without any special configuration. Datasets can mix modalities across columns, for example a “query” column containing text strings and a “document” column containing PIL images.
Training Example: Document Screenshot Embedding
The training_visual_document_retrieval.py script finetunes Qwen/Qwen3-VL-Embedding-2B on query-document screenshot pairs for visual document retrieval. Here is how it works:
1. Load the model with efficient training settings:
from sentence_transformers import SentenceTransformer
model = SentenceTransformer(
"Qwen/Qwen3-VL-Embedding-2B",
model_kwargs={"attn_implementation": "flash_attention_2", "torch_dtype": "bfloat16"},
processor_kwargs={"min_pixels": 28 * 28, "max_pixels": 600 * 600},
revision="refs/pr/23",
)
The model_kwargs enable Flash Attention 2 and bfloat16 precision for faster training. The processor_kwargs control image resolution bounds; smaller max_pixels reduces memory usage at the cost of image detail.
2. Load the dataset from the tomaarsen/llamaindex-vdr-en-train-preprocessed dataset, which contains text queries paired with document screenshot images:
from datasets import load_dataset
train_dataset = load_dataset("tomaarsen/llamaindex-vdr-en-train-preprocessed", "train", split="train")
eval_dataset = load_dataset("tomaarsen/llamaindex-vdr-en-train-preprocessed", "eval", split="train")
3. Define the loss function using CachedMultipleNegativesRankingLoss wrapped in MatryoshkaLoss. This combination trains the model for retrieval with in-batch negatives while producing embeddings that remain effective after truncation to smaller dimensions:
from sentence_transformers.sentence_transformer.losses import CachedMultipleNegativesRankingLoss, MatryoshkaLoss
loss = CachedMultipleNegativesRankingLoss(model, mini_batch_size=1)
loss = MatryoshkaLoss(model, loss, matryoshka_dims=[2048, 1536, 1024, 512, 256, 128, 64])
4. Evaluate using InformationRetrievalEvaluator with text queries against an image corpus, measuring cross-modal retrieval performance:
from sentence_transformers.sentence_transformer.evaluation import InformationRetrievalEvaluator
eval_evaluator = InformationRetrievalEvaluator(
queries=eval_queries, # dict of text queries
corpus=eval_corpus, # dict of PIL images
relevant_docs=eval_relevant_docs,
name="vdr-eval-hard",
)
5. Train using the standard SentenceTransformerTrainer:
from sentence_transformers.sentence_transformer.trainer import SentenceTransformerTrainer
trainer = SentenceTransformerTrainer(
model=model,
args=args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
loss=loss,
evaluator=eval_evaluator,
)
trainer.train()
After training, the model can be evaluated at each Matryoshka dimension separately to measure the performance-efficiency tradeoff.