未验证 提交 f1873b90 编写于 作者: W Weilong Wu 提交者: GitHub

[Eager] use eager final state instead intermediate state (#44722)

* [Eager] call final_state_slice under eager mode

* rm useless comments

* use eager final state instead intermidiate state

* update fill_constant yaml

* update fill_constant yaml

* modify wrapped_infermeta_gen logic to fix special case

* fix slice in manipulation

* use fill_constant_

* modify slice infermeta

* rm final_state_conv2d

* use final_state_slice

* use final_state_slice only

* polish slice, use final state

* add paddle_throw for SplitInferMeta

* rm fill_constant_ temply

* recover array_equal, not allclose

* recover original code
上级 2cf2e786
...@@ -825,8 +825,7 @@ static PyObject* tensor__getitem_index_not_tensor(TensorObject* self, ...@@ -825,8 +825,7 @@ static PyObject* tensor__getitem_index_not_tensor(TensorObject* self,
} }
paddle::experimental::Tensor new_out; paddle::experimental::Tensor new_out;
framework::AttributeMap attrs = {{"axes", none_axes}}; new_out = unsqueeze_final_state_dygraph_function(out, none_axes);
new_out = std::get<0>(unsqueeze2_dygraph_function(out, std::move(attrs)));
return ToPyObject(new_out); return ToPyObject(new_out);
} }
} }
......
...@@ -276,7 +276,7 @@ ...@@ -276,7 +276,7 @@
func : assign_value func : assign_value
param : [shape, dtype, values] param : [shape, dtype, values]
data_type : dtype data_type : dtype
backend : place > output backend : place > output
# atan # atan
- api : atan - api : atan
......
...@@ -2614,16 +2614,22 @@ void SliceRawInferMeta(const MetaTensor& input, ...@@ -2614,16 +2614,22 @@ void SliceRawInferMeta(const MetaTensor& input,
// To be compatible with other op tests in which infer_flags is not set. // To be compatible with other op tests in which infer_flags is not set.
infer_flags = std::vector<int64_t>(axes.size(), 1); infer_flags = std::vector<int64_t>(axes.size(), 1);
} }
auto new_axes = axes;
for (auto& axis : new_axes) {
if (axis < 0) {
axis = std::max(int64_t(0), axis + int64_t(in_dims.size()));
}
}
// 2.1 Check attrs. // 2.1 Check attrs.
std::vector<int64_t> starts = starts_arr.GetData(); std::vector<int64_t> starts = starts_arr.GetData();
std::vector<int64_t> ends = ends_arr.GetData(); std::vector<int64_t> ends = ends_arr.GetData();
phi::funcs::CheckAndUpdateSliceAttrs<int64_t>( phi::funcs::CheckAndUpdateSliceAttrs<int64_t>(
in_dims, axes, &starts, &ends, nullptr, &infer_flags); in_dims, new_axes, &starts, &ends, nullptr, &infer_flags);
auto slice_dims = phi::funcs::GetSliceDims<int64_t>( auto slice_dims = phi::funcs::GetSliceDims<int64_t>(
in_dims, axes, starts, ends, nullptr, &infer_flags); in_dims, new_axes, starts, ends, nullptr, &infer_flags);
if (config.is_runtime) { if (config.is_runtime) {
out_dims = phi::funcs::GetDecreasedDims<int64_t>( out_dims = phi::funcs::GetDecreasedDims<int64_t>(
slice_dims, decrease_axis, &infer_flags); slice_dims, decrease_axis, &infer_flags);
...@@ -2633,7 +2639,7 @@ void SliceRawInferMeta(const MetaTensor& input, ...@@ -2633,7 +2639,7 @@ void SliceRawInferMeta(const MetaTensor& input,
} }
out->set_dims(out_dims); out->set_dims(out_dims);
if (axes.size() > 0 && axes[0] != 0) { if (new_axes.size() > 0 && new_axes[0] != 0) {
out->share_lod(input); out->share_lod(input);
} }
} }
...@@ -2662,6 +2668,13 @@ void SplitInferMeta(const MetaTensor& x, ...@@ -2662,6 +2668,13 @@ void SplitInferMeta(const MetaTensor& x,
const Scalar& axis, const Scalar& axis,
std::vector<MetaTensor*> out, std::vector<MetaTensor*> out,
MetaConfig config) { MetaConfig config) {
if (axis.dtype() == DataType::FLOAT32 || axis.dtype() == DataType::FLOAT64) {
PADDLE_THROW(
phi::errors::InvalidArgument("%s(): argument (position 3) must be "
"int, but got %s",
"split",
"float")); // NOLINT
}
int axis_value = axis.to<int>(); int axis_value = axis.to<int>();
int rank = x.dims().size(); int rank = x.dims().size();
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
......
...@@ -290,6 +290,8 @@ def monkey_patch_math_varbase(): ...@@ -290,6 +290,8 @@ def monkey_patch_math_varbase():
axis = -1 axis = -1
math_op = getattr(_C_ops, op_type) math_op = getattr(_C_ops, op_type)
if call_final_api: if call_final_api:
if op_type == "final_state_matmul":
return math_op(self, other_var, False, False)
return math_op(self, other_var, -1) return math_op(self, other_var, -1)
return math_op(self, other_var, 'axis', axis) return math_op(self, other_var, 'axis', axis)
...@@ -385,10 +387,16 @@ def monkey_patch_math_varbase(): ...@@ -385,10 +387,16 @@ def monkey_patch_math_varbase():
None)), None)),
('__floordiv__', ('__floordiv__',
_binary_creator_('__floordiv__', 'elementwise_floordiv', False, None)), _binary_creator_('__floordiv__', 'elementwise_floordiv', False, None)),
('__mod__', _binary_creator_('__mod__', 'elementwise_mod', False, ('__mod__',
None)), _binary_creator_('__mod__', 'final_state_modulo', False, None, True))
('__matmul__', _binary_creator_('__matmul__', "matmul_v2", False, if framework._in_eager_mode_ else
None)), ('__mod__',
_binary_creator_('__mod__', 'elementwise_mod', False, None)),
('__matmul__',
_binary_creator_('__matmul__', "final_state_matmul", False, None,
True)) if framework._in_eager_mode_ else
('__matmul__',
_binary_creator_('__matmul__', "matmul_v2", False, None)),
## for logical compare ## for logical compare
('__eq__', ('__eq__',
_binary_creator_('__eq__', 'final_state_equal', False, None, True)) _binary_creator_('__eq__', 'final_state_equal', False, None, True))
......
...@@ -23,6 +23,7 @@ from .core import VarDesc ...@@ -23,6 +23,7 @@ from .core import VarDesc
from . import unique_name from . import unique_name
from .data_feeder import check_variable_and_dtype, check_type, check_dtype from .data_feeder import check_variable_and_dtype, check_type, check_dtype
from paddle import _C_ops from paddle import _C_ops
import paddle
__all__ = [ __all__ = [
'Constant', 'Uniform', 'Normal', 'TruncatedNormal', 'Xavier', 'Bilinear', 'Constant', 'Uniform', 'Normal', 'TruncatedNormal', 'Xavier', 'Bilinear',
...@@ -599,9 +600,15 @@ class XavierInitializer(Initializer): ...@@ -599,9 +600,15 @@ class XavierInitializer(Initializer):
if framework._non_static_mode(): if framework._non_static_mode():
if self._uniform: if self._uniform:
limit = math.sqrt(6.0 / float(fan_in + fan_out)) limit = math.sqrt(6.0 / float(fan_in + fan_out))
out_var = _C_ops.uniform_random('shape', out_var.shape, 'min', if in_dygraph_mode():
-limit, 'max', limit, 'seed', out_var = _C_ops.final_state_uniform_random(
self._seed, 'dtype', out_dtype) out_var.shape, out_dtype, -limit, limit, self._seed,
_current_expected_place())
elif _in_legacy_dygraph():
out_var = _C_ops.uniform_random('shape', out_var.shape,
'min', -limit, 'max', limit,
'seed', self._seed, 'dtype',
out_dtype)
else: else:
std = math.sqrt(2.0 / float(fan_in + fan_out)) std = math.sqrt(2.0 / float(fan_in + fan_out))
...@@ -617,8 +624,11 @@ class XavierInitializer(Initializer): ...@@ -617,8 +624,11 @@ class XavierInitializer(Initializer):
if var.dtype == VarDesc.VarType.FP16 or ( if var.dtype == VarDesc.VarType.FP16 or (
var.dtype == VarDesc.VarType.BF16 and not self._uniform): var.dtype == VarDesc.VarType.BF16 and not self._uniform):
var_tmp = _C_ops.cast(out_var, 'in_dtype', out_var.dtype, if in_dygraph_mode():
'out_dtype', var.dtype) var_tmp = _C_ops.final_state_cast(out_var, var.dtype)
elif _in_legacy_dygraph():
var_tmp = _C_ops.cast(out_var, 'in_dtype', out_var.dtype,
'out_dtype', var.dtype)
var_tmp._share_underline_tensor_to(var) var_tmp._share_underline_tensor_to(var)
else: else:
out_var._share_underline_tensor_to(var) out_var._share_underline_tensor_to(var)
......
...@@ -21,7 +21,7 @@ from op_test import OpTest, convert_float_to_uint16 ...@@ -21,7 +21,7 @@ from op_test import OpTest, convert_float_to_uint16
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.fluid.layers as layers import paddle.fluid.layers as layers
import paddle import paddle
from paddle.fluid.framework import _test_eager_guard from paddle.fluid.framework import _test_eager_guard, _enable_legacy_dygraph
paddle.enable_static() paddle.enable_static()
...@@ -787,7 +787,6 @@ class TestInferShape(unittest.TestCase): ...@@ -787,7 +787,6 @@ class TestInferShape(unittest.TestCase):
self.assertEqual(out0.shape, (3, 3, 5)) self.assertEqual(out0.shape, (3, 3, 5))
def test_axis_less_than_zero(self): def test_axis_less_than_zero(self):
# Using paddle.disable_static will make other unittests fail. # Using paddle.disable_static will make other unittests fail.
with fluid.dygraph.guard(): with fluid.dygraph.guard():
x_arr = np.arange(0, 24, dtype=np.float32).reshape([2, 3, 4]) x_arr = np.arange(0, 24, dtype=np.float32).reshape([2, 3, 4])
...@@ -829,6 +828,7 @@ class TestInferShape(unittest.TestCase): ...@@ -829,6 +828,7 @@ class TestInferShape(unittest.TestCase):
class TestImperativeCUDAPinnedInput(unittest.TestCase): class TestImperativeCUDAPinnedInput(unittest.TestCase):
def test_input_cuda_pinned_var(self): def test_input_cuda_pinned_var(self):
_enable_legacy_dygraph()
with fluid.dygraph.guard(): with fluid.dygraph.guard():
data = np.random.random((2, 80, 16128)).astype('float32') data = np.random.random((2, 80, 16128)).astype('float32')
var = core.VarBase(value=data, var = core.VarBase(value=data,
......
...@@ -490,18 +490,31 @@ def _getitem_impl_(var, item): ...@@ -490,18 +490,31 @@ def _getitem_impl_(var, item):
out = var out = var
if len(axes) > 0: if len(axes) > 0:
target_block = default_main_program().current_block()
op_type = "strided_slice" if use_strided_slice else "slice" op_type = "strided_slice" if use_strided_slice else "slice"
if paddle.fluid.framework.in_dygraph_mode() and op_type == "slice":
slice_out_var = target_block.create_var( if "StartsTensorList" in inputs.keys():
name=unique_name.generate_with_ignorable_key(var.name + "_" + st = inputs['StartsTensorList']
op_type), else:
dtype=var.dtype) st = attrs['starts']
target_block.append_op(type=op_type, if "EndsTensorList" in inputs.keys():
inputs=inputs, end = inputs['EndsTensorList']
outputs={'Out': [slice_out_var]}, else:
attrs=attrs) end = attrs['ends']
out = slice_out_var out = paddle._C_ops.final_state_slice(var, axes, st, end,
attrs['infer_flags'],
attrs['decrease_axis'])
else:
target_block = default_main_program().current_block()
slice_out_var = target_block.create_var(
name=unique_name.generate_with_ignorable_key(var.name + "_" +
op_type),
dtype=var.dtype)
target_block.append_op(type=op_type,
inputs=inputs,
outputs={'Out': [slice_out_var]},
attrs=attrs)
out = slice_out_var
if len(reverse_axes) > 0: if len(reverse_axes) > 0:
from .layers.tensor import reverse from .layers.tensor import reverse
......
...@@ -964,7 +964,9 @@ def silu(x, name=None): ...@@ -964,7 +964,9 @@ def silu(x, name=None):
out = F.silu(x) # [ 0.731059, 1.761594, 2.857722, 3.928055 ] out = F.silu(x) # [ 0.731059, 1.761594, 2.857722, 3.928055 ]
""" """
if in_dynamic_mode(): if in_dygraph_mode():
return _C_ops.final_state_silu(x)
if _in_legacy_dygraph():
return _C_ops.silu(x) return _C_ops.silu(x)
check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], 'silu') check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], 'silu')
......
...@@ -82,7 +82,7 @@ def normalize(x, p=2, axis=1, epsilon=1e-12, name=None): ...@@ -82,7 +82,7 @@ def normalize(x, p=2, axis=1, epsilon=1e-12, name=None):
if in_dygraph_mode(): if in_dygraph_mode():
eps = fluid.dygraph.base.to_variable([epsilon], dtype=x.dtype) eps = fluid.dygraph.base.to_variable([epsilon], dtype=x.dtype)
out = _C_ops.final_state_p_norm(x, float(p), axis, epsilon, True, False) out = _C_ops.final_state_p_norm(x, float(p), axis, epsilon, True, False)
return x / _C_ops.elementwise_max(out, eps) return x / _C_ops.final_state_maximum(out, eps)
if _in_legacy_dygraph(): if _in_legacy_dygraph():
eps = fluid.dygraph.base.to_variable([epsilon], dtype=x.dtype) eps = fluid.dygraph.base.to_variable([epsilon], dtype=x.dtype)
......
...@@ -49,7 +49,7 @@ from .. import functional as F ...@@ -49,7 +49,7 @@ from .. import functional as F
from paddle import _C_ops from paddle import _C_ops
from .. import Layer from .. import Layer
from paddle import in_dynamic_mode from paddle import in_dynamic_mode
from paddle.fluid.framework import in_dygraph_mode from paddle.fluid.framework import in_dygraph_mode, _in_legacy_dygraph
__all__ = [] __all__ = []
...@@ -411,7 +411,15 @@ class GroupNorm(Layer): ...@@ -411,7 +411,15 @@ class GroupNorm(Layer):
variance_out = self._helper.create_variable_for_type_inference( variance_out = self._helper.create_variable_for_type_inference(
dtype=input.dtype, stop_gradient=True) dtype=input.dtype, stop_gradient=True)
if _non_static_mode(): if in_dygraph_mode():
pre_act = _C_ops.final_state_group_norm(input, self.weight,
self.bias, self._epsilon,
self._num_groups, "NCHW")
return dygraph_utils._append_activation_in_dygraph(pre_act,
act=None)
elif _in_legacy_dygraph():
pre_act, _, _ = _C_ops.group_norm( pre_act, _, _ = _C_ops.group_norm(
input, input,
self.weight, self.weight,
......
...@@ -197,10 +197,9 @@ def slice(input, axes, starts, ends): ...@@ -197,10 +197,9 @@ def slice(input, axes, starts, ends):
if isinstance(item, tmp_tensor_type) else item if isinstance(item, tmp_tensor_type) else item
for item in starts for item in starts
] ]
attrs += ('starts', starts)
elif isinstance(starts, tmp_tensor_type): elif isinstance(starts, tmp_tensor_type):
starts_tensor = starts tensor_t = starts.numpy()
starts.stop_gradient = True starts = [ele for ele in tensor_t]
infer_flags = list(-1 for i in range(len(axes))) infer_flags = list(-1 for i in range(len(axes)))
if isinstance(ends, (list, tuple)): if isinstance(ends, (list, tuple)):
...@@ -208,13 +207,13 @@ def slice(input, axes, starts, ends): ...@@ -208,13 +207,13 @@ def slice(input, axes, starts, ends):
item.numpy().item(0) item.numpy().item(0)
if isinstance(item, tmp_tensor_type) else item for item in ends if isinstance(item, tmp_tensor_type) else item for item in ends
] ]
attrs += ('ends', ends)
elif isinstance(ends, tmp_tensor_type): elif isinstance(ends, tmp_tensor_type):
ends_tensor = ends etensor_t = ends.numpy()
ends_tensor.stop_gradient = True ends = [ele for ele in tensor_t]
infer_flags = list(-1 for i in range(len(axes))) infer_flags = list(-1 for i in range(len(axes)))
return _C_ops.slice(input, starts_tensor, ends_tensor, None, None,
'axes', axes, 'infer_flags', infer_flags, *attrs) return _C_ops.final_state_slice(input, axes, starts, ends, infer_flags,
[])
else: else:
if _in_legacy_dygraph(): if _in_legacy_dygraph():
attrs = () attrs = ()
...@@ -1817,9 +1816,14 @@ def split(x, num_or_sections, axis=0, name=None): ...@@ -1817,9 +1816,14 @@ def split(x, num_or_sections, axis=0, name=None):
raise TypeError( raise TypeError(
"The type of 'num_or_sections' in split must be int, list or tuple in imperative mode, but " "The type of 'num_or_sections' in split must be int, list or tuple in imperative mode, but "
"received %s." % (type(num_or_sections))) "received %s." % (type(num_or_sections)))
out = [_varbase_creator() for n in range(num)] if in_dygraph_mode():
_C_ops.split(input, out, *attrs) return _C_ops.final_state_split(
return out input, [num_or_sections]
if isinstance(num_or_sections, int) else num_or_sections, dim)
elif _in_legacy_dygraph():
out = [_varbase_creator() for n in range(num)]
_C_ops.split(input, out, *attrs)
return out
check_variable_and_dtype(input, 'input', [ check_variable_and_dtype(input, 'input', [
'bool', 'float16', 'float32', 'float64', 'int32', 'int64', 'uint8', 'bool', 'float16', 'float32', 'float64', 'int32', 'int64', 'uint8',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册