From 9b5be6ef43c31ccb6f9c306c0c44a8f19a72a24f Mon Sep 17 00:00:00 2001 From: typhoonzero Date: Sat, 24 Feb 2018 15:13:18 +0800 Subject: [PATCH] fix short connection again --- paddle/fluid/framework/executor.cc | 6 +++--- paddle/fluid/framework/framework.proto | 5 ++++- paddle/fluid/operators/detail/grpc_client.cc | 4 ++-- paddle/fluid/operators/nccl_op.cc | 2 +- paddle/fluid/operators/send_op.cc | 20 ++++++++++++++++++- paddle/fluid/pybind/protobuf.cc | 2 +- .../paddle/v2/fluid/distribute_transpiler.py | 3 +-- 7 files changed, 31 insertions(+), 11 deletions(-) diff --git a/paddle/fluid/framework/executor.cc b/paddle/fluid/framework/executor.cc index 0d2691e8115..88863ab99eb 100644 --- a/paddle/fluid/framework/executor.cc +++ b/paddle/fluid/framework/executor.cc @@ -58,13 +58,13 @@ static void CreateTensor(Variable* var, proto::VarType::Type var_type) { var->GetMutable(); } else if (var_type == proto::VarType::CHANNEL) { var->GetMutable(); - } else if (var_type == proto::VarType::NCCL_COM) { - // GetMutable will be called in ncclInit + } else if (var_type == proto::VarType::RAW) { + // GetMutable will be called in operator } else { PADDLE_THROW( "Variable type %d is not in " "[LOD_TENSOR, SELECTED_ROWS, FEED_MINIBATCH, FETCH_LIST, " - "LOD_RANK_TABLE, PLACE_LIST, READER, CHANNEL, NCCL_COM]", + "LOD_RANK_TABLE, PLACE_LIST, READER, CHANNEL, RAW]", var_type); } } diff --git a/paddle/fluid/framework/framework.proto b/paddle/fluid/framework/framework.proto index 5b43f5a8a4a..23064541a05 100644 --- a/paddle/fluid/framework/framework.proto +++ b/paddle/fluid/framework/framework.proto @@ -113,7 +113,10 @@ message VarType { PLACE_LIST = 14; READER = 15; CHANNEL = 16; - NCCL_COM = 17; + // Any runtime decided variable type is raw + // raw variables should manage their own allocations + // in operators likc nccl_op + RAW = 17; } required Type type = 1; diff --git a/paddle/fluid/operators/detail/grpc_client.cc b/paddle/fluid/operators/detail/grpc_client.cc index ee9044b1f5d..7266f327647 100644 --- a/paddle/fluid/operators/detail/grpc_client.cc +++ b/paddle/fluid/operators/detail/grpc_client.cc @@ -177,8 +177,8 @@ std::shared_ptr RPCClient::GetChannel(const std::string& ep) { args.SetMaxSendMessageSize(std::numeric_limits::max()); args.SetMaxReceiveMessageSize(std::numeric_limits::max()); - auto ch = std::shared_ptr( - grpc::CreateCustomChannel(ep, grpc::InsecureChannelCredentials(), args)); + auto ch = + grpc::CreateCustomChannel(ep, grpc::InsecureChannelCredentials(), args); channels_[ep] = ch; return ch; diff --git a/paddle/fluid/operators/nccl_op.cc b/paddle/fluid/operators/nccl_op.cc index 0994bba782b..9185666c56c 100644 --- a/paddle/fluid/operators/nccl_op.cc +++ b/paddle/fluid/operators/nccl_op.cc @@ -65,7 +65,7 @@ class NCCLInitOpVarTypeInference : public framework::VarTypeInference { framework::BlockDesc *block) const override { auto out_var_name = op_desc.Output("Communicator").front(); auto &out_var = block->FindRecursiveOrCreateVar(out_var_name); - auto var_type = framework::proto::VarType::NCCL_COM; + auto var_type = framework::proto::VarType::RAW; out_var.SetType(var_type); } }; diff --git a/paddle/fluid/operators/send_op.cc b/paddle/fluid/operators/send_op.cc index 58850bf566e..178976f96fd 100644 --- a/paddle/fluid/operators/send_op.cc +++ b/paddle/fluid/operators/send_op.cc @@ -121,9 +121,27 @@ This operator will send tensor to recv_op at the parameter server. } }; +class SendOpVarTypeInference : public framework::VarTypeInference { + public: + void operator()(const framework::OpDesc& op_desc, + framework::BlockDesc* block) const override { + auto out_var_name = op_desc.Output("RPCClient").front(); + auto& out_var = block->FindRecursiveOrCreateVar(out_var_name); + auto var_type = framework::proto::VarType::RAW; + out_var.SetType(var_type); + } +}; + +class SendOpShapeInference : public framework::InferShapeBase { + public: + void operator()(framework::InferShapeContext* ctx) const override {} +}; + } // namespace operators } // namespace paddle namespace ops = paddle::operators; -REGISTER_OPERATOR(send, ops::SendOp, ops::SendOpMaker); +REGISTER_OPERATOR(send, ops::SendOp, paddle::framework::EmptyGradOpMaker, + ops::SendOpMaker, ops::SendOpVarTypeInference, + ops::SendOpShapeInference); diff --git a/paddle/fluid/pybind/protobuf.cc b/paddle/fluid/pybind/protobuf.cc index b725be79529..b0a2497d919 100644 --- a/paddle/fluid/pybind/protobuf.cc +++ b/paddle/fluid/pybind/protobuf.cc @@ -252,7 +252,7 @@ void BindVarDsec(py::module &m) { .value("CHANNEL", proto::VarType::CHANNEL) .value("PLACE_LIST", proto::VarType::PLACE_LIST) .value("READER", proto::VarType::READER) - .value("NCCL_COM", proto::VarType::NCCL_COM); + .value("RAW", proto::VarType::RAW); } void BindOpDesc(py::module &m) { diff --git a/python/paddle/v2/fluid/distribute_transpiler.py b/python/paddle/v2/fluid/distribute_transpiler.py index 2fcf3753c5f..8da9ca290b2 100644 --- a/python/paddle/v2/fluid/distribute_transpiler.py +++ b/python/paddle/v2/fluid/distribute_transpiler.py @@ -226,8 +226,7 @@ class DistributeTranspiler: rpc_client_var = program.global_block().create_var( name="RPC_CLIENT_VAR", persistable=True, - dtype='float32', # dtype and shape is not used in fact - shape=[0]) + type=core.VarDesc.VarType.RAW) # create send_op program.global_block().append_op( -- GitLab