grpc_service.h 4.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25
//   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <grpc++/impl/codegen/async_stream.h>
#include <grpc++/impl/codegen/async_unary_call.h>
#include <grpc++/impl/codegen/proto_utils.h>
#include <grpc++/impl/codegen/rpc_method.h>
#include <grpc++/impl/codegen/service_type.h>
#include <grpc++/impl/codegen/status.h>
#include <grpc++/impl/codegen/stub_options.h>
#include <grpc++/impl/codegen/sync_stream.h>
#include <grpc++/support/byte_buffer.h>
W
Wu Yi 已提交
26
#include "paddle/fluid/operators/distributed/grpc/grpc_variable_response.h"
X
Xin Pan 已提交
27 28
#include "paddle/fluid/platform/profiler.h"

29 30 31 32 33 34 35 36 37 38 39 40 41 42 43
// NOTE: This method was originally created by tensorflow
//       (https://github.com/tensorflow/tensorflow/) we borrow this
//       method and did some modifications so that we can parse gRPC
//       requests without too much copying of the tensor data.

namespace grpc {
class CompletionQueue;
class Channel;
class RpcService;
class ServerCompletionQueue;
class ServerContext;

// Support parsing/unparsing of tensorflow::VariableResponse.
// Wire-format is identical to RecvVariableResponse.
template <>
44 45
class SerializationTraits<
    paddle::operators::distributed::GRPCVariableResponse> {
46 47
 public:
  static Status Serialize(
48
      const paddle::operators::distributed::GRPCVariableResponse& msg,
49
      grpc_byte_buffer** bp, bool* own_buffer) {
M
MRXLT 已提交
50 51
    PADDLE_THROW(paddle::platform::errors::Unimplemented(
        "SerializationTraits::Serialize not implemented!"));
52 53
    return Status();
  }
54 55
  static Status Deserialize(
      grpc_byte_buffer* buffer,
56
      paddle::operators::distributed::GRPCVariableResponse* msg,
57
      int max_message_size = INT_MAX) {
58 59 60 61 62 63
    if (buffer == nullptr) {
      return Status(StatusCode::INTERNAL, "No payload");
    }

    Status result = g_core_codegen_interface->ok();
    if (result.ok()) {
64
      paddle::operators::distributed::GrpcByteSource source(buffer);
65 66 67 68 69 70 71 72 73 74 75 76 77
      int ret = msg->Parse(&source);
      if (ret != 0) {
        result = Status(StatusCode::INTERNAL, "VariableResponse parse error");
      }
    }
    g_core_codegen_interface->grpc_byte_buffer_destroy(buffer);
    return result;
  }
};
}  // namespace grpc

namespace paddle {
namespace operators {
78
namespace distributed {
79 80 81 82

enum class GrpcMethod {
  kSendVariable,
  kGetVariable,
83
  kPrefetchVariable,
T
tangwei12 已提交
84
  kCheckpointNotify,
85
  kGetVariableNoBarrier,
86 87
  kGetMonomerVariable,
  kGetMonomerBarrier,
88
  kRequestNotify,
89 90
  kRequestSendAndRecv,
  // when you add new handler, change kGrpcNumMethods at the same time!
91 92 93
};

static const int kGrpcNumMethods =
94
    static_cast<int>(GrpcMethod::kRequestSendAndRecv) + 1;
95 96 97 98 99 100 101

inline const char* GrpcMethodName(GrpcMethod id) {
  switch (id) {
    case GrpcMethod::kSendVariable:
      return "/sendrecv.SendRecvService/SendVariable";
    case GrpcMethod::kGetVariable:
      return "/sendrecv.SendRecvService/GetVariable";
102 103
    case GrpcMethod::kGetVariableNoBarrier:
      return "/sendrecv.SendRecvService/GetVariableNoBarrier";
104 105 106 107
    case GrpcMethod::kGetMonomerVariable:
      return "/sendrecv.SendRecvService/GetMonomerVariable";
    case GrpcMethod::kGetMonomerBarrier:
      return "/sendrecv.SendRecvService/GetMonomerBarrier";
108
    case GrpcMethod::kPrefetchVariable:
109
      return "/sendrecv.SendRecvService/PrefetchVariable";
T
tangwei12 已提交
110 111
    case GrpcMethod::kCheckpointNotify:
      return "/sendrecv.SendRecvService/CheckpointNotify";
112 113
    case GrpcMethod::kRequestNotify:
      return "/sendrecv.SendRecvService/DistributeNotify";
114 115
    case GrpcMethod::kRequestSendAndRecv:
      return "/sendrecv.SendRecvService/SendAndRecvVariable";
116 117 118
  }

  // Shouldn't be reached.
M
MRXLT 已提交
119 120
  PADDLE_THROW(platform::errors::InvalidArgument(
      "Invalid id: not found valid method name"));
121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142
  return nullptr;
}

class GrpcService final {
 public:
  class AsyncService : public ::grpc::Service {
   public:
    AsyncService() {
      for (int i = 0; i < kGrpcNumMethods; ++i) {
        AddMethod(new ::grpc::internal::RpcServiceMethod(
            GrpcMethodName(static_cast<GrpcMethod>(i)),
            ::grpc::internal::RpcMethod::NORMAL_RPC, nullptr));
        ::grpc::Service::MarkMethodAsync(i);
      }
    }
    virtual ~AsyncService() {}

    // Make RequestAsyncUnary public for grpc_call.h
    using ::grpc::Service::RequestAsyncUnary;
  };
};

143
}  // namespace distributed
144
}  // namespace operators
145
}  // namespace paddle