You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

77 lines
2.6 KiB

// Copyright (c) 2021 Mobvoi Inc (Zhendong Peng)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "decoder/ctc_endpoint.h"
#include <math.h>
#include <string>
#include <vector>
#include "glog/logging.h"
namespace wenet {
CtcEndpoint::CtcEndpoint(const CtcEndpointConfig& config) : config_(config) {
Reset();
}
void CtcEndpoint::Reset() {
num_frames_decoded_ = 0;
num_frames_trailing_blank_ = 0;
}
static bool RuleActivated(const CtcEndpointRule& rule,
const std::string& rule_name, bool decoded_sth,
int trailing_silence, int utterance_length) {
bool ans = (decoded_sth || !rule.must_decoded_sth) &&
trailing_silence >= rule.min_trailing_silence &&
utterance_length >= rule.min_utterance_length;
if (ans) {
VLOG(2) << "Endpointing rule " << rule_name
<< " activated: " << (decoded_sth ? "true" : "false") << ','
<< trailing_silence << ',' << utterance_length;
}
return ans;
}
bool CtcEndpoint::IsEndpoint(
const std::vector<std::vector<float>>& ctc_log_probs,
bool decoded_something) {
for (int t = 0; t < ctc_log_probs.size(); ++t) {
const auto& logp_t = ctc_log_probs[t];
float blank_prob = expf(logp_t[config_.blank]);
num_frames_decoded_++;
if (blank_prob > config_.blank_threshold) {
num_frames_trailing_blank_++;
} else {
num_frames_trailing_blank_ = 0;
}
}
CHECK_GE(num_frames_decoded_, num_frames_trailing_blank_);
CHECK_GT(frame_shift_in_ms_, 0);
int utterance_length = num_frames_decoded_ * frame_shift_in_ms_;
int trailing_silence = num_frames_trailing_blank_ * frame_shift_in_ms_;
if (RuleActivated(config_.rule1, "rule1", decoded_something, trailing_silence,
utterance_length))
return true;
if (RuleActivated(config_.rule2, "rule2", decoded_something, trailing_silence,
utterance_length))
return true;
if (RuleActivated(config_.rule3, "rule3", decoded_something, trailing_silence,
utterance_length))
return true;
return false;
}
} // namespace wenet