You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

176 lines
5.6 KiB

  1. // Copyright (c) 2021 Ximalaya Speech Team (Xiang Lyu)
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include "grpc/grpc_server.h"
  15. namespace wenet {
  16. using grpc::ServerReaderWriter;
  17. using wenet::Request;
  18. using wenet::Response;
  19. GrpcConnectionHandler::GrpcConnectionHandler(
  20. ServerReaderWriter<Response, Request>* stream,
  21. std::shared_ptr<Request> request, std::shared_ptr<Response> response,
  22. std::shared_ptr<FeaturePipelineConfig> feature_config,
  23. std::shared_ptr<DecodeOptions> decode_config,
  24. std::shared_ptr<DecodeResource> decode_resource)
  25. : stream_(std::move(stream)),
  26. request_(std::move(request)),
  27. response_(std::move(response)),
  28. feature_config_(std::move(feature_config)),
  29. decode_config_(std::move(decode_config)),
  30. decode_resource_(std::move(decode_resource)) {}
  31. void GrpcConnectionHandler::OnSpeechStart() {
  32. LOG(INFO) << "Received speech start signal, start reading speech";
  33. got_start_tag_ = true;
  34. response_->set_status(Response::ok);
  35. response_->set_type(Response::server_ready);
  36. stream_->Write(*response_);
  37. feature_pipeline_ = std::make_shared<FeaturePipeline>(*feature_config_);
  38. decoder_ = std::make_shared<AsrDecoder>(feature_pipeline_, decode_resource_,
  39. *decode_config_);
  40. // Start decoder thread
  41. decode_thread_ = std::make_shared<std::thread>(
  42. &GrpcConnectionHandler::DecodeThreadFunc, this);
  43. }
  44. void GrpcConnectionHandler::OnSpeechEnd() {
  45. LOG(INFO) << "Received speech end signal";
  46. CHECK(feature_pipeline_ != nullptr);
  47. feature_pipeline_->set_input_finished();
  48. got_end_tag_ = true;
  49. }
  50. void GrpcConnectionHandler::OnPartialResult() {
  51. LOG(INFO) << "Partial result";
  52. response_->set_status(Response::ok);
  53. response_->set_type(Response::partial_result);
  54. stream_->Write(*response_);
  55. }
  56. void GrpcConnectionHandler::OnFinalResult() {
  57. LOG(INFO) << "Final result";
  58. response_->set_status(Response::ok);
  59. response_->set_type(Response::final_result);
  60. stream_->Write(*response_);
  61. }
  62. void GrpcConnectionHandler::OnFinish() {
  63. // Send finish tag
  64. response_->set_status(Response::ok);
  65. response_->set_type(Response::speech_end);
  66. stream_->Write(*response_);
  67. }
  68. void GrpcConnectionHandler::OnSpeechData() {
  69. // Read binary PCM data
  70. const int16_t* pcm_data =
  71. reinterpret_cast<const int16_t*>(request_->audio_data().c_str());
  72. int num_samples = request_->audio_data().length() / sizeof(int16_t);
  73. VLOG(2) << "Received " << num_samples << " samples";
  74. CHECK(feature_pipeline_ != nullptr);
  75. CHECK(decoder_ != nullptr);
  76. feature_pipeline_->AcceptWaveform(pcm_data, num_samples);
  77. }
  78. void GrpcConnectionHandler::SerializeResult(bool finish) {
  79. for (const DecodeResult& path : decoder_->result()) {
  80. Response_OneBest* one_best_ = response_->add_nbest();
  81. one_best_->set_sentence(path.sentence);
  82. if (finish) {
  83. for (const WordPiece& word_piece : path.word_pieces) {
  84. Response_OnePiece* one_piece_ = one_best_->add_wordpieces();
  85. one_piece_->set_word(word_piece.word);
  86. one_piece_->set_start(word_piece.start);
  87. one_piece_->set_end(word_piece.end);
  88. }
  89. }
  90. if (response_->nbest_size() == nbest_) {
  91. break;
  92. }
  93. }
  94. return;
  95. }
  96. void GrpcConnectionHandler::DecodeThreadFunc() {
  97. while (true) {
  98. DecodeState state = decoder_->Decode();
  99. response_->clear_status();
  100. response_->clear_type();
  101. response_->clear_nbest();
  102. if (state == DecodeState::kEndFeats) {
  103. decoder_->Rescoring();
  104. SerializeResult(true);
  105. OnFinalResult();
  106. OnFinish();
  107. stop_recognition_ = true;
  108. break;
  109. } else if (state == DecodeState::kEndpoint) {
  110. decoder_->Rescoring();
  111. SerializeResult(true);
  112. OnFinalResult();
  113. // If it's not continuous decoding, continue to do next recognition
  114. // otherwise stop the recognition
  115. if (continuous_decoding_) {
  116. decoder_->ResetContinuousDecoding();
  117. } else {
  118. OnFinish();
  119. stop_recognition_ = true;
  120. break;
  121. }
  122. } else {
  123. if (decoder_->DecodedSomething()) {
  124. SerializeResult(false);
  125. OnPartialResult();
  126. }
  127. }
  128. }
  129. }
  130. void GrpcConnectionHandler::operator()() {
  131. try {
  132. while (stream_->Read(request_.get())) {
  133. if (!got_start_tag_) {
  134. nbest_ = request_->decode_config().nbest_config();
  135. continuous_decoding_ =
  136. request_->decode_config().continuous_decoding_config();
  137. OnSpeechStart();
  138. } else {
  139. OnSpeechData();
  140. }
  141. }
  142. OnSpeechEnd();
  143. LOG(INFO) << "Read all pcm data, wait for decoding thread";
  144. if (decode_thread_ != nullptr) {
  145. decode_thread_->join();
  146. }
  147. } catch (std::exception const& e) {
  148. LOG(ERROR) << e.what();
  149. }
  150. }
  151. Status GrpcServer::Recognize(ServerContext* context,
  152. ServerReaderWriter<Response, Request>* stream) {
  153. LOG(INFO) << "Get Recognize request" << std::endl;
  154. auto request = std::make_shared<Request>();
  155. auto response = std::make_shared<Response>();
  156. GrpcConnectionHandler handler(stream, request, response, feature_config_,
  157. decode_config_, decode_resource_);
  158. std::thread t(std::move(handler));
  159. t.join();
  160. return Status::OK;
  161. }
  162. } // namespace wenet