提交 aa84b21e 编写于 作者: Y Yancey1989

fix unit tests

上级 d723022e
......@@ -262,7 +262,7 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
auto optimize_blocks =
Attr<std::vector<framework::BlockDesc *>>(kOptimizeBlocks);
PADDLE_ENFORCE(optimize_blocks.size() > 1,
PADDLE_ENFORCE(optimize_blocks.size() >= 1,
"optimize blocks should be 1 at least on the pserver side.");
auto *program = optimize_blocks[0]->Program();
framework::Executor executor(dev_place);
......
......@@ -558,19 +558,20 @@ class Operator(object):
if (attr_name not in self.attrs) or (
self.attrs[attr_name] is None):
continue
if isinstance(self.attrs[attr_name], Block):
attr_val = self.attrs[attr_name]
if isinstance(attr_val, Block):
self.desc.set_block_attr(attr_name,
self.attrs[attr_name].desc)
elif isinstance(self.attrs[attr_name], list) and \
all(isinstance(v, Block) for v in self.attrs[attr_name]):
self.desc.set_blocks_attr(
attr_name, [v.desc for v in self.attrs[attr_name]])
elif isinstance(self.attrs[attr_name], core.BlockDesc) or \
isinstance(self.attrs[attr_name], core.ProgramDesc):
elif isinstance(attr_val, list) and attr_val and \
all(isinstance(v, Block) for v in attr_val):
self.desc.set_blocks_attr(attr_name,
[v.desc for v in attr_val])
elif isinstance(attr_val, core.BlockDesc) or \
isinstance(attr_val, core.ProgramDesc):
self.desc.set_serialized_attr(
attr_name, self.attrs[attr_name].serialize_to_string())
attr_name, attr_val.serialize_to_string())
else:
self.desc.set_attr(attr_name, self.attrs[attr_name])
self.desc.set_attr(attr_name, attr_val)
self.desc.check_attrs()
if self.has_kernel(type):
self.desc.infer_var_type(self.block.desc)
......@@ -719,7 +720,8 @@ class Operator(object):
self.attrs[name] = val
if isinstance(val, Block):
self.desc.set_block_attr(name, val.desc)
elif isinstance(val, list) and all(isinstance(v, Block) for v in val):
elif isinstance(val, list) and val and all(
isinstance(v, Block) for v in val):
self.desc.set_blocks_attr(name, [v.desc for v in val])
elif isinstance(val, core.BlockDesc) or \
isinstance(val, core.ProgramDesc):
......
......@@ -186,7 +186,6 @@ class ListenAndServ(object):
main_program = self.helper.main_program
current_block = main_program.current_block()
parent_block = self.parent_block()
empty_block = Program().global_block()
parent_block.append_op(
type='listen_and_serv',
......@@ -195,8 +194,9 @@ class ListenAndServ(object):
attrs={
'endpoint': self.endpoint,
'Fanin': self.fan_in,
'OptimizeBlock': current_block,
'PrefetchBlock': empty_block,
'optimize_blocks': [
current_block
], # did not support multiple optimize blocks in layers
'sync_mode': True, # did not support async now in layers
'grad_to_block_id': [""]
})
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册