From aaabb796a8ae6cf5b4ab24997f2cb898ef7d0b48 Mon Sep 17 00:00:00 2001 From: wanghuancoder Date: Mon, 18 Apr 2022 20:49:20 +0800 Subject: [PATCH] [Eager] use final op in maskrcnn and hrnet (#41927) * update * add conv yaml * add backward * remove useless code * fix bug * fix bug * revert fluid dygraph conv2d * remove useless infermeta function * fix meta fn deluplicat error * conv using custom impl * remove amp include * fix bug * use final op in maskrcnn and hrnet * refine Co-authored-by: phlrain --- paddle/fluid/pybind/eager_method.cc | 23 ++-- paddle/fluid/pybind/eager_utils.cc | 24 ++-- python/paddle/fluid/clip.py | 2 +- python/paddle/fluid/dygraph/math_op_patch.py | 125 ++++++++++++------- python/paddle/fluid/dygraph/tracer.py | 11 ++ python/paddle/fluid/layers/control_flow.py | 17 +-- python/paddle/fluid/layers/nn.py | 44 ++++++- python/paddle/fluid/optimizer.py | 6 +- python/paddle/fluid/regularizer.py | 6 +- python/paddle/nn/functional/loss.py | 54 ++++++-- python/paddle/nn/layer/loss.py | 8 +- python/paddle/optimizer/optimizer.py | 4 +- python/paddle/tensor/linalg.py | 50 +++++--- python/paddle/tensor/math.py | 46 +++++-- python/paddle/tensor/search.py | 4 +- 15 files changed, 297 insertions(+), 127 deletions(-) diff --git a/paddle/fluid/pybind/eager_method.cc b/paddle/fluid/pybind/eager_method.cc index 542d59318bb..17908e80de4 100644 --- a/paddle/fluid/pybind/eager_method.cc +++ b/paddle/fluid/pybind/eager_method.cc @@ -41,6 +41,7 @@ limitations under the License. */ #include "paddle/phi/core/sparse_csr_tensor.h" #include "pybind11/detail/internals.h" #pragma GCC diagnostic ignored "-Wmissing-field-initializers" +#include "paddle/fluid/eager/api/generated/eager_generated/forwards/dygraph_functions.h" #include "paddle/fluid/framework/python_headers.h" #include "paddle/fluid/memory/allocation/mmap_allocator.h" #include "paddle/fluid/pybind/tensor_py.h" @@ -713,15 +714,19 @@ static PyObject* tensor__getitem_index_not_tensor(TensorObject* self, break; } } + std::vector slice_axes_tmp(slice_axes.begin(), slice_axes.end()); + std::vector infer_flags_tmp(infer_flags.begin(), + infer_flags.end()); + std::vector decrease_axis_tmp(decrease_axis.begin(), + decrease_axis.end()); + if (op_type == "slice") { - out = slice_dygraph_function(self->tensor, paddle::experimental::Tensor(), - paddle::experimental::Tensor(), {}, {}, - std::move(attrs)); + out = slice_final_state_dygraph_function( + self->tensor, slice_axes_tmp, slice_starts, slice_ends, + infer_flags_tmp, decrease_axis_tmp); } else if (op_type == "strided_slice") { - out = strided_slice_dygraph_function( - self->tensor, paddle::experimental::Tensor(), - paddle::experimental::Tensor(), paddle::experimental::Tensor(), {}, - {}, {}, attrs); + out = strided_slice_final_state_dygraph_function( + self->tensor, slice_axes, slice_starts, slice_ends, slice_strides); } else { PADDLE_THROW(platform::errors::InvalidArgument( "Slice is only support slice and strided_slice, but we got %s which " @@ -776,8 +781,8 @@ static PyObject* tensor__getitem_index_not_tensor(TensorObject* self, paddle::framework::TensorFromVector(list_select_idxs, *dev_ctx, idx_tensor.get()); framework::AttributeMap attrs = {{"dim", 0}}; - out = index_select_dygraph_function(self->tensor, select_index, - std::move(attrs)); + out = index_select_final_state_dygraph_function(self->tensor, select_index, + 0); } return ToPyObject(out); diff --git a/paddle/fluid/pybind/eager_utils.cc b/paddle/fluid/pybind/eager_utils.cc index 081e2783826..18337f36ca2 100644 --- a/paddle/fluid/pybind/eager_utils.cc +++ b/paddle/fluid/pybind/eager_utils.cc @@ -1027,22 +1027,22 @@ paddle::experimental::Scalar CastPyArg2Scalar(PyObject* obj, // obj could be: int, float, bool, paddle.Tensor PyTypeObject* type = obj->ob_type; auto type_name = std::string(type->tp_name); - if (type_name == "int") { + if (PyBool_Check(obj)) { + bool value = CastPyArg2Boolean(obj, op_type, arg_pos); + return paddle::experimental::Scalar(value); + } else if (PyLong_Check(obj)) { int value = CastPyArg2Int(obj, op_type, arg_pos); return paddle::experimental::Scalar(value); - } else if (type_name == "float") { + } else if (PyFloat_Check(obj)) { float value = CastPyArg2Float(obj, op_type, arg_pos); return paddle::experimental::Scalar(value); - - } else if (type_name == "bool") { - bool value = CastPyArg2Boolean(obj, op_type, arg_pos); - return paddle::experimental::Scalar(value); - - } else if (type_name == "Tensor") { + } else if (IsEagerTensor(obj)) { paddle::experimental::Tensor& value = GetTensorFromPyObject( op_type, "" /*arg_name*/, obj, arg_pos, false /*dispensable*/); return paddle::experimental::Scalar(value); - + } else if (PyObject_CheckLongOrToLong(&obj)) { + int value = CastPyArg2Int(obj, op_type, arg_pos); + return paddle::experimental::Scalar(value); } else { PADDLE_THROW(platform::errors::InvalidArgument( "%s(): argument (position %d) must be " @@ -1159,11 +1159,7 @@ paddle::Place CastPyArg2Place(PyObject* obj, const std::string& op_type, paddle::DataType CastPyArg2DataType(PyObject* obj, const std::string& op_type, ssize_t arg_pos) { if (obj == Py_None) { - PADDLE_THROW(platform::errors::InvalidArgument( - "%s(): argument (position %d) must be " - "data_type, but got %s", - op_type, arg_pos + 1, - ((PyTypeObject*)obj->ob_type)->tp_name)); // NOLINT + return paddle::experimental::DataType::UNDEFINED; } framework::proto::VarType::Type type = CastPyArg2ProtoType(obj, arg_pos); diff --git a/python/paddle/fluid/clip.py b/python/paddle/fluid/clip.py index 826deae498c..0ba980c3e92 100644 --- a/python/paddle/fluid/clip.py +++ b/python/paddle/fluid/clip.py @@ -543,7 +543,7 @@ class ClipGradByGlobalNorm(ClipGradBase): clip_input = (clip_var.astype('float16') if g.dtype == core.VarDesc.VarType.FP16 else clip_var) - new_grad = layers.elementwise_mul(x=g, y=clip_input) + new_grad = _C_ops.elementwise_mul(g, clip_input) params_and_grads.append((p, new_grad)) else: params_and_grads.append((p, g)) diff --git a/python/paddle/fluid/dygraph/math_op_patch.py b/python/paddle/fluid/dygraph/math_op_patch.py index 5b305325f3d..8ce56d5a926 100644 --- a/python/paddle/fluid/dygraph/math_op_patch.py +++ b/python/paddle/fluid/dygraph/math_op_patch.py @@ -62,15 +62,6 @@ _complex_dtypes = [ _already_patch_varbase = False _already_patch_eager_tensor = False -# Dispatch to final state Python-C functions -_final_state_op_type_mapping = { - "elementwise_add": "final_state_add", - "elementwise_sub": "final_state_subtract", - "elementwise_div": "final_state_divide", - "elementwise_mul": "final_state_multiply", - "matmul_v2": "final_state_matmul", -} - def monkey_patch_math_varbase(): """ @@ -80,9 +71,13 @@ def monkey_patch_math_varbase(): @no_grad def create_tensor(value, dtype, shape): - out = _varbase_creator(dtype=dtype) - out = _C_ops.fill_constant(out, 'dtype', dtype, 'shape', shape, 'value', - value, 'force_cpu', False) + if framework._in_eager_mode_: + out = _C_ops.final_state_full(shape, value, dtype, + framework._current_expected_place()) + else: + out = _varbase_creator(dtype=dtype) + out = _C_ops.fill_constant(out, 'dtype', dtype, 'shape', shape, + 'value', value, 'force_cpu', False) out.stop_gradient = True return out @@ -120,9 +115,9 @@ def monkey_patch_math_varbase(): return _C_ops.final_state_cast(self, dtype) def _scalar_elementwise_op_(var, scale, bias): - if _in_legacy_dygraph(): - return _C_ops.scale(var, 'scale', scale, 'bias', bias) - return _C_ops.final_state_scale(var, float(scale), bias, True) + if framework.in_dygraph_mode(): + return _C_ops.final_state_scale(var, float(scale), bias, True) + return _C_ops.scale(var, 'scale', scale, 'bias', bias) def _neg_(var): return _scalar_elementwise_op_(var, -1.0, 0.0) @@ -203,7 +198,8 @@ def monkey_patch_math_varbase(): def _binary_creator_(method_name, op_type, reverse=False, - scalar_method=None): + scalar_method=None, + call_final_api=False): def __impl__(self, other_var): # 1. scalar exists cases # we need combine the tensor.dtype and scalar.dtype, cast correct object @@ -287,15 +283,15 @@ def monkey_patch_math_varbase(): # 4. calculation axis = -1 - if in_dygraph_mode( - ) and op_type in _final_state_op_type_mapping.keys(): - math_op = getattr(_C_ops, _final_state_op_type_mapping[op_type]) - return math_op(self, other_var) - else: - math_op = getattr(_C_ops, op_type) - return math_op(self, other_var, 'axis', axis) + math_op = getattr(_C_ops, op_type) + if call_final_api: + return math_op(self, other_var, -1) + return math_op(self, other_var, 'axis', axis) - comment = OpProtoHolder.instance().get_op_proto(op_type).comment + if call_final_api: + comment = "" + else: + comment = OpProtoHolder.instance().get_op_proto(op_type).comment __impl__.__doc__ = """ {0} @@ -321,28 +317,48 @@ def monkey_patch_math_varbase(): ('ndim', _ndim_), ('size', _size_), ('T', _T_), - ('__add__', - _binary_creator_('__add__', 'elementwise_add', False, _scalar_add_)), + ('__add__', _binary_creator_('__add__', 'final_state_add', False, + _scalar_add_, True)) + if framework._in_eager_mode_ else ('__add__', _binary_creator_( + '__add__', 'elementwise_add', False, _scalar_add_)), ## a+b == b+a. Do not need to reverse explicitly - ('__radd__', - _binary_creator_('__radd__', 'elementwise_add', False, _scalar_add_)), - ('__sub__', _binary_creator_('__sub__', 'elementwise_sub', False, - _scalar_sub_)), - ('__rsub__', _binary_creator_('__rsub__', 'elementwise_sub', True, - _scalar_rsub_)), - ('__mul__', _binary_creator_('__mul__', 'elementwise_mul', False, - _scalar_mul_)), + ('__radd__', _binary_creator_('__radd__', 'final_state_add', False, + _scalar_add_, True)) + if framework._in_eager_mode_ else ('__radd__', _binary_creator_( + '__radd__', 'elementwise_add', False, _scalar_add_)), + ('__sub__', _binary_creator_('__sub__', 'final_state_subtract', False, + _scalar_sub_, True)) + if framework._in_eager_mode_ else ('__sub__', _binary_creator_( + '__sub__', 'elementwise_sub', False, _scalar_sub_)), + ('__rsub__', _binary_creator_('__rsub__', 'final_state_subtract', True, + _scalar_rsub_, True)) + if framework._in_eager_mode_ else ('__rsub__', _binary_creator_( + '__rsub__', 'elementwise_sub', True, _scalar_rsub_)), + ('__mul__', _binary_creator_('__mul__', 'final_state_multiply', False, + _scalar_mul_, True)) + if framework._in_eager_mode_ else ('__mul__', _binary_creator_( + '__mul__', 'elementwise_mul', False, _scalar_mul_)), ## a*b == b*a. Do not need to reverse explicitly - ('__rmul__', - _binary_creator_('__rmul__', 'elementwise_mul', False, _scalar_mul_)), - ('__div__', _binary_creator_('__div__', 'elementwise_div', False, - _scalar_div_)), - ('__truediv__', _binary_creator_('__truediv__', 'elementwise_div', - False, _scalar_div_)), - ('__rdiv__', _binary_creator_('__rdiv__', 'elementwise_div', True, - None)), - ('__rtruediv__', _binary_creator_('rtruediv__', 'elementwise_div', True, - None)), + ('__rmul__', _binary_creator_('__rmul__', 'final_state_multiply', False, + _scalar_mul_, True)) + if framework._in_eager_mode_ else ('__rmul__', _binary_creator_( + '__rmul__', 'elementwise_mul', False, _scalar_mul_)), + ('__div__', _binary_creator_('__div__', 'final_state_divide', False, + _scalar_div_, True)) + if framework._in_eager_mode_ else ('__div__', _binary_creator_( + '__div__', 'elementwise_div', False, _scalar_div_)), + ('__truediv__', _binary_creator_('__truediv__', 'final_state_divide', + False, _scalar_div_, True)) + if framework._in_eager_mode_ else ('__truediv__', _binary_creator_( + '__truediv__', 'elementwise_div', False, _scalar_div_)), + ('__rdiv__', _binary_creator_('__rdiv__', 'final_state_divide', True, + None, True)) if framework._in_eager_mode_ + else ('__rdiv__', _binary_creator_('__rdiv__', 'elementwise_div', True, + None)), + ('__rtruediv__', _binary_creator_('rtruediv__', 'final_state_divide', + True, None, True)) + if framework._in_eager_mode_ else ('__rtruediv__', _binary_creator_( + 'rtruediv__', 'elementwise_div', True, None)), ('__pow__', _binary_creator_('__pow__', 'elementwise_pow', False, None)), ('__rpow__', _binary_creator_('__rpow__', 'elementwise_pow', True, @@ -354,11 +370,26 @@ def monkey_patch_math_varbase(): ('__matmul__', _binary_creator_('__matmul__', "matmul_v2", False, None)), ## for logical compare + ('__eq__', + _binary_creator_('__eq__', 'final_state_equal', False, None, True)) + if framework._in_eager_mode_ else ('__eq__', _binary_creator_('__eq__', 'equal', False, None)), - ('__ne__', _binary_creator_('__ne__', 'not_equal', False, None)), - ('__lt__', _binary_creator_('__lt__', 'less_than', False, None)), - ('__le__', _binary_creator_('__le__', 'less_equal', False, None)), + ('__ne__', _binary_creator_('__ne__', 'final_state_not_equal', False, + None, True)) if framework._in_eager_mode_ + else ('__ne__', _binary_creator_('__ne__', 'not_equal', False, None)), + ('__lt__', _binary_creator_('__lt__', 'final_state_less_than', False, + None, True)) if framework._in_eager_mode_ + else ('__lt__', _binary_creator_('__lt__', 'less_than', False, None)), + ('__le__', _binary_creator_('__le__', 'final_state_less_equal', False, + None, True)) if framework._in_eager_mode_ + else ('__le__', _binary_creator_('__le__', 'less_equal', False, None)), + ('__gt__', _binary_creator_('__gt__', 'final_state_greater_than', False, + None, True)) + if framework._in_eager_mode_ else ('__gt__', _binary_creator_('__gt__', 'greater_than', False, None)), + ('__ge__', _binary_creator_('__ge__', 'final_state_greater_equal', + False, None, True)) + if framework._in_eager_mode_ else ('__ge__', _binary_creator_('__ge__', 'greater_equal', False, None)), ('__array_ufunc__', None) ] diff --git a/python/paddle/fluid/dygraph/tracer.py b/python/paddle/fluid/dygraph/tracer.py index 6e1ed6b0a1d..44a49148ca0 100644 --- a/python/paddle/fluid/dygraph/tracer.py +++ b/python/paddle/fluid/dygraph/tracer.py @@ -72,6 +72,17 @@ final_state_name_mapping = { "axis2": "axis2", "out": "Out", }, + "roi_align": { + "final_op_name": "final_state_roi_align", + "x": "X", + "boxes": "ROIs", + "boxes_num": "RoisNum", + "pooled_height": "pooled_height", + "pooled_width": "pooled_width", + "spatial_scale": "spatial_scale", + "sampling_ratio": "sampling_ratio", + "aligned": "aligned", + }, # "one_hot": { # "final_op_name": "final_state_one_hot", # "x": "X", diff --git a/python/paddle/fluid/layers/control_flow.py b/python/paddle/fluid/layers/control_flow.py index 184453a6fcb..d143a6637f8 100755 --- a/python/paddle/fluid/layers/control_flow.py +++ b/python/paddle/fluid/layers/control_flow.py @@ -1796,13 +1796,16 @@ def greater_than(x, y, cond=None, name=None): attrs = dict() - helper.append_op( - type='greater_than', - inputs={'X': [x], - 'Y': [y]}, - outputs={'Out': [cond]}, - attrs=attrs) - return cond + if in_dygraph_mode(): + return _C_ops.final_state_greater_than(x, y, -1) + else: + helper.append_op( + type='greater_than', + inputs={'X': [x], + 'Y': [y]}, + outputs={'Out': [cond]}, + attrs=attrs) + return cond @templatedoc() diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 47f40a2e6a5..1fdf5994834 100755 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -196,6 +196,17 @@ __all__ = [ 'unbind', ] +OP_NAMEMAPPING = { + 'elementwise_max': 'final_state_maximum', + 'elementwise_min': 'final_state_minimum', + 'elementwise_pow': 'final_state_elementwise_pow', + 'elementwise_floordiv': 'final_state_floor_divide', + 'elementwise_add': 'final_state_add', + 'elementwise_sub': 'final_state_subtract', + 'elementwise_mul': 'final_state_multiply', + 'elementwise_div': 'final_state_divide', +} + @dygraph_only def _elementwise_op_in_dygraph(x, @@ -204,8 +215,21 @@ def _elementwise_op_in_dygraph(x, act=None, use_mkldnn=False, op_name=None): - op = getattr(_C_ops, op_name) - out = op(x, y, 'axis', axis, 'use_mkldnn', use_mkldnn) + def is_inplace(op_name): + return op_name[-1] == "_" + + if op_name not in OP_NAMEMAPPING.keys() or axis != -1: + op = getattr(_C_ops, op_name) + out = op(x, y, 'axis', axis, 'use_mkldnn', use_mkldnn) + else: + if in_dygraph_mode(): + op = getattr(_C_ops, OP_NAMEMAPPING[op_name] + if not is_inplace(op_name) else op_name) + out = op(x, y) + + if _in_legacy_dygraph(): + op = getattr(_C_ops, op_name) + out = op(x, y, 'axis', axis, 'use_mkldnn', use_mkldnn) return dygraph_utils._append_activation_in_dygraph( out, act, use_mkldnn=use_mkldnn) @@ -5093,9 +5117,12 @@ def split(input, num_or_sections, dim=-1, name=None): raise TypeError( "The type of 'num_or_sections' in split must be int, list or tuple in imperative mode, but " "received %s." % (type(num_or_sections))) - out = [_varbase_creator() for n in range(num)] - _C_ops.split(input, out, *attrs) - return out + if in_dygraph_mode(): + return _C_ops.final_state_split(input, [num], dim) + elif _in_legacy_dygraph(): + out = [_varbase_creator() for n in range(num)] + _C_ops.split(input, out, *attrs) + return out check_variable_and_dtype( input, 'input', @@ -7284,7 +7311,12 @@ def roi_align(input, sampling_ratio=-1, rois_num=rois_num) """ - if _non_static_mode(): + if in_dygraph_mode(): + assert rois_num is not None, "rois_num should not be None in dygraph mode." + return _C_ops.final_state_roi_align( + input, rois, rois_num, pooled_height, pooled_width, spatial_scale, + sampling_ratio, False) + if _in_legacy_dygraph(): assert rois_num is not None, "rois_num should not be None in dygraph mode." align_out = _C_ops.roi_align( input, rois, rois_num, "pooled_height", pooled_height, diff --git a/python/paddle/fluid/optimizer.py b/python/paddle/fluid/optimizer.py index 95db9d39c1e..bb14fb9a86f 100755 --- a/python/paddle/fluid/optimizer.py +++ b/python/paddle/fluid/optimizer.py @@ -2848,7 +2848,11 @@ class AdamaxOptimizer(Optimizer): beta1_pow_acc = self._get_accumulator(self._beta1_pow_acc_str, param) if framework._non_static_mode(): - tmp = _C_ops.scale(beta1_pow_acc, "scale", self._beta1) + if framework.in_dygraph_mode(): + tmp = _C_ops.final_state_scale(beta1_pow_acc, + self._beta1, 0.0, True) + else: + tmp = _C_ops.scale(beta1_pow_acc, "scale", self._beta1) beta1_pow_acc.copy_(tmp, False) else: block.append_op( diff --git a/python/paddle/fluid/regularizer.py b/python/paddle/fluid/regularizer.py index d58ef6ddd52..ed28a2813e2 100644 --- a/python/paddle/fluid/regularizer.py +++ b/python/paddle/fluid/regularizer.py @@ -134,7 +134,11 @@ class L2DecayRegularizer(WeightDecayRegularizer): assert isinstance(block, framework.Block) if framework._non_static_mode(): - return _C_ops.scale(param, "scale", self._regularization_coeff) + if framework.in_dygraph_mode(): + return _C_ops.final_state_scale( + param, self._regularization_coeff, 0.0, True) + else: + return _C_ops.scale(param, "scale", self._regularization_coeff) else: decay = block.create_var( dtype=param.dtype, shape=param.shape, lod_level=param.lod_level) diff --git a/python/paddle/nn/functional/loss.py b/python/paddle/nn/functional/loss.py index 62f034c7b41..ca3ac177282 100755 --- a/python/paddle/nn/functional/loss.py +++ b/python/paddle/nn/functional/loss.py @@ -37,7 +37,7 @@ from paddle.utils import deprecated from paddle import _C_ops from paddle import in_dynamic_mode from paddle.framework import core -from ...fluid.framework import _in_legacy_dygraph, in_dygraph_mode, _non_static_mode +from ...fluid.framework import _in_legacy_dygraph, in_dygraph_mode, _non_static_mode, _current_expected_place __all__ = [] @@ -116,13 +116,13 @@ def binary_cross_entropy(input, label, weight=None, reduction='mean', if in_dygraph_mode(): out = _C_ops.final_state_bce_loss(input, label) if weight is not None: - out = _C_ops.elementwise_mul(out, weight, 'axis', -1) + out = _C_ops.final_state_multiply(out, weight, 'axis', -1) if reduction == 'sum': return _C_ops.reduce_sum(out, 'dim', [0], 'keep_dim', False, "reduce_all", True) elif reduction == 'mean': - return _C_ops.mean(out) + return _C_ops.final_state_mean_all(out) else: return out else: @@ -260,14 +260,17 @@ def binary_cross_entropy_with_logits(logit, % reduction) if _non_static_mode(): - one = _varbase_creator(dtype=logit.dtype) - _C_ops.fill_constant(one, 'value', - float(1.0), 'force_cpu', False, 'dtype', one.dtype, - 'str_value', '1.0', 'shape', [1]) if in_dygraph_mode(): + one = _C_ops.final_state_full([1], + float(1.0), core.VarDesc.VarType.FP32, + _current_expected_place()) out = _C_ops.final_state_sigmoid_cross_entropy_with_logits( logit, label, False, -100) else: + one = _varbase_creator(dtype=logit.dtype) + _C_ops.fill_constant(one, 'value', + float(1.0), 'force_cpu', False, 'dtype', + one.dtype, 'str_value', '1.0', 'shape', [1]) out = _C_ops.sigmoid_cross_entropy_with_logits(logit, label) if pos_weight is not None: log_weight = _C_ops.elementwise_add( @@ -405,7 +408,7 @@ def hsigmoid_loss(input, # [2.2407534]] """ - if in_dynamic_mode(): + if _non_static_mode(): out, _, _ = _C_ops.hierarchical_sigmoid( input, weight, label, path_table, path_code, bias, 'num_classes', num_classes, 'is_sparse', is_sparse, 'remote_prefetch', is_sparse) @@ -582,7 +585,19 @@ def margin_ranking_loss(input, raise ValueError( "The value of 'reduction' in MarginRankingLoss should be 'sum', 'mean' or 'none', but " "received %s, which is not allowed." % reduction) - if in_dynamic_mode(): + if in_dygraph_mode(): + out = _C_ops.final_state_subtract(other, input) + out = _C_ops.final_state_multiply(out, label) + if margin != 0.0: + margin = fluid.dygraph.base.to_variable([margin], dtype=out.dtype) + out = _C_ops.elementwise_add(out, margin) + out = _C_ops.relu(out) + if reduction == 'sum': + return _C_ops.reduce_sum(out, 'reduce_all', True) + elif reduction == 'mean': + return _C_ops.final_state_mean_all(out) + return out + elif _in_legacy_dygraph(): out = _C_ops.elementwise_sub(other, input) out = _C_ops.elementwise_mul(out, label) if margin != 0.0: @@ -698,7 +713,17 @@ def l1_loss(input, label, reduction='mean', name=None): "The value of 'reduction' in L1Loss should be 'sum', 'mean' or 'none', but " "received %s, which is not allowed." % reduction) - if in_dynamic_mode(): + if in_dygraph_mode(): + unreduced = _elementwise_op_in_dygraph( + input, label, axis=-1, act='abs', op_name='elementwise_sub') + if reduction == 'mean': + return _C_ops.final_state_mean_all(unreduced) + elif reduction == 'sum': + return _C_ops.reduce_sum(unreduced, 'dim', [0], 'keep_dim', False, + 'reduce_all', True) + else: + return unreduced + elif in_dynamic_mode(): unreduced = _elementwise_op_in_dygraph( input, label, axis=-1, act='abs', op_name='elementwise_sub') if reduction == 'mean': @@ -1819,7 +1844,10 @@ def cross_entropy(input, 'reduce_all', True) return out_sum / (total_weight + (total_weight == 0.0)) else: - return _C_ops.mean(out) + if in_dygraph_mode(): + return _C_ops.final_state_mean_all(out) + else: + return _C_ops.mean(out) else: if input_dims - 1 == label_dims: @@ -2064,6 +2092,8 @@ def sigmoid_focal_loss(logit, if reduction == "sum": return _C_ops.reduce_sum(loss, 'reduce_all', True) elif reduction == "mean": + if in_dygraph_mode(): + return _C_ops.final_state_mean_all(loss) return _C_ops.mean(loss) return loss @@ -2179,7 +2209,7 @@ def hinge_embedding_loss(input, label, margin=1.0, reduction='mean', name=None): "'reduction' in 'hinge_embedding_loss' should be 'sum', 'mean' or 'none', " "but received {}.".format(reduction)) - if not in_dynamic_mode(): + if not _non_static_mode(): check_variable_and_dtype(input, 'input', ['float32', 'float64'], 'hinge_embedding_loss') check_variable_and_dtype(label, 'label', ['float32', 'float64'], diff --git a/python/paddle/nn/layer/loss.py b/python/paddle/nn/layer/loss.py index 7e40c029a02..d4e059b6dfa 100644 --- a/python/paddle/nn/layer/loss.py +++ b/python/paddle/nn/layer/loss.py @@ -18,7 +18,7 @@ import numpy as np import paddle.fluid as fluid import paddle from .. import functional as F -from paddle.fluid.framework import _varbase_creator +from paddle.fluid.framework import _varbase_creator, in_dygraph_mode, _in_legacy_dygraph from .. import Layer from paddle import in_dynamic_mode @@ -597,7 +597,11 @@ class MSELoss(Layer): fluid.data_feeder.check_variable_and_dtype( label, 'label', ['float32', 'float64'], 'MSELoss') - square_out = paddle.square(paddle.subtract(input, label)) + if in_dygraph_mode(): + square_out = paddle._C_ops.final_state_square( + paddle.subtract(input, label)) + else: + square_out = paddle.square(paddle.subtract(input, label)) if self.reduction == 'none': return square_out diff --git a/python/paddle/optimizer/optimizer.py b/python/paddle/optimizer/optimizer.py index 0af8b8bb894..0dfe294c00d 100644 --- a/python/paddle/optimizer/optimizer.py +++ b/python/paddle/optimizer/optimizer.py @@ -987,7 +987,9 @@ class Optimizer(object): assert regularization_term is not None - if framework._non_static_mode(): + if framework.in_dygraph_mode(): + return _C_ops.final_state_add_n([grad, regularization_term]) + elif framework._in_legacy_dygraph(): return _C_ops.sum([grad, regularization_term]) new_grad = grad diff --git a/python/paddle/tensor/linalg.py b/python/paddle/tensor/linalg.py index 9c2074bbe3c..098f17e6759 100644 --- a/python/paddle/tensor/linalg.py +++ b/python/paddle/tensor/linalg.py @@ -758,10 +758,13 @@ def cond(x, p=None, name=None): axis = axis if axis != None and axis != [] else [0] keepdim = False - if paddle.in_dynamic_mode(): + if _non_static_mode(): abs_out = _C_ops.abs(input) - sum_out = _C_ops.reduce_sum(abs_out, 'dim', axis, 'keepdim', - keepdim, 'reduce_all', reduce_all) + if in_dygraph_mode(): + sum_out = _C_ops.final_state_sum(abs_out, axis, None, keepdim) + else: + sum_out = _C_ops.reduce_sum(abs_out, 'dim', axis, 'keepdim', + keepdim, 'reduce_all', reduce_all) if porder == 1 or porder == np.inf: return _C_ops.reduce_max(sum_out, 'dim', [-1], 'keepdim', keepdim, 'reduce_all', reduce_all) @@ -815,7 +818,12 @@ def cond(x, p=None, name=None): reduce_all = True if axis is None or axis == [] else False keepdim = False - if paddle.in_dynamic_mode(): + if in_dygraph_mode(): + pow_out = _C_ops.pow(input, 'factor', porder) + sum_out_1 = _C_ops.final_state_sum(pow_out, axis, None, keepdim) + sum_out_2 = _C_ops.final_state_sum(sum_out_1, axis, None, keepdim) + return _C_ops.pow(sum_out_2, 'factor', float(1. / porder)) + elif paddle.in_dynamic_mode(): pow_out = _C_ops.pow(input, 'factor', porder) sum_out_1 = _C_ops.reduce_sum(pow_out, 'dim', axis, 'keepdim', keepdim, 'reduce_all', reduce_all) @@ -869,10 +877,13 @@ def cond(x, p=None, name=None): u, s, vh = svd(input, full_matrices=False) - if paddle.in_dynamic_mode(): + if _non_static_mode(): if porder == "nuc": - return _C_ops.reduce_sum(s, 'dim', axis, 'keepdim', keepdim, - 'reduce_all', reduce_all) + if in_dygraph_mode(): + return _C_ops.final_state_sum(s, axis, None, keepdim) + else: + return _C_ops.reduce_sum(s, 'dim', axis, 'keepdim', keepdim, + 'reduce_all', reduce_all) max_out = _C_ops.reduce_max(s, 'dim', axis, 'keepdim', keepdim, 'reduce_all', reduce_all) min_out = _C_ops.reduce_min(s, 'dim', axis, 'keepdim', keepdim, @@ -2530,7 +2541,7 @@ def pinv(x, rcond=1e-15, hermitian=False, name=None): # or out * x * out = x ; """ - if paddle.in_dynamic_mode(): + if _non_static_mode(): if not hermitian: # combine svd and matmul op u, s, vt = _C_ops.svd(x, 'full_matrices', False) @@ -2554,8 +2565,11 @@ def pinv(x, rcond=1e-15, hermitian=False, name=None): v, _ = _C_ops.transpose2(vt, 'axis', perm) out_1 = v * st - out_2 = _C_ops.matmul_v2(out_1, u, 'trans_x', False, 'trans_y', - True) + if in_dygraph_mode(): + out_2 = _C_ops.final_state_matmul(out_1, u, False, True) + else: + out_2 = _C_ops.matmul_v2(out_1, u, 'trans_x', False, 'trans_y', + True) return out_2 else: # combine eigh and matmul op @@ -2578,8 +2592,11 @@ def pinv(x, rcond=1e-15, hermitian=False, name=None): out_1 = u * st u_conj = _C_ops.conj(u) - out_2 = _C_ops.matmul_v2(out_1, u_conj, 'trans_x', False, 'trans_y', - True) + if in_dygraph_mode(): + out_2 = _C_ops.final_state_matmul(out_1, u_conj, False, True) + else: + out_2 = _C_ops.matmul_v2(out_1, u_conj, 'trans_x', False, + 'trans_y', True) return out_2 else: if not hermitian: @@ -3080,7 +3097,7 @@ def lstsq(x, y, rcond=None, driver=None, name=None): elif x.dtype == paddle.float64: rcond = 1e-15 * max(x.shape[-2], x.shape[-1]) - if paddle.in_dynamic_mode(): + if _non_static_mode(): solution, rank, singular_values = _C_ops.lstsq(x, y, "rcond", rcond, "driver", driver) if x.shape[-2] > x.shape[-1]: @@ -3089,8 +3106,11 @@ def lstsq(x, y, rcond=None, driver=None, name=None): False) minus_out = _C_ops.elementwise_sub(matmul_out, y) pow_out = _C_ops.pow(minus_out, 'factor', 2) - residuals = _C_ops.reduce_sum(pow_out, 'dim', [-2], 'keepdim', - False, 'reduce_all', False) + if in_dygraph_mode(): + residuals = _C_ops.final_state_sum(pow_out, [-2], None, False) + else: + residuals = _C_ops.reduce_sum(pow_out, 'dim', [-2], 'keepdim', + False, 'reduce_all', False) else: residuals = paddle.empty(shape=[0], dtype=x.dtype) diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index cfc9abb8698..6bbeb4e77be 100644 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -424,6 +424,10 @@ OP_NAMEMAPPING = { 'elementwise_pow': 'final_state_elementwise_pow', 'elementwise_floordiv': 'final_state_floor_divide', 'elementwise_mod': 'final_state_modulo', + 'elementwise_add': 'final_state_add', + 'elementwise_sub': 'final_state_subtract', + 'elementwise_mul': 'final_state_multiply', + 'elementwise_div': 'final_state_divide', } @dygraph_only @@ -436,7 +440,7 @@ def _elementwise_op_in_dygraph(x, def is_inplace(op_name): return op_name[-1] == "_" - if op_name not in OP_NAMEMAPPING.keys(): + if op_name not in OP_NAMEMAPPING.keys() or axis != -1: op = getattr(_C_ops, op_name) out = op(x, y, 'axis', axis, 'use_mkldnn', use_mkldnn) else: @@ -1528,7 +1532,9 @@ def mm(input, mat2, name=None): """ - if paddle.in_dynamic_mode(): + if in_dygraph_mode(): + return _C_ops.final_state_matmul(input, mat2, False, False) + elif paddle.in_dynamic_mode(): return _C_ops.matmul_v2(input, mat2) def __check_input(x, y): @@ -1751,7 +1757,9 @@ def inner(x, y, name=None): nx = x.reshape((-1, xshape[-1])) ny = y.reshape((-1, yshape[-1])) - if paddle.in_dynamic_mode(): + if in_dygraph_mode(): + return _C_ops.final_state_matmul(nx, ny.T, False, False).reshape(dstshape) + elif paddle.in_dynamic_mode(): return _C_ops.matmul_v2(nx, ny.T).reshape(dstshape) def __check_input(x, y): @@ -1814,7 +1822,9 @@ def outer(x, y, name=None): nx = x.reshape((-1, 1)) ny = y.reshape((1, -1)) - if paddle.in_dynamic_mode(): + if in_dygraph_mode(): + return _C_ops.final_state_matmul(nx, ny, False, False) + elif paddle.in_dynamic_mode(): return _C_ops.matmul_v2(nx, ny) def __check_input(x, y): @@ -3965,7 +3975,11 @@ def rad2deg(x, name=None): # [57.29578018]) """ rad2deg_scale = 180 / np.pi - if paddle.in_dynamic_mode(): + if in_dygraph_mode(): + if convert_dtype(x.dtype) in ['int32', 'int64']: + x = cast(x, dtype="float32") + return _C_ops.final_state_scale(x, rad2deg_scale, 0.0, True) + elif paddle.in_dynamic_mode(): if convert_dtype(x.dtype) in ['int32', 'int64']: x = cast(x, dtype="float32") return _C_ops.scale(x, 'scale', rad2deg_scale) @@ -4018,7 +4032,11 @@ def deg2rad(x, name=None): # [3.14159274]) """ deg2rad_scale = np.pi / 180.0 - if paddle.in_dynamic_mode(): + if in_dygraph_mode(): + if convert_dtype(x.dtype) in ['int32', 'int64']: + x = cast(x, dtype="float32") + return _C_ops.final_state_scale(x, deg2rad_scale, 0.0, True) + elif paddle.in_dynamic_mode(): if convert_dtype(x.dtype) in ['int32', 'int64']: x = cast(x, dtype="float32") return _C_ops.scale(x, 'scale', deg2rad_scale) @@ -4263,14 +4281,22 @@ def diff(x, n=1, axis=-1, prepend=None, append=None, name=None): attrs_1 += ('starts', starts_1) ends_1 = [dim_len - 1] attrs_1 += ('ends', ends_1) - input_front = _C_ops.slice(new_input, None, None, None, None, 'axes', axes, \ - 'infer_flags', infer_flags, *attrs_1) + if in_dygraph_mode(): + input_front = _C_ops.final_state_slice(new_input, axes, starts_1, ends_1, infer_flags, + []) + else: + input_front = _C_ops.slice(new_input, None, None, None, None, 'axes', axes, \ + 'infer_flags', infer_flags, *attrs_1) starts_2 = [1] attrs_2 += ('starts', starts_2) ends_2 = [dim_len] attrs_2 += ('ends', ends_2) - input_back = _C_ops.slice(new_input, None, None, None, None, 'axes', axes, \ - 'infer_flags', infer_flags, *attrs_2) + if in_dygraph_mode(): + input_back = input_front = _C_ops.final_state_slice(new_input, axes, starts_2, ends_2, infer_flags, + []) + else: + input_back = _C_ops.slice(new_input, None, None, None, None, 'axes', axes, \ + 'infer_flags', infer_flags, *attrs_2) if x.dtype == paddle.bool: op = getattr(_C_ops, "logical_xor") diff --git a/python/paddle/tensor/search.py b/python/paddle/tensor/search.py index 6855b8f0f70..04704981c89 100644 --- a/python/paddle/tensor/search.py +++ b/python/paddle/tensor/search.py @@ -398,7 +398,9 @@ def nonzero(x, as_tuple=False): shape = x.shape rank = len(shape) - if paddle.in_dynamic_mode(): + if in_dygraph_mode(): + outs = _C_ops.final_state_where_index(x) + elif paddle.in_dynamic_mode(): outs = _C_ops.where_index(x) else: helper = LayerHelper("where_index", **locals()) -- GitLab