From 2b8b16d715ae528e30b3bb8db66c98a7b6c53d75 Mon Sep 17 00:00:00 2001
From: furnace <34057289+windstamp@users.noreply.github.com>
Date: Thu, 10 Feb 2022 20:45:43 +0800
Subject: [PATCH] [NPU] add reduce_min (#39019)

[NPU] add reduce_min
---
 .../operators/reduce_ops/reduce_min_op_npu.cc | 118 +++++++
 .../unittests/npu/test_reduce_min_op_npu.py   | 300 ++++++++++++++++++
 2 files changed, 418 insertions(+)
 create mode 100644 paddle/fluid/operators/reduce_ops/reduce_min_op_npu.cc
 create mode 100644 python/paddle/fluid/tests/unittests/npu/test_reduce_min_op_npu.py

diff --git a/paddle/fluid/operators/reduce_ops/reduce_min_op_npu.cc b/paddle/fluid/operators/reduce_ops/reduce_min_op_npu.cc
new file mode 100644
index 00000000000..d9a62ce4dc9
--- /dev/null
+++ b/paddle/fluid/operators/reduce_ops/reduce_min_op_npu.cc
@@ -0,0 +1,118 @@
+/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#include "paddle/fluid/operators/reduce_ops/reduce_min_max_op.h"
+#include "paddle/fluid/platform/device/npu/npu_op_runner.h"
+
+namespace paddle {
+namespace operators {
+
+using Tensor = framework::Tensor;
+template <typename DeviceContext, typename T>
+class ReduceMinNPUKernel : public framework::OpKernel<T> {
+ public:
+  void Compute(const framework::ExecutionContext& ctx) const override {
+    auto* x = ctx.Input<Tensor>("X");
+    auto* out = ctx.Output<Tensor>("Out");
+    auto dims = ctx.Attr<std::vector<int>>("dim");
+    bool keep_dim = ctx.Attr<bool>("keep_dim");
+    bool reduce_all = ctx.Attr<bool>("reduce_all");
+    int out_dtype = ctx.Attr<int>("out_dtype");
+
+    auto place = ctx.GetPlace();
+
+    framework::Tensor cast_out(x->type());
+    cast_out.Resize(out->dims());
+    cast_out.mutable_data<T>(place);
+
+    auto cast_out_dtype = x->type();
+    if (out_dtype != -1) {
+      cast_out_dtype = static_cast<framework::proto::VarType::Type>(out_dtype);
+    }
+
+    if (x->type() != cast_out_dtype) {
+      if (cast_out_dtype == framework::proto::VarType::FP32) {
+        out->mutable_data<float>(place);
+      } else if (cast_out_dtype == framework::proto::VarType::FP16) {
+        out->mutable_data<paddle::platform::float16>(place);
+      } else if (cast_out_dtype == framework::proto::VarType::INT16) {
+        out->mutable_data<int16_t>(place);
+      } else if (cast_out_dtype == framework::proto::VarType::INT32) {
+        out->mutable_data<int32_t>(place);
+      } else if (cast_out_dtype == framework::proto::VarType::INT64) {
+        out->mutable_data<int64_t>(place);
+      } else if (cast_out_dtype == framework::proto::VarType::FP64) {
+        out->mutable_data<double>(place);
+      } else if (cast_out_dtype == framework::proto::VarType::BOOL) {
+        out->mutable_data<bool>(place);
+      }
+    } else {
+      out->ShareDataWith(cast_out);
+    }
+
+    framework::NPUAttributeMap attr_input = {{"axes", dims},
+                                             {"keep_dims", keep_dim}};
+
+    if (reduce_all) {
+      std::vector<int> dim_vec;
+      for (int i = 0; i < x->dims().size(); i++) {
+        dim_vec.push_back(i);
+      }
+
+      attr_input = {{"axes", dim_vec}, {"keep_dims", keep_dim}};
+    }
+
+    const auto& dev_ctx =
+        ctx.template device_context<paddle::platform::NPUDeviceContext>();
+    if (x->type() == framework::proto::VarType::INT64) {
+      auto op_func = [](const std::vector<Tensor>& inputs,
+                        const std::vector<Tensor>& outputs,
+                        const NPUAttributeMap& attrs,
+                        const platform::NPUDeviceContext& dev_ctx) {
+        const auto& runner =
+            NpuOpRunner("ReduceMinD", {inputs[0]}, {outputs[0]}, attrs);
+        runner.Run(dev_ctx.stream());
+      };
+
+      NpuOpRunner::TypeAdapter({*x}, {cast_out}, attr_input, dev_ctx, op_func,
+                               {framework::proto::VarType::INT32},
+                               {framework::proto::VarType::INT32});
+    } else {
+      const auto& runner =
+          NpuOpRunner("ReduceMinD", {*x}, {cast_out}, attr_input);
+      runner.Run(dev_ctx.stream());
+    }
+
+    if (x->type() != cast_out_dtype) {
+      auto dst_dtype = ConvertToNpuDtype(cast_out_dtype);
+      const auto& runner_cast =
+          NpuOpRunner("Cast", {cast_out}, {*out},
+                      {{"dst_type", static_cast<int>(dst_dtype)}});
+      runner_cast.Run(dev_ctx.stream());
+    }
+  }
+};
+
+}  // namespace operators
+}  // namespace paddle
+
+namespace ops = paddle::operators;
+namespace plat = paddle::platform;
+REGISTER_OP_NPU_KERNEL(
+    reduce_min, ops::ReduceMinNPUKernel<plat::NPUDeviceContext, float>,
+    ops::ReduceMinNPUKernel<plat::NPUDeviceContext, plat::float16>,
+#ifdef PADDLE_WITH_ASCEND_INT64
+    ops::ReduceMinNPUKernel<plat::NPUDeviceContext, int64_t>,
+#endif
+    ops::ReduceMinNPUKernel<plat::NPUDeviceContext, int>);
diff --git a/python/paddle/fluid/tests/unittests/npu/test_reduce_min_op_npu.py b/python/paddle/fluid/tests/unittests/npu/test_reduce_min_op_npu.py
new file mode 100644
index 00000000000..bbf23e1be3e
--- /dev/null
+++ b/python/paddle/fluid/tests/unittests/npu/test_reduce_min_op_npu.py
@@ -0,0 +1,300 @@
+#   Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import print_function
+
+import unittest
+import numpy as np
+from paddle.fluid.tests.unittests.op_test import OpTest, skip_check_grad_ci
+import paddle
+import paddle.fluid.core as core
+import paddle.fluid as fluid
+from paddle.fluid import compiler, Program, program_guard
+from paddle.fluid.framework import convert_np_dtype_to_dtype_
+
+paddle.enable_static()
+
+
+@skip_check_grad_ci(
+    reason="reduce_min is discontinuous non-derivable function,"
+    " its gradient check is not supported by unittest framework.")
+class TestNPUReduceMinOp(OpTest):
+    """Remove Min with subgradient from gradient check to confirm the success of CI."""
+
+    def setUp(self):
+        self.op_type = "reduce_min"
+        self.set_npu()
+        self.init_dtype()
+
+        self.inputs = {'X': np.random.random((5, 6, 10)).astype(self.dtype)}
+        self.attrs = {'dim': [-1]}
+        self.outputs = {
+            'Out': self.inputs['X'].min(axis=tuple(self.attrs['dim']))
+        }
+
+    def test_check_output(self):
+        self.check_output_with_place(self.place)
+
+    def set_npu(self):
+        self.__class__.use_npu = True
+        self.place = paddle.NPUPlace(0)
+
+    def init_dtype(self):
+        self.dtype = np.float32
+
+
+@skip_check_grad_ci(
+    reason="reduce_min is discontinuous non-derivable function,"
+    " its gradient check is not supported by unittest framework.")
+class TestReduceMinOpMultiAxises(TestNPUReduceMinOp):
+    """Remove Min with subgradient from gradient check to confirm the success of CI."""
+
+    def setUp(self):
+        self.op_type = "reduce_min"
+        self.set_npu()
+        self.init_dtype()
+
+        self.inputs = {'X': np.random.random((5, 6, 10)).astype(self.dtype)}
+        self.attrs = {'dim': [-2, -1]}
+        self.outputs = {
+            'Out': self.inputs['X'].min(axis=tuple(self.attrs['dim']))
+        }
+
+
+@skip_check_grad_ci(
+    reason="reduce_min is discontinuous non-derivable function,"
+    " its gradient check is not supported by unittest framework.")
+class TestReduceAll(TestNPUReduceMinOp):
+    """Remove Min with subgradient from gradient check to confirm the success of CI."""
+
+    def setUp(self):
+        self.op_type = "reduce_min"
+        self.set_npu()
+        self.init_dtype()
+
+        self.inputs = {'X': np.random.random((5, 6, 10)).astype(self.dtype)}
+        self.attrs = {'reduce_all': True}
+        self.outputs = {'Out': self.inputs['X'].min()}
+
+
+@skip_check_grad_ci(
+    reason="reduce_min is discontinuous non-derivable function,"
+    " its gradient check is not supported by unittest framework.")
+class TestReduceMinOpWithOutDtype_bool(TestNPUReduceMinOp):
+    """Remove Min with subgradient from gradient check to confirm the success of CI."""
+
+    def setUp(self):
+        self.op_type = "reduce_min"
+        self.set_npu()
+        self.init_dtype()
+
+        self.inputs = {'X': np.random.random((5, 6, 10)).astype(self.dtype)}
+        self.attrs = {
+            'dim': [-2, -1],
+            'out_dtype': int(core.VarDesc.VarType.BOOL)
+        }
+        self.outputs = {
+            'Out':
+            self.inputs['X'].min(axis=tuple(self.attrs['dim'])).astype(np.bool)
+        }
+
+
+@skip_check_grad_ci(
+    reason="reduce_min is discontinuous non-derivable function,"
+    " its gradient check is not supported by unittest framework.")
+class TestReduceMinOpWithOutDtype_int16(TestNPUReduceMinOp):
+    """Remove Min with subgradient from gradient check to confirm the success of CI."""
+
+    def setUp(self):
+        self.op_type = "reduce_min"
+        self.set_npu()
+        self.init_dtype()
+
+        self.inputs = {'X': np.random.random((5, 6, 10)).astype(self.dtype)}
+        self.attrs = {
+            'dim': [-2, -1],
+            'out_dtype': int(core.VarDesc.VarType.INT16)
+        }
+
+        self.outputs = {
+            'Out':
+            self.inputs['X'].min(axis=tuple(self.attrs['dim'])).astype(np.int16)
+        }
+
+
+@skip_check_grad_ci(
+    reason="reduce_min is discontinuous non-derivable function,"
+    " its gradient check is not supported by unittest framework.")
+class TestReduceMinOpWithOutDtype_int32(TestNPUReduceMinOp):
+    """Remove Min with subgradient from gradient check to confirm the success of CI."""
+
+    def setUp(self):
+        self.op_type = "reduce_min"
+        self.set_npu()
+        self.init_dtype()
+
+        self.inputs = {'X': np.random.random((5, 6, 10)).astype(self.dtype)}
+        self.attrs = {
+            'dim': [-2, -1],
+            'out_dtype': int(core.VarDesc.VarType.INT32)
+        }
+        self.outputs = {
+            'Out':
+            self.inputs['X'].min(axis=tuple(self.attrs['dim'])).astype(np.int32)
+        }
+
+
+@skip_check_grad_ci(
+    reason="reduce_min is discontinuous non-derivable function,"
+    " its gradient check is not supported by unittest framework.")
+class TestReduceMinOpWithOutDtype_int64(TestNPUReduceMinOp):
+    """Remove Min with subgradient from gradient check to confirm the success of CI."""
+
+    def setUp(self):
+        self.op_type = "reduce_min"
+        self.set_npu()
+        self.init_dtype()
+
+        self.inputs = {'X': np.random.random((5, 6, 10)).astype(self.dtype)}
+        self.attrs = {
+            'dim': [-2, -1],
+            'out_dtype': int(core.VarDesc.VarType.INT64)
+        }
+        self.outputs = {
+            'Out':
+            self.inputs['X'].min(axis=tuple(self.attrs['dim'])).astype(np.int64)
+        }
+
+
+@skip_check_grad_ci(
+    reason="reduce_min is discontinuous non-derivable function,"
+    " its gradient check is not supported by unittest framework.")
+class TestReduceMinOpWithOutDtype_fp16(TestNPUReduceMinOp):
+    """Remove Min with subgradient from gradient check to confirm the success of CI."""
+
+    def setUp(self):
+        self.op_type = "reduce_min"
+        self.set_npu()
+        self.init_dtype()
+
+        self.inputs = {'X': np.random.random((5, 6, 10)).astype(self.dtype)}
+        self.attrs = {
+            'dim': [-2, -1],
+            'out_dtype': int(core.VarDesc.VarType.FP16)
+        }
+        self.outputs = {
+            'Out': self.inputs['X'].min(
+                axis=tuple(self.attrs['dim'])).astype(np.float16)
+        }
+
+    def test_check_output(self):
+        self.check_output_with_place(self.place, atol=1e-3)
+
+
+@skip_check_grad_ci(
+    reason="reduce_min is discontinuous non-derivable function,"
+    " its gradient check is not supported by unittest framework.")
+class TestReduceMinOpWithOutDtype_fp32(TestNPUReduceMinOp):
+    """Remove Min with subgradient from gradient check to confirm the success of CI."""
+
+    def setUp(self):
+        self.op_type = "reduce_min"
+        self.set_npu()
+        self.init_dtype()
+
+        self.inputs = {'X': np.random.random((5, 6, 10)).astype(self.dtype)}
+        self.attrs = {
+            'dim': [-2, -1],
+            'out_dtype': int(core.VarDesc.VarType.FP32)
+        }
+        self.outputs = {
+            'Out': self.inputs['X'].min(
+                axis=tuple(self.attrs['dim'])).astype(np.float32)
+        }
+
+
+@skip_check_grad_ci(
+    reason="reduce_min is discontinuous non-derivable function,"
+    " its gradient check is not supported by unittest framework.")
+class TestReduceMinOpWithOutDtype_fp64(TestNPUReduceMinOp):
+    """Remove Min with subgradient from gradient check to confirm the success of CI."""
+
+    def setUp(self):
+        self.op_type = "reduce_min"
+        self.set_npu()
+        self.init_dtype()
+
+        self.inputs = {'X': np.random.random((5, 6, 10)).astype(self.dtype)}
+        self.attrs = {
+            'dim': [-2, -1],
+            'out_dtype': int(core.VarDesc.VarType.FP64)
+        }
+        self.outputs = {
+            'Out': self.inputs['X'].min(
+                axis=tuple(self.attrs['dim'])).astype(np.float64)
+        }
+
+
+@skip_check_grad_ci(
+    reason="reduce_min is discontinuous non-derivable function,"
+    " its gradient check is not supported by unittest framework.")
+class TestReduceMinOpWithOutDtype_fp32_2(TestNPUReduceMinOp):
+    """Remove Min with subgradient from gradient check to confirm the success of CI."""
+
+    def setUp(self):
+        self.op_type = "reduce_min"
+        self.set_npu()
+        self.init_dtype()
+
+        self.inputs = {'X': np.random.random((5, 6, 10)).astype(self.dtype)}
+        self.attrs = {
+            'dim': [-2, -1],
+            'out_dtype': int(core.VarDesc.VarType.FP32)
+        }
+        self.outputs = {
+            'Out': self.inputs['X'].min(
+                axis=tuple(self.attrs['dim'])).astype(np.float32)
+        }
+
+    def init_dtype(self):
+        self.dtype = np.float16
+
+
+@skip_check_grad_ci(
+    reason="reduce_min is discontinuous non-derivable function,"
+    " its gradient check is not supported by unittest framework.")
+class TestReduceMinOpInt64(TestNPUReduceMinOp):
+    """Remove Min with subgradient from gradient check to confirm the success of CI."""
+
+    def setUp(self):
+        self.op_type = "reduce_min"
+        self.set_npu()
+        self.init_dtype()
+
+        self.inputs = {'X': np.random.random((5, 6, 10)).astype(self.dtype)}
+        self.attrs = {
+            'dim': [-2, -1],
+            'out_dtype': int(core.VarDesc.VarType.INT64)
+        }
+        self.outputs = {
+            'Out': self.inputs['X'].min(
+                axis=tuple(self.attrs['dim'])).astype(np.float32)
+        }
+
+    def init_dtype(self):
+        self.dtype = np.int64
+
+
+if __name__ == '__main__':
+    unittest.main()
-- 
GitLab