提交 9b5be6ef 编写于 作者: T typhoonzero

fix short connection again

上级 d4dabe3e
......@@ -58,13 +58,13 @@ static void CreateTensor(Variable* var, proto::VarType::Type var_type) {
var->GetMutable<ReaderHolder>();
} else if (var_type == proto::VarType::CHANNEL) {
var->GetMutable<ChannelHolder>();
} else if (var_type == proto::VarType::NCCL_COM) {
// GetMutable will be called in ncclInit
} else if (var_type == proto::VarType::RAW) {
// GetMutable will be called in operator
} else {
PADDLE_THROW(
"Variable type %d is not in "
"[LOD_TENSOR, SELECTED_ROWS, FEED_MINIBATCH, FETCH_LIST, "
"LOD_RANK_TABLE, PLACE_LIST, READER, CHANNEL, NCCL_COM]",
"LOD_RANK_TABLE, PLACE_LIST, READER, CHANNEL, RAW]",
var_type);
}
}
......
......@@ -113,7 +113,10 @@ message VarType {
PLACE_LIST = 14;
READER = 15;
CHANNEL = 16;
NCCL_COM = 17;
// Any runtime decided variable type is raw
// raw variables should manage their own allocations
// in operators likc nccl_op
RAW = 17;
}
required Type type = 1;
......
......@@ -177,8 +177,8 @@ std::shared_ptr<grpc::Channel> RPCClient::GetChannel(const std::string& ep) {
args.SetMaxSendMessageSize(std::numeric_limits<int>::max());
args.SetMaxReceiveMessageSize(std::numeric_limits<int>::max());
auto ch = std::shared_ptr<grpc::Channel>(
grpc::CreateCustomChannel(ep, grpc::InsecureChannelCredentials(), args));
auto ch =
grpc::CreateCustomChannel(ep, grpc::InsecureChannelCredentials(), args);
channels_[ep] = ch;
return ch;
......
......@@ -65,7 +65,7 @@ class NCCLInitOpVarTypeInference : public framework::VarTypeInference {
framework::BlockDesc *block) const override {
auto out_var_name = op_desc.Output("Communicator").front();
auto &out_var = block->FindRecursiveOrCreateVar(out_var_name);
auto var_type = framework::proto::VarType::NCCL_COM;
auto var_type = framework::proto::VarType::RAW;
out_var.SetType(var_type);
}
};
......
......@@ -121,9 +121,27 @@ This operator will send tensor to recv_op at the parameter server.
}
};
class SendOpVarTypeInference : public framework::VarTypeInference {
public:
void operator()(const framework::OpDesc& op_desc,
framework::BlockDesc* block) const override {
auto out_var_name = op_desc.Output("RPCClient").front();
auto& out_var = block->FindRecursiveOrCreateVar(out_var_name);
auto var_type = framework::proto::VarType::RAW;
out_var.SetType(var_type);
}
};
class SendOpShapeInference : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext* ctx) const override {}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(send, ops::SendOp, ops::SendOpMaker);
REGISTER_OPERATOR(send, ops::SendOp, paddle::framework::EmptyGradOpMaker,
ops::SendOpMaker, ops::SendOpVarTypeInference,
ops::SendOpShapeInference);
......@@ -252,7 +252,7 @@ void BindVarDsec(py::module &m) {
.value("CHANNEL", proto::VarType::CHANNEL)
.value("PLACE_LIST", proto::VarType::PLACE_LIST)
.value("READER", proto::VarType::READER)
.value("NCCL_COM", proto::VarType::NCCL_COM);
.value("RAW", proto::VarType::RAW);
}
void BindOpDesc(py::module &m) {
......
......@@ -226,8 +226,7 @@ class DistributeTranspiler:
rpc_client_var = program.global_block().create_var(
name="RPC_CLIENT_VAR",
persistable=True,
dtype='float32', # dtype and shape is not used in fact
shape=[0])
type=core.VarDesc.VarType.RAW)
# create send_op
program.global_block().append_op(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册