// Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu) // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include #include #include #include "decoder/params.h" #include "frontend/wav.h" #include "utils/flags.h" #include "utils/wn_string.h" DEFINE_string(text, "", "kaldi style text input file"); DEFINE_string(wav_scp, "", "kaldi style wav scp"); DEFINE_double(is_penalty, 1.0, "insertion/substitution penalty for align insertion"); DEFINE_double(del_penalty, 1.0, "deletion penalty for align insertion"); DEFINE_string(result, "", "result output file"); DEFINE_string(timestamp, "", "timestamp output file"); namespace wenet { const char* kDeletion = ""; // Is: Insertion and substitution const char* kIsStart = ""; const char* kIsEnd = ""; bool MapToLabel(const std::string& text, std::shared_ptr symbol_table, std::vector* labels) { labels->clear(); // Split label to char sequence std::vector chars; SplitUTF8StringToChars(text, &chars); for (size_t i = 0; i < chars.size(); i++) { // ▁ is special symbol for white space std::string label = chars[i] != " " ? chars[i] : "▁"; int id = symbol_table->Find(label); if (id != -1) { // fst::kNoSymbol // LOG(INFO) << label << " " << id; labels->push_back(id); } } return true; } std::shared_ptr MakeSymbolTableForFst( std::shared_ptr isymbol_table) { LOG(INFO) << isymbol_table; CHECK(isymbol_table != nullptr); auto osymbol_table = std::make_shared(); osymbol_table->AddSymbol("", 0); CHECK_EQ(isymbol_table->Find(""), 0); osymbol_table->AddSymbol("", 1); for (int i = 1; i < isymbol_table->NumSymbols(); i++) { std::string symbol = isymbol_table->Find(i); osymbol_table->AddSymbol(symbol, i + 1); } osymbol_table->AddSymbol(kDeletion, isymbol_table->NumSymbols() + 1); osymbol_table->AddSymbol(kIsStart, isymbol_table->NumSymbols() + 2); osymbol_table->AddSymbol(kIsEnd, isymbol_table->NumSymbols() + 3); return osymbol_table; } void CompileCtcFst(std::shared_ptr symbol_table, fst::StdVectorFst* ofst) { ofst->DeleteStates(); int start = ofst->AddState(); ofst->SetStart(start); CHECK_EQ(symbol_table->Find(""), 0); CHECK_EQ(symbol_table->Find(""), 1); ofst->AddArc(start, fst::StdArc(1, 0, 0.0, start)); // Exclude kDeletion and kInsertion for (int i = 2; i < symbol_table->NumSymbols() - 3; i++) { int s = ofst->AddState(); ofst->AddArc(start, fst::StdArc(i, i, 0.0, s)); ofst->AddArc(s, fst::StdArc(i, 0, 0.0, s)); ofst->AddArc(s, fst::StdArc(0, 0, 0.0, start)); } ofst->SetFinal(start, fst::StdArc::Weight::One()); fst::ArcSort(ofst, fst::StdOLabelCompare()); } void CompileAlignFst(std::vector labels, std::shared_ptr symbol_table, fst::StdVectorFst* ofst) { ofst->DeleteStates(); int deletion = symbol_table->Find(kDeletion); int insertion_start = symbol_table->Find(kIsStart); int insertion_end = symbol_table->Find(kIsEnd); int start = ofst->AddState(); ofst->SetStart(start); // Filler State int filler_start = ofst->AddState(); int filler_end = ofst->AddState(); for (int i = 2; i < symbol_table->NumSymbols() - 3; i++) { ofst->AddArc(filler_start, fst::StdArc(i, i, FLAGS_is_penalty, filler_end)); } ofst->AddArc(filler_end, fst::StdArc(0, 0, 0.0, filler_start)); int prev = start; // Alignment path and optional filler for (size_t i = 0; i < labels.size(); i++) { int cur = ofst->AddState(); // 1. Insertion or Substitution ofst->AddArc(prev, fst::StdArc(0, insertion_start, 0.0, filler_start)); ofst->AddArc(filler_end, fst::StdArc(0, insertion_end, 0.0, prev)); // 2. Correct ofst->AddArc(prev, fst::StdArc(labels[i], labels[i], 0.0, cur)); // 3. Deletion ofst->AddArc(prev, fst::StdArc(0, deletion, FLAGS_del_penalty, cur)); prev = cur; } // Optional add endding filler ofst->AddArc(prev, fst::StdArc(0, insertion_start, 0.0, filler_start)); ofst->AddArc(filler_end, fst::StdArc(0, insertion_end, 0.0, prev)); ofst->SetFinal(prev, fst::StdArc::Weight::One()); fst::ArcSort(ofst, fst::StdILabelCompare()); } } // namespace wenet int main(int argc, char* argv[]) { gflags::ParseCommandLineFlags(&argc, &argv, false); google::InitGoogleLogging(argv[0]); auto decode_config = wenet::InitDecodeOptionsFromFlags(); auto feature_config = wenet::InitFeaturePipelineConfigFromFlags(); auto decode_resource = wenet::InitDecodeResourceFromFlags(); CHECK(decode_resource->unit_table != nullptr); auto wfst_symbol_table = wenet::MakeSymbolTableForFst(decode_resource->unit_table); // wfst_symbol_table->WriteText("fst.txt"); // Reset symbol_table to on-the-fly generated wfst_symbol_table decode_resource->symbol_table = wfst_symbol_table; // Compile ctc FST fst::StdVectorFst ctc_fst; wenet::CompileCtcFst(wfst_symbol_table, &ctc_fst); // ctc_fst.Write("ctc.fst"); std::unordered_map wav_table; std::ifstream wav_is(FLAGS_wav_scp); std::string line; while (std::getline(wav_is, line)) { std::vector strs; wenet::SplitString(line, &strs); CHECK_EQ(strs.size(), 2); wav_table[strs[0]] = strs[1]; } std::ifstream text_is(FLAGS_text); std::ofstream result_os(FLAGS_result, std::ios::out); std::ofstream timestamp_out; if (!FLAGS_timestamp.empty()) { timestamp_out.open(FLAGS_timestamp, std::ios::out); } std::ostream& timestamp_os = FLAGS_timestamp.empty() ? std::cout : timestamp_out; while (std::getline(text_is, line)) { std::vector strs; wenet::SplitString(line, &strs); if (strs.size() < 2) continue; std::string key = strs[0]; LOG(INFO) << "Processing " << key; if (wav_table.find(key) != wav_table.end()) { strs.erase(strs.begin()); std::string text = wenet::JoinString(" ", strs); std::vector labels; wenet::MapToLabel(text, wfst_symbol_table, &labels); // Prepare FST for alignment decoding fst::StdVectorFst align_fst; wenet::CompileAlignFst(labels, wfst_symbol_table, &align_fst); // align_fst.Write("align.fst"); auto decoding_fst = std::make_shared(); fst::Compose(ctc_fst, align_fst, decoding_fst.get()); // decoding_fst->Write("decoding.fst"); // Preapre feature pipeline wenet::WavReader wav_reader; if (!wav_reader.Open(wav_table[key])) { LOG(WARNING) << "Error in reading " << wav_table[key]; continue; } int num_samples = wav_reader.num_samples(); CHECK_EQ(wav_reader.sample_rate(), FLAGS_sample_rate); auto feature_pipeline = std::make_shared(*feature_config); feature_pipeline->AcceptWaveform(wav_reader.data(), num_samples); feature_pipeline->set_input_finished(); decode_resource->fst = decoding_fst; LOG(INFO) << "num frames " << feature_pipeline->num_frames(); wenet::AsrDecoder decoder(feature_pipeline, decode_resource, *decode_config); while (true) { wenet::DecodeState state = decoder.Decode(); if (state == wenet::DecodeState::kEndFeats) { decoder.Rescoring(); break; } } std::string final_result; std::string timestamp_str; if (decoder.DecodedSomething()) { const wenet::DecodeResult& result = decoder.result()[0]; final_result = result.sentence; std::stringstream ss; for (const auto& w : result.word_pieces) { ss << " " << w.word << " " << w.start << " " << w.end; } timestamp_str = ss.str(); } result_os << key << " " << final_result << std::endl; timestamp_os << key << " " << timestamp_str << std::endl; LOG(INFO) << key << " " << final_result; } else { LOG(WARNING) << "No wav file for " << key; } } return 0; }