未验证 提交 4a66e7cf 编写于 作者: Y Yiqun Liu 提交者: GitHub

Optimize the python overhead of reshape and layer_norm. (#48635)

上级 5bcf35cd
...@@ -382,16 +382,11 @@ def layer_norm( ...@@ -382,16 +382,11 @@ def layer_norm(
) )
if in_dygraph_mode(): if in_dygraph_mode():
( out, _, _ = _C_ops.layer_norm(x, weight, bias, epsilon, begin_norm_axis)
pre_act, return out
_,
_,
) = _C_ops.layer_norm(x, weight, bias, epsilon, begin_norm_axis)
return dygraph_utils._append_activation_in_dygraph(pre_act, act=None)
if _in_legacy_dygraph(): if _in_legacy_dygraph():
pre_act, _, _ = _legacy_C_ops.layer_norm( out, _, _ = _legacy_C_ops.layer_norm(
x, x,
weight, weight,
bias, bias,
...@@ -400,7 +395,7 @@ def layer_norm( ...@@ -400,7 +395,7 @@ def layer_norm(
'begin_norm_axis', 'begin_norm_axis',
begin_norm_axis, begin_norm_axis,
) )
return dygraph_utils._append_activation_in_dygraph(pre_act, act=None) return out
check_variable_and_dtype( check_variable_and_dtype(
x, 'input', ['float16', 'float32', 'float64'], 'LayerNorm' x, 'input', ['float16', 'float32', 'float64'], 'LayerNorm'
......
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
# TODO: define functions to manipulate a tensor # TODO: define functions to manipulate a tensor
import warnings
from collections import Counter from collections import Counter
import numpy as np import numpy as np
...@@ -22,7 +21,7 @@ import numpy as np ...@@ -22,7 +21,7 @@ import numpy as np
import paddle import paddle
from paddle import _C_ops, _legacy_C_ops from paddle import _C_ops, _legacy_C_ops
from ..common_ops_import import _varbase_creator, dygraph_utils, fill_constant from ..common_ops_import import _varbase_creator, fill_constant
from ..fluid.data_feeder import ( from ..fluid.data_feeder import (
check_dtype, check_dtype,
check_type, check_type,
...@@ -3564,16 +3563,9 @@ def reshape(x, shape, name=None): ...@@ -3564,16 +3563,9 @@ def reshape(x, shape, name=None):
""" """
actual_shape = None actual_shape = None
act = None
inplace = False
if in_dygraph_mode(): if in_dygraph_mode():
tmp_tensor_type = core.eager.Tensor tmp_tensor_type = core.eager.Tensor
# TODO(zhiqiu): enable inplace in dygraph mode.
if inplace:
warnings.warn(
"Inplace on reshape is not allowed and will be discarded in dygraph mode currently."
)
if isinstance(shape, (list, tuple)): if isinstance(shape, (list, tuple)):
shape = [ shape = [
item.numpy().item(0) item.numpy().item(0)
...@@ -3581,8 +3573,11 @@ def reshape(x, shape, name=None): ...@@ -3581,8 +3573,11 @@ def reshape(x, shape, name=None):
else item else item
for item in shape for item in shape
] ]
out = _C_ops.reshape(x, shape) if shape == x.shape:
elif isinstance(shape, tmp_tensor_type): out = x
else:
out = _C_ops.reshape(x, shape)
elif isinstance(shape, core.eager.Tensor):
shape.stop_gradient = True shape.stop_gradient = True
out = _C_ops.reshape(x, shape) out = _C_ops.reshape(x, shape)
else: else:
...@@ -3591,14 +3586,10 @@ def reshape(x, shape, name=None): ...@@ -3591,14 +3586,10 @@ def reshape(x, shape, name=None):
" got '{}.'".format(type(shape)) " got '{}.'".format(type(shape))
) )
return dygraph_utils._append_activation_in_dygraph(out, act) return out
else: else:
if _in_legacy_dygraph(): if _in_legacy_dygraph():
tmp_tensor_type = Variable tmp_tensor_type = Variable
if inplace:
warnings.warn(
"Inplace on reshape is not allowed and will be discarded in dygraph mode currently."
)
if isinstance(shape, (list, tuple)): if isinstance(shape, (list, tuple)):
shape = [ shape = [
item.numpy().item(0) if isinstance(item, Variable) else item item.numpy().item(0) if isinstance(item, Variable) else item
...@@ -3614,7 +3605,7 @@ def reshape(x, shape, name=None): ...@@ -3614,7 +3605,7 @@ def reshape(x, shape, name=None):
" got '{}.'".format(type(shape)) " got '{}.'".format(type(shape))
) )
return dygraph_utils._append_activation_in_dygraph(out, act) return out
check_variable_and_dtype( check_variable_and_dtype(
x, x,
...@@ -3690,11 +3681,7 @@ def reshape(x, shape, name=None): ...@@ -3690,11 +3681,7 @@ def reshape(x, shape, name=None):
actual_shape.stop_gradient = True actual_shape.stop_gradient = True
inputs["Shape"] = actual_shape inputs["Shape"] = actual_shape
out = ( out = helper.create_variable_for_type_inference(dtype=x.dtype)
x
if inplace
else helper.create_variable_for_type_inference(dtype=x.dtype)
)
x_shape = helper.create_variable_for_type_inference(dtype=x.dtype) x_shape = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op( helper.append_op(
type="reshape2", type="reshape2",
...@@ -3703,7 +3690,7 @@ def reshape(x, shape, name=None): ...@@ -3703,7 +3690,7 @@ def reshape(x, shape, name=None):
outputs={"Out": out, "XShape": x_shape}, outputs={"Out": out, "XShape": x_shape},
) )
return helper.append_activation(out) return out
@inplace_apis_in_dygraph_only @inplace_apis_in_dygraph_only
...@@ -3721,7 +3708,10 @@ def reshape_(x, shape, name=None): ...@@ -3721,7 +3708,10 @@ def reshape_(x, shape, name=None):
else item else item
for item in shape for item in shape
] ]
out = _C_ops.reshape_(x, shape) if shape == x.shape:
out = x
else:
out = _C_ops.reshape_(x, shape)
elif isinstance(shape, tmp_tensor_type): elif isinstance(shape, tmp_tensor_type):
shape.stop_gradient = True shape.stop_gradient = True
out = _C_ops.reshape_(x, shape) out = _C_ops.reshape_(x, shape)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册