You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

78 lines
2.6 KiB

  1. // Copyright (c) 2021 Mobvoi Inc (Zhendong Peng)
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #ifndef DECODER_CTC_ENDPOINT_H_
  15. #define DECODER_CTC_ENDPOINT_H_
  16. #include <vector>
  17. namespace wenet {
  18. struct CtcEndpointRule {
  19. bool must_decoded_sth;
  20. int min_trailing_silence;
  21. int min_utterance_length;
  22. CtcEndpointRule(bool must_decoded_sth = true, int min_trailing_silence = 1000,
  23. int min_utterance_length = 0)
  24. : must_decoded_sth(must_decoded_sth),
  25. min_trailing_silence(min_trailing_silence),
  26. min_utterance_length(min_utterance_length) {}
  27. };
  28. struct CtcEndpointConfig {
  29. /// We consider blank as silence for purposes of endpointing.
  30. int blank = 0; // blank id
  31. float blank_threshold = 0.8; // blank threshold to be silence
  32. /// We support three rules. We terminate decoding if ANY of these rules
  33. /// evaluates to "true". If you want to add more rules, do it by changing this
  34. /// code. If you want to disable a rule, you can set the silence-timeout for
  35. /// that rule to a very large number.
  36. /// rule1 times out after 5000 ms of silence, even if we decoded nothing.
  37. CtcEndpointRule rule1;
  38. /// rule2 times out after 1000 ms of silence after decoding something.
  39. CtcEndpointRule rule2;
  40. /// rule3 times out after the utterance is 20000 ms long, regardless of
  41. /// anything else.
  42. CtcEndpointRule rule3;
  43. CtcEndpointConfig()
  44. : rule1(false, 5000, 0), rule2(true, 1000, 0), rule3(false, 0, 20000) {}
  45. };
  46. class CtcEndpoint {
  47. public:
  48. explicit CtcEndpoint(const CtcEndpointConfig& config);
  49. void Reset();
  50. /// This function returns true if this set of endpointing rules thinks we
  51. /// should terminate decoding.
  52. bool IsEndpoint(const std::vector<std::vector<float>>& ctc_log_probs,
  53. bool decoded_something);
  54. void frame_shift_in_ms(int frame_shift_in_ms) {
  55. frame_shift_in_ms_ = frame_shift_in_ms;
  56. }
  57. private:
  58. CtcEndpointConfig config_;
  59. int frame_shift_in_ms_ = -1;
  60. int num_frames_decoded_ = 0;
  61. int num_frames_trailing_blank_ = 0;
  62. };
  63. } // namespace wenet
  64. #endif // DECODER_CTC_ENDPOINT_H_