diff --git a/docs/source/developer_guides/lora.md b/docs/source/developer_guides/lora.md
index 77eca7de75..6d0ba9b5ff 100644
--- a/docs/source/developer_guides/lora.md
+++ b/docs/source/developer_guides/lora.md
@@ -54,6 +54,37 @@ lora_config = LoraConfig(init_lora_weights="pissa_niter_[number of iters]", ...)
```
For detailed instruction on using PiSSA, please follow [these instructions](https://github.com/huggingface/peft/tree/main/examples/pissa_finetuning).
+### CorDA
+
+[CorDA](https://arxiv.org/pdf/2406.05223) builds task-aware LoRA adapters from weight decomposition oriented by the context of downstream task to learn (instruction-previewed mode, IPM) or world knowledge to maintain (knowledge-preserved mode, KPM).
+The KPM not only achieves better performance than LoRA on fine-tuning tasks, but also mitigates the catastrophic forgetting of pre-trained world knowledge.
+When preserving pre-trained knowledge is not a concern,
+the IPM is favored because it can further accelerate convergence and enhance the fine-tuning performance.
+
+You need to configure the initialization method to "corda", and specify the mode of IPM or KPM and the dataset to collect covariance matrices.
+
+```py
+@torch.no_grad()
+def run_model():
+ # Assume `model` and `dataset` is in context...
+ model.eval()
+ for batch in dataset:
+ model(**batch)
+
+
+corda_config = CordaConfig(
+ corda_method="kpm",
+)
+lora_config = LoraConfig(
+ init_lora_weights="corda",
+ corda_config=corda_config,
+)
+preprocess_corda(model, lora_config, run_model=run_model)
+peft_model = get_peft_model(model, lora_config)
+```
+
+For detailed instruction on using CorDA, please follow [these instructions](https://github.com/huggingface/peft/tree/main/examples/corda_finetuning).
+
### OLoRA
[OLoRA](https://arxiv.org/abs/2406.01775) utilizes QR decomposition to initialize the LoRA adapters. OLoRA translates the base weights of the model by a factor of their QR decompositions, i.e., it mutates the weights before performing any training on them. This approach significantly improves stability, accelerates convergence speed, and ultimately achieves superior performance.
diff --git a/examples/corda_finetuning/README.md b/examples/corda_finetuning/README.md
new file mode 100644
index 0000000000..c248e99ae1
--- /dev/null
+++ b/examples/corda_finetuning/README.md
@@ -0,0 +1,246 @@
+# CorDA: Context-Oriented Decomposition Adaptation of Large Language Models for Task-Aware Parameter-Efficient Fine-tuning
+
+## Introduction
+
+
+Existing PEFT methods are mostly agnostic of the context of a task of concern, e.g., a downstream task to learn or some pre-trained world knowledge to maintain.
+[CorDA](https://openreview.net/pdf?id=Gi00NVru6n) builds task-aware LoRA adapters from weight decomposition oriented by the context of the task concerned.
+
+Concretely, CorDA randomly collects a few (by default 256 in our `preprocess.py`) data samples from a target task, e.g. questions from a QA dataset or instructions to write a code or solve a math problem, and feeds these samples into a pre-trained LLM. We can obtain the covariance matrix of the input activation of each linear layer, i.e., $C=XX^T\in\mathcal{R}^{d_{in}\times d_{in}}$.
+We then perform singular value decomposition (SVD) for the weight $W\in \mathcal{R}^{d_{out}\times d_{in}}$ multiplied by the covariance matrix, i.e., $\verb|SVD|(WC) = U\Sigma V^T$. In this way, the context expressed by these representative covariance matrices is able to orientate the decomposition, such that the principal components (the singular vectors with the largest singular values) are most associated with the task of concern (please refer to Fig.2 of our paper for the advantage of our decomposition over the plain SVD). To ensure the same inference result with the pre-trained model at the start of adaptation, we multiply the inverse of these covariance matrices with the decomposed components, i.e., $\hat{W}=U\Sigma V^T C^{-1}$.
+
+Thanks to the task-awareness, you can choose how to utilize the task-specific principal components. For examples, if you want to adapt a model to a new task without losing the knowledge of a question-answering dataset, e.g., TriviaQA and NQopen, you can sample questions from this dataset to collect covariance matrices, and keep the principal components frozen because they compact the ability of this dataset, while using the lowest components with the smallest $r$ singular values to initialize the learnable LoRA adapters. This is achieved by the **knowledge-preserved mode (KPM)** of CorDA, which learns new tasks effectively while keeping the world knowledge you are concerned about as sound as possible. Alternatively, when your primary objective is to maximize performance on the finetuning task, disregarding the preservation of world knowledge, the **instruction-previewed mode (IPM**) will be favored. In this mode, CorDA uses the instruction and response from the fine-tuning task (e.g., Math or Code) to produce the covariance matrices. The principal components with the largest $r$ singular values, capturing the characteristics of the finetuning task in advance, can better adapt to the new ability, so they are used to initialize the LoRA adapters, with the remaining components frozen. IPM can further accelerate convergence to enhance the fine-tuning performance on downstream tasks.
+
+
+The implementations of KPM and IPM are compared as follows:
+
+| Mode | Collect covariance from | LoRA $A$ | LoRA $B$ |
+|---|---|---|---
+|KPM | questions from the knowledge benchmark to maintain | $A=\sqrt{\Sigma}\_{[-r:]}(V^T C^{-1})\_{[-r:,:]}$ | $B=U_{[:,-r:]}\sqrt{\Sigma}_{[-r:]}$ |
+IPM | instructions and responses from the downstream task to learn | $A= \sqrt{\Sigma}\_{[:r]} (V^T C^{-1})\_{[:r,:]}$ | $B =U_{[:,:r]} \sqrt{\Sigma}_{[:r]}$ |
+
+### Comparison with alternative methods
+
+The distinction between CorDA with other similar LoRA initialization methods is summarized as follows:
+
+| Method | Initialization for | SVD on | Data-driven | Supports knowledge maintenance |
+| - | - | - | - | - |
+| PiSSA | $A$ and $B$ | weights | no | no |
+| EVA | $A$ | activations | yes | no |
+|CorDA | $A$ and $B$ | weights (oriented by covariance) | yes | yes |
+
+"Supports knowledge maintenance" denotes the ability of explicitly associating a knowledge benchmark with some components of the pre-trained weights after decomposition, and keeping these components frozen during fine-tuning.
+
+### Some Results
+
+- Performance with knowledge-preserved mode (sample from NQopen, fine-tune on Math)
+
+| Method | Model | NQ open | GSM8k | Math | Avg. |
+|---|---|---|---|---|---|
+|Pre-trained|Llama-2-7b| 14.99 | -| - | - |
+|LoRA|Llama-2-7b|1.27| 42.68 | 5.88 | 16.61 |
+|**CorDA (KPM)** |Llama-2-7b| **8.20** | **46.32** | **7.00** | **20.51** |
+|Pre-trained|Llama-2-13b|23.63|-|-|-|
+|LoRA|Llama-2-13b| 16.26 | 57.24 | 8.92 | 27.47 |
+|**CorDA (KPM)** |Llama-2-13b| **19.86** | **59.29** | **9.62** | **29.59** |
+|Pre-trained|Llama-3-8b|13.41|-|-|-|
+|LoRA|Llama-3-8b| 8.75 | 72.33 | 24.04| 35.04 |
+|**CorDA (KPM)** |Llama-3-8b| **9.61** | **74.68** | **25.34** | **36.54** |
+|Pre-trained|Gemma-2-9b|12.85|-|-|-|
+|LoRA|Gemma-2-9b| 9.28 | 83.47 | 42.30| 45.02 |
+|**CorDA (KPM)** |Gemma-2-9b|**10.17** | **84.08** | **42.64** | **45.63** |
+
+- Performance with instruction-previewed mode (sample from Math, fine-tune on Math)
+
+| Method | Model | GSM8k | Math |
+| --- | --- | --- | ---|
+|LoRA| Llama-2-7b | 42.68 | 5.88 |
+|PiSSA | Llama-2-7b | 51.63 | 7.32 |
+| **CorDA (IPM)** | Llama-2-7b | **53.45** | **8.64** |
+|LoRA| Llama-2-13b | 57.24 | 8.92 |
+|PiSSA | Llama-2-13b |60.88 | 11.08|
+| **CorDA (IPM)** | Llama-2-13b | **62.47** |**11.54** |
+|LoRA| Gemma-2-9b | 83.47 | 42.30 |
+|PiSSA | Gemma-2-9b | 84.23 | 43.52|
+| **CorDA (IPM)** | Gemma-2-9b | **84.45** | **43.88** |
+
+
+## Quick Start
+
+### Knowledge-preserved adaptation mode
+
+```py
+import torch
+from peft import LoraConfig, get_peft_model
+from transformers import AutoTokenizer, AutoModelForCausalLM
+from peft.tuners.lora.config import CordaConfig
+from peft.tuners.lora.corda import preprocess_corda
+from trl import SFTConfig, SFTTrainer
+from datasets import load_dataset
+
+model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf", torch_dtype=torch.bfloat16, device_map="auto")
+tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
+tokenizer.pad_token_id = tokenizer.eos_token_id
+sampled_dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="train[:256]")
+dataset = load_dataset("imdb", split="train[:256]")
+
+
+def run_model():
+ for batch in sampled_dataset:
+ input_ids = batch["text"]
+ input_ids = input_ids.to(model.device)
+ with torch.no_grad():
+ model(input_ids)
+
+
+corda_config = CordaConfig(
+ corda_method="kpm",
+)
+lora_config = LoraConfig(
+ init_lora_weights="corda",
+ corda_config=corda_config,
+)
+preprocess_corda(model, lora_config, run_model=run_model)
+peft_model = get_peft_model(model, lora_config)
+peft_model.print_trainable_parameters()
+
+training_args = SFTConfig(dataset_text_field="text", max_seq_length=128)
+trainer = SFTTrainer(
+ model=peft_model,
+ args=training_args,
+ train_dataset=dataset,
+ tokenizer=tokenizer,
+)
+trainer.train()
+peft_model.save_pretrained("corda-llama-2-7b")
+```
+
+### Instruction-previewed adaptation mode
+
+```py
+# Get model and dataset identically as KPM...
+
+# Different from KPM, we run the model on dataset of the downstream task to collect covariance matrices
+def run_model():
+ for batch in dataset:
+ input_ids = batch["text"]
+ input_ids = input_ids.to(model.device)
+ with torch.no_grad():
+ model(input_ids)
+
+# Different from KPM, we set `corda_method` to `"ipm"`
+corda_config = CordaConfig(
+ corda_method="ipm",
+)
+
+# The rest of training process is identical to KPM...
+```
+
+## Advanced Usage
+
+### Preprocessing
+
+`preprocess.py`: This script builds CorDA adapters for a model, and saves the adapters initial weights and residual model weights to a specified directory. Example usage:
+
+#### Knowledge-preserved adaptation mode
+
+```bash
+CUDA_VISIBLE_DEVICES=0 python -u preprocess.py --model_id="meta-llama/Llama-2-7b-hf" \
+ --r 128 --seed 233 \
+ --save_model --save_path {path_to_residual_model} \
+ --calib_dataset "nqopen"
+```
+Arguments:
+
+- `--model_id` is the pre-trained model for decomposition.
+- `--r` is the low rank of LoRA, e.g. 128.
+- `--calib_dataset` specifies the dataset to sample data to obtain covariance matrices. KPA mode uses QA datasets such as `"nqopen"`, `"traivia_qa"`, or other choices.
+- `--save_model` saves the initialized model in `--save_path`.
+
+#### Instruction-previewed adaptation mode
+
+```bash
+CUDA_VISIBLE_DEVICES=0 python -u preprocess.py --model_id="meta-llama/Llama-2-7b-hf" \
+ --r 128 --seed 233 \
+ --save_model --save_path {path_to_residual_model} \
+ --first_eigen --calib_dataset "MetaMATH"
+```
+
+Arguments:
+
+- `--first_eigen` uses the largest $r$ singular values and vectors to initialize the learnable adapter for the instruction-previewed adaptation mode.
+- `--calib_dataset` specifies the dataset to sample data to obtain covariance matrices. Instruction-previewed mode uses the downstream task dataset you are learning, such as `"MetaMATH"`, `"codefeedback"`, `"WizLMinstruct"`, `"alpaca"`, or other choices.
+
+#### Note about memory consumption
+
+The process of collecting covariance matrices is performed in `torch.float32` by default. If you would like to reduce the memory consumption of preprocessing, you can specify `use_float16_for_covariance=True` in `CordaConfig` to collect covariance matrices in `torch.float16`. But this may cause numerical instability only in a few cases, such that the initialized model does not ensure the exact same inference result as the original model. So it is suggested to check, e.g., comparing the inference result of Wiki/PTB perplexity before and after preprocessing, if you choose to perform in `torch.float16`.
+
+### Fine-tuning
+
+`corda_finetuning.py`: This script fine-tunes the preprocessed model built above on a downstream task.
+
+Example usage:
+
+```bash
+python corda_finetuning.py \
+ --model_name_or_path {path_to_residual_model} \
+ --output_dir {path_to_output_model} \
+ --corda_mode True \
+ --data_path meta-math/MetaMathQA \
+ --dataset_split "train[:100000]" \
+ --dataset_field query response \
+ --num_train_epochs 1 \
+ --per_device_train_batch_size 1 \
+ --gradient_accumulation_steps 32 \
+ --save_strategy "steps" \
+ --save_steps 100 \
+ --save_total_limit 1 \
+ --learning_rate 2e-5 \
+ --weight_decay 0. \
+ --warmup_ratio 0.03 \
+ --lr_scheduler_type "cosine" \
+ --logging_steps 1 \
+ --bf16 True \
+ --tf32 True \
+ --report_to none
+```
+
+### Convert CorDA to LoRA
+
+The main advantage of CorDA is concentrated during the training phase. For a trained CorDA adapter, we recommend converting it equivalently to the LoRA adapter for using and sharing.
+
+```python
+# The fine-tuned matrices $A$ and $B$ in CorDA adapter is saved and should be combined with the residual model.
+peft_model.save_pretrained(output_dir)
+# Given the matrices $A_0$ and $B_0$, initialized by CorDA and untrained, and the trained matrices $A$ and $B$,
+# we can convert these to LoRA by setting $\Delta W = A \times B - A_0 \times B_0 = [A \mid A_0] \times [B \mid -B_0]^T = A'B'$.
+peft_model.save_pretrained(output_dir, path_initial_model_for_weight_conversion="corda_init")
+```
+
+This conversion enables the loading of LoRA on top of a standard base model:
+
+```python
+import torch
+from peft import PeftModel
+from transformers import AutoModelForCausalLM
+
+model = AutoModelForCausalLM.from_pretrained(
+ "meta-llama/Llama-2-7b-hf", torch_dtype=torch.bfloat16, device_map="auto"
+)
+# No SVD is performed during this step, and the base model remains unaltered.
+peft_model = PeftModel.from_pretrained(model, "corda-llama-2-7b-lora")
+```
+
+Utilizing the converted LoRA does not require modifying the parameters of the base model. When multiple converted LoRAs are needed simultaneously, each adapter operates independently without interference, allowing for the adapters to be freely deleted or added.
+
+Note that this conversion is not supported if `rslora` is used in combination with `rank_pattern` or `alpha_pattern`.
+
+## Citation
+```
+@inproceedings{yangcorda,
+ title={CorDA: Context-Oriented Decomposition Adaptation of Large Language Models for Task-Aware Parameter-Efficient Fine-tuning},
+ author={Yang, Yibo and Li, Xiaojie and Zhou, Zhongzhu and Song, Shuaiwen Leon and Wu, Jianlong and Nie, Liqiang and Ghanem, Bernard},
+ booktitle={The Thirty-eighth Annual Conference on Neural Information Processing Systems},
+ year={2024},
+}
+```
\ No newline at end of file
diff --git a/examples/corda_finetuning/corda_finetuning.py b/examples/corda_finetuning/corda_finetuning.py
new file mode 100644
index 0000000000..8312e100ad
--- /dev/null
+++ b/examples/corda_finetuning/corda_finetuning.py
@@ -0,0 +1,275 @@
+# Copyright 2024-present the HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import copy
+import os
+from dataclasses import dataclass, field
+from typing import Dict, List, Optional, Sequence
+
+import torch
+import transformers
+from datasets import load_dataset
+from transformers import Trainer
+
+from peft import LoraConfig, PeftModel, get_peft_model
+
+
+IGNORE_INDEX = -100
+
+PROMPT = (
+ "Below is an instruction that describes a task. "
+ "Write a response that appropriately completes the request.\n\n"
+ "### Instruction:\n{instruction}\n\n### Response:"
+)
+
+
+def get_nb_trainable_parameters(model) -> tuple[int, int]:
+ r"""
+ Returns the number of trainable parameters and the number of all parameters in the model.
+ """
+ trainable_params = 0
+ all_param = 0
+ for _, param in model.named_parameters():
+ num_params = param.numel()
+ # if using DS Zero 3 and the weights are initialized empty
+ if num_params == 0 and hasattr(param, "ds_numel"):
+ num_params = param.ds_numel
+
+ # Due to the design of 4bit linear layers from bitsandbytes
+ # one needs to multiply the number of parameters by 2 to get
+ # the correct number of parameters
+ if param.__class__.__name__ == "Params4bit":
+ num_bytes = param.quant_storage.itemsize if hasattr(param, "quant_storage") else 1
+ num_params = num_params * 2 * num_bytes
+
+ all_param += num_params
+ if param.requires_grad:
+ trainable_params += num_params
+
+ return trainable_params, all_param
+
+
+@dataclass
+class TrainingArguments(transformers.TrainingArguments):
+ model_name_or_path: Optional[str] = field(default="facebook/opt-125m")
+ data_path: str = field(default=None, metadata={"help": "Path to the training data."})
+ dataset_split: str = field(default="train[:100000]", metadata={"help": "(`['train', 'test', 'eval']`):"})
+ dataset_field: List[str] = field(default=None, metadata={"help": "Fields of dataset input and output."})
+ dataloader_num_proc: int = field(default=16, metadata={"help": "Number of processes to load dataset"})
+ dataloader_batch_size: int = field(
+ default=3000,
+ metadata={
+ "help": "batch size to load dataset. To set the batch size for training, you should pass --batch_size argument instead."
+ },
+ )
+ optim: str = field(default="adamw_torch")
+ model_max_length: int = field(
+ default=512,
+ metadata={"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."},
+ )
+ lora_r: int = field(
+ default=None,
+ metadata={"help": "The rank of LoRA adapter. When passing `None`, CorDA or full fine-tuning is used."},
+ )
+ corda_mode: bool = field(default=True, metadata={"help": "True for CorDA mode"})
+
+
+def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str):
+ """Collects the state dict and dump to disk."""
+ state_dict = trainer.model.state_dict()
+ if trainer.args.should_save:
+ cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()}
+ del state_dict
+ trainer._save(output_dir, state_dict=cpu_state_dict) # noqa
+
+
+def smart_tokenizer_and_embedding_resize(
+ special_tokens_dict: Dict,
+ tokenizer: transformers.PreTrainedTokenizer,
+ model: transformers.PreTrainedModel,
+):
+ """Resize tokenizer and embedding.
+
+ Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
+ """
+ num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
+ model.resize_token_embeddings(len(tokenizer))
+
+ if num_new_tokens > 0:
+ input_embeddings = model.get_input_embeddings().weight.data
+ output_embeddings = model.get_output_embeddings().weight.data
+
+ input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
+ output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
+
+ input_embeddings[-num_new_tokens:] = input_embeddings_avg
+ output_embeddings[-num_new_tokens:] = output_embeddings_avg
+
+
+def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer) -> Dict:
+ """Tokenize a list of strings."""
+ tokenized_list = [
+ tokenizer(
+ text,
+ return_tensors="pt",
+ padding="longest",
+ max_length=tokenizer.model_max_length,
+ truncation=True,
+ )
+ for text in strings
+ ]
+ input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list]
+ input_ids_lens = labels_lens = [
+ tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() for tokenized in tokenized_list
+ ]
+ return {
+ "input_ids": input_ids,
+ "labels": labels,
+ "input_ids_lens": input_ids_lens,
+ "labels_lens": labels_lens,
+ }
+
+
+def preprocess(
+ sources: Sequence[str],
+ targets: Sequence[str],
+ tokenizer: transformers.PreTrainedTokenizer,
+) -> Dict:
+ """Preprocess the data by tokenizing."""
+ examples = [s + t for s, t in zip(sources, targets)]
+ examples_tokenized, sources_tokenized = (_tokenize_fn(strings, tokenizer) for strings in (examples, sources))
+ input_ids = examples_tokenized["input_ids"]
+ labels = copy.deepcopy(input_ids)
+ for label, source_len in zip(labels, sources_tokenized["input_ids_lens"]):
+ label[:source_len] = IGNORE_INDEX
+ return {
+ "input_ids": input_ids,
+ "labels": labels,
+ }
+
+
+@dataclass
+class DataCollatorForSupervisedDataset:
+ """Collate examples for supervised fine-tuning."""
+
+ tokenizer: transformers.PreTrainedTokenizer
+
+ def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
+ input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels"))
+ input_ids = [torch.tensor(x) for x in input_ids]
+ input_ids = torch.nn.utils.rnn.pad_sequence(
+ input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
+ )
+ labels = [torch.tensor(x) for x in labels]
+ labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX)
+ return {
+ "input_ids": input_ids,
+ "labels": labels,
+ "attention_mask": input_ids.ne(self.tokenizer.pad_token_id),
+ }
+
+
+def train_tokenize_function(examples, tokenizer, query, response):
+ sources = [
+ PROMPT.format_map(
+ {
+ "instruction": instruction,
+ }
+ )
+ for instruction in examples[query]
+ ]
+ targets = [f"{output}{tokenizer.eos_token}" for output in examples[response]]
+ data_dict = preprocess(sources, targets, tokenizer)
+ return data_dict
+
+
+def train():
+ parser = transformers.HfArgumentParser(TrainingArguments)
+ script_args = parser.parse_args_into_dataclasses()[0]
+ print(script_args)
+
+ if script_args.corda_mode:
+ print("Train in CorDA mode")
+ res_model = transformers.AutoModelForCausalLM.from_pretrained(
+ script_args.model_name_or_path,
+ device_map="auto",
+ )
+ model = PeftModel.from_pretrained(
+ res_model, script_args.model_name_or_path, subfolder="corda_init", is_trainable=True
+ )
+ elif script_args.lora_r is not None:
+ print("Train in LoRA mode")
+ model = transformers.AutoModelForCausalLM.from_pretrained(
+ script_args.model_name_or_path,
+ device_map="auto",
+ )
+ lora_config = LoraConfig(
+ r=script_args.lora_r,
+ lora_alpha=script_args.lora_r,
+ init_lora_weights=True, # script_args.init_lora_weights,
+ target_modules=["q_proj", "o_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj"],
+ lora_dropout=0,
+ bias="none",
+ task_type="CAUSAL_LM",
+ )
+ model = get_peft_model(model, lora_config)
+ else:
+ print("Train in Full Finetuning mode")
+ model = transformers.AutoModelForCausalLM.from_pretrained(
+ script_args.model_name_or_path,
+ torch_dtype=torch.bfloat16,
+ device_map="auto",
+ )
+ trainable_params, all_param = get_nb_trainable_parameters(model)
+ print(
+ f"trainable params: {trainable_params:,d} || all params: {all_param:,d} || trainable%: {100 * trainable_params / all_param}"
+ )
+ tokenizer = transformers.AutoTokenizer.from_pretrained(
+ script_args.model_name_or_path,
+ model_max_length=script_args.model_max_length,
+ padding_side="right",
+ use_fast=True,
+ trust_remote_code=True,
+ )
+ tokenizer.pad_token_id = tokenizer.eos_token_id
+
+ raw_train_datasets = load_dataset(script_args.data_path, split=script_args.dataset_split)
+ train_dataset = raw_train_datasets.map(
+ train_tokenize_function,
+ batched=True,
+ batch_size=script_args.dataloader_batch_size,
+ num_proc=script_args.dataloader_num_proc,
+ remove_columns=raw_train_datasets.column_names,
+ load_from_cache_file=True,
+ desc="Running tokenizer on train dataset",
+ fn_kwargs={
+ "tokenizer": tokenizer,
+ "query": script_args.dataset_field[0],
+ "response": script_args.dataset_field[1],
+ },
+ )
+
+ data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
+ data_module = {
+ "train_dataset": train_dataset,
+ "data_collator": data_collator,
+ }
+ trainer = Trainer(model=model, tokenizer=tokenizer, args=script_args, **data_module)
+ trainer.train()
+ trainer.save_state()
+ model.save_pretrained(os.path.join(script_args.output_dir, "ft"))
+
+
+if __name__ == "__main__":
+ train()
diff --git a/examples/corda_finetuning/datautils.py b/examples/corda_finetuning/datautils.py
new file mode 100644
index 0000000000..a01b5500a6
--- /dev/null
+++ b/examples/corda_finetuning/datautils.py
@@ -0,0 +1,237 @@
+# Copyright 2024-present the HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import random
+
+import numpy as np
+import torch
+from datasets import load_dataset
+
+
+"""
+doc https://huggingface.co/docs/datasets/loading
+doc https://huggingface.co/docs/datasets/process
+doc https://huggingface.co/blog/llama2#how-to-prompt-llama-2
+"""
+
+
+def set_seed(seed):
+ np.random.seed(seed)
+ torch.random.manual_seed(seed)
+
+
+def sample_train_loaders(name, tokenizer, nsamples=128, seed=0, seqlen=2048):
+ set_seed(seed)
+ if "wikitext2" in name:
+ traindata = load_dataset(
+ "wikitext",
+ "wikitext-2-raw-v1",
+ split="train",
+ )
+ traindata = "\n\n".join(traindata["text"])
+ elif "c4" in name:
+ traindata = load_dataset(
+ "allenai/c4",
+ "allenai--c4",
+ data_files={"train": "en/c4-train.00000-of-01024.json.gz"},
+ split="train",
+ )
+ traindata = "\n\n".join(traindata["text"])
+ else:
+ raise NotImplementedError
+
+ trainloader = []
+ for _ in range(nsamples):
+ i = random.randint(0, len(traindata) - seqlen * 2 - 1)
+ j = i + seqlen * 2
+ # breakpoint()
+ trainenc = tokenizer(traindata[i:j], return_tensors="pt")
+ inp = trainenc.input_ids[:, :seqlen]
+ trainloader.append(inp)
+ return trainloader
+
+
+def get_redpajama_train(tokenizer, percent=10, seed=3, batch_size=128, max_length=2048):
+ def tokenization(example):
+ return tokenizer(example["text"], truncation=True, max_length=max_length)
+
+ if percent != 100:
+ split = f"train[:{int(850000*percent/100)}]"
+ else:
+ split = "train"
+ dataset = load_dataset("togethercomputer/RedPajama-Data-1T-Sample", split=split)
+
+ processed_dataset = dataset.map(tokenization, batched=True, batch_size=batch_size, num_proc=os.cpu_count())
+ return processed_dataset
+
+
+def get_english_quote(dataset_name, tokenizer):
+ data = load_dataset(dataset_name)
+ data = data.map(lambda samples: tokenizer(samples["quote"]), batched=True)
+ return data["train"]
+
+
+def get_qat_dataset(name, tokenizer, data_percent):
+ if name == "red_pajama":
+ data = get_redpajama_train(tokenizer, data_percent)
+
+ elif name == "Abirate/english_quotes":
+ data = get_english_quote(name, tokenizer)
+ else:
+ raise NotImplementedError
+ data = data.shuffle()
+ return data
+
+
+llama_chat_format = """[INST] <>
+"Below is an instruction that describes a task. Write a response that appropriately completes the request."
+<>
+
+{instruction} [/INST] {response}
+"""
+
+
+def get_calib_data(name, tokenizer, model_id, nsamples, seqlen=2048, seed=3):
+ print(f" get_data_from: {name}, nsamples={nsamples}, seqlen={seqlen}, {seed}")
+ cache_file = f"cache/{name}_{model_id.replace('/','_')}_{nsamples}_{seqlen}_{seed}.pt"
+ if not os.path.exists("cache"):
+ os.makedirs("cache")
+ if os.path.exists(cache_file):
+ print(f"found data file: {cache_file}")
+ traindataset = torch.load(cache_file)
+ print("loaded ...")
+ return traindataset
+ if name == "c4":
+ traindata = load_dataset(
+ "allenai/c4",
+ "allenai--c4",
+ data_files={"train": "en/c4-train.00000-of-01024.json.gz"},
+ split="train",
+ )
+ tot_text = "\n\n".join(traindata["text"])
+ elif name == "wikitext2":
+ traindata = load_dataset("wikitext", "wikitext-2-raw-v1", split="train")
+ tot_text = "\n\n".join(traindata["text"])
+ elif name == "ptb":
+ traindata = load_dataset(
+ "ptb_text_only",
+ "penn_treebank",
+ split="train",
+ )
+ tot_text = "\n\n".join(traindata["sentence"])
+ elif name == "traivia_qa":
+ traindata = load_dataset("trivia_qa", "rc", split="train")
+ tot_text = "\n\n".join(traindata["question"])
+ elif name == "nqopen":
+ traindata = load_dataset("nq_open", split="train")
+ tot_text = "\n\n".join(traindata["question"])
+ elif name == "alpaca":
+ selected_data_dict = load_dataset("iboing/alpaca_data", split="train").shuffle(seed=seed).take(nsamples)
+ traindataset = []
+ for example in selected_data_dict:
+ if example.get("input", "") == "":
+ s = llama_chat_format.format(instruction=example["instruction"], response=example["output"])
+ trainenc = tokenizer(s, return_tensors="pt")
+ inp = trainenc.input_ids[:, :seqlen]
+ attention_mask = torch.ones_like(inp)
+ traindataset.append({"input_ids": inp, "attention_mask": attention_mask})
+ print("example instruction:", s)
+ torch.save(traindataset, cache_file)
+ return traindataset
+ elif name == "MetaMATH":
+ selected_data_dict = load_dataset("iboing/MetaMathQA-395K", split="train").shuffle(seed=seed).take(nsamples)
+ traindataset = []
+ for example in selected_data_dict:
+ if example.get("input", "") == "":
+ s = llama_chat_format.format(instruction=example["query"], response=example["response"])
+ trainenc = tokenizer(s, return_tensors="pt")
+ inp = trainenc.input_ids[:, :seqlen]
+ attention_mask = torch.ones_like(inp)
+ traindataset.append({"input_ids": inp, "attention_mask": attention_mask})
+ print("example instruction:", s)
+ torch.save(traindataset, cache_file)
+ return traindataset
+ elif name == "codefeedback":
+ selected_data_dict = (
+ load_dataset("iboing/CodeFeedback-Filtered-Instruction", split="train").shuffle(seed=seed).take(nsamples)
+ )
+ for example in selected_data_dict:
+ if example.get("input", "") == "":
+ s = llama_chat_format.format(instruction=example["query"], response=example["answer"])
+ trainenc = tokenizer(s, return_tensors="pt")
+ inp = trainenc.input_ids[:, :seqlen]
+ attention_mask = torch.ones_like(inp)
+ traindataset.append({"input_ids": inp, "attention_mask": attention_mask})
+ print("example instruction:", s)
+ torch.save(traindataset, cache_file)
+ return traindataset
+ elif name == "WizLMinstruct":
+ selected_data_dict = (
+ load_dataset("iboing/WizardLM_evol_instruct_V2_143k", split="train").shuffle(seed=seed).take(nsamples)
+ )
+ for example in selected_data_dict:
+ if example.get("input", "") == "":
+ s = llama_chat_format.format(
+ instruction=example["conversation"][0]["human"], response=example["conversation"][0]["assistant"]
+ )
+ trainenc = tokenizer(s, return_tensors="pt")
+ inp = trainenc.input_ids[:, :seqlen]
+ attention_mask = torch.ones_like(inp)
+ traindataset.append({"input_ids": inp, "attention_mask": attention_mask})
+ print("example instruction:", s)
+ torch.save(traindataset, cache_file)
+ return traindataset
+ else:
+ raise NotImplementedError
+ print(f"tot_text={len(tot_text)}")
+ traindataset = []
+ for _ in range(nsamples):
+ i = random.randint(0, len(tot_text) - seqlen - 1)
+ j = i + seqlen * 10
+ trainenc = tokenizer(tot_text[i:j], return_tensors="pt")
+ inp = trainenc.input_ids[:, :seqlen]
+ attention_mask = torch.ones_like(inp)
+ traindataset.append({"input_ids": inp, "attention_mask": attention_mask})
+ torch.save(traindataset, cache_file)
+ return traindataset
+
+
+def get_eval_loaders(name, tokenizer):
+ if "wikitext2" in name:
+ testdata = load_dataset(
+ "wikitext",
+ "wikitext-2-raw-v1",
+ split="test",
+ )
+ testenc = tokenizer("\n\n".join(testdata["text"]), return_tensors="pt")
+ return testenc
+ if "ptb" in name:
+ valdata = load_dataset(
+ "ptb_text_only",
+ "penn_treebank",
+ split="validation",
+ )
+ testenc = tokenizer("\n\n".join(valdata["sentence"]), return_tensors="pt")
+ return testenc
+ if "c4" in name:
+ testdata = load_dataset(
+ "allenai/c4",
+ "allenai--c4",
+ data_files={"validation": "en/c4-validation.00000-of-00008.json.gz"},
+ split="validation",
+ )
+ testenc = tokenizer("\n\n".join(testdata["text"]), return_tensors="pt")
+ return testenc
+ raise NotImplementedError
diff --git a/examples/corda_finetuning/preprocess.py b/examples/corda_finetuning/preprocess.py
new file mode 100644
index 0000000000..01721d296e
--- /dev/null
+++ b/examples/corda_finetuning/preprocess.py
@@ -0,0 +1,162 @@
+# Copyright 2024-present the HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import argparse
+import os
+
+import numpy as np
+import torch
+from datautils import get_calib_data
+from tqdm import tqdm
+from transformers import AutoModelForCausalLM, AutoTokenizer
+
+from peft.mapping import get_peft_model
+from peft.tuners.lora.config import CordaConfig, LoraConfig
+from peft.tuners.lora.corda import preprocess_corda
+
+
+@torch.no_grad()
+def run_model(model, calib_loader):
+ model.eval()
+ for batch in tqdm(calib_loader):
+ batch = {k: v.to(model.device) for k, v in batch.items()}
+ model(**batch)
+
+
+def main(args):
+ # Setting random seed of numpy and torch
+ np.random.seed(args.seed)
+ torch.manual_seed(args.seed)
+ torch.cuda.manual_seed_all(args.seed)
+ torch.backends.cudnn.deterministic = True
+
+ # Load model
+ model_id = args.model_id
+ tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
+
+ model = AutoModelForCausalLM.from_pretrained(
+ model_id, device_map="auto", torch_dtype=torch.float16, trust_remote_code=True
+ )
+
+ # Collect data
+ calib_loader = get_calib_data(args.calib_dataset, tokenizer, model_id, args.calib_loader_size, seed=args.seed)
+
+ # Evaluate the original model
+ print("\n---- model before svd ---\n")
+ print(model)
+
+ # Perform decomposition
+ corda_config = CordaConfig(
+ corda_method="ipm" if args.first_eigen else "kpm",
+ )
+ lora_config = LoraConfig(
+ init_lora_weights="corda",
+ target_modules=["q_proj", "o_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj"],
+ r=args.r,
+ lora_alpha=args.r,
+ corda_config=corda_config,
+ )
+ preprocess_corda(
+ model,
+ lora_config,
+ run_model=lambda: run_model(model, calib_loader),
+ )
+ model = get_peft_model(model, lora_config)
+
+ # Evaluate again to check if the model is consistent
+ # Using `model.model` here because `get_peft_model` wraps a layer to the model
+ print("\n---- model after svd ---\n")
+ print(model)
+
+ # Save as hugging face model
+ if args.save_model:
+ assert args.save_path is not None
+ save_path = args.save_path
+
+ # Save CorDA modules
+ model.peft_config["default"].init_lora_weights = True
+ model.save_pretrained(os.path.join(save_path, "corda_init"))
+
+ # Save residual model
+ model = model.unload()
+ model.save_pretrained(save_path)
+
+ # Save tokenizer
+ tokenizer.save_pretrained(save_path)
+ print(f"Done building CorDA huggingface model in {save_path}")
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--model_id",
+ type=str,
+ default="meta-llama/Llama-2-7b-hf",
+ help="Pretrained model ID",
+ )
+ parser.add_argument(
+ "--calib_loader_size",
+ type=int,
+ default=256,
+ help="number of samples used for covariance matrices",
+ )
+ parser.add_argument(
+ "--calib_dataset",
+ type=str,
+ default="wikitext2",
+ choices=[
+ "wikitext2",
+ "c4",
+ "ptb",
+ "traivia_qa",
+ "nqopen",
+ "MetaMATH",
+ "codefeedback",
+ "WizLMinstruct",
+ "alpaca",
+ ],
+ help="calibration dataset",
+ )
+ parser.add_argument(
+ "--eval_mmlu",
+ action="store_true",
+ help="evaluate mmlu",
+ )
+ parser.add_argument(
+ "--seed",
+ type=int,
+ default=233,
+ help="random seed",
+ )
+ parser.add_argument(
+ "--r",
+ type=int,
+ default=None,
+ )
+ parser.add_argument(
+ "--first_eigen",
+ action="store_true",
+ )
+ parser.add_argument(
+ "--save_model",
+ action="store_true",
+ )
+ parser.add_argument(
+ "--save_path",
+ type=str,
+ default=None,
+ )
+ args = parser.parse_args()
+
+ main(args)
diff --git a/src/peft/peft_model.py b/src/peft/peft_model.py
index a48570fe7e..6f7caaed06 100644
--- a/src/peft/peft_model.py
+++ b/src/peft/peft_model.py
@@ -253,13 +253,13 @@ def save_pretrained(
Whether the process calling this is the main process or not. Will default to `True`. Will not save the
checkpoint if not on the main process, which is important for multi device setups (e.g. DDP).
path_initial_model_for_weight_conversion (`str, *optional*`):
- The path to the initialized adapter, which is obtained after initializing the model with PiSSA or OLoRA
- and before performing any training. When `path_initial_model_for_weight_conversion` is not None, the
- difference in adapter before and after fine-tuning is calculated. This difference can be represented as
- the parameters of a standard LoRA adapter. Using this converted adapter does not require changes to the
- base model, thus conveniently allowing the use of multiple PiSSA or OLoRA adapters with LoRA adapters,
- and the activation or deactivation of any adapters. Note that this conversion is not supported if
- `rslora` is used in combination with `rank_pattern` or `alpha_pattern`.
+ The path to the initialized adapter, which is obtained after initializing the model with
+ PiSSA/CorDA/OLoRA and before performing any training. When `path_initial_model_for_weight_conversion`
+ is not None, the difference in adapter before and after fine-tuning is calculated. This difference can
+ be represented as the parameters of a standard LoRA adapter. Using this converted adapter does not
+ require changes to the base model, thus conveniently allowing the use of multiple PiSSA/CorDA/OLoRA
+ adapters with LoRA adapters, and the activation or deactivation of any adapters. Note that this
+ conversion is not supported if `rslora` is used in combination with `rank_pattern` or `alpha_pattern`.
kwargs (additional keyword arguments, *optional*):
Additional keyword arguments passed along to the `push_to_hub` method.
@@ -288,10 +288,11 @@ def save_mutated_as_lora(peft_config, path_initial_model_for_weight_conversion,
raise ValueError(msg)
if not any(
- str(peft_config.init_lora_weights).lower().startswith(prefix) for prefix in ["pissa", "olora", "true"]
+ str(peft_config.init_lora_weights).lower().startswith(prefix)
+ for prefix in ["pissa", "corda", "olora", "true"]
):
warnings.warn(
- "`path_initial_model_for_weight_conversion` only works for converting a PiSSA or OLoRA adapter to "
+ "`path_initial_model_for_weight_conversion` only works for converting a PiSSA/CorDA/OLoRA adapter to "
"a LoRA adapter"
)
initial_adapter_name = os.path.basename(path_initial_model_for_weight_conversion)
@@ -302,8 +303,9 @@ def save_mutated_as_lora(peft_config, path_initial_model_for_weight_conversion,
adapter_name=initial_adapter_name,
)
is_pissa = str(self.peft_config[initial_adapter_name].init_lora_weights).lower().startswith("pissa")
+ is_corda = str(self.peft_config[initial_adapter_name].init_lora_weights).lower() == "corda"
is_olora = str(self.peft_config[initial_adapter_name].init_lora_weights).lower() == "olora"
- if is_pissa or is_olora:
+ if is_pissa or is_corda or is_olora:
raise ValueError(
"The `init_lora_weights` parameter of the initial adapter should be set to `True`. "
"Otherwise, `self.load_adapter` will subtract the decomposed values again based on the "
@@ -1164,7 +1166,7 @@ def _check_new_adapter_config(self, peft_config: PeftConfig, is_trainable: bool)
if peft_config.is_prompt_learning and is_trainable:
raise ValueError("Cannot set a prompt learning adapter to trainable when loading pretrained adapter.")
- # Since PiSSA/OLoRA modifies the base weights, it should not be combined with other adapters.
+ # Since PiSSA/CorDA/OLoRA modifies the base weights, it should not be combined with other adapters.
all_configs = [peft_config] + list(self.peft_config.values())
if len(all_configs) > 1:
if any(getattr(config, "init_lora_weights", None) == "pissa" for config in all_configs):
@@ -1174,6 +1176,13 @@ def _check_new_adapter_config(self, peft_config: PeftConfig, is_trainable: bool)
"https://github.com/huggingface/peft/tree/main/examples/pissa_finetuning#convert-pissa-to-lora"
)
warnings.warn(msg)
+ elif any(getattr(config, "init_lora_weights", None) == "corda" for config in all_configs):
+ msg = (
+ "CorDA changes the base weights of the model and should thus not be used with other adapters. "
+ "Consider converting the CorDA adapter into a normal LoRA adapter: "
+ "https://github.com/huggingface/peft/tree/main/examples/corda_finetuning#convert-corda-to-lora"
+ )
+ warnings.warn(msg)
elif any(getattr(config, "init_lora_weights", None) == "olora" for config in all_configs):
msg = (
"OLoRA changes the base weights of the model and should thus not be used with other adapters. "
diff --git a/src/peft/tuners/lora/config.py b/src/peft/tuners/lora/config.py
index 051d59e3e6..a62a3420da 100644
--- a/src/peft/tuners/lora/config.py
+++ b/src/peft/tuners/lora/config.py
@@ -120,6 +120,77 @@ def __post_init__(self):
raise ValueError("`tau` must be between 0.0 and 1.0.")
+@dataclass
+class CordaConfig:
+ """
+ This is the sub-configuration class to store the configuration of a [`LoraModel`].
+
+ Args:
+ cache_file (`Optional[str]`):
+ File to store the SVD cache. The SVD cache is much smaller than the residual model (for example, residual
+ model of Llama-3-8b is 15GB, while SVD cache is 1.4GB), but with SVD cache and original model weights,
+ residual model weights can be built quickly. If you need to reuse residual model weights with limited
+ storage, you can store the SVD cache instead.
+ covariance_file (`Optional[str]`):
+ File to store the covariance matrix. If you wish to train multiple models with different ranks, but they
+ sample from the same dataset, you can store the covariance matrix and reuse it for different ranks. Note
+ that covariance file is usually large (comparable to model size), so you will need sufficient storage.
+ corda_method (`Literal["ipm", "kpm"]`):
+ Method to build adapter. The KPM (Knowledge-Preserved Mode) not only achieves better performance than LoRA
+ on fine-tuning tasks, but also mitigates the catastrophic forgetting of pre-trained world knowledge. When
+ preserving pre-trained knowledge is not a concern, the IPM (Instruction-Previewed Mode) is favored because
+ it can further accelerate convergence and enhance the fine-tuning performance. Defaults to `'ipm'`.
+ verbose (`bool`):
+ If true, prints the progress of CorDA initialization. Defaults to `False`.
+ use_float16_for_covariance (`bool`):
+ If true, uses float16 for the covariance matrix. This can reduce the memory usage of the covariance matrix
+ by half, but may lead to numerical instability. Defaults to `False`.
+ """
+
+ cache_file: Optional[str] = field(
+ default=None,
+ metadata={
+ "help": (
+ "File to store the SVD cache. The SVD cache is much smaller than the residual model (for example, "
+ "residual model of Llama-3-8b is 15GB, while SVD cache is 1.4GB), but with SVD cache and original model "
+ "weights, residual model weights can be built quickly. If you need to reuse residual model weights with "
+ "limited storage, you can store the SVD cache instead."
+ )
+ },
+ )
+ covariance_file: Optional[str] = field(
+ default=None,
+ metadata={
+ "help": (
+ "File to store the covariance matrix. If you wish to train multiple models with different ranks, but "
+ "they sample from the same dataset, you can store the covariance matrix and reuse it for different ranks. "
+ "Note that covariance file is usually large (comparable to model size), so you will need sufficient storage."
+ )
+ },
+ )
+ corda_method: Literal["ipm", "kpm"] = field(
+ default="ipm",
+ metadata={
+ "help": (
+ "Method to build adapter. The KPM not only achieves better performance than LoRA on fine-tuning tasks, but "
+ "also mitigates the catastrophic forgetting of pre-trained world knowledge. When preserving pre-trained "
+ "knowledge is not a concern, the IPM is favored because it can further accelerate convergence and enhance "
+ "the fine-tuning performance."
+ )
+ },
+ )
+ verbose: bool = field(default=False, metadata={"help": "If true, prints the progress of CorDA initialization."})
+ use_float16_for_covariance: bool = field(
+ default=False,
+ metadata={
+ "help": (
+ "If true, uses float16 for the covariance matrix. This can reduce the memory usage of the covariance matrix "
+ "by half, but may lead to numerical instability."
+ )
+ },
+ )
+
+
@dataclass
class LoraConfig(PeftConfig):
"""
@@ -157,7 +228,7 @@ class LoraConfig(PeftConfig):
Otherwise, it will use the original default value of `lora_alpha/r`.
modules_to_save (`List[str]`):
List of modules apart from adapter layers to be set as trainable and saved in the final checkpoint.
- init_lora_weights (`bool` | `Literal["gaussian", "eva", "olora", "pissa", "pissa_niter_[number of iters]", "loftq"]`):
+ init_lora_weights (`bool` | `Literal["gaussian", "eva", "olora", "pissa", "pissa_niter_[number of iters]", "corda", "loftq"]`):
How to initialize the weights of the adapter layers. Passing True (default) results in the default
initialization from the reference implementation from Microsoft. Passing 'gaussian' results in Gaussian
initialization scaled by the LoRA rank for linear and layers. Setting the initialization to False leads to
@@ -171,7 +242,10 @@ class LoraConfig(PeftConfig):
leading to further enhancements. Passing `'pissa_niter_[number of iters]'` initiates Fast-SVD-based PiSSA
initialization, where `[number of iters]` indicates the number of subspace iterations to perform FSVD, and
must be a nonnegative integer. When `[number of iters]` is set to 16, it can complete the initialization of
- a 7B model within seconds, and the training effect is approximately equivalent to using SVD.
+ a 7B model within seconds, and the training effect is approximately equivalent to using SVD. Passing
+ `'corda'` results in the initialization of Context-Oriented
+ Decomposition Adaptation, which converges even more rapidly than PiSSA in Instruction-Previewed Mode,
+ and preserves world knowledge better than LoRA in Knowledge-Preserved Mode.
layers_to_transform (`Union[List[int], int]`):
The layer indices to transform. If a list of ints is passed, it will apply the adapter to the layer indices
that are specified in this list. If a single integer is passed, it will apply the transformations on the
@@ -199,6 +273,9 @@ class LoraConfig(PeftConfig):
eva_config (`Optional[EvaConfig]`):
The configuration of EVA. At a minimum the dataset argument needs to be set (use the same dataset as for
finetuning).
+ corda_config (`Optional[CordaConfig]`):
+ The configuration of CorDA. If this is not None, then CorDA will be used to build the adapter layers. Also
+ pass `init_lora_weights='corda'`.
use_dora (`bool`):
Enable 'Weight-Decomposed Low-Rank Adaptation' (DoRA). This technique decomposes the updates of the weights
into two parts, magnitude and direction. Direction is handled by normal LoRA, whereas the magnitude is
@@ -265,7 +342,7 @@ class LoraConfig(PeftConfig):
},
)
init_lora_weights: (
- bool | Literal["gaussian", "eva", "olora", "pissa", "pissa_niter_[number of iters]", "loftq"]
+ bool | Literal["gaussian", "eva", "olora", "pissa", "pissa_niter_[number of iters]", "corda", "loftq"]
) = field(
default=True,
metadata={
@@ -279,6 +356,7 @@ class LoraConfig(PeftConfig):
"Passing `'pissa'` results in PiSSA initialization."
"Passing `'pissa_niter_[number of iters]'` initiates Fast-SVD-based PiSSA initialization, "
"where [number of iters] indicates the number of subspace iterations to perform fsvd, and must be a nonnegative integer."
+ "Passing `'corda'` results in CorDA initialization."
"Pass `'loftq'` to use LoftQ initialization"
),
},
@@ -361,6 +439,15 @@ class LoraConfig(PeftConfig):
)
},
)
+ corda_config: Optional[CordaConfig] = field(
+ default=None,
+ metadata={
+ "help": (
+ "The configuration of CorDA. If this is passed, then CorDA will be used to build the adapter layers. "
+ "Also set `init_lora_weights='corda'` in this case."
+ )
+ },
+ )
use_dora: bool = field(
default=False,
metadata={
@@ -461,6 +548,14 @@ def __post_init__(self):
elif self.init_lora_weights != "eva" and self.eva_config is not None:
warnings.warn("`eva_config` specified but will be ignored when `init_lora_weights` is not 'eva'.")
+ elif self.init_lora_weights == "corda" and self.corda_config is None:
+ warnings.warn(
+ "`init_lora_weights` is 'corda' but `corda_config` is not specified. Using default CorDA config."
+ )
+ self.corda_config = CordaConfig()
+ elif self.init_lora_weights != "corda" and self.corda_config is not None:
+ warnings.warn("`corda_config` specified but will be ignored when `init_lora_weights` is not 'corda'.")
+
if self.lora_bias:
if self.init_lora_weights not in (True, False):
raise ValueError(
@@ -470,7 +565,7 @@ def __post_init__(self):
if self.use_dora:
raise ValueError("The argument lora_bias=True is not supported for DoRA, please pass use_dora=False")
- # Using post training conversion of modified base weights to restore their initial values (PiSSA, OLoRA) cannot
+ # Using post training conversion of modified base weights to restore their initial values PiSSA/CorDA/OLoRA cannot
# be correctly done when using rslora + rank_pattern/alpha_pattern. We can't really know if the user intends
# this when they'll eventually call save_pretrained (i.e. if they'll pass
# path_initial_model_for_weight_conversionl). Therefore, we only warn but don't raise an error here.
@@ -480,11 +575,12 @@ def __post_init__(self):
and (
(isinstance(self.init_lora_weights, str) and (self.init_lora_weights.startswith("pissa")))
or (self.init_lora_weights == "olora")
+ or (self.init_lora_weights == "corda")
)
):
msg = (
"Using Rank-Stabilized LoRA with rank_pattern/alpha_pattern and post-training conversion of modified "
- "base weights (PiSSA, OLoRA) means that you won't be able to pass "
+ "base weights PiSSA/CorDA/OLoRA means that you won't be able to pass "
"`path_initial_model_for_weight_conversion` to `save_pretrained` to restore the initial values of the "
"base weights; if you intend to do this, please ensure not to use rslora or rank_pattern/alpha_pattern."
)
diff --git a/src/peft/tuners/lora/corda.py b/src/peft/tuners/lora/corda.py
new file mode 100644
index 0000000000..0d1d70b1a8
--- /dev/null
+++ b/src/peft/tuners/lora/corda.py
@@ -0,0 +1,359 @@
+# Copyright 2024-present the HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Reference code: https://github.com/iboing/CorDA/blob/main/cordalib/decomposition.py
+# Reference paper: https://arxiv.org/abs/2406.05223
+
+import os
+from typing import Any, Callable, Iterable, Optional
+
+import torch
+import torch.nn as nn
+from attr import dataclass
+from tqdm import tqdm
+
+from peft.tuners.lora.config import LoraConfig
+from peft.tuners.lora.model import LoraModel
+from peft.utils.other import get_pattern_key
+
+
+@dataclass
+class CordaEigens:
+ S_WC: torch.Tensor
+ U_WC: torch.Tensor
+ V_WC: torch.Tensor
+
+
+def target_modules(model: nn.Module, config: LoraConfig) -> Iterable[nn.Module]:
+ """
+ Iterate over CorDA target name and modules of a model. A module is a target if its name is in
+ `config.target_modules` and is `nn.Linear`.
+ """
+ for name, module in model.named_modules():
+ if LoraModel._check_target_module_exists(config, name) and isinstance(module, nn.Linear):
+ yield name, module
+
+
+def get_model_device(model: nn.Module) -> str:
+ if hasattr(model, "module"): # Handle DeepSpeed/DataParallel
+ model = model.module
+ return next(iter(model.parameters())).device.type
+
+
+@torch.no_grad()
+def preprocess_corda(
+ model: nn.Module,
+ lora_config: LoraConfig,
+ run_model: Optional[Callable[[], None]] = None,
+ hooked_model: Optional[nn.Module] = None,
+):
+ """
+ Build necessary CorDA fields for a model.
+
+ Args:
+ model (`nn.Module`):
+ Model to preprocess.
+ lora_config (`LoraConfig`):
+ Lora configuration of the model. `lora_config.corda_config` should be set.
+ run_model (`Optional[Callable[[], None]]`):
+ Callback to run the model when building covariance. Typically you should run model inference on your sample
+ dataset in this callback. Experiments have shown 256 samples to be a good default dataset size. `run_model`
+ can be `None` only if covariance file in `lora_config.corda_config` is already created.
+ hooked_model (`Optional[nn.Module]`):
+ Model to hook when building covariance. If none, original model will be hooked. This is only useful when
+ you want to hook a different model than the one you are training, typically you should leave this `None`.
+
+ Upon completion, the following fields are set for each target module:
+ corda_method (`Literal["ipm", "kpm"]`):
+ CorDA method to apply. "ipm" for Instruction-Previewed Mode, "kpm" for Knowledge-Preserved Mode.
+ rank (`int`):
+ Rank of CorDA to apply.
+ eigens.S_WC (`torch.Tensor`):
+ Singular values of the weight matrix.
+ eigens.U_WC (`torch.Tensor`):
+ Left singular vectors of the weight matrix.
+ eigens.V_WC (`torch.Tensor`):
+ Right singular vectors of the weight matrix, multiplied by inverse of covariance matrix.
+ """
+ cache_file = lora_config.corda_config.cache_file
+ covariance_file = lora_config.corda_config.covariance_file
+ corda_method = lora_config.corda_config.corda_method
+ verbose = lora_config.corda_config.verbose
+
+ # If cache exists, skip building
+ if cache_file is not None and os.path.exists(cache_file) and os.path.getsize(cache_file) > 0:
+ cache = torch.load(cache_file, map_location=get_model_device(model))
+ for name, module in target_modules(model, lora_config):
+ module.corda_method = cache[f"{name}.corda_method"]
+ module.rank = cache[f"{name}.rank"]
+ module.eigens = CordaEigens(
+ S_WC=cache[f"{name}.eigens.S_WC"],
+ U_WC=cache[f"{name}.eigens.U_WC"],
+ V_WC=cache[f"{name}.eigens.V_WC"],
+ )
+ else:
+ # Specify CorDA method for each layer
+ if corda_method is None:
+ raise ValueError("corda_method is required when cache_file is not provided.")
+ for name, module in target_modules(model, lora_config):
+ module.corda_method = corda_method
+
+ # Specify CorDA rank for each layer
+ for name, module in target_modules(model, lora_config):
+ r_key = get_pattern_key(lora_config.rank_pattern.keys(), name)
+ module.rank = lora_config.rank_pattern.get(r_key, lora_config.r)
+
+ # Calculate covariance matrix
+ calib_cov_distribution(model, lora_config, run_model, hooked_model, covariance_file)
+
+ # Calculate eigens
+ collect_eigens(model, lora_config, verbose)
+
+ # Crop CorDA eigens so that there's less to save
+ crop_corda_eigens(model, lora_config)
+
+ # Save cache to disk
+ if cache_file is not None:
+ cache: dict[str, Any] = {}
+ for name, module in target_modules(model, lora_config):
+ cache[f"{name}.corda_method"] = module.corda_method
+ cache[f"{name}.rank"] = module.rank
+ cache[f"{name}.eigens.S_WC"] = module.eigens.S_WC
+ cache[f"{name}.eigens.U_WC"] = module.eigens.U_WC
+ cache[f"{name}.eigens.V_WC"] = module.eigens.V_WC
+
+ os.makedirs(os.path.dirname(cache_file), exist_ok=True)
+ torch.save(cache, cache_file)
+
+
+@torch.no_grad()
+def calib_cov_distribution(
+ model: nn.Module,
+ config: LoraConfig,
+ run_model: Optional[Callable[[], None]],
+ hooked_model: Optional[nn.Module],
+ covariance_file: Optional[str],
+):
+ if covariance_file is not None and os.path.exists(covariance_file) and os.path.getsize(covariance_file) > 0:
+ all_covariance_matrix = torch.load(covariance_file, map_location=get_model_device(model))
+ for name, module in target_modules(model, config):
+ module.covariance_matrix = all_covariance_matrix[name]
+ return
+
+ if run_model is None:
+ raise ValueError("run_model must be specified when covariance file and cache file aren't built.")
+ if hooked_model is None:
+ hooked_model = model
+ hooked_model.eval()
+
+ def hook(module, input, output):
+ input = input[0].detach().squeeze(0).data ## (context_length = 2048, dim)
+ if not config.corda_config.use_float16_for_covariance:
+ input = input.float()
+ input = input / torch.max(input).abs()
+
+ # check if input is valid
+ if torch.isnan(input).any() or torch.isinf(input).any():
+ raise ValueError("Invalid value found in input, please check your input data.")
+
+ # calculate covariance and check if it's valid
+ covariance = input.t().matmul(input)
+ if torch.isnan(covariance).any() or torch.isinf(covariance).any():
+ raise ValueError(
+ "Invalid value found in covariance. Please file an issue at https://github.com/huggingface/peft/issues."
+ )
+
+ # calculate mean and std
+ mean = input.mean(0)
+ std = input.std(0)
+
+ # add to module
+ module.sample_count += 1
+ module.covariance_matrix += covariance
+ module.mean += mean
+ module.std += std
+
+ # free memory
+ del covariance, input
+
+ handles = []
+ for name, module in target_modules(hooked_model, config):
+ module.sample_count = 0
+ module.covariance_matrix = 0
+ module.mean = 0
+ module.std = 0
+ handles.append(module.register_forward_hook(hook))
+
+ run_model()
+
+ # Clear the hooks
+ for handle in handles:
+ handle.remove()
+
+ # In some edge cases you might need to hook a model different from the model to add adapters,
+ # this case you would specify `hooked_model` and set it to a different model from `model`.
+ if hooked_model is not model:
+ targets = {}
+ for name, module in target_modules(model, config):
+ targets[name] = module
+ for name, module in target_modules(hooked_model, config):
+ # There can be modules used only in inference, but not training
+ # Exclude modules not in target model to prevent KeyError in this case
+ if name in targets:
+ targets[name].sample_count = module.sample_count
+ targets[name].covariance_matrix = module.covariance_matrix
+ targets[name].mean = module.mean
+ targets[name].std = module.std
+
+ # Divide by sample count
+ for name, module in target_modules(model, config):
+ module.covariance_matrix /= module.sample_count
+ module.mean /= module.sample_count
+ module.std /= module.sample_count
+
+ # Save covariance to disk
+ if covariance_file is not None:
+ all_covariance_matrix = {}
+ for name, module in target_modules(model, config):
+ all_covariance_matrix[name] = module.covariance_matrix
+ os.makedirs(os.path.dirname(covariance_file), exist_ok=True)
+ torch.save(all_covariance_matrix, covariance_file)
+
+
+@torch.no_grad()
+def collect_eigens(
+ model: nn.Module,
+ config: LoraConfig,
+ verbose: bool,
+):
+ """Call collect_eigens_for_layer and store result in key `eigens` of each layer."""
+ linear_modules = []
+ for name, module in target_modules(model, config):
+ linear_modules.append((name, module))
+ if verbose:
+ linear_modules = tqdm(linear_modules, desc="Collecting eigens")
+ for name, module in linear_modules:
+ module.eigens = collect_eigens_for_layer(module, config)
+
+
+@torch.no_grad()
+def collect_eigens_for_layer(
+ linear: nn.Linear,
+ config: LoraConfig,
+) -> CordaEigens:
+ w = linear.weight.data.float()
+ out_dim = w.size(0)
+ in_dim = w.size(1)
+ min_dim = min(in_dim, out_dim)
+
+ if not hasattr(linear, "covariance_matrix"):
+ raise ValueError(
+ "Covariance matrix not found in linear module. Please do not call this function directly, "
+ "instead call `preprocess_corda`. If your usage is correct but this error still encounters, "
+ "please file an issue at https://github.com/huggingface/peft/issues."
+ )
+ covariance_matrix = linear.covariance_matrix.float()
+
+ damp = 0.01
+ while True:
+ compensate = torch.diag(
+ torch.ones(covariance_matrix.size(0)).to(covariance_matrix.device)
+ * torch.mean(torch.diag(covariance_matrix))
+ * damp
+ )
+ fix_covariance_matrix = covariance_matrix + compensate
+ cov_inv = torch.linalg.inv(fix_covariance_matrix)
+ inv_error = torch.dist(
+ fix_covariance_matrix @ cov_inv, torch.eye(covariance_matrix.size(0)).to(get_model_device(linear))
+ ).item()
+ if inv_error < 0.05:
+ break
+ else:
+ damp = damp * 2
+ w = w @ fix_covariance_matrix ## w: out_dim, in_dim; covariance_matrix: in_dim, in_dim
+
+ U, S, Vh = torch.linalg.svd(w, full_matrices=False)
+ V = (Vh @ cov_inv).transpose(0, 1)
+
+ # Sanity check, temporarily U and V are large, they will be crop after rank search
+ r = min_dim
+ if U.size(0) != out_dim or U.size(1) != r:
+ raise ValueError(
+ f"Matrix U size mismatch: {U.size()} vs. ({out_dim}, {r}), "
+ "please file an issue at https://github.com/huggingface/peft/issues."
+ )
+ if S.size(0) != r:
+ raise ValueError(
+ f"Matrix S size mismatch: {S.size()} vs. ({r},), "
+ "please file an issue at https://github.com/huggingface/peft/issues."
+ )
+ if V.size(0) != in_dim or V.size(1) != r:
+ raise ValueError(
+ f"Matrix V size mismatch: {V.size()} vs. ({in_dim}, {r}), "
+ "please file an issue at https://github.com/huggingface/peft/issues."
+ )
+
+ # Offload U and V to CPU, they consume too much memory
+ U = U.cpu()
+ V = V.cpu()
+ return CordaEigens(
+ S_WC=S,
+ U_WC=U,
+ V_WC=V,
+ )
+
+
+@torch.no_grad()
+def crop_corda_eigens(model: nn.Module, config: LoraConfig):
+ for name, module in target_modules(model, config):
+ # We don't expect saving sliced tensor writes the whole tensor to disk,
+ # so it's necessary to copy the tensors.
+ # Reference: https://github.com/pytorch/pytorch/issues/40157
+ if module.corda_method == "ipm":
+ module.eigens.S_WC = module.eigens.S_WC[: module.rank].clone()
+ module.eigens.U_WC = module.eigens.U_WC[:, : module.rank].clone().to(get_model_device(model))
+ module.eigens.V_WC = module.eigens.V_WC[:, : module.rank].clone().to(get_model_device(model))
+ elif module.corda_method == "kpm":
+ module.eigens.S_WC = module.eigens.S_WC[-module.rank :].clone()
+ module.eigens.U_WC = module.eigens.U_WC[:, -module.rank :].clone().to(get_model_device(model))
+ module.eigens.V_WC = module.eigens.V_WC[:, -module.rank :].clone().to(get_model_device(model))
+ else:
+ raise ValueError(f"Invalid corda_method found: {module.corda_method}, it should be 'ipm' or 'kpm'.")
+
+ # Sanity check
+ if module.eigens.S_WC.size(0) != module.rank:
+ raise ValueError(
+ f"rank mismatch: {module.eigens.S_WC.size(0)} vs. {module.rank},"
+ "please file an issue at https://github.com/huggingface/peft/issues."
+ )
+ if module.eigens.U_WC.size(0) != module.weight.size(0):
+ raise ValueError(
+ f"U size mismatch: {module.eigens.U_WC.size(0)} vs. {module.weight.size(0)},"
+ "please file an issue at https://github.com/huggingface/peft/issues."
+ )
+ if module.eigens.U_WC.size(1) != module.rank:
+ raise ValueError(
+ f"U size mismatch: {module.eigens.U_WC.size(1)} vs. {module.rank},"
+ "please file an issue at https://github.com/huggingface/peft/issues."
+ )
+ if module.eigens.V_WC.size(0) != module.weight.size(1):
+ raise ValueError(
+ f"V size mismatch: {module.eigens.V_WC.size(0)} vs. {module.weight.size(1)},"
+ "please file an issue at https://github.com/huggingface/peft/issues."
+ )
+ if module.eigens.V_WC.size(1) != module.rank:
+ raise ValueError(
+ f"V size mismatch: {module.eigens.V_WC.size(1)} vs. {module.rank},"
+ "please file an issue at https://github.com/huggingface/peft/issues."
+ )
diff --git a/src/peft/tuners/lora/layer.py b/src/peft/tuners/lora/layer.py
index 8965c5eefa..9cce04ccf8 100644
--- a/src/peft/tuners/lora/layer.py
+++ b/src/peft/tuners/lora/layer.py
@@ -140,6 +140,9 @@ def update_layer(
if isinstance(init_lora_weights, str) and init_lora_weights.startswith("pissa"):
with gather_params_ctx(self.get_base_layer().weight):
self.pissa_init(adapter_name, init_lora_weights)
+ elif isinstance(init_lora_weights, str) and init_lora_weights.startswith("corda"):
+ with gather_params_ctx(self.get_base_layer().weight):
+ self.corda_init(adapter_name, init_lora_weights)
elif isinstance(init_lora_weights, str) and init_lora_weights.lower() == "olora":
with gather_params_ctx(self.get_base_layer().weight):
self.olora_init(adapter_name)
@@ -266,6 +269,77 @@ def pissa_init(self, adapter_name, init_lora_weights):
weight = transpose(weight.to(dtype), self.fan_in_fan_out)
self.get_base_layer().weight.data = weight
+ def corda_init(self, adapter_name, init_lora_weights):
+ linear = self.get_base_layer()
+ weight = linear.weight
+ dtype = weight.dtype
+ if dtype not in [torch.float32, torch.float16, torch.bfloat16]:
+ raise TypeError(
+ "Please initialize CorDA under float32, float16, or bfloat16. "
+ "Subsequently, re-quantize the residual model to help minimize quantization errors."
+ )
+ weight = weight.to(torch.float32)
+ out_dim = weight.data.size(0)
+ in_dim = weight.data.size(1)
+
+ # Calculate WC from covariance matrix
+ if not hasattr(linear, "eigens"):
+ raise ValueError(
+ "`eigens` attribute not found for layer, please run `preprocess_corda` first. "
+ "More information can be found at examples/corda_finetuning/README.md."
+ )
+ eigens = linear.eigens
+ U = eigens.U_WC
+ S = eigens.S_WC
+ V = eigens.V_WC
+ r = self.r[adapter_name]
+
+ # nan or inf check
+ if torch.isnan(S).any() or torch.isinf(S).any():
+ raise ValueError(
+ "Invalid value found in matrix S. Please file an issue at https://github.com/huggingface/peft/issues."
+ )
+ if torch.isnan(U).any() or torch.isinf(U).any():
+ raise ValueError(
+ "Invalid value found in matrix U. Please file an issue at https://github.com/huggingface/peft/issues."
+ )
+ if torch.isnan(V).any() or torch.isinf(V).any():
+ raise ValueError(
+ "Invalid value found in matrix V. Please file an issue at https://github.com/huggingface/peft/issues."
+ )
+
+ # Sanity check
+ if U.size(0) != out_dim or U.size(1) != r:
+ raise ValueError(
+ f"Matrix U size mismatch: {U.size()} vs. ({out_dim}, {r}). Please make sure the `lora_config` and "
+ "`model` argument of `preprocess_corda` is consistent with `get_peft_model`. If you're using cache "
+ "in `preprocess_corda`, please make sure the cache is built with the same model and LoRA rank."
+ )
+ if S.size(0) != r:
+ raise ValueError(
+ f"Matrix S size mismatch: {S.size()} vs. ({r},). Please make sure the `lora_config` and `model` argument "
+ "of `preprocess_corda` is consistent with `get_peft_model`. If you're using cache in `preprocess_corda`, "
+ "please make sure the cache is built with the same model and LoRA rank."
+ )
+ if V.size(0) != in_dim or V.size(1) != r:
+ raise ValueError(
+ f"Matrix V size mismatch: {V.size()} vs. ({in_dim}, {r}). Please make sure the `lora_config` and "
+ "`model` argument of `preprocess_corda` is consistent with `get_peft_model`. If you're using cache "
+ "in `preprocess_corda`, please make sure the cache is built with the same model and LoRA rank."
+ )
+
+ # Apply alpha
+ S /= self.scaling[adapter_name]
+
+ # Init lora_A and lora_B weights
+ lora_A = V.t().mul(S.sqrt().view(-1, 1)).contiguous()
+ lora_B = U.mul(S.sqrt()).contiguous()
+ self.lora_A[adapter_name].weight.data = lora_A
+ self.lora_B[adapter_name].weight.data = lora_B
+ weight = weight.data - self.scaling[adapter_name] * lora_B @ lora_A
+ weight = weight.to(dtype)
+ self.get_base_layer().weight.data = weight
+
def loftq_init(self, adapter_name):
from peft.utils.loftq_utils import loftq_init
diff --git a/src/peft/tuners/lora/model.py b/src/peft/tuners/lora/model.py
index ba45202778..b249974302 100644
--- a/src/peft/tuners/lora/model.py
+++ b/src/peft/tuners/lora/model.py
@@ -902,9 +902,9 @@ def unload(self) -> torch.nn.Module:
def subtract_mutated_init(self, output_state_dict: dict[str, torch.Tensor], adapter_name: str, kwargs=None):
"""
- This function can calculate the updates of the [PiSSA | OLoRA] by comparing the parameters of the [PiSSA |
- OLoRA] adapter in `output_state_dict` with the initial values of [PiSSA | OLoRA] in `adapter_name`, thus
- converting [PiSSA | OLoRA] to LoRA.
+ This function can calculate the updates of the PiSSA/CorDA/OLoRA by comparing the parameters of the
+ PiSSA/CorDA/OLoRA adapter in `output_state_dict` with the initial values of PiSSA/CorDA/OLoRA in
+ `adapter_name`, thus converting PiSSA/CorDA/OLoRA to LoRA.
"""
for name, param in self.model.named_parameters():
if (
diff --git a/src/peft/tuners/lora/tp_layer.py b/src/peft/tuners/lora/tp_layer.py
index e12609d008..f47c565588 100644
--- a/src/peft/tuners/lora/tp_layer.py
+++ b/src/peft/tuners/lora/tp_layer.py
@@ -158,6 +158,9 @@ def update_layer(
if isinstance(init_lora_weights, str) and init_lora_weights.startswith("pissa"):
with gather_params_ctx(self.get_base_layer().weight):
self.pissa_init(adapter_name, init_lora_weights)
+ elif isinstance(init_lora_weights, str) and init_lora_weights.startswith("corda"):
+ with gather_params_ctx(self.get_base_layer().weight):
+ self.corda_init(adapter_name, init_lora_weights)
elif isinstance(init_lora_weights, str) and init_lora_weights.lower() == "olora":
with gather_params_ctx(self.get_base_layer().weight):
self.olora_init(adapter_name)
diff --git a/tests/test_initialization.py b/tests/test_initialization.py
index a1c5b8de60..516e7a3751 100644
--- a/tests/test_initialization.py
+++ b/tests/test_initialization.py
@@ -55,6 +55,8 @@
inject_adapter_in_model,
set_peft_model_state_dict,
)
+from peft.tuners.lora.config import CordaConfig
+from peft.tuners.lora.corda import preprocess_corda
from peft.tuners.lora.layer import LoraLayer
from peft.utils import infer_device
from peft.utils.constants import PEFT_TYPE_TO_PREFIX_MAPPING
@@ -1817,6 +1819,566 @@ def test_warning_naming_conflict_save_and_load(self, recwarn, tmp_path):
assert any(expected_msg in str(w.message) for w in recwarn.list)
+class TestCordaInitialization:
+ """Test class to check the initialization of CorDA adapters."""
+
+ torch_device = infer_device()
+
+ def get_model(self):
+ class MyModule(nn.Module):
+ def __init__(self):
+ super().__init__()
+ # choose a large weight so that averages are close to expected values
+ self.linear = nn.Linear(1000, 1000)
+
+ def forward(self, x):
+ return self.linear(x)
+
+ return MyModule().eval().to(self.torch_device)
+
+ @pytest.fixture
+ def data(self):
+ # larger data is required to pass KPM test
+ torch.manual_seed(233)
+ return torch.rand(1000, 1000).to(self.torch_device)
+
+ @pytest.mark.parametrize("corda_method", ("ipm", "kpm"))
+ def test_lora_corda_sample_count(self, data, corda_method):
+ original_model = self.get_model()
+ model = deepcopy(original_model)
+
+ corda_config = CordaConfig(
+ corda_method=corda_method,
+ )
+ config = LoraConfig(
+ init_lora_weights="corda",
+ target_modules=["linear"],
+ corda_config=corda_config,
+ )
+ preprocess_corda(
+ model,
+ config,
+ run_model=lambda: [model(data), model(data)], # running model twice to test `sample_count`
+ hooked_model=model,
+ )
+
+ # covariance of linear should be data.T @ data
+ layer = model.linear
+ assert hasattr(layer, "covariance_matrix")
+ assert torch.allclose(layer.covariance_matrix, data.T @ data, atol=1e-06)
+
+ # sample count of linear should be 2
+ assert hasattr(layer, "sample_count")
+ assert layer.sample_count == 2
+
+ @pytest.mark.parametrize("corda_method", ("ipm", "kpm"))
+ def test_lora_corda_hook_unregister(self, data, corda_method):
+ original_model = self.get_model()
+ model = deepcopy(original_model)
+
+ hook_call_count = 0
+
+ def hook(*args):
+ nonlocal hook_call_count
+ hook_call_count += 1
+
+ model.linear.register_forward_hook(hook)
+
+ corda_config = CordaConfig(
+ corda_method=corda_method,
+ )
+ config = LoraConfig(
+ init_lora_weights="corda",
+ target_modules=["linear"],
+ corda_config=corda_config,
+ )
+ preprocess_corda(
+ model,
+ config,
+ run_model=lambda: model(data),
+ hooked_model=model,
+ )
+
+ # after preprocessing, external and internal hook should be run once
+ assert hook_call_count == 1
+ assert model.linear.sample_count == 1
+
+ # run preprocessed model once
+ model(data)[0]
+
+ # the external hook should be kept, but the internal hook should be gone
+ assert hook_call_count == 2
+ assert model.linear.sample_count == 1
+
+ @pytest.mark.parametrize("corda_method", ("ipm", "kpm"))
+ def test_lora_corda_linear_init_default(self, data, tmp_path, corda_method):
+ original_model = self.get_model()
+ model = deepcopy(original_model)
+ output_base = model(data)[0]
+
+ corda_config = CordaConfig(
+ cache_file=tmp_path / "corda_cache.pt",
+ covariance_file=tmp_path / "covariance_cache.pt",
+ corda_method=corda_method,
+ )
+ config = LoraConfig(
+ init_lora_weights="corda",
+ target_modules=["linear"],
+ corda_config=corda_config,
+ )
+ preprocess_corda(
+ model,
+ config,
+ run_model=lambda: model(data),
+ hooked_model=model,
+ )
+ peft_model = get_peft_model(model, config)
+
+ # check if adapter performs an identity transformantion
+ assert torch.allclose(output_base, peft_model(data)[0], atol=1e-06)
+
+ # modify the weights, or else the adapter performs an identity transformation
+ peft_model.base_model.linear.lora_B["default"].weight.data *= 2.0
+ output_corda = peft_model(data)[0]
+
+ # sanity check
+ tol = 1e-06
+ assert not torch.allclose(output_base, output_corda, atol=tol, rtol=tol)
+
+ # if load SVD result from cache, the output should be the same
+ model = deepcopy(original_model)
+ config = LoraConfig(
+ init_lora_weights="corda",
+ target_modules=["linear"],
+ corda_config=CordaConfig(cache_file=tmp_path / "corda_cache.pt", corda_method=corda_method),
+ )
+ preprocess_corda(model, config)
+ peft_model = get_peft_model(model, config)
+ peft_model.base_model.linear.lora_B["default"].weight.data *= 2.0
+ assert torch.allclose(output_corda, peft_model(data)[0], atol=1e-06)
+
+ # if load covariance from cache, the output should be the same
+ model = deepcopy(original_model)
+ config = LoraConfig(
+ init_lora_weights="corda",
+ target_modules=["linear"],
+ corda_config=CordaConfig(covariance_file=tmp_path / "covariance_cache.pt", corda_method=corda_method),
+ )
+ preprocess_corda(model, config)
+ peft_model = get_peft_model(model, config)
+ peft_model.base_model.linear.lora_B["default"].weight.data *= 2.0
+ assert torch.allclose(output_corda, peft_model(data)[0], atol=1e-06)
+
+ @pytest.mark.parametrize("corda_method", ("ipm", "kpm"))
+ def test_lora_corda_hooked_model_linear_init_default(self, data, tmp_path, corda_method):
+ original_model = self.get_model()
+ model = deepcopy(original_model)
+ hooked_model = deepcopy(model)
+ output_base = model(data)[0]
+
+ corda_config = CordaConfig(
+ cache_file=tmp_path / "corda_cache.pt",
+ covariance_file=tmp_path / "covariance_cache.pt",
+ corda_method=corda_method,
+ )
+ config = LoraConfig(
+ init_lora_weights="corda",
+ target_modules=["linear"],
+ corda_config=corda_config,
+ )
+
+ # difference from the above test: this test uses a copied model as hooked model
+ preprocess_corda(
+ model,
+ config,
+ run_model=lambda: hooked_model(data),
+ hooked_model=hooked_model,
+ )
+ peft_model = get_peft_model(model, config)
+
+ # check if adapter performs an identity transformantion
+ assert torch.allclose(output_base, peft_model(data)[0], atol=1e-06)
+
+ # modify the weights, or else the adapter performs an identity transformation
+ peft_model.base_model.linear.lora_B["default"].weight.data *= 2.0
+ output_corda = peft_model(data)[0]
+
+ # sanity check
+ tol = 1e-06
+ assert not torch.allclose(output_base, output_corda, atol=tol, rtol=tol)
+
+ # if load SVD result from cache, the output should be the same
+ model = deepcopy(original_model)
+ config = LoraConfig(
+ init_lora_weights="corda",
+ target_modules=["linear"],
+ corda_config=CordaConfig(cache_file=tmp_path / "corda_cache.pt", corda_method=corda_method),
+ )
+ preprocess_corda(model, config)
+ peft_model = get_peft_model(model, config)
+ peft_model.base_model.linear.lora_B["default"].weight.data *= 2.0
+ assert torch.allclose(output_corda, peft_model(data)[0], atol=1e-06)
+
+ # if load covariance from cache, the output should be the same
+ model = deepcopy(original_model)
+ config = LoraConfig(
+ init_lora_weights="corda",
+ target_modules=["linear"],
+ corda_config=CordaConfig(covariance_file=tmp_path / "covariance_cache.pt", corda_method=corda_method),
+ )
+ preprocess_corda(model, config)
+ peft_model = get_peft_model(model, config)
+ peft_model.base_model.linear.lora_B["default"].weight.data *= 2.0
+ assert torch.allclose(output_corda, peft_model(data)[0], atol=1e-06)
+
+ @pytest.mark.parametrize("corda_method", ("ipm", "kpm"))
+ def test_lora_corda_linear_init_default_with_rank_pattern(self, data, tmp_path, corda_method):
+ original_model = self.get_model()
+ model = deepcopy(original_model)
+ output_base = model(data)[0]
+
+ corda_config = CordaConfig(
+ cache_file=tmp_path / "corda_cache.pt",
+ covariance_file=tmp_path / "covariance_cache.pt",
+ corda_method=corda_method,
+ )
+ config = LoraConfig(
+ rank_pattern={"linear": 8, "embed": 16, "conv2d": 32},
+ init_lora_weights="corda",
+ target_modules=["linear"],
+ corda_config=corda_config,
+ )
+ preprocess_corda(
+ model,
+ config,
+ run_model=lambda: model(data),
+ )
+ peft_model = get_peft_model(model, config)
+
+ # check if adapter performs an identity transformantion
+ assert torch.allclose(output_base, peft_model(data)[0], atol=1e-06)
+
+ # modify the weights, or else the adapter performs an identity transformation
+ peft_model.base_model.linear.lora_B["default"].weight.data *= 2.0
+ output_corda = peft_model(data)[0]
+
+ # sanity check
+ tol = 1e-06
+ assert not torch.allclose(output_base, output_corda, atol=tol, rtol=tol)
+
+ # if load SVD result from cache, the output should be the same
+ model = deepcopy(original_model)
+ config = LoraConfig(
+ rank_pattern={"linear": 8, "embed": 16, "conv2d": 32},
+ init_lora_weights="corda",
+ target_modules=["linear"],
+ corda_config=CordaConfig(cache_file=tmp_path / "corda_cache.pt", corda_method=corda_method),
+ )
+ preprocess_corda(model, config)
+ peft_model = get_peft_model(model, config)
+ peft_model.base_model.linear.lora_B["default"].weight.data *= 2.0
+ assert torch.allclose(output_corda, peft_model(data)[0], atol=1e-06)
+
+ # if load covariance from cache, the output should be the same
+ model = deepcopy(original_model)
+ config = LoraConfig(
+ rank_pattern={"linear": 8, "embed": 16, "conv2d": 32},
+ init_lora_weights="corda",
+ target_modules=["linear"],
+ corda_config=CordaConfig(covariance_file=tmp_path / "covariance_cache.pt", corda_method=corda_method),
+ )
+ preprocess_corda(model, config)
+ peft_model = get_peft_model(model, config)
+ peft_model.base_model.linear.lora_B["default"].weight.data *= 2.0
+ assert torch.allclose(output_corda, peft_model(data)[0], atol=1e-06)
+
+ @pytest.mark.parametrize("corda_method", ("ipm", "kpm"))
+ def test_lora_corda_conversion_same_output_after_loading(self, data, tmp_path, corda_method):
+ model = self.get_model()
+ output_base = model(data)[0]
+
+ corda_config = CordaConfig(corda_method=corda_method)
+ config = LoraConfig(init_lora_weights="corda", target_modules=["linear"], r=8, corda_config=corda_config)
+ preprocess_corda(model, config, run_model=lambda: model(data), hooked_model=model)
+ peft_model = get_peft_model(deepcopy(model), config)
+ # save the initial model
+ peft_model.peft_config["default"].init_lora_weights = True
+ peft_model.save_pretrained(tmp_path / "init-model")
+ peft_model.peft_config["default"].init_lora_weights = "corda"
+
+ # modify the weights, or else the adapter performs an identity transformation
+ peft_model.base_model.linear.lora_B["default"].weight.data *= 2.0
+ output_corda = peft_model(data)[0]
+
+ # sanity check
+ tol = 1e-06
+ assert not torch.allclose(output_base, output_corda, atol=tol, rtol=tol)
+
+ # save the model normally
+ peft_model.save_pretrained(tmp_path / "corda-model")
+ model_loaded = PeftModel.from_pretrained(deepcopy(model), tmp_path / "corda-model")
+ output_loaded = model_loaded(data)[0]
+
+ assert torch.allclose(output_corda, output_loaded, atol=tol, rtol=tol)
+ # sanity check: ranks should still be 8 as initially
+ assert model_loaded.peft_config["default"].r == 8
+ assert model_loaded.base_model.model.linear.lora_A["default"].weight.shape[0] == 8
+ # sanity check: the base model weights were indeed changed
+ assert not torch.allclose(
+ model.linear.weight, model_loaded.base_model.model.linear.base_layer.weight, atol=tol, rtol=tol
+ )
+
+ # save the model with conversion
+ peft_config_keys_before = list(peft_model.peft_config.keys())
+ peft_config_dict_before = peft_model.peft_config["default"].to_dict()
+ peft_model.save_pretrained(
+ tmp_path / "corda-model-converted", path_initial_model_for_weight_conversion=tmp_path / "init-model"
+ )
+ peft_config_keys_after = list(peft_model.peft_config.keys())
+ peft_config_dict_after = peft_model.peft_config["default"].to_dict()
+ assert peft_config_keys_before == peft_config_keys_after
+ assert peft_config_dict_before == peft_config_dict_after
+
+ model_converted = PeftModel.from_pretrained(deepcopy(model), tmp_path / "corda-model-converted")
+ output_converted = model_converted(data)[0]
+
+ assert torch.allclose(output_corda, output_converted, atol=tol, rtol=tol)
+ # rank should be double of what it was initially
+ assert model_converted.peft_config["default"].r == 16
+ assert model_converted.base_model.model.linear.lora_A["default"].weight.shape[0] == 16
+ # base model weights should be the same as the initial model
+ assert torch.allclose(
+ model.linear.weight, model_converted.base_model.model.linear.base_layer.weight, atol=tol, rtol=tol
+ )
+
+ @pytest.mark.parametrize("corda_method", ("ipm", "kpm"))
+ def test_lora_corda_conversion_same_output_after_loading_with_rank_pattern(self, data, tmp_path, corda_method):
+ # same as above, but using rank_pattern
+ model = self.get_model()
+ output_base = model(data)[0]
+
+ # use rank_pattern here; note that since there is only a single linear layer, r is completely overridden
+ corda_config = CordaConfig(corda_method=corda_method)
+ config = LoraConfig(
+ init_lora_weights="corda",
+ target_modules=["linear"],
+ r=8,
+ rank_pattern={"linear": 32},
+ corda_config=corda_config,
+ )
+ preprocess_corda(model, config, run_model=lambda: model(data), hooked_model=model)
+ peft_model = get_peft_model(deepcopy(model), config)
+ # save the initial model
+ peft_model.peft_config["default"].init_lora_weights = True
+ peft_model.save_pretrained(tmp_path / "init-model")
+ peft_model.peft_config["default"].init_lora_weights = "corda"
+
+ # modify the weights, or else the adapter performs an identity transformation
+ peft_model.base_model.linear.lora_B["default"].weight.data *= 2.0
+ output_corda = peft_model(data)[0]
+
+ # sanity check
+ tol = 1e-06
+ assert not torch.allclose(output_base, output_corda, atol=tol, rtol=tol)
+
+ # save the model normally
+ peft_model.save_pretrained(tmp_path / "corda-model")
+ model_loaded = PeftModel.from_pretrained(deepcopy(model), tmp_path / "corda-model")
+ output_loaded = model_loaded(data)[0]
+
+ assert torch.allclose(output_corda, output_loaded, atol=tol, rtol=tol)
+ # sanity check: ranks should still be 8 as initially
+ assert model_loaded.peft_config["default"].r == 8
+ assert model_loaded.base_model.model.linear.lora_A["default"].weight.shape[0] == 32
+ # sanity check: the base model weights were indeed changed
+ assert not torch.allclose(
+ model.linear.weight, model_loaded.base_model.model.linear.base_layer.weight, atol=tol, rtol=tol
+ )
+
+ # save the model with conversion
+ peft_model.save_pretrained(
+ tmp_path / "corda-model-converted", path_initial_model_for_weight_conversion=tmp_path / "init-model"
+ )
+ model_converted = PeftModel.from_pretrained(deepcopy(model), tmp_path / "corda-model-converted")
+ output_converted = model_converted(data)[0]
+
+ assert torch.allclose(output_corda, output_converted, atol=tol, rtol=tol)
+ # rank should be double of what it was initially
+ assert model_converted.peft_config["default"].r == 16
+ assert model_converted.base_model.model.linear.lora_A["default"].weight.shape[0] == 64
+ # base model weights should be the same as the initial model
+ assert torch.allclose(
+ model.linear.weight, model_converted.base_model.model.linear.base_layer.weight, atol=tol, rtol=tol
+ )
+
+ @pytest.mark.parametrize("corda_method", ("ipm", "kpm"))
+ def test_lora_corda_conversion_same_output_after_loading_with_alpha_pattern(self, data, tmp_path, corda_method):
+ # same as above, but using alpha_pattern
+ model = self.get_model()
+ output_base = model(data)[0]
+
+ # use alpha_pattern here; note that since there is only a single linear layer, lora_alpha is completely
+ # overridden
+ corda_config = CordaConfig(corda_method=corda_method)
+ config = LoraConfig(
+ init_lora_weights="corda",
+ target_modules=["linear"],
+ alpha_pattern={"linear": 5},
+ corda_config=corda_config,
+ )
+ preprocess_corda(model, config, run_model=lambda: model(data), hooked_model=model)
+ peft_model = get_peft_model(deepcopy(model), config)
+ # save the initial model
+ peft_model.peft_config["default"].init_lora_weights = True
+ peft_model.save_pretrained(tmp_path / "init-model")
+ peft_model.peft_config["default"].init_lora_weights = "corda"
+
+ # modify the weights, or else the adapter performs an identity transformation
+ peft_model.base_model.linear.lora_B["default"].weight.data *= 2.0
+ output_corda = peft_model(data)[0]
+
+ # sanity check
+ tol = 1e-06
+ assert not torch.allclose(output_base, output_corda, atol=tol, rtol=tol)
+
+ # save the model normally
+ peft_model.save_pretrained(tmp_path / "corda-model")
+ model_loaded = PeftModel.from_pretrained(deepcopy(model), tmp_path / "corda-model")
+ output_loaded = model_loaded(data)[0]
+
+ assert torch.allclose(output_corda, output_loaded, atol=tol, rtol=tol)
+ # sanity check: ranks should still be 8 as initially
+ assert model_loaded.peft_config["default"].r == 8
+ assert model_loaded.base_model.model.linear.lora_A["default"].weight.shape[0] == 8
+ assert model_loaded.base_model.model.linear.scaling["default"] == 5 / 8
+ # sanity check: the base model weights were indeed changed
+ assert not torch.allclose(
+ model.linear.weight, model_loaded.base_model.model.linear.base_layer.weight, atol=tol, rtol=tol
+ )
+
+ # save the model with conversion
+ peft_model.save_pretrained(
+ tmp_path / "corda-model-converted", path_initial_model_for_weight_conversion=tmp_path / "init-model"
+ )
+ model_converted = PeftModel.from_pretrained(deepcopy(model), tmp_path / "corda-model-converted")
+ output_converted = model_converted(data)[0]
+
+ assert torch.allclose(output_corda, output_converted, atol=tol, rtol=tol)
+ # rank should be double of what it was initially
+ assert model_converted.peft_config["default"].r == 16
+ assert model_converted.base_model.model.linear.lora_A["default"].weight.shape[0] == 16
+ assert model_converted.base_model.model.linear.scaling["default"] == 10 / 16
+ # base model weights should be the same as the initial model
+ assert torch.allclose(
+ model.linear.weight, model_converted.base_model.model.linear.base_layer.weight, atol=tol, rtol=tol
+ )
+
+ @pytest.mark.parametrize("corda_method", ("ipm", "kpm"))
+ def test_lora_corda_conversion_same_output_after_loading_with_rslora(self, data, tmp_path, corda_method):
+ model = self.get_model()
+ output_base = model(data)[0]
+
+ corda_config = CordaConfig(corda_method=corda_method)
+ config = LoraConfig(
+ init_lora_weights="corda", target_modules=["linear"], r=8, use_rslora=True, corda_config=corda_config
+ )
+ preprocess_corda(model, config, run_model=lambda: model(data), hooked_model=model)
+ peft_model = get_peft_model(deepcopy(model), config)
+ # save the initial model
+ peft_model.peft_config["default"].init_lora_weights = True
+ peft_model.save_pretrained(tmp_path / "init-model")
+ peft_model.peft_config["default"].init_lora_weights = "corda"
+
+ # modify the weights, or else the adapter performs an identity transformation
+ peft_model.base_model.linear.lora_B["default"].weight.data *= 2.0
+ output_corda = peft_model(data)[0]
+
+ # sanity check
+ tol = 1e-06
+ assert not torch.allclose(output_base, output_corda, atol=tol, rtol=tol)
+
+ # save the model normally
+ peft_model.save_pretrained(tmp_path / "corda-model")
+ model_loaded = PeftModel.from_pretrained(deepcopy(model), tmp_path / "corda-model")
+ output_loaded = model_loaded(data)[0]
+
+ assert torch.allclose(output_corda, output_loaded, atol=tol, rtol=tol)
+ # sanity check: ranks should still be 8 as initially
+ assert model_loaded.peft_config["default"].r == 8
+ assert model_loaded.base_model.model.linear.lora_A["default"].weight.shape[0] == 8
+ assert model_loaded.base_model.model.linear.scaling["default"] == 8 / (8**0.5)
+ # sanity check: the base model weights were indeed changed
+ assert not torch.allclose(
+ model.linear.weight, model_loaded.base_model.model.linear.base_layer.weight, atol=tol, rtol=tol
+ )
+
+ # save the model with conversion
+ peft_model.save_pretrained(
+ tmp_path / "corda-model-converted", path_initial_model_for_weight_conversion=tmp_path / "init-model"
+ )
+ model_converted = PeftModel.from_pretrained(deepcopy(model), tmp_path / "corda-model-converted")
+ output_converted = model_converted(data)[0]
+
+ assert torch.allclose(output_corda, output_converted, atol=tol, rtol=tol)
+ # rank should be double of what it was initially
+ assert model_converted.peft_config["default"].r == 16
+ assert model_converted.base_model.model.linear.lora_A["default"].weight.shape[0] == 16
+ # same scale as before with a little bit of floating point imprecision
+ assert model_converted.base_model.model.linear.scaling["default"] == pytest.approx(8 / (8**0.5))
+ # base model weights should be the same as the initial model
+ assert torch.allclose(
+ model.linear.weight, model_converted.base_model.model.linear.base_layer.weight, atol=tol, rtol=tol
+ )
+
+ @pytest.mark.parametrize("corda_method", ("ipm", "kpm"))
+ def test_lora_corda_rank_pattern_and_rslora_raises(self, data, tmp_path, corda_method):
+ # it's not possible to determine the correct scale when using rslora with rank or alpha pattern, because the
+ # scale is not stored in the state_dict
+ model = self.get_model()
+ corda_config = CordaConfig(corda_method=corda_method)
+ config = LoraConfig(
+ init_lora_weights="corda",
+ target_modules=["linear"],
+ r=8,
+ rank_pattern={"linear": 2},
+ use_rslora=True,
+ corda_config=corda_config,
+ )
+ preprocess_corda(model, config, run_model=lambda: model(data), hooked_model=model)
+ peft_model = get_peft_model(model, config)
+ peft_model.save_pretrained(tmp_path / "init-model")
+
+ msg = re.escape("Passing `path_initial_model_for_weight_conversion` to `save_pretrained`")
+ with pytest.raises(ValueError, match=msg):
+ peft_model.save_pretrained(
+ tmp_path / "corda-model", path_initial_model_for_weight_conversion=tmp_path / "init-model"
+ )
+
+ @pytest.mark.parametrize("corda_method", ("ipm", "kpm"))
+ def test_lora_corda_alpha_pattern_and_rslora_raises(self, data, tmp_path, corda_method):
+ # it's not possible to determine the correct scale when using rslora with rank or alpha pattern, because the
+ # scale is not stored in the state_dict
+ model = self.get_model()
+ corda_config = CordaConfig(corda_method=corda_method)
+ config = LoraConfig(
+ init_lora_weights="corda",
+ target_modules=["linear"],
+ r=8,
+ alpha_pattern={"linear": 2},
+ use_rslora=True,
+ corda_config=corda_config,
+ )
+ preprocess_corda(model, config, run_model=lambda: model(data), hooked_model=model)
+ peft_model = get_peft_model(model, config)
+ peft_model.save_pretrained(tmp_path / "init-model")
+
+ msg = re.escape("Passing `path_initial_model_for_weight_conversion` to `save_pretrained`")
+ with pytest.raises(ValueError, match=msg):
+ peft_model.save_pretrained(
+ tmp_path / "corda-model", path_initial_model_for_weight_conversion=tmp_path / "init-model"
+ )
+
+
class TestEvaInitialization:
"""Tests for the EVA (Explained Variance Adaptation) initialization method.