Fine-Tuning Embedding Models with Unsloth

Introduction

We are finally at the point that many people have been waiting for: small LLMs have become quite powerful and can run on consumer GPUs. With good fine-tuning in a given domain, they even rival some of the best commercially available LLMs.

This combo of being runnable and fine-tuneable on consumer hardware is possible thanks to weight quantization and LoRA adapters, respectively.

This post fine-tunes a text embedding model with the unsloth and Sentence Transformers libraries. Specifically, we fine-tune a set of QLoRA adapters using a contrastive loss on a simple Question and Answer dataset.

The unsloth library

The unsloth library makes it both efficient and affordable to fine-tune transformer networks on consumer hardware.

Unsloth has an ocean of starter notebooks that make it easy for anyone to fine-tune relevant, modern LLMs. Many of the notebooks use quantization setups that even fit on 8GB GPUs. If you went back a few years ago, and told people we'd be able to meaningfully fine-tune powerful, SoTA LLMs on such small cards it would have sounded outlandish.

Most of their work focuses on fine-tuning decoder models, aka the LLM family of models. This makes sense given the high visibility and ever-increasing capabilities of generative networks.

While generative LLMs receive much attention, there is also the flip side of the architecture coin: encoder models. These are models like BERT that transform sentences into vector embeddings that capture semantic content and relationships.

Encoder models power incredibly useful tools like RAG. Despite the LLM hype, it is RAG engines that are the backbone of most LLM applications currently deployed in the wild.

RAG workhorses

RAG engines rely on text embedding models, aka the encoder side of transformer networks.

There is a great post here from the creators of the recent modernBERT model that describes how LLMs capture all the hype and fanfare, but encoding models are the actual workhorses for AI products.

Unfortunately, as of writing, unsloth does not directly support fine-tuning encoder models. It's been a feature in their pipeline for a while, but they understandably have a ton of other pressing work.

We can still however leverage some recent PRs, along with the Sentence Transformers library, to patch fine-tuning embeddings into unsloth.

In this post, we will fine-tune an all-MiniLM model, specifically the recent all-MiniLM-L12-v2.

Fine-tuning embeddings with unsloth and Sentence Transformers

Let's describe the overall process we'll go through. We first take the all-MiniLM model and wrap it in unsloth's QLoRA adapters. Then, we again wrap the unsloth-patched model inside of a custom Sentence Transformers model. It is this final double-wrapped model that will be fine-tuned.

Both Sentence Transformers and unsloth actually subclass HuggingFace's Trainer and TrainingArguments. Their APIs and functionality aren't quite identical, but are close enough for our purposes.

Unsloth will handle the QLoRA adapters that make it possible to fine-tune encoder models with a tiny fraction of the parameters that full fine-tuning would have taken.

Sentence Transformers will do the heavy lifting of the learning loop: preparing the input batches, computing the embeddings-specific loss, and handling the weight updates.

Let's get started and put all of this together. First, we need to prepare our environment.

Installing Unsloth

The following command installs unsloth:

pip install unsloth

Note that unsloth is under constant development. It directly patches and modifies many low-level libraries used for LLM inference and training. Because of this, it can be quite tricky to install. The default setup in their Google Colab notebooks do a specific pip installation dance that is quite handy.

You may have good luck with the simple pip install unsloth. Depending on your linux setup, it might not be so simple. If the simple install fails, mimic the Colab-specific pip installation below.

# only do this if the simple pip install fails
pip install --no-deps bitsandbytes accelerate xformers==0.0.29.post3 peft trl==0.15.2 triton cut_cross_entropy unsloth_zoo
pip install sentencepiece protobuf datasets huggingface_hub hf_transfer
pip install --no-deps unsloth

Once this is ready, we can import unsloth and get started.

# import unsloth first so it can patch in optimizations
import unsloth
🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.
🦥 Unsloth Zoo will now patch everything to make training faster!

Note that it's best-practice to import unsloth before anything else. This lets it patch all the lower-level libraries that it needs. Then we'll be using the FastModel class. This class takes a very handy auto_class argument that lets us load the actual encoder models.

# loads encoder models
from unsloth import FastModel

Now we can bring in all of our regular imports. We import all of them here for convenience.

# general imports
from pathlib import Path
import torch

# import the huggingface classes
from transformers import BertModel
from datasets import load_dataset, concatenate_datasets
from peft import LoraConfig, TaskType

# import the sentence transformers classes
import sentence_transformers
from sentence_transformers import SentenceTransformerTrainingArguments, SentenceTransformerTrainer, SentenceTransformer
from sentence_transformers.util import cos_sim
from sentence_transformers.losses import MultipleNegativesRankingLoss
from sentence_transformers.training_args import BatchSamplers
from sentence_transformers.evaluation import InformationRetrievalEvaluator

We can now start setting the variables we'll need. As mentioned, we're using the all-MiniLM-L12-v2 model which is part of the BERT family.

# Model Configuration
BASE_MODEL_ID = 'sentence-transformers/all-MiniLM-L12-v2'
BERT_MODEL = BertModel

# Maximum sequence length of this model
MAX_SEQ_LENGTH = 512
LOAD_IN_4BIT = True  # For QLoRA (4-bit quantization)

Let's load this using FastModel. Note that this is the full model, before we've attached any QLoRA adapters.

# load the base model optimized with unsloth
model, tokenizer = FastModel.from_pretrained(
    model_name = BASE_MODEL_ID,
    auto_model = BERT_MODEL,
    max_seq_length = MAX_SEQ_LENGTH,
    dtype = None, # Auto-detects (BF16/FP16)
    load_in_4bit = LOAD_IN_4BIT,
)
print(f"Loaded {BASE_MODEL_ID} with Unsloth.")
==((====))==  Unsloth 2025.4.7: Fast Bert patching. Transformers: 4.51.3.
   \\   /|    NVIDIA A100-SXM4-40GB. Num GPUs = 1. Max memory: 39.557 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.6.0+cu124. CUDA: 8.0. CUDA Toolkit: 12.4. Triton: 3.2.0
\        /    Bfloat16 = TRUE. FA [Xformers = 0.0.29.post3. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!
Loaded sentence-transformers/all-MiniLM-L12-v2 with Unsloth.

QLoRA patches

Next, we'll attach the QLoRA weights to be learned using unsloth. Unsloth has a whole set of good, hard-won default arguments for fine-tuning LLMs. From my initial experiments, it seems like some of these will need re-thinking for encoder models. But, they are certainly a solid starting point.

# LORA Configuration
LORA_R = 16          # Rank of the LORA matrices.
LORA_ALPHA = 32      # Rule of thumb: 2 * rank
LORA_DROPOUT = 0.0   # Dropout of 0 is best.
USE_RSLORA = False   # Rank-Stabilized LoRA if desired

# Target modules for adapters
LORA_TARGET_MODULES = ["query", "key", "value", "dense"]
LORA_EXCLUDE_MODULES = [] # put anything you want to skip here

With our QLoRA settings, we can then attach them to the base model.

print("Attaching QLoRA adapters...")
model = FastModel.get_peft_model(
    model,
    r = LORA_R,
    lora_alpha = LORA_ALPHA,
    lora_dropout = LORA_DROPOUT,
    target_modules = LORA_TARGET_MODULES,
    exclude_modules = LORA_EXCLUDE_MODULES,
    use_rslora = USE_RSLORA,
    bias = "none", # Standard practice for LoRA
    use_gradient_checkpointing = "unsloth",
    modules_to_save = None, # Add to train non-LORA modules
    task_type = TaskType.FEATURE_EXTRACTION, # Important!
)
print("LORA adapters added.")
Attaching QLoRA adapters...
Unsloth: Making `model.base_model.model.encoder` require gradients
LORA adapters added.

A key part here is the line task_type = TaskType.FEATURE_EXTRACTION which prepares the models for the embeddings loss.

We can see below how QLoRA only learns a fraction of the model's original parameters, making it feasible to run this training on regular consumer hardware instead of on massive GPU clusters.

# check how many parameters we will actually learn
model.print_trainable_parameters()
trainable params: 1,339,392 || all params: 34,699,392 || trainable%: 3.8600

Wrapping unsloth model with Sentence Transformers

To use a loss for embedding models, we need the Sentence Transformers library. Below we can follow the Sentence Transformers documentation to create a custom model.

First, we create a Transformer model. Then we manually patch in our QLoRA unsloth model and tokenizer.

We also need to tell Sentence Transformers how to convert the model's final output into an embedding. This is known as the pooling stage. There are many pooling techniques, but it seems like mean-pooling is the rising star. Mean pooling means we take the token-wise average of the network's final activations and call that collapsed, single vector the embedding.

Lastly, many models include a normalization stage. This determines whether or not we scale vectors to have a uniform unit length. It's the default for sentence transformers, and in practice I've found it's saved me a lot of headache to always and only deal with normalized vectors.

With our three modules ready, we pass them into a SentenceTransformer instance. This creates the final model that can be used by the library's Trainer class.

Note: you can also pass in additional arguments here that would have typically be passed to the huggingface model, such as the attention implementation.

Phew. That's a lot. Let's write an annotated function this a bit clearer and our lives a bit easier.

# Prepare the ST model
def get_st_unsloth_wrapper(
        model,
        tokenizer,
        base_model_id=BASE_MODEL_ID,
        pooling_mode="mean",
        max_seq_length=MAX_SEQ_LENGTH,
        ):
    print("Initializing Sentence Transformer modules...")

    # 1. Create the Transformer module instance
    transformer_module = sentence_transformers.models.Transformer(
        model_name_or_path=base_model_id,
        max_seq_length=max_seq_length,
    )

    # 2. Replace the internal Hugging Face model with our LORA-patched Unsloth model
    transformer_module.auto_model = model
    transformer_module.tokenizer = tokenizer

    print(f"Manually assigned Unsloth LORA model to sentence_transformers.models.Transformer module.")

    # 3. Create the Pooling module
    hidden_size = model.config.hidden_size
    pooling_module = sentence_transformers.models.Pooling(
        word_embedding_dimension=hidden_size,
        pooling_mode=pooling_mode,
    )
    print(f"Using Pooling module with mode: {pooling_mode}")

    # 4. Add the Normalize module
    normalize_module = sentence_transformers.models.Normalize()
    modules = [transformer_module, pooling_module, normalize_module]

    # 5. Initialize SentenceTransformer with custom modules
    sbert_model = SentenceTransformer(modules=modules)

    print(f"SentenceTransformer wrapper created with custom modules.")
    return sbert_model

# wrap our unsloth model in Sentence Transformers
sbert_model = get_st_unsloth_wrapper(
        model,
        tokenizer,
        max_seq_length=MAX_SEQ_LENGTH,
        base_model_id=BASE_MODEL_ID,
        pooling_mode="mean",
  )
Initializing Sentence Transformer modules...
Manually assigned Unsloth LORA model to sentence_transformers.models.Transformer module.
Using Pooling module with mode: mean
SentenceTransformer wrapper created with custom modules.

Data preparation

We can now focus on the most important part of this whole process: the data. The data setup below is taken from Phil Schmid's excellent guide on fine-tuning embedding models for RAG applications. We mirror this setup because it's a fun, interesting dataset and because we're mainly focused on the unsloth and QLoRA pieces.

The main thing we have to do is properly format the data for the contrastive loss we will be using.

A proper deep dive into contrastive losses is far beyond the scope of this post. Here's an excellent blog post from Lilian Weng that teaches you all the basics (and then some) of these losses.

# from: https://github.com/philschmid/deep-learning-pytorch-huggingface/blob/main/training/fine-tune-embedding-model-for-rag.ipynb

# prepare the NVIDIA financial dataset
from datasets import load_dataset

# Load dataset from the hub
dataset = load_dataset("philschmid/finanical-rag-embedding-dataset", split="train")

# rename columns
dataset = dataset.rename_column("question", "anchor")
dataset = dataset.rename_column("context", "positive")

# Add an id column to the dataset
dataset = dataset.add_column("id", range(len(dataset)))

# split dataset into a 10% test set
dataset = dataset.train_test_split(test_size=0.1)

# save datasets to disk
dataset["train"].to_json("train_dataset.json", orient="records")
dataset["test"].to_json("test_dataset.json", orient="records")
Creating json from Arrow format:   0%|          | 0/7 [00:00
Creating json from Arrow format:   0%|          | 0/1 [00:00
240993

The key takeaway is that all the hard research into contrastive losses has paid off tremendously: we now have a certain kind of loss called Multiple Negatives Ranking Loss (MNRL) that makes it possible to train embeddings model with loosely, implicitly labeled data like Q&A pairs.

Questions and Answers became a pair of reference (anchor) and matching (positive) vectors that should be retrieved together.

For any one valid pair, the model randomly picks vectors from different training examples in the same mini-batch to use as negatives.

This means all you need to start training an embeddings model is a good set of Q&A questions. With how ubiquitous and powerful this kind of data has become thanks to SFT and reasoning-based RL, you can see how we're very close to an insanely powerful data feedback loop.

And, we can always do some work to improve this loss by picking or mining better negative examples. But it is pretty outrageous and fortunate how quickly we can set up fine-tuning embeddings models with the MNRL loss.

Let's go ahead and define this powerful loss function.

# define the loss function
loss = MultipleNegativesRankingLoss(sbert_model)
print(f"Using loss: {type(loss).__name__}")
Using loss: MultipleNegativesRankingLoss

The next step is to define and group up all of our training arguments. A rule of thumb is that LoRA can overfit if you train for too many epochs. So we start with just a few, but this is definitely a parameter to explore.

For the batch size, you should use the largest value that fits on your GPU. This is especially important for the MNRL loss since it randomly picks negative examples from the same batch. The larger the batch size, the more random negative examples it can pick from.

## Preparing all of our training arguments

# Training Configuration
NUM_TRAIN_EPOCHS = 4                # Start with 1 epochs
PER_DEVICE_TRAIN_BATCH_SIZE = 512   # Adjust based on GPU VRAM and MAX_SEQ_LENGTH.
PER_DEVICE_EVAL_BATCH_SIZE = 1024   # Can usually be higher than train batch size.
GRADIENT_ACCUMULATION_STEPS = 1     # Only for small cards

# don't repeat samples in the same batch given our loss
batch_sampler = BatchSamplers.NO_DUPLICATES if isinstance(loss, MultipleNegativesRankingLoss) else None

The rest of the training arguments are standard for unsloth models. However, as mentioned, QLoRA adapters for encoders are a relatively unexplored space. There are likely far more optimal values, but this is a good start.

# set lower for longer training runs
LEARNING_RATE = 2e-4

WARMUP_RATIO = 0.1                 # percent of warmup steps
OPTIMIZER = "adamw_torch_fused"    # start with 8bit optimizer
LR_SCHEDULER_TYPE = "cosine"       # schedule for the lr
WEIGHT_DECAY = 0.1                 # Weight decay
FP16 = not torch.cuda.is_bf16_supported() # Use FP16 if BF16 is not available
BF16 = torch.cuda.is_bf16_supported()     # Use BF16 on supported GPUs (Ampere+) for stability.

Let's define how we'll evaluate the model. We'll also make an output directory where to save the fine-tuned model.

# set the output directory
OUTPUT_DIR = Path("finetuned_embeddings")
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

# evaluation and saving
EVAL_STEPS = 4           # evaluate every N steps
EVAL_STRATEGY = "steps"
SAVE_STEPS = 4           # save checkpoint every N steps
SAVE_STRATEGY = "steps"
SAVE_TOTAL_LIMIT = 2     # keep only the last N checkpoints
LOGGING_STEPS = 2        # log metrics every N steps

Once again, the specific evaluation setup is taken from Phil Schmid's notebook. This creates a simple evaluator that's meant to mirror how relevant embeddings should be retrieved in RAG applications.

# load test dataset
test_dataset = load_dataset("json", data_files="test_dataset.json", split="train")
train_dataset = load_dataset("json", data_files="train_dataset.json", split="train")
corpus_dataset = concatenate_datasets([train_dataset, test_dataset])

# Convert the datasets to dictionaries
corpus = dict(
    zip(corpus_dataset["id"], corpus_dataset["positive"])
)  # Our corpus (cid => document)
queries = dict(
    zip(test_dataset["id"], test_dataset["anchor"])
)  # Our queries (qid => question)

# Create a mapping of relevant document (1 in our case) for each query
relevant_docs = {}  # Query ID to relevant documents (qid => set([relevant_cids])
for q_id in queries:
    relevant_docs[q_id] = [q_id]


evaluator = InformationRetrievalEvaluator(
    queries=queries,
    corpus=corpus,
    relevant_docs=relevant_docs,
    score_functions={"cosine": cos_sim},
    name="ir-eval"
)
Generating train split: 0 examples [00:00, ? examples/s]
Generating train split: 0 examples [00:00, ? examples/s]

With all the work done, we can now group up the training arguments.

print("Defining training arguments...")
args = SentenceTransformerTrainingArguments(
    # Core Training Parameters
    output_dir=str(OUTPUT_DIR),
    num_train_epochs=NUM_TRAIN_EPOCHS,
    per_device_train_batch_size=PER_DEVICE_TRAIN_BATCH_SIZE,
    per_device_eval_batch_size=PER_DEVICE_EVAL_BATCH_SIZE,
    gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
    learning_rate=LEARNING_RATE,
    lr_scheduler_type=LR_SCHEDULER_TYPE,
    warmup_ratio=WARMUP_RATIO,
    weight_decay=WEIGHT_DECAY,
    optim=OPTIMIZER,
    batch_sampler=batch_sampler,
    fp16=FP16,
    bf16=BF16,
    tf32=True, # NOTE: gpu must support
    fp16_full_eval=True,
    # Evaluation and Saving
    eval_strategy=EVAL_STRATEGY,
    eval_steps=EVAL_STEPS,
    save_strategy=SAVE_STRATEGY,
    save_steps=SAVE_STEPS,
    save_total_limit=SAVE_TOTAL_LIMIT,
    load_best_model_at_end=True if evaluator else False,
    metric_for_best_model="eval_ir-eval_cosine_ndcg@10" if evaluator and isinstance(evaluator, InformationRetrievalEvaluator) else None,
    greater_is_better=True,
    # Logging and Reporting
    logging_steps=LOGGING_STEPS,
    report_to="tensorboard",
    run_name=f"{BASE_MODEL_ID.split('/')[-1]}-st-finetune",
    seed=42,
)
Defining training arguments...

Preparing the Trainer

We have everything we need to start training:

  • A model.
  • Training and evaluation datasets.
  • A loss function.

We can wrap all of these in a SentenceTransformerTrainer and off we go.

print("Initializing SentenceTransformerTrainer...")
trainer = SentenceTransformerTrainer(
    model=sbert_model, # Pass the standard SentenceTransformer model
    args=args,
    train_dataset=train_dataset.select_columns(["anchor", "positive"]),
    eval_dataset=test_dataset.select_columns(["anchor", "positive"]) if evaluator else None,
    loss=loss,
    evaluator=evaluator,
    callbacks=[],
)
Initializing SentenceTransformerTrainer...
Computing widget examples:   0%|          | 0/1 [00:00

Drumroll... and train the model!

# train the unsloth embeddings
train_res = trainer.train()
==((====))==  Unsloth - 2x faster free finetuning | Num GPUs used = 1
   \\   /|    Num examples = 6,300 | Num Epochs = 4 | Total steps = 52
O^O/ \_/ \    Batch size per device = 512 | Gradient accumulation steps = 1
\        /    Data Parallel GPUs = 1 | Total batch size (512 x 1 x 1) = 512
 "-____-"     Trainable parameters = 1,339,392/24,008,832 (5.58% trained)
[52/52 05:30, Epoch 4/4]
Step Training Loss Validation Loss Ir-eval Cosine Accuracy@1 Ir-eval Cosine Accuracy@3 Ir-eval Cosine Accuracy@5 Ir-eval Cosine Accuracy@10 Ir-eval Cosine Precision@1 Ir-eval Cosine Precision@3 Ir-eval Cosine Precision@5 Ir-eval Cosine Precision@10 Ir-eval Cosine Recall@1 Ir-eval Cosine Recall@3 Ir-eval Cosine Recall@5 Ir-eval Cosine Recall@10 Ir-eval Cosine Ndcg@10 Ir-eval Cosine Mrr@10 Ir-eval Cosine Map@100
4 1.842900 1.183924 0.578571 0.721429 0.777143 0.831429 0.578571 0.240476 0.155429 0.083143 0.578571 0.721429 0.777143 0.831429 0.704366 0.663803 0.668477
8 1.389200 0.715329 0.618571 0.742857 0.801429 0.847143 0.618571 0.247619 0.160286 0.084714 0.618571 0.742857 0.801429 0.847143 0.731694 0.694814 0.699990
12 0.749600 0.507300 0.638571 0.764286 0.810000 0.867143 0.638571 0.254762 0.162000 0.086714 0.638571 0.764286 0.810000 0.867143 0.749741 0.712506 0.717353
16 0.657600 0.428953 0.645714 0.780000 0.824286 0.878571 0.645714 0.260000 0.164857 0.087857 0.645714 0.780000 0.824286 0.878571 0.760336 0.722725 0.727359
20 0.614300 0.382340 0.668571 0.790000 0.838571 0.891429 0.668571 0.263333 0.167714 0.089143 0.668571 0.790000 0.838571 0.891429 0.778032 0.741937 0.745949
24 0.559200 0.357284 0.668571 0.802857 0.847143 0.898571 0.668571 0.267619 0.169429 0.089857 0.668571 0.802857 0.847143 0.898571 0.784138 0.747524 0.750950
28 0.459600 0.345445 0.670000 0.815714 0.858571 0.900000 0.670000 0.271905 0.171714 0.090000 0.670000 0.815714 0.858571 0.900000 0.787243 0.750852 0.754066
32 0.555100 0.337076 0.675714 0.811429 0.864286 0.901429 0.675714 0.270476 0.172857 0.090143 0.675714 0.811429 0.864286 0.901429 0.790867 0.755160 0.758290
36 0.546900 0.331842 0.678571 0.815714 0.867143 0.901429 0.678571 0.271905 0.173429 0.090143 0.678571 0.815714 0.867143 0.901429 0.793002 0.757874 0.761037
40 0.359800 0.328072 0.680000 0.820000 0.868571 0.901429 0.680000 0.273333 0.173714 0.090143 0.680000 0.820000 0.868571 0.901429 0.793829 0.758917 0.762097
44 0.484400 0.327019 0.678571 0.818571 0.868571 0.901429 0.678571 0.272857 0.173714 0.090143 0.678571 0.818571 0.868571 0.901429 0.793415 0.758348 0.761538
48 0.459900 0.326337 0.680000 0.818571 0.870000 0.901429 0.680000 0.272857 0.174000 0.090143 0.680000 0.818571 0.870000 0.901429 0.794044 0.759169 0.762369
52 0.428900 0.326476 0.680000 0.818571 0.865714 0.901429 0.680000 0.272857 0.173143 0.090143 0.680000 0.818571 0.865714 0.901429 0.793913 0.759026 0.762210

Unsloth: Will smartly offload gradients to save VRAM!

Let's save the model to disk so we can use it later.

# save the fine-tuned for persistence
sbert_model.save_pretrained(str(OUTPUT_DIR))

Comparison with the original model

We need to evaluate the model to know if this whole process actually improved it. We'll compare it against the original model that had no QLoRA adapters. The small snippet below loads up fresh versions of both models, puts them into eval mode, and runs the evaluator on them.

# load the baseline
original_model = SentenceTransformer(BASE_MODEL_ID)
original_model.eval()

# load the fine-tuned model
fine_tuned_model = SentenceTransformer(
    args.output_dir, device="cuda" if torch.cuda.is_available() else "cpu"
)
fine_tuned_model.eval()

# evaluate both models
with torch.inference_mode():
  baselines = evaluator(original_model)
  fine_tuned_results = evaluator(fine_tuned_model)

# print their scores
print(f"Original model: {baselines}")
print(f"Fine-tuned model: {fine_tuned_results}")
/usr/local/lib/python3.11/dist-packages/peft/tuners/tuners_utils.py:167: UserWarning: Already found a `peft_config` attribute in the model. This will lead to having multiple adapters in the model. Make sure to know what you are doing!
  warnings.warn(
Original model: {'ir-eval_cosine_accuracy@1': 0.5985714285714285, 'ir-eval_cosine_accuracy@3': 0.7271428571428571, 'ir-eval_cosine_accuracy@5': 0.7814285714285715, 'ir-eval_cosine_accuracy@10': 0.8442857142857143, 'ir-eval_cosine_precision@1': 0.5985714285714285, 'ir-eval_cosine_precision@3': 0.24238095238095236, 'ir-eval_cosine_precision@5': 0.15628571428571425, 'ir-eval_cosine_precision@10': 0.08442857142857142, 'ir-eval_cosine_recall@1': 0.5985714285714285, 'ir-eval_cosine_recall@3': 0.7271428571428571, 'ir-eval_cosine_recall@5': 0.7814285714285715, 'ir-eval_cosine_recall@10': 0.8442857142857143, 'ir-eval_cosine_ndcg@10': 0.7169950659589105, 'ir-eval_cosine_mrr@10': 0.6768259637188209, 'ir-eval_cosine_map@100': 0.682233373628609}
Fine-tuned model: {'ir-eval_cosine_accuracy@1': 0.65, 'ir-eval_cosine_accuracy@3': 0.7914285714285715, 'ir-eval_cosine_accuracy@5': 0.8414285714285714, 'ir-eval_cosine_accuracy@10': 0.9042857142857142, 'ir-eval_cosine_precision@1': 0.65, 'ir-eval_cosine_precision@3': 0.26380952380952377, 'ir-eval_cosine_precision@5': 0.16828571428571426, 'ir-eval_cosine_precision@10': 0.09042857142857141, 'ir-eval_cosine_recall@1': 0.65, 'ir-eval_cosine_recall@3': 0.7914285714285715, 'ir-eval_cosine_recall@5': 0.8414285714285714, 'ir-eval_cosine_recall@10': 0.9042857142857142, 'ir-eval_cosine_ndcg@10': 0.7753242341249659, 'ir-eval_cosine_mrr@10': 0.734347505668934, 'ir-eval_cosine_map@100': 0.7375553017174123}

Let's focus on one of the more useful retrieval metrics: NDCG@10. How did the baseline do?

# baseline result
baselines['ir-eval_cosine_ndcg@10']
0.7169950659589105

Now let's check the fine-tuned model.

# fine-tuned results
fine_tuned_results['ir-eval_cosine_ndcg@10']
0.7753242341249659

Great! We improved quite a bit on our baseline! Just how much better are we?

# how much better is the fine-tune?
fine_tuned_results['ir-eval_cosine_ndcg@10'] / baselines['ir-eval_cosine_ndcg@10']
1.0813522588025706

We got an 8% improvement in performance on this metric. That's pretty solid gain for something that took just over 5 minutes to train. And we could fit this whole process on even an extremely low-end consumer GPU. We created a fine-tuned QLoRA embedding model that solidly outperforms its baseline.

Conclusion

This post showed how we can fine-tune encoders with QLoRA adapters using the unsloth and Sentence Transformers libraries. We trained the models with a tiny fraction of the parameters that full fine-tuning would have otherwise taken.

QLoRAs for encoder models are a pretty under-explored area. It is likely that many of the parameters above are not optimal, but a proper sweep and ablatement is beyond this post. I mainly wanted to share a way to reliably fine-tune encoder moders with unsloth.

Finally, encoder models are usually significantly smaller than their LLM counterparts. This gives us a nice two-for-one, where we can use incredibly large batch sizes during our fine-tuning. And because our MNR Loss picks random examples from the same batch to use as negatives, this means our loss can pick from much more varied samples.