提交 1b8bc3c5 编写于 作者: T typhoonzero

rename rpc ops

上级 686024c0
...@@ -12,179 +12,67 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,179 +12,67 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <stdint.h>
#include <sys/stat.h>
#include <ostream> #include <ostream>
#include <thread>
#include <unistd.h> #include "paddle/framework/data_type.h"
#include "paddle/framework/executor.h"
#include "paddle/framework/framework.pb.h" #include "paddle/framework/framework.pb.h"
#include "paddle/framework/lod_tensor.h" #include "paddle/framework/lod_tensor.h"
#include "paddle/framework/op_registry.h" #include "paddle/framework/op_registry.h"
#include "paddle/framework/proto_desc.h"
#include "paddle/operators/detail/grpc_server.h"
#include "paddle/operators/detail/sendrecvop_utils.h"
#include "paddle/operators/detail/simple_block_queue.h"
#include "paddle/string/printf.h"
#define LISTEN_TERMINATE_MESSAGE "TERMINATE@RECV" #include <future>
#include "paddle/operators/detail/grpc_client.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
constexpr char kOptimizeBlock[] = "OptimizeBlock"; class SendOp : public framework::OperatorBase {
void RunServer(std::shared_ptr<detail::AsyncGRPCServer> service) {
service->RunSyncUpdate();
VLOG(4) << "RunServer thread end";
}
static void CreateTensorFromMessageType(framework::Variable *var,
sendrecv::VarType var_type) {
if (var_type == sendrecv::VarType::LOD_TENSOR) {
var->GetMutable<framework::LoDTensor>();
} else if (var_type == sendrecv::VarType::SELECTED_ROWS) {
var->GetMutable<framework::SelectedRows>();
} else {
PADDLE_THROW(
"VariableMessage type %d is not in "
"[LoDTensor, SelectedRows]",
var_type);
}
}
class RecvOp : public framework::OperatorBase {
public: public:
RecvOp(const std::string &type, const framework::VariableNameMap &inputs, SendOp(const std::string& type, const framework::VariableNameMap& inputs,
const framework::VariableNameMap &outputs, const framework::VariableNameMap& outputs,
const framework::AttributeMap &attrs) const framework::AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) { : OperatorBase(type, inputs, outputs, attrs) {}
if (!rpc_service_) {
std::string endpoint = Attr<std::string>("endpoint"); void Run(const framework::Scope& scope,
rpc_service_.reset(new detail::AsyncGRPCServer(endpoint)); const platform::Place& place) const override {
server_thread_.reset(new std::thread(RunServer, rpc_service_)); auto outs = Outputs("Out");
} std::vector<std::string> epmap = Attr<std::vector<std::string>>("epmap");
}
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
void Stop() override { auto& ctx = *pool.Get(place);
detail::MessageWithName term_msg;
term_msg.first = LISTEN_TERMINATE_MESSAGE; for (size_t i = 0; i < outs.size(); i++) {
rpc_service_->Push(term_msg); VLOG(3) << "getting " << outs[i];
rpc_service_->ShutDown(); client_.AsyncGetVariable(epmap[i], ctx, scope, outs[i]);
server_thread_->join();
}
std::string GetGradVarNameForTrainer(const std::string &varname) const {
if (grads_counter_.find(varname) == grads_counter_.end()) {
grads_counter_[varname] = 0;
}
return string::Sprintf("%s.trainer_%d", varname, grads_counter_[varname]++);
} }
void Run(const framework::Scope &scope, PADDLE_ENFORCE(client_.Wait());
const platform::Place &dev_place) const override {
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(dev_place);
framework::Scope &recv_scope = scope.NewScope();
// FIXME(Yancey1989): initialize rpc server with laze mode.
rpc_service_->SetScope(&recv_scope);
rpc_service_->SetDevCtx(&dev_ctx);
auto param_list = Attr<std::vector<std::string>>("ParamList");
auto grad_list = Attr<std::vector<std::string>>("GradList");
auto fan_in = Attr<int>("Fanin");
size_t param_count = param_list.size();
auto *block = Attr<framework::BlockDesc *>(kOptimizeBlock);
auto *program = block->Program();
framework::Executor executor(dev_place);
// TODO(typhoonzero): change this to a while_op for every cluster-batch.
bool exit_flag = false;
size_t barrier_size = param_count * fan_in;
while (!exit_flag) {
// Get from multiple trainers, we don't care about the order in which
// the gradients arrives, just add suffix 0~n and merge the gradient.
rpc_service_->SetCond(0);
for (size_t i = 0; i < barrier_size; ++i) {
const detail::MessageWithName &v = rpc_service_->Get();
auto grad_var_name = v.first;
if (grad_var_name == LISTEN_TERMINATE_MESSAGE) {
LOG(INFO) << "received terminate message and exit";
exit_flag = true;
break;
}
auto it = std::find(grad_list.begin(), grad_list.end(), grad_var_name);
std::string param_var_name;
if (it != grad_list.end()) {
param_var_name = param_list[it - grad_list.begin()];
} else {
LOG(ERROR) << "grad has no paired param:" << grad_var_name;
}
VLOG(3) << "received grad: " << grad_var_name
<< " updating param: " << param_var_name;
if (fan_in > 1) {
grad_var_name = this->GetGradVarNameForTrainer(grad_var_name);
}
auto *var = recv_scope.FindVar(grad_var_name);
if (var == nullptr) {
LOG(ERROR) << "Can not find server side var: " << grad_var_name;
PADDLE_THROW("Can not find server side var");
}
detail::DeserializeFromMessage(v.second, dev_ctx, var);
}
if (exit_flag) {
break;
}
try {
executor.Run(*program, &recv_scope, block->ID(), /*global_block*/
false /*create_local_scope*/, false /*create_vars*/);
} catch (std::exception &e) {
LOG(ERROR) << "run sub program error " << e.what();
}
rpc_service_->SetCond(1);
rpc_service_->WaitClientGet(barrier_size);
grads_counter_.clear();
} // while(true)
} }
protected: private:
std::shared_ptr<detail::AsyncGRPCServer> rpc_service_; mutable detail::RPCClient client_;
std::shared_ptr<std::thread> server_thread_;
mutable std::unordered_map<std::string, int> grads_counter_;
}; };
class RecvOpMaker : public framework::OpProtoAndCheckerMaker { class SendOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
RecvOpMaker(OpProto *proto, OpAttrChecker *op_checker) SendOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("RX", "(Tensor) Input tensor to be optimized").AsDuplicable(); AddInput("X", "(Tensor) Input tensor to be sent").AsDuplicable();
AddOutput("Out", "(Tensor) Output tensor to be received from server")
.AsDuplicable();
AddComment(R"DOC( AddComment(R"DOC(
Recv operator Send operator
This operator will recieve tensor from send_op This operator will send tensor to recv_op at the parameter server.
)DOC"); )DOC");
AddAttr<std::string>("endpoint", AddAttr<std::vector<std::string>>("endpoints",
"(string, default 127.0.0.1:6164)" "(string vector, default 127.0.0.1:6164)"
"IP address to listen on.") "Server endpoints to send variables to.")
.SetDefault("127.0.0.1:6164")
.AddCustomChecker([](const std::string &ip) { return !ip.empty(); });
AddAttr<framework::BlockDesc *>(
kOptimizeBlock, "Serialized ProgramDesc string for recv to run.");
AddAttr<std::vector<std::string>>(
"ParamList", "type list of string",
"grad->param name mapping to find which parameters to optimize.")
.SetDefault({}); .SetDefault({});
AddAttr<std::vector<std::string>>( AddAttr<std::vector<std::string>>("epmap",
"GradList", "type list of string", "(string vector, default 127.0.0.1:6164)"
"grad->param name mapping to find which parameters to optimize.") "Server endpoints in the order of input "
"variables for mapping")
.SetDefault({}); .SetDefault({});
AddAttr<int>("Fanin", "type int",
"Number of trainers in the current cluster job")
.SetDefault(1);
} }
}; };
...@@ -193,4 +81,4 @@ This operator will recieve tensor from send_op ...@@ -193,4 +81,4 @@ This operator will recieve tensor from send_op
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(recv, ops::RecvOp, ops::RecvOpMaker); REGISTER_OPERATOR(send, ops::SendOp, ops::SendOpMaker);
...@@ -37,6 +37,7 @@ class SendOp : public framework::OperatorBase { ...@@ -37,6 +37,7 @@ class SendOp : public framework::OperatorBase {
auto ins = Inputs("X"); auto ins = Inputs("X");
auto outs = Outputs("Out"); auto outs = Outputs("Out");
std::vector<std::string> epmap = Attr<std::vector<std::string>>("epmap"); std::vector<std::string> epmap = Attr<std::vector<std::string>>("epmap");
bool do_get = Attr<bool>("DoGet");
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto& ctx = *pool.Get(place); auto& ctx = *pool.Get(place);
...@@ -46,10 +47,12 @@ class SendOp : public framework::OperatorBase { ...@@ -46,10 +47,12 @@ class SendOp : public framework::OperatorBase {
} }
PADDLE_ENFORCE(client_.Wait()); PADDLE_ENFORCE(client_.Wait());
if (do_get) {
for (size_t i = 0; i < outs.size(); i++) { for (size_t i = 0; i < outs.size(); i++) {
VLOG(3) << "getting " << outs[i]; VLOG(3) << "getting " << outs[i];
client_.AsyncGetVariable(epmap[i], ctx, scope, outs[i]); client_.AsyncGetVariable(epmap[i], ctx, scope, outs[i]);
} }
}
PADDLE_ENFORCE(client_.Wait()); PADDLE_ENFORCE(client_.Wait());
} }
...@@ -79,6 +82,10 @@ This operator will send tensor to recv_op at the parameter server. ...@@ -79,6 +82,10 @@ This operator will send tensor to recv_op at the parameter server.
"Server endpoints in the order of input " "Server endpoints in the order of input "
"variables for mapping") "variables for mapping")
.SetDefault({}); .SetDefault({});
AddAttr<bool>("DoGet",
"(bool, default true)"
"Whether do GetVariable call after send")
.SetDefault(true);
} }
}; };
......
...@@ -25,7 +25,7 @@ limitations under the License. */ ...@@ -25,7 +25,7 @@ limitations under the License. */
#include "paddle/string/printf.h" #include "paddle/string/printf.h"
USE_NO_KERNEL_OP(send); USE_NO_KERNEL_OP(send);
USE_NO_KERNEL_OP(recv); USE_NO_KERNEL_OP(listen_and_serv);
USE_OP(sum); USE_OP(sum);
namespace f = paddle::framework; namespace f = paddle::framework;
...@@ -33,7 +33,7 @@ namespace p = paddle::platform; ...@@ -33,7 +33,7 @@ namespace p = paddle::platform;
namespace m = paddle::operators::math; namespace m = paddle::operators::math;
// global for simplicity. // global for simplicity.
std::unique_ptr<f::OperatorBase> recv_op; std::unique_ptr<f::OperatorBase> listen_and_serv_op;
void InitTensorsInScope(f::Scope &scope, p::CPUPlace &place) { void InitTensorsInScope(f::Scope &scope, p::CPUPlace &place) {
p::CPUDeviceContext ctx(place); p::CPUDeviceContext ctx(place);
...@@ -120,7 +120,7 @@ void StartServerNet(bool is_sparse) { ...@@ -120,7 +120,7 @@ void StartServerNet(bool is_sparse) {
InitTensorsInScope(scope, place); InitTensorsInScope(scope, place);
} }
// sub program run in recv_op, for simple test we use sum // sub program run in listen_and_serv_op, for simple test we use sum
f::ProgramDesc program; f::ProgramDesc program;
f::BlockDesc *block = program.MutableBlock(0); f::BlockDesc *block = program.MutableBlock(0);
// X for server side tensors, RX for received tensers, must be of same shape. // X for server side tensors, RX for received tensers, must be of same shape.
...@@ -131,8 +131,9 @@ void StartServerNet(bool is_sparse) { ...@@ -131,8 +131,9 @@ void StartServerNet(bool is_sparse) {
attrs.insert({"ParamList", std::vector<std::string>({"Out"})}); attrs.insert({"ParamList", std::vector<std::string>({"Out"})});
attrs.insert({"GradList", std::vector<std::string>({"x1"})}); attrs.insert({"GradList", std::vector<std::string>({"x1"})});
attrs.insert({"OptimizeBlock", block}); attrs.insert({"OptimizeBlock", block});
recv_op = f::OpRegistry::CreateOp("recv", {{"RX", {"x1"}}}, {}, attrs); listen_and_serv_op =
recv_op->Run(scope, place); f::OpRegistry::CreateOp("listen_and_serv", {{"RX", {"x1"}}}, {}, attrs);
listen_and_serv_op->Run(scope, place);
} }
TEST(SendRecvOp, CPUDense) { TEST(SendRecvOp, CPUDense) {
...@@ -161,9 +162,9 @@ TEST(SendRecvOp, CPUDense) { ...@@ -161,9 +162,9 @@ TEST(SendRecvOp, CPUDense) {
for (int64_t i = 0; i < target->numel(); ++i) { for (int64_t i = 0; i < target->numel(); ++i) {
EXPECT_EQ(expected[i] * 2, actual[i]); EXPECT_EQ(expected[i] * 2, actual[i]);
} }
recv_op->Stop(); listen_and_serv_op->Stop();
server_thread.join(); server_thread.join();
recv_op.reset(nullptr); listen_and_serv_op.reset(nullptr);
} }
TEST(SendRecvOp, CPUSparse) { TEST(SendRecvOp, CPUSparse) {
...@@ -200,7 +201,7 @@ TEST(SendRecvOp, CPUSparse) { ...@@ -200,7 +201,7 @@ TEST(SendRecvOp, CPUSparse) {
EXPECT_EQ(expect_value->mutable_data<float>(place)[i], EXPECT_EQ(expect_value->mutable_data<float>(place)[i],
actual->mutable_data<float>(place)[i]); actual->mutable_data<float>(place)[i]);
} }
recv_op->Stop(); listen_and_serv_op->Stop();
server_thread.join(); server_thread.join();
recv_op.reset(); listen_and_serv_op.reset();
} }
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册