diff --git a/paddle/fluid/framework/executor.cc b/paddle/fluid/framework/executor.cc index 0d2691e8115ad6de46dcd4fcd5b7fd79ed60ecb9..88863ab99eb765124bc825b4e9ec9dff890ba3cc 100644 --- a/paddle/fluid/framework/executor.cc +++ b/paddle/fluid/framework/executor.cc @@ -58,13 +58,13 @@ static void CreateTensor(Variable* var, proto::VarType::Type var_type) { var->GetMutable(); } else if (var_type == proto::VarType::CHANNEL) { var->GetMutable(); - } else if (var_type == proto::VarType::NCCL_COM) { - // GetMutable will be called in ncclInit + } else if (var_type == proto::VarType::RAW) { + // GetMutable will be called in operator } else { PADDLE_THROW( "Variable type %d is not in " "[LOD_TENSOR, SELECTED_ROWS, FEED_MINIBATCH, FETCH_LIST, " - "LOD_RANK_TABLE, PLACE_LIST, READER, CHANNEL, NCCL_COM]", + "LOD_RANK_TABLE, PLACE_LIST, READER, CHANNEL, RAW]", var_type); } } diff --git a/paddle/fluid/framework/framework.proto b/paddle/fluid/framework/framework.proto index 5b43f5a8a4a1c128b04ac206d387e30c55f533fe..53725d3d802c27202a6379cee518991a628cf9a1 100644 --- a/paddle/fluid/framework/framework.proto +++ b/paddle/fluid/framework/framework.proto @@ -113,7 +113,10 @@ message VarType { PLACE_LIST = 14; READER = 15; CHANNEL = 16; - NCCL_COM = 17; + // Any runtime decided variable type is raw + // raw variables should manage their own allocations + // in operators like nccl_op + RAW = 17; } required Type type = 1; diff --git a/paddle/fluid/operators/detail/grpc_client.cc b/paddle/fluid/operators/detail/grpc_client.cc index ee9044b1f5d46dc725c9583d0d90ab5681d2850c..7266f3276477891d3c7b6827316a428ef7a31c6e 100644 --- a/paddle/fluid/operators/detail/grpc_client.cc +++ b/paddle/fluid/operators/detail/grpc_client.cc @@ -177,8 +177,8 @@ std::shared_ptr RPCClient::GetChannel(const std::string& ep) { args.SetMaxSendMessageSize(std::numeric_limits::max()); args.SetMaxReceiveMessageSize(std::numeric_limits::max()); - auto ch = std::shared_ptr( - grpc::CreateCustomChannel(ep, grpc::InsecureChannelCredentials(), args)); + auto ch = + grpc::CreateCustomChannel(ep, grpc::InsecureChannelCredentials(), args); channels_[ep] = ch; return ch; diff --git a/paddle/fluid/operators/nccl_op.cc b/paddle/fluid/operators/nccl_op.cc index 0994bba782b42be994ae479f4c9c4de5a2e384ed..9185666c56c4621d42429c9cfdb079001c6336f1 100644 --- a/paddle/fluid/operators/nccl_op.cc +++ b/paddle/fluid/operators/nccl_op.cc @@ -65,7 +65,7 @@ class NCCLInitOpVarTypeInference : public framework::VarTypeInference { framework::BlockDesc *block) const override { auto out_var_name = op_desc.Output("Communicator").front(); auto &out_var = block->FindRecursiveOrCreateVar(out_var_name); - auto var_type = framework::proto::VarType::NCCL_COM; + auto var_type = framework::proto::VarType::RAW; out_var.SetType(var_type); } }; diff --git a/paddle/fluid/operators/send_op.cc b/paddle/fluid/operators/send_op.cc index 58850bf566e00f88de19305110e2ef696b73467e..178976f96fdbd08cead7b7c518ea1fbaaa2a5db8 100644 --- a/paddle/fluid/operators/send_op.cc +++ b/paddle/fluid/operators/send_op.cc @@ -121,9 +121,27 @@ This operator will send tensor to recv_op at the parameter server. } }; +class SendOpVarTypeInference : public framework::VarTypeInference { + public: + void operator()(const framework::OpDesc& op_desc, + framework::BlockDesc* block) const override { + auto out_var_name = op_desc.Output("RPCClient").front(); + auto& out_var = block->FindRecursiveOrCreateVar(out_var_name); + auto var_type = framework::proto::VarType::RAW; + out_var.SetType(var_type); + } +}; + +class SendOpShapeInference : public framework::InferShapeBase { + public: + void operator()(framework::InferShapeContext* ctx) const override {} +}; + } // namespace operators } // namespace paddle namespace ops = paddle::operators; -REGISTER_OPERATOR(send, ops::SendOp, ops::SendOpMaker); +REGISTER_OPERATOR(send, ops::SendOp, paddle::framework::EmptyGradOpMaker, + ops::SendOpMaker, ops::SendOpVarTypeInference, + ops::SendOpShapeInference); diff --git a/paddle/fluid/pybind/protobuf.cc b/paddle/fluid/pybind/protobuf.cc index b725be79529c5ccdde12446b5b5c09eaf47550e6..b0a2497d919b65afbe5eeaf4fe47c19baa1aba1c 100644 --- a/paddle/fluid/pybind/protobuf.cc +++ b/paddle/fluid/pybind/protobuf.cc @@ -252,7 +252,7 @@ void BindVarDsec(py::module &m) { .value("CHANNEL", proto::VarType::CHANNEL) .value("PLACE_LIST", proto::VarType::PLACE_LIST) .value("READER", proto::VarType::READER) - .value("NCCL_COM", proto::VarType::NCCL_COM); + .value("RAW", proto::VarType::RAW); } void BindOpDesc(py::module &m) { diff --git a/python/paddle/fluid/distribute_transpiler.py b/python/paddle/fluid/distribute_transpiler.py index 2fcf3753c5f1211d3b27f38fbdc8d097c437c79a..8da9ca290b22ae69b1fd195d8614c31dc4e13e00 100644 --- a/python/paddle/fluid/distribute_transpiler.py +++ b/python/paddle/fluid/distribute_transpiler.py @@ -226,8 +226,7 @@ class DistributeTranspiler: rpc_client_var = program.global_block().create_var( name="RPC_CLIENT_VAR", persistable=True, - dtype='float32', # dtype and shape is not used in fact - shape=[0]) + type=core.VarDesc.VarType.RAW) # create send_op program.global_block().append_op(