grpc_service.h 3.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78
//   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <grpc++/impl/codegen/async_stream.h>
#include <grpc++/impl/codegen/async_unary_call.h>
#include <grpc++/impl/codegen/proto_utils.h>
#include <grpc++/impl/codegen/rpc_method.h>
#include <grpc++/impl/codegen/service_type.h>
#include <grpc++/impl/codegen/status.h>
#include <grpc++/impl/codegen/stub_options.h>
#include <grpc++/impl/codegen/sync_stream.h>
#include <grpc++/support/byte_buffer.h>
#include "paddle/fluid/operators/detail/variable_response.h"

// NOTE: This method was originally created by tensorflow
//       (https://github.com/tensorflow/tensorflow/) we borrow this
//       method and did some modifications so that we can parse gRPC
//       requests without too much copying of the tensor data.

namespace grpc {
class CompletionQueue;
class Channel;
class RpcService;
class ServerCompletionQueue;
class ServerContext;

// Support parsing/unparsing of tensorflow::VariableResponse.
// Wire-format is identical to RecvVariableResponse.
template <>
class SerializationTraits<paddle::operators::detail::VariableResponse> {
 public:
  static Status Serialize(
      const paddle::operators::detail::VariableResponse& msg,
      grpc_byte_buffer** bp, bool* own_buffer) {
    PADDLE_ENFORCE(false, "SerializationTraits::Serialize not implemented!");
    return Status();
  }
  static Status Deserialize(grpc_byte_buffer* buffer,
                            paddle::operators::detail::VariableResponse* msg,
                            int max_message_size = INT_MAX) {
    if (buffer == nullptr) {
      return Status(StatusCode::INTERNAL, "No payload");
    }

    Status result = g_core_codegen_interface->ok();
    if (result.ok()) {
      paddle::operators::detail::GrpcByteSource source(buffer);
      int ret = msg->Parse(&source);
      if (ret != 0) {
        result = Status(StatusCode::INTERNAL, "VariableResponse parse error");
      }
    }
    g_core_codegen_interface->grpc_byte_buffer_destroy(buffer);
    return result;
  }
};
}  // namespace grpc

namespace paddle {
namespace operators {
namespace detail {

enum class GrpcMethod {
  kSendVariable,
  kGetVariable,
79
  kPrefetchVariable,
80 81 82 83 84 85 86 87 88 89 90
};

static const int kGrpcNumMethods =
    static_cast<int>(GrpcMethod::kGetVariable) + 1;

inline const char* GrpcMethodName(GrpcMethod id) {
  switch (id) {
    case GrpcMethod::kSendVariable:
      return "/sendrecv.SendRecvService/SendVariable";
    case GrpcMethod::kGetVariable:
      return "/sendrecv.SendRecvService/GetVariable";
91 92
    case GrpcMethod::kPrefetchVariable:
      return "/sendrecv.SendREcvService/PrefetchVariable";
93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121
  }

  // Shouldn't be reached.
  PADDLE_ENFORCE(false, "Invalid id: not found valid method name");
  return nullptr;
}

class GrpcService final {
 public:
  class AsyncService : public ::grpc::Service {
   public:
    AsyncService() {
      for (int i = 0; i < kGrpcNumMethods; ++i) {
        AddMethod(new ::grpc::internal::RpcServiceMethod(
            GrpcMethodName(static_cast<GrpcMethod>(i)),
            ::grpc::internal::RpcMethod::NORMAL_RPC, nullptr));
        ::grpc::Service::MarkMethodAsync(i);
      }
    }
    virtual ~AsyncService() {}

    // Make RequestAsyncUnary public for grpc_call.h
    using ::grpc::Service::RequestAsyncUnary;
  };
};

}  // namespace detail
}  // namespace operator
}  // namespace paddle