You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

103 lines
3.5 KiB

  1. // Copyright (c) 2021 Mobvoi Inc (Binbin Zhang)
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #ifndef DECODER_CTC_WFST_BEAM_SEARCH_H_
  15. #define DECODER_CTC_WFST_BEAM_SEARCH_H_
  16. #include <memory>
  17. #include <vector>
  18. #include "decoder/context_graph.h"
  19. #include "decoder/search_interface.h"
  20. #include "kaldi/decoder/lattice-faster-online-decoder.h"
  21. #include "../utils/wn_utils.h"
  22. namespace wenet {
  23. class DecodableTensorScaled : public kaldi::DecodableInterface {
  24. public:
  25. explicit DecodableTensorScaled(float scale = 1.0) : scale_(scale) { Reset(); }
  26. void Reset();
  27. int32 NumFramesReady() const override { return num_frames_ready_; }
  28. bool IsLastFrame(int32 frame) const override;
  29. float LogLikelihood(int32 frame, int32 index) override;
  30. int32 NumIndices() const override;
  31. void AcceptLoglikes(const std::vector<float>& logp);
  32. void SetFinish() { done_ = true; }
  33. private:
  34. int num_frames_ready_ = 0;
  35. float scale_ = 1.0;
  36. bool done_ = false;
  37. std::vector<float> logp_;
  38. };
  39. // LatticeFasterDecoderConfig has the following key members
  40. // beam: decoding beam
  41. // max_active: Decoder max active states
  42. // lattice_beam: Lattice generation beam
  43. struct CtcWfstBeamSearchOptions : public kaldi::LatticeFasterDecoderConfig {
  44. float acoustic_scale = 1.0;
  45. float nbest = 10;
  46. // When blank score is greater than this thresh, skip the frame in viterbi
  47. // search
  48. float blank_skip_thresh = 0.98;
  49. float blank_scale = 1.0;
  50. int blank = 0;
  51. };
  52. class CtcWfstBeamSearch : public SearchInterface {
  53. public:
  54. explicit CtcWfstBeamSearch(
  55. const fst::Fst<fst::StdArc>& fst, const CtcWfstBeamSearchOptions& opts,
  56. const std::shared_ptr<ContextGraph>& context_graph);
  57. void Search(const std::vector<std::vector<float>>& logp) override;
  58. void Reset() override;
  59. void FinalizeSearch() override;
  60. SearchType Type() const override { return SearchType::kWfstBeamSearch; }
  61. // For CTC prefix beam search, both inputs and outputs are hypotheses_
  62. const std::vector<std::vector<int>>& Inputs() const override {
  63. return inputs_;
  64. }
  65. const std::vector<std::vector<int>>& Outputs() const override {
  66. return outputs_;
  67. }
  68. const std::vector<float>& Likelihood() const override { return likelihood_; }
  69. const std::vector<std::vector<int>>& Times() const override { return times_; }
  70. private:
  71. // Sub one and remove <blank>
  72. void ConvertToInputs(const std::vector<int>& alignment,
  73. std::vector<int>* input,
  74. std::vector<int>* time = nullptr);
  75. int num_frames_ = 0;
  76. std::vector<int> decoded_frames_mapping_;
  77. int last_best_ = 0; // last none blank best id
  78. std::vector<float> last_frame_prob_;
  79. bool is_last_frame_blank_ = false;
  80. std::vector<std::vector<int>> inputs_, outputs_;
  81. std::vector<float> likelihood_;
  82. std::vector<std::vector<int>> times_;
  83. DecodableTensorScaled decodable_;
  84. kaldi::LatticeFasterOnlineDecoder decoder_;
  85. std::shared_ptr<ContextGraph> context_graph_;
  86. const CtcWfstBeamSearchOptions& opts_;
  87. };
  88. } // namespace wenet
  89. #endif // DECODER_CTC_WFST_BEAM_SEARCH_H_