grpc_service.h 4.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25
//   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <grpc++/impl/codegen/async_stream.h>
#include <grpc++/impl/codegen/async_unary_call.h>
#include <grpc++/impl/codegen/proto_utils.h>
#include <grpc++/impl/codegen/rpc_method.h>
#include <grpc++/impl/codegen/service_type.h>
#include <grpc++/impl/codegen/status.h>
#include <grpc++/impl/codegen/stub_options.h>
#include <grpc++/impl/codegen/sync_stream.h>
#include <grpc++/support/byte_buffer.h>
W
Wu Yi 已提交
26
#include "paddle/fluid/operators/distributed/grpc/grpc_variable_response.h"
X
Xin Pan 已提交
27 28
#include "paddle/fluid/platform/profiler.h"

29 30 31 32 33 34 35 36 37 38 39 40 41 42 43
// NOTE: This method was originally created by tensorflow
//       (https://github.com/tensorflow/tensorflow/) we borrow this
//       method and did some modifications so that we can parse gRPC
//       requests without too much copying of the tensor data.

namespace grpc {
class CompletionQueue;
class Channel;
class RpcService;
class ServerCompletionQueue;
class ServerContext;

// Support parsing/unparsing of tensorflow::VariableResponse.
// Wire-format is identical to RecvVariableResponse.
template <>
44 45
class SerializationTraits<
    paddle::operators::distributed::GRPCVariableResponse> {
46 47
 public:
  static Status Serialize(
48
      const paddle::operators::distributed::GRPCVariableResponse& msg,
49 50 51 52
      grpc_byte_buffer** bp, bool* own_buffer) {
    PADDLE_ENFORCE(false, "SerializationTraits::Serialize not implemented!");
    return Status();
  }
53 54
  static Status Deserialize(
      grpc_byte_buffer* buffer,
55
      paddle::operators::distributed::GRPCVariableResponse* msg,
56
      int max_message_size = INT_MAX) {
57 58 59 60 61 62
    if (buffer == nullptr) {
      return Status(StatusCode::INTERNAL, "No payload");
    }

    Status result = g_core_codegen_interface->ok();
    if (result.ok()) {
63
      paddle::operators::distributed::GrpcByteSource source(buffer);
64 65 66 67 68 69 70 71 72 73 74 75 76
      int ret = msg->Parse(&source);
      if (ret != 0) {
        result = Status(StatusCode::INTERNAL, "VariableResponse parse error");
      }
    }
    g_core_codegen_interface->grpc_byte_buffer_destroy(buffer);
    return result;
  }
};
}  // namespace grpc

namespace paddle {
namespace operators {
77
namespace distributed {
78 79 80 81

enum class GrpcMethod {
  kSendVariable,
  kGetVariable,
82
  kPrefetchVariable,
T
tangwei12 已提交
83
  kCheckpointNotify,
84
  kGetVariableNoBarrier,
85 86
  kGetMonomerVariable,
  kGetMonomerBarrier,
87 88 89
};

static const int kGrpcNumMethods =
90
    static_cast<int>(GrpcMethod::kGetMonomerBarrier) + 1;
91 92 93 94 95 96 97

inline const char* GrpcMethodName(GrpcMethod id) {
  switch (id) {
    case GrpcMethod::kSendVariable:
      return "/sendrecv.SendRecvService/SendVariable";
    case GrpcMethod::kGetVariable:
      return "/sendrecv.SendRecvService/GetVariable";
98 99
    case GrpcMethod::kGetVariableNoBarrier:
      return "/sendrecv.SendRecvService/GetVariableNoBarrier";
100 101 102 103
    case GrpcMethod::kGetMonomerVariable:
      return "/sendrecv.SendRecvService/GetMonomerVariable";
    case GrpcMethod::kGetMonomerBarrier:
      return "/sendrecv.SendRecvService/GetMonomerBarrier";
104
    case GrpcMethod::kPrefetchVariable:
105
      return "/sendrecv.SendRecvService/PrefetchVariable";
T
tangwei12 已提交
106 107
    case GrpcMethod::kCheckpointNotify:
      return "/sendrecv.SendRecvService/CheckpointNotify";
108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133
  }

  // Shouldn't be reached.
  PADDLE_ENFORCE(false, "Invalid id: not found valid method name");
  return nullptr;
}

class GrpcService final {
 public:
  class AsyncService : public ::grpc::Service {
   public:
    AsyncService() {
      for (int i = 0; i < kGrpcNumMethods; ++i) {
        AddMethod(new ::grpc::internal::RpcServiceMethod(
            GrpcMethodName(static_cast<GrpcMethod>(i)),
            ::grpc::internal::RpcMethod::NORMAL_RPC, nullptr));
        ::grpc::Service::MarkMethodAsync(i);
      }
    }
    virtual ~AsyncService() {}

    // Make RequestAsyncUnary public for grpc_call.h
    using ::grpc::Service::RequestAsyncUnary;
  };
};

134
}  // namespace distributed
135
}  // namespace operators
136
}  // namespace paddle