You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

77 lines
2.6 KiB

  1. // Copyright (c) 2021 Mobvoi Inc (Zhendong Peng)
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include "decoder/ctc_endpoint.h"
  15. #include <math.h>
  16. #include <string>
  17. #include <vector>
  18. #include "glog/logging.h"
  19. namespace wenet {
  20. CtcEndpoint::CtcEndpoint(const CtcEndpointConfig& config) : config_(config) {
  21. Reset();
  22. }
  23. void CtcEndpoint::Reset() {
  24. num_frames_decoded_ = 0;
  25. num_frames_trailing_blank_ = 0;
  26. }
  27. static bool RuleActivated(const CtcEndpointRule& rule,
  28. const std::string& rule_name, bool decoded_sth,
  29. int trailing_silence, int utterance_length) {
  30. bool ans = (decoded_sth || !rule.must_decoded_sth) &&
  31. trailing_silence >= rule.min_trailing_silence &&
  32. utterance_length >= rule.min_utterance_length;
  33. if (ans) {
  34. VLOG(2) << "Endpointing rule " << rule_name
  35. << " activated: " << (decoded_sth ? "true" : "false") << ','
  36. << trailing_silence << ',' << utterance_length;
  37. }
  38. return ans;
  39. }
  40. bool CtcEndpoint::IsEndpoint(
  41. const std::vector<std::vector<float>>& ctc_log_probs,
  42. bool decoded_something) {
  43. for (int t = 0; t < ctc_log_probs.size(); ++t) {
  44. const auto& logp_t = ctc_log_probs[t];
  45. float blank_prob = expf(logp_t[config_.blank]);
  46. num_frames_decoded_++;
  47. if (blank_prob > config_.blank_threshold) {
  48. num_frames_trailing_blank_++;
  49. } else {
  50. num_frames_trailing_blank_ = 0;
  51. }
  52. }
  53. CHECK_GE(num_frames_decoded_, num_frames_trailing_blank_);
  54. CHECK_GT(frame_shift_in_ms_, 0);
  55. int utterance_length = num_frames_decoded_ * frame_shift_in_ms_;
  56. int trailing_silence = num_frames_trailing_blank_ * frame_shift_in_ms_;
  57. if (RuleActivated(config_.rule1, "rule1", decoded_something, trailing_silence,
  58. utterance_length))
  59. return true;
  60. if (RuleActivated(config_.rule2, "rule2", decoded_something, trailing_silence,
  61. utterance_length))
  62. return true;
  63. if (RuleActivated(config_.rule3, "rule3", decoded_something, trailing_silence,
  64. utterance_length))
  65. return true;
  66. return false;
  67. }
  68. } // namespace wenet