-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathevaluate_lightning_sequence_labeling.py
65 lines (57 loc) · 2.58 KB
/
evaluate_lightning_sequence_labeling.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import pprint
from typing import Optional
import typer
from embeddings.defaults import RESULTS_PATH
from embeddings.metric.sequence_labeling import EvaluationMode, TaggingScheme
from embeddings.pipeline.lightning_sequence_labeling import LightningSequenceLabelingPipeline
from embeddings.utils.loggers import LightningLoggingConfig
from embeddings.utils.utils import build_output_path, format_eval_results
def run(
embedding_name_or_path: str = typer.Option(
"allegro/herbert-base-cased", help="Hugging Face embedding model name or path."
),
dataset_name: str = typer.Option(
"clarin-pl/kpwr-ner", help="Hugging Face dataset name or path."
),
input_column_name: str = typer.Option(
"tokens", help="Column name that contains text to classify."
),
target_column_name: str = typer.Option(
"ner", help="Column name that contains tag labels for POS tagging."
),
evaluation_mode: EvaluationMode = typer.Option(
EvaluationMode.CONLL,
help="Evaluation mode. Supported modes: [unit, conll, strict].",
),
tagging_scheme: Optional[TaggingScheme] = typer.Option(
None, help="Tagging scheme. Supported schemes: [IOB1, IOB2, IOE1, IOE2, IOBES, BILOU]"
),
root: str = typer.Option(RESULTS_PATH.joinpath("lightning_sequence_classification")),
run_name: Optional[str] = typer.Option(None, help="Name of run used for logging."),
wandb: bool = typer.Option(False, help="Flag for using wandb."),
tensorboard: bool = typer.Option(False, help="Flag for using tensorboard."),
csv: bool = typer.Option(False, help="Flag for using csv."),
tracking_project_name: Optional[str] = typer.Option(None, help="Name of wandb project."),
wandb_entity: Optional[str] = typer.Option(None, help="Name of wandb entity."),
) -> None:
typer.echo(pprint.pformat(locals()))
output_path = build_output_path(root, embedding_name_or_path, dataset_name)
pipeline = LightningSequenceLabelingPipeline(
embedding_name_or_path=embedding_name_or_path,
dataset_name_or_path=dataset_name,
input_column_name=input_column_name,
target_column_name=target_column_name,
output_path=root,
evaluation_mode=evaluation_mode,
tagging_scheme=tagging_scheme,
logging_config=LightningLoggingConfig.from_flags(
wandb=wandb,
tensorboard=tensorboard,
csv=csv,
tracking_project_name=tracking_project_name,
wandb_entity=wandb_entity,
),
)
result = pipeline.run()
typer.echo(format_eval_results(result))
typer.run(run)