提交 ad15b9c1 编写于 作者: W wjj19950828

rm scale for add

上级 07e1fa2b
...@@ -72,10 +72,15 @@ def prim_add_(layer, ...@@ -72,10 +72,15 @@ def prim_add_(layer,
forward_func=[], forward_func=[],
layer_id=None, layer_id=None,
different_attrs=None): different_attrs=None):
line = "{} = {} + {} * {}".format(layer.outputs[0], if layer.attrs["alpha"] == 1:
get_value(layer, "x", different_attrs), line = "{} = {} + {}".format(layer.outputs[0],
layer.attrs["alpha"], get_value(layer, "x", different_attrs),
get_value(layer, "y", different_attrs)) get_value(layer, "y", different_attrs))
else:
line = "{} = {} + {} * {}".format(
layer.outputs[0],
get_value(layer, "x", different_attrs), layer.attrs["alpha"],
get_value(layer, "y", different_attrs))
forward_func.extend(gen_codes([line], indent=indent)) forward_func.extend(gen_codes([line], indent=indent))
......
...@@ -169,6 +169,9 @@ pd_reshape = partial(paddle.Tensor.reshape) ...@@ -169,6 +169,9 @@ pd_reshape = partial(paddle.Tensor.reshape)
@add_tensor_function @add_tensor_function
def reshape(self, *shape): def reshape(self, *shape):
# deal with list or tuple type
if isinstance(shape, (list, tuple)):
shape = shape[0]
return pd_reshape(self, shape) return pd_reshape(self, shape)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册