You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

176 lines
5.6 KiB

// Copyright (c) 2021 Ximalaya Speech Team (Xiang Lyu)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "grpc/grpc_server.h"
namespace wenet {
using grpc::ServerReaderWriter;
using wenet::Request;
using wenet::Response;
GrpcConnectionHandler::GrpcConnectionHandler(
ServerReaderWriter<Response, Request>* stream,
std::shared_ptr<Request> request, std::shared_ptr<Response> response,
std::shared_ptr<FeaturePipelineConfig> feature_config,
std::shared_ptr<DecodeOptions> decode_config,
std::shared_ptr<DecodeResource> decode_resource)
: stream_(std::move(stream)),
request_(std::move(request)),
response_(std::move(response)),
feature_config_(std::move(feature_config)),
decode_config_(std::move(decode_config)),
decode_resource_(std::move(decode_resource)) {}
void GrpcConnectionHandler::OnSpeechStart() {
LOG(INFO) << "Received speech start signal, start reading speech";
got_start_tag_ = true;
response_->set_status(Response::ok);
response_->set_type(Response::server_ready);
stream_->Write(*response_);
feature_pipeline_ = std::make_shared<FeaturePipeline>(*feature_config_);
decoder_ = std::make_shared<AsrDecoder>(feature_pipeline_, decode_resource_,
*decode_config_);
// Start decoder thread
decode_thread_ = std::make_shared<std::thread>(
&GrpcConnectionHandler::DecodeThreadFunc, this);
}
void GrpcConnectionHandler::OnSpeechEnd() {
LOG(INFO) << "Received speech end signal";
CHECK(feature_pipeline_ != nullptr);
feature_pipeline_->set_input_finished();
got_end_tag_ = true;
}
void GrpcConnectionHandler::OnPartialResult() {
LOG(INFO) << "Partial result";
response_->set_status(Response::ok);
response_->set_type(Response::partial_result);
stream_->Write(*response_);
}
void GrpcConnectionHandler::OnFinalResult() {
LOG(INFO) << "Final result";
response_->set_status(Response::ok);
response_->set_type(Response::final_result);
stream_->Write(*response_);
}
void GrpcConnectionHandler::OnFinish() {
// Send finish tag
response_->set_status(Response::ok);
response_->set_type(Response::speech_end);
stream_->Write(*response_);
}
void GrpcConnectionHandler::OnSpeechData() {
// Read binary PCM data
const int16_t* pcm_data =
reinterpret_cast<const int16_t*>(request_->audio_data().c_str());
int num_samples = request_->audio_data().length() / sizeof(int16_t);
VLOG(2) << "Received " << num_samples << " samples";
CHECK(feature_pipeline_ != nullptr);
CHECK(decoder_ != nullptr);
feature_pipeline_->AcceptWaveform(pcm_data, num_samples);
}
void GrpcConnectionHandler::SerializeResult(bool finish) {
for (const DecodeResult& path : decoder_->result()) {
Response_OneBest* one_best_ = response_->add_nbest();
one_best_->set_sentence(path.sentence);
if (finish) {
for (const WordPiece& word_piece : path.word_pieces) {
Response_OnePiece* one_piece_ = one_best_->add_wordpieces();
one_piece_->set_word(word_piece.word);
one_piece_->set_start(word_piece.start);
one_piece_->set_end(word_piece.end);
}
}
if (response_->nbest_size() == nbest_) {
break;
}
}
return;
}
void GrpcConnectionHandler::DecodeThreadFunc() {
while (true) {
DecodeState state = decoder_->Decode();
response_->clear_status();
response_->clear_type();
response_->clear_nbest();
if (state == DecodeState::kEndFeats) {
decoder_->Rescoring();
SerializeResult(true);
OnFinalResult();
OnFinish();
stop_recognition_ = true;
break;
} else if (state == DecodeState::kEndpoint) {
decoder_->Rescoring();
SerializeResult(true);
OnFinalResult();
// If it's not continuous decoding, continue to do next recognition
// otherwise stop the recognition
if (continuous_decoding_) {
decoder_->ResetContinuousDecoding();
} else {
OnFinish();
stop_recognition_ = true;
break;
}
} else {
if (decoder_->DecodedSomething()) {
SerializeResult(false);
OnPartialResult();
}
}
}
}
void GrpcConnectionHandler::operator()() {
try {
while (stream_->Read(request_.get())) {
if (!got_start_tag_) {
nbest_ = request_->decode_config().nbest_config();
continuous_decoding_ =
request_->decode_config().continuous_decoding_config();
OnSpeechStart();
} else {
OnSpeechData();
}
}
OnSpeechEnd();
LOG(INFO) << "Read all pcm data, wait for decoding thread";
if (decode_thread_ != nullptr) {
decode_thread_->join();
}
} catch (std::exception const& e) {
LOG(ERROR) << e.what();
}
}
Status GrpcServer::Recognize(ServerContext* context,
ServerReaderWriter<Response, Request>* stream) {
LOG(INFO) << "Get Recognize request" << std::endl;
auto request = std::make_shared<Request>();
auto response = std::make_shared<Response>();
GrpcConnectionHandler handler(stream, request, response, feature_config_,
decode_config_, decode_resource_);
std::thread t(std::move(handler));
t.join();
return Status::OK;
}
} // namespace wenet