diff --git a/paddle/fluid/operators/reduce_ops/reduce_max_op_mlu.cc b/paddle/fluid/operators/reduce_ops/reduce_max_op_mlu.cc new file mode 100644 index 0000000000000000000000000000000000000000..7e02f0268b5e510ac8262543db58ee98ef20e517 --- /dev/null +++ b/paddle/fluid/operators/reduce_ops/reduce_max_op_mlu.cc @@ -0,0 +1,93 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/mlu/mlu_baseop.h" +#include "paddle/fluid/operators/reduce_ops/reduce_min_max_op.h" + +namespace paddle { +namespace operators { + +template +class ReduceMaxMLUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* input = context.Input("X"); + auto* output = context.Output("Out"); + int out_dtype = context.Attr("out_dtype"); + bool reduce_all = context.Attr("reduce_all"); + auto dims = context.Attr>("dim"); + auto input_dims = framework::vectorize(input->dims()); + const auto& input_dim_size = input->dims().size(); + std::vector reduce_dims; + if (reduce_all) { + for (size_t i = 0; i < input_dims.size(); i++) { + reduce_dims.push_back(static_cast(i)); + } + } else { + for (size_t i = 0; i < dims.size(); ++i) { + if (dims[i] < 0) { + reduce_dims.push_back(dims[i] + input_dim_size); + } else { + reduce_dims.push_back(dims[i]); + } + } + } + + auto place = context.GetPlace(); + framework::Tensor cast_out(input->type()); + cast_out.Resize(output->dims()); + cast_out.mutable_data(place); + + auto cast_out_dtype = framework::TransToProtoVarType(input->dtype()); + + if (out_dtype != -1) { + cast_out_dtype = static_cast(out_dtype); + } + if (framework::TransToProtoVarType(input->type()) != cast_out_dtype) { + if (cast_out_dtype == framework::proto::VarType::FP32) { + output->mutable_data(place); + } else if (cast_out_dtype == framework::proto::VarType::FP16) { + output->mutable_data(place); + } else if (cast_out_dtype == framework::proto::VarType::INT32) { + output->mutable_data(place); + } + } else { + output->ShareDataWith(cast_out); + } + + MLUCnnlTensorDesc input_desc(*input, CNNL_LAYOUT_ARRAY, + ToCnnlDataType(input->dtype())); + MLUCnnlTensorDesc output_desc(*output, CNNL_LAYOUT_ARRAY, + ToCnnlDataType(output->dtype())); + + MLUCnnlReduceDesc reduction_desc( + reduce_dims, CNNL_REDUCE_MAX, ToCnnlDataType(), + CNNL_NOT_PROPAGATE_NAN, CNNL_REDUCE_NO_INDICES, CNNL_32BIT_INDICES); + + MLUCnnl::Reduce(context, true /*need_workspace*/, reduction_desc.get(), + nullptr, input_desc.get(), GetBasePtr(input), + 0 /*indices_size*/, nullptr, nullptr, output_desc.get(), + GetBasePtr(output)); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_MLU_KERNEL(reduce_max, ops::ReduceMaxMLUKernel, + ops::ReduceMaxMLUKernel, + ops::ReduceMaxMLUKernel); diff --git a/paddle/fluid/operators/reduce_ops/reduce_min_op_mlu.cc b/paddle/fluid/operators/reduce_ops/reduce_min_op_mlu.cc new file mode 100644 index 0000000000000000000000000000000000000000..daf5965fd54628a097ad1d53057ec54b9a5d329a --- /dev/null +++ b/paddle/fluid/operators/reduce_ops/reduce_min_op_mlu.cc @@ -0,0 +1,93 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/mlu/mlu_baseop.h" +#include "paddle/fluid/operators/reduce_ops/reduce_min_max_op.h" + +namespace paddle { +namespace operators { + +template +class ReduceMinMLUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* input = context.Input("X"); + auto* output = context.Output("Out"); + int out_dtype = context.Attr("out_dtype"); + bool reduce_all = context.Attr("reduce_all"); + auto dims = context.Attr>("dim"); + auto input_dims = framework::vectorize(input->dims()); + const auto& input_dim_size = input->dims().size(); + std::vector reduce_dims; + if (reduce_all) { + for (size_t i = 0; i < input_dims.size(); i++) { + reduce_dims.push_back(static_cast(i)); + } + } else { + for (size_t i = 0; i < dims.size(); ++i) { + if (dims[i] < 0) { + reduce_dims.push_back(dims[i] + input_dim_size); + } else { + reduce_dims.push_back(dims[i]); + } + } + } + + auto place = context.GetPlace(); + framework::Tensor cast_out(input->type()); + cast_out.Resize(output->dims()); + cast_out.mutable_data(place); + + auto cast_out_dtype = framework::TransToProtoVarType(input->dtype()); + + if (out_dtype != -1) { + cast_out_dtype = static_cast(out_dtype); + } + if (framework::TransToProtoVarType(input->type()) != cast_out_dtype) { + if (cast_out_dtype == framework::proto::VarType::FP32) { + output->mutable_data(place); + } else if (cast_out_dtype == framework::proto::VarType::FP16) { + output->mutable_data(place); + } else if (cast_out_dtype == framework::proto::VarType::INT32) { + output->mutable_data(place); + } + } else { + output->ShareDataWith(cast_out); + } + + MLUCnnlTensorDesc input_desc(*input, CNNL_LAYOUT_ARRAY, + ToCnnlDataType(input->dtype())); + MLUCnnlTensorDesc output_desc(*output, CNNL_LAYOUT_ARRAY, + ToCnnlDataType(output->dtype())); + + MLUCnnlReduceDesc reduction_desc( + reduce_dims, CNNL_REDUCE_MIN, ToCnnlDataType(), + CNNL_NOT_PROPAGATE_NAN, CNNL_REDUCE_NO_INDICES, CNNL_32BIT_INDICES); + + MLUCnnl::Reduce(context, true /*need_workspace*/, reduction_desc.get(), + nullptr, input_desc.get(), GetBasePtr(input), + 0 /*indices_size*/, nullptr, nullptr, output_desc.get(), + GetBasePtr(output)); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_MLU_KERNEL(reduce_min, ops::ReduceMinMLUKernel, + ops::ReduceMinMLUKernel, + ops::ReduceMinMLUKernel); diff --git a/paddle/fluid/operators/reduce_ops/reduce_op.h b/paddle/fluid/operators/reduce_ops/reduce_op.h index ca3575f5dea84321e4fb46cbaa5606652ef267d4..eb39f069e56b73aa34e3323768a4b72cd6c737f4 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_op.h +++ b/paddle/fluid/operators/reduce_ops/reduce_op.h @@ -541,11 +541,12 @@ class ReduceOp : public framework::OperatorWithKernel { #endif if (input_data_type == framework::proto::VarType::FP16) { - PADDLE_ENFORCE_EQ(platform::is_gpu_place(ctx.GetPlace()) || - platform::is_npu_place(ctx.GetPlace()), - true, - platform::errors::InvalidArgument( - "float16 can only be used on GPU or NPU place")); + PADDLE_ENFORCE_EQ( + platform::is_gpu_place(ctx.GetPlace()) || + platform::is_npu_place(ctx.GetPlace()) || + platform::is_mlu_place(ctx.GetPlace()), + true, platform::errors::InvalidArgument( + "float16 can only be used on GPU or NPU or MLU place")); } return framework::OpKernelType(input_data_type, ctx.GetPlace()); } diff --git a/python/paddle/fluid/tests/unittests/mlu/test_reduce_max_op_mlu.py b/python/paddle/fluid/tests/unittests/mlu/test_reduce_max_op_mlu.py new file mode 100644 index 0000000000000000000000000000000000000000..ef33719d368e8b2ecb248f4e14f6bd6a031c26bb --- /dev/null +++ b/python/paddle/fluid/tests/unittests/mlu/test_reduce_max_op_mlu.py @@ -0,0 +1,170 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +from paddle.fluid.tests.unittests.op_test import OpTest, skip_check_grad_ci +import paddle +import paddle.fluid.core as core +import paddle.fluid as fluid +from paddle.fluid import compiler, Program, program_guard +from paddle.fluid.framework import convert_np_dtype_to_dtype_ + +paddle.enable_static() + + +@skip_check_grad_ci( + reason="reduce_max is discontinuous non-derivable function," + " its gradient check is not supported by unittest framework.") +class TestMLUReduceMaxOp(OpTest): + """Remove Min with subgradient from gradient check to confirm the success of CI.""" + + def setUp(self): + self.op_type = "reduce_max" + self.set_mlu() + self.init_dtype() + + self.inputs = {'X': np.random.random((5, 6, 10)).astype(self.dtype)} + self.attrs = {'dim': [-1]} + self.outputs = { + 'Out': self.inputs['X'].max(axis=tuple(self.attrs['dim'])) + } + + def test_check_output(self): + self.check_output_with_place(self.place) + + def set_mlu(self): + self.__class__.use_mlu = True + self.place = paddle.MLUPlace(0) + + def init_dtype(self): + self.dtype = np.float32 + + +@skip_check_grad_ci( + reason="reduce_max is discontinuous non-derivable function," + " its gradient check is not supported by unittest framework.") +class TestReduceMaxOpMultiAxises(TestMLUReduceMaxOp): + """Remove Min with subgradient from gradient check to confirm the success of CI.""" + + def setUp(self): + self.op_type = "reduce_max" + self.set_mlu() + self.init_dtype() + + self.inputs = {'X': np.random.random((5, 6, 10)).astype(self.dtype)} + self.attrs = {'dim': [-2, -1]} + self.outputs = { + 'Out': self.inputs['X'].max(axis=tuple(self.attrs['dim'])) + } + + +@skip_check_grad_ci( + reason="reduce_max is discontinuous non-derivable function," + " its gradient check is not supported by unittest framework.") +class TestReduceAll(TestMLUReduceMaxOp): + """Remove Min with subgradient from gradient check to confirm the success of CI.""" + + def setUp(self): + self.op_type = "reduce_max" + self.set_mlu() + self.init_dtype() + + self.inputs = {'X': np.random.random((5, 6, 10)).astype(self.dtype)} + self.attrs = {'reduce_all': True} + self.outputs = {'Out': self.inputs['X'].max()} + + +@skip_check_grad_ci( + reason="reduce_max is discontinuous non-derivable function," + " its gradient check is not supported by unittest framework.") +class TestReduceMaxOpWithOutDtype_int32(TestMLUReduceMaxOp): + """Remove Min with subgradient from gradient check to confirm the success of CI.""" + + def setUp(self): + self.op_type = "reduce_max" + self.set_mlu() + self.init_dtype() + + self.inputs = {'X': np.random.random((5, 6, 10)).astype(self.dtype)} + self.attrs = { + 'dim': [-2, -1], + 'out_dtype': int(core.VarDesc.VarType.INT32) + } + self.outputs = { + 'Out': + self.inputs['X'].max(axis=tuple(self.attrs['dim'])).astype(np.int32) + } + + def init_dtype(self): + self.dtype = np.int32 + + +@skip_check_grad_ci( + reason="reduce_max is discontinuous non-derivable function," + " its gradient check is not supported by unittest framework.") +class TestReduceMaxOpWithOutDtype_fp16(TestMLUReduceMaxOp): + """Remove Min with subgradient from gradient check to confirm the success of CI.""" + + def setUp(self): + self.op_type = "reduce_max" + self.set_mlu() + self.init_dtype() + + self.inputs = {'X': np.random.random((5, 6, 10)).astype(self.dtype)} + self.attrs = { + 'dim': [-2, -1], + 'out_dtype': int(core.VarDesc.VarType.FP16) + } + self.outputs = { + 'Out': self.inputs['X'].max( + axis=tuple(self.attrs['dim'])).astype(np.float16) + } + + def init_dtype(self): + self.dtype = np.float16 + + def test_check_output(self): + self.check_output_with_place(self.place, atol=1e-3) + + +@skip_check_grad_ci( + reason="reduce_max is discontinuous non-derivable function," + " its gradient check is not supported by unittest framework.") +class TestReduceMaxOpWithOutDtype_fp32(TestMLUReduceMaxOp): + """Remove Min with subgradient from gradient check to confirm the success of CI.""" + + def setUp(self): + self.op_type = "reduce_max" + self.set_mlu() + self.init_dtype() + + self.inputs = {'X': np.random.random((5, 6, 10)).astype(self.dtype)} + self.attrs = { + 'dim': [-2, -1], + 'out_dtype': int(core.VarDesc.VarType.FP32) + } + self.outputs = { + 'Out': self.inputs['X'].max( + axis=tuple(self.attrs['dim'])).astype(np.float32) + } + + def init_dtype(self): + self.dtype = np.float32 + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/mlu/test_reduce_min_op_mlu.py b/python/paddle/fluid/tests/unittests/mlu/test_reduce_min_op_mlu.py new file mode 100644 index 0000000000000000000000000000000000000000..284f8f984c232d3c178bcb13977b958ac4775a30 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/mlu/test_reduce_min_op_mlu.py @@ -0,0 +1,170 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +from paddle.fluid.tests.unittests.op_test import OpTest, skip_check_grad_ci +import paddle +import paddle.fluid.core as core +import paddle.fluid as fluid +from paddle.fluid import compiler, Program, program_guard +from paddle.fluid.framework import convert_np_dtype_to_dtype_ + +paddle.enable_static() + + +@skip_check_grad_ci( + reason="reduce_min is discontinuous non-derivable function," + " its gradient check is not supported by unittest framework.") +class TestMLUReduceMinOp(OpTest): + """Remove Min with subgradient from gradient check to confirm the success of CI.""" + + def setUp(self): + self.op_type = "reduce_min" + self.set_mlu() + self.init_dtype() + + self.inputs = {'X': np.random.random((5, 6, 10)).astype(self.dtype)} + self.attrs = {'dim': [-1]} + self.outputs = { + 'Out': self.inputs['X'].min(axis=tuple(self.attrs['dim'])) + } + + def test_check_output(self): + self.check_output_with_place(self.place) + + def set_mlu(self): + self.__class__.use_mlu = True + self.place = paddle.MLUPlace(0) + + def init_dtype(self): + self.dtype = np.float32 + + +@skip_check_grad_ci( + reason="reduce_min is discontinuous non-derivable function," + " its gradient check is not supported by unittest framework.") +class TestReduceMinOpMultiAxises(TestMLUReduceMinOp): + """Remove Min with subgradient from gradient check to confirm the success of CI.""" + + def setUp(self): + self.op_type = "reduce_min" + self.set_mlu() + self.init_dtype() + + self.inputs = {'X': np.random.random((5, 6, 10)).astype(self.dtype)} + self.attrs = {'dim': [-2, -1]} + self.outputs = { + 'Out': self.inputs['X'].min(axis=tuple(self.attrs['dim'])) + } + + +@skip_check_grad_ci( + reason="reduce_min is discontinuous non-derivable function," + " its gradient check is not supported by unittest framework.") +class TestReduceAll(TestMLUReduceMinOp): + """Remove Min with subgradient from gradient check to confirm the success of CI.""" + + def setUp(self): + self.op_type = "reduce_min" + self.set_mlu() + self.init_dtype() + + self.inputs = {'X': np.random.random((5, 6, 10)).astype(self.dtype)} + self.attrs = {'reduce_all': True} + self.outputs = {'Out': self.inputs['X'].min()} + + +@skip_check_grad_ci( + reason="reduce_min is discontinuous non-derivable function," + " its gradient check is not supported by unittest framework.") +class TestReduceMinOpWithOutDtype_int32(TestMLUReduceMinOp): + """Remove Min with subgradient from gradient check to confirm the success of CI.""" + + def setUp(self): + self.op_type = "reduce_min" + self.set_mlu() + self.init_dtype() + + self.inputs = {'X': np.random.random((5, 6, 10)).astype(self.dtype)} + self.attrs = { + 'dim': [-2, -1], + 'out_dtype': int(core.VarDesc.VarType.INT32) + } + self.outputs = { + 'Out': + self.inputs['X'].min(axis=tuple(self.attrs['dim'])).astype(np.int32) + } + + def init_dtype(self): + self.dtype = np.int32 + + +@skip_check_grad_ci( + reason="reduce_min is discontinuous non-derivable function," + " its gradient check is not supported by unittest framework.") +class TestReduceMinOpWithOutDtype_fp16(TestMLUReduceMinOp): + """Remove Min with subgradient from gradient check to confirm the success of CI.""" + + def setUp(self): + self.op_type = "reduce_min" + self.set_mlu() + self.init_dtype() + + self.inputs = {'X': np.random.random((5, 6, 10)).astype(self.dtype)} + self.attrs = { + 'dim': [-2, -1], + 'out_dtype': int(core.VarDesc.VarType.FP16) + } + self.outputs = { + 'Out': self.inputs['X'].min( + axis=tuple(self.attrs['dim'])).astype(np.float16) + } + + def init_dtype(self): + self.dtype = np.float16 + + def test_check_output(self): + self.check_output_with_place(self.place, atol=1e-3) + + +@skip_check_grad_ci( + reason="reduce_min is discontinuous non-derivable function," + " its gradient check is not supported by unittest framework.") +class TestReduceMinOpWithOutDtype_fp32(TestMLUReduceMinOp): + """Remove Min with subgradient from gradient check to confirm the success of CI.""" + + def setUp(self): + self.op_type = "reduce_min" + self.set_mlu() + self.init_dtype() + + self.inputs = {'X': np.random.random((5, 6, 10)).astype(self.dtype)} + self.attrs = { + 'dim': [-2, -1], + 'out_dtype': int(core.VarDesc.VarType.FP32) + } + self.outputs = { + 'Out': self.inputs['X'].min( + axis=tuple(self.attrs['dim'])).astype(np.float32) + } + + def init_dtype(self): + self.dtype = np.float32 + + +if __name__ == '__main__': + unittest.main()