未验证 提交 4a66e7cf 编写于 作者: Y Yiqun Liu 提交者: GitHub

Optimize the python overhead of reshape and layer_norm. (#48635)

上级 5bcf35cd
......@@ -382,16 +382,11 @@ def layer_norm(
)
if in_dygraph_mode():
(
pre_act,
_,
_,
) = _C_ops.layer_norm(x, weight, bias, epsilon, begin_norm_axis)
return dygraph_utils._append_activation_in_dygraph(pre_act, act=None)
out, _, _ = _C_ops.layer_norm(x, weight, bias, epsilon, begin_norm_axis)
return out
if _in_legacy_dygraph():
pre_act, _, _ = _legacy_C_ops.layer_norm(
out, _, _ = _legacy_C_ops.layer_norm(
x,
weight,
bias,
......@@ -400,7 +395,7 @@ def layer_norm(
'begin_norm_axis',
begin_norm_axis,
)
return dygraph_utils._append_activation_in_dygraph(pre_act, act=None)
return out
check_variable_and_dtype(
x, 'input', ['float16', 'float32', 'float64'], 'LayerNorm'
......
......@@ -14,7 +14,6 @@
# TODO: define functions to manipulate a tensor
import warnings
from collections import Counter
import numpy as np
......@@ -22,7 +21,7 @@ import numpy as np
import paddle
from paddle import _C_ops, _legacy_C_ops
from ..common_ops_import import _varbase_creator, dygraph_utils, fill_constant
from ..common_ops_import import _varbase_creator, fill_constant
from ..fluid.data_feeder import (
check_dtype,
check_type,
......@@ -3564,16 +3563,9 @@ def reshape(x, shape, name=None):
"""
actual_shape = None
act = None
inplace = False
if in_dygraph_mode():
tmp_tensor_type = core.eager.Tensor
# TODO(zhiqiu): enable inplace in dygraph mode.
if inplace:
warnings.warn(
"Inplace on reshape is not allowed and will be discarded in dygraph mode currently."
)
if isinstance(shape, (list, tuple)):
shape = [
item.numpy().item(0)
......@@ -3581,8 +3573,11 @@ def reshape(x, shape, name=None):
else item
for item in shape
]
out = _C_ops.reshape(x, shape)
elif isinstance(shape, tmp_tensor_type):
if shape == x.shape:
out = x
else:
out = _C_ops.reshape(x, shape)
elif isinstance(shape, core.eager.Tensor):
shape.stop_gradient = True
out = _C_ops.reshape(x, shape)
else:
......@@ -3591,14 +3586,10 @@ def reshape(x, shape, name=None):
" got '{}.'".format(type(shape))
)
return dygraph_utils._append_activation_in_dygraph(out, act)
return out
else:
if _in_legacy_dygraph():
tmp_tensor_type = Variable
if inplace:
warnings.warn(
"Inplace on reshape is not allowed and will be discarded in dygraph mode currently."
)
if isinstance(shape, (list, tuple)):
shape = [
item.numpy().item(0) if isinstance(item, Variable) else item
......@@ -3614,7 +3605,7 @@ def reshape(x, shape, name=None):
" got '{}.'".format(type(shape))
)
return dygraph_utils._append_activation_in_dygraph(out, act)
return out
check_variable_and_dtype(
x,
......@@ -3690,11 +3681,7 @@ def reshape(x, shape, name=None):
actual_shape.stop_gradient = True
inputs["Shape"] = actual_shape
out = (
x
if inplace
else helper.create_variable_for_type_inference(dtype=x.dtype)
)
out = helper.create_variable_for_type_inference(dtype=x.dtype)
x_shape = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op(
type="reshape2",
......@@ -3703,7 +3690,7 @@ def reshape(x, shape, name=None):
outputs={"Out": out, "XShape": x_shape},
)
return helper.append_activation(out)
return out
@inplace_apis_in_dygraph_only
......@@ -3721,7 +3708,10 @@ def reshape_(x, shape, name=None):
else item
for item in shape
]
out = _C_ops.reshape_(x, shape)
if shape == x.shape:
out = x
else:
out = _C_ops.reshape_(x, shape)
elif isinstance(shape, tmp_tensor_type):
shape.stop_gradient = True
out = _C_ops.reshape_(x, shape)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册