提交 0d4d9c4e 编写于 作者: T typhoonzero

fix grpc short connection

上级 9942565f
...@@ -82,7 +82,7 @@ class ListenAndServOp : public framework::OperatorBase { ...@@ -82,7 +82,7 @@ class ListenAndServOp : public framework::OperatorBase {
return string::Sprintf("%s.trainer_%d", varname, grads_counter_[varname]++); return string::Sprintf("%s.trainer_%d", varname, grads_counter_[varname]++);
} }
void Run(const framework::Scope &scope, void RunImpl(const framework::Scope &scope,
const platform::Place &dev_place) const override { const platform::Place &dev_place) const override {
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(dev_place); auto &dev_ctx = *pool.Get(dev_place);
......
...@@ -32,7 +32,7 @@ class RecvOp : public framework::OperatorBase { ...@@ -32,7 +32,7 @@ class RecvOp : public framework::OperatorBase {
const framework::AttributeMap& attrs) const framework::AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {} : OperatorBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope& scope, void RunImpl(const framework::Scope& scope,
const platform::Place& place) const override { const platform::Place& place) const override {
auto outs = Outputs("Out"); auto outs = Outputs("Out");
std::vector<std::string> epmap = Attr<std::vector<std::string>>("epmap"); std::vector<std::string> epmap = Attr<std::vector<std::string>>("epmap");
......
...@@ -48,7 +48,7 @@ class SendOp : public framework::OperatorBase { ...@@ -48,7 +48,7 @@ class SendOp : public framework::OperatorBase {
const framework::AttributeMap& attrs) const framework::AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {} : OperatorBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope& scope, void RunImpl(const framework::Scope& scope,
const platform::Place& place) const override { const platform::Place& place) const override {
auto ins = Inputs("X"); auto ins = Inputs("X");
auto outs = Outputs("Out"); auto outs = Outputs("Out");
......
...@@ -121,7 +121,6 @@ def split_dense_variable(var_list, ...@@ -121,7 +121,6 @@ def split_dense_variable(var_list,
block_size += dim1 - remains block_size += dim1 - remains
# update split_count after aligning # update split_count after aligning
split_count = int(math.ceil(var_numel / float(block_size))) split_count = int(math.ceil(var_numel / float(block_size)))
print("###split var ", var.name, var.shape, block_size, split_count)
for block_id in xrange(split_count): for block_id in xrange(split_count):
curr_block_size = min(block_size, var_numel - ( curr_block_size = min(block_size, var_numel - (
(block_id) * block_size)) (block_id) * block_size))
...@@ -207,7 +206,7 @@ class DistributeTranspiler: ...@@ -207,7 +206,7 @@ class DistributeTranspiler:
rpc_client_var = program.global_block().create_var( rpc_client_var = program.global_block().create_var(
name="RPC_CLIENT_VAR", name="RPC_CLIENT_VAR",
psersistable=True, persistable=True,
dtype='float32', # dtype and shape is not used in fact dtype='float32', # dtype and shape is not used in fact
shape=[0]) shape=[0])
...@@ -256,15 +255,13 @@ class DistributeTranspiler: ...@@ -256,15 +255,13 @@ class DistributeTranspiler:
splited_shape = [rows] splited_shape = [rows]
if len(orig_shape) >= 2: if len(orig_shape) >= 2:
splited_shape.extend(orig_shape[1:]) splited_shape.extend(orig_shape[1:])
print("###splited: ", size, rows, splited_shape)
var = program.global_block().create_var( var = program.global_block().create_var(
name="%s.block%d" % (varname, i), name="%s.block%d" % (varname, i),
psersistable=False, persistable=False,
dtype=orig_var.dtype, dtype=orig_var.dtype,
type=orig_var.type, type=orig_var.type,
shape=splited_shape) # flattend splited var shape=splited_shape) # flattend splited var
var_mapping[varname].append(var) var_mapping[varname].append(var)
print("###created split var ", var)
return var_mapping return var_mapping
def _clone_var(self, block, var): def _clone_var(self, block, var):
...@@ -322,7 +319,7 @@ class DistributeTranspiler: ...@@ -322,7 +319,7 @@ class DistributeTranspiler:
for i in xrange(trainers): for i in xrange(trainers):
var_each = block.create_var( var_each = block.create_var(
name="%s.trainer_%d" % (var.name, i), name="%s.trainer_%d" % (var.name, i),
psersistable=var.persistable, persistable=var.persistable,
dtype=var.dtype, dtype=var.dtype,
type=var.type, type=var.type,
shape=var.shape) shape=var.shape)
...@@ -531,8 +528,6 @@ class DistributeTranspiler: ...@@ -531,8 +528,6 @@ class DistributeTranspiler:
""" """
# step5 # step5
pserver_program = Program() pserver_program = Program()
print("param mapping on pserver: #### ",
self.param_grad_ep_mapping[endpoint]["params"])
for v in self.param_grad_ep_mapping[endpoint]["params"]: for v in self.param_grad_ep_mapping[endpoint]["params"]:
self._clone_var(pserver_program.global_block(), v) self._clone_var(pserver_program.global_block(), v)
for v in self.param_grad_ep_mapping[endpoint]["grads"]: for v in self.param_grad_ep_mapping[endpoint]["grads"]:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册