Modules

sentence_transformers.sparse_encoder.models defines different building blocks, that can be used to create SparseEncoder networks from scratch. For more details, see Training Overview. Note that modules from sentence_transformers.models can also be used for Sparse models, such as sentence_transformers.models.Transformer from SentenceTransformer > Modules

SPLADE Pooling

class sentence_transformers.sparse_encoder.models.SpladePooling(pooling_strategy: Literal['max', 'sum'] = 'max', activation_function: Literal['relu', 'log1p_relu'] = 'relu', word_embedding_dimension: int | None = None, chunk_size: int | None = None)[source]

SPLADE Pooling module for creating the sparse embeddings.

This module implements the SPLADE pooling mechanism that:

  1. Takes token logits from a masked language model (MLM).

  2. Applies a sparse transformation using an activation function followed by log1p (i.e., log(1 + activation(MLM_logits))).

  3. Applies a pooling strategy max or sum to produce sparse embeddings.

The resulting embeddings are highly sparse and capture lexical information, making them suitable for efficient information retrieval.

Parameters:
  • pooling_strategy (str) –

    Pooling method across token dimensions. Choices:

  • activation_function (str) –

    Activation function applied before log1p transformation. Choices:

    • relu: ReLU activation (standard in all Splade models).

    • log1p_relu: log(1 + ReLU(x)) variant used in Opensearch Splade models see arxiv.org/pdf/2504.14839.

  • word_embedding_dimension (int, optional) – Dimensionality of the output embeddings (if needed).

  • chunk_size (int, optional) – Chunk size along the sequence length dimension (i.e., number of tokens per chunk). If None, processes entire sequence at once. Using smaller chunks the reduces memory usage but may lower the training and inference speed. Default is None.

MLM Transformer

class sentence_transformers.sparse_encoder.models.MLMTransformer(model_name_or_path: str, max_seq_length: int | None = None, model_args: dict[str, Any] | None = None, tokenizer_args: dict[str, Any] | None = None, config_args: dict[str, Any] | None = None, cache_dir: str | None = None, do_lower_case: bool = False, tokenizer_name_or_path: str | None = None, backend: str = 'torch')[source]

MLMTransformer adapts a Masked Language Model (MLM) for sparse encoding applications.

This class extends the Transformer class to work specifically with models that have a MLM head (like BERT, RoBERTa, etc.) and is designed to be used with SpladePooling for creating SPLADE sparse representations.

MLMTransformer accesses the MLM prediction head to get vocabulary logits for each token, which are later used by SpladePooling to create sparse lexical representations.

Parameters:
  • model_name_or_path – Hugging Face models name (https://huggingface.co/models)

  • max_seq_length – Truncate any inputs longer than max_seq_length

  • model_args – Keyword arguments passed to the Hugging Face MLMTransformers model

  • tokenizer_args – Keyword arguments passed to the Hugging Face MLMTransformers tokenizer

  • config_args – Keyword arguments passed to the Hugging Face MLMTransformers config

  • cache_dir – Cache dir for Hugging Face MLMTransformers to store/load models

  • do_lower_case – If true, lowercases the input (independent if the model is cased or not)

  • tokenizer_name_or_path – Name or path of the tokenizer. When None, then model_name_or_path is used

  • backend – Backend used for model inference. Can be only torch for now for this class.

SparseAutoEncoder

class sentence_transformers.sparse_encoder.models.SparseAutoEncoder(input_dim: int, hidden_dim: int = 512, k: int = 8, k_aux: int = 512, normalize: bool = False, dead_threshold: int = 30)[source]

This module implements the Sparse AutoEncoder architecture based on the paper: Beyond Matryoshka: Revisiting Sparse Coding for Adaptive Representation, https://arxiv.org/abs/2503.01776

This module transforms dense embeddings into sparse representations by:

  1. Applying a multi-layer feed-forward network

  2. Applying top-k sparsification to keep only the largest values

  3. Supporting auxiliary losses for training stability (via k_aux parameter)

Parameters:
  • input_dim – Dimension of the input embeddings.

  • hidden_dim – Dimension of the hidden layers. Defaults to 512.

  • k – Number of top values to keep in the final sparse representation. Defaults to 8.

  • k_aux – Number of top values to keep for auxiliary loss calculation. Defaults to 512.

  • normalize – Whether to apply layer normalization to the input embeddings. Defaults to False.

  • dead_threshold – Threshold for dead neurons. Neurons with non-zero activations below this threshold are considered dead. Defaults to 30.

SparseStaticEmbedding

class sentence_transformers.sparse_encoder.models.SparseStaticEmbedding(tokenizer: PreTrainedTokenizer, weight: torch.Tensor | None = None, frozen: bool = False)[source]

SparseStaticEmbedding module for efficient sparse representations.

This lightweight module computes sparse representations by mapping input tokens to static weights, such as IDF (Inverse Document Frequency) weights. It is designed to encode queries or documents into fixed-size embeddings based on the presence of tokens in the input.

A common scenario is to use this module for encoding queries, and using a heavier module like SPLADE (MLMTransformer + SpladePooling) for document encoding.

Parameters:
  • tokenizer (PreTrainedTokenizer) – PreTrainedTokenizer to tokenize input texts into input IDs.

  • weight (torch.Tensor | None) – Static weights for vocabulary tokens (e.g., IDF weights), shape should be (vocab_size,). If None, initializes weights to a vector of ones. Default is None.

  • frozen (bool) – Whether the weights should be frozen (not trainable). Default is False.