recv_op.cc 5.6 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
武毅 已提交
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
武毅 已提交
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
武毅 已提交
8

L
Luo Tao 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
武毅 已提交
14

Y
Yi Wang 已提交
15
#include "paddle/fluid/framework/op_registry.h"
16
#include "paddle/fluid/operators/distributed/communicator.h"
W
Wu Yi 已提交
17
#include "paddle/fluid/operators/distributed/distributed.h"
W
wanghuancoder 已提交
18 19 20 21 22 23 24 25 26 27 28 29 30

namespace paddle {
namespace framework {
class InferShapeContext;
class OpDesc;
class Scope;
template <typename T>
class EmptyGradOpMaker;
}  // namespace framework
namespace imperative {
class OpBase;
}  // namespace imperative
}  // namespace paddle
T
typhoonzero 已提交
31

武毅 已提交
32 33 34
namespace paddle {
namespace operators {

W
wanghuancoder 已提交
35 36 37 38
namespace distributed {
class RPCClient;
}  // namespace distributed

武毅 已提交
39 40
class RecvOp : public framework::OperatorBase {
 public:
41 42 43
  RecvOp(const std::string &type, const framework::VariableNameMap &inputs,
         const framework::VariableNameMap &outputs,
         const framework::AttributeMap &attrs)
T
typhoonzero 已提交
44 45
      : OperatorBase(type, inputs, outputs, attrs) {}

46 47
  void RunImpl(const framework::Scope &scope,
               const platform::Place &place) const override {
T
typhoonzero 已提交
48
    std::vector<std::string> epmap = Attr<std::vector<std::string>>("epmap");
49 50
    std::vector<std::string> varnames =
        Attr<std::vector<std::string>>("varnames");
51

52 53
    auto outs = Outputs("Out");
    bool with_barrier = Attr<bool>("with_barrier");
Y
Yancey1989 已提交
54

55 56
    platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
    auto &ctx = *pool.Get(place);
Q
Qiao Longfei 已提交
57
    auto trainer_id = Attr<int>("trainer_id");
Y
Yancey1989 已提交
58

59
    distributed::RPCClient *rpc_client =
Q
Qiao Longfei 已提交
60
        distributed::RPCClient::GetInstance<RPCCLIENT_T>(trainer_id);
T
typhoonzero 已提交
61

Q
Qiao Longfei 已提交
62 63 64 65
    std::vector<std::string> recv_varnames =
        Attr<std::vector<std::string>>("recv_varnames");

    if (recv_varnames.size() > 0) {
66 67
      auto *communicator = distributed::Communicator::GetInstance();

T
tangwei12 已提交
68
      if (communicator != nullptr) {
69
        PADDLE_THROW(platform::errors::InvalidArgument(
T
tangwei12 已提交
70
            "execute startup program must before fleet.init_worker"));
71
      }
Q
Qiao Longfei 已提交
72
    } else {
73
      std::vector<distributed::VarHandlePtr> rets;
Q
Qiao Longfei 已提交
74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89
      if (with_barrier) {
        for (size_t i = 0; i < outs.size(); i++) {
          std::string varname = varnames.size() == 0 ? outs[i] : varnames[i];
          VLOG(4) << "recv " << outs[i] << " from " << epmap[i] << " with "
                  << varname << " and with AsyncGetVar";
          rets.push_back(
              rpc_client->AsyncGetVar(epmap[i], ctx, scope, varname, outs[i]));
        }
      } else {
        for (size_t i = 0; i < outs.size(); i++) {
          std::string varname = varnames.size() == 0 ? outs[i] : varnames[i];
          VLOG(4) << "recv " << outs[i] << " from " << epmap[i] << " with "
                  << varname << " and with AsyncGetVarNoBarrier";
          rets.push_back(rpc_client->AsyncGetVarNoBarrier(epmap[i], ctx, scope,
                                                          varname, outs[i]));
        }
90 91 92
      }
      for (size_t i = 0; i < rets.size(); i++) {
        VLOG(7) << "before sync_recv " << outs[i] << "from " << epmap[i];
T
tangwei12 已提交
93 94 95
        PADDLE_ENFORCE_NE(
            rets[i]->Wait(), 0U,
            platform::errors::ExecutionTimeout("internal error in RPCClient"));
96
        VLOG(7) << "after sync_recv " << outs[i] << "from " << epmap[i];
97
      }
Y
Yancey1989 已提交
98
    }
武毅 已提交
99 100 101 102 103
  }
};

class RecvOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
104
  void Make() {
105 106
    AddInput("X", "(Any) Dummy inputs, used for control dependency")
        .AsDuplicable();
T
typhoonzero 已提交
107
    AddOutput("Out", "(Tensor) Variables to get from server.").AsDuplicable();
武毅 已提交
108 109 110
    AddComment(R"DOC(
Recv operator

111
This operator can get variables from server side.
武毅 已提交
112
)DOC");
T
typhoonzero 已提交
113 114 115 116
    AddAttr<std::vector<std::string>>("epmap",
                                      "(string vector, default 127.0.0.1:6164)"
                                      "Server endpoints in the order of input "
                                      "variables for mapping")
Y
Yancey1989 已提交
117
        .SetDefault({});
W
Wu Yi 已提交
118
    AddAttr<int>("trainer_id", "trainer id from 0 ~ worker_num.").SetDefault(0);
119 120 121 122 123 124 125 126 127 128 129
    AddAttr<bool>("with_barrier",
                  "(bool, default True) if with_barrier=False, will use "
                  "AsyncGetVarNoBarrier get variable from pserver immediately")
        .SetDefault(true);
    AddAttr<std::vector<std::string>>(
        "varnames",
        "(string vector, default {}) "
        "sometimes we need to put received var in another name "
        "for example: we need var named 'moment_1@127.0.0.1:1001', "
        "and it real name on parameter server is 'moment_1'. ")
        .SetDefault({});
Q
Qiao Longfei 已提交
130 131 132
    AddAttr<std::vector<std::string>>(
        "recv_varnames",
        "(vector<string>) "
T
tianshuo78520a 已提交
133
        "the split parameter varnames to be recved from pserver")
Q
Qiao Longfei 已提交
134
        .SetDefault(std::vector<std::string>{});
135
    AddAttr<int>("do_not_run", "if recv need to really run").SetDefault(0);
武毅 已提交
136 137 138
  }
};

139 140
class RecvOpShapeInference : public framework::InferShapeBase {
 public:
141
  void operator()(framework::InferShapeContext *ctx) const override {}
142 143
};

武毅 已提交
144 145 146 147 148
}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;

H
hong 已提交
149 150 151 152 153
REGISTER_OPERATOR(
    recv, ops::RecvOp,
    paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
    paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
    ops::RecvOpMaker, ops::RecvOpShapeInference);