listen_and_serv_op.cc 8.4 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
T
typhoonzero 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include <stdint.h>
#include <sys/stat.h>
#include <ostream>
#include <thread>

#include <unistd.h>

Y
Yi Wang 已提交
22 23 24 25 26
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/proto_desc.h"
T
typhoonzero 已提交
27
#include "paddle/fluid/framework/threadpool.h"
Y
Yi Wang 已提交
28 29 30
#include "paddle/fluid/operators/detail/grpc_server.h"
#include "paddle/fluid/operators/detail/sendrecvop_utils.h"
#include "paddle/fluid/operators/detail/simple_block_queue.h"
31
#include "paddle/fluid/string/printf.h"
T
typhoonzero 已提交
32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56

namespace paddle {
namespace operators {

constexpr char kOptimizeBlock[] = "OptimizeBlock";

void RunServer(std::shared_ptr<detail::AsyncGRPCServer> service) {
  service->RunSyncUpdate();
  VLOG(4) << "RunServer thread end";
}

static void CreateTensorFromMessageType(framework::Variable *var,
                                        sendrecv::VarType var_type) {
  if (var_type == sendrecv::VarType::LOD_TENSOR) {
    var->GetMutable<framework::LoDTensor>();
  } else if (var_type == sendrecv::VarType::SELECTED_ROWS) {
    var->GetMutable<framework::SelectedRows>();
  } else {
    PADDLE_THROW(
        "VariableMessage type %d is not in "
        "[LoDTensor, SelectedRows]",
        var_type);
  }
}

57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74
static void ParallelExecuteBlocks(const std::vector<size_t> &parallel_blkids,
                                  framework::Executor *executor,
                                  framework::ProgramDesc *program,
                                  framework::Scope *scope) {
  std::vector<std::future<void>> fs;
  for (size_t idx : parallel_blkids) {
    fs.push_back(framework::Async([&executor, &program, &scope, idx]() {
      int run_block = idx;  // thread local
      try {
        executor->Run(*program, scope, run_block, false, false);
      } catch (std::exception &e) {
        LOG(ERROR) << "run sub program error " << e.what();
      }
    }));
  }
  for (size_t i = 0; i < fs.size(); ++i) fs[i].wait();
}

T
typhoonzero 已提交
75 76 77 78 79 80 81 82 83 84 85 86 87 88 89
class ListenAndServOp : public framework::OperatorBase {
 public:
  ListenAndServOp(const std::string &type,
                  const framework::VariableNameMap &inputs,
                  const framework::VariableNameMap &outputs,
                  const framework::AttributeMap &attrs)
      : OperatorBase(type, inputs, outputs, attrs) {
    if (!rpc_service_) {
      std::string endpoint = Attr<std::string>("endpoint");
      rpc_service_.reset(new detail::AsyncGRPCServer(endpoint));
      server_thread_.reset(new std::thread(RunServer, rpc_service_));
    }
  }

  void Stop() override {
90
    rpc_service_->Push(LISTEN_TERMINATE_MESSAGE);
T
typhoonzero 已提交
91 92 93
    server_thread_->join();
  }

T
typhoonzero 已提交
94 95
  void RunImpl(const framework::Scope &scope,
               const platform::Place &dev_place) const override {
T
typhoonzero 已提交
96 97 98 99 100 101 102
    platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
    auto &dev_ctx = *pool.Get(dev_place);
    framework::Scope &recv_scope = scope.NewScope();

    // FIXME(Yancey1989): initialize rpc server with lazy mode.
    rpc_service_->SetScope(&recv_scope);
    rpc_service_->SetDevCtx(&dev_ctx);
T
typhoonzero 已提交
103
    auto ins = Inputs("X");
104
    auto fan_in = Attr<int>("Fanin");
T
typhoonzero 已提交
105 106 107

    auto *block = Attr<framework::BlockDesc *>(kOptimizeBlock);
    auto *program = block->Program();
T
typhoonzero 已提交
108
    int num_blocks = program->Size();
T
typhoonzero 已提交
109 110 111
    PADDLE_ENFORCE_GE(num_blocks, 2,
                      "server program should have at least 2 blocks");

T
typhoonzero 已提交
112 113 114 115
    framework::Executor executor(dev_place);

    // TODO(typhoonzero): change this to a while_op for every cluster-batch.
    bool exit_flag = false;
116 117 118
    // Record received sparse variables, so that
    // we could reset those after execute optimize program
    std::vector<framework::Variable *> sparse_vars;
T
typhoonzero 已提交
119 120 121 122 123 124 125
    while (!exit_flag) {
      // Get from multiple trainers, we don't care about the order in which
      // the gradients arrives, just add suffix 0~n and merge the gradient.
      rpc_service_->SetCond(0);
      size_t recv_var_cnt = 0;
      int batch_barrier = 0;
      while (batch_barrier != fan_in) {
126
        const detail::ReceivedMessage v = rpc_service_->Get();
T
typhoonzero 已提交
127 128
        auto recv_var_name = v.first;
        if (recv_var_name == LISTEN_TERMINATE_MESSAGE) {
T
typhoonzero 已提交
129 130 131
          LOG(INFO) << "received terminate message and exit";
          exit_flag = true;
          break;
T
typhoonzero 已提交
132
        } else if (recv_var_name == BATCH_BARRIER_MESSAGE) {
T
typhoonzero 已提交
133 134 135 136
          VLOG(3) << "recv batch barrier message";
          batch_barrier++;
          continue;
        } else {
T
typhoonzero 已提交
137
          VLOG(3) << "received grad: " << recv_var_name;
T
typhoonzero 已提交
138
          recv_var_cnt++;
139
          auto var = v.second->GetVar();
T
typhoonzero 已提交
140
          if (var == nullptr) {
T
typhoonzero 已提交
141
            LOG(ERROR) << "Can not find server side var: " << recv_var_name;
T
typhoonzero 已提交
142 143
            PADDLE_THROW("Can not find server side var");
          }
144 145 146
          if (var->IsType<framework::SelectedRows>()) {
            sparse_vars.push_back(var);
          }
T
typhoonzero 已提交
147 148 149
        }
      }
      if (exit_flag) {
Y
Yancey1989 已提交
150
        rpc_service_->SetCond(1);
151
        rpc_service_->ShutDown();
Y
Yancey1989 已提交
152
        break;
T
typhoonzero 已提交
153
      }
T
typhoonzero 已提交
154 155 156

      // NOTE: if is_gpu_place, CUDA kernels are laugched by multiple threads
      // and this will still work.
T
typhoonzero 已提交
157

158 159 160 161 162
      // The optimize blocks which have the same parent ID would run parallel
      // TODO(Yancey1989): need to use ParallelExecutor for future
      size_t last_parent_blkid = program->Block(1).Parent();
      std::vector<size_t> parallel_blkids;
      parallel_blkids.push_back(1);
T
typhoonzero 已提交
163
      double ts = detail::GetTimestamp();
164 165 166 167 168 169 170
      for (size_t blkid = 2; blkid < num_blocks; ++blkid) {
        if (program->Block(blkid).Parent() != last_parent_blkid) {
          for (size_t idx : parallel_blkids) VLOG(3) << idx;
          ParallelExecuteBlocks(parallel_blkids, &executor, program,
                                &recv_scope);
          parallel_blkids.clear();
          last_parent_blkid = program->Block(blkid).Parent();
T
typhoonzero 已提交
171
        }
172
        parallel_blkids.push_back(blkid);
T
typhoonzero 已提交
173
      }
174 175
      ParallelExecuteBlocks(parallel_blkids, &executor, program, &recv_scope);

T
typhoonzero 已提交
176
      VLOG(2) << "run all blocks spent (ms) " << detail::GetTimestamp() - ts;
T
typhoonzero 已提交
177

178 179 180
      // Reset the received sparse variables, the sum operator would not
      // sum the input sparse variables which rows is empty at the next
      // mini-batch.
T
typhoonzero 已提交
181
      // TODO(Yancey1989): move the reset action into an operator, we couldn't
182 183 184 185
      // have any hide logic in the operator.
      for (auto &var : sparse_vars) {
        var->GetMutable<framework::SelectedRows>()->mutable_rows()->clear();
      }
T
typhoonzero 已提交
186
      rpc_service_->SetCond(1);
T
typhoonzero 已提交
187
      // FIXME(typhoonzero): use another condition to sync wait clients get.
188
      rpc_service_->WaitClientGet(fan_in);
189
      sparse_vars.clear();
T
typhoonzero 已提交
190 191 192 193 194 195 196 197 198 199 200 201
    }  // while(true)
  }

 protected:
  std::shared_ptr<detail::AsyncGRPCServer> rpc_service_;
  std::shared_ptr<std::thread> server_thread_;
};

class ListenAndServOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
  ListenAndServOpMaker(OpProto *proto, OpAttrChecker *op_checker)
      : OpProtoAndCheckerMaker(proto, op_checker) {
T
typhoonzero 已提交
202
    AddInput("X", "(Tensor) Variables that server recv.").AsDuplicable();
T
typhoonzero 已提交
203 204 205 206 207 208 209 210 211 212 213 214 215
    AddComment(R"DOC(
ListenAndServ operator

This operator will start a RPC server which can receive variables
from send_op and send back variables to recv_op.
)DOC");
    AddAttr<std::string>("endpoint",
                         "(string, default 127.0.0.1:6164)"
                         "IP address to listen on.")
        .SetDefault("127.0.0.1:6164")
        .AddCustomChecker([](const std::string &ip) { return !ip.empty(); });
    AddAttr<framework::BlockDesc *>(kOptimizeBlock,
                                    "BlockID to run on server side.");
216 217
    AddAttr<int>("Fanin", "How many clients send to this server.")
        .SetDefault(1);
T
typhoonzero 已提交
218 219 220 221 222 223 224 225 226
  }
};

}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;

REGISTER_OPERATOR(listen_and_serv, ops::ListenAndServOp,
Y
Yancey1989 已提交
227
                  ops::ListenAndServOpMaker);