From f0af1398b8216428255b7981a4fe0b490d2c03e6 Mon Sep 17 00:00:00 2001 From: Qiao Longfei Date: Fri, 30 Mar 2018 11:30:05 +0800 Subject: [PATCH] add prefetch_op (#9495) * add prefetch_op * fix ci * optimize code * optimize code * fix include --- paddle/fluid/operators/CMakeLists.txt | 6 +- paddle/fluid/operators/detail/grpc_client.cc | 50 +++++++- paddle/fluid/operators/detail/grpc_client.h | 7 ++ paddle/fluid/operators/prefetch_op.cc | 115 +++++++++++++++++++ paddle/fluid/operators/send_op.cc | 20 +--- paddle/fluid/operators/send_recv_util.h | 36 ++++++ paddle/fluid/operators/send_vars_op.cc | 23 +--- 7 files changed, 213 insertions(+), 44 deletions(-) create mode 100644 paddle/fluid/operators/prefetch_op.cc create mode 100644 paddle/fluid/operators/send_recv_util.h diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index 8341170d68..9ed79453b9 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -183,6 +183,8 @@ if(WITH_DISTRIBUTE) set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor") op_library(send_op DEPS ${DISTRIBUTE_DEPS}) set_source_files_properties(send_op.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) + op_library(prefetch_op DEPS ${DISTRIBUTE_DEPS}) + set_source_files_properties(prefetch_op.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) op_library(recv_op DEPS ${DISTRIBUTE_DEPS}) set_source_files_properties(recv_op.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) op_library(listen_and_serv_op DEPS ${DISTRIBUTE_DEPS}) @@ -191,9 +193,9 @@ if(WITH_DISTRIBUTE) set_source_files_properties(send_vars_op.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) op_library(send_barrier_op DEPS ${DISTRIBUTE_DEPS}) set_source_files_properties(send_barrier_op.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) - cc_test(test_send_recv SRCS send_recv_op_test.cc DEPS send_op listen_and_serv_op sum_op executor) + cc_test(test_send_recv SRCS send_recv_op_test.cc DEPS prefetch_op send_op listen_and_serv_op sum_op executor) else() - set(DEPS_OPS ${DEPS_OPS} send_op recv_op listen_and_serv_op send_vars_op send_barrier_op) + set(DEPS_OPS ${DEPS_OPS} send_op prefetch_op recv_op listen_and_serv_op send_vars_op send_barrier_op) endif() op_library(cond_op DEPS framework_proto tensor net_op) diff --git a/paddle/fluid/operators/detail/grpc_client.cc b/paddle/fluid/operators/detail/grpc_client.cc index 03b789f326..9652bb888b 100644 --- a/paddle/fluid/operators/detail/grpc_client.cc +++ b/paddle/fluid/operators/detail/grpc_client.cc @@ -88,10 +88,13 @@ bool RPCClient::AsyncGetVariable(const std::string& ep, const auto ch = GetChannel(ep_val); framework::Async([var_name_val, ep_val, p_scope, p_ctx, time_out, ch, this] { + // prepare input sendrecv::VariableMessage req; req.set_varname(var_name_val); + ::grpc::ByteBuffer buf; + RequestToByteBuffer(req, &buf); - // varhandle + // var handle VarHandle var_h; var_h.ep = ep_val; var_h.scope = p_scope; @@ -103,9 +106,6 @@ bool RPCClient::AsyncGetVariable(const std::string& ep, s->Prepare(var_h, time_out); s->response_call_back_ = ProcGetResponse; - ::grpc::ByteBuffer buf; - RequestToByteBuffer(req, &buf); - auto call = s->stub_g_.PrepareUnaryCall( s->context_.get(), "/sendrecv.SendRecvService/GetVariable", buf, &cq_); call->StartCall(); @@ -117,6 +117,48 @@ bool RPCClient::AsyncGetVariable(const std::string& ep, return true; } +bool RPCClient::AsyncPrefetchVariable(const std::string& ep, + const platform::DeviceContext& ctx, + const framework::Scope& scope, + const std::string& in_var_name, + const std::string& out_var_name, + int64_t time_out) { + const platform::DeviceContext* p_ctx = &ctx; + const std::string ep_val = ep; + const std::string in_var_name_val = in_var_name; + const std::string out_var_name_val = out_var_name; + const framework::Scope* p_scope = &scope; + const auto ch = GetChannel(ep_val); + + framework::Async([in_var_name_val, out_var_name_val, ep_val, p_scope, p_ctx, + time_out, ch, this] { + auto* var = p_scope->FindVar(in_var_name_val); + + ::grpc::ByteBuffer req; + SerializeToByteBuffer(in_var_name_val, var, *p_ctx, &req); + + // var handle + VarHandle var_h; + var_h.ep = ep_val; + var_h.scope = p_scope; + var_h.name = out_var_name_val; + var_h.ctx = p_ctx; + + // stub context + GetProcessor* s = new GetProcessor(ch); + s->Prepare(var_h, time_out); + s->response_call_back_ = ProcGetResponse; + + auto call = s->stub_g_.PrepareUnaryCall( + s->context_.get(), "/sendrecv.SendRecvService/GetVariable", req, &cq_); + call->StartCall(); + call->Finish(&s->reply_, &s->status_, (void*)s); + }); + + req_count_++; + return true; +} + void RPCClient::AsyncSendBatchBarrier(const std::string& ep, int64_t time_out) { const auto ch = GetChannel(ep); diff --git a/paddle/fluid/operators/detail/grpc_client.h b/paddle/fluid/operators/detail/grpc_client.h index 8216ac52fb..fe237e54ef 100644 --- a/paddle/fluid/operators/detail/grpc_client.h +++ b/paddle/fluid/operators/detail/grpc_client.h @@ -172,6 +172,13 @@ class RPCClient { const std::string& var_name, int64_t time_out = 600 * 1000); + bool AsyncPrefetchVariable(const std::string& ep, + const platform::DeviceContext& ctx, + const framework::Scope& scope, + const std::string& in_var_name, + const std::string& out_var_name, + int64_t time_out = 600 * 1000); + void AsyncSendBatchBarrier(const std::string& ep, int64_t time_out = 600 * 1000); diff --git a/paddle/fluid/operators/prefetch_op.cc b/paddle/fluid/operators/prefetch_op.cc new file mode 100644 index 0000000000..09ab7da663 --- /dev/null +++ b/paddle/fluid/operators/prefetch_op.cc @@ -0,0 +1,115 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include + +#include "paddle/fluid/framework/data_type.h" +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/detail/grpc_client.h" +#include "paddle/fluid/operators/send_recv_util.h" + +namespace paddle { +namespace operators { + +class PrefetchOp : public framework::OperatorBase { + public: + PrefetchOp(const std::string& type, const framework::VariableNameMap& inputs, + const framework::VariableNameMap& outputs, + const framework::AttributeMap& attrs) + : OperatorBase(type, inputs, outputs, attrs) {} + + void RunImpl(const framework::Scope& scope, + const platform::Place& place) const override { + auto ins = Inputs("X"); + auto outs = Outputs("Out"); + + std::vector epmap = Attr>("epmap"); + + platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); + auto& ctx = *pool.Get(place); + + auto client_var_name = Output("RPCClient"); + PADDLE_ENFORCE_NOT_NULL(scope.FindVar(client_var_name), + "Can not find variable '%s' in the scope.", + client_var_name); + auto* client_var = scope.FindVar(client_var_name); + detail::RPCClient* rpc_client = client_var->GetMutable(); + + for (size_t i = 0; i < ins.size(); i++) { + if (NeedSend(scope, ins[i])) { + VLOG(3) << "sending " << ins[i] << " to " << epmap[i] << "to get " + << outs[i] << "back"; + rpc_client->AsyncPrefetchVariable(epmap[i], ctx, scope, ins[i], + outs[i]); + } else { + VLOG(3) << "don't send no-initialied variable: " << ins[i]; + } + } + PADDLE_ENFORCE(rpc_client->Wait()); + } +}; + +class PrefetchOpMaker : public framework::OpProtoAndCheckerMaker { + public: + PrefetchOpMaker(OpProto* proto, OpAttrChecker* op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", "(LoDTensor) Input Id variables to be sent").AsDuplicable(); + AddOutput("RPCClient", + "(RPCClient) The RPC client object which will be" + "initialized at most once."); + AddOutput("Out", + "(SelectedRows) result " + "to be fetched from parameter server") + .AsDuplicable(); + AddAttr>( + "epmap", + "(string vector, default 127.0.0.1:6164)" + "Server endpoints in the order of input variables for mapping") + .SetDefault({"127.0.0.1:6164"}); + AddComment(R"DOC( +Prefetch operator + +This operator will send Ids variables to listen_and_serve op at +the parameter server and fetch result back. +)DOC"); + } +}; + +class PrefetchOpVarTypeInference : public framework::VarTypeInference { + public: + void operator()(const framework::OpDesc& op_desc, + framework::BlockDesc* block) const override { + auto out_var_name = op_desc.Output("RPCClient").front(); + auto& out_var = block->FindRecursiveOrCreateVar(out_var_name); + auto var_type = framework::proto::VarType::RAW; + out_var.SetType(var_type); + } +}; + +class PrefetchOpShapeInference : public framework::InferShapeBase { + public: + void operator()(framework::InferShapeContext* ctx) const override {} +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OPERATOR(prefetch, ops::PrefetchOp, + paddle::framework::EmptyGradOpMaker, ops::PrefetchOpMaker, + ops::PrefetchOpVarTypeInference, + ops::PrefetchOpShapeInference); diff --git a/paddle/fluid/operators/send_op.cc b/paddle/fluid/operators/send_op.cc index 0752bd1bbd..d47f66de21 100644 --- a/paddle/fluid/operators/send_op.cc +++ b/paddle/fluid/operators/send_op.cc @@ -12,35 +12,19 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include #include #include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/framework.pb.h" #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/op_registry.h" - -#include #include "paddle/fluid/operators/detail/grpc_client.h" +#include "paddle/fluid/operators/send_recv_util.h" #include "paddle/fluid/platform/profiler.h" namespace paddle { namespace operators { -static bool NeedSend(const framework::Scope& scope, - const std::string& varname) { - auto* var = scope.FindVar(varname); - PADDLE_ENFORCE_NOT_NULL(var, "Can not find variable '%s' in the send side.", - varname); - if (var->IsType()) { - return var->Get().IsInitialized(); - } else if (var->IsType()) { - return var->Get().rows().size() > 0UL; - } else { - PADDLE_THROW( - "Variable type in send side should be in " - "[LodTensor, SelectedRows]"); - } - return false; -} class SendOp : public framework::OperatorBase { public: diff --git a/paddle/fluid/operators/send_recv_util.h b/paddle/fluid/operators/send_recv_util.h new file mode 100644 index 0000000000..196f56f634 --- /dev/null +++ b/paddle/fluid/operators/send_recv_util.h @@ -0,0 +1,36 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +namespace paddle { +namespace operators { + +inline bool NeedSend(const framework::Scope& scope, + const std::string& varname) { + auto* var = scope.FindVar(varname); + PADDLE_ENFORCE_NOT_NULL(var, "Can not find variable '%s' in the send side.", + varname); + if (var->IsType()) { + return var->Get().IsInitialized(); + } else if (var->IsType()) { + return var->Get().rows().size() > 0UL; + } else { + PADDLE_THROW( + "Variable type in send side should be in " + "[LodTensor, SelectedRows]"); + } + return false; +} + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/send_vars_op.cc b/paddle/fluid/operators/send_vars_op.cc index 523e9e2780..2cbd9e2394 100644 --- a/paddle/fluid/operators/send_vars_op.cc +++ b/paddle/fluid/operators/send_vars_op.cc @@ -12,34 +12,17 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include #include #include "paddle/fluid/framework/data_type.h" -#include "paddle/fluid/framework/framework.pb.h" #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/op_registry.h" - -#include #include "paddle/fluid/operators/detail/grpc_client.h" +#include "paddle/fluid/operators/send_recv_util.h" namespace paddle { namespace operators { -static bool NeedSend(const framework::Scope& scope, - const std::string& varname) { - auto* var = scope.FindVar(varname); - PADDLE_ENFORCE_NOT_NULL(var, "Can not find variable '%s' in the send side.", - varname); - if (var->IsType()) { - return var->Get().IsInitialized(); - } else if (var->IsType()) { - return var->Get().rows().size() > 0UL; - } else { - PADDLE_THROW( - "Variable type in send side should be in " - "[LodTensor, SelectedRows]"); - } - return false; -} class SendVarsOp : public framework::OperatorBase { public: @@ -95,7 +78,7 @@ Send operator This operator will send variables to listen_and_serve op at the parameter server. )DOC"); - AddAttr("ync_send", + AddAttr("sync_send", "(int, default 0)" "sync send or async send.") .SetDefault(0); -- GitLab