From 74cc73bb1be46690378a9b3e078c6d642cdf6091 Mon Sep 17 00:00:00 2001 From: qipengh Date: Fri, 17 Jun 2022 17:48:22 +0800 Subject: [PATCH] [MLU]add elementwise op (#43491) --- .../elementwise/elementwise_min_op_mlu.cc | 139 +++++++++++ .../operators/elementwise/elementwise_mlu.h | 10 + .../platform/device/mlu/device_context.cc | 6 +- .../platform/device/mlu/device_context.h | 1 + paddle/fluid/platform/device/mlu/mlu_info.cc | 7 + paddle/fluid/platform/device/mlu/mlu_info.h | 3 + .../mlu/test_elementwise_min_op_mlu.py | 230 ++++++++++++++++++ 7 files changed, 395 insertions(+), 1 deletion(-) create mode 100644 paddle/fluid/operators/elementwise/elementwise_min_op_mlu.cc create mode 100644 python/paddle/fluid/tests/unittests/mlu/test_elementwise_min_op_mlu.py diff --git a/paddle/fluid/operators/elementwise/elementwise_min_op_mlu.cc b/paddle/fluid/operators/elementwise/elementwise_min_op_mlu.cc new file mode 100644 index 00000000000..7ddf9cd679e --- /dev/null +++ b/paddle/fluid/operators/elementwise/elementwise_min_op_mlu.cc @@ -0,0 +1,139 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include + +#include "paddle/fluid/operators/elementwise/elementwise_mlu.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +template +class ElementwiseMinMLUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + MLUBinaryOp(ctx); + } +}; + +template +class ElementwiseMinGradMLUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* x = ctx.Input("X"); + auto* y = ctx.Input("Y"); + auto* dout = ctx.Input(framework::GradVarName("Out")); + auto* dx = ctx.Output(framework::GradVarName("X")); + auto* dy = ctx.Output(framework::GradVarName("Y")); + int axis = ctx.Attr("axis"); + + const auto& x_dims = x->dims(); + const auto& y_dims = y->dims(); + axis = (axis < 0 ? (std::abs(x_dims.size() - y_dims.size()) + axis + 1) + : axis); + int max_dim = std::max(x_dims.size(), y_dims.size()); + std::vector x_dims_array(max_dim); + std::vector y_dims_array(max_dim); + std::vector out_dims_array(max_dim); + GetBroadcastDimsArrays(x_dims, y_dims, x_dims_array.data(), + y_dims_array.data(), out_dims_array.data(), max_dim, + axis); + + // mask = LessEqual(x, y) + Tensor mask(x->dtype()); + mask.Resize(phi::make_ddim(out_dims_array)); + mask.mutable_data(ctx.GetPlace()); + + cnnlDataType_t data_type = ToCnnlDataType(); + MLUCnnlTensorDesc x_desc(max_dim, x_dims_array.data(), data_type); + MLUCnnlTensorDesc y_desc(max_dim, y_dims_array.data(), data_type); + MLUCnnlTensorDesc mask_desc(max_dim, out_dims_array.data(), data_type); + MLUCnnl::Logic(ctx, CNNL_LOGIC_OP_LE, x_desc.get(), GetBasePtr(x), + y_desc.get(), GetBasePtr(y), mask_desc.get(), + GetBasePtr(&mask)); + + // dx = Mul(dz, mask) + Tensor dx_temp(x->dtype()); + dx_temp.Resize(dout->dims()); + dx_temp.mutable_data(ctx.GetPlace()); + MLUCnnlTensorDesc dout_desc(*dout); + MLUCnnlOpTensorDesc mul_op_desc(CNNL_OP_TENSOR_MUL, data_type, + CNNL_NOT_PROPAGATE_NAN); + MLUCnnl::OpTensor(ctx, mul_op_desc.get(), dout_desc.get(), GetBasePtr(dout), + dout_desc.get(), GetBasePtr(&mask), dout_desc.get(), + GetBasePtr(&dx_temp), data_type); + + // dy = Sub(dz, dx) + Tensor dy_temp(y->dtype()); + dy_temp.Resize(dout->dims()); + dy_temp.mutable_data(ctx.GetPlace()); + MLUCnnlOpTensorDesc sub_op_desc(CNNL_OP_TENSOR_SUB, data_type, + CNNL_NOT_PROPAGATE_NAN); + MLUCnnl::OpTensor(ctx, sub_op_desc.get(), dout_desc.get(), GetBasePtr(dout), + dout_desc.get(), GetBasePtr(&dx_temp), dout_desc.get(), + GetBasePtr(&dy_temp), data_type); + + if (dx) { + if (dx->dims() != dout->dims()) { + dx->mutable_data(ctx.GetPlace()); + std::vector reduce_axes; + GetReduceAxes(axis, dx_temp.dims(), dx->dims(), &reduce_axes); + MLUCnnlReduceDesc reduction_desc( + reduce_axes, CNNL_REDUCE_ADD, data_type, CNNL_NOT_PROPAGATE_NAN, + CNNL_REDUCE_NO_INDICES, CNNL_32BIT_INDICES); + MLUCnnlTensorDesc dx_desc(*dx); + MLUCnnl::Reduce(ctx, true /*need_workspace*/, reduction_desc.get(), + nullptr, dout_desc.get(), GetBasePtr(&dx_temp), 0, + nullptr, nullptr, dx_desc.get(), GetBasePtr(dx)); + } else { + dx->ShareDataWith(dx_temp); + } + } + + if (dy) { + if (dy->dims() != dout->dims()) { + dy->mutable_data(ctx.GetPlace()); + std::vector reduce_axes; + GetReduceAxes(axis, dy_temp.dims(), dy->dims(), &reduce_axes); + MLUCnnlReduceDesc reduction_desc( + reduce_axes, CNNL_REDUCE_ADD, data_type, CNNL_NOT_PROPAGATE_NAN, + CNNL_REDUCE_NO_INDICES, CNNL_32BIT_INDICES); + MLUCnnlTensorDesc dy_desc(*dy); + MLUCnnl::Reduce(ctx, true /*need_workspace*/, reduction_desc.get(), + nullptr, dout_desc.get(), GetBasePtr(&dy_temp), 0, + nullptr, nullptr, dy_desc.get(), GetBasePtr(dy)); + } else { + dy->ShareDataWith(dy_temp); + } + } + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_MLU_KERNEL(elementwise_min, ops::ElementwiseMinMLUKernel, + ops::ElementwiseMinMLUKernel, + ops::ElementwiseMinMLUKernel); + +REGISTER_OP_MLU_KERNEL(elementwise_min_grad, + ops::ElementwiseMinGradMLUKernel, + ops::ElementwiseMinGradMLUKernel, + ops::ElementwiseMinGradMLUKernel); diff --git a/paddle/fluid/operators/elementwise/elementwise_mlu.h b/paddle/fluid/operators/elementwise/elementwise_mlu.h index a6a153c34d4..ea3de211ada 100644 --- a/paddle/fluid/operators/elementwise/elementwise_mlu.h +++ b/paddle/fluid/operators/elementwise/elementwise_mlu.h @@ -109,6 +109,7 @@ enum BINARY_FUNCTOR { DIV, DIVNONAN, MAXIMUM, + MINIMUM, }; template @@ -137,6 +138,15 @@ inline void MLUBinary( MLUCnnl::Maximum(ctx, x_desc, x, y_desc, y, out_desc, out); } +template <> +inline void MLUBinary( + const framework::ExecutionContext& ctx, cnnlComputationPreference_t prefer, + const cnnlTensorDescriptor_t in1_desc, const void* in1, + const cnnlTensorDescriptor_t in2_desc, const void* in2, + const cnnlTensorDescriptor_t out_desc, void* out) { + MLUCnnl::Minimum(ctx, in1_desc, in1, in2_desc, in2, out_desc, out); +} + template void MLUBinaryOp(const framework::ExecutionContext& ctx) { auto* x = ctx.Input("X"); diff --git a/paddle/fluid/platform/device/mlu/device_context.cc b/paddle/fluid/platform/device/mlu/device_context.cc index e737432ecb4..c3c5546a12a 100644 --- a/paddle/fluid/platform/device/mlu/device_context.cc +++ b/paddle/fluid/platform/device/mlu/device_context.cc @@ -40,6 +40,7 @@ MLUDeviceContext::MLUDeviceContext(MLUPlace place) : place_(place) { compute_capability_ = GetMLUComputeCapability(place_.device); driver_version_ = GetMLUDriverVersion(place_.device); runtime_version_ = GetMLURuntimeVersion(place_.device); + cnnl_version_ = GetMLUCnnlVersion(place_.device); LOG_FIRST_N(WARNING, 1) << "Please NOTE: device: " << place_.device << ", MLU Compute Capability: " @@ -50,7 +51,10 @@ MLUDeviceContext::MLUDeviceContext(MLUPlace place) : place_(place) { << driver_version_ % 100 << ", Runtime API Version: " << runtime_version_ / 10000 << "." << (runtime_version_ / 100) % 100 << "." - << runtime_version_ % 100; + << runtime_version_ % 100 + << ", Cnnl API Version: " << cnnl_version_ / 10000 + << "." << (cnnl_version_ / 100) % 100 << "." + << cnnl_version_ % 100; default_ctx_.reset(new MLUContext(place_)); } diff --git a/paddle/fluid/platform/device/mlu/device_context.h b/paddle/fluid/platform/device/mlu/device_context.h index d607b1e12f5..d8bb7623159 100644 --- a/paddle/fluid/platform/device/mlu/device_context.h +++ b/paddle/fluid/platform/device/mlu/device_context.h @@ -134,6 +134,7 @@ class MLUDeviceContext : public DeviceContext { int compute_capability_; int driver_version_; int runtime_version_; + int cnnl_version_; MLUPlace place_; std::shared_ptr default_ctx_; diff --git a/paddle/fluid/platform/device/mlu/mlu_info.cc b/paddle/fluid/platform/device/mlu/mlu_info.cc index e3672707210..63b45495cd2 100644 --- a/paddle/fluid/platform/device/mlu/mlu_info.cc +++ b/paddle/fluid/platform/device/mlu/mlu_info.cc @@ -116,6 +116,13 @@ int GetMLURuntimeVersion(int id) { return x * 10000 + y * 100 + z; } +int GetMLUCnnlVersion(int id) { + CheckDeviceId(id); + int x, y, z; + cnnlGetLibVersion(&x, &y, &z); + return x * 10000 + y * 100 + z; +} + int GetMLUCurrentDeviceId() { int device_id; PADDLE_ENFORCE_MLU_SUCCESS(cnrtGetDevice(&device_id)); diff --git a/paddle/fluid/platform/device/mlu/mlu_info.h b/paddle/fluid/platform/device/mlu/mlu_info.h index 12c206ef2c4..6ee754fcc89 100644 --- a/paddle/fluid/platform/device/mlu/mlu_info.h +++ b/paddle/fluid/platform/device/mlu/mlu_info.h @@ -46,6 +46,9 @@ int GetMLUDriverVersion(int id); //! Get the runtime version of the ith MLU. int GetMLURuntimeVersion(int id); +//! Get the cnnl version of the ith MLU. +int GetMLUCnnlVersion(int id); + //! Get the total number of MLU devices in system. int GetMLUDeviceCount(); diff --git a/python/paddle/fluid/tests/unittests/mlu/test_elementwise_min_op_mlu.py b/python/paddle/fluid/tests/unittests/mlu/test_elementwise_min_op_mlu.py new file mode 100644 index 00000000000..f04f0eb781e --- /dev/null +++ b/python/paddle/fluid/tests/unittests/mlu/test_elementwise_min_op_mlu.py @@ -0,0 +1,230 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import numpy as np +import unittest +import sys + +sys.path.append("..") +from op_test import OpTest, skip_check_grad_ci +import paddle +import paddle.fluid as fluid +from paddle.fluid import Program, program_guard +import paddle.fluid.core as core + +paddle.enable_static() +SEED = 2022 + + +class TestElementwiseMinOp(OpTest): + + def setUp(self): + self.set_mlu() + self.op_type = "elementwise_min" + self.init_dtype() + self.init_input_output() + self.inputs = { + 'X': OpTest.np_dtype_to_fluid_dtype(self.x), + 'Y': OpTest.np_dtype_to_fluid_dtype(self.y) + } + self.outputs = {'Out': self.out} + self.attrs = {'axis': self.axis} + + def set_mlu(self): + self.__class__.use_mlu = True + self.place = paddle.device.MLUPlace(0) + + def init_input_output(self): + # If x and y have the same value, the min() is not differentiable. + # So we generate test data by the following method + # to avoid them being too close to each other. + self.x = np.random.uniform(0.1, 1, [13, 17]).astype(self.dtype) + self.sgn = np.random.choice([-1, 1], [13, 17]).astype(self.dtype) + self.y = self.x + self.sgn * np.random.uniform(0.1, 1, [13, 17]).astype( + self.dtype) + self.out = np.minimum(self.x, self.y) + self.axis = -1 + + def init_dtype(self): + self.dtype = np.float32 + + def test_check_output(self): + self.check_output_with_place(self.place) + + def test_check_grad_normal(self): + if self.dtype == np.float16: + self.check_grad_with_place(self.place, ['X', 'Y'], + 'Out', + max_relative_error=0.5) + else: + self.check_grad_with_place( + self.place, + ['X', 'Y'], + 'Out', + ) + + def test_check_grad_ingore_x(self): + if self.dtype == np.float16: + self.check_grad_with_place(self.place, ['Y'], + 'Out', + no_grad_set=set("X"), + max_relative_error=0.9) + else: + self.check_grad_with_place( + self.place, + ['Y'], + 'Out', + no_grad_set=set("X"), + ) + + def test_check_grad_ingore_y(self): + if self.dtype == np.float16: + self.check_grad_with_place(self.place, ['X'], + 'Out', + no_grad_set=set("Y"), + max_relative_error=0.1) + else: + self.check_grad_with_place( + self.place, + ['X'], + 'Out', + no_grad_set=set("Y"), + ) + + +class TestElementwiseMinOpFp16(TestElementwiseMinOp): + + def init_dtype(self): + self.dtype = np.float16 + + +class TestElementwiseMinOp_Vector(TestElementwiseMinOp): + + def init_input_output(self): + self.x = np.random.uniform(1, 2, (100, )).astype(self.dtype) + self.sgn = np.random.choice([-1, 1], (100, )).astype(self.dtype) + self.y = self.x + self.sgn * np.random.uniform(0.1, 1, (100, )).astype( + self.dtype) + self.out = np.minimum(self.x, self.y) + self.axis = -1 + + +class TestElementwiseMinOpFp16_Vector(TestElementwiseMinOp_Vector): + + def init_dtype(self): + self.dtype = np.float16 + + +@skip_check_grad_ci( + reason="[skip shape check] Use y_shape(1) to test broadcast.") +class TestElementwiseMinOp_scalar(TestElementwiseMinOp): + + def init_input_output(self): + self.x = np.random.random_integers(-5, 5, [10, 3, 4]).astype(self.dtype) + self.y = np.array([0.5]).astype(self.dtype) + self.out = np.minimum(self.x, self.y) + self.axis = -1 + + +@skip_check_grad_ci( + reason="[skip shape check] Use y_shape(1) to test broadcast.") +class TestElementwiseMinOpFp16_scalar(TestElementwiseMinOp_scalar): + + def init_dtype(self): + self.dtype = np.float16 + + +class TestElementwiseMinOp_broadcast(TestElementwiseMinOp): + + def init_input_output(self): + self.x = np.random.uniform(0.5, 1, (2, 3, 100)).astype(self.dtype) + self.sgn = np.random.choice([-1, 1], (100, )).astype(self.dtype) + self.y = self.x[0, 0, :] + self.sgn * \ + np.random.uniform(1, 2, (100, )).astype(self.dtype) + self.out = np.minimum(self.x, self.y.reshape(1, 1, 100)) + self.axis = -1 + + +class TestElementwiseMinOpFp16_broadcast(TestElementwiseMinOp_broadcast): + + def init_dtype(self): + self.dtype = np.float16 + + +class TestElementwiseMinOpNet(unittest.TestCase): + + def _test(self, run_mlu=True): + main_prog = paddle.static.Program() + startup_prog = paddle.static.Program() + main_prog.random_seed = SEED + startup_prog.random_seed = SEED + np.random.seed(SEED) + + a_np = np.random.random(size=(32, 32)).astype('float32') + b_np = np.random.random(size=(32, 32)).astype('float32') + label_np = np.random.randint(2, size=(32, 1)).astype('int64') + + with paddle.static.program_guard(main_prog, startup_prog): + a = paddle.static.data(name="a", shape=[32, 32], dtype='float32') + b = paddle.static.data(name="b", shape=[32, 32], dtype='float32') + label = paddle.static.data(name="label", + shape=[32, 1], + dtype='int64') + + c = paddle.minimum(a, b) + + fc_1 = fluid.layers.fc(input=c, size=128) + prediction = fluid.layers.fc(input=fc_1, size=2, act='softmax') + + cost = fluid.layers.cross_entropy(input=prediction, label=label) + loss = fluid.layers.reduce_mean(cost) + sgd = fluid.optimizer.SGD(learning_rate=0.01) + sgd.minimize(loss) + + if run_mlu: + place = paddle.device.MLUPlace(0) + else: + place = paddle.CPUPlace() + + exe = paddle.static.Executor(place) + exe.run(startup_prog) + + print("Start run on {}".format(place)) + for epoch in range(100): + + pred_res, loss_res = exe.run(main_prog, + feed={ + "a": a_np, + "b": b_np, + "label": label_np + }, + fetch_list=[prediction, loss]) + if epoch % 10 == 0: + print("Epoch {} | Prediction[0]: {}, Loss: {}".format( + epoch, pred_res[0], loss_res)) + + return pred_res, loss_res + + def test_mlu(self): + cpu_pred, cpu_loss = self._test(False) + mlu_pred, mlu_loss = self._test(True) + + self.assertTrue(np.allclose(mlu_pred, cpu_pred)) + self.assertTrue(np.allclose(mlu_loss, cpu_loss)) + + +if __name__ == '__main__': + unittest.main() -- GitLab