From 4a66e7cffd519db23106892aed69423d827ade1b Mon Sep 17 00:00:00 2001 From: Yiqun Liu Date: Fri, 2 Dec 2022 10:38:36 +0800 Subject: [PATCH] Optimize the python overhead of reshape and layer_norm. (#48635) --- python/paddle/nn/functional/norm.py | 13 +++------- python/paddle/tensor/manipulation.py | 38 ++++++++++------------------ 2 files changed, 18 insertions(+), 33 deletions(-) diff --git a/python/paddle/nn/functional/norm.py b/python/paddle/nn/functional/norm.py index f2546b6244..6e248af333 100644 --- a/python/paddle/nn/functional/norm.py +++ b/python/paddle/nn/functional/norm.py @@ -382,16 +382,11 @@ def layer_norm( ) if in_dygraph_mode(): - ( - pre_act, - _, - _, - ) = _C_ops.layer_norm(x, weight, bias, epsilon, begin_norm_axis) - - return dygraph_utils._append_activation_in_dygraph(pre_act, act=None) + out, _, _ = _C_ops.layer_norm(x, weight, bias, epsilon, begin_norm_axis) + return out if _in_legacy_dygraph(): - pre_act, _, _ = _legacy_C_ops.layer_norm( + out, _, _ = _legacy_C_ops.layer_norm( x, weight, bias, @@ -400,7 +395,7 @@ def layer_norm( 'begin_norm_axis', begin_norm_axis, ) - return dygraph_utils._append_activation_in_dygraph(pre_act, act=None) + return out check_variable_and_dtype( x, 'input', ['float16', 'float32', 'float64'], 'LayerNorm' diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index cb4fec4a33..fceae51e14 100644 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -14,7 +14,6 @@ # TODO: define functions to manipulate a tensor -import warnings from collections import Counter import numpy as np @@ -22,7 +21,7 @@ import numpy as np import paddle from paddle import _C_ops, _legacy_C_ops -from ..common_ops_import import _varbase_creator, dygraph_utils, fill_constant +from ..common_ops_import import _varbase_creator, fill_constant from ..fluid.data_feeder import ( check_dtype, check_type, @@ -3564,16 +3563,9 @@ def reshape(x, shape, name=None): """ actual_shape = None - act = None - inplace = False if in_dygraph_mode(): tmp_tensor_type = core.eager.Tensor - # TODO(zhiqiu): enable inplace in dygraph mode. - if inplace: - warnings.warn( - "Inplace on reshape is not allowed and will be discarded in dygraph mode currently." - ) if isinstance(shape, (list, tuple)): shape = [ item.numpy().item(0) @@ -3581,8 +3573,11 @@ def reshape(x, shape, name=None): else item for item in shape ] - out = _C_ops.reshape(x, shape) - elif isinstance(shape, tmp_tensor_type): + if shape == x.shape: + out = x + else: + out = _C_ops.reshape(x, shape) + elif isinstance(shape, core.eager.Tensor): shape.stop_gradient = True out = _C_ops.reshape(x, shape) else: @@ -3591,14 +3586,10 @@ def reshape(x, shape, name=None): " got '{}.'".format(type(shape)) ) - return dygraph_utils._append_activation_in_dygraph(out, act) + return out else: if _in_legacy_dygraph(): tmp_tensor_type = Variable - if inplace: - warnings.warn( - "Inplace on reshape is not allowed and will be discarded in dygraph mode currently." - ) if isinstance(shape, (list, tuple)): shape = [ item.numpy().item(0) if isinstance(item, Variable) else item @@ -3614,7 +3605,7 @@ def reshape(x, shape, name=None): " got '{}.'".format(type(shape)) ) - return dygraph_utils._append_activation_in_dygraph(out, act) + return out check_variable_and_dtype( x, @@ -3690,11 +3681,7 @@ def reshape(x, shape, name=None): actual_shape.stop_gradient = True inputs["Shape"] = actual_shape - out = ( - x - if inplace - else helper.create_variable_for_type_inference(dtype=x.dtype) - ) + out = helper.create_variable_for_type_inference(dtype=x.dtype) x_shape = helper.create_variable_for_type_inference(dtype=x.dtype) helper.append_op( type="reshape2", @@ -3703,7 +3690,7 @@ def reshape(x, shape, name=None): outputs={"Out": out, "XShape": x_shape}, ) - return helper.append_activation(out) + return out @inplace_apis_in_dygraph_only @@ -3721,7 +3708,10 @@ def reshape_(x, shape, name=None): else item for item in shape ] - out = _C_ops.reshape_(x, shape) + if shape == x.shape: + out = x + else: + out = _C_ops.reshape_(x, shape) elif isinstance(shape, tmp_tensor_type): shape.stop_gradient = True out = _C_ops.reshape_(x, shape) -- GitLab