You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

168 lines
5.4 KiB

  1. // Copyright (c) 2020 Mobvoi Inc (Binbin Zhang, Di Wu)
  2. // 2022 Binbin Zhang (binbzha@qq.com)
  3. //
  4. // Licensed under the Apache License, Version 2.0 (the "License");
  5. // you may not use this file except in compliance with the License.
  6. // You may obtain a copy of the License at
  7. //
  8. // http://www.apache.org/licenses/LICENSE-2.0
  9. //
  10. // Unless required by applicable law or agreed to in writing, software
  11. // distributed under the License is distributed on an "AS IS" BASIS,
  12. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. // See the License for the specific language governing permissions and
  14. // limitations under the License.
  15. #ifndef DECODER_ASR_DECODER_H_
  16. #define DECODER_ASR_DECODER_H_
  17. #include <memory>
  18. #include <string>
  19. #include <unordered_set>
  20. #include <utility>
  21. #include <vector>
  22. #include "fst/fstlib.h"
  23. #include "fst/symbol-table.h"
  24. #include "decoder/asr_model.h"
  25. #include "decoder/context_graph.h"
  26. #include "decoder/ctc_endpoint.h"
  27. #include "decoder/ctc_prefix_beam_search.h"
  28. #include "decoder/ctc_wfst_beam_search.h"
  29. #include "decoder/search_interface.h"
  30. #include "frontend/feature_pipeline.h"
  31. #include "../post_processor/processor/post_processor.h"
  32. #include "utils/utils.h"
  33. namespace wenet {
  34. struct DecodeOptions {
  35. // chunk_size is the frame number of one chunk after subsampling.
  36. // e.g. if subsample rate is 4 and chunk_size = 16, the frames in
  37. // one chunk are 64 = 16*4
  38. int chunk_size = 16;
  39. int num_left_chunks = -1;
  40. // final_score = rescoring_weight * rescoring_score + ctc_weight * ctc_score;
  41. // rescoring_score = left_to_right_score * (1 - reverse_weight) +
  42. // right_to_left_score * reverse_weight
  43. // Please note the concept of ctc_scores in the following two search
  44. // methods are different.
  45. // For CtcPrefixBeamSearch, it's a sum(prefix) score + context score
  46. // For CtcWfstBeamSearch, it's a max(viterbi) path score + context score
  47. // So we should carefully set ctc_weight according to the search methods.
  48. float ctc_weight = 0.5;
  49. float rescoring_weight = 1.0;
  50. float reverse_weight = 0.0;
  51. CtcEndpointConfig ctc_endpoint_config;
  52. CtcPrefixBeamSearchOptions ctc_prefix_search_opts;
  53. CtcWfstBeamSearchOptions ctc_wfst_search_opts;
  54. };
  55. struct WordPiece {
  56. std::string word;
  57. int start = -1;
  58. int end = -1;
  59. WordPiece(std::string word, int start, int end)
  60. : word(std::move(word)), start(start), end(end) {}
  61. };
  62. struct DecodeResult {
  63. float score = -kFloatMax;
  64. std::string sentence;
  65. std::unordered_set<std::string> contexts;
  66. std::vector<WordPiece> word_pieces;
  67. static bool CompareFunc(const DecodeResult& a, const DecodeResult& b) {
  68. return a.score > b.score;
  69. }
  70. };
  71. enum DecodeState {
  72. kEndBatch = 0x00, // End of current decoding batch, normal case
  73. kEndpoint = 0x01, // Endpoint is detected
  74. kEndFeats = 0x02, // All feature is decoded
  75. kWaitFeats = 0x03 // Feat is not enough for one chunk inference, wait
  76. };
  77. // DecodeResource is thread safe, which can be shared for multiple
  78. // decoding threads
  79. struct DecodeResource {
  80. std::shared_ptr<AsrModel> model = nullptr;
  81. std::shared_ptr<fst::SymbolTable> symbol_table = nullptr;
  82. std::shared_ptr<fst::VectorFst<fst::StdArc>> fst = nullptr;
  83. std::shared_ptr<fst::SymbolTable> unit_table = nullptr;
  84. std::shared_ptr<ContextGraph> context_graph = nullptr;
  85. std::shared_ptr<PostProcessor> post_processor = nullptr;
  86. };
  87. // Torch ASR decoder
  88. class AsrDecoder {
  89. public:
  90. AsrDecoder(std::shared_ptr<FeaturePipeline> feature_pipeline,
  91. std::shared_ptr<DecodeResource> resource,
  92. const DecodeOptions& opts);
  93. // @param block: if true, block when feature is not enough for one chunk
  94. // inference. Otherwise, return kWaitFeats.
  95. DecodeState Decode(bool block = true);
  96. void Rescoring();
  97. void Reset();
  98. void ResetContinuousDecoding();
  99. bool DecodedSomething() const {
  100. return !result_.empty() && !result_[0].sentence.empty();
  101. }
  102. // This method is used for time benchmark
  103. int num_frames_in_current_chunk() const {
  104. return num_frames_in_current_chunk_;
  105. }
  106. int frame_shift_in_ms() const {
  107. return model_->subsampling_rate() *
  108. feature_pipeline_->config().frame_shift * 1000 /
  109. feature_pipeline_->config().sample_rate;
  110. }
  111. int feature_frame_shift_in_ms() const {
  112. return feature_pipeline_->config().frame_shift * 1000 /
  113. feature_pipeline_->config().sample_rate;
  114. }
  115. const std::vector<DecodeResult>& result() const { return result_; }
  116. private:
  117. DecodeState AdvanceDecoding(bool block = true);
  118. void AttentionRescoring();
  119. void UpdateResult(bool finish = false);
  120. std::shared_ptr<FeaturePipeline> feature_pipeline_;
  121. std::shared_ptr<AsrModel> model_;
  122. std::shared_ptr<PostProcessor> post_processor_;
  123. std::shared_ptr<ContextGraph> context_graph_;
  124. std::shared_ptr<fst::VectorFst<fst::StdArc>> fst_ = nullptr;
  125. // output symbol table
  126. std::shared_ptr<fst::SymbolTable> symbol_table_;
  127. // e2e unit symbol table
  128. std::shared_ptr<fst::SymbolTable> unit_table_ = nullptr;
  129. const DecodeOptions& opts_;
  130. // cache feature
  131. bool start_ = false;
  132. // For continuous decoding
  133. int num_frames_ = 0;
  134. int global_frame_offset_ = 0;
  135. const int time_stamp_gap_ = 100; // timestamp gap between words in a sentence
  136. std::unique_ptr<SearchInterface> searcher_;
  137. std::unique_ptr<CtcEndpoint> ctc_endpointer_;
  138. int num_frames_in_current_chunk_ = 0;
  139. std::vector<DecodeResult> result_;
  140. public:
  141. WENET_DISALLOW_COPY_AND_ASSIGN(AsrDecoder);
  142. };
  143. } // namespace wenet
  144. #endif // DECODER_ASR_DECODER_H_