未验证 提交 449a207d 编写于 作者: J Jason 提交者: GitHub

Merge pull request #835 from wjj19950828/rm_scaleforadd

Remove scale for add
......@@ -72,9 +72,14 @@ def prim_add_(layer,
forward_func=[],
layer_id=None,
different_attrs=None):
line = "{} = {} + {} * {}".format(layer.outputs[0],
if abs(layer.attrs["alpha"] - 1.) < 1e-6:
line = "{} = {} + {}".format(layer.outputs[0],
get_value(layer, "x", different_attrs),
layer.attrs["alpha"],
get_value(layer, "y", different_attrs))
else:
line = "{} = {} + {} * {}".format(
layer.outputs[0],
get_value(layer, "x", different_attrs), layer.attrs["alpha"],
get_value(layer, "y", different_attrs))
forward_func.extend(gen_codes([line], indent=indent))
......
......@@ -169,6 +169,9 @@ pd_reshape = partial(paddle.Tensor.reshape)
@add_tensor_function
def reshape(self, *shape):
# deal with list or tuple type
if isinstance(shape, (list, tuple)):
shape = shape[0]
return pd_reshape(self, shape)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册