You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

83 lines
2.6 KiB

  1. // Copyright (c) 2021 Ximalaya Speech Team (Xiang Lyu)
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include "grpc/grpc_client.h"
  15. namespace wenet {
  16. using grpc::Channel;
  17. using grpc::ClientContext;
  18. using grpc::ClientReaderWriter;
  19. using grpc::Status;
  20. using wenet::Request;
  21. using wenet::Response;
  22. GrpcClient::GrpcClient(const std::string& host, int port, int nbest,
  23. bool continuous_decoding)
  24. : host_(host),
  25. port_(port),
  26. nbest_(nbest),
  27. continuous_decoding_(continuous_decoding) {
  28. Connect();
  29. t_.reset(new std::thread(&GrpcClient::ReadLoopFunc, this));
  30. }
  31. void GrpcClient::Connect() {
  32. channel_ = grpc::CreateChannel(host_ + ":" + std::to_string(port_),
  33. grpc::InsecureChannelCredentials());
  34. stub_ = ASR::NewStub(channel_);
  35. context_ = std::make_shared<ClientContext>();
  36. stream_ = stub_->Recognize(context_.get());
  37. request_ = std::make_shared<Request>();
  38. response_ = std::make_shared<Response>();
  39. request_->mutable_decode_config()->set_nbest_config(nbest_);
  40. request_->mutable_decode_config()->set_continuous_decoding_config(
  41. continuous_decoding_);
  42. stream_->Write(*request_);
  43. }
  44. void GrpcClient::SendBinaryData(const void* data, size_t size) {
  45. const int16_t* pdata = reinterpret_cast<const int16_t*>(data);
  46. request_->set_audio_data(pdata, size);
  47. stream_->Write(*request_);
  48. }
  49. void GrpcClient::ReadLoopFunc() {
  50. try {
  51. while (stream_->Read(response_.get())) {
  52. for (int i = 0; i < response_->nbest_size(); i++) {
  53. // you can also traverse wordpieces like demonstrated above
  54. LOG(INFO) << i + 1 << "best " << response_->nbest(i).sentence();
  55. }
  56. if (response_->status() != Response_Status_ok) {
  57. break;
  58. }
  59. if (response_->type() == Response_Type_speech_end) {
  60. done_ = true;
  61. break;
  62. }
  63. }
  64. } catch (std::exception const& e) {
  65. LOG(ERROR) << e.what();
  66. }
  67. }
  68. void GrpcClient::Join() {
  69. stream_->WritesDone();
  70. t_->join();
  71. Status status = stream_->Finish();
  72. if (!status.ok()) {
  73. LOG(INFO) << "Recognize rpc failed.";
  74. }
  75. }
  76. } // namespace wenet