From 6fe84a4cdff60f288abe2b59a0fa97048315e24a Mon Sep 17 00:00:00 2001 From: Roy Shilkrot Date: Wed, 14 Aug 2024 22:57:18 -0400 Subject: [PATCH 1/3] refactor: Add filter-replace-utils for serializing and deserializing filter words replacements --- CMakeLists.txt | 6 +- src/tests/README.md | 2 +- src/tests/evaluate_output.py | 120 ++++++++++++++++++-------- src/tests/localvocal-offline-test.cpp | 7 ++ src/ui/filter-replace-dialog.cpp | 29 ------- src/ui/filter-replace-dialog.h | 5 -- src/ui/filter-replace-utils.cpp | 32 +++++++ src/ui/filter-replace-utils.h | 12 +++ 8 files changed, 139 insertions(+), 74 deletions(-) create mode 100644 src/ui/filter-replace-utils.cpp create mode 100644 src/ui/filter-replace-utils.h diff --git a/CMakeLists.txt b/CMakeLists.txt index 7eff873..880c4e6 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -117,7 +117,8 @@ target_sources( src/translation/language_codes.cpp src/translation/translation.cpp src/translation/translation-utils.cpp - src/ui/filter-replace-dialog.cpp) + src/ui/filter-replace-dialog.cpp + src/ui/filter-replace-utils.cpp) set_target_properties_plugin(${CMAKE_PROJECT_NAME} PROPERTIES OUTPUT_NAME ${_name}) @@ -138,7 +139,8 @@ if(ENABLE_TESTS) src/whisper-utils/silero-vad-onnx.cpp src/whisper-utils/token-buffer-thread.cpp src/translation/language_codes.cpp - src/translation/translation.cpp) + src/translation/translation.cpp + src/ui/filter-replace-utils.cpp) find_libav(${CMAKE_PROJECT_NAME}-tests) diff --git a/src/tests/README.md b/src/tests/README.md index e47f730..ed9e52f 100644 --- a/src/tests/README.md +++ b/src/tests/README.md @@ -112,7 +112,7 @@ The JSON config file can look e.g. like "silero_vad_model_file": ".../obs-localvocal/data/models/silero-vad/silero_vad.onnx", "ct2_model_folder": ".../obs-localvocal/models/m2m-100-418M", "fix_utf8": true, - "suppress_sentences": "다음 영상에서 만나요!\nMBC 뉴스 김지경입니다\nMBC 뉴스 김성현입니다\n구독과 좋아요 눌러주세요!\n구독과 좋아요는 저에게 아주 큰\n다음 영상에서 만나요\n끝까지 시청해주셔서 감사합니다\n구독과 좋아요 부탁드립니다!\nMBC 뉴스 이준범입니다\nMBC 뉴스 문재인입니다\nMBC 뉴스 김지연입니다\nMBC 뉴스 안영백입니다.\nMBC 뉴스 이덕영입니다\nMBC 뉴스 김상현입니다\n구독과 좋아요 눌러주세요!\n구독과 좋아요 부탁드", + "filter_words_replace": "[{\"key\": \"다음 영상에서 만나요!\", \"value\":\"\"}]", "overlap_ms": 150, "log_level": "debug", "whisper_sampling_method": 0 diff --git a/src/tests/evaluate_output.py b/src/tests/evaluate_output.py index 873ce25..e4ae59e 100644 --- a/src/tests/evaluate_output.py +++ b/src/tests/evaluate_output.py @@ -1,33 +1,80 @@ import Levenshtein import argparse -from diff_match_patch import diff_match_patch +import unicodedata +import re +import difflib -def visualize_differences(ref_text, hyp_text): - dmp = diff_match_patch() - diffs = dmp.diff_main(hyp_text, ref_text, checklines=True) - html = dmp.diff_prettyHtml(diffs) - return html +def remove_accents(text): + return ''.join(c for c in unicodedata.normalize('NFD', text) + if unicodedata.category(c) != 'Mn') -def calculate_wer(ref_text, hyp_text): - ref_words = ref_text.split() - hyp_words = hyp_text.split() +def clean_text(text): + # Remove punctuation and special characters + text = re.sub(r'[^\w\s]', '', text) + # Remove extra whitespace + text = re.sub(r'\s+', ' ', text).strip() + return text - distance = Levenshtein.distance(ref_words, hyp_words) - wer = distance / len(ref_words) +def normalize_spanish_gender_postfixes(text): + # Normalize + text = re.sub(r'\b(\w+?)(a)\b', r'\1e', text) + return text + +def tokenize(text, should_remove_accents=False, remove_punctuation=False): + # Convert to lowercase, remove accents, clean text, and split + if should_remove_accents: + text = remove_accents(text) + text = normalize_spanish_gender_postfixes(text) + if remove_punctuation: + text = clean_text(text) + tokens = text.lower().split() + return tokens + +def calculate_wer(ref_text_tokens, hyp_text_tokens): + distance = Levenshtein.distance(ref_text_tokens, hyp_text_tokens, weights=(1, 1, 1)) + wer = distance / max(len(ref_text_tokens), len(hyp_text_tokens)) return wer -def calculate_cer(ref_text, hyp_text): +def calculate_cer(ref_text_tokens, hyp_text_tokens): + # Join tokens into a single string + ref_text = ' '.join(ref_text_tokens) + hyp_text = ' '.join(hyp_text_tokens) distance = Levenshtein.distance(ref_text, hyp_text) cer = distance / len(ref_text) return cer -def compare_tokens(ref_tokens, hyp_tokens): - comparisons = [] - for ref_token, hyp_token in zip(ref_tokens, hyp_tokens): - distance = Levenshtein.distance(ref_token, hyp_token) - comparison = {'ref_token': ref_token, 'hyp_token': hyp_token, 'error_rate': distance / len(ref_token)} - comparisons.append(comparison) - return comparisons +def print_alignment(ref_words, hyp_words): + d = difflib.Differ() + diff = list(d.compare(ref_words, hyp_words)) + + print("\nToken-by-token alignment:") + print("Reference | Hypothesis") + print("-" * 30) + + ref_token = hyp_token = "" + for token in diff: + if token.startswith(" "): # Common token + if ref_token or hyp_token: + print(f"{ref_token:<10} | {hyp_token:<10}") + ref_token = hyp_token = "" + print(f"{token[2:]:<10} | {token[2:]:<10}") + elif token.startswith("- "): # Token in reference, not in hypothesis + ref_token = token[2:] + elif token.startswith("+ "): # Token in hypothesis, not in reference + hyp_token = token[2:] + if ref_token: + print(f"{ref_token:<10} | {hyp_token:<10} (Substitution)") + ref_token = hyp_token = "" + else: + print(f"{"":10} | {hyp_token:<10} (Insertion)") + hyp_token = "" + + # Print any remaining tokens + if ref_token: + print(f"{ref_token:<10} | {"":10} (Deletion)") + elif hyp_token: + print(f"{"":10} | {hyp_token:<10} (Insertion)") + def read_text_from_file(file_path, join_sentences=True): with open(file_path, 'r', encoding='utf-8', errors='ignore') as file: @@ -41,28 +88,27 @@ def read_text_from_file(file_path, join_sentences=True): parser = argparse.ArgumentParser(description='Evaluate output') parser.add_argument('ref_file_path', type=str, help='Path to the reference file') parser.add_argument('hyp_file_path', type=str, help='Path to the hypothesis file') +parser.add_argument('--remove_accents', action='store_true', help='Remove accents from text') +parser.add_argument('--remove_punctuation', action='store_true', help='Remove punctuation from text') +parser.add_argument('--print_alignment', action='store_true', help='Print the alignment to the console') +parser.add_argument('--write_tokens', action='store_true', help='Write the tokens to a file') args = parser.parse_args() -ref_text = read_text_from_file(args.ref_file_path) -hyp_text = read_text_from_file(args.hyp_file_path) -wer = calculate_wer(ref_text, hyp_text) -cer = calculate_cer(ref_text, hyp_text) -print("Word Error Rate (WER):", wer) -print("Character Error Rate (CER):", cer) - -ref_text = '\n'.join(read_text_from_file(args.ref_file_path, join_sentences=False)) -hyp_text = '\n'.join(read_text_from_file(args.hyp_file_path, join_sentences=False)) -html_diff = visualize_differences(ref_text, hyp_text) -with open("diff_visualization.html", "w", encoding="utf-8") as file: - file.write(html_diff) +ref_text = read_text_from_file(args.ref_file_path, join_sentences=True) +hyp_text = read_text_from_file(args.hyp_file_path, join_sentences=True) +ref_tokens = tokenize(ref_text, should_remove_accents=args.remove_accents, remove_punctuation=args.remove_punctuation) +hyp_tokens = tokenize(hyp_text, should_remove_accents=args.remove_accents, remove_punctuation=args.remove_punctuation) -from Bio.Align import PairwiseAligner +if args.print_alignment: + print_alignment(ref_tokens, hyp_tokens) -aligner = PairwiseAligner() +if args.write_tokens: + with open("ref_tokens.txt", "w", encoding="utf-8") as file: + file.write('\n'.join(ref_tokens)) + with open("hyp_tokens.txt", "w", encoding="utf-8") as file: + file.write('\n'.join(hyp_tokens)) -alignments = aligner.align(ref_text, hyp_text) +wer = calculate_wer(ref_tokens, hyp_tokens) -# write the first alignment to a file -with open("alignment.txt", "w", encoding="utf-8") as file: - file.write(alignments[0].format()) +print(f"\"{args.ref_file_path}\" WER: \"{wer:.2}\"") diff --git a/src/tests/localvocal-offline-test.cpp b/src/tests/localvocal-offline-test.cpp index 8fec08b..48876c4 100644 --- a/src/tests/localvocal-offline-test.cpp +++ b/src/tests/localvocal-offline-test.cpp @@ -19,6 +19,7 @@ #include "whisper-utils/whisper-utils.h" #include "audio-file-utils.h" #include "translation/language_codes.h" +#include "ui/filter-replace-utils.h" #include #include @@ -429,6 +430,12 @@ int wmain(int argc, wchar_t *argv[]) config["no_context"] ? "true" : "false"); gf->whisper_params.no_context = config["no_context"]; } + if (config.contains("filter_words_replace")) { + obs_log(LOG_INFO, "Setting filter_words_replace to %s", + config["filter_words_replace"]); + gf->filter_words_replace = deserialize_filter_words_replace( + config["filter_words_replace"]); + } // set log level if (logLevelStr == "debug") { gf->log_level = LOG_DEBUG; diff --git a/src/ui/filter-replace-dialog.cpp b/src/ui/filter-replace-dialog.cpp index 1464082..b491a31 100644 --- a/src/ui/filter-replace-dialog.cpp +++ b/src/ui/filter-replace-dialog.cpp @@ -73,32 +73,3 @@ void FilterReplaceDialog::editFilter(QTableWidgetItem *item) // use the row number to update the filter_words_replace map ctx->filter_words_replace[item->row()] = std::make_tuple(key, value); } - -std::string serialize_filter_words_replace( - const std::vector> &filter_words_replace) -{ - if (filter_words_replace.empty()) { - return "[]"; - } - // use JSON to serialize the filter_words_replace map - nlohmann::json j; - for (const auto &entry : filter_words_replace) { - j.push_back({{"key", std::get<0>(entry)}, {"value", std::get<1>(entry)}}); - } - return j.dump(); -} - -std::vector> -deserialize_filter_words_replace(const std::string &filter_words_replace_str) -{ - if (filter_words_replace_str.empty()) { - return {}; - } - // use JSON to deserialize the filter_words_replace map - std::vector> filter_words_replace; - nlohmann::json j = nlohmann::json::parse(filter_words_replace_str); - for (const auto &entry : j) { - filter_words_replace.push_back(std::make_tuple(entry["key"], entry["value"])); - } - return filter_words_replace; -} diff --git a/src/ui/filter-replace-dialog.h b/src/ui/filter-replace-dialog.h index c605531..d392a80 100644 --- a/src/ui/filter-replace-dialog.h +++ b/src/ui/filter-replace-dialog.h @@ -27,9 +27,4 @@ private slots: void editFilter(QTableWidgetItem *item); }; -std::string serialize_filter_words_replace( - const std::vector> &filter_words_replace); -std::vector> -deserialize_filter_words_replace(const std::string &filter_words_replace_str); - #endif // FILTERREPLACEDIALOG_H diff --git a/src/ui/filter-replace-utils.cpp b/src/ui/filter-replace-utils.cpp new file mode 100644 index 0000000..14af016 --- /dev/null +++ b/src/ui/filter-replace-utils.cpp @@ -0,0 +1,32 @@ +#include "filter-replace-utils.h" + +#include + +std::string serialize_filter_words_replace( + const std::vector> &filter_words_replace) +{ + if (filter_words_replace.empty()) { + return "[]"; + } + // use JSON to serialize the filter_words_replace map + nlohmann::json j; + for (const auto &entry : filter_words_replace) { + j.push_back({{"key", std::get<0>(entry)}, {"value", std::get<1>(entry)}}); + } + return j.dump(); +} + +std::vector> +deserialize_filter_words_replace(const std::string &filter_words_replace_str) +{ + if (filter_words_replace_str.empty()) { + return {}; + } + // use JSON to deserialize the filter_words_replace map + std::vector> filter_words_replace; + nlohmann::json j = nlohmann::json::parse(filter_words_replace_str); + for (const auto &entry : j) { + filter_words_replace.push_back(std::make_tuple(entry["key"], entry["value"])); + } + return filter_words_replace; +} diff --git a/src/ui/filter-replace-utils.h b/src/ui/filter-replace-utils.h new file mode 100644 index 0000000..9d87376 --- /dev/null +++ b/src/ui/filter-replace-utils.h @@ -0,0 +1,12 @@ +#ifndef FILTER_REPLACE_UTILS_H +#define FILTER_REPLACE_UTILS_H + +#include +#include + +std::string serialize_filter_words_replace( + const std::vector> &filter_words_replace); +std::vector> +deserialize_filter_words_replace(const std::string &filter_words_replace_str); + +#endif /* FILTER_REPLACE_UTILS_H */ \ No newline at end of file From 747f3f069043ab5ef1ec52ec2323c1278dc52ea5 Mon Sep 17 00:00:00 2001 From: Roy Shilkrot Date: Wed, 21 Aug 2024 08:12:34 -0400 Subject: [PATCH 2/3] refactor: Add filter-replace-utils for serializing and deserializing filter words replacements --- src/transcription-filter-properties.cpp | 1 + src/ui/filter-replace-utils.h | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/src/transcription-filter-properties.cpp b/src/transcription-filter-properties.cpp index 523bbf8..b15e93b 100644 --- a/src/transcription-filter-properties.cpp +++ b/src/transcription-filter-properties.cpp @@ -10,6 +10,7 @@ #include "model-utils/model-downloader-types.h" #include "translation/language_codes.h" #include "ui/filter-replace-dialog.h" +#include "ui/filter-replace-utils.h" #include #include diff --git a/src/ui/filter-replace-utils.h b/src/ui/filter-replace-utils.h index 9d87376..485b968 100644 --- a/src/ui/filter-replace-utils.h +++ b/src/ui/filter-replace-utils.h @@ -9,4 +9,4 @@ std::string serialize_filter_words_replace( std::vector> deserialize_filter_words_replace(const std::string &filter_words_replace_str); -#endif /* FILTER_REPLACE_UTILS_H */ \ No newline at end of file +#endif /* FILTER_REPLACE_UTILS_H */ From 13a8cc86152f660f97faf9d3a79f051740dd9985 Mon Sep 17 00:00:00 2001 From: Roy Shilkrot Date: Wed, 21 Aug 2024 09:31:30 -0400 Subject: [PATCH 3/3] refactor: Add filter-replace-utils for serializing and deserializing filter words replacements --- src/transcription-filter.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/src/transcription-filter.cpp b/src/transcription-filter.cpp index 3683c18..d6b2d7e 100644 --- a/src/transcription-filter.cpp +++ b/src/transcription-filter.cpp @@ -30,6 +30,7 @@ #include "translation/translation.h" #include "translation/translation-includes.h" #include "ui/filter-replace-dialog.h" +#include "ui/filter-replace-utils.h" void set_source_signals(transcription_filter_data *gf, obs_source_t *parent_source) {