// Copyright (c) 2021 Ximalaya Speech Team (Xiang Lyu) // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "grpc/grpc_server.h" namespace wenet { using grpc::ServerReaderWriter; using wenet::Request; using wenet::Response; GrpcConnectionHandler::GrpcConnectionHandler( ServerReaderWriter* stream, std::shared_ptr request, std::shared_ptr response, std::shared_ptr feature_config, std::shared_ptr decode_config, std::shared_ptr decode_resource) : stream_(std::move(stream)), request_(std::move(request)), response_(std::move(response)), feature_config_(std::move(feature_config)), decode_config_(std::move(decode_config)), decode_resource_(std::move(decode_resource)) {} void GrpcConnectionHandler::OnSpeechStart() { LOG(INFO) << "Received speech start signal, start reading speech"; got_start_tag_ = true; response_->set_status(Response::ok); response_->set_type(Response::server_ready); stream_->Write(*response_); feature_pipeline_ = std::make_shared(*feature_config_); decoder_ = std::make_shared(feature_pipeline_, decode_resource_, *decode_config_); // Start decoder thread decode_thread_ = std::make_shared( &GrpcConnectionHandler::DecodeThreadFunc, this); } void GrpcConnectionHandler::OnSpeechEnd() { LOG(INFO) << "Received speech end signal"; CHECK(feature_pipeline_ != nullptr); feature_pipeline_->set_input_finished(); got_end_tag_ = true; } void GrpcConnectionHandler::OnPartialResult() { LOG(INFO) << "Partial result"; response_->set_status(Response::ok); response_->set_type(Response::partial_result); stream_->Write(*response_); } void GrpcConnectionHandler::OnFinalResult() { LOG(INFO) << "Final result"; response_->set_status(Response::ok); response_->set_type(Response::final_result); stream_->Write(*response_); } void GrpcConnectionHandler::OnFinish() { // Send finish tag response_->set_status(Response::ok); response_->set_type(Response::speech_end); stream_->Write(*response_); } void GrpcConnectionHandler::OnSpeechData() { // Read binary PCM data const int16_t* pcm_data = reinterpret_cast(request_->audio_data().c_str()); int num_samples = request_->audio_data().length() / sizeof(int16_t); VLOG(2) << "Received " << num_samples << " samples"; CHECK(feature_pipeline_ != nullptr); CHECK(decoder_ != nullptr); feature_pipeline_->AcceptWaveform(pcm_data, num_samples); } void GrpcConnectionHandler::SerializeResult(bool finish) { for (const DecodeResult& path : decoder_->result()) { Response_OneBest* one_best_ = response_->add_nbest(); one_best_->set_sentence(path.sentence); if (finish) { for (const WordPiece& word_piece : path.word_pieces) { Response_OnePiece* one_piece_ = one_best_->add_wordpieces(); one_piece_->set_word(word_piece.word); one_piece_->set_start(word_piece.start); one_piece_->set_end(word_piece.end); } } if (response_->nbest_size() == nbest_) { break; } } return; } void GrpcConnectionHandler::DecodeThreadFunc() { while (true) { DecodeState state = decoder_->Decode(); response_->clear_status(); response_->clear_type(); response_->clear_nbest(); if (state == DecodeState::kEndFeats) { decoder_->Rescoring(); SerializeResult(true); OnFinalResult(); OnFinish(); stop_recognition_ = true; break; } else if (state == DecodeState::kEndpoint) { decoder_->Rescoring(); SerializeResult(true); OnFinalResult(); // If it's not continuous decoding, continue to do next recognition // otherwise stop the recognition if (continuous_decoding_) { decoder_->ResetContinuousDecoding(); } else { OnFinish(); stop_recognition_ = true; break; } } else { if (decoder_->DecodedSomething()) { SerializeResult(false); OnPartialResult(); } } } } void GrpcConnectionHandler::operator()() { try { while (stream_->Read(request_.get())) { if (!got_start_tag_) { nbest_ = request_->decode_config().nbest_config(); continuous_decoding_ = request_->decode_config().continuous_decoding_config(); OnSpeechStart(); } else { OnSpeechData(); } } OnSpeechEnd(); LOG(INFO) << "Read all pcm data, wait for decoding thread"; if (decode_thread_ != nullptr) { decode_thread_->join(); } } catch (std::exception const& e) { LOG(ERROR) << e.what(); } } Status GrpcServer::Recognize(ServerContext* context, ServerReaderWriter* stream) { LOG(INFO) << "Get Recognize request" << std::endl; auto request = std::make_shared(); auto response = std::make_shared(); GrpcConnectionHandler handler(stream, request, response, feature_config_, decode_config_, decode_resource_); std::thread t(std::move(handler)); t.join(); return Status::OK; } } // namespace wenet