未验证 提交 fe7c1814 编写于 作者: 武毅 提交者: GitHub

Merge pull request #8538 from typhoonzero/add_raw_var_type

fix short connection again
...@@ -58,13 +58,13 @@ static void CreateTensor(Variable* var, proto::VarType::Type var_type) { ...@@ -58,13 +58,13 @@ static void CreateTensor(Variable* var, proto::VarType::Type var_type) {
var->GetMutable<ReaderHolder>(); var->GetMutable<ReaderHolder>();
} else if (var_type == proto::VarType::CHANNEL) { } else if (var_type == proto::VarType::CHANNEL) {
var->GetMutable<ChannelHolder>(); var->GetMutable<ChannelHolder>();
} else if (var_type == proto::VarType::NCCL_COM) { } else if (var_type == proto::VarType::RAW) {
// GetMutable will be called in ncclInit // GetMutable will be called in operator
} else { } else {
PADDLE_THROW( PADDLE_THROW(
"Variable type %d is not in " "Variable type %d is not in "
"[LOD_TENSOR, SELECTED_ROWS, FEED_MINIBATCH, FETCH_LIST, " "[LOD_TENSOR, SELECTED_ROWS, FEED_MINIBATCH, FETCH_LIST, "
"LOD_RANK_TABLE, PLACE_LIST, READER, CHANNEL, NCCL_COM]", "LOD_RANK_TABLE, PLACE_LIST, READER, CHANNEL, RAW]",
var_type); var_type);
} }
} }
......
...@@ -113,7 +113,10 @@ message VarType { ...@@ -113,7 +113,10 @@ message VarType {
PLACE_LIST = 14; PLACE_LIST = 14;
READER = 15; READER = 15;
CHANNEL = 16; CHANNEL = 16;
NCCL_COM = 17; // Any runtime decided variable type is raw
// raw variables should manage their own allocations
// in operators like nccl_op
RAW = 17;
} }
required Type type = 1; required Type type = 1;
......
...@@ -177,8 +177,8 @@ std::shared_ptr<grpc::Channel> RPCClient::GetChannel(const std::string& ep) { ...@@ -177,8 +177,8 @@ std::shared_ptr<grpc::Channel> RPCClient::GetChannel(const std::string& ep) {
args.SetMaxSendMessageSize(std::numeric_limits<int>::max()); args.SetMaxSendMessageSize(std::numeric_limits<int>::max());
args.SetMaxReceiveMessageSize(std::numeric_limits<int>::max()); args.SetMaxReceiveMessageSize(std::numeric_limits<int>::max());
auto ch = std::shared_ptr<grpc::Channel>( auto ch =
grpc::CreateCustomChannel(ep, grpc::InsecureChannelCredentials(), args)); grpc::CreateCustomChannel(ep, grpc::InsecureChannelCredentials(), args);
channels_[ep] = ch; channels_[ep] = ch;
return ch; return ch;
......
...@@ -65,7 +65,7 @@ class NCCLInitOpVarTypeInference : public framework::VarTypeInference { ...@@ -65,7 +65,7 @@ class NCCLInitOpVarTypeInference : public framework::VarTypeInference {
framework::BlockDesc *block) const override { framework::BlockDesc *block) const override {
auto out_var_name = op_desc.Output("Communicator").front(); auto out_var_name = op_desc.Output("Communicator").front();
auto &out_var = block->FindRecursiveOrCreateVar(out_var_name); auto &out_var = block->FindRecursiveOrCreateVar(out_var_name);
auto var_type = framework::proto::VarType::NCCL_COM; auto var_type = framework::proto::VarType::RAW;
out_var.SetType(var_type); out_var.SetType(var_type);
} }
}; };
......
...@@ -121,9 +121,27 @@ This operator will send tensor to recv_op at the parameter server. ...@@ -121,9 +121,27 @@ This operator will send tensor to recv_op at the parameter server.
} }
}; };
class SendOpVarTypeInference : public framework::VarTypeInference {
public:
void operator()(const framework::OpDesc& op_desc,
framework::BlockDesc* block) const override {
auto out_var_name = op_desc.Output("RPCClient").front();
auto& out_var = block->FindRecursiveOrCreateVar(out_var_name);
auto var_type = framework::proto::VarType::RAW;
out_var.SetType(var_type);
}
};
class SendOpShapeInference : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext* ctx) const override {}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(send, ops::SendOp, ops::SendOpMaker); REGISTER_OPERATOR(send, ops::SendOp, paddle::framework::EmptyGradOpMaker,
ops::SendOpMaker, ops::SendOpVarTypeInference,
ops::SendOpShapeInference);
...@@ -252,7 +252,7 @@ void BindVarDsec(py::module &m) { ...@@ -252,7 +252,7 @@ void BindVarDsec(py::module &m) {
.value("CHANNEL", proto::VarType::CHANNEL) .value("CHANNEL", proto::VarType::CHANNEL)
.value("PLACE_LIST", proto::VarType::PLACE_LIST) .value("PLACE_LIST", proto::VarType::PLACE_LIST)
.value("READER", proto::VarType::READER) .value("READER", proto::VarType::READER)
.value("NCCL_COM", proto::VarType::NCCL_COM); .value("RAW", proto::VarType::RAW);
} }
void BindOpDesc(py::module &m) { void BindOpDesc(py::module &m) {
......
...@@ -226,8 +226,7 @@ class DistributeTranspiler: ...@@ -226,8 +226,7 @@ class DistributeTranspiler:
rpc_client_var = program.global_block().create_var( rpc_client_var = program.global_block().create_var(
name="RPC_CLIENT_VAR", name="RPC_CLIENT_VAR",
persistable=True, persistable=True,
dtype='float32', # dtype and shape is not used in fact type=core.VarDesc.VarType.RAW)
shape=[0])
# create send_op # create send_op
program.global_block().append_op( program.global_block().append_op(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册