recv_op.cc 7.2 KB
Newer Older
L
Luo Tao 已提交
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
武毅 已提交
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
武毅 已提交
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
武毅 已提交
8

L
Luo Tao 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
武毅 已提交
14 15 16 17 18 19 20 21 22 23 24 25

#include <stdint.h>
#include <sys/stat.h>
#include <ostream>
#include <thread>

#include <unistd.h>

#include "paddle/framework/executor.h"
#include "paddle/framework/framework.pb.h"
#include "paddle/framework/lod_tensor.h"
#include "paddle/framework/op_registry.h"
T
typhoonzero 已提交
26
#include "paddle/framework/proto_desc.h"
G
gongweibao 已提交
27 28
#include "paddle/operators/detail/grpc_server.h"
#include "paddle/operators/detail/sendrecvop_utils.h"
武毅 已提交
29
#include "paddle/operators/detail/simple_block_queue.h"
T
typhoonzero 已提交
30
#include "paddle/string/printf.h"
武毅 已提交
31 32 33 34

namespace paddle {
namespace operators {

35 36
constexpr char kOptimizeBlock[] = "OptimizeBlock";

G
gongweibao 已提交
37 38 39 40 41
void RunServer(std::shared_ptr<detail::AsyncGRPCServer> service) {
  service->RunSyncUpdate();
  VLOG(4) << "RunServer thread end";
}

Y
Yancey 已提交
42 43 44 45 46 47 48 49
static void CreateTensorFromMessageType(framework::Variable *var,
                                        sendrecv::VarType var_type) {
  if (var_type == sendrecv::VarType::LOD_TENSOR) {
    var->GetMutable<framework::LoDTensor>();
  } else if (var_type == sendrecv::VarType::SELECTED_ROWS) {
    var->GetMutable<framework::SelectedRows>();
  } else {
    PADDLE_THROW(
T
typhoonzero 已提交
50
        "VariableMessage type %d is not in "
Y
Yancey 已提交
51 52 53 54 55
        "[LoDTensor, SelectedRows]",
        var_type);
  }
}

武毅 已提交
56 57 58 59 60 61 62 63
class RecvOp : public framework::OperatorBase {
 public:
  RecvOp(const std::string &type, const framework::VariableNameMap &inputs,
         const framework::VariableNameMap &outputs,
         const framework::AttributeMap &attrs)
      : OperatorBase(type, inputs, outputs, attrs) {
    if (!rpc_service_) {
      std::string endpoint = Attr<std::string>("endpoint");
G
gongweibao 已提交
64 65
      rpc_service_.reset(new detail::AsyncGRPCServer(endpoint));
      server_thread_.reset(new std::thread(RunServer, rpc_service_));
武毅 已提交
66 67 68
    }
  }

T
typhoonzero 已提交
69
  void Stop() override {
Y
Yancey 已提交
70
    detail::MessageWithName term_msg;
T
typhoonzero 已提交
71 72
    term_msg.first = LISTEN_TERMINATE_MESSAGE;
    rpc_service_->Push(term_msg);
G
gongweibao 已提交
73
    rpc_service_->ShutDown();
武毅 已提交
74 75 76
    server_thread_->join();
  }

T
done  
typhoonzero 已提交
77
  std::string GetGradVarNameForTrainer(const std::string &varname) const {
T
typhoonzero 已提交
78
    if (grads_counter_.find(varname) == grads_counter_.end()) {
T
done  
typhoonzero 已提交
79 80
      grads_counter_[varname] = 0;
    }
T
typhoonzero 已提交
81
    return string::Sprintf("%s.trainer_%d", varname, grads_counter_[varname]++);
T
done  
typhoonzero 已提交
82 83
  }

武毅 已提交
84
  void Run(const framework::Scope &scope,
D
dzhwinter 已提交
85
           const platform::Place &dev_place) const override {
Y
Yancey1989 已提交
86 87
    platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
    auto &dev_ctx = *pool.Get(dev_place);
武毅 已提交
88
    framework::Scope &recv_scope = scope.NewScope();
Y
Yancey1989 已提交
89

Y
Yancey1989 已提交
90
    // FIXME(Yancey1989): initialize rpc server with laze mode.
T
typhoonzero 已提交
91
    rpc_service_->SetScope(&recv_scope);
Y
Yancey1989 已提交
92
    rpc_service_->SetDevCtx(&dev_ctx);
T
typhoonzero 已提交
93 94
    auto param_list = Attr<std::vector<std::string>>("ParamList");
    auto grad_list = Attr<std::vector<std::string>>("GradList");
T
typhoonzero 已提交
95
    auto fan_in = Attr<int>("Fanin");
G
gongweibao 已提交
96

97 98
    auto *block = Attr<framework::BlockDesc *>(kOptimizeBlock);
    auto *program = block->Program();
T
typhoonzero 已提交
99 100
    framework::Executor executor(dev_place);

T
typhoonzero 已提交
101
    // TODO(typhoonzero): change this to a while_op for every cluster-batch.
T
typhoonzero 已提交
102 103
    bool exit_flag = false;
    while (!exit_flag) {
T
typhoonzero 已提交
104 105
      // Get from multiple trainers, we don't care about the order in which
      // the gradients arrives, just add suffix 0~n and merge the gradient.
T
typhoonzero 已提交
106
      rpc_service_->SetCond(0);
Y
Yancey 已提交
107 108 109
      size_t recv_var_cnt = 0;
      int batch_barrier = 0;
      while (batch_barrier != fan_in) {
Y
Yancey 已提交
110
        const detail::MessageWithName &v = rpc_service_->Get();
T
typhoonzero 已提交
111
        auto grad_var_name = v.first;
T
typhoonzero 已提交
112
        if (grad_var_name == LISTEN_TERMINATE_MESSAGE) {
T
typhoonzero 已提交
113
          LOG(INFO) << "received terminate message and exit";
T
typhoonzero 已提交
114 115
          exit_flag = true;
          break;
Y
Yancey 已提交
116 117 118 119
        } else if (grad_var_name == BATCH_BARRIER_MESSAGE) {
          VLOG(3) << "recv batch barrier message";
          batch_barrier++;
          continue;
T
typhoonzero 已提交
120
        } else {
Y
Yancey 已提交
121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142
          // receive a variable
          recv_var_cnt++;
          auto it =
              std::find(grad_list.begin(), grad_list.end(), grad_var_name);
          std::string param_var_name;
          if (it != grad_list.end()) {
            param_var_name = param_list[it - grad_list.begin()];
          } else {
            LOG(ERROR) << "grad has no paired param:" << grad_var_name;
          }
          VLOG(3) << "received grad: " << grad_var_name
                  << " updating param: " << param_var_name;

          if (fan_in > 1) {
            grad_var_name = this->GetGradVarNameForTrainer(grad_var_name);
          }
          auto *var = recv_scope.FindVar(grad_var_name);
          if (var == nullptr) {
            LOG(ERROR) << "Can not find server side var: " << grad_var_name;
            PADDLE_THROW("Can not find server side var");
          }
          detail::DeserializeFromMessage(v.second, dev_ctx, var);
T
done  
typhoonzero 已提交
143
        }
T
typhoonzero 已提交
144
      }
Y
Yancey 已提交
145 146
      VLOG(3) << "recv " << recv_var_cnt << " parmeters for one barrier.";
      // TODO(Yancey1989): merge SelectedRows variables here
T
typhoonzero 已提交
147 148
      if (exit_flag) {
        break;
T
typhoonzero 已提交
149
      }
G
gongweibao 已提交
150

T
typhoonzero 已提交
151
      try {
152
        executor.Run(*program, &recv_scope, block->ID(), /*global_block*/
T
typhoonzero 已提交
153 154 155 156
                     false /*create_local_scope*/, false /*create_vars*/);
      } catch (std::exception &e) {
        LOG(ERROR) << "run sub program error " << e.what();
      }
T
typhoonzero 已提交
157
      rpc_service_->SetCond(1);
Y
Yancey 已提交
158
      rpc_service_->WaitClientGet(recv_var_cnt);
T
typhoonzero 已提交
159
      grads_counter_.clear();
T
typhoonzero 已提交
160
    }  // while(true)
武毅 已提交
161 162 163
  }

 protected:
G
gongweibao 已提交
164
  std::shared_ptr<detail::AsyncGRPCServer> rpc_service_;
武毅 已提交
165
  std::shared_ptr<std::thread> server_thread_;
T
done  
typhoonzero 已提交
166
  mutable std::unordered_map<std::string, int> grads_counter_;
武毅 已提交
167 168 169 170
};

class RecvOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
171
  RecvOpMaker(OpProto *proto, OpAttrChecker *op_checker)
武毅 已提交
172 173 174 175
      : OpProtoAndCheckerMaker(proto, op_checker) {
    AddComment(R"DOC(
Recv operator

T
typhoonzero 已提交
176
This operator will recieve tensor from send_op
武毅 已提交
177 178 179 180 181 182
)DOC");
    AddAttr<std::string>("endpoint",
                         "(string, default 127.0.0.1:6164)"
                         "IP address to listen on.")
        .SetDefault("127.0.0.1:6164")
        .AddCustomChecker([](const std::string &ip) { return !ip.empty(); });
183 184
    AddAttr<framework::BlockDesc *>(
        kOptimizeBlock, "Serialized ProgramDesc string for recv to run.");
T
typhoonzero 已提交
185 186
    AddAttr<std::vector<std::string>>(
        "ParamList", "type list of string",
T
typhoonzero 已提交
187
        "grad->param name mapping to find which parameters to optimize.")
Y
Yancey1989 已提交
188
        .SetDefault({});
T
typhoonzero 已提交
189 190
    AddAttr<std::vector<std::string>>(
        "GradList", "type list of string",
T
typhoonzero 已提交
191
        "grad->param name mapping to find which parameters to optimize.")
Y
Yancey1989 已提交
192
        .SetDefault({});
T
typhoonzero 已提交
193
    AddAttr<int>("Fanin", "type int",
T
done  
typhoonzero 已提交
194 195
                 "Number of trainers in the current cluster job")
        .SetDefault(1);
武毅 已提交
196 197 198 199 200 201 202 203 204
  }
};

}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;

REGISTER_OPERATOR(recv, ops::RecvOp, ops::RecvOpMaker);