提交 aa84b21e 编写于 作者: Y Yancey1989

fix unit tests

上级 d723022e
...@@ -262,7 +262,7 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, ...@@ -262,7 +262,7 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
auto optimize_blocks = auto optimize_blocks =
Attr<std::vector<framework::BlockDesc *>>(kOptimizeBlocks); Attr<std::vector<framework::BlockDesc *>>(kOptimizeBlocks);
PADDLE_ENFORCE(optimize_blocks.size() > 1, PADDLE_ENFORCE(optimize_blocks.size() >= 1,
"optimize blocks should be 1 at least on the pserver side."); "optimize blocks should be 1 at least on the pserver side.");
auto *program = optimize_blocks[0]->Program(); auto *program = optimize_blocks[0]->Program();
framework::Executor executor(dev_place); framework::Executor executor(dev_place);
......
...@@ -558,19 +558,20 @@ class Operator(object): ...@@ -558,19 +558,20 @@ class Operator(object):
if (attr_name not in self.attrs) or ( if (attr_name not in self.attrs) or (
self.attrs[attr_name] is None): self.attrs[attr_name] is None):
continue continue
if isinstance(self.attrs[attr_name], Block): attr_val = self.attrs[attr_name]
if isinstance(attr_val, Block):
self.desc.set_block_attr(attr_name, self.desc.set_block_attr(attr_name,
self.attrs[attr_name].desc) self.attrs[attr_name].desc)
elif isinstance(self.attrs[attr_name], list) and \ elif isinstance(attr_val, list) and attr_val and \
all(isinstance(v, Block) for v in self.attrs[attr_name]): all(isinstance(v, Block) for v in attr_val):
self.desc.set_blocks_attr( self.desc.set_blocks_attr(attr_name,
attr_name, [v.desc for v in self.attrs[attr_name]]) [v.desc for v in attr_val])
elif isinstance(self.attrs[attr_name], core.BlockDesc) or \ elif isinstance(attr_val, core.BlockDesc) or \
isinstance(self.attrs[attr_name], core.ProgramDesc): isinstance(attr_val, core.ProgramDesc):
self.desc.set_serialized_attr( self.desc.set_serialized_attr(
attr_name, self.attrs[attr_name].serialize_to_string()) attr_name, attr_val.serialize_to_string())
else: else:
self.desc.set_attr(attr_name, self.attrs[attr_name]) self.desc.set_attr(attr_name, attr_val)
self.desc.check_attrs() self.desc.check_attrs()
if self.has_kernel(type): if self.has_kernel(type):
self.desc.infer_var_type(self.block.desc) self.desc.infer_var_type(self.block.desc)
...@@ -719,7 +720,8 @@ class Operator(object): ...@@ -719,7 +720,8 @@ class Operator(object):
self.attrs[name] = val self.attrs[name] = val
if isinstance(val, Block): if isinstance(val, Block):
self.desc.set_block_attr(name, val.desc) self.desc.set_block_attr(name, val.desc)
elif isinstance(val, list) and all(isinstance(v, Block) for v in val): elif isinstance(val, list) and val and all(
isinstance(v, Block) for v in val):
self.desc.set_blocks_attr(name, [v.desc for v in val]) self.desc.set_blocks_attr(name, [v.desc for v in val])
elif isinstance(val, core.BlockDesc) or \ elif isinstance(val, core.BlockDesc) or \
isinstance(val, core.ProgramDesc): isinstance(val, core.ProgramDesc):
......
...@@ -186,7 +186,6 @@ class ListenAndServ(object): ...@@ -186,7 +186,6 @@ class ListenAndServ(object):
main_program = self.helper.main_program main_program = self.helper.main_program
current_block = main_program.current_block() current_block = main_program.current_block()
parent_block = self.parent_block() parent_block = self.parent_block()
empty_block = Program().global_block()
parent_block.append_op( parent_block.append_op(
type='listen_and_serv', type='listen_and_serv',
...@@ -195,8 +194,9 @@ class ListenAndServ(object): ...@@ -195,8 +194,9 @@ class ListenAndServ(object):
attrs={ attrs={
'endpoint': self.endpoint, 'endpoint': self.endpoint,
'Fanin': self.fan_in, 'Fanin': self.fan_in,
'OptimizeBlock': current_block, 'optimize_blocks': [
'PrefetchBlock': empty_block, current_block
], # did not support multiple optimize blocks in layers
'sync_mode': True, # did not support async now in layers 'sync_mode': True, # did not support async now in layers
'grad_to_block_id': [""] 'grad_to_block_id': [""]
}) })
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册