From 0a29fc85d64dd7c65aee2e4b8d17f2e2b4cbf4b1 Mon Sep 17 00:00:00 2001 From: wawltor Date: Wed, 2 Sep 2020 10:39:46 +0800 Subject: [PATCH] fix the argmin,argmax op for the paddlepaddle 2.0 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix the argmin,argmax op for the paddlepaddle 2.0, add checkPoint for the argmax/argmin --- paddle/fluid/operators/arg_max_op.cc | 18 +++++ paddle/fluid/operators/arg_min_max_op_base.h | 35 +++++++-- paddle/fluid/operators/arg_min_op.cc | 18 +++++ .../tests/unittests/test_arg_min_max_v2_op.py | 22 ++++-- python/paddle/tensor/search.py | 71 +++++++------------ 5 files changed, 108 insertions(+), 56 deletions(-) diff --git a/paddle/fluid/operators/arg_max_op.cc b/paddle/fluid/operators/arg_max_op.cc index fd7fa17ac9a..a82134921ef 100644 --- a/paddle/fluid/operators/arg_max_op.cc +++ b/paddle/fluid/operators/arg_max_op.cc @@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/operators/arg_min_max_op_base.h" REGISTER_OPERATOR( @@ -31,3 +32,20 @@ REGISTER_OP_CPU_KERNEL( int16_t>, paddle::operators::ArgMaxKernel); +REGISTER_OP_VERSION(arg_max) + .AddCheckpoint( + R"ROC( + Upgrade argmax add a new attribute [flatten] and modify the attribute of dtype)ROC", + paddle::framework::compatible::OpVersionDesc() + .NewAttr("flatten", + "In order to compute the argmax over the flattened array " + "when the " + "argument `axis` in python API is None.", + false) + .ModifyAttr( + "dtype", + "change the default value of dtype, the older version " + "is -1, means return the int64 indices." + "The new version is 3, return the int64 indices directly." + "And supporting the dtype of -1 in new version.", + 3)); diff --git a/paddle/fluid/operators/arg_min_max_op_base.h b/paddle/fluid/operators/arg_min_max_op_base.h index ae3637f6f99..69365357084 100644 --- a/paddle/fluid/operators/arg_min_max_op_base.h +++ b/paddle/fluid/operators/arg_min_max_op_base.h @@ -70,6 +70,8 @@ struct VisitDataArgMinMaxFunctor { auto axis = ctx.Attr("axis"); auto keepdims = ctx.Attr("keepdims"); const bool& flatten = ctx.Attr("flatten"); + // paddle do not have the scalar tensor, just return the shape [1] tensor + if (flatten) keepdims = true; // if flatten, will construct the new dims for the cacluate framework::DDim x_dims; @@ -164,15 +166,30 @@ class ArgMinMaxOp : public framework::OperatorWithKernel { platform::errors::InvalidArgument( "'axis'(%d) must be less than Rank(X)(%d).", axis, x_dims.size())); + auto x_rank = x_dims.size(); + if (axis < 0) axis += x_rank; + if (ctx->IsRuntime()) { + const int& dtype = ctx->Attrs().Get("dtype"); + if (dtype == framework::proto::VarType::INT32) { + int64_t all_element_num = 0; + if (flatten) { + all_element_num = framework::product(x_dims); + + } else { + all_element_num = x_dims[axis]; + } + PADDLE_ENFORCE_LE( + all_element_num, INT_MAX, + "The element num of the argmin/argmax input at axis is " + "%d, is larger than int32 maximum value:%d, you must " + "set the dtype of argmin/argmax to 'int64'.", + all_element_num, INT_MAX); + } + } std::vector vec; if (flatten) { - // if is flatten, will return the only on element - if (keepdims) { - vec.emplace_back(static_cast(1)); - } + vec.emplace_back(static_cast(1)); } else { - auto x_rank = x_dims.size(); - if (axis < 0) axis += x_rank; for (int64_t i = 0; i < axis; i++) vec.emplace_back(x_dims[i]); if (keepdims) { vec.emplace_back(static_cast(1)); @@ -194,10 +211,14 @@ class BaseArgMinMaxOpMaker : public framework::OpProtoAndCheckerMaker { AddOutput("Out", "Output tensor."); AddAttr("axis", "The axis in which to compute the arg indics."); AddAttr("keepdims", "Keep the dim that to reduce.").SetDefault(false); - AddAttr("dtype", "Keep the dim that to reduce.").SetDefault(-1); AddAttr("flatten", "Flatten the input value, and search the min or max indices") .SetDefault(false); + AddAttr("dtype", + "(int, 3), the dtype of indices, the indices dtype must be " + "int32, int64." + "default dtype is int64, and proto value is 3.") + .SetDefault(3); AddComment(string::Sprintf(R"DOC( %s Operator. diff --git a/paddle/fluid/operators/arg_min_op.cc b/paddle/fluid/operators/arg_min_op.cc index 74fc3292746..23ed7d727c5 100644 --- a/paddle/fluid/operators/arg_min_op.cc +++ b/paddle/fluid/operators/arg_min_op.cc @@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/operators/arg_min_max_op_base.h" REGISTER_OPERATOR( @@ -31,3 +32,20 @@ REGISTER_OP_CPU_KERNEL( int16_t>, paddle::operators::ArgMinKernel); +REGISTER_OP_VERSION(arg_min) + .AddCheckpoint( + R"ROC( + Upgrade argmin add a new attribute [flatten] and modify the attribute of dtype)ROC", + paddle::framework::compatible::OpVersionDesc() + .NewAttr("flatten", + "In order to compute the argmin over the flattened array " + "when the " + "argument `axis` in python API is None.", + false) + .ModifyAttr( + "dtype", + "change the default value of dtype, the older version " + "is -1, means return the int64 indices." + "The new version is 3, return the int64 indices directly." + "And supporting the dtype of -1 in new version.", + 3)); diff --git a/python/paddle/fluid/tests/unittests/test_arg_min_max_v2_op.py b/python/paddle/fluid/tests/unittests/test_arg_min_max_v2_op.py index 7c1f9d802c3..0fd9863948a 100644 --- a/python/paddle/fluid/tests/unittests/test_arg_min_max_v2_op.py +++ b/python/paddle/fluid/tests/unittests/test_arg_min_max_v2_op.py @@ -218,7 +218,7 @@ def create_test_case(op_type): self.assertTrue("test_arg_api" in result.name) def run_dygraph(self, place): - paddle.disable_static() + paddle.disable_static(place) op = eval("paddle.%s" % (op_type)) data_tensor = paddle.to_tensor(self.input_data) @@ -240,7 +240,7 @@ def create_test_case(op_type): #case 4 result_data = op(data_tensor, axis=-1, keepdim=True) excepted_data = self.numpy_op(self.input_data, axis=-1) - excepted_data = excepted_data.reshape((10)) + excepted_data = excepted_data.reshape((10, 1)) self.assertTrue((result_data.numpy() == excepted_data).all(), True) #case 5 @@ -299,14 +299,28 @@ class TestArgMinMaxOpError(unittest.TestCase): name="test_argmax", shape=[10], dtype="float32") output = paddle.argmax(x=data, dtype="float32") - self.assertRaises(ValueError, test_argmax_attr_type) + self.assertRaises(TypeError, test_argmax_attr_type) def test_argmin_attr_type(): data = paddle.static.data( name="test_argmax", shape=[10], dtype="float32") output = paddle.argmin(x=data, dtype="float32") - self.assertRaises(ValueError, test_argmin_attr_type) + self.assertRaises(TypeError, test_argmin_attr_type) + + def test_argmax_axis_type(): + data = paddle.static.data( + name="test_argmax", shape=[10], dtype="float32") + output = paddle.argmax(x=data, axis=1.2) + + self.assertRaises(TypeError, test_argmax_axis_type) + + def test_argmin_axis_type(): + data = paddle.static.data( + name="test_argmin", shape=[10], dtype="float32") + output = paddle.argmin(x=data, axis=1.2) + + self.assertRaises(TypeError, test_argmin_axis_type) if __name__ == '__main__': diff --git a/python/paddle/tensor/search.py b/python/paddle/tensor/search.py index eede022e05b..552da3401c6 100644 --- a/python/paddle/tensor/search.py +++ b/python/paddle/tensor/search.py @@ -18,7 +18,6 @@ from ..fluid.data_feeder import check_variable_and_dtype, check_type, check_dtyp from ..fluid import core, layers # TODO: define searching & indexing functions of a tensor -from ..fluid.layers import argmin #DEFINE_ALIAS from ..fluid.layers import has_inf #DEFINE_ALIAS from ..fluid.layers import has_nan #DEFINE_ALIAS @@ -124,7 +123,7 @@ def argsort(x, axis=-1, descending=False, name=None): return ids -def argmax(x, axis=None, dtype=None, keepdim=False, name=None): +def argmax(x, axis=None, keepdim=False, dtype="int64", name=None): """ This OP computes the indices of the max elements of the input tensor's element along the provided axis. @@ -135,10 +134,10 @@ def argmax(x, axis=None, dtype=None, keepdim=False, name=None): axis(int, optional): Axis to compute indices along. The effective range is [-R, R), where R is x.ndim. when axis < 0, it works the same way as axis + R. Default is None, the input `x` will be into the flatten tensor, and selecting the min value index. - dtype(str): Data type of the output tensor which can - be int32, int64. The default value is None, and it will - return the int64 indices. keepdim(bool, optional): Keep the axis that selecting max. The defalut value is False. + dtype(str|np.dtype, optional): Data type of the output tensor which can + be int32, int64. The default value is 'int64', and it will + return the int64 indices. name(str, optional): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name`. @@ -166,48 +165,39 @@ def argmax(x, axis=None, dtype=None, keepdim=False, name=None): print(out3.numpy()) # [2 3 1] """ + if axis is not None and not isinstance(axis, int): + raise TypeError( + "The type of 'axis' must be int or None in argmax, but received %s." + % (type(axis))) + var_dtype = convert_np_dtype_to_dtype_(dtype) + check_dtype(var_dtype, 'dtype', ['int32', 'int64'], 'argmin') flatten = False if axis is None: flatten = True axis = 0 if in_dygraph_mode(): - if dtype != None: - var_dtype = convert_np_dtype_to_dtype_(dtype) - out = core.ops.arg_max(x, 'axis', axis, 'dtype', var_dtype, - 'keepdim', keepdim, 'flatten', flatten) - else: - out = core.ops.arg_max(x, 'axis', axis, 'keepdim', keepdim, - 'flatten', flatten) + out = core.ops.arg_max(x, 'axis', axis, 'dtype', var_dtype, 'keepdims', + keepdim, 'flatten', flatten) return out helper = LayerHelper("argmax", **locals()) check_variable_and_dtype( x, 'x', ['float32', 'float64', 'int16', 'int32', 'int64', 'uint8'], 'paddle.argmax') - var_dtype = None attrs = {} - if dtype is not None: - if dtype not in ['int32', 'int64']: - raise ValueError( - "The value of 'dtype' in argmax op must be int32, int64, but received of {}". - format(dtype)) - var_dtype = convert_np_dtype_to_dtype_(dtype) - attrs["dtype"] = var_dtype - else: - var_dtype = VarDesc.VarType.INT64 - out = helper.create_variable_for_type_inference(var_dtype) attrs['keepdims'] = keepdim attrs['axis'] = axis attrs['flatten'] = flatten + attrs['dtype'] = var_dtype helper.append_op( type='arg_max', inputs={'X': x}, outputs={'Out': [out]}, attrs=attrs) out.stop_gradient = True return out -def argmin(x, axis=None, dtype=None, keepdim=False, name=None): +def argmin(x, axis=None, keepdim=False, dtype="int64", name=None): """ This OP computes the indices of the min elements of the input tensor's element along the provided axis. @@ -218,10 +208,10 @@ def argmin(x, axis=None, dtype=None, keepdim=False, name=None): axis(int, optional): Axis to compute indices along. The effective range is [-R, R), where R is x.ndim. when axis < 0, it works the same way as axis + R. Default is None, the input `x` will be into the flatten tensor, and selecting the min value index. + keepdim(bool, optional): Keep the axis that selecting min. The defalut value is False. dtype(str): Data type of the output tensor which can - be int32, int64. The default value is None, and it will + be int32, int64. The default value is 'int64', and it will return the int64 indices. - keepdim(bool, optional): Keep the axis that selecting min. The defalut value is False. name(str, optional): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name`. @@ -249,41 +239,32 @@ def argmin(x, axis=None, dtype=None, keepdim=False, name=None): print(out3.numpy()) # [0 0 2] """ + if axis is not None and not isinstance(axis, int): + raise TypeError( + "The type of 'axis' must be int or None in argmin, but received %s." + % (type(axis))) + var_dtype = convert_np_dtype_to_dtype_(dtype) + check_dtype(var_dtype, 'dtype', ['int32', 'int64'], 'argmin') flatten = False if axis is None: flatten = True axis = 0 if in_dygraph_mode(): - if dtype != None: - var_dtype = convert_np_dtype_to_dtype_(dtype) - out = core.ops.arg_min(x, 'axis', axis, 'dtype', var_dtype, - 'keepdim', keepdim, 'flatten', flatten) - else: - out = core.ops.arg_min(x, 'axis', axis, 'keepdim', keepdim, - 'flatten', flatten) + out = core.ops.arg_min(x, 'axis', axis, 'dtype', var_dtype, 'keepdims', + keepdim, 'flatten', flatten) return out helper = LayerHelper("argmin", **locals()) check_variable_and_dtype( x, 'x', ['float32', 'float64', 'int16', 'int32', 'int64', 'uint8'], 'paddle.argmin') - var_dtype = None - attrs = {} - if dtype is not None: - if dtype not in ['int32', 'int64']: - raise ValueError( - "The value of 'dtype' in argmin op must be int32, int64, but received of {}". - format(dtype)) - var_dtype = convert_np_dtype_to_dtype_(dtype) - attrs["dtype"] = var_dtype - else: - var_dtype = VarDesc.VarType.INT64 - out = helper.create_variable_for_type_inference(var_dtype) + attrs = {} attrs['keepdims'] = keepdim attrs['axis'] = axis attrs['flatten'] = flatten + attrs['dtype'] = var_dtype helper.append_op( type='arg_min', inputs={'X': x}, outputs={'Out': [out]}, attrs=attrs) out.stop_gradient = True -- GitLab