// Copyright (c) 2021 Ximalaya Speech Team (Xiang Lyu)
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
#include "grpc/grpc_server.h"
|
|
|
|
namespace wenet {
|
|
|
|
using grpc::ServerReaderWriter;
|
|
using wenet::Request;
|
|
using wenet::Response;
|
|
|
|
GrpcConnectionHandler::GrpcConnectionHandler(
|
|
ServerReaderWriter<Response, Request>* stream,
|
|
std::shared_ptr<Request> request, std::shared_ptr<Response> response,
|
|
std::shared_ptr<FeaturePipelineConfig> feature_config,
|
|
std::shared_ptr<DecodeOptions> decode_config,
|
|
std::shared_ptr<DecodeResource> decode_resource)
|
|
: stream_(std::move(stream)),
|
|
request_(std::move(request)),
|
|
response_(std::move(response)),
|
|
feature_config_(std::move(feature_config)),
|
|
decode_config_(std::move(decode_config)),
|
|
decode_resource_(std::move(decode_resource)) {}
|
|
|
|
void GrpcConnectionHandler::OnSpeechStart() {
|
|
LOG(INFO) << "Received speech start signal, start reading speech";
|
|
got_start_tag_ = true;
|
|
response_->set_status(Response::ok);
|
|
response_->set_type(Response::server_ready);
|
|
stream_->Write(*response_);
|
|
feature_pipeline_ = std::make_shared<FeaturePipeline>(*feature_config_);
|
|
decoder_ = std::make_shared<AsrDecoder>(feature_pipeline_, decode_resource_,
|
|
*decode_config_);
|
|
// Start decoder thread
|
|
decode_thread_ = std::make_shared<std::thread>(
|
|
&GrpcConnectionHandler::DecodeThreadFunc, this);
|
|
}
|
|
|
|
void GrpcConnectionHandler::OnSpeechEnd() {
|
|
LOG(INFO) << "Received speech end signal";
|
|
CHECK(feature_pipeline_ != nullptr);
|
|
feature_pipeline_->set_input_finished();
|
|
got_end_tag_ = true;
|
|
}
|
|
|
|
void GrpcConnectionHandler::OnPartialResult() {
|
|
LOG(INFO) << "Partial result";
|
|
response_->set_status(Response::ok);
|
|
response_->set_type(Response::partial_result);
|
|
stream_->Write(*response_);
|
|
}
|
|
|
|
void GrpcConnectionHandler::OnFinalResult() {
|
|
LOG(INFO) << "Final result";
|
|
response_->set_status(Response::ok);
|
|
response_->set_type(Response::final_result);
|
|
stream_->Write(*response_);
|
|
}
|
|
|
|
void GrpcConnectionHandler::OnFinish() {
|
|
// Send finish tag
|
|
response_->set_status(Response::ok);
|
|
response_->set_type(Response::speech_end);
|
|
stream_->Write(*response_);
|
|
}
|
|
|
|
void GrpcConnectionHandler::OnSpeechData() {
|
|
// Read binary PCM data
|
|
const int16_t* pcm_data =
|
|
reinterpret_cast<const int16_t*>(request_->audio_data().c_str());
|
|
int num_samples = request_->audio_data().length() / sizeof(int16_t);
|
|
VLOG(2) << "Received " << num_samples << " samples";
|
|
CHECK(feature_pipeline_ != nullptr);
|
|
CHECK(decoder_ != nullptr);
|
|
feature_pipeline_->AcceptWaveform(pcm_data, num_samples);
|
|
}
|
|
|
|
void GrpcConnectionHandler::SerializeResult(bool finish) {
|
|
for (const DecodeResult& path : decoder_->result()) {
|
|
Response_OneBest* one_best_ = response_->add_nbest();
|
|
one_best_->set_sentence(path.sentence);
|
|
if (finish) {
|
|
for (const WordPiece& word_piece : path.word_pieces) {
|
|
Response_OnePiece* one_piece_ = one_best_->add_wordpieces();
|
|
one_piece_->set_word(word_piece.word);
|
|
one_piece_->set_start(word_piece.start);
|
|
one_piece_->set_end(word_piece.end);
|
|
}
|
|
}
|
|
if (response_->nbest_size() == nbest_) {
|
|
break;
|
|
}
|
|
}
|
|
return;
|
|
}
|
|
|
|
void GrpcConnectionHandler::DecodeThreadFunc() {
|
|
while (true) {
|
|
DecodeState state = decoder_->Decode();
|
|
response_->clear_status();
|
|
response_->clear_type();
|
|
response_->clear_nbest();
|
|
if (state == DecodeState::kEndFeats) {
|
|
decoder_->Rescoring();
|
|
SerializeResult(true);
|
|
OnFinalResult();
|
|
OnFinish();
|
|
stop_recognition_ = true;
|
|
break;
|
|
} else if (state == DecodeState::kEndpoint) {
|
|
decoder_->Rescoring();
|
|
SerializeResult(true);
|
|
OnFinalResult();
|
|
// If it's not continuous decoding, continue to do next recognition
|
|
// otherwise stop the recognition
|
|
if (continuous_decoding_) {
|
|
decoder_->ResetContinuousDecoding();
|
|
} else {
|
|
OnFinish();
|
|
stop_recognition_ = true;
|
|
break;
|
|
}
|
|
} else {
|
|
if (decoder_->DecodedSomething()) {
|
|
SerializeResult(false);
|
|
OnPartialResult();
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
void GrpcConnectionHandler::operator()() {
|
|
try {
|
|
while (stream_->Read(request_.get())) {
|
|
if (!got_start_tag_) {
|
|
nbest_ = request_->decode_config().nbest_config();
|
|
continuous_decoding_ =
|
|
request_->decode_config().continuous_decoding_config();
|
|
OnSpeechStart();
|
|
} else {
|
|
OnSpeechData();
|
|
}
|
|
}
|
|
OnSpeechEnd();
|
|
LOG(INFO) << "Read all pcm data, wait for decoding thread";
|
|
if (decode_thread_ != nullptr) {
|
|
decode_thread_->join();
|
|
}
|
|
} catch (std::exception const& e) {
|
|
LOG(ERROR) << e.what();
|
|
}
|
|
}
|
|
|
|
Status GrpcServer::Recognize(ServerContext* context,
|
|
ServerReaderWriter<Response, Request>* stream) {
|
|
LOG(INFO) << "Get Recognize request" << std::endl;
|
|
auto request = std::make_shared<Request>();
|
|
auto response = std::make_shared<Response>();
|
|
GrpcConnectionHandler handler(stream, request, response, feature_config_,
|
|
decode_config_, decode_resource_);
|
|
std::thread t(std::move(handler));
|
|
t.join();
|
|
return Status::OK;
|
|
}
|
|
} // namespace wenet
|