未验证 提交 0afc4b02 编写于 作者: J Jason 提交者: GitHub

Merge pull request #10 from jiangjiajun/master

fix bug for avg_pool
......@@ -594,8 +594,9 @@ class PaddleEmitter(object):
pad_param)
node.code.add_layer("pool2d", node.output_name,
node.output_name, param_attr)
return
node.code.add_layer("pool2d", data.ref_name, node.output_name,
param_attr)
param_attr)
def emit_rsqrt(self, node):
data = node.inputs[0]
......@@ -670,9 +671,22 @@ class PaddleEmitter(object):
for i in range(len(strides)):
assert strides[i] == 1
param_attr = {"axes": range(len(begin)), "starts": begin, "ends": end}
node.code.add_layer("slice", data.ref_name, node.output_name,
param_attr)
if len(set(end)) == 1 and end[0] == 0:
output_shape = list(self.infer.get_tensor_shape(node.layer))
if node.data_format == NHWC and len(output_shape) == 4:
output_shape = [output_shape[0],
output_shape[3],
output_shape[1],
output_shape[2]]
begin = [begin[0], begin[3], begin[1], begin[2]]
param = {"shape":output_shape, "offsets":begin}
node.code.add_layer("crop", data.ref_name,
node.output_name, param)
else:
param = {"axes": range(len(begin)), "starts": begin, "ends": end}
node.code.add_layer("slice", data.ref_name,
node.output_name, param)
def emit_resizenearestneighbor(self, node):
data = node.inputs[0]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册