From 461bdd036507823a3d5cafdfde753d7e0fdbd09f Mon Sep 17 00:00:00 2001 From: aj280192 Date: Thu, 17 Mar 2022 20:01:38 +0100 Subject: [PATCH] Adding layer deeplift shap and layer gradient shap explainers. --- configs/ag_news/albert/lds.jsonnet | 31 + configs/ag_news/albert/lgs.jsonnet | 32 + configs/ag_news/bert/lds.jsonnet | 31 + configs/ag_news/bert/lgs.jsonnet | 32 + configs/ag_news/roberta/lds.jsonnet | 31 + configs/ag_news/roberta/lgs.jsonnet | 32 + configs/imdb/albert/lds.jsonnet | 31 + configs/imdb/albert/lgs.jsonnet | 32 + configs/imdb/bert/lds.jsonnet | 31 + configs/imdb/bert/lgs.jsonnet | 32 + configs/imdb/electra/lds.jsonnet | 31 + configs/imdb/electra/lgs.jsonnet | 32 + configs/imdb/roberta/lds.jsonnet | 31 + configs/imdb/roberta/lgs.jsonnet | 32 + configs/imdb/xlnet/lds.jsonnet | 31 + configs/imdb/xlnet/lgs.jsonnet | 32 + configs/mnli/albert/lds.jsonnet | 32 + configs/mnli/albert/lgs.jsonnet | 33 + configs/mnli/bert/lds.jsonnet | 32 + configs/mnli/bert/lgs.jsonnet | 33 + configs/mnli/electra/lds.jsonnet | 32 + configs/mnli/electra/lgs.jsonnet | 33 + configs/mnli/roberta/lds.jsonnet | 32 + configs/mnli/roberta/lgs.jsonnet | 33 + configs/mnli/xlnet/lds.jsonnet | 32 + configs/mnli/xlnet/lgs.jsonnet | 33 + configs/xnli/albert/lds.jsonnet | 33 + configs/xnli/albert/lgs.jsonnet | 34 + configs/xnli/bert/lds.jsonnet | 33 + configs/xnli/bert/lgs.jsonnet | 34 + configs/xnli/electra/lds.jsonnet | 33 + configs/xnli/electra/lgs.jsonnet | 34 + configs/xnli/roberta/lds.jsonnet | 33 + configs/xnli/roberta/lgs.jsonnet | 34 + configs/xnli/xlnet/lds.jsonnet | 33 + configs/xnli/xlnet/lgs.jsonnet | 34 + demo.ipynb | 3366 +++++++++++++++++---- src/thermostat/data/thermostat_configs.py | 240 ++ src/thermostat/explainers/__init__.py | 5 + src/thermostat/explainers/shap.py | 106 + 40 files changed, 4353 insertions(+), 528 deletions(-) create mode 100644 configs/ag_news/albert/lds.jsonnet create mode 100644 configs/ag_news/albert/lgs.jsonnet create mode 100644 configs/ag_news/bert/lds.jsonnet create mode 100644 configs/ag_news/bert/lgs.jsonnet create mode 100644 configs/ag_news/roberta/lds.jsonnet create mode 100644 configs/ag_news/roberta/lgs.jsonnet create mode 100644 configs/imdb/albert/lds.jsonnet create mode 100644 configs/imdb/albert/lgs.jsonnet create mode 100644 configs/imdb/bert/lds.jsonnet create mode 100644 configs/imdb/bert/lgs.jsonnet create mode 100644 configs/imdb/electra/lds.jsonnet create mode 100644 configs/imdb/electra/lgs.jsonnet create mode 100644 configs/imdb/roberta/lds.jsonnet create mode 100644 configs/imdb/roberta/lgs.jsonnet create mode 100644 configs/imdb/xlnet/lds.jsonnet create mode 100644 configs/imdb/xlnet/lgs.jsonnet create mode 100644 configs/mnli/albert/lds.jsonnet create mode 100644 configs/mnli/albert/lgs.jsonnet create mode 100644 configs/mnli/bert/lds.jsonnet create mode 100644 configs/mnli/bert/lgs.jsonnet create mode 100644 configs/mnli/electra/lds.jsonnet create mode 100644 configs/mnli/electra/lgs.jsonnet create mode 100644 configs/mnli/roberta/lds.jsonnet create mode 100644 configs/mnli/roberta/lgs.jsonnet create mode 100644 configs/mnli/xlnet/lds.jsonnet create mode 100644 configs/mnli/xlnet/lgs.jsonnet create mode 100644 configs/xnli/albert/lds.jsonnet create mode 100644 configs/xnli/albert/lgs.jsonnet create mode 100644 configs/xnli/bert/lds.jsonnet create mode 100644 configs/xnli/bert/lgs.jsonnet create mode 100644 configs/xnli/electra/lds.jsonnet create mode 100644 configs/xnli/electra/lgs.jsonnet create mode 100644 configs/xnli/roberta/lds.jsonnet create mode 100644 configs/xnli/roberta/lgs.jsonnet create mode 100644 configs/xnli/xlnet/lds.jsonnet create mode 100644 configs/xnli/xlnet/lgs.jsonnet create mode 100644 src/thermostat/explainers/shap.py diff --git a/configs/ag_news/albert/lds.jsonnet b/configs/ag_news/albert/lds.jsonnet new file mode 100644 index 0000000..29e5c7c --- /dev/null +++ b/configs/ag_news/albert/lds.jsonnet @@ -0,0 +1,31 @@ +{ + "path": "$HOME/experiments/thermostat", + "device": "cuda", + "dataset": { + "name": "ag_news", + "split": "test", + "columns": ['input_ids', 'attention_mask', 'token_type_ids', 'special_tokens_mask', 'labels'], + "batch_size": 1, + "root_dir": "$HOME/experiments/thermostat/datasets", + }, + "explainer": { + "name": "LayerDeepLiftShap", + }, + "model": { + "name": "textattack/albert-base-v2-ag-news", + "mode_load": "hf", + "path_model": null, + "tokenization": { + "max_length": 512, + "padding": "max_length", + "return_tensors": "np", + "truncation": true, + "special_tokens_mask": true, + } + }, + "visualization": { + "columns": ["attributions", "predictions", "input_ids", "labels"], + "gamma": 2.0, + "normalize": true, + } +} diff --git a/configs/ag_news/albert/lgs.jsonnet b/configs/ag_news/albert/lgs.jsonnet new file mode 100644 index 0000000..937ca1e --- /dev/null +++ b/configs/ag_news/albert/lgs.jsonnet @@ -0,0 +1,32 @@ +{ + "path": "$HOME/experiments/thermostat", + "device": "cuda", + "dataset": { + "name": "ag_news", + "split": "test", + "columns": ['input_ids', 'attention_mask', 'token_type_ids', 'special_tokens_mask', 'labels'], + "batch_size": 1, + "root_dir": "$HOME/experiments/thermostat/datasets", + }, + "explainer": { + "name": "LayerGradientShap", + "n_samples": 5, + }, + "model": { + "name": "textattack/albert-base-v2-ag-news", + "mode_load": "hf", + "path_model": null, + "tokenization": { + "max_length": 512, + "padding": "max_length", + "return_tensors": "np", + "truncation": true, + "special_tokens_mask": true, + } + }, + "visualization": { + "columns": ["attributions", "predictions", "input_ids", "labels"], + "gamma": 2.0, + "normalize": true, + } +} diff --git a/configs/ag_news/bert/lds.jsonnet b/configs/ag_news/bert/lds.jsonnet new file mode 100644 index 0000000..3f5e52e --- /dev/null +++ b/configs/ag_news/bert/lds.jsonnet @@ -0,0 +1,31 @@ +{ + "path": "$HOME/experiments/thermostat", + "device": "cuda", + "dataset": { + "name": "ag_news", + "split": "test", + "columns": ['input_ids', 'attention_mask', 'token_type_ids', 'special_tokens_mask', 'labels'], + "batch_size": 1, + "root_dir": "$HOME/experiments/thermostat/datasets", + }, + "explainer": { + "name": "LayerDeepLiftShap", + }, + "model": { + "name": "textattack/bert-base-uncased-ag-news", + "mode_load": "hf", + "path_model": null, + "tokenization": { + "max_length": 512, + "padding": "max_length", + "return_tensors": "np", + "truncation": true, + "special_tokens_mask": true, + } + }, + "visualization": { + "columns": ["attributions", "predictions", "input_ids", "labels"], + "gamma": 2.0, + "normalize": true, + } +} diff --git a/configs/ag_news/bert/lgs.jsonnet b/configs/ag_news/bert/lgs.jsonnet new file mode 100644 index 0000000..be0fa74 --- /dev/null +++ b/configs/ag_news/bert/lgs.jsonnet @@ -0,0 +1,32 @@ +{ + "path": "$HOME/experiments/thermostat", + "device": "cuda", + "dataset": { + "name": "ag_news", + "split": "test", + "columns": ['input_ids', 'attention_mask', 'token_type_ids', 'special_tokens_mask', 'labels'], + "batch_size": 1, + "root_dir": "$HOME/experiments/thermostat/datasets", + }, + "explainer": { + "name": "LayerGradientShap", + "n_samples": 5, + }, + "model": { + "name": "textattack/bert-base-uncased-ag-news", + "mode_load": "hf", + "path_model": null, + "tokenization": { + "max_length": 512, + "padding": "max_length", + "return_tensors": "np", + "truncation": true, + "special_tokens_mask": true, + } + }, + "visualization": { + "columns": ["attributions", "predictions", "input_ids", "labels"], + "gamma": 2.0, + "normalize": true, + } +} diff --git a/configs/ag_news/roberta/lds.jsonnet b/configs/ag_news/roberta/lds.jsonnet new file mode 100644 index 0000000..0b6f9f0 --- /dev/null +++ b/configs/ag_news/roberta/lds.jsonnet @@ -0,0 +1,31 @@ +{ + "path": "$HOME/experiments/thermostat", + "device": "cuda", + "dataset": { + "name": "ag_news", + "split": "test", + "columns": ['input_ids', 'attention_mask', 'special_tokens_mask', 'labels'], + "batch_size": 1, + "root_dir": "$HOME/experiments/thermostat/datasets", + }, + "explainer": { + "name": "LayerDeepLiftShap", + }, + "model": { + "name": "textattack/roberta-base-ag-news", + "mode_load": "hf", + "path_model": null, + "tokenization": { + "max_length": 512, + "padding": "max_length", + "return_tensors": "np", + "truncation": true, + "special_tokens_mask": true, + } + }, + "visualization": { + "columns": ["attributions", "predictions", "input_ids", "labels"], + "gamma": 2.0, + "normalize": true, + } +} diff --git a/configs/ag_news/roberta/lgs.jsonnet b/configs/ag_news/roberta/lgs.jsonnet new file mode 100644 index 0000000..c641824 --- /dev/null +++ b/configs/ag_news/roberta/lgs.jsonnet @@ -0,0 +1,32 @@ +{ + "path": "$HOME/experiments/thermostat", + "device": "cuda", + "dataset": { + "name": "ag_news", + "split": "test", + "columns": ['input_ids', 'attention_mask', 'special_tokens_mask', 'labels'], + "batch_size": 1, + "root_dir": "$HOME/experiments/thermostat/datasets", + }, + "explainer": { + "name": "LayerGradientShap", + "n_samples": 5, + }, + "model": { + "name": "textattack/roberta-base-ag-news", + "mode_load": "hf", + "path_model": null, + "tokenization": { + "max_length": 512, + "padding": "max_length", + "return_tensors": "np", + "truncation": true, + "special_tokens_mask": true, + } + }, + "visualization": { + "columns": ["attributions", "predictions", "input_ids", "labels"], + "gamma": 2.0, + "normalize": true, + } +} diff --git a/configs/imdb/albert/lds.jsonnet b/configs/imdb/albert/lds.jsonnet new file mode 100644 index 0000000..d7e383b --- /dev/null +++ b/configs/imdb/albert/lds.jsonnet @@ -0,0 +1,31 @@ +{ + "path": "$HOME/experiments/thermostat", + "device": "cuda", + "dataset": { + "name": "imdb", + "split": "test", + "columns": ['input_ids', 'attention_mask', 'token_type_ids', 'special_tokens_mask', 'labels'], + "batch_size": 1, + "root_dir": "$HOME/experiments/thermostat/datasets", + }, + "explainer": { + "name": "LayerDeepLiftShap", + }, + "model": { + "name": "textattack/albert-base-v2-imdb", + "mode_load": "hf", + "path_model": null, + "tokenization": { + "max_length": 512, + "padding": "max_length", + "return_tensors": "np", + "truncation": true, + "special_tokens_mask": true, + } + }, + "visualization": { + "columns": ["attributions", "predictions", "input_ids", "labels"], + "gamma": 2.0, + "normalize": true, + } +} diff --git a/configs/imdb/albert/lgs.jsonnet b/configs/imdb/albert/lgs.jsonnet new file mode 100644 index 0000000..c3406cd --- /dev/null +++ b/configs/imdb/albert/lgs.jsonnet @@ -0,0 +1,32 @@ +{ + "path": "$HOME/experiments/thermostat", + "device": "cuda", + "dataset": { + "name": "imdb", + "split": "test", + "columns": ['input_ids', 'attention_mask', 'token_type_ids', 'special_tokens_mask', 'labels'], + "batch_size": 1, + "root_dir": "$HOME/experiments/thermostat/datasets", + }, + "explainer": { + "name": "LayerGradientShap", + "n_samples": 5, + }, + "model": { + "name": "textattack/albert-base-v2-imdb", + "mode_load": "hf", + "path_model": null, + "tokenization": { + "max_length": 512, + "padding": "max_length", + "return_tensors": "np", + "truncation": true, + "special_tokens_mask": true, + } + }, + "visualization": { + "columns": ["attributions", "predictions", "input_ids", "labels"], + "gamma": 2.0, + "normalize": true, + } +} diff --git a/configs/imdb/bert/lds.jsonnet b/configs/imdb/bert/lds.jsonnet new file mode 100644 index 0000000..e1a698b --- /dev/null +++ b/configs/imdb/bert/lds.jsonnet @@ -0,0 +1,31 @@ +{ + "path": "$HOME/experiments/thermostat", + "device": "cuda", + "dataset": { + "name": "imdb", + "split": "test", + "columns": ['input_ids', 'attention_mask', 'special_tokens_mask', 'token_type_ids', 'labels'], + "batch_size": 1, + "root_dir": "$HOME/experiments/thermostat/datasets", + }, + "explainer": { + "name": "LayerDeepLiftShap", + }, + "model": { + "name": "textattack/bert-base-uncased-imdb", + "mode_load": "hf", + "path_model": null, + "tokenization": { + "max_length": 512, + "padding": "max_length", + "return_tensors": "np", + "truncation": true, + "special_tokens_mask": true, + } + }, + "visualization": { + "columns": ["attributions", "predictions", "input_ids", "labels"], + "gamma": 2.0, + "normalize": true, + } +} diff --git a/configs/imdb/bert/lgs.jsonnet b/configs/imdb/bert/lgs.jsonnet new file mode 100644 index 0000000..c52f22d --- /dev/null +++ b/configs/imdb/bert/lgs.jsonnet @@ -0,0 +1,32 @@ +{ + "path": "$HOME/experiments/thermostat", + "device": "cuda", + "dataset": { + "name": "imdb", + "split": "test", + "columns": ['input_ids', 'attention_mask', 'special_tokens_mask', 'token_type_ids', 'labels'], + "batch_size": 1, + "root_dir": "$HOME/experiments/thermostat/datasets", + }, + "explainer": { + "name": "LayerGradientShap", + "n_samples": 5, + }, + "model": { + "name": "textattack/bert-base-uncased-imdb", + "mode_load": "hf", + "path_model": null, + "tokenization": { + "max_length": 512, + "padding": "max_length", + "return_tensors": "np", + "truncation": true, + "special_tokens_mask": true, + } + }, + "visualization": { + "columns": ["attributions", "predictions", "input_ids", "labels"], + "gamma": 2.0, + "normalize": true, + } +} diff --git a/configs/imdb/electra/lds.jsonnet b/configs/imdb/electra/lds.jsonnet new file mode 100644 index 0000000..03bc8a7 --- /dev/null +++ b/configs/imdb/electra/lds.jsonnet @@ -0,0 +1,31 @@ +{ + "path": "$HOME/experiments/thermostat", + "device": "cuda", + "dataset": { + "name": "imdb", + "split": "test", + "columns": ['input_ids', 'attention_mask', 'token_type_ids', 'special_tokens_mask', 'labels'], + "batch_size": 1, + "root_dir": "$HOME/experiments/thermostat/datasets", + }, + "explainer": { + "name": "LayerDeepLiftShap", + }, + "model": { + "name": "monologg/electra-small-finetuned-imdb", + "mode_load": "hf", + "path_model": null, + "tokenization": { + "max_length": 512, + "padding": "max_length", + "return_tensors": "np", + "truncation": true, + "special_tokens_mask": true, + } + }, + "visualization": { + "columns": ["attributions", "predictions", "input_ids", "labels"], + "gamma": 2.0, + "normalize": true, + } +} diff --git a/configs/imdb/electra/lgs.jsonnet b/configs/imdb/electra/lgs.jsonnet new file mode 100644 index 0000000..33d9669 --- /dev/null +++ b/configs/imdb/electra/lgs.jsonnet @@ -0,0 +1,32 @@ +{ + "path": "$HOME/experiments/thermostat", + "device": "cuda", + "dataset": { + "name": "imdb", + "split": "test", + "columns": ['input_ids', 'attention_mask', 'token_type_ids', 'special_tokens_mask', 'labels'], + "batch_size": 1, + "root_dir": "$HOME/experiments/thermostat/datasets", + }, + "explainer": { + "name": "LayerGradientShap", + "n_samples": 5, + }, + "model": { + "name": "monologg/electra-small-finetuned-imdb", + "mode_load": "hf", + "path_model": null, + "tokenization": { + "max_length": 512, + "padding": "max_length", + "return_tensors": "np", + "truncation": true, + "special_tokens_mask": true, + } + }, + "visualization": { + "columns": ["attributions", "predictions", "input_ids", "labels"], + "gamma": 2.0, + "normalize": true, + } +} diff --git a/configs/imdb/roberta/lds.jsonnet b/configs/imdb/roberta/lds.jsonnet new file mode 100644 index 0000000..998bffd --- /dev/null +++ b/configs/imdb/roberta/lds.jsonnet @@ -0,0 +1,31 @@ +{ + "path": "$HOME/experiments/thermostat", + "device": "cuda", + "dataset": { + "name": "imdb", + "split": "test", + "columns": ['input_ids', 'special_tokens_mask', 'attention_mask', 'labels'], + "batch_size": 1, + "root_dir": "$HOME/experiments/thermostat/datasets", + }, + "explainer": { + "name": "LayerDeepLiftShap", + }, + "model": { + "name": "textattack/roberta-base-imdb", + "mode_load": "hf", + "path_model": null, + "tokenization": { + "max_length": 512, + "padding": "max_length", + "return_tensors": "np", + "truncation": true, + "special_tokens_mask": true, + } + }, + "visualization": { + "columns": ["attributions", "predictions", "input_ids", "labels"], + "gamma": 2.0, + "normalize": true, + } +} diff --git a/configs/imdb/roberta/lgs.jsonnet b/configs/imdb/roberta/lgs.jsonnet new file mode 100644 index 0000000..ce7721f --- /dev/null +++ b/configs/imdb/roberta/lgs.jsonnet @@ -0,0 +1,32 @@ +{ + "path": "$HOME/experiments/thermostat", + "device": "cuda", + "dataset": { + "name": "imdb", + "split": "test", + "columns": ['input_ids', 'special_tokens_mask', 'attention_mask', 'labels'], + "batch_size": 1, + "root_dir": "$HOME/experiments/thermostat/datasets", + }, + "explainer": { + "name": "LayerGradientShap", + "n_samples": 5, + }, + "model": { + "name": "textattack/roberta-base-imdb", + "mode_load": "hf", + "path_model": null, + "tokenization": { + "max_length": 512, + "padding": "max_length", + "return_tensors": "np", + "truncation": true, + "special_tokens_mask": true, + } + }, + "visualization": { + "columns": ["attributions", "predictions", "input_ids", "labels"], + "gamma": 2.0, + "normalize": true, + } +} diff --git a/configs/imdb/xlnet/lds.jsonnet b/configs/imdb/xlnet/lds.jsonnet new file mode 100644 index 0000000..474a0f7 --- /dev/null +++ b/configs/imdb/xlnet/lds.jsonnet @@ -0,0 +1,31 @@ +{ + "path": "$HOME/experiments/thermostat", + "device": "cuda", + "dataset": { + "name": "imdb", + "split": "test", + "columns": ['input_ids', 'attention_mask', 'token_type_ids', 'special_tokens_mask', 'labels'], + "batch_size": 1, + "root_dir": "$HOME/experiments/thermostat/datasets", + }, + "explainer": { + "name": "LayerDeepLiftShap", + }, + "model": { + "name": "textattack/xlnet-base-cased-imdb", + "mode_load": "hf", + "path_model": null, + "tokenization": { + "max_length": 512, + "padding": "max_length", + "return_tensors": "np", + "truncation": true, + "special_tokens_mask": true, + } + }, + "visualization": { + "columns": ["attributions", "predictions", "input_ids", "labels"], + "gamma": 2.0, + "normalize": true, + } +} diff --git a/configs/imdb/xlnet/lgs.jsonnet b/configs/imdb/xlnet/lgs.jsonnet new file mode 100644 index 0000000..f15b372 --- /dev/null +++ b/configs/imdb/xlnet/lgs.jsonnet @@ -0,0 +1,32 @@ +{ + "path": "$HOME/experiments/thermostat", + "device": "cuda", + "dataset": { + "name": "imdb", + "split": "test", + "columns": ['input_ids', 'attention_mask', 'token_type_ids', 'special_tokens_mask', 'labels'], + "batch_size": 1, + "root_dir": "$HOME/experiments/thermostat/datasets", + }, + "explainer": { + "name": "LayerGradientShap", + "n_samples": 5, + }, + "model": { + "name": "textattack/xlnet-base-cased-imdb", + "mode_load": "hf", + "path_model": null, + "tokenization": { + "max_length": 512, + "padding": "max_length", + "return_tensors": "np", + "truncation": true, + "special_tokens_mask": true, + } + }, + "visualization": { + "columns": ["attributions", "predictions", "input_ids", "labels"], + "gamma": 2.0, + "normalize": true, + } +} diff --git a/configs/mnli/albert/lds.jsonnet b/configs/mnli/albert/lds.jsonnet new file mode 100644 index 0000000..a0d7a46 --- /dev/null +++ b/configs/mnli/albert/lds.jsonnet @@ -0,0 +1,32 @@ +{ + "path": "$HOME/experiments/thermostat", + "device": "cuda", + "dataset": { + "name": "multi_nli", + "text_field": ["premise", "hypothesis"], + "split": "validation_matched", + "columns": ['input_ids', 'attention_mask', 'token_type_ids', 'special_tokens_mask', 'labels'], + "batch_size": 1, + "root_dir": "$HOME/experiments/thermostat/datasets", + }, + "explainer": { + "name": "LayerDeepLiftShap", + }, + "model": { + "name": "prajjwal1/albert-base-v2-mnli", + "mode_load": "hf", + "path_model": null, + "tokenization": { + "max_length": 512, + "padding": "max_length", + "return_tensors": "np", + "truncation": true, + "special_tokens_mask": true, + } + }, + "visualization": { + "columns": ["attributions", "predictions", "input_ids", "label"], + "gamma": 2.0, + "normalize": true, + } +} diff --git a/configs/mnli/albert/lgs.jsonnet b/configs/mnli/albert/lgs.jsonnet new file mode 100644 index 0000000..dbd0727 --- /dev/null +++ b/configs/mnli/albert/lgs.jsonnet @@ -0,0 +1,33 @@ +{ + "path": "$HOME/experiments/thermostat", + "device": "cuda", + "dataset": { + "name": "multi_nli", + "text_field": ["premise", "hypothesis"], + "split": "validation_matched", + "columns": ['input_ids', 'attention_mask', 'token_type_ids', 'special_tokens_mask', 'labels'], + "batch_size": 1, + "root_dir": "$HOME/experiments/thermostat/datasets", + }, + "explainer": { + "name": "LayerGradientShap", + "n_samples": 5, + }, + "model": { + "name": "prajjwal1/albert-base-v2-mnli", + "mode_load": "hf", + "path_model": null, + "tokenization": { + "max_length": 512, + "padding": "max_length", + "return_tensors": "np", + "truncation": true, + "special_tokens_mask": true, + } + }, + "visualization": { + "columns": ["attributions", "predictions", "input_ids", "label"], + "gamma": 2.0, + "normalize": true, + } +} diff --git a/configs/mnli/bert/lds.jsonnet b/configs/mnli/bert/lds.jsonnet new file mode 100644 index 0000000..7997184 --- /dev/null +++ b/configs/mnli/bert/lds.jsonnet @@ -0,0 +1,32 @@ +{ + "path": "$HOME/experiments/thermostat", + "device": "cuda", + "dataset": { + "name": "multi_nli", + "text_field": ["premise", "hypothesis"], + "split": "validation_matched", + "columns": ['input_ids', 'attention_mask', 'token_type_ids', 'special_tokens_mask', 'labels'], + "batch_size": 1, + "root_dir": "$HOME/experiments/thermostat/datasets", + }, + "explainer": { + "name": "LayerDeepLiftShap", + }, + "model": { + "name": "textattack/bert-base-uncased-MNLI", + "mode_load": "hf", + "path_model": null, + "tokenization": { + "max_length": 512, + "padding": "max_length", + "return_tensors": "np", + "truncation": true, + "special_tokens_mask": true, + } + }, + "visualization": { + "columns": ["attributions", "predictions", "input_ids", "label"], + "gamma": 2.0, + "normalize": true, + } +} diff --git a/configs/mnli/bert/lgs.jsonnet b/configs/mnli/bert/lgs.jsonnet new file mode 100644 index 0000000..b3a51a4 --- /dev/null +++ b/configs/mnli/bert/lgs.jsonnet @@ -0,0 +1,33 @@ +{ + "path": "$HOME/experiments/thermostat", + "device": "cuda", + "dataset": { + "name": "multi_nli", + "text_field": ["premise", "hypothesis"], + "split": "validation_matched", + "columns": ['input_ids', 'attention_mask', 'token_type_ids', 'special_tokens_mask', 'labels'], + "batch_size": 1, + "root_dir": "$HOME/experiments/thermostat/datasets", + }, + "explainer": { + "name": "LayerGradientShap", + "n_samples": 5, + }, + "model": { + "name": "textattack/bert-base-uncased-MNLI", + "mode_load": "hf", + "path_model": null, + "tokenization": { + "max_length": 512, + "padding": "max_length", + "return_tensors": "np", + "truncation": true, + "special_tokens_mask": true, + } + }, + "visualization": { + "columns": ["attributions", "predictions", "input_ids", "label"], + "gamma": 2.0, + "normalize": true, + } +} diff --git a/configs/mnli/electra/lds.jsonnet b/configs/mnli/electra/lds.jsonnet new file mode 100644 index 0000000..523b046 --- /dev/null +++ b/configs/mnli/electra/lds.jsonnet @@ -0,0 +1,32 @@ +{ + "path": "$HOME/experiments/thermostat", + "device": "cuda", + "dataset": { + "name": "multi_nli", + "text_field": ["premise", "hypothesis"], + "split": "validation_matched", + "columns": ['input_ids', 'attention_mask', 'token_type_ids', 'special_tokens_mask', 'labels'], + "batch_size": 1, + "root_dir": "$HOME/experiments/thermostat/datasets", + }, + "explainer": { + "name": "LayerDeepLiftShap", + }, + "model": { + "name": "howey/electra-base-mnli", + "mode_load": "hf", + "path_model": null, + "tokenization": { + "max_length": 512, + "padding": "max_length", + "return_tensors": "np", + "truncation": true, + "special_tokens_mask": true, + } + }, + "visualization": { + "columns": ["attributions", "predictions", "input_ids", "labels"], + "gamma": 2.0, + "normalize": true, + } +} diff --git a/configs/mnli/electra/lgs.jsonnet b/configs/mnli/electra/lgs.jsonnet new file mode 100644 index 0000000..e4fc703 --- /dev/null +++ b/configs/mnli/electra/lgs.jsonnet @@ -0,0 +1,33 @@ +{ + "path": "$HOME/experiments/thermostat", + "device": "cuda", + "dataset": { + "name": "multi_nli", + "text_field": ["premise", "hypothesis"], + "split": "validation_matched", + "columns": ['input_ids', 'attention_mask', 'token_type_ids', 'special_tokens_mask', 'labels'], + "batch_size": 1, + "root_dir": "$HOME/experiments/thermostat/datasets", + }, + "explainer": { + "name": "LayerGradientShap", + "n_samples": 5, + }, + "model": { + "name": "howey/electra-base-mnli", + "mode_load": "hf", + "path_model": null, + "tokenization": { + "max_length": 512, + "padding": "max_length", + "return_tensors": "np", + "truncation": true, + "special_tokens_mask": true, + } + }, + "visualization": { + "columns": ["attributions", "predictions", "input_ids", "labels"], + "gamma": 2.0, + "normalize": true, + } +} diff --git a/configs/mnli/roberta/lds.jsonnet b/configs/mnli/roberta/lds.jsonnet new file mode 100644 index 0000000..ce9dd4c --- /dev/null +++ b/configs/mnli/roberta/lds.jsonnet @@ -0,0 +1,32 @@ +{ + "path": "$HOME/experiments/thermostat", + "device": "cuda", + "dataset": { + "name": "multi_nli", + "text_field": ["premise", "hypothesis"], + "split": "validation_matched", + "columns": ['input_ids', 'special_tokens_mask', 'attention_mask', 'labels'], + "batch_size": 1, + "root_dir": "$HOME/experiments/thermostat/datasets", + }, + "explainer": { + "name": "LayerDeepLiftShap", + }, + "model": { + "name": "textattack/roberta-base-MNLI", + "mode_load": "hf", + "path_model": null, + "tokenization": { + "max_length": 512, + "padding": "max_length", + "return_tensors": "np", + "truncation": true, + "special_tokens_mask": true, + } + }, + "visualization": { + "columns": ["attributions", "predictions", "input_ids", "labels"], + "gamma": 2.0, + "normalize": true, + } +} diff --git a/configs/mnli/roberta/lgs.jsonnet b/configs/mnli/roberta/lgs.jsonnet new file mode 100644 index 0000000..87aef8d --- /dev/null +++ b/configs/mnli/roberta/lgs.jsonnet @@ -0,0 +1,33 @@ +{ + "path": "$HOME/experiments/thermostat", + "device": "cuda", + "dataset": { + "name": "multi_nli", + "text_field": ["premise", "hypothesis"], + "split": "validation_matched", + "columns": ['input_ids', 'special_tokens_mask', 'attention_mask', 'labels'], + "batch_size": 1, + "root_dir": "$HOME/experiments/thermostat/datasets", + }, + "explainer": { + "name": "LayerGradientShap", + "n_samples": 5, + }, + "model": { + "name": "textattack/roberta-base-MNLI", + "mode_load": "hf", + "path_model": null, + "tokenization": { + "max_length": 512, + "padding": "max_length", + "return_tensors": "np", + "truncation": true, + "special_tokens_mask": true, + } + }, + "visualization": { + "columns": ["attributions", "predictions", "input_ids", "labels"], + "gamma": 2.0, + "normalize": true, + } +} diff --git a/configs/mnli/xlnet/lds.jsonnet b/configs/mnli/xlnet/lds.jsonnet new file mode 100644 index 0000000..9e7a4ee --- /dev/null +++ b/configs/mnli/xlnet/lds.jsonnet @@ -0,0 +1,32 @@ +{ + "path": "$HOME/experiments/thermostat", + "device": "cuda", + "dataset": { + "name": "multi_nli", + "text_field": ["premise", "hypothesis"], + "split": "validation_matched", + "columns": ['input_ids', 'attention_mask', 'token_type_ids', 'special_tokens_mask', 'labels'], + "batch_size": 1, + "root_dir": "$HOME/experiments/thermostat/datasets", + }, + "explainer": { + "name": "LayerDeepLiftShap", + }, + "model": { + "name": "textattack/xlnet-base-cased-MNLI", + "mode_load": "hf", + "path_model": null, + "tokenization": { + "max_length": 512, + "padding": "max_length", + "return_tensors": "np", + "truncation": true, + "special_tokens_mask": true, + } + }, + "visualization": { + "columns": ["attributions", "predictions", "input_ids", "labels"], + "gamma": 2.0, + "normalize": true, + } +} diff --git a/configs/mnli/xlnet/lgs.jsonnet b/configs/mnli/xlnet/lgs.jsonnet new file mode 100644 index 0000000..b696de0 --- /dev/null +++ b/configs/mnli/xlnet/lgs.jsonnet @@ -0,0 +1,33 @@ +{ + "path": "$HOME/experiments/thermostat", + "device": "cuda", + "dataset": { + "name": "multi_nli", + "text_field": ["premise", "hypothesis"], + "split": "validation_matched", + "columns": ['input_ids', 'attention_mask', 'token_type_ids', 'special_tokens_mask', 'labels'], + "batch_size": 1, + "root_dir": "$HOME/experiments/thermostat/datasets", + }, + "explainer": { + "name": "LayerGradientShap", + "n_samples": 5, + }, + "model": { + "name": "textattack/xlnet-base-cased-MNLI", + "mode_load": "hf", + "path_model": null, + "tokenization": { + "max_length": 512, + "padding": "max_length", + "return_tensors": "np", + "truncation": true, + "special_tokens_mask": true, + } + }, + "visualization": { + "columns": ["attributions", "predictions", "input_ids", "labels"], + "gamma": 2.0, + "normalize": true, + } +} diff --git a/configs/xnli/albert/lds.jsonnet b/configs/xnli/albert/lds.jsonnet new file mode 100644 index 0000000..6dff984 --- /dev/null +++ b/configs/xnli/albert/lds.jsonnet @@ -0,0 +1,33 @@ +{ + "path": "$HOME/experiments/thermostat", + "device": "cuda", + "dataset": { + "name": "xnli", + "subset": "en", + "split": "test", + "text_field": ["premise", "hypothesis"], + "columns": ['input_ids', 'attention_mask', 'token_type_ids', 'special_tokens_mask', 'labels'], + "batch_size": 1, + "root_dir": "$HOME/experiments/thermostat/datasets", + }, + "explainer": { + "name": "LayerDeepLiftShap", + }, + "model": { + "name": "prajjwal1/albert-base-v2-mnli", + "mode_load": "hf", + "path_model": null, + "tokenization": { + "max_length": 512, + "padding": "max_length", + "return_tensors": "np", + "truncation": true, + "special_tokens_mask": true, + } + }, + "visualization": { + "columns": ["attributions", "predictions", "input_ids", "label"], + "gamma": 2.0, + "normalize": true, + } +} diff --git a/configs/xnli/albert/lgs.jsonnet b/configs/xnli/albert/lgs.jsonnet new file mode 100644 index 0000000..4f45f00 --- /dev/null +++ b/configs/xnli/albert/lgs.jsonnet @@ -0,0 +1,34 @@ +{ + "path": "$HOME/experiments/thermostat", + "device": "cuda", + "dataset": { + "name": "xnli", + "subset": "en", + "split": "test", + "text_field": ["premise", "hypothesis"], + "columns": ['input_ids', 'attention_mask', 'token_type_ids', 'special_tokens_mask', 'labels'], + "batch_size": 1, + "root_dir": "$HOME/experiments/thermostat/datasets", + }, + "explainer": { + "name": "LayerGradientShap", + "n_samples": 5, + }, + "model": { + "name": "prajjwal1/albert-base-v2-mnli", + "mode_load": "hf", + "path_model": null, + "tokenization": { + "max_length": 512, + "padding": "max_length", + "return_tensors": "np", + "truncation": true, + "special_tokens_mask": true, + } + }, + "visualization": { + "columns": ["attributions", "predictions", "input_ids", "label"], + "gamma": 2.0, + "normalize": true, + } +} diff --git a/configs/xnli/bert/lds.jsonnet b/configs/xnli/bert/lds.jsonnet new file mode 100644 index 0000000..a7209a2 --- /dev/null +++ b/configs/xnli/bert/lds.jsonnet @@ -0,0 +1,33 @@ +{ + "path": "$HOME/experiments/thermostat", + "device": "cuda", + "dataset": { + "name": "xnli", + "subset": "en", + "split": "test", + "text_field": ["premise", "hypothesis"], + "columns": ['input_ids', 'attention_mask', 'token_type_ids', 'special_tokens_mask', 'labels'], + "batch_size": 1, + "root_dir": "$HOME/experiments/thermostat/datasets", + }, + "explainer": { + "name": "LayerDeepLiftShap", + }, + "model": { + "name": "textattack/bert-base-uncased-MNLI", + "mode_load": "hf", + "path_model": null, + "tokenization": { + "max_length": 512, + "padding": "max_length", + "return_tensors": "np", + "truncation": true, + "special_tokens_mask": true, + } + }, + "visualization": { + "columns": ["attributions", "predictions", "input_ids", "label"], + "gamma": 2.0, + "normalize": true, + } +} diff --git a/configs/xnli/bert/lgs.jsonnet b/configs/xnli/bert/lgs.jsonnet new file mode 100644 index 0000000..f67379d --- /dev/null +++ b/configs/xnli/bert/lgs.jsonnet @@ -0,0 +1,34 @@ +{ + "path": "$HOME/experiments/thermostat", + "device": "cuda", + "dataset": { + "name": "xnli", + "subset": "en", + "split": "test", + "text_field": ["premise", "hypothesis"], + "columns": ['input_ids', 'attention_mask', 'token_type_ids', 'special_tokens_mask', 'labels'], + "batch_size": 1, + "root_dir": "$HOME/experiments/thermostat/datasets", + }, + "explainer": { + "name": "LayerGradientShap", + "n_samples": 5, + }, + "model": { + "name": "textattack/bert-base-uncased-MNLI", + "mode_load": "hf", + "path_model": null, + "tokenization": { + "max_length": 512, + "padding": "max_length", + "return_tensors": "np", + "truncation": true, + "special_tokens_mask": true, + } + }, + "visualization": { + "columns": ["attributions", "predictions", "input_ids", "label"], + "gamma": 2.0, + "normalize": true, + } +} diff --git a/configs/xnli/electra/lds.jsonnet b/configs/xnli/electra/lds.jsonnet new file mode 100644 index 0000000..c6d3fad --- /dev/null +++ b/configs/xnli/electra/lds.jsonnet @@ -0,0 +1,33 @@ +{ + "path": "$HOME/experiments/thermostat", + "device": "cuda", + "dataset": { + "name": "xnli", + "subset": "en", + "split": "test", + "text_field": ["premise", "hypothesis"], + "columns": ['input_ids', 'attention_mask', 'token_type_ids', 'special_tokens_mask', 'labels'], + "batch_size": 1, + "root_dir": "$HOME/experiments/thermostat/datasets", + }, + "explainer": { + "name": "LayerDeepLiftShap", + }, + "model": { + "name": "howey/electra-base-mnli", + "mode_load": "hf", + "path_model": null, + "tokenization": { + "max_length": 512, + "padding": "max_length", + "return_tensors": "np", + "truncation": true, + "special_tokens_mask": true, + } + }, + "visualization": { + "columns": ["attributions", "predictions", "input_ids", "label"], + "gamma": 2.0, + "normalize": true, + } +} diff --git a/configs/xnli/electra/lgs.jsonnet b/configs/xnli/electra/lgs.jsonnet new file mode 100644 index 0000000..90e7c67 --- /dev/null +++ b/configs/xnli/electra/lgs.jsonnet @@ -0,0 +1,34 @@ +{ + "path": "$HOME/experiments/thermostat", + "device": "cuda", + "dataset": { + "name": "xnli", + "subset": "en", + "split": "test", + "text_field": ["premise", "hypothesis"], + "columns": ['input_ids', 'attention_mask', 'token_type_ids', 'special_tokens_mask', 'labels'], + "batch_size": 1, + "root_dir": "$HOME/experiments/thermostat/datasets", + }, + "explainer": { + "name": "LayerGradientShap", + "n_samples": 5, + }, + "model": { + "name": "howey/electra-base-mnli", + "mode_load": "hf", + "path_model": null, + "tokenization": { + "max_length": 512, + "padding": "max_length", + "return_tensors": "np", + "truncation": true, + "special_tokens_mask": true, + } + }, + "visualization": { + "columns": ["attributions", "predictions", "input_ids", "label"], + "gamma": 2.0, + "normalize": true, + } +} diff --git a/configs/xnli/roberta/lds.jsonnet b/configs/xnli/roberta/lds.jsonnet new file mode 100644 index 0000000..293e342 --- /dev/null +++ b/configs/xnli/roberta/lds.jsonnet @@ -0,0 +1,33 @@ +{ + "path": "$HOME/experiments/thermostat", + "device": "cuda", + "dataset": { + "name": "xnli", + "subset": "en", + "split": "test", + "text_field": ["premise", "hypothesis"], + "columns": ['input_ids', 'attention_mask', 'special_tokens_mask', 'labels'], + "batch_size": 1, + "root_dir": "$HOME/experiments/thermostat/datasets", + }, + "explainer": { + "name": "LayerDeepLiftShap", + }, + "model": { + "name": "textattack/roberta-base-MNLI", + "mode_load": "hf", + "path_model": null, + "tokenization": { + "max_length": 512, + "padding": "max_length", + "return_tensors": "np", + "truncation": true, + "special_tokens_mask": true, + } + }, + "visualization": { + "columns": ["attributions", "predictions", "input_ids", "label"], + "gamma": 2.0, + "normalize": true, + } +} diff --git a/configs/xnli/roberta/lgs.jsonnet b/configs/xnli/roberta/lgs.jsonnet new file mode 100644 index 0000000..5e4fa6e --- /dev/null +++ b/configs/xnli/roberta/lgs.jsonnet @@ -0,0 +1,34 @@ +{ + "path": "$HOME/experiments/thermostat", + "device": "cuda", + "dataset": { + "name": "xnli", + "subset": "en", + "split": "test", + "text_field": ["premise", "hypothesis"], + "columns": ['input_ids', 'attention_mask', 'special_tokens_mask', 'labels'], + "batch_size": 1, + "root_dir": "$HOME/experiments/thermostat/datasets", + }, + "explainer": { + "name": "LayerGradientShap", + "n_samples": 5, + }, + "model": { + "name": "textattack/roberta-base-MNLI", + "mode_load": "hf", + "path_model": null, + "tokenization": { + "max_length": 512, + "padding": "max_length", + "return_tensors": "np", + "truncation": true, + "special_tokens_mask": true, + } + }, + "visualization": { + "columns": ["attributions", "predictions", "input_ids", "label"], + "gamma": 2.0, + "normalize": true, + } +} diff --git a/configs/xnli/xlnet/lds.jsonnet b/configs/xnli/xlnet/lds.jsonnet new file mode 100644 index 0000000..c025a07 --- /dev/null +++ b/configs/xnli/xlnet/lds.jsonnet @@ -0,0 +1,33 @@ +{ + "path": "$HOME/experiments/thermostat", + "device": "cuda", + "dataset": { + "name": "xnli", + "subset": "en", + "split": "test", + "text_field": ["premise", "hypothesis"], + "columns": ['input_ids', 'attention_mask', 'token_type_ids', 'special_tokens_mask', 'labels'], + "batch_size": 1, + "root_dir": "$HOME/experiments/thermostat/datasets", + }, + "explainer": { + "name": "LayerDeepLiftShap", + }, + "model": { + "name": "textattack/xlnet-base-cased-MNLI", + "mode_load": "hf", + "path_model": null, + "tokenization": { + "max_length": 512, + "padding": "max_length", + "return_tensors": "np", + "truncation": true, + "special_tokens_mask": true, + } + }, + "visualization": { + "columns": ["attributions", "predictions", "input_ids", "label"], + "gamma": 2.0, + "normalize": true, + } +} diff --git a/configs/xnli/xlnet/lgs.jsonnet b/configs/xnli/xlnet/lgs.jsonnet new file mode 100644 index 0000000..3334b92 --- /dev/null +++ b/configs/xnli/xlnet/lgs.jsonnet @@ -0,0 +1,34 @@ +{ + "path": "$HOME/experiments/thermostat", + "device": "cuda", + "dataset": { + "name": "xnli", + "subset": "en", + "split": "test", + "text_field": ["premise", "hypothesis"], + "columns": ['input_ids', 'attention_mask', 'token_type_ids', 'special_tokens_mask', 'labels'], + "batch_size": 1, + "root_dir": "$HOME/experiments/thermostat/datasets", + }, + "explainer": { + "name": "LayerGradientShap", + "n_samples": 5, + }, + "model": { + "name": "textattack/xlnet-base-cased-MNLI", + "mode_load": "hf", + "path_model": null, + "tokenization": { + "max_length": 512, + "padding": "max_length", + "return_tensors": "np", + "truncation": true, + "special_tokens_mask": true, + } + }, + "visualization": { + "columns": ["attributions", "predictions", "input_ids", "label"], + "gamma": 2.0, + "normalize": true, + } +} diff --git a/demo.ipynb b/demo.ipynb index 38487d9..142495d 100644 --- a/demo.ipynb +++ b/demo.ipynb @@ -17,142 +17,372 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Defaulting to user installation because normal site-packages is not writeable\n", - "Requirement already satisfied: pip in /home/nfel/.local/lib/python3.6/site-packages (21.1.2)\n", + "Requirement already satisfied: pip in c:\\programdata\\anaconda3\\lib\\site-packages (21.2.4)\n", "Collecting pip\n", - " Downloading pip-21.1.3-py3-none-any.whl (1.5 MB)\n", - "\u001b[K |████████████████████████████████| 1.5 MB 2.1 MB/s eta 0:00:01\n", - "\u001b[?25hInstalling collected packages: pip\n", + " Downloading pip-22.0.4-py3-none-any.whl (2.1 MB)\n", + "Installing collected packages: pip\n", " Attempting uninstall: pip\n", - " Found existing installation: pip 21.1.2\n", - " Uninstalling pip-21.1.2:\n", - " Successfully uninstalled pip-21.1.2\n", - "Successfully installed pip-21.1.3\n", - "Defaulting to user installation because normal site-packages is not writeable\n", - "Requirement already satisfied: cmake in /home/nfel/.local/lib/python3.6/site-packages (3.20.3)\n", - "Defaulting to user installation because normal site-packages is not writeable\n", - "Requirement already satisfied: cython in /home/nfel/.local/lib/python3.6/site-packages (0.29.23)\n", - "Defaulting to user installation because normal site-packages is not writeable\n", - "Requirement already satisfied: numpy in /home/nfel/.local/lib/python3.6/site-packages (1.19.5)\n", - "Defaulting to user installation because normal site-packages is not writeable\n", - "Requirement already satisfied: torch in /home/nfel/.local/lib/python3.6/site-packages (1.9.0)\n", - "Requirement already satisfied: dataclasses in /home/nfel/.local/lib/python3.6/site-packages (from torch) (0.8)\n", - "Requirement already satisfied: typing-extensions in /home/nfel/.local/lib/python3.6/site-packages (from torch) (3.10.0.0)\n", - "Defaulting to user installation because normal site-packages is not writeable\n", - "Requirement already satisfied: datasets in /home/nfel/.local/lib/python3.6/site-packages (1.8.0)\n", - "Requirement already satisfied: dataclasses in /home/nfel/.local/lib/python3.6/site-packages (from datasets) (0.8)\n", - "Requirement already satisfied: importlib-metadata in /home/nfel/.local/lib/python3.6/site-packages (from datasets) (4.5.0)\n", - "Requirement already satisfied: xxhash in /home/nfel/.local/lib/python3.6/site-packages (from datasets) (2.0.2)\n", - "Requirement already satisfied: requests>=2.19.0 in /home/nfel/.local/lib/python3.6/site-packages (from datasets) (2.25.1)\n", - "Requirement already satisfied: pyarrow<4.0.0,>=1.0.0 in /home/nfel/.local/lib/python3.6/site-packages (from datasets) (3.0.0)\n", - "Requirement already satisfied: pandas in /home/nfel/.local/lib/python3.6/site-packages (from datasets) (1.1.5)\n", - "Requirement already satisfied: tqdm<4.50.0,>=4.27 in /home/nfel/.local/lib/python3.6/site-packages (from datasets) (4.49.0)\n", - "Requirement already satisfied: multiprocess in /home/nfel/.local/lib/python3.6/site-packages (from datasets) (0.70.12.2)\n", - "Requirement already satisfied: dill in /home/nfel/.local/lib/python3.6/site-packages (from datasets) (0.3.4)\n", - "Requirement already satisfied: packaging in /home/nfel/.local/lib/python3.6/site-packages (from datasets) (20.9)\n", - "Requirement already satisfied: numpy>=1.17 in /home/nfel/.local/lib/python3.6/site-packages (from datasets) (1.19.5)\n", - "Requirement already satisfied: huggingface-hub<0.1.0 in /home/nfel/.local/lib/python3.6/site-packages (from datasets) (0.0.8)\n", - "Requirement already satisfied: fsspec in /home/nfel/.local/lib/python3.6/site-packages (from datasets) (2021.6.0)\n", - "Requirement already satisfied: filelock in /home/nfel/.local/lib/python3.6/site-packages (from huggingface-hub<0.1.0->datasets) (3.0.12)\n", - "Requirement already satisfied: chardet<5,>=3.0.2 in /home/nfel/.local/lib/python3.6/site-packages (from requests>=2.19.0->datasets) (4.0.0)\n", - "Requirement already satisfied: urllib3<1.27,>=1.21.1 in /home/nfel/.local/lib/python3.6/site-packages (from requests>=2.19.0->datasets) (1.26.5)\n", - "Requirement already satisfied: idna<3,>=2.5 in /home/nfel/.local/lib/python3.6/site-packages (from requests>=2.19.0->datasets) (2.10)\n", - "Requirement already satisfied: certifi>=2017.4.17 in /home/nfel/.local/lib/python3.6/site-packages (from requests>=2.19.0->datasets) (2021.5.30)\n", - "Requirement already satisfied: typing-extensions>=3.6.4 in /home/nfel/.local/lib/python3.6/site-packages (from importlib-metadata->datasets) (3.10.0.0)\n", - "Requirement already satisfied: zipp>=0.5 in /home/nfel/.local/lib/python3.6/site-packages (from importlib-metadata->datasets) (3.4.1)\n", - "Requirement already satisfied: pyparsing>=2.0.2 in /home/nfel/.local/lib/python3.6/site-packages (from packaging->datasets) (2.4.7)\n", - "Requirement already satisfied: pytz>=2017.2 in /home/nfel/.local/lib/python3.6/site-packages (from pandas->datasets) (2021.1)\n", - "Requirement already satisfied: python-dateutil>=2.7.3 in /home/nfel/.local/lib/python3.6/site-packages (from pandas->datasets) (2.8.1)\n", - "Requirement already satisfied: six>=1.5 in /home/nfel/.local/lib/python3.6/site-packages (from python-dateutil>=2.7.3->pandas->datasets) (1.16.0)\n", - "Defaulting to user installation because normal site-packages is not writeable\n", - "Requirement already satisfied: spacy in /home/nfel/.local/lib/python3.6/site-packages (3.0.6)\n", - "Requirement already satisfied: jinja2 in /home/nfel/.local/lib/python3.6/site-packages (from spacy) (3.0.1)\n", - "Requirement already satisfied: wasabi<1.1.0,>=0.8.1 in /home/nfel/.local/lib/python3.6/site-packages (from spacy) (0.8.2)\n", - "Requirement already satisfied: murmurhash<1.1.0,>=0.28.0 in /home/nfel/.local/lib/python3.6/site-packages (from spacy) (1.0.5)\n", - "Requirement already satisfied: pydantic<1.8.0,>=1.7.1 in /home/nfel/.local/lib/python3.6/site-packages (from spacy) (1.7.4)\n", - "Requirement already satisfied: requests<3.0.0,>=2.13.0 in /home/nfel/.local/lib/python3.6/site-packages (from spacy) (2.25.1)\n", - "Requirement already satisfied: pathy>=0.3.5 in /home/nfel/.local/lib/python3.6/site-packages (from spacy) (0.5.2)\n", - "Requirement already satisfied: thinc<8.1.0,>=8.0.3 in /home/nfel/.local/lib/python3.6/site-packages (from spacy) (8.0.4)\n", - "Requirement already satisfied: numpy>=1.15.0 in /home/nfel/.local/lib/python3.6/site-packages (from spacy) (1.19.5)\n", - "Requirement already satisfied: packaging>=20.0 in /home/nfel/.local/lib/python3.6/site-packages (from spacy) (20.9)\n", - "Requirement already satisfied: cymem<2.1.0,>=2.0.2 in /home/nfel/.local/lib/python3.6/site-packages (from spacy) (2.0.5)\n", - "Requirement already satisfied: srsly<3.0.0,>=2.4.1 in /home/nfel/.local/lib/python3.6/site-packages (from spacy) (2.4.1)\n", - "Requirement already satisfied: typing-extensions<4.0.0.0,>=3.7.4 in /home/nfel/.local/lib/python3.6/site-packages (from spacy) (3.10.0.0)\n", - "Requirement already satisfied: spacy-legacy<3.1.0,>=3.0.4 in /home/nfel/.local/lib/python3.6/site-packages (from spacy) (3.0.6)\n", - "Requirement already satisfied: preshed<3.1.0,>=3.0.2 in /home/nfel/.local/lib/python3.6/site-packages (from spacy) (3.0.5)\n", - "Requirement already satisfied: tqdm<5.0.0,>=4.38.0 in /home/nfel/.local/lib/python3.6/site-packages (from spacy) (4.49.0)\n", - "Requirement already satisfied: catalogue<2.1.0,>=2.0.3 in /home/nfel/.local/lib/python3.6/site-packages (from spacy) (2.0.4)\n", - "Requirement already satisfied: typer<0.4.0,>=0.3.0 in /home/nfel/.local/lib/python3.6/site-packages (from spacy) (0.3.2)\n", - "Requirement already satisfied: setuptools in /home/nfel/.local/lib/python3.6/site-packages (from spacy) (57.0.0)\n", - "Requirement already satisfied: blis<0.8.0,>=0.4.0 in /home/nfel/.local/lib/python3.6/site-packages (from spacy) (0.7.4)\n", - "Requirement already satisfied: zipp>=0.5 in /home/nfel/.local/lib/python3.6/site-packages (from catalogue<2.1.0,>=2.0.3->spacy) (3.4.1)\n", - "Requirement already satisfied: pyparsing>=2.0.2 in /home/nfel/.local/lib/python3.6/site-packages (from packaging>=20.0->spacy) (2.4.7)\n", - "Requirement already satisfied: smart-open<4.0.0,>=2.2.0 in /home/nfel/.local/lib/python3.6/site-packages (from pathy>=0.3.5->spacy) (3.0.0)\n", - "Requirement already satisfied: dataclasses<1.0,>=0.6 in /home/nfel/.local/lib/python3.6/site-packages (from pathy>=0.3.5->spacy) (0.8)\n", - "Requirement already satisfied: idna<3,>=2.5 in /home/nfel/.local/lib/python3.6/site-packages (from requests<3.0.0,>=2.13.0->spacy) (2.10)\n", - "Requirement already satisfied: urllib3<1.27,>=1.21.1 in /home/nfel/.local/lib/python3.6/site-packages (from requests<3.0.0,>=2.13.0->spacy) (1.26.5)\n", - "Requirement already satisfied: certifi>=2017.4.17 in /home/nfel/.local/lib/python3.6/site-packages (from requests<3.0.0,>=2.13.0->spacy) (2021.5.30)\n", - "Requirement already satisfied: chardet<5,>=3.0.2 in /home/nfel/.local/lib/python3.6/site-packages (from requests<3.0.0,>=2.13.0->spacy) (4.0.0)\n" + " Found existing installation: pip 21.2.4\n", + " Uninstalling pip-21.2.4:\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "WARNING: Ignoring invalid distribution -andas (c:\\programdata\\anaconda3\\lib\\site-packages)\n", + "WARNING: Ignoring invalid distribution -andas (c:\\programdata\\anaconda3\\lib\\site-packages)\n", + "WARNING: Ignoring invalid distribution -andas (c:\\programdata\\anaconda3\\lib\\site-packages)\n", + " WARNING: Ignoring invalid distribution -andas (c:\\programdata\\anaconda3\\lib\\site-packages)\n", + "ERROR: Could not install packages due to an OSError: [WinError 5] Access is denied: 'c:\\\\programdata\\\\anaconda3\\\\lib\\\\site-packages\\\\pip\\\\py.typed'\n", + "Consider using the `--user` option or check the permissions.\n", + "\n", + "WARNING: Ignoring invalid distribution -andas (c:\\programdata\\anaconda3\\lib\\site-packages)\n", + "WARNING: Ignoring invalid distribution -andas (c:\\programdata\\anaconda3\\lib\\site-packages)\n", + "WARNING: Ignoring invalid distribution -andas (c:\\programdata\\anaconda3\\lib\\site-packages)\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Requirement already satisfied: cmake in c:\\programdata\\anaconda3\\lib\\site-packages (3.22.2)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "WARNING: Ignoring invalid distribution -andas (c:\\programdata\\anaconda3\\lib\\site-packages)\n", + "WARNING: Ignoring invalid distribution -andas (c:\\programdata\\anaconda3\\lib\\site-packages)\n", + "WARNING: Ignoring invalid distribution -andas (c:\\programdata\\anaconda3\\lib\\site-packages)\n", + "WARNING: Ignoring invalid distribution -andas (c:\\programdata\\anaconda3\\lib\\site-packages)\n", + "WARNING: Ignoring invalid distribution -andas (c:\\programdata\\anaconda3\\lib\\site-packages)\n", + "WARNING: Ignoring invalid distribution -andas (c:\\programdata\\anaconda3\\lib\\site-packages)\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Requirement already satisfied: cython in c:\\programdata\\anaconda3\\lib\\site-packages (0.29.23)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "WARNING: Ignoring invalid distribution -andas (c:\\programdata\\anaconda3\\lib\\site-packages)\n", + "WARNING: Ignoring invalid distribution -andas (c:\\programdata\\anaconda3\\lib\\site-packages)\n", + "WARNING: Ignoring invalid distribution -andas (c:\\programdata\\anaconda3\\lib\\site-packages)\n", + "WARNING: Ignoring invalid distribution -andas (c:\\programdata\\anaconda3\\lib\\site-packages)\n", + "WARNING: Ignoring invalid distribution -andas (c:\\programdata\\anaconda3\\lib\\site-packages)\n", + "WARNING: Ignoring invalid distribution -andas (c:\\programdata\\anaconda3\\lib\\site-packages)\n", + "WARNING: Ignoring invalid distribution -andas (c:\\programdata\\anaconda3\\lib\\site-packages)\n", + "WARNING: Ignoring invalid distribution -andas (c:\\programdata\\anaconda3\\lib\\site-packages)\n", + "WARNING: Ignoring invalid distribution -andas (c:\\programdata\\anaconda3\\lib\\site-packages)\n", + "WARNING: Ignoring invalid distribution -andas (c:\\programdata\\anaconda3\\lib\\site-packages)\n", + "WARNING: Ignoring invalid distribution -andas (c:\\programdata\\anaconda3\\lib\\site-packages)\n", + "WARNING: Ignoring invalid distribution -andas (c:\\programdata\\anaconda3\\lib\\site-packages)\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Requirement already satisfied: numpy in c:\\programdata\\anaconda3\\lib\\site-packages (1.20.3)\n", + "Requirement already satisfied: torch in c:\\programdata\\anaconda3\\lib\\site-packages (1.10.1)\n", + "Requirement already satisfied: typing_extensions in c:\\programdata\\anaconda3\\lib\\site-packages (from torch) (3.10.0.2)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "WARNING: Ignoring invalid distribution -andas (c:\\programdata\\anaconda3\\lib\\site-packages)\n", + "WARNING: Ignoring invalid distribution -andas (c:\\programdata\\anaconda3\\lib\\site-packages)\n", + "WARNING: Ignoring invalid distribution -andas (c:\\programdata\\anaconda3\\lib\\site-packages)\n", + "WARNING: Ignoring invalid distribution -andas (c:\\programdata\\anaconda3\\lib\\site-packages)\n", + "WARNING: Ignoring invalid distribution -andas (c:\\programdata\\anaconda3\\lib\\site-packages)\n", + "WARNING: Ignoring invalid distribution -andas (c:\\programdata\\anaconda3\\lib\\site-packages)\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Requirement already satisfied: datasets in c:\\programdata\\anaconda3\\lib\\site-packages (1.18.3)\n", + "Requirement already satisfied: pandas in c:\\programdata\\anaconda3\\lib\\site-packages (from datasets) (1.3.4)\n", + "Requirement already satisfied: xxhash in c:\\programdata\\anaconda3\\lib\\site-packages (from datasets) (2.0.2)\n", + "Requirement already satisfied: packaging in c:\\programdata\\anaconda3\\lib\\site-packages (from datasets) (21.0)\n", + "Requirement already satisfied: multiprocess in c:\\programdata\\anaconda3\\lib\\site-packages (from datasets) (0.70.12.2)\n", + "Requirement already satisfied: tqdm>=4.62.1 in c:\\programdata\\anaconda3\\lib\\site-packages (from datasets) (4.62.3)\n", + "Requirement already satisfied: aiohttp in c:\\programdata\\anaconda3\\lib\\site-packages (from datasets) (3.8.1)\n", + "Requirement already satisfied: fsspec[http]>=2021.05.0 in c:\\programdata\\anaconda3\\lib\\site-packages (from datasets) (2021.10.1)\n", + "Requirement already satisfied: dill in c:\\programdata\\anaconda3\\lib\\site-packages (from datasets) (0.3.4)\n", + "Requirement already satisfied: huggingface-hub<1.0.0,>=0.1.0 in c:\\programdata\\anaconda3\\lib\\site-packages (from datasets) (0.4.0)\n", + "Requirement already satisfied: pyarrow!=4.0.0,>=3.0.0 in c:\\programdata\\anaconda3\\lib\\site-packages (from datasets) (6.0.1)\n", + "Requirement already satisfied: numpy>=1.17 in c:\\programdata\\anaconda3\\lib\\site-packages (from datasets) (1.20.3)\n", + "Requirement already satisfied: requests>=2.19.0 in c:\\programdata\\anaconda3\\lib\\site-packages (from datasets) (2.26.0)\n", + "Requirement already satisfied: pyyaml in c:\\programdata\\anaconda3\\lib\\site-packages (from huggingface-hub<1.0.0,>=0.1.0->datasets) (6.0)\n", + "Requirement already satisfied: filelock in c:\\programdata\\anaconda3\\lib\\site-packages (from huggingface-hub<1.0.0,>=0.1.0->datasets) (3.3.1)\n", + "Requirement already satisfied: typing-extensions>=3.7.4.3 in c:\\programdata\\anaconda3\\lib\\site-packages (from huggingface-hub<1.0.0,>=0.1.0->datasets) (3.10.0.2)\n", + "Requirement already satisfied: pyparsing>=2.0.2 in c:\\programdata\\anaconda3\\lib\\site-packages (from packaging->datasets) (3.0.4)\n", + "Requirement already satisfied: urllib3<1.27,>=1.21.1 in c:\\programdata\\anaconda3\\lib\\site-packages (from requests>=2.19.0->datasets) (1.26.7)\n", + "Requirement already satisfied: charset-normalizer~=2.0.0 in c:\\programdata\\anaconda3\\lib\\site-packages (from requests>=2.19.0->datasets) (2.0.4)\n", + "Requirement already satisfied: idna<4,>=2.5 in c:\\programdata\\anaconda3\\lib\\site-packages (from requests>=2.19.0->datasets) (3.2)\n", + "Requirement already satisfied: certifi>=2017.4.17 in c:\\programdata\\anaconda3\\lib\\site-packages (from requests>=2.19.0->datasets) (2021.10.8)\n", + "Requirement already satisfied: colorama in c:\\programdata\\anaconda3\\lib\\site-packages (from tqdm>=4.62.1->datasets) (0.4.4)\n", + "Requirement already satisfied: yarl<2.0,>=1.0 in c:\\programdata\\anaconda3\\lib\\site-packages (from aiohttp->datasets) (1.7.2)\n", + "Requirement already satisfied: async-timeout<5.0,>=4.0.0a3 in c:\\programdata\\anaconda3\\lib\\site-packages (from aiohttp->datasets) (4.0.2)\n", + "Requirement already satisfied: aiosignal>=1.1.2 in c:\\programdata\\anaconda3\\lib\\site-packages (from aiohttp->datasets) (1.2.0)\n", + "Requirement already satisfied: attrs>=17.3.0 in c:\\programdata\\anaconda3\\lib\\site-packages (from aiohttp->datasets) (21.2.0)\n", + "Requirement already satisfied: frozenlist>=1.1.1 in c:\\programdata\\anaconda3\\lib\\site-packages (from aiohttp->datasets) (1.3.0)\n", + "Requirement already satisfied: multidict<7.0,>=4.5 in c:\\programdata\\anaconda3\\lib\\site-packages (from aiohttp->datasets) (6.0.2)\n", + "Requirement already satisfied: pytz>=2017.3 in c:\\programdata\\anaconda3\\lib\\site-packages (from pandas->datasets) (2021.3)\n", + "Requirement already satisfied: python-dateutil>=2.7.3 in c:\\programdata\\anaconda3\\lib\\site-packages (from pandas->datasets) (2.8.2)\n", + "Requirement already satisfied: six>=1.5 in c:\\programdata\\anaconda3\\lib\\site-packages (from python-dateutil>=2.7.3->pandas->datasets) (1.16.0)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "WARNING: Ignoring invalid distribution -andas (c:\\programdata\\anaconda3\\lib\\site-packages)\n", + "WARNING: Ignoring invalid distribution -andas (c:\\programdata\\anaconda3\\lib\\site-packages)\n", + "WARNING: Ignoring invalid distribution -andas (c:\\programdata\\anaconda3\\lib\\site-packages)\n", + "WARNING: Ignoring invalid distribution -andas (c:\\programdata\\anaconda3\\lib\\site-packages)\n", + "WARNING: Ignoring invalid distribution -andas (c:\\programdata\\anaconda3\\lib\\site-packages)\n", + "WARNING: Ignoring invalid distribution -andas (c:\\programdata\\anaconda3\\lib\\site-packages)\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Requirement already satisfied: spacy in c:\\programdata\\anaconda3\\lib\\site-packages (3.2.1)\n", + "Requirement already satisfied: jinja2 in c:\\programdata\\anaconda3\\lib\\site-packages (from spacy) (2.11.3)\n", + "Requirement already satisfied: blis<0.8.0,>=0.4.0 in c:\\programdata\\anaconda3\\lib\\site-packages (from spacy) (0.7.5)\n", + "Requirement already satisfied: packaging>=20.0 in c:\\programdata\\anaconda3\\lib\\site-packages (from spacy) (21.0)\n", + "Requirement already satisfied: langcodes<4.0.0,>=3.2.0 in c:\\programdata\\anaconda3\\lib\\site-packages (from spacy) (3.3.0)\n", + "Requirement already satisfied: preshed<3.1.0,>=3.0.2 in c:\\programdata\\anaconda3\\lib\\site-packages (from spacy) (3.0.6)\n", + "Requirement already satisfied: murmurhash<1.1.0,>=0.28.0 in c:\\programdata\\anaconda3\\lib\\site-packages (from spacy) (1.0.6)\n", + "Requirement already satisfied: spacy-loggers<2.0.0,>=1.0.0 in c:\\programdata\\anaconda3\\lib\\site-packages (from spacy) (1.0.1)\n", + "Requirement already satisfied: pathy>=0.3.5 in c:\\programdata\\anaconda3\\lib\\site-packages (from spacy) (0.6.1)\n", + "Requirement already satisfied: spacy-legacy<3.1.0,>=3.0.8 in c:\\programdata\\anaconda3\\lib\\site-packages (from spacy) (3.0.8)\n", + "Requirement already satisfied: srsly<3.0.0,>=2.4.1 in c:\\programdata\\anaconda3\\lib\\site-packages (from spacy) (2.4.2)\n", + "Requirement already satisfied: cymem<2.1.0,>=2.0.2 in c:\\programdata\\anaconda3\\lib\\site-packages (from spacy) (2.0.6)\n", + "Requirement already satisfied: thinc<8.1.0,>=8.0.12 in c:\\programdata\\anaconda3\\lib\\site-packages (from spacy) (8.0.13)\n", + "Requirement already satisfied: typer<0.5.0,>=0.3.0 in c:\\programdata\\anaconda3\\lib\\site-packages (from spacy) (0.4.0)\n", + "Requirement already satisfied: numpy>=1.15.0 in c:\\programdata\\anaconda3\\lib\\site-packages (from spacy) (1.20.3)\n", + "Requirement already satisfied: wasabi<1.1.0,>=0.8.1 in c:\\programdata\\anaconda3\\lib\\site-packages (from spacy) (0.9.0)\n", + "Requirement already satisfied: catalogue<2.1.0,>=2.0.6 in c:\\programdata\\anaconda3\\lib\\site-packages (from spacy) (2.0.6)\n", + "Requirement already satisfied: requests<3.0.0,>=2.13.0 in c:\\programdata\\anaconda3\\lib\\site-packages (from spacy) (2.26.0)\n", + "Requirement already satisfied: setuptools in c:\\programdata\\anaconda3\\lib\\site-packages (from spacy) (58.0.4)\n", + "Requirement already satisfied: pydantic!=1.8,!=1.8.1,<1.9.0,>=1.7.4 in c:\\programdata\\anaconda3\\lib\\site-packages (from spacy) (1.8.2)\n", + "Requirement already satisfied: tqdm<5.0.0,>=4.38.0 in c:\\programdata\\anaconda3\\lib\\site-packages (from spacy) (4.62.3)\n", + "Requirement already satisfied: pyparsing>=2.0.2 in c:\\programdata\\anaconda3\\lib\\site-packages (from packaging>=20.0->spacy) (3.0.4)\n", + "Requirement already satisfied: smart-open<6.0.0,>=5.0.0 in c:\\programdata\\anaconda3\\lib\\site-packages (from pathy>=0.3.5->spacy) (5.2.1)\n", + "Requirement already satisfied: typing-extensions>=3.7.4.3 in c:\\programdata\\anaconda3\\lib\\site-packages (from pydantic!=1.8,!=1.8.1,<1.9.0,>=1.7.4->spacy) (3.10.0.2)\n", + "Requirement already satisfied: urllib3<1.27,>=1.21.1 in c:\\programdata\\anaconda3\\lib\\site-packages (from requests<3.0.0,>=2.13.0->spacy) (1.26.7)\n", + "Requirement already satisfied: idna<4,>=2.5 in c:\\programdata\\anaconda3\\lib\\site-packages (from requests<3.0.0,>=2.13.0->spacy) (3.2)\n", + "Requirement already satisfied: certifi>=2017.4.17 in c:\\programdata\\anaconda3\\lib\\site-packages (from requests<3.0.0,>=2.13.0->spacy) (2021.10.8)\n", + "Requirement already satisfied: charset-normalizer~=2.0.0 in c:\\programdata\\anaconda3\\lib\\site-packages (from requests<3.0.0,>=2.13.0->spacy) (2.0.4)\n", + "Requirement already satisfied: colorama in c:\\programdata\\anaconda3\\lib\\site-packages (from tqdm<5.0.0,>=4.38.0->spacy) (0.4.4)\n", + "Requirement already satisfied: click<9.0.0,>=7.1.1 in c:\\programdata\\anaconda3\\lib\\site-packages (from typer<0.5.0,>=0.3.0->spacy) (8.0.3)\n", + "Requirement already satisfied: MarkupSafe>=0.23 in c:\\programdata\\anaconda3\\lib\\site-packages (from jinja2->spacy) (1.1.1)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "WARNING: Ignoring invalid distribution -andas (c:\\programdata\\anaconda3\\lib\\site-packages)\n", + "WARNING: Ignoring invalid distribution -andas (c:\\programdata\\anaconda3\\lib\\site-packages)\n", + "WARNING: Ignoring invalid distribution -andas (c:\\programdata\\anaconda3\\lib\\site-packages)\n", + "WARNING: Ignoring invalid distribution -andas (c:\\programdata\\anaconda3\\lib\\site-packages)\n", + "WARNING: Ignoring invalid distribution -andas (c:\\programdata\\anaconda3\\lib\\site-packages)\n", + "WARNING: Ignoring invalid distribution -andas (c:\\programdata\\anaconda3\\lib\\site-packages)\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Requirement already satisfied: sentencepiece in c:\\programdata\\anaconda3\\lib\\site-packages (0.1.96)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "WARNING: Ignoring invalid distribution -andas (c:\\programdata\\anaconda3\\lib\\site-packages)\n", + "WARNING: Ignoring invalid distribution -andas (c:\\programdata\\anaconda3\\lib\\site-packages)\n", + "WARNING: Ignoring invalid distribution -andas (c:\\programdata\\anaconda3\\lib\\site-packages)\n", + "WARNING: Ignoring invalid distribution -andas (c:\\programdata\\anaconda3\\lib\\site-packages)\n", + "WARNING: Ignoring invalid distribution -andas (c:\\programdata\\anaconda3\\lib\\site-packages)\n", + "WARNING: Ignoring invalid distribution -andas (c:\\programdata\\anaconda3\\lib\\site-packages)\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Requirement already satisfied: transformers in c:\\programdata\\anaconda3\\lib\\site-packages (4.16.2)\n", + "Requirement already satisfied: huggingface-hub<1.0,>=0.1.0 in c:\\programdata\\anaconda3\\lib\\site-packages (from transformers) (0.4.0)\n", + "Requirement already satisfied: sacremoses in c:\\programdata\\anaconda3\\lib\\site-packages (from transformers) (0.0.47)\n", + "Requirement already satisfied: packaging>=20.0 in c:\\programdata\\anaconda3\\lib\\site-packages (from transformers) (21.0)\n", + "Requirement already satisfied: numpy>=1.17 in c:\\programdata\\anaconda3\\lib\\site-packages (from transformers) (1.20.3)\n", + "Requirement already satisfied: tokenizers!=0.11.3,>=0.10.1 in c:\\programdata\\anaconda3\\lib\\site-packages (from transformers) (0.11.4)\n", + "Requirement already satisfied: requests in c:\\programdata\\anaconda3\\lib\\site-packages (from transformers) (2.26.0)\n", + "Requirement already satisfied: filelock in c:\\programdata\\anaconda3\\lib\\site-packages (from transformers) (3.3.1)\n", + "Requirement already satisfied: pyyaml>=5.1 in c:\\programdata\\anaconda3\\lib\\site-packages (from transformers) (6.0)\n", + "Requirement already satisfied: regex!=2019.12.17 in c:\\programdata\\anaconda3\\lib\\site-packages (from transformers) (2021.8.3)\n", + "Requirement already satisfied: tqdm>=4.27 in c:\\programdata\\anaconda3\\lib\\site-packages (from transformers) (4.62.3)\n", + "Requirement already satisfied: typing-extensions>=3.7.4.3 in c:\\programdata\\anaconda3\\lib\\site-packages (from huggingface-hub<1.0,>=0.1.0->transformers) (3.10.0.2)\n", + "Requirement already satisfied: pyparsing>=2.0.2 in c:\\programdata\\anaconda3\\lib\\site-packages (from packaging>=20.0->transformers) (3.0.4)\n", + "Requirement already satisfied: colorama in c:\\programdata\\anaconda3\\lib\\site-packages (from tqdm>=4.27->transformers) (0.4.4)\n", + "Requirement already satisfied: charset-normalizer~=2.0.0 in c:\\programdata\\anaconda3\\lib\\site-packages (from requests->transformers) (2.0.4)\n", + "Requirement already satisfied: certifi>=2017.4.17 in c:\\programdata\\anaconda3\\lib\\site-packages (from requests->transformers) (2021.10.8)\n", + "Requirement already satisfied: urllib3<1.27,>=1.21.1 in c:\\programdata\\anaconda3\\lib\\site-packages (from requests->transformers) (1.26.7)\n", + "Requirement already satisfied: idna<4,>=2.5 in c:\\programdata\\anaconda3\\lib\\site-packages (from requests->transformers) (3.2)\n", + "Requirement already satisfied: six in c:\\programdata\\anaconda3\\lib\\site-packages (from sacremoses->transformers) (1.16.0)\n", + "Requirement already satisfied: click in c:\\programdata\\anaconda3\\lib\\site-packages (from sacremoses->transformers) (8.0.3)\n", + "Requirement already satisfied: joblib in c:\\programdata\\anaconda3\\lib\\site-packages (from sacremoses->transformers) (1.1.0)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "WARNING: Ignoring invalid distribution -andas (c:\\programdata\\anaconda3\\lib\\site-packages)\n", + "WARNING: Ignoring invalid distribution -andas (c:\\programdata\\anaconda3\\lib\\site-packages)\n", + "WARNING: Ignoring invalid distribution -andas (c:\\programdata\\anaconda3\\lib\\site-packages)\n", + "WARNING: Ignoring invalid distribution -andas (c:\\programdata\\anaconda3\\lib\\site-packages)\n", + "WARNING: Ignoring invalid distribution -andas (c:\\programdata\\anaconda3\\lib\\site-packages)\n", + "WARNING: Ignoring invalid distribution -andas (c:\\programdata\\anaconda3\\lib\\site-packages)\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Requirement already satisfied: overrides in c:\\programdata\\anaconda3\\lib\\site-packages (6.1.0)\n", + "Requirement already satisfied: typing-utils>=0.0.3 in c:\\programdata\\anaconda3\\lib\\site-packages (from overrides) (0.1.0)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "WARNING: Ignoring invalid distribution -andas (c:\\programdata\\anaconda3\\lib\\site-packages)\n", + "WARNING: Ignoring invalid distribution -andas (c:\\programdata\\anaconda3\\lib\\site-packages)\n", + "WARNING: Ignoring invalid distribution -andas (c:\\programdata\\anaconda3\\lib\\site-packages)\n", + "WARNING: Ignoring invalid distribution -andas (c:\\programdata\\anaconda3\\lib\\site-packages)\n", + "WARNING: Ignoring invalid distribution -andas (c:\\programdata\\anaconda3\\lib\\site-packages)\n", + "WARNING: Ignoring invalid distribution -andas (c:\\programdata\\anaconda3\\lib\\site-packages)\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Collecting jsonnet\n", + " Using cached jsonnet-0.18.0.tar.gz (592 kB)\n", + "Building wheels for collected packages: jsonnet\n", + " Building wheel for jsonnet (setup.py): started\n", + " Building wheel for jsonnet (setup.py): finished with status 'error'\n", + " Running setup.py clean for jsonnet\n", + "Failed to build jsonnet\n", + "Installing collected packages: jsonnet\n", + " Running setup.py install for jsonnet: started\n", + " Running setup.py install for jsonnet: finished with status 'error'\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "WARNING: Ignoring invalid distribution -andas (c:\\programdata\\anaconda3\\lib\\site-packages)\n", + "WARNING: Ignoring invalid distribution -andas (c:\\programdata\\anaconda3\\lib\\site-packages)\n", + " ERROR: Command errored out with exit status 1:\n", + " command: 'C:\\ProgramData\\Anaconda3\\python.exe' -u -c 'import io, os, sys, setuptools, tokenize; sys.argv[0] = '\"'\"'C:\\\\Users\\\\49176\\\\AppData\\\\Local\\\\Temp\\\\pip-install-vrsb96ad\\\\jsonnet_8f800cc73699425d8babbbf5b9340802\\\\setup.py'\"'\"'; __file__='\"'\"'C:\\\\Users\\\\49176\\\\AppData\\\\Local\\\\Temp\\\\pip-install-vrsb96ad\\\\jsonnet_8f800cc73699425d8babbbf5b9340802\\\\setup.py'\"'\"';f = getattr(tokenize, '\"'\"'open'\"'\"', open)(__file__) if os.path.exists(__file__) else io.StringIO('\"'\"'from setuptools import setup; setup()'\"'\"');code = f.read().replace('\"'\"'\\r\\n'\"'\"', '\"'\"'\\n'\"'\"');f.close();exec(compile(code, __file__, '\"'\"'exec'\"'\"'))' bdist_wheel -d 'C:\\Users\\49176\\AppData\\Local\\Temp\\pip-wheel-qsb54w8v'\n", + " cwd: C:\\Users\\49176\\AppData\\Local\\Temp\\pip-install-vrsb96ad\\jsonnet_8f800cc73699425d8babbbf5b9340802\\\n", + " Complete output (4 lines):\n", + " running bdist_wheel\n", + " running build\n", + " running build_ext\n", + " error: [WinError 2] The system cannot find the file specified\n", + " ----------------------------------------\n", + " ERROR: Failed building wheel for jsonnet\n", + "WARNING: Ignoring invalid distribution -andas (c:\\programdata\\anaconda3\\lib\\site-packages)\n", + " ERROR: Command errored out with exit status 1:\n", + " command: 'C:\\ProgramData\\Anaconda3\\python.exe' -u -c 'import io, os, sys, setuptools, tokenize; sys.argv[0] = '\"'\"'C:\\\\Users\\\\49176\\\\AppData\\\\Local\\\\Temp\\\\pip-install-vrsb96ad\\\\jsonnet_8f800cc73699425d8babbbf5b9340802\\\\setup.py'\"'\"'; __file__='\"'\"'C:\\\\Users\\\\49176\\\\AppData\\\\Local\\\\Temp\\\\pip-install-vrsb96ad\\\\jsonnet_8f800cc73699425d8babbbf5b9340802\\\\setup.py'\"'\"';f = getattr(tokenize, '\"'\"'open'\"'\"', open)(__file__) if os.path.exists(__file__) else io.StringIO('\"'\"'from setuptools import setup; setup()'\"'\"');code = f.read().replace('\"'\"'\\r\\n'\"'\"', '\"'\"'\\n'\"'\"');f.close();exec(compile(code, __file__, '\"'\"'exec'\"'\"'))' install --record 'C:\\Users\\49176\\AppData\\Local\\Temp\\pip-record-m2_9diqj\\install-record.txt' --single-version-externally-managed --compile --install-headers 'C:\\ProgramData\\Anaconda3\\Include\\jsonnet'\n", + " cwd: C:\\Users\\49176\\AppData\\Local\\Temp\\pip-install-vrsb96ad\\jsonnet_8f800cc73699425d8babbbf5b9340802\\\n", + " Complete output (4 lines):\n", + " running install\n", + " running build\n", + " running build_ext\n", + " error: [WinError 2] The system cannot find the file specified\n", + " ----------------------------------------\n", + "ERROR: Command errored out with exit status 1: 'C:\\ProgramData\\Anaconda3\\python.exe' -u -c 'import io, os, sys, setuptools, tokenize; sys.argv[0] = '\"'\"'C:\\\\Users\\\\49176\\\\AppData\\\\Local\\\\Temp\\\\pip-install-vrsb96ad\\\\jsonnet_8f800cc73699425d8babbbf5b9340802\\\\setup.py'\"'\"'; __file__='\"'\"'C:\\\\Users\\\\49176\\\\AppData\\\\Local\\\\Temp\\\\pip-install-vrsb96ad\\\\jsonnet_8f800cc73699425d8babbbf5b9340802\\\\setup.py'\"'\"';f = getattr(tokenize, '\"'\"'open'\"'\"', open)(__file__) if os.path.exists(__file__) else io.StringIO('\"'\"'from setuptools import setup; setup()'\"'\"');code = f.read().replace('\"'\"'\\r\\n'\"'\"', '\"'\"'\\n'\"'\"');f.close();exec(compile(code, __file__, '\"'\"'exec'\"'\"'))' install --record 'C:\\Users\\49176\\AppData\\Local\\Temp\\pip-record-m2_9diqj\\install-record.txt' --single-version-externally-managed --compile --install-headers 'C:\\ProgramData\\Anaconda3\\Include\\jsonnet' Check the logs for full command output.\n", + "WARNING: Ignoring invalid distribution -andas (c:\\programdata\\anaconda3\\lib\\site-packages)\n", + "WARNING: Ignoring invalid distribution -andas (c:\\programdata\\anaconda3\\lib\\site-packages)\n", + "WARNING: Ignoring invalid distribution -andas (c:\\programdata\\anaconda3\\lib\\site-packages)\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Requirement already satisfied: sklearn in c:\\programdata\\anaconda3\\lib\\site-packages (0.0)\n", + "Requirement already satisfied: scikit-learn in c:\\programdata\\anaconda3\\lib\\site-packages (from sklearn) (0.24.2)\n", + "Requirement already satisfied: threadpoolctl>=2.0.0 in c:\\programdata\\anaconda3\\lib\\site-packages (from scikit-learn->sklearn) (2.2.0)\n", + "Requirement already satisfied: scipy>=0.19.1 in c:\\programdata\\anaconda3\\lib\\site-packages (from scikit-learn->sklearn) (1.7.1)\n", + "Requirement already satisfied: numpy>=1.13.3 in c:\\programdata\\anaconda3\\lib\\site-packages (from scikit-learn->sklearn) (1.20.3)\n", + "Requirement already satisfied: joblib>=0.11 in c:\\programdata\\anaconda3\\lib\\site-packages (from scikit-learn->sklearn) (1.1.0)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "WARNING: Ignoring invalid distribution -andas (c:\\programdata\\anaconda3\\lib\\site-packages)\n", + "WARNING: Ignoring invalid distribution -andas (c:\\programdata\\anaconda3\\lib\\site-packages)\n", + "WARNING: Ignoring invalid distribution -andas (c:\\programdata\\anaconda3\\lib\\site-packages)\n", + "WARNING: Ignoring invalid distribution -andas (c:\\programdata\\anaconda3\\lib\\site-packages)\n", + "WARNING: Ignoring invalid distribution -andas (c:\\programdata\\anaconda3\\lib\\site-packages)\n", + "WARNING: Ignoring invalid distribution -andas (c:\\programdata\\anaconda3\\lib\\site-packages)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "Requirement already satisfied: contextvars<3,>=2.4 in /home/nfel/.local/lib/python3.6/site-packages (from thinc<8.1.0,>=8.0.3->spacy) (2.4)\n", - "Requirement already satisfied: immutables>=0.9 in /home/nfel/.local/lib/python3.6/site-packages (from contextvars<3,>=2.4->thinc<8.1.0,>=8.0.3->spacy) (0.15)\n", - "Requirement already satisfied: click<7.2.0,>=7.1.1 in /home/nfel/.local/lib/python3.6/site-packages (from typer<0.4.0,>=0.3.0->spacy) (7.1.2)\n", - "Requirement already satisfied: MarkupSafe>=2.0 in /home/nfel/.local/lib/python3.6/site-packages (from jinja2->spacy) (2.0.1)\n", - "Defaulting to user installation because normal site-packages is not writeable\n", - "Requirement already satisfied: sentencepiece in /home/nfel/.local/lib/python3.6/site-packages (0.1.95)\n", - "Defaulting to user installation because normal site-packages is not writeable\n", - "Requirement already satisfied: transformers in /home/nfel/.local/lib/python3.6/site-packages (4.6.1)\n", - "Requirement already satisfied: tokenizers<0.11,>=0.10.1 in /home/nfel/.local/lib/python3.6/site-packages (from transformers) (0.10.3)\n", - "Requirement already satisfied: huggingface-hub==0.0.8 in /home/nfel/.local/lib/python3.6/site-packages (from transformers) (0.0.8)\n", - "Requirement already satisfied: tqdm>=4.27 in /home/nfel/.local/lib/python3.6/site-packages (from transformers) (4.49.0)\n", - "Requirement already satisfied: importlib-metadata in /home/nfel/.local/lib/python3.6/site-packages (from transformers) (4.5.0)\n", - "Requirement already satisfied: packaging in /home/nfel/.local/lib/python3.6/site-packages (from transformers) (20.9)\n", - "Requirement already satisfied: sacremoses in /home/nfel/.local/lib/python3.6/site-packages (from transformers) (0.0.45)\n", - "Requirement already satisfied: dataclasses in /home/nfel/.local/lib/python3.6/site-packages (from transformers) (0.8)\n", - "Requirement already satisfied: filelock in /home/nfel/.local/lib/python3.6/site-packages (from transformers) (3.0.12)\n", - "Requirement already satisfied: numpy>=1.17 in /home/nfel/.local/lib/python3.6/site-packages (from transformers) (1.19.5)\n", - "Requirement already satisfied: regex!=2019.12.17 in /home/nfel/.local/lib/python3.6/site-packages (from transformers) (2021.4.4)\n", - "Requirement already satisfied: requests in /home/nfel/.local/lib/python3.6/site-packages (from transformers) (2.25.1)\n", - "Requirement already satisfied: typing-extensions>=3.6.4 in /home/nfel/.local/lib/python3.6/site-packages (from importlib-metadata->transformers) (3.10.0.0)\n", - "Requirement already satisfied: zipp>=0.5 in /home/nfel/.local/lib/python3.6/site-packages (from importlib-metadata->transformers) (3.4.1)\n", - "Requirement already satisfied: pyparsing>=2.0.2 in /home/nfel/.local/lib/python3.6/site-packages (from packaging->transformers) (2.4.7)\n", - "Requirement already satisfied: certifi>=2017.4.17 in /home/nfel/.local/lib/python3.6/site-packages (from requests->transformers) (2021.5.30)\n", - "Requirement already satisfied: chardet<5,>=3.0.2 in /home/nfel/.local/lib/python3.6/site-packages (from requests->transformers) (4.0.0)\n", - "Requirement already satisfied: urllib3<1.27,>=1.21.1 in /home/nfel/.local/lib/python3.6/site-packages (from requests->transformers) (1.26.5)\n", - "Requirement already satisfied: idna<3,>=2.5 in /home/nfel/.local/lib/python3.6/site-packages (from requests->transformers) (2.10)\n", - "Requirement already satisfied: six in /home/nfel/.local/lib/python3.6/site-packages (from sacremoses->transformers) (1.16.0)\n", - "Requirement already satisfied: click in /home/nfel/.local/lib/python3.6/site-packages (from sacremoses->transformers) (7.1.2)\n", - "Requirement already satisfied: joblib in /home/nfel/.local/lib/python3.6/site-packages (from sacremoses->transformers) (1.0.1)\n", - "Defaulting to user installation because normal site-packages is not writeable\n", - "Requirement already satisfied: overrides in /home/nfel/.local/lib/python3.6/site-packages (6.1.0)\n", - "Requirement already satisfied: typing-utils>=0.0.3 in /home/nfel/.local/lib/python3.6/site-packages (from overrides) (0.1.0)\n", - "Defaulting to user installation because normal site-packages is not writeable\n", - "Requirement already satisfied: jsonnet in /home/nfel/.local/lib/python3.6/site-packages (0.17.0)\n", - "Defaulting to user installation because normal site-packages is not writeable\n", - "Requirement already satisfied: sklearn in /home/nfel/.local/lib/python3.6/site-packages (0.0)\n", - "Requirement already satisfied: scikit-learn in /home/nfel/.local/lib/python3.6/site-packages (from sklearn) (0.24.2)\n", - "Requirement already satisfied: threadpoolctl>=2.0.0 in /home/nfel/.local/lib/python3.6/site-packages (from scikit-learn->sklearn) (2.1.0)\n", - "Requirement already satisfied: numpy>=1.13.3 in /home/nfel/.local/lib/python3.6/site-packages (from scikit-learn->sklearn) (1.19.5)\n", - "Requirement already satisfied: scipy>=0.19.1 in /home/nfel/.local/lib/python3.6/site-packages (from scikit-learn->sklearn) (1.5.4)\n", - "Requirement already satisfied: joblib>=0.11 in /home/nfel/.local/lib/python3.6/site-packages (from scikit-learn->sklearn) (1.0.1)\n", - "Defaulting to user installation because normal site-packages is not writeable\n", - "Requirement already satisfied: pandas in /home/nfel/.local/lib/python3.6/site-packages (1.1.5)\n", - "Requirement already satisfied: numpy>=1.15.4 in /home/nfel/.local/lib/python3.6/site-packages (from pandas) (1.19.5)\n", - "Requirement already satisfied: pytz>=2017.2 in /home/nfel/.local/lib/python3.6/site-packages (from pandas) (2021.1)\n", - "Requirement already satisfied: python-dateutil>=2.7.3 in /home/nfel/.local/lib/python3.6/site-packages (from pandas) (2.8.1)\n", - "Requirement already satisfied: six>=1.5 in /home/nfel/.local/lib/python3.6/site-packages (from python-dateutil>=2.7.3->pandas) (1.16.0)\n" + "Requirement already satisfied: pandas in c:\\programdata\\anaconda3\\lib\\site-packages (1.3.4)\n", + "Requirement already satisfied: pytz>=2017.3 in c:\\programdata\\anaconda3\\lib\\site-packages (from pandas) (2021.3)\n", + "Requirement already satisfied: numpy>=1.17.3 in c:\\programdata\\anaconda3\\lib\\site-packages (from pandas) (1.20.3)\n", + "Requirement already satisfied: python-dateutil>=2.7.3 in c:\\programdata\\anaconda3\\lib\\site-packages (from pandas) (2.8.2)\n", + "Requirement already satisfied: six>=1.5 in c:\\programdata\\anaconda3\\lib\\site-packages (from python-dateutil>=2.7.3->pandas) (1.16.0)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "WARNING: Ignoring invalid distribution -andas (c:\\programdata\\anaconda3\\lib\\site-packages)\n", + "WARNING: Ignoring invalid distribution -andas (c:\\programdata\\anaconda3\\lib\\site-packages)\n", + "WARNING: Ignoring invalid distribution -andas (c:\\programdata\\anaconda3\\lib\\site-packages)\n", + "WARNING: Ignoring invalid distribution -andas (c:\\programdata\\anaconda3\\lib\\site-packages)\n", + "WARNING: Ignoring invalid distribution -andas (c:\\programdata\\anaconda3\\lib\\site-packages)\n", + "WARNING: Ignoring invalid distribution -andas (c:\\programdata\\anaconda3\\lib\\site-packages)\n" ] } ], @@ -181,7 +411,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ @@ -189,6 +419,7 @@ "# Suppress warnings\n", "warnings.filterwarnings('ignore')\n", "\n", + "\n", "import sys\n", "# Include root directory in module path\n", "sys.path.append('src')\n", @@ -205,7 +436,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ @@ -228,7 +459,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 10, "metadata": { "pycharm": { "name": "#%%\n" @@ -239,19 +470,19 @@ "name": "stdout", "output_type": "stream", "text": [ - "Loading Thermostat configuration: ag_news-bert-lime-100\n", - "Downloading and preparing dataset thermostat/ag_news-bert-lime-100 to C:\\Users\\49176\\.cache\\huggingface\\datasets\\thermostat\\ag_news-bert-lime-100\\1.0.1\\0cbe93e1fbe5b8ed0217559442d8b49a80fd4c2787185f2d7940817c67d8707b...\n" + "Loading Thermostat configuration: imdb-bert-lig\n", + "Downloading and preparing dataset thermostat/imdb-bert-lig to C:\\Users\\49176\\.cache\\huggingface\\datasets\\thermostat\\imdb-bert-lig\\1.0.1\\0cbe93e1fbe5b8ed0217559442d8b49a80fd4c2787185f2d7940817c67d8707b...\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "10bcd13dc63844e4b73e7f382f9e2a7c", + "model_id": "207ec97028ba4563b4b29eedad4525a9", "version_major": 2, "version_minor": 0 }, "text/plain": [ - "Downloading: 0%| | 0.00/48.2M [00:00', -0.00799628160893917, 300),\n", + " ('<', -0.0044715143740177155, 301),\n", + " ('br', 0.014373987913131714, 302),\n", + " ('/', 0.016613293439149857, 303),\n", + " ('>', 0.029748301953077316, 304),\n", + " ('spoil', 0.016072934493422508, 305),\n", + " ('##er', 0.05765294283628464, 306),\n", + " (':', 0.005897448863834143, 307),\n", + " ('this', -0.08212670683860779, 308),\n", + " ('movie', -0.03927135467529297, 309),\n", + " ('doesn', -0.00897288415580988, 310),\n", + " (\"'\", 0.005554314237087965, 311),\n", + " ('t', 0.018459515646100044, 312),\n", + " ('have', 0.004178288858383894, 313),\n", + " ('a', 0.006657324731349945, 314),\n", + " ('goo', -0.004324286710470915, 315),\n", + " ('##fs', -0.009721309877932072, 316),\n", + " ('section', 0.010842953808605671, 317),\n", + " ('.', 0.007669608108699322, 318),\n", + " ('wonder', 0.019997509196400642, 319),\n", + " (',', 0.02238672412931919, 320),\n", + " ('didn', 0.026670066639780998, 321),\n", + " (\"'\", 0.004090417176485062, 322),\n", + " ('t', -0.02376842498779297, 323),\n", + " ('anybody', 0.015019877813756466, 324),\n", + " ('notice', -0.032073259353637695, 325),\n", + " ('that', -0.027400804683566093, 326),\n", + " ('hand', 0.013895555399358273, 327),\n", + " ('in', -0.01581115648150444, 328),\n", + " ('the', -0.0005177874118089676, 329),\n", + " ('2', -0.013833531178534031, 330),\n", + " ('part', 0.003973441198468208, 331),\n", + " ('when', -0.0276736319065094, 332),\n", + " ('the', 0.03735414519906044, 333),\n", + " ('kidnap', -0.0038885336834937334, 334),\n", + " ('##pers', 0.007577804382890463, 335),\n", + " ('decided', -0.007950885221362114, 336),\n", + " ('to', 0.008233290165662766, 337),\n", + " ('go', -0.026257596909999847, 338),\n", + " ('home', 0.0024092162493616343, 339),\n", + " ('?', 0.04507692903280258, 340),\n", + " ('looks', -0.02298576943576336, 341),\n", + " ('like', -0.0147428372874856, 342),\n", + " ('a', 0.041484564542770386, 343),\n", + " ('part', 0.04004620015621185, 344),\n", + " ('of', 0.02201233059167862, 345),\n", + " ('crew', 0.004088917281478643, 346),\n", + " (',', 0.0054010068997740746, 347),\n", + " ('he', -0.014832447282969952, 348),\n", + " ('##he', 0.0015261276857927442, 349),\n", + " ('.', -0.006007165182381868, 350),\n", + " ('i', 0.005043786019086838, 351),\n", + " ('know', 0.0077549186535179615, 352),\n", + " ('i', 0.027806663885712624, 353),\n", + " ('should', -0.018970897421240807, 354),\n", + " ('better', -0.010785568505525589, 355),\n", + " ('post', 0.05738293007016182, 356),\n", + " ('this', -0.012203543446958065, 357),\n", + " ('in', 0.04099973291158676, 358),\n", + " ('forums', 0.06110705807805061, 359),\n", + " (',', 0.015236682258546352, 360),\n", + " ('but', 0.03359401598572731, 361),\n", + " ('i', -0.005674791056662798, 362),\n", + " ('don', -0.011217826046049595, 363),\n", + " (\"'\", 0.004538937471807003, 364),\n", + " ('t', 0.017671801149845123, 365),\n", + " ('agree', 0.01695604808628559, 366),\n", + " ('with', 0.010900290682911873, 367),\n", + " ('some', 0.11742840707302094, 368),\n", + " ('policies', 0.056090448051691055, 369),\n", + " ('here', 0.026037804782390594, 370),\n", + " ('.', 0.01010459940880537, 371),\n", + " ('[SEP]', 0.0, 372)]\n" ] } ], @@ -426,52 +981,312 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 20, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "token_index 0 1 2 3 4 7 8 \\\n", - "token [CLS] stunt pilots to snag a falling \n", - "attribution 0.0 -0.022761 0.099706 0.250304 0.485399 -0.62276 0.003058 \n", - "text_field text text text text text text text \n", - "\n", - "token_index 9 10 11 12 13 14 15 \\\n", - "token nasa craft nasa # 39 ; s \n", - "attribution 0.291342 0.181943 0.274672 1.0 0.119171 0.03814 0.807925 \n", - "text_field text text text text text text text \n", - "\n", - "token_index 16 17 18 19 20 21 \\\n", - "token three - year effort to bring \n", - "attribution -0.027412 -0.098508 0.258135 -0.057648 -0.045056 -0.110808 \n", + "token_index 0 1 2 3 4 5 \\\n", + "token [CLS] i first saw it at \n", + "attribution 0.0 0.170457 -0.090006 0.107485 -0.069431 0.039387 \n", + "text_field text text text text text text \n", + "\n", + "token_index 6 8 9 10 11 12 \\\n", + "token 5am january 1 , 2009 , \n", + "attribution 0.023045 0.415474 0.384542 0.672036 0.261633 -0.006347 \n", "text_field text text text text text text \n", "\n", - "token_index 22 23 24 25 26 27 \\\n", - "token some genuine star dust back to \n", - "attribution 0.136781 0.017323 -0.003065 0.051997 0.131099 0.013059 \n", + "token_index 13 14 15 16 17 18 \\\n", + "token and after a day i watched \n", + "attribution 0.080754 -0.114537 -0.139005 0.174024 0.008026 -0.149438 \n", "text_field text text text text text text \n", "\n", - "token_index 28 29 30 31 32 33 \\\n", - "token earth is set for a dramatic \n", - "attribution 0.116943 0.01814 -0.005572 0.019079 0.403523 -0.022874 \n", - "text_field text text text text text text \n", + "token_index 19 20 21 22 23 24 \\\n", + "token it again and i want to \n", + "attribution 0.141589 0.042472 0.078582 0.173841 0.059821 0.27329 \n", + "text_field text text text text text text \n", + "\n", + "token_index 25 26 27 28 29 30 \\\n", + "token watch it again . love everything \n", + "attribution 0.085501 0.119676 -0.051532 0.032687 0.349016 0.175231 \n", + "text_field text text text text text text \n", + "\n", + "token_index 31 32 33 34 35 36 \\\n", + "token ( well , almost , so \n", + "attribution 0.152014 -0.081542 0.005581 -0.074633 -0.133133 0.033487 \n", + "text_field text text text text text text \n", + "\n", + "token_index 37 38 39 40 41 42 \\\n", + "token 9 stars ) about it . \n", + "attribution -0.057861 -0.131787 0.061106 -0.073951 -0.195056 0.072654 \n", + "text_field text text text text text text \n", + "\n", + "token_index 43 44 45 46 47 48 \\\n", + "token no color , beautiful naive stories \n", + "attribution -0.070579 -0.090318 0.027063 0.183721 0.023685 0.076976 \n", + "text_field text text text text text text \n", + "\n", + "token_index 49 50 51 53 54 55 \\\n", + "token , funny gangsters , anna , \n", + "attribution 0.093026 -0.099573 0.111415 0.033609 -0.018002 -0.031873 \n", + "text_field text text text text text text \n", + "\n", + "token_index 56 57 58 59 60 61 \\\n", + "token camera work , music . well \n", + "attribution -0.168191 0.005432 0.016994 -0.166483 -0.003029 0.000305 \n", + "text_field text text text text text text \n", + "\n", + "token_index 62 63 64 65 66 67 \\\n", + "token , sometimes you just want to \n", + "attribution 0.044889 0.296482 0.164041 -0.074963 -0.053668 -0.133754 \n", + "text_field text text text text text text \n", + "\n", + "token_index 68 69 70 71 72 73 \\\n", + "token listen little bit longer and the \n", + "attribution -0.078899 -0.02474 0.037653 0.00921 -0.050143 0.025857 \n", + "text_field text text text text text text \n", + "\n", + "token_index 74 75 76 77 78 79 \\\n", + "token music just stops . but this \n", + "attribution -0.122204 -0.058685 -0.067941 -0.023277 0.191971 -0.035146 \n", + "text_field text text text text text text \n", "\n", - "token_index 34 35 36 37 38 39 \\\n", - "token finale sept . 8 when hollywood \n", - "attribution -0.027122 -0.032759 -0.118329 -0.193103 -0.134687 0.064931 \n", + "token_index 80 81 82 83 84 85 \\\n", + "token is not a musical after all \n", + "attribution 0.007563 -0.002562 0.122107 -1.0 -0.180618 -0.109668 \n", + "text_field text text text text text text \n", + "\n", + "token_index 86 87 88 89 90 91 \\\n", + "token . i like anna ' s \n", + "attribution 0.071741 0.011276 0.159978 0.008217 0.175088 0.379829 \n", + "text_field text text text text text text \n", + "\n", + "token_index 92 93 94 95 96 98 \\\n", + "token acting , this naive wannabe gangster \n", + "attribution -0.615353 -0.147787 -0.492506 -0.481686 -0.317616 0.193298 \n", + "text_field text text text text text text \n", + "\n", + "token_index 99 100 101 102 103 104 \\\n", + "token girl , how she speaks , \n", + "attribution 0.034738 0.0362 0.155725 -0.04965 -0.028578 0.027849 \n", + "text_field text text text text text text \n", + "\n", + "token_index 105 106 107 108 109 110 \\\n", + "token holds the gun , everything makes \n", + "attribution 0.106511 -0.098679 -0.117951 -0.088747 0.189717 0.073041 \n", + "text_field text text text text text text \n", + "\n", + "token_index 111 112 113 114 115 116 \\\n", + "token me smile . no , it \n", + "attribution 0.091698 0.149203 0.029522 -0.038933 0.088369 -0.021045 \n", + "text_field text text text text text text \n", + "\n", + "token_index 117 118 119 120 121 122 \\\n", + "token ' s not that funny , \n", + "attribution 0.08587 0.022358 -0.009932 -0.002558 -0.06864 0.059714 \n", + "text_field text text text text text text \n", + "\n", + "token_index 123 124 125 126 127 128 \\\n", + "token though i have laughed a bit \n", + "attribution 0.009371 0.073627 -0.05037 0.012842 -0.011386 0.027237 \n", + "text_field text text text text text text \n", + "\n", + "token_index 129 130 131 132 133 134 \\\n", + "token at some moments , it ' \n", + "attribution 0.06302 -0.007297 0.27467 0.151535 -0.031748 0.119639 \n", + "text_field text text text text text text \n", + "\n", + "token_index 135 136 137 138 139 140 \\\n", + "token s just so subtle . excellent \n", + "attribution 0.111179 0.013132 0.111981 0.250567 0.113439 0.652182 \n", "text_field text text text text text text \n", "\n", - "token_index 40 41 42 43 44 45 \\\n", - "token helicopter pilots will attempt a midair \n", - "attribution -0.012046 0.020483 0.149106 0.164166 -0.379799 0.073093 \n", - "text_field text text text text text text \n", + "token_index 141 142 143 144 147 148 \\\n", + "token work by samuel benchetrit . though \n", + "attribution 0.130436 0.0702 0.082449 0.060754 0.04657 0.109886 \n", + "text_field text text text text text text \n", + "\n", + "token_index 149 150 151 152 153 154 \\\n", + "token 3d nouvelle seems weaker , but \n", + "attribution -0.10006 -0.025344 -0.059655 0.053869 0.010111 0.047435 \n", + "text_field text text text text text text \n", + "\n", + "token_index 155 156 157 158 160 161 \\\n", + "token they are also gangsters , maybe \n", + "attribution -0.087616 -0.079145 -0.067326 0.164802 0.057904 0.070503 \n", + "text_field text text text text text text \n", + "\n", + "token_index 162 163 164 165 166 167 \\\n", + "token even worse , cause they are \n", + "attribution 0.011633 0.054106 -0.090237 -0.056379 -0.067337 -0.051897 \n", + "text_field text text text text text text \n", + "\n", + "token_index 168 169 170 171 172 173 \\\n", + "token stealing ideas . and the last \n", + "attribution 0.024267 -0.073004 -0.054746 0.023162 -0.104921 -0.073514 \n", + "text_field text text text text text text \n", + "\n", + "token_index 174 175 176 177 178 179 \\\n", + "token scene is my favorite . makes \n", + "attribution -0.191759 0.083886 0.277159 0.410594 -0.144259 -0.102936 \n", + "text_field text text text text text text \n", + "\n", + "token_index 180 181 182 183 184 185 \\\n", + "token me feel so warm and . \n", + "attribution 0.10451 0.198935 0.178571 0.10484 0.059488 -0.13345 \n", + "text_field text text text text text text \n", + "\n", + "token_index 186 187 188 189 190 191 \\\n", + "token . romantic . yes , i \n", + "attribution -0.033263 0.102611 -0.023939 0.086991 -0.006159 -0.016589 \n", + "text_field text text text text text text \n", + "\n", + "token_index 192 193 194 195 196 197 \\\n", + "token would recommend this movie for the \n", + "attribution -0.016266 0.05755 0.018799 0.077491 0.019652 0.03683 \n", + "text_field text text text text text text \n", + "\n", + "token_index 198 199 200 201 202 203 \\\n", + "token romantic souls with a taste for \n", + "attribution 0.066298 0.036036 0.097791 0.03978 0.01561 0.028333 \n", + "text_field text text text text text text \n", + "\n", + "token_index 204 205 206 207 210 211 \\\n", + "token such art - housish movies . \n", + "attribution 0.044567 0.005998 0.05708 0.019486 0.120159 0.042933 \n", + "text_field text text text text text text \n", + "\n", + "token_index 212 213 214 215 216 217 \\\n", + "token and i don ' t agree \n", + "attribution 0.029758 0.066354 0.046331 0.03605 0.024658 -0.078732 \n", + "text_field text text text text text text \n", + "\n", + "token_index 218 219 220 221 222 223 \\\n", + "token with those comparing it to pulp \n", + "attribution 0.023078 0.139114 -0.000608 0.004622 -0.016126 0.016344 \n", + "text_field text text text text text text \n", + "\n", + "token_index 224 225 226 227 228 229 \\\n", + "token fiction . it ' s not \n", + "attribution 0.033682 0.039149 0.068615 0.068419 0.067978 0.16568 \n", + "text_field text text text text text text \n", + "\n", + "token_index 230 231 232 233 234 235 \\\n", + "token about action and twisted story , \n", + "attribution -0.018121 -0.030201 0.007424 0.04382 -0.021345 -0.055491 \n", + "text_field text text text text text text \n", + "\n", + "token_index 236 237 238 241 242 243 \\\n", + "token though all vignettes intersect . it \n", + "attribution -0.020809 0.052888 0.10707 0.069352 0.126138 0.116565 \n", + "text_field text text text text text text \n", + "\n", + "token_index 244 245 246 247 248 249 \\\n", + "token ' s calm , and maybe \n", + "attribution 0.155122 0.223995 0.810752 0.191691 0.136151 0.297176 \n", + "text_field text text text text text text \n", + "\n", + "token_index 250 251 252 253 254 255 \\\n", + "token too slow movie for most of \n", + "attribution 0.074983 0.034951 0.046917 -0.633489 -0.049592 -0.020356 \n", + "text_field text text text text text text \n", + "\n", + "token_index 256 257 258 259 260 261 \\\n", + "token the people . it ' s \n", + "attribution 0.005256 0.06978 -0.031699 -0.074841 0.137725 0.039087 \n", + "text_field text text text text text text \n", + "\n", + "token_index 262 263 264 265 266 267 \\\n", + "token about characters , their feelings , \n", + "attribution 0.161196 -0.346557 0.42938 0.237538 0.217716 0.394794 \n", + "text_field text text text text text text \n", + "\n", + "token_index 268 269 270 271 272 273 \\\n", + "token very subtle . anyway , probably \n", + "attribution 0.407511 0.217422 -0.229159 -0.179182 0.084117 0.122595 \n", + "text_field text text text text text text \n", + "\n", + "token_index 274 275 276 277 278 279 \\\n", + "token this review won ' t be \n", + "attribution -0.253398 0.031705 0.105428 0.031298 0.034757 -0.06585 \n", + "text_field text text text text text text \n", + "\n", + "token_index 280 281 282 283 284 285 \\\n", + "token of much help to anyone ( \n", + "attribution -0.114305 -0.097195 -0.042372 -0.079064 -0.12552 0.049177 \n", + "text_field text text text text text text \n", + "\n", + "token_index 286 287 288 289 290 291 \\\n", + "token my first ) , just wanted \n", + "attribution 0.025227 0.01541 -0.090277 0.126098 -0.088128 0.212069 \n", + "text_field text text text text text text \n", + "\n", + "token_index 292 293 294 295 296 297 \\\n", + "token to express my appreciation . < \n", + "attribution -0.046134 0.134626 -0.040483 0.093308 -0.158916 0.012593 \n", + "text_field text text text text text text \n", + "\n", + "token_index 298 299 300 301 302 303 \\\n", + "token br / > < br / \n", + "attribution -0.017436 0.004035 -0.014979 -0.008376 0.026926 0.031121 \n", + "text_field text text text text text text \n", + "\n", + "token_index 304 305 307 308 309 310 \\\n", + "token > spoiler : this movie doesn \n", + "attribution 0.055725 0.107997 0.011047 -0.153842 -0.073564 -0.016808 \n", + "text_field text text text text text text \n", + "\n", + "token_index 311 312 313 314 315 317 \\\n", + "token ' t have a goofs section \n", + "attribution 0.010405 0.034579 0.007827 0.012471 -0.01821 0.020311 \n", + "text_field text text text text text text \n", + "\n", + "token_index 318 319 320 321 322 323 \\\n", + "token . wonder , didn ' t \n", + "attribution 0.014367 0.03746 0.041935 0.049959 0.007662 -0.044524 \n", + "text_field text text text text text text \n", + "\n", + "token_index 324 325 326 327 328 329 \\\n", + "token anybody notice that hand in the \n", + "attribution 0.028136 -0.060081 -0.051328 0.02603 -0.029618 -0.00097 \n", + "text_field text text text text text text \n", + "\n", + "token_index 330 331 332 333 334 336 \\\n", + "token 2 part when the kidnappers decided \n", + "attribution -0.025913 0.007443 -0.051839 0.069973 0.014195 -0.014894 \n", + "text_field text text text text text text \n", + "\n", + "token_index 337 338 339 340 341 342 \\\n", + "token to go home ? looks like \n", + "attribution 0.015423 -0.049187 0.004513 0.084439 -0.043058 -0.027617 \n", + "text_field text text text text text text \n", "\n", - "token_index 47 48 \n", - "token retrieval [SEP] \n", - "attribution -0.012136 0.0 \n", - "text_field text text \n" + "token_index 343 344 345 346 347 348 \\\n", + "token a part of crew , hehe \n", + "attribution 0.07771 0.075016 0.041234 0.007659 0.010117 -0.027785 \n", + "text_field text text text text text text \n", + "\n", + "token_index 350 351 352 353 354 355 \\\n", + "token . i know i should better \n", + "attribution -0.011253 0.009448 0.014527 0.052088 -0.035537 -0.020204 \n", + "text_field text text text text text text \n", + "\n", + "token_index 356 357 358 359 360 361 \\\n", + "token post this in forums , but \n", + "attribution 0.107491 -0.02286 0.076802 0.114468 0.028542 0.062929 \n", + "text_field text text text text text text \n", + "\n", + "token_index 362 363 364 365 366 367 \\\n", + "token i don ' t agree with \n", + "attribution -0.01063 -0.021014 0.008502 0.033103 0.031763 0.020419 \n", + "text_field text text text text text text \n", + "\n", + "token_index 368 369 370 371 372 \n", + "token some policies here . [SEP] \n", + "attribution 0.21997 0.10507 0.048775 0.018928 0.0 \n", + "text_field text text text text text \n" ] } ], @@ -492,7 +1307,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 16, "metadata": {}, "outputs": [ { @@ -504,124 +1319,1619 @@ " [CLS]\n", " \n", " \n", - " \n", - " stunt\n", + " i\n", " \n", " \n", - " \n", - " pilots\n", + " first\n", " \n", " \n", - " \n", - " to\n", + " saw\n", " \n", " \n", - " \n", - " snag\n", + " it\n", " \n", " \n", - " \n", - " a\n", + " at\n", " \n", " \n", - " \n", - " falling\n", + " 5am\n", " \n", " \n", - " \n", - " nasa\n", + " january\n", " \n", " \n", - " \n", - " craft\n", + " 1\n", " \n", " \n", - " \n", - " nasa\n", + " ,\n", " \n", " \n", - " \n", - " #\n", + " 2009\n", " \n", " \n", - " \n", - " 39\n", + " ,\n", " \n", " \n", - " \n", - " ;\n", + " and\n", " \n", " \n", - " \n", - " s\n", + " after\n", " \n", " \n", - " \n", - " three\n", + " a\n", " \n", " \n", - " \n", - " -\n", + " day\n", " \n", " \n", - " \n", - " year\n", + " i\n", " \n", " \n", - " \n", - " effort\n", + " watched\n", " \n", " \n", - " \n", - " to\n", + " it\n", " \n", " \n", - " \n", - " bring\n", + " again\n", " \n", " \n", - " \n", + " and\n", + " \n", + " \n", + " \n", + " i\n", + " \n", + " \n", + " \n", + " want\n", + " \n", + " \n", + " \n", + " to\n", + " \n", + " \n", + " \n", + " watch\n", + " \n", + " \n", + " \n", + " it\n", + " \n", + " \n", + " \n", + " again\n", + " \n", + " \n", + " \n", + " .\n", + " \n", + " \n", + " \n", + " love\n", + " \n", + " \n", + " \n", + " everything\n", + " \n", + " \n", + " \n", + " (\n", + " \n", + " \n", + " \n", + " well\n", + " \n", + " \n", + " \n", + " ,\n", + " \n", + " \n", + " \n", + " almost\n", + " \n", + " \n", + " \n", + " ,\n", + " \n", + " \n", + " \n", + " so\n", + " \n", + " \n", + " \n", + " 9\n", + " \n", + " \n", + " \n", + " stars\n", + " \n", + " \n", + " \n", + " )\n", + " \n", + " \n", + " \n", + " about\n", + " \n", + " \n", + " \n", + " it\n", + " \n", + " \n", + " \n", + " .\n", + " \n", + " \n", + " \n", + " no\n", + " \n", + " \n", + " \n", + " color\n", + " \n", + " \n", + " \n", + " ,\n", + " \n", + " \n", + " \n", + " beautiful\n", + " \n", + " \n", + " \n", + " naive\n", + " \n", + " \n", + " \n", + " stories\n", + " \n", + " \n", + " \n", + " ,\n", + " \n", + " \n", + " \n", + " funny\n", + " \n", + " \n", + " \n", + " gangsters\n", + " \n", + " \n", + " \n", + " ,\n", + " \n", + " \n", + " \n", + " anna\n", + " \n", + " \n", + " \n", + " ,\n", + " \n", + " \n", + " \n", + " camera\n", + " \n", + " \n", + " \n", + " work\n", + " \n", + " \n", + " \n", + " ,\n", + " \n", + " \n", + " \n", + " music\n", + " \n", + " \n", + " \n", + " .\n", + " \n", + " \n", + " \n", + " well\n", + " \n", + " \n", + " \n", + " ,\n", + " \n", + " \n", + " \n", + " sometimes\n", + " \n", + " \n", + " \n", + " you\n", + " \n", + " \n", + " \n", + " just\n", + " \n", + " \n", + " \n", + " want\n", + " \n", + " \n", + " \n", + " to\n", + " \n", + " \n", + " \n", + " listen\n", + " \n", + " \n", + " \n", + " little\n", + " \n", + " \n", + " \n", + " bit\n", + " \n", + " \n", + " \n", + " longer\n", + " \n", + " \n", + " \n", + " and\n", + " \n", + " \n", + " \n", + " the\n", + " \n", + " \n", + " \n", + " music\n", + " \n", + " \n", + " \n", + " just\n", + " \n", + " \n", + " \n", + " stops\n", + " \n", + " \n", + " \n", + " .\n", + " \n", + " \n", + " \n", + " but\n", + " \n", + " \n", + " \n", + " this\n", + " \n", + " \n", + " \n", + " is\n", + " \n", + " \n", + " \n", + " not\n", + " \n", + " \n", + " \n", + " a\n", + " \n", + " \n", + " \n", + " musical\n", + " \n", + " \n", + " \n", + " after\n", + " \n", + " \n", + " \n", + " all\n", + " \n", + " \n", + " \n", + " .\n", + " \n", + " \n", + " \n", + " i\n", + " \n", + " \n", + " \n", + " like\n", + " \n", + " \n", + " \n", + " anna\n", + " \n", + " \n", + " \n", + " '\n", + " \n", + " \n", + " \n", + " s\n", + " \n", + " \n", + " \n", + " acting\n", + " \n", + " \n", + " \n", + " ,\n", + " \n", + " \n", + " \n", + " this\n", + " \n", + " \n", + " \n", + " naive\n", + " \n", + " \n", + " \n", + " wannabe\n", + " \n", + " \n", + " \n", + " gangster\n", + " \n", + " \n", + " \n", + " girl\n", + " \n", + " \n", + " \n", + " ,\n", + " \n", + " \n", + " \n", + " how\n", + " \n", + " \n", + " \n", + " she\n", + " \n", + " \n", + " \n", + " speaks\n", + " \n", + " \n", + " \n", + " ,\n", + " \n", + " \n", + " \n", + " holds\n", + " \n", + " \n", + " \n", + " the\n", + " \n", + " \n", + " \n", + " gun\n", + " \n", + " \n", + " \n", + " ,\n", + " \n", + " \n", + " \n", + " everything\n", + " \n", + " \n", + " \n", + " makes\n", + " \n", + " \n", + " \n", + " me\n", + " \n", + " \n", + " \n", + " smile\n", + " \n", + " \n", + " \n", + " .\n", + " \n", + " \n", + " \n", + " no\n", + " \n", + " \n", + " \n", + " ,\n", + " \n", + " \n", + " \n", + " it\n", + " \n", + " \n", + " \n", + " '\n", + " \n", + " \n", + " \n", + " s\n", + " \n", + " \n", + " \n", + " not\n", + " \n", + " \n", + " \n", + " that\n", + " \n", + " \n", + " \n", + " funny\n", + " \n", + " \n", + " \n", + " ,\n", + " \n", + " \n", + " \n", + " though\n", + " \n", + " \n", + " \n", + " i\n", + " \n", + " \n", + " \n", + " have\n", + " \n", + " \n", + " \n", + " laughed\n", + " \n", + " \n", + " \n", + " a\n", + " \n", + " \n", + " \n", + " bit\n", + " \n", + " \n", + " \n", + " at\n", + " \n", + " \n", + " \n", + " some\n", + " \n", + " \n", + " \n", + " moments\n", + " \n", + " \n", + " \n", + " ,\n", + " \n", + " \n", + " \n", + " it\n", + " \n", + " \n", + " \n", + " '\n", + " \n", + " \n", + " \n", + " s\n", + " \n", + " \n", + " \n", + " just\n", + " \n", + " \n", + " \n", + " so\n", + " \n", + " \n", + " \n", + " subtle\n", + " \n", + " \n", + " \n", + " .\n", + " \n", + " \n", + " \n", + " excellent\n", + " \n", + " \n", + " \n", + " work\n", + " \n", + " \n", + " \n", + " by\n", + " \n", + " \n", + " \n", + " samuel\n", + " \n", + " \n", + " \n", + " benchetrit\n", + " \n", + " \n", + " \n", + " .\n", + " \n", + " \n", + " \n", + " though\n", + " \n", + " \n", + " \n", + " 3d\n", + " \n", + " \n", + " \n", + " nouvelle\n", + " \n", + " \n", + " \n", + " seems\n", + " \n", + " \n", + " \n", + " weaker\n", + " \n", + " \n", + " \n", + " ,\n", + " \n", + " \n", + " \n", + " but\n", + " \n", + " \n", + " \n", + " they\n", + " \n", + " \n", + " \n", + " are\n", + " \n", + " \n", + " \n", + " also\n", + " \n", + " \n", + " \n", + " gangsters\n", + " \n", + " \n", + " \n", + " ,\n", + " \n", + " \n", + " \n", + " maybe\n", + " \n", + " \n", + " \n", + " even\n", + " \n", + " \n", + " \n", + " worse\n", + " \n", + " \n", + " \n", + " ,\n", + " \n", + " \n", + " \n", + " cause\n", + " \n", + " \n", + " \n", + " they\n", + " \n", + " \n", + " \n", + " are\n", + " \n", + " \n", + " \n", + " stealing\n", + " \n", + " \n", + " \n", + " ideas\n", + " \n", + " \n", + " \n", + " .\n", + " \n", + " \n", + " \n", + " and\n", + " \n", + " \n", + " \n", + " the\n", + " \n", + " \n", + " \n", + " last\n", + " \n", + " \n", + " \n", + " scene\n", + " \n", + " \n", + " \n", + " is\n", + " \n", + " \n", + " \n", + " my\n", + " \n", + " \n", + " \n", + " favorite\n", + " \n", + " \n", + " \n", + " .\n", + " \n", + " \n", + " \n", + " makes\n", + " \n", + " \n", + " \n", + " me\n", + " \n", + " \n", + " \n", + " feel\n", + " \n", + " \n", + " \n", + " so\n", + " \n", + " \n", + " \n", + " warm\n", + " \n", + " \n", + " \n", + " and\n", + " \n", + " \n", + " \n", + " .\n", + " \n", + " \n", + " \n", + " .\n", + " \n", + " \n", + " \n", + " romantic\n", + " \n", + " \n", + " \n", + " .\n", + " \n", + " \n", + " \n", + " yes\n", + " \n", + " \n", + " \n", + " ,\n", + " \n", + " \n", + " \n", + " i\n", + " \n", + " \n", + " \n", + " would\n", + " \n", + " \n", + " \n", + " recommend\n", + " \n", + " \n", + " \n", + " this\n", + " \n", + " \n", + " \n", + " movie\n", + " \n", + " \n", + " \n", + " for\n", + " \n", + " \n", + " \n", + " the\n", + " \n", + " \n", + " \n", + " romantic\n", + " \n", + " \n", + " \n", + " souls\n", + " \n", + " \n", + " \n", + " with\n", + " \n", + " \n", + " \n", + " a\n", + " \n", + " \n", + " \n", + " taste\n", + " \n", + " \n", + " \n", + " for\n", + " \n", + " \n", + " \n", + " such\n", + " \n", + " \n", + " \n", + " art\n", + " \n", + " \n", + " \n", + " -\n", + " \n", + " \n", + " \n", + " housish\n", + " \n", + " \n", + " \n", + " movies\n", + " \n", + " \n", + " \n", + " .\n", + " \n", + " \n", + " \n", + " and\n", + " \n", + " \n", + " \n", + " i\n", + " \n", + " \n", + " \n", + " don\n", + " \n", + " \n", + " \n", + " '\n", + " \n", + " \n", + " \n", + " t\n", + " \n", + " \n", + " \n", + " agree\n", + " \n", + " \n", + " \n", + " with\n", + " \n", + " \n", + " \n", + " those\n", + " \n", + " \n", + " \n", + " comparing\n", + " \n", + " \n", + " \n", + " it\n", + " \n", + " \n", + " \n", + " to\n", + " \n", + " \n", + " \n", + " pulp\n", + " \n", + " \n", + " \n", + " fiction\n", + " \n", + " \n", + " \n", + " .\n", + " \n", + " \n", + " \n", + " it\n", + " \n", + " \n", + " \n", + " '\n", + " \n", + " \n", + " \n", + " s\n", + " \n", + " \n", + " \n", + " not\n", + " \n", + " \n", + " \n", + " about\n", + " \n", + " \n", + " \n", + " action\n", + " \n", + " \n", + " \n", + " and\n", + " \n", + " \n", + " \n", + " twisted\n", + " \n", + " \n", + " \n", + " story\n", + " \n", + " \n", + " \n", + " ,\n", + " \n", + " \n", + " \n", + " though\n", + " \n", + " \n", + " \n", + " all\n", + " \n", + " \n", + " \n", + " vignettes\n", + " \n", + " \n", + " \n", + " intersect\n", + " \n", + " \n", + " \n", + " .\n", + " \n", + " \n", + " \n", + " it\n", + " \n", + " \n", + " \n", + " '\n", + " \n", + " \n", + " \n", + " s\n", + " \n", + " \n", + " \n", + " calm\n", + " \n", + " \n", + " \n", + " ,\n", + " \n", + " \n", + " \n", + " and\n", + " \n", + " \n", + " \n", + " maybe\n", + " \n", + " \n", + " \n", + " too\n", + " \n", + " \n", + " \n", + " slow\n", + " \n", + " \n", + " \n", + " movie\n", + " \n", + " \n", + " \n", + " for\n", + " \n", + " \n", + " \n", + " most\n", + " \n", + " \n", + " \n", + " of\n", + " \n", + " \n", + " \n", + " the\n", + " \n", + " \n", + " \n", + " people\n", + " \n", + " \n", + " \n", + " .\n", + " \n", + " \n", + " \n", + " it\n", + " \n", + " \n", + " \n", + " '\n", + " \n", + " \n", + " \n", + " s\n", + " \n", + " \n", + " \n", + " about\n", + " \n", + " \n", + " \n", + " characters\n", + " \n", + " \n", + " \n", + " ,\n", + " \n", + " \n", + " \n", + " their\n", + " \n", + " \n", + " \n", + " feelings\n", + " \n", + " \n", + " \n", + " ,\n", + " \n", + " \n", + " \n", + " very\n", + " \n", + " \n", + " \n", + " subtle\n", + " \n", + " \n", + " \n", + " .\n", + " \n", + " \n", + " \n", + " anyway\n", + " \n", + " \n", + " \n", + " ,\n", + " \n", + " \n", + " \n", + " probably\n", + " \n", + " \n", + " \n", + " this\n", + " \n", + " \n", + " \n", + " review\n", + " \n", + " \n", + " \n", + " won\n", + " \n", + " \n", + " \n", + " '\n", + " \n", + " \n", + " \n", + " t\n", + " \n", + " \n", + " \n", + " be\n", + " \n", + " \n", + " \n", + " of\n", + " \n", + " \n", + " \n", + " much\n", + " \n", + " \n", + " \n", + " help\n", + " \n", + " \n", + " \n", + " to\n", + " \n", + " \n", + " \n", + " anyone\n", + " \n", + " \n", + " \n", + " (\n", + " \n", + " \n", + " \n", - " some\n", + " my\n", " \n", " \n", - " \n", + " first\n", + " \n", + " \n", + " \n", + " )\n", + " \n", + " \n", + " \n", + " ,\n", + " \n", + " \n", + " \n", + " just\n", + " \n", + " \n", + " \n", + " wanted\n", + " \n", + " \n", + " \n", + " to\n", + " \n", + " \n", + " \n", + " express\n", + " \n", + " \n", + " \n", + " my\n", + " \n", + " \n", + " \n", + " appreciation\n", + " \n", + " \n", + " \n", + " .\n", + " \n", + " \n", + " \n", + " <\n", + " \n", + " \n", + " \n", + " br\n", + " \n", + " \n", + " \n", + " /\n", + " \n", + " \n", + " \n", + " >\n", + " \n", + " \n", + " \n", + " <\n", + " \n", + " \n", + " \n", + " br\n", + " \n", + " \n", + " \n", + " /\n", + " \n", + " \n", + " \n", + " >\n", + " \n", + " \n", + " \n", + " spoiler\n", + " \n", + " \n", + " \n", + " :\n", + " \n", + " \n", + " \n", + " this\n", + " \n", + " \n", + " \n", + " movie\n", + " \n", + " \n", + " \n", + " doesn\n", + " \n", + " \n", + " \n", + " '\n", + " \n", + " \n", + " \n", + " t\n", + " \n", + " \n", + " \n", + " have\n", + " \n", + " \n", + " \n", + " a\n", + " \n", + " \n", + " \n", + " goofs\n", + " \n", + " \n", + " \n", + " section\n", + " \n", + " \n", + " \n", + " .\n", + " \n", + " \n", + " \n", + " wonder\n", + " \n", + " \n", + " \n", + " ,\n", + " \n", + " \n", + " \n", + " didn\n", + " \n", + " \n", + " \n", + " '\n", + " \n", + " \n", + " \n", + " t\n", + " \n", + " \n", + " \n", + " anybody\n", + " \n", + " \n", + " \n", + " notice\n", + " \n", + " \n", + " \n", + " that\n", + " \n", + " \n", + " \n", + " hand\n", + " \n", + " \n", + " \n", - " genuine\n", + " in\n", " \n", " \n", " \n", - " star\n", + " the\n", " \n", " \n", - " \n", - " dust\n", + " 2\n", " \n", " \n", - " \n", + " part\n", + " \n", + " \n", + " \n", + " when\n", + " \n", + " \n", + " \n", + " the\n", + " \n", + " \n", + " \n", + " kidnappers\n", + " \n", + " \n", + " \n", - " back\n", + " decided\n", " \n", " \n", " \n", " \n", - " \n", - " earth\n", + " go\n", " \n", " \n", - " \n", - " is\n", + " home\n", " \n", " \n", - " \n", - " set\n", + " ?\n", " \n", " \n", - " \n", - " for\n", + " looks\n", + " \n", + " \n", + " \n", + " like\n", " \n", " \n", - " \n", " a\n", " \n", " \n", - " \n", - " dramatic\n", + " part\n", " \n", " \n", - " \n", - " finale\n", + " of\n", " \n", " \n", - " \n", - " sept\n", + " crew\n", " \n", " \n", - " \n", + " ,\n", + " \n", + " \n", + " \n", + " hehe\n", + " \n", + " \n", + " \n", " .\n", " \n", " \n", - " \n", - " 8\n", + " i\n", " \n", " \n", - " \n", - " when\n", + " know\n", + " \n", + " \n", + " \n", + " i\n", + " \n", + " \n", + " \n", + " should\n", + " \n", + " \n", + " \n", + " better\n", + " \n", + " \n", + " \n", + " post\n", + " \n", + " \n", + " \n", + " this\n", + " \n", + " \n", + " \n", + " in\n", + " \n", + " \n", + " \n", + " forums\n", + " \n", + " \n", + " \n", + " ,\n", " \n", " \n", " \n", - " hollywood\n", + " but\n", " \n", " \n", - " \n", - " helicopter\n", + " i\n", " \n", " \n", - " \n", - " pilots\n", + " don\n", " \n", " \n", - " \n", + " '\n", + " \n", + " \n", + " \n", - " will\n", + " t\n", " \n", " \n", - " \n", - " attempt\n", + " agree\n", + " \n", + " \n", + " \n", + " with\n", + " \n", + " \n", + " \n", + " some\n", " \n", " \n", - " \n", - " a\n", + " policies\n", " \n", " \n", - " \n", - " midair\n", + " here\n", " \n", " \n", - " \n", - " retrieval\n", + " .\n", " \n", " \n", "
\n", - " \n", " [CLS]\n", " \n", " \n", - " \n", - " california\n", - " \n", - " \n", - " \n", - " group\n", - " \n", - " \n", - " \n", - " sues\n", - " \n", - " \n", - " \n", - " albertson\n", + " as\n", " \n", " \n", - " \n", - " '\n", + " recent\n", " \n", " \n", - " \n", - " s\n", + " events\n", " \n", " \n", - " \n", - " over\n", + " illustrate\n", " \n", " \n", - " \n", - " privacy\n", + " ,\n", " \n", " \n", - " \n", - " concerns\n", + " trust\n", " \n", " \n", - " \n", - " a\n", + " takes\n", " \n", " \n", - " \n", - " california\n", + " years\n", " \n", " \n", - " \n", - " -\n", + " to\n", " \n", " \n", - " \n", - " based\n", + " gain\n", " \n", " \n", - " \n", - " privacy\n", + " but\n", " \n", " \n", - " \n", - " advocacy\n", + " can\n", " \n", " \n", - " \n", - " group\n", + " be\n", " \n", " \n", - " \n", - " is\n", + " lost\n", " \n", " \n", - " \n", - " suing\n", + " in\n", " \n", " \n", - " \n", - " supermarket\n", + " an\n", " \n", " \n", - " \n", - " giant\n", + " instant\n", " \n", " \n", - " \n", - " albertson\n", + " .\n", " \n", " \n", - " \n", - " '\n", + " [SEP]\n", " \n", " \n", - " \n", - " s\n", + " trust\n", " \n", " \n", - " \n", - " over\n", + " ,\n", " \n", " \n", - " \n", - " alleged\n", + " once\n", " \n", " \n", - " \n", - " privacy\n", + " built\n", " \n", " \n", - " \n", - " violations\n", + " ,\n", " \n", " \n", - " \n", - " involving\n", + " is\n", " \n", " \n", - " \n", - " its\n", + " hard\n", " \n", " \n", - " \n", - " pharmacy\n", + " to\n", " \n", " \n", - " \n", - " customers\n", + " lose\n", " \n", " \n", - " \n", " .\n", " \n", " \n", - " \n", " [SEP]\n", " \n", @@ -1091,141 +3376,166 @@ "name": "stdout", "output_type": "stream", "text": [ - "Model: textattack/albert-base-v2-ag-news | Pred: Business | True: Sci/Tech\n" + "Model: howey/electra-base-mnli | Pred: contradiction | True: contradiction\n" ] }, { "data": { "text/html": [ "
\n", - " \n", " [CLS]\n", " \n", " \n", - " \n", - " california\n", + " as\n", " \n", " \n", - " \n", - " group\n", + " recent\n", " \n", " \n", - " \n", - " sues\n", + " events\n", " \n", " \n", - " \n", - " albertson's\n", + " illustrate\n", " \n", " \n", - " \n", - " over\n", + " ,\n", " \n", " \n", - " \n", - " privacy\n", + " trust\n", " \n", " \n", - " \n", - " concerns\n", + " takes\n", " \n", " \n", - " \n", - " a\n", + " years\n", + " \n", + " \n", + " \n", + " to\n", + " \n", + " \n", + " \n", + " gain\n", + " \n", + " \n", + " \n", + " but\n", + " \n", + " \n", + " \n", + " can\n", + " \n", + " \n", + " \n", + " be\n", " \n", " \n", " \n", - " california-based\n", + " lost\n", " \n", " \n", - " \n", - " privacy\n", + " in\n", " \n", " \n", - " \n", - " advocacy\n", + " an\n", " \n", " \n", - " \n", - " group\n", + " instant\n", " \n", " \n", - " \n", - " is\n", + " .\n", " \n", " \n", - " \n", - " suing\n", + " [SEP]\n", " \n", " \n", - " \n", - " supermarket\n", + " trust\n", " \n", " \n", - " \n", - " giant\n", + " ,\n", " \n", " \n", - " \n", - " albertson's\n", + " once\n", " \n", " \n", - " \n", - " over\n", + " built\n", " \n", " \n", - " \n", - " alleged\n", + " ,\n", " \n", " \n", - " \n", - " privacy\n", + " is\n", " \n", " \n", - " \n", - " violations\n", + " hard\n", " \n", " \n", - " \n", - " involving\n", + " to\n", " \n", " \n", - " \n", - " its\n", + " lose\n", " \n", " \n", - " \n", - " pharmacy\n", + " .\n", " \n", " \n", - " \n", - " customers.[SEP]\n", + " [SEP]\n", " \n", "
" ], @@ -1264,7 +3574,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 24, "metadata": {}, "outputs": [ { diff --git a/src/thermostat/data/thermostat_configs.py b/src/thermostat/data/thermostat_configs.py index b7b5690..bd50c58 100644 --- a/src/thermostat/data/thermostat_configs.py +++ b/src/thermostat/data/thermostat_configs.py @@ -200,6 +200,22 @@ def __init__( data_url="https://cloud.dfki.de/owncloud/index.php/s/GLppwQjeBTsLtTC/download", **_AGNEWS_ALBERT_KWARGS, ), + # shap value inclusion + ThermostatConfig( + name="ag_news-albert-lds", + description="AG News dataset, ALBERT model, Layer DeepLift Shap explanations", + explainer="LayerDeepLiftShap", + data_url="https://cloud.dfki.de/owncloud/index.php/s/nzriaZHxniNtJBN/download", + **_AGNEWS_ALBERT_KWARGS, + ), + ThermostatConfig( + name="ag_news-albert-lgs", + description="AG News dataset, ALBERT model, Layer Gradient Shap explanations", + explainer="LayerGradientShap", + data_url="https://cloud.dfki.de/owncloud/index.php/s/EjCXxWCETQH9onj/download", + **_AGNEWS_ALBERT_KWARGS, + ), + # shap value inclusion ThermostatConfig( name="ag_news-bert-lgxa", description="AG News dataset, BERT model, Layer Gradient x Activation explanations", @@ -244,6 +260,22 @@ def __init__( data_url="https://cloud.dfki.de/owncloud/index.php/s/dCbgsjdW6b9pzo3/download", **_AGNEWS_BERT_KWARGS, ), + # shap value inclusion + ThermostatConfig( + name="ag_news-bert-lds", + description="AG News dataset, BERT model, Layer DeepLift Shap explanations", + explainer="LayerDeepLiftShap", + data_url="https://cloud.dfki.de/owncloud/index.php/s/XQzBnRniEyC8NEF/download", + **_AGNEWS_BERT_KWARGS, + ), + ThermostatConfig( + name="ag_news-bert-lgs", + description="AG News dataset, BERT model, Layer Gradient Shap explanations", + explainer="LayerGradientShap", + data_url="https://cloud.dfki.de/owncloud/index.php/s/nHt4Ld4AKfbG2ft/download", + **_AGNEWS_BERT_KWARGS, + ), + # shap value inclusion ThermostatConfig( name="ag_news-roberta-lgxa", description="AG News dataset, RoBERTa model, Layer Gradient x Activation explanations", @@ -288,6 +320,22 @@ def __init__( data_url="https://cloud.dfki.de/owncloud/index.php/s/yabEAY5sLpjxKkW/download", **_AGNEWS_ROBERTA_KWARGS, ), + # shap value inclusion + ThermostatConfig( + name="ag_news-roberta-lds", + description="AG News dataset, RoBERTa model, Layer DeepLift Shap explanations", + explainer="LayerDeepLiftShap", + data_url="https://cloud.dfki.de/owncloud/index.php/s/ZWoxq27s36For98/download", + **_AGNEWS_ROBERTA_KWARGS, + ), + ThermostatConfig( + name="ag_news-roberta-lgs", + description="AG News dataset, RoBERTa model, Layer Gradient Shap explanations", + explainer="LayerGradientShap", + data_url="https://cloud.dfki.de/owncloud/index.php/s/HgHjNFcMQbCC2Nb/download", + **_AGNEWS_ROBERTA_KWARGS, + ), + # shap value inclusion ThermostatConfig( name="imdb-albert-lgxa", description="IMDb dataset, ALBERT model, Layer Gradient x Activation explanations", @@ -332,6 +380,22 @@ def __init__( data_url="https://cloud.dfki.de/owncloud/index.php/s/sQMK2XsknbzK23a/download", **_IMDB_ALBERT_KWARGS, ), + # shap value inclusion + ThermostatConfig( + name="imdb-albert-lds", + description="IMDb dataset, ALBERT model, Layer DeepLift Shap explanations", + explainer="LayerDeepLiftShap", + data_url="https://cloud.dfki.de/owncloud/index.php/s/SYg2GjRkewW8fn7/download", + **_IMDB_ALBERT_KWARGS, + ), + ThermostatConfig( + name="imdb-albert-lgs", + description="IMDb dataset, ALBERT model, Layer Gradient Shap explanations", + explainer="LayerGradientShap", + data_url="https://cloud.dfki.de/owncloud/index.php/s/rWZfSzPN7Gm3Cko/download", + **_IMDB_ALBERT_KWARGS, + ), + # shap value inclusion ThermostatConfig( name="imdb-bert-lgxa", description="IMDb dataset, BERT model, Layer Gradient x Activation explanations", @@ -376,6 +440,22 @@ def __init__( data_url="https://cloud.dfki.de/owncloud/index.php/s/DjmCKdBoWHt8jbX/download", **_IMDB_BERT_KWARGS, ), + # shap value inclusion + ThermostatConfig( + name="imdb-bert-lds", + description="IMDb dataset, BERT model, Layer DeepLift Shap explanations", + explainer="LayerDeepLiftShap", + data_url="https://cloud.dfki.de/owncloud/index.php/s/dnqSfrgGPYcYsKt/download", + **_IMDB_BERT_KWARGS, + ), + ThermostatConfig( + name="imdb-bert-lgs", + description="IMDb dataset, BERT model, Layer Gradient Shap explanations", + explainer="LayerGradientShap", + data_url="https://cloud.dfki.de/owncloud/index.php/s/fnBqKxEjfsScqwg/download", + **_IMDB_BERT_KWARGS, + ), + # shap value inclusion ThermostatConfig( name="imdb-electra-lgxa", description="IMDb dataset, ELECTRA model, Layer Gradient x Activation explanations", @@ -420,6 +500,22 @@ def __init__( data_url="https://cloud.dfki.de/owncloud/index.php/s/MPHqZwJCP97sA4D/download", **_IMDB_ELECTRA_KWARGS, ), + # shap value inclusion + ThermostatConfig( + name="imdb-electra-lds", + description="IMDb dataset, ELECTRA model, Layer DeepLift Shap explanations", + explainer="LayerDeepLiftShap", + data_url="https://cloud.dfki.de/owncloud/index.php/s/cHdECMFjHacAnFk/download", + **_IMDB_ELECTRA_KWARGS, + ), + ThermostatConfig( + name="imdb-electra-lgs", + description="IMDb dataset, ELECTRA model, Layer Gradient Shap explanations", + explainer="LayerGradientShap", + data_url="https://cloud.dfki.de/owncloud/index.php/s/dRM6RKSwD5fKteG/download", + **_IMDB_ELECTRA_KWARGS, + ), + # shap value inclusion ThermostatConfig( name="imdb-roberta-lgxa", description="IMDb dataset, RoBERTa model, Layer Gradient x Activation explanations", @@ -464,6 +560,22 @@ def __init__( data_url="https://cloud.dfki.de/owncloud/index.php/s/339zLEttF6djtBR/download", **_IMDB_ROBERTA_KWARGS, ), + # shap value inclusion + ThermostatConfig( + name="imdb-roberta-lds", + description="IMDb dataset, RoBERTa model, Layer DeepLift Shap explanations", + explainer="LayerDeepLiftShap", + data_url="https://cloud.dfki.de/owncloud/index.php/s/EzD7brEosFx4iW2/download", + **_IMDB_ROBERTA_KWARGS, + ), + ThermostatConfig( + name="imdb-roberta-lgs", + description="IMDb dataset, RoBERTa model, Layer Gradient Shap explanations", + explainer="LayerGradientShap", + data_url="https://cloud.dfki.de/owncloud/index.php/s/ptcfks7sTnpm85M/download", + **_IMDB_ROBERTA_KWARGS, + ), + # shap value inclusion ThermostatConfig( name="imdb-xlnet-lgxa", description="IMDb dataset, XLNet model, Layer Gradient x Activation explanations", @@ -552,6 +664,22 @@ def __init__( data_url="https://cloud.dfki.de/owncloud/index.php/s/fffM7w64CnTSzHA/download", **_MNLI_ALBERT_KWARGS, ), + # shap value inclusion + ThermostatConfig( + name="multi_nli-albert-lds", + description="MultiNLI dataset, ALBERT model, Layer DeepLift Shap explanations", + explainer="LayerDeepLiftShap", + data_url="https://cloud.dfki.de/owncloud/index.php/s/LiMzjexGXc5PdAm/download", + **_MNLI_ALBERT_KWARGS, + ), + ThermostatConfig( + name="multi_nli-albert-lgs", + description="MultiNLI dataset, ALBERT model, Layer Gradient Shap explanations", + explainer="LayerGradientShap", + data_url="https://cloud.dfki.de/owncloud/index.php/s/ZL9fN8QKy8Di58B/download", + **_MNLI_ALBERT_KWARGS, + ), + # shap value inclusion ThermostatConfig( name="multi_nli-bert-lgxa", description="MultiNLI dataset, BERT model, Layer Gradient x Activation explanations", @@ -596,6 +724,22 @@ def __init__( data_url="https://cloud.dfki.de/owncloud/index.php/s/d5TTHCkAb5TJmbg/download", **_MNLI_BERT_KWARGS, ), + # shap value inclusion + ThermostatConfig( + name="multi_nli-bert-lds", + description="MultiNLI dataset, BERT model, Layer DeepLift Shap explanations", + explainer="LayerDeepLiftShap", + data_url="https://cloud.dfki.de/owncloud/index.php/s/yBtF9ybEm5kCzeg/download", + **_MNLI_BERT_KWARGS, + ), + ThermostatConfig( + name="multi_nli-bert-lgs", + description="MultiNLI dataset, BERT model, Layer Gradient Shap explanations", + explainer="LayerGradientShap", + data_url="https://cloud.dfki.de/owncloud/index.php/s/XMAgeYenwS7MSWP/download", + **_MNLI_BERT_KWARGS, + ), + # shap value inclusion ThermostatConfig( name="multi_nli-electra-lgxa", description="MultiNLI dataset, ELECTRA model, Layer Gradient x Activation explanations", @@ -640,6 +784,22 @@ def __init__( data_url="https://cloud.dfki.de/owncloud/index.php/s/zx3rGTpMkRT68tk/download", **_MNLI_ELECTRA_KWARGS, ), + # shap value inclusion + ThermostatConfig( + name="multi_nli-electra-lds", + description="MultiNLI dataset, ELECTRA model, Layer DeepLift Shap explanations", + explainer="LayerDeepLiftShap", + data_url="https://cloud.dfki.de/owncloud/index.php/s/PD578t8fzsR2Q6k/download", + **_MNLI_ELECTRA_KWARGS, + ), + ThermostatConfig( + name="multi_nli-electra-lgs", + description="MultiNLI dataset, ELECTRA model, Layer Gradient Shap explanations", + explainer="LayerGradientShap", + data_url="https://cloud.dfki.de/owncloud/index.php/s/iq2QNQRojpsoKwW/download", + **_MNLI_ELECTRA_KWARGS, + ), + # shap value inclusion ThermostatConfig( name="multi_nli-roberta-lgxa", description="MultiNLI dataset, RoBERTa model, Layer Gradient x Activation explanations", @@ -684,6 +844,22 @@ def __init__( data_url="https://cloud.dfki.de/owncloud/index.php/s/3aPeTawM8cbAsEg/download", **_MNLI_ROBERTA_KWARGS, ), + # shap value inclusion + ThermostatConfig( + name="multi_nli-roberta-lds", + description="MultiNLI dataset, RoBERTa model, Layer DeepLift Shap explanations", + explainer="LayerDeepLiftShap", + data_url="https://cloud.dfki.de/owncloud/index.php/s/Kg4eLx2SZWrf8MF/download", + **_MNLI_ROBERTA_KWARGS, + ), + ThermostatConfig( + name="multi_nli-roberta-lgs", + description="MultiNLI dataset, RoBERTa model, Layer Gradient Shap explanations", + explainer="LayerGradientShap", + data_url="https://cloud.dfki.de/owncloud/index.php/s/7Fxtawd3t8WXTWC/download", + **_MNLI_ROBERTA_KWARGS, + ), + # shap value inclusion ThermostatConfig( name="multi_nli-xlnet-lgxa", description="MultiNLI dataset, XLNet model, Layer Gradient x Activation explanations", @@ -772,6 +948,22 @@ def __init__( data_url="https://cloud.dfki.de/owncloud/index.php/s/wekiPq7ijzsCQK4/download", **_XNLI_ALBERT_KWARGS, ), + # shap value inclusion + ThermostatConfig( + name="xnli-albert-lds", + description="XNLI dataset, ALBERT model, Layer DeepLift Shap explanations", + explainer="LayerDeepLiftShap", + data_url="https://cloud.dfki.de/owncloud/index.php/s/AzKmCQfEP6CmwTH/download", + **_XNLI_ALBERT_KWARGS, + ), + ThermostatConfig( + name="xnli-albert-lgs", + description="XNLI dataset, ALBERT model, Layer Gradient Shap explanations", + explainer="LayerGradientShap", + data_url="https://cloud.dfki.de/owncloud/index.php/s/qWdK5YkSBxmMPKp/download", + **_XNLI_ALBERT_KWARGS, + ), + # shap value inclusion ThermostatConfig( name="xnli-bert-lgxa", description="XNLI dataset, BERT model, Layer Gradient x Activation explanations", @@ -816,6 +1008,22 @@ def __init__( data_url="https://cloud.dfki.de/owncloud/index.php/s/D4ctEijzerMoNT8/download", **_XNLI_BERT_KWARGS, ), + # shap value inclusion + ThermostatConfig( + name="xnli-bert-lds", + description="XNLI dataset, BERT model, Layer DeepLift Shap explanations", + explainer="LayerDeepLiftShap", + data_url="https://cloud.dfki.de/owncloud/index.php/s/LtARP8wA4mLCcL2/download", + **_XNLI_BERT_KWARGS, + ), + ThermostatConfig( + name="xnli-bert-lgs", + description="XNLI dataset, BERT model, Layer Gradient Shap explanations", + explainer="LayerGradientShap", + data_url="https://cloud.dfki.de/owncloud/index.php/s/AHw6WdfNbYXP9zD/download", + **_XNLI_BERT_KWARGS, + ), + # shap value inclusion ThermostatConfig( name="xnli-electra-lgxa", description="XNLI dataset, ELECTRA model, Layer Gradient x Activation explanations", @@ -860,6 +1068,22 @@ def __init__( data_url="https://cloud.dfki.de/owncloud/index.php/s/T3KKsM5TtsHyCAL/download", **_XNLI_ELECTRA_KWARGS, ), + # shap value inclusion + ThermostatConfig( + name="xnli-electra-lds", + description="XNLI dataset, ELECTRA model, Layer DeepLift Shap explanations", + explainer="LayerDeepLiftShap", + data_url="https://cloud.dfki.de/owncloud/index.php/s/MmHwGXa6EpzT3ns/download", + **_XNLI_ELECTRA_KWARGS, + ), + ThermostatConfig( + name="xnli-electra-lgs", + description="XNLI dataset, ELECTRA model, Layer Gradient Shap explanations", + explainer="LayerGradientShap", + data_url="https://cloud.dfki.de/owncloud/index.php/s/gLwPwAaJDCB956T/download", + **_XNLI_ELECTRA_KWARGS, + ), + # shap value inclusion ThermostatConfig( name="xnli-roberta-lgxa", description="XNLI dataset, RoBERTa model, Layer Gradient x Activation explanations", @@ -904,6 +1128,22 @@ def __init__( data_url="https://cloud.dfki.de/owncloud/index.php/s/opYTzjSeWWL7eYg/download", **_XNLI_ROBERTA_KWARGS, ), + # shap value inclusion + ThermostatConfig( + name="xnli-roberta-lds", + description="XNLI dataset, RoBERTa model, Layer DeepLift Shap explanations", + explainer="LayerDeepLiftShap", + data_url="https://cloud.dfki.de/owncloud/index.php/s/zsrrBirPHY4gp2p/download", + **_XNLI_ROBERTA_KWARGS, + ), + ThermostatConfig( + name="xnli-roberta-lgs", + description="XNLI dataset, RoBERTa model, Layer Gradient Shap explanations", + explainer="LayerGradientShap", + data_url="https://cloud.dfki.de/owncloud/index.php/s/fxAMzqGzgBZb9b4/download", + **_XNLI_ROBERTA_KWARGS, + ), + # shap value inclusion ThermostatConfig( name="xnli-xlnet-lgxa", description="XNLI dataset, XLNet model, Layer Gradient x Activation explanations", diff --git a/src/thermostat/explainers/__init__.py b/src/thermostat/explainers/__init__.py index f622235..182ccff 100644 --- a/src/thermostat/explainers/__init__.py +++ b/src/thermostat/explainers/__init__.py @@ -20,3 +20,8 @@ from .svs import ( ExplainerShapleyValueSampling ) + +from .shap import ( + ExplainerLayerDeepLiftShap, + ExplainerLayerGradientShap, +) \ No newline at end of file diff --git a/src/thermostat/explainers/shap.py b/src/thermostat/explainers/shap.py new file mode 100644 index 0000000..761af46 --- /dev/null +++ b/src/thermostat/explainers/shap.py @@ -0,0 +1,106 @@ +import torch +from captum.attr import ( + LayerGradientShap, LayerDeepLiftShap,) +from transformers import XLNetForSequenceClassification +from typing import Dict + +from thermostat.explain import ExplainerAutoModelInitializer +from thermostat.utils import HookableModelWrapper + + +class ExplainerLayerGradientShap(ExplainerAutoModelInitializer): + def __init__(self): + super().__init__() + self.name_layer: str = None + self.layer = None + self.n_samples: int = None + + def validate_config(self, config: Dict) -> bool: + super().validate_config(config) + assert 'n_samples' in config['explainer'], \ + 'Define how many samples to take along the straight line path from the baseline.' + + @classmethod + def from_config(cls, config): + res = super().from_config(config) + res.validate_config(config) + res.n_samples = config['explainer']['n_samples'] + + res.model.eval() # setting the model to eval and zero grad + res.model.zero_grad() # so we do not have to loop it in explainer for each batch. + + res.explainer = LayerGradientShap(forward_func=res.forward_func, + layer=res.get_embedding_layer(res.model)) + return res + + def explain(self, batch): + + batch = {k: v.to(self.device) for k, v in batch.items()} + + inputs, additional_forward_args = self.get_inputs_and_additional_args(base_model=type(self.model.base_model), + batch=batch) + + predictions = self.forward_func(inputs, *additional_forward_args) + target = torch.argmax(predictions, dim=1) + + base_line = self.get_baseline(batch=batch) + + attributions = self.explainer.attribute(inputs=inputs, + baselines=base_line, + n_samples =self.n_samples, + target=target, + additional_forward_args=additional_forward_args) + + attributions = torch.sum(attributions, dim=2) + + if isinstance(self.model, XLNetForSequenceClassification): + # for xlnet, attributions.shape = [seq_len, batch_dim] + # but [batch_dim, seq_len] is assumed + attributions = attributions.T + + return attributions, predictions # xlnet: [130, 1] + + + +class ExplainerLayerDeepLiftShap(ExplainerAutoModelInitializer): + def __init__(self): + super().__init__() + + @classmethod + def from_config(cls, config): + res = super().from_config(config) + + res.model.eval() # setting the model to eval and zero grad + res.model.zero_grad() # so we do not have to loop it in explainer for each batch. + + res.explainer = LayerDeepLiftShap(model=HookableModelWrapper(res), layer=res.get_embedding_layer(res.model)) + + return res + + def explain(self, batch): + + batch = {k: v.to(self.device) for k, v in batch.items()} + + inputs, additional_forward_args = self.get_inputs_and_additional_args(base_model=type(self.model.base_model), + batch=batch) + base_line = self.get_baseline(batch) + base_line = torch.stack((torch.zeros_like(inputs[0]), base_line[0])) # stacking to provide more than one example for base line + + predictions = self.forward_func(inputs, *additional_forward_args) + target = torch.argmax(predictions, dim=1) + + attributions = self.explainer.attribute( + inputs=inputs.float(), + additional_forward_args=additional_forward_args, + target=target, + baselines=base_line.float(), + ) + + attributions = torch.sum(attributions, dim=2) + + if isinstance(self.model, XLNetForSequenceClassification): + # for xlnet, attributions.shape = [seq_len, batch_dim] + # but [batch_dim, seq_len] is assumed + attributions = attributions.T + + return attributions, predictions