未验证 提交 c0fc50d5 编写于 作者: 武毅 提交者: GitHub

Merge pull request #8409 from typhoonzero/fix_grpc_short_conn

Fix grpc short connection
......@@ -82,8 +82,8 @@ class ListenAndServOp : public framework::OperatorBase {
return string::Sprintf("%s.trainer_%d", varname, grads_counter_[varname]++);
}
void Run(const framework::Scope &scope,
const platform::Place &dev_place) const override {
void RunImpl(const framework::Scope &scope,
const platform::Place &dev_place) const override {
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(dev_place);
framework::Scope &recv_scope = scope.NewScope();
......
......@@ -32,8 +32,8 @@ class RecvOp : public framework::OperatorBase {
const framework::AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope& scope,
const platform::Place& place) const override {
void RunImpl(const framework::Scope& scope,
const platform::Place& place) const override {
auto outs = Outputs("Out");
std::vector<std::string> epmap = Attr<std::vector<std::string>>("epmap");
......
......@@ -48,8 +48,8 @@ class SendOp : public framework::OperatorBase {
const framework::AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope& scope,
const platform::Place& place) const override {
void RunImpl(const framework::Scope& scope,
const platform::Place& place) const override {
auto ins = Inputs("X");
auto outs = Outputs("Out");
std::vector<std::string> epmap = Attr<std::vector<std::string>>("epmap");
......
......@@ -121,7 +121,6 @@ def split_dense_variable(var_list,
block_size += dim1 - remains
# update split_count after aligning
split_count = int(math.ceil(var_numel / float(block_size)))
print("###split var ", var.name, var.shape, block_size, split_count)
for block_id in xrange(split_count):
curr_block_size = min(block_size, var_numel - (
(block_id) * block_size))
......@@ -207,7 +206,7 @@ class DistributeTranspiler:
rpc_client_var = program.global_block().create_var(
name="RPC_CLIENT_VAR",
psersistable=True,
persistable=True,
dtype='float32', # dtype and shape is not used in fact
shape=[0])
......@@ -256,15 +255,13 @@ class DistributeTranspiler:
splited_shape = [rows]
if len(orig_shape) >= 2:
splited_shape.extend(orig_shape[1:])
print("###splited: ", size, rows, splited_shape)
var = program.global_block().create_var(
name="%s.block%d" % (varname, i),
psersistable=False,
persistable=False,
dtype=orig_var.dtype,
type=orig_var.type,
shape=splited_shape) # flattend splited var
var_mapping[varname].append(var)
print("###created split var ", var)
return var_mapping
def _clone_var(self, block, var):
......@@ -322,7 +319,7 @@ class DistributeTranspiler:
for i in xrange(trainers):
var_each = block.create_var(
name="%s.trainer_%d" % (var.name, i),
psersistable=var.persistable,
persistable=var.persistable,
dtype=var.dtype,
type=var.type,
shape=var.shape)
......@@ -531,8 +528,6 @@ class DistributeTranspiler:
"""
# step5
pserver_program = Program()
print("param mapping on pserver: #### ",
self.param_grad_ep_mapping[endpoint]["params"])
for v in self.param_grad_ep_mapping[endpoint]["params"]:
self._clone_var(pserver_program.global_block(), v)
for v in self.param_grad_ep_mapping[endpoint]["grads"]:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册