-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #6 from zenml-io/feature/OSSK-569-accelerated-temp…
…late Accelerated template
- Loading branch information
Showing
22 changed files
with
388 additions
and
130 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
# {% include 'template/license_header' %} | ||
|
||
from steps import ( | ||
evaluate_model, | ||
finetune, | ||
prepare_data, | ||
promote, | ||
log_metadata_from_step_artifact, | ||
) | ||
from zenml import pipeline | ||
from zenml.integrations.huggingface.steps import run_with_accelerate | ||
|
||
|
||
@pipeline | ||
def {{ product_name.replace("-","_") }}_full_finetune( | ||
system_prompt: str, | ||
base_model_id: str, | ||
use_fast: bool = True, | ||
load_in_8bit: bool = False, | ||
load_in_4bit: bool = False, | ||
): | ||
"""Pipeline for finetuning an LLM with PEFT powered by Accelerate. | ||
It will run the following steps: | ||
- prepare_data: prepare the datasets and tokenize them | ||
- finetune: finetune the model | ||
- evaluate_model: evaluate the base and finetuned model | ||
- promote: promote the model to the target stage, if evaluation was successful | ||
""" | ||
if not load_in_8bit and not load_in_4bit: | ||
raise ValueError( | ||
"At least one of `load_in_8bit` and `load_in_4bit` must be True." | ||
) | ||
if load_in_4bit and load_in_8bit: | ||
raise ValueError("Only one of `load_in_8bit` and `load_in_4bit` can be True.") | ||
|
||
datasets_dir = prepare_data( | ||
base_model_id=base_model_id, | ||
system_prompt=system_prompt, | ||
use_fast=use_fast, | ||
) | ||
|
||
evaluate_model( | ||
base_model_id, | ||
system_prompt, | ||
datasets_dir, | ||
None, | ||
use_fast=use_fast, | ||
load_in_8bit=load_in_8bit, | ||
load_in_4bit=load_in_4bit, | ||
id="evaluate_base", | ||
) | ||
log_metadata_from_step_artifact( | ||
"evaluate_base", | ||
"base_model_rouge_metrics", | ||
after=["evaluate_base"], | ||
id="log_metadata_evaluation_base" | ||
) | ||
|
||
ft_model_dir = run_with_accelerate(finetune)( | ||
base_model_id=base_model_id, | ||
dataset_dir=datasets_dir, | ||
use_fast=use_fast, | ||
load_in_8bit=load_in_8bit, | ||
load_in_4bit=load_in_4bit, | ||
) | ||
|
||
evaluate_model( | ||
base_model_id, | ||
system_prompt, | ||
datasets_dir, | ||
ft_model_dir, | ||
use_fast=use_fast, | ||
load_in_8bit=load_in_8bit, | ||
load_in_4bit=load_in_4bit, | ||
id="evaluate_finetuned", | ||
) | ||
log_metadata_from_step_artifact( | ||
"evaluate_finetuned", | ||
"finetuned_model_rouge_metrics", | ||
after=["evaluate_finetuned"], | ||
id="log_metadata_evaluation_finetuned" | ||
) | ||
|
||
promote(after=["log_metadata_evaluation_finetuned", "log_metadata_evaluation_base"]) |
Oops, something went wrong.