diff --git a/x2paddle/op_mapper/pytorch2paddle/prim2code.py b/x2paddle/op_mapper/pytorch2paddle/prim2code.py index 17ad50427d5a6ec0969042ba2bb094083ed1e32d..673456ce931d1329a29720915f183f67a6ad1dae 100755 --- a/x2paddle/op_mapper/pytorch2paddle/prim2code.py +++ b/x2paddle/op_mapper/pytorch2paddle/prim2code.py @@ -72,10 +72,15 @@ def prim_add_(layer, forward_func=[], layer_id=None, different_attrs=None): - line = "{} = {} + {} * {}".format(layer.outputs[0], - get_value(layer, "x", different_attrs), - layer.attrs["alpha"], - get_value(layer, "y", different_attrs)) + if abs(layer.attrs["alpha"] - 1.) < 1e-6: + line = "{} = {} + {}".format(layer.outputs[0], + get_value(layer, "x", different_attrs), + get_value(layer, "y", different_attrs)) + else: + line = "{} = {} + {} * {}".format( + layer.outputs[0], + get_value(layer, "x", different_attrs), layer.attrs["alpha"], + get_value(layer, "y", different_attrs)) forward_func.extend(gen_codes([line], indent=indent)) diff --git a/x2paddle/project_convertor/pytorch/torch2paddle/tensor.py b/x2paddle/project_convertor/pytorch/torch2paddle/tensor.py index 1c9635dba3b46325d4456d88f0a401776469785b..76e97aa7925c8a9ba88945bb6e6e588cfe7d0416 100644 --- a/x2paddle/project_convertor/pytorch/torch2paddle/tensor.py +++ b/x2paddle/project_convertor/pytorch/torch2paddle/tensor.py @@ -169,6 +169,9 @@ pd_reshape = partial(paddle.Tensor.reshape) @add_tensor_function def reshape(self, *shape): + # deal with list or tuple type + if isinstance(shape, (list, tuple)): + shape = shape[0] return pd_reshape(self, shape)