listen_and_serv_op.cc 6.8 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
T
typhoonzero 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include <stdint.h>
#include <sys/stat.h>
#include <ostream>
#include <thread>

#include <unistd.h>

Y
Yi Wang 已提交
22 23 24 25 26 27 28 29
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/proto_desc.h"
#include "paddle/fluid/operators/detail/grpc_server.h"
#include "paddle/fluid/operators/detail/sendrecvop_utils.h"
#include "paddle/fluid/operators/detail/simple_block_queue.h"
30
#include "paddle/fluid/string/printf.h"
T
typhoonzero 已提交
31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77

namespace paddle {
namespace operators {

constexpr char kOptimizeBlock[] = "OptimizeBlock";

void RunServer(std::shared_ptr<detail::AsyncGRPCServer> service) {
  service->RunSyncUpdate();
  VLOG(4) << "RunServer thread end";
}

static void CreateTensorFromMessageType(framework::Variable *var,
                                        sendrecv::VarType var_type) {
  if (var_type == sendrecv::VarType::LOD_TENSOR) {
    var->GetMutable<framework::LoDTensor>();
  } else if (var_type == sendrecv::VarType::SELECTED_ROWS) {
    var->GetMutable<framework::SelectedRows>();
  } else {
    PADDLE_THROW(
        "VariableMessage type %d is not in "
        "[LoDTensor, SelectedRows]",
        var_type);
  }
}

class ListenAndServOp : public framework::OperatorBase {
 public:
  ListenAndServOp(const std::string &type,
                  const framework::VariableNameMap &inputs,
                  const framework::VariableNameMap &outputs,
                  const framework::AttributeMap &attrs)
      : OperatorBase(type, inputs, outputs, attrs) {
    if (!rpc_service_) {
      std::string endpoint = Attr<std::string>("endpoint");
      rpc_service_.reset(new detail::AsyncGRPCServer(endpoint));
      server_thread_.reset(new std::thread(RunServer, rpc_service_));
    }
  }

  void Stop() override {
    detail::MessageWithName term_msg;
    term_msg.first = LISTEN_TERMINATE_MESSAGE;
    rpc_service_->Push(term_msg);
    rpc_service_->ShutDown();
    server_thread_->join();
  }

T
typhoonzero 已提交
78 79
  void RunImpl(const framework::Scope &scope,
               const platform::Place &dev_place) const override {
T
typhoonzero 已提交
80 81 82 83 84 85 86
    platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
    auto &dev_ctx = *pool.Get(dev_place);
    framework::Scope &recv_scope = scope.NewScope();

    // FIXME(Yancey1989): initialize rpc server with lazy mode.
    rpc_service_->SetScope(&recv_scope);
    rpc_service_->SetDevCtx(&dev_ctx);
T
typhoonzero 已提交
87
    auto ins = Inputs("X");
88
    auto fan_in = Attr<int>("Fanin");
T
typhoonzero 已提交
89 90 91 92 93 94 95

    auto *block = Attr<framework::BlockDesc *>(kOptimizeBlock);
    auto *program = block->Program();
    framework::Executor executor(dev_place);

    // TODO(typhoonzero): change this to a while_op for every cluster-batch.
    bool exit_flag = false;
96 97 98
    // Record received sparse variables, so that
    // we could reset those after execute optimize program
    std::vector<framework::Variable *> sparse_vars;
T
typhoonzero 已提交
99 100 101 102 103 104 105 106
    while (!exit_flag) {
      // Get from multiple trainers, we don't care about the order in which
      // the gradients arrives, just add suffix 0~n and merge the gradient.
      rpc_service_->SetCond(0);
      size_t recv_var_cnt = 0;
      int batch_barrier = 0;
      while (batch_barrier != fan_in) {
        const detail::MessageWithName &v = rpc_service_->Get();
T
typhoonzero 已提交
107 108
        auto recv_var_name = v.first;
        if (recv_var_name == LISTEN_TERMINATE_MESSAGE) {
T
typhoonzero 已提交
109 110 111
          LOG(INFO) << "received terminate message and exit";
          exit_flag = true;
          break;
T
typhoonzero 已提交
112
        } else if (recv_var_name == BATCH_BARRIER_MESSAGE) {
T
typhoonzero 已提交
113 114 115 116
          VLOG(3) << "recv batch barrier message";
          batch_barrier++;
          continue;
        } else {
T
typhoonzero 已提交
117
          VLOG(3) << "received grad: " << recv_var_name;
T
typhoonzero 已提交
118
          recv_var_cnt++;
T
typhoonzero 已提交
119
          auto *var = recv_scope.FindVar(recv_var_name);
T
typhoonzero 已提交
120
          if (var == nullptr) {
T
typhoonzero 已提交
121
            LOG(ERROR) << "Can not find server side var: " << recv_var_name;
T
typhoonzero 已提交
122 123 124
            PADDLE_THROW("Can not find server side var");
          }
          detail::DeserializeFromMessage(v.second, dev_ctx, var);
125 126 127
          if (var->IsType<framework::SelectedRows>()) {
            sparse_vars.push_back(var);
          }
T
typhoonzero 已提交
128 129 130 131 132 133 134 135 136 137 138
        }
      }
      if (exit_flag) {
        rpc_service_->ShutDown();
      }
      try {
        executor.Run(*program, &recv_scope, block->ID(), /*global_block*/
                     false /*create_local_scope*/, false /*create_vars*/);
      } catch (std::exception &e) {
        LOG(ERROR) << "run sub program error " << e.what();
      }
139 140 141
      // Reset the received sparse variables, the sum operator would not
      // sum the input sparse variables which rows is empty at the next
      // mini-batch.
T
typhoonzero 已提交
142
      // TODO(Yancey1989): move the reset action into an operator, we couldn't
143 144 145 146
      // have any hide logic in the operator.
      for (auto &var : sparse_vars) {
        var->GetMutable<framework::SelectedRows>()->mutable_rows()->clear();
      }
T
typhoonzero 已提交
147
      rpc_service_->SetCond(1);
T
typhoonzero 已提交
148 149
      // FIXME(typhoonzero): use another condition to sync wait clients get.
      rpc_service_->WaitClientGet(ins.size());
150
      sparse_vars.clear();
T
typhoonzero 已提交
151 152 153 154 155 156 157 158 159 160 161 162
    }  // while(true)
  }

 protected:
  std::shared_ptr<detail::AsyncGRPCServer> rpc_service_;
  std::shared_ptr<std::thread> server_thread_;
};

class ListenAndServOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
  ListenAndServOpMaker(OpProto *proto, OpAttrChecker *op_checker)
      : OpProtoAndCheckerMaker(proto, op_checker) {
T
typhoonzero 已提交
163
    AddInput("X", "(Tensor) Variables that server recv.").AsDuplicable();
T
typhoonzero 已提交
164 165 166 167 168 169 170 171 172 173 174 175 176
    AddComment(R"DOC(
ListenAndServ operator

This operator will start a RPC server which can receive variables
from send_op and send back variables to recv_op.
)DOC");
    AddAttr<std::string>("endpoint",
                         "(string, default 127.0.0.1:6164)"
                         "IP address to listen on.")
        .SetDefault("127.0.0.1:6164")
        .AddCustomChecker([](const std::string &ip) { return !ip.empty(); });
    AddAttr<framework::BlockDesc *>(kOptimizeBlock,
                                    "BlockID to run on server side.");
177 178
    AddAttr<int>("Fanin", "How many clients send to this server.")
        .SetDefault(1);
T
typhoonzero 已提交
179 180 181 182 183 184 185 186 187
  }
};

}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;

REGISTER_OPERATOR(listen_and_serv, ops::ListenAndServOp,
Y
Yancey1989 已提交
188
                  ops::ListenAndServOpMaker);