You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

85 lines
2.6 KiB

  1. // Copyright (c) 2021 Ximalaya Speech Team (Xiang Lyu)
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include "grpc/grpc_client.h"
  15. #include "utils/log.h"
  16. namespace wenet {
  17. using grpc::Channel;
  18. using grpc::ClientContext;
  19. using grpc::ClientReaderWriter;
  20. using grpc::Status;
  21. using wenet::Request;
  22. using wenet::Response;
  23. GrpcClient::GrpcClient(const std::string& host, int port, int nbest,
  24. bool continuous_decoding)
  25. : host_(host),
  26. port_(port),
  27. nbest_(nbest),
  28. continuous_decoding_(continuous_decoding) {
  29. Connect();
  30. t_.reset(new std::thread(&GrpcClient::ReadLoopFunc, this));
  31. }
  32. void GrpcClient::Connect() {
  33. channel_ = grpc::CreateChannel(host_ + ":" + std::to_string(port_),
  34. grpc::InsecureChannelCredentials());
  35. stub_ = ASR::NewStub(channel_);
  36. context_ = std::make_shared<ClientContext>();
  37. stream_ = stub_->Recognize(context_.get());
  38. request_ = std::make_shared<Request>();
  39. response_ = std::make_shared<Response>();
  40. request_->mutable_decode_config()->set_nbest_config(nbest_);
  41. request_->mutable_decode_config()->set_continuous_decoding_config(
  42. continuous_decoding_);
  43. stream_->Write(*request_);
  44. }
  45. void GrpcClient::SendBinaryData(const void* data, size_t size) {
  46. const int16_t* pdata = reinterpret_cast<const int16_t*>(data);
  47. request_->set_audio_data(pdata, size);
  48. stream_->Write(*request_);
  49. }
  50. void GrpcClient::ReadLoopFunc() {
  51. try {
  52. while (stream_->Read(response_.get())) {
  53. for (int i = 0; i < response_->nbest_size(); i++) {
  54. // you can also traverse wordpieces like demonstrated above
  55. LOG(INFO) << i + 1 << "best " << response_->nbest(i).sentence();
  56. }
  57. if (response_->status() != Response_Status_ok) {
  58. break;
  59. }
  60. if (response_->type() == Response_Type_speech_end) {
  61. done_ = true;
  62. break;
  63. }
  64. }
  65. } catch (std::exception const& e) {
  66. LOG(ERROR) << e.what();
  67. }
  68. }
  69. void GrpcClient::Join() {
  70. stream_->WritesDone();
  71. t_->join();
  72. Status status = stream_->Finish();
  73. if (!status.ok()) {
  74. LOG(INFO) << "Recognize rpc failed.";
  75. }
  76. }
  77. } // namespace wenet