diff --git a/paddle/fluid/operators/activation_op_mlu.cc b/paddle/fluid/operators/activation_op_mlu.cc index 43d662830c0c8b7b14d7dd023666e8f11f2817a4..f66b75fd1f3197230f9c4d304c49a92d18bbcad0 100644 --- a/paddle/fluid/operators/activation_op_mlu.cc +++ b/paddle/fluid/operators/activation_op_mlu.cc @@ -15,12 +15,8 @@ limitations under the Licnse. */ #include #include -#include "paddle/fluid/framework/framework.pb.h" -#include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/operators/activation_op.h" #include "paddle/fluid/operators/mlu/mlu_baseop.h" -#include "paddle/fluid/platform/device/mlu/device_context.h" -#include "paddle/phi/core/ddim.h" namespace paddle { namespace operators { @@ -38,20 +34,39 @@ class ActivationMLUKernel : public framework::OpKernel { output->mutable_data(ctx.GetPlace()); MLUCnnlActivationDesc act_desc(act_mode, alpha); - MLUCnnlTensorDesc input_desc(*input, CNNL_LAYOUT_ARRAY, - ToCnnlDataType(input->dtype())); - MLUCnnlTensorDesc output_desc(*output, CNNL_LAYOUT_ARRAY, - ToCnnlDataType(output->dtype())); - - MLUCnnl::Active(ctx, act_desc.get(), input_desc.get(), - reinterpret_cast(input->data()), - output_desc.get(), - reinterpret_cast(output->data())); + MLUCnnlTensorDesc input_desc(*input); + MLUCnnlTensorDesc output_desc(*output); + + MLUCnnl::Active(ctx, act_desc.get(), input_desc.get(), GetBasePtr(input), + output_desc.get(), GetBasePtr(output)); } }; +// For gelu, leaky_relu template -class ActivationGradMLUKernel : public framework::OpKernel { +class ActivationGradMLUKernelV1 : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* x = ctx.Input("X"); + auto* dout = ctx.Input(framework::GradVarName("Out")); + auto* dx = ctx.Output(framework::GradVarName("X")); + float alpha = ctx.HasAttr("alpha") ? ctx.Attr("alpha") : 1.0f; + + dx->mutable_data(ctx.GetPlace()); + + MLUCnnlTensorDesc x_desc(*x); + MLUCnnlTensorDesc dout_desc(*dout); + MLUCnnlTensorDesc dx_desc(*dx); + MLUCnnlActivationDesc act_desc(act_mode, alpha); + MLUCnnl::ActiveGrad(ctx, act_desc.get(), nullptr, nullptr, nullptr, nullptr, + dout_desc.get(), GetBasePtr(dout), x_desc.get(), + GetBasePtr(x), dx_desc.get(), GetBasePtr(dx)); + } +}; + +// For tanh, sigmoid +template +class ActivationGradMLUKernelV2 : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { auto* out = ctx.Input("Out"); @@ -61,18 +76,35 @@ class ActivationGradMLUKernel : public framework::OpKernel { dx->mutable_data(ctx.GetPlace()); - MLUCnnlTensorDesc dout_desc(*dout, CNNL_LAYOUT_ARRAY, - ToCnnlDataType(dout->dtype())); - MLUCnnlTensorDesc out_desc(*out, CNNL_LAYOUT_ARRAY, - ToCnnlDataType(out->dtype())); - MLUCnnlTensorDesc dx_desc(*dx, CNNL_LAYOUT_ARRAY, - ToCnnlDataType(dx->dtype())); + MLUCnnlTensorDesc out_desc(*out); + MLUCnnlTensorDesc dout_desc(*dout); + MLUCnnlTensorDesc dx_desc(*dx); MLUCnnlActivationDesc act_desc(act_mode, alpha); - MLUCnnl::ActiveGrad( - ctx, act_desc.get(), nullptr, nullptr, nullptr, nullptr, - dout_desc.get(), reinterpret_cast(dout->data()), - out_desc.get(), reinterpret_cast(out->data()), - dx_desc.get(), reinterpret_cast(dx->data())); + MLUCnnl::ActiveGrad(ctx, act_desc.get(), nullptr, nullptr, out_desc.get(), + GetBasePtr(out), dout_desc.get(), GetBasePtr(dout), + nullptr, nullptr, dx_desc.get(), GetBasePtr(dx)); + } +}; + +// For relu, relu6 +template +class ActivationGradMLUKernelV3 : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* out = ctx.Input("Out"); + auto* dout = ctx.Input(framework::GradVarName("Out")); + auto* dx = ctx.Output(framework::GradVarName("X")); + float alpha = ctx.HasAttr("alpha") ? ctx.Attr("alpha") : 1.0f; + + dx->mutable_data(ctx.GetPlace()); + + MLUCnnlTensorDesc out_desc(*out); + MLUCnnlTensorDesc dout_desc(*dout); + MLUCnnlTensorDesc dx_desc(*dx); + MLUCnnlActivationDesc act_desc(act_mode, alpha); + MLUCnnl::ActiveGrad(ctx, act_desc.get(), nullptr, nullptr, nullptr, nullptr, + dout_desc.get(), GetBasePtr(dout), out_desc.get(), + GetBasePtr(out), dx_desc.get(), GetBasePtr(dx)); } }; @@ -81,10 +113,60 @@ class ActivationGradMLUKernel : public framework::OpKernel { namespace ops = paddle::operators; +// relu REGISTER_OP_MLU_KERNEL( relu, ops::ActivationMLUKernel, ops::ActivationMLUKernel); REGISTER_OP_MLU_KERNEL( - relu_grad, ops::ActivationGradMLUKernel, - ops::ActivationGradMLUKernel); + relu_grad, ops::ActivationGradMLUKernelV3, + ops::ActivationGradMLUKernelV3); + +// relu6 +REGISTER_OP_MLU_KERNEL( + relu6, ops::ActivationMLUKernel, + ops::ActivationMLUKernel); +REGISTER_OP_MLU_KERNEL( + relu6_grad, ops::ActivationGradMLUKernelV3, + ops::ActivationGradMLUKernelV3); + +// sigmoid +REGISTER_OP_MLU_KERNEL(sigmoid, + ops::ActivationMLUKernel, + ops::ActivationMLUKernel); +REGISTER_OP_MLU_KERNEL( + sigmoid_grad, + ops::ActivationGradMLUKernelV2, + ops::ActivationGradMLUKernelV2); + +// tanh +REGISTER_OP_MLU_KERNEL( + tanh, ops::ActivationMLUKernel, + ops::ActivationMLUKernel); +REGISTER_OP_MLU_KERNEL( + tanh_grad, ops::ActivationGradMLUKernelV2, + ops::ActivationGradMLUKernelV2); + +// gelu +REGISTER_OP_MLU_KERNEL( + gelu, ops::ActivationMLUKernel, + ops::ActivationMLUKernel); +REGISTER_OP_MLU_KERNEL( + gelu_grad, ops::ActivationGradMLUKernelV1, + ops::ActivationGradMLUKernelV1); + +// leaky_relu +REGISTER_OP_MLU_KERNEL( + leaky_relu, ops::ActivationMLUKernel, + ops::ActivationMLUKernel); +REGISTER_OP_MLU_KERNEL( + leaky_relu_grad, + ops::ActivationGradMLUKernelV1, + ops::ActivationGradMLUKernelV1); diff --git a/paddle/fluid/operators/fill_constant_op_mlu.cc b/paddle/fluid/operators/fill_constant_op_mlu.cc index 10e7c72d158e6f6ec16d9d62886e9efe82145c9d..f7463c5dd8821385bf067780aa80850eb3765cb2 100644 --- a/paddle/fluid/operators/fill_constant_op_mlu.cc +++ b/paddle/fluid/operators/fill_constant_op_mlu.cc @@ -51,6 +51,8 @@ class FillConstantMLUKernel : public framework::OpKernel { } } } + const T *value_data = &value; + cnnlPointerMode_t pointer_mode = CNNL_POINTER_MODE_HOST; if (ctx.HasInput("ValueTensor")) { auto *value_tensor = ctx.Input("ValueTensor"); PADDLE_ENFORCE_EQ( @@ -59,22 +61,18 @@ class FillConstantMLUKernel : public framework::OpKernel { "When use Tensor as value to set Tensor value in fill_cosntant, " "value input(ValueTensor) size must be 1, but get %d", value_tensor->numel())); - const T *tensor_data = value_tensor->data(); - framework::Tensor mlu_tensor; + value_data = value_tensor->data(); auto tmp_place = value_tensor->place(); if (platform::is_mlu_place(tmp_place)) { - framework::TensorCopySync(*value_tensor, platform::CPUPlace(), - &mlu_tensor); - tensor_data = mlu_tensor.data(); + pointer_mode = CNNL_POINTER_MODE_DEVICE; } - value = tensor_data[0]; } auto shape = GetShape(ctx); out_var->mutable_data(shape, ctx.GetPlace()); - MLUCnnlTensorDesc output_desc(*out_var, CNNL_LAYOUT_ARRAY, - ToCnnlDataType(out_var->dtype())); - MLUCnnl::Fill(ctx, value, output_desc.get(), GetBasePtr(out_var)); + MLUCnnlTensorDesc output_desc(*out_var); + MLUCnnl::Fill(ctx, pointer_mode, value_data, output_desc.get(), + GetBasePtr(out_var)); } }; } // namespace operators diff --git a/paddle/fluid/operators/mean_op_mlu.cc b/paddle/fluid/operators/mean_op_mlu.cc index 1fed01194c1a6c4f5743d98e09db1993c8c8e998..1456e749b1343603851548899b6aa522f06a88b8 100644 --- a/paddle/fluid/operators/mean_op_mlu.cc +++ b/paddle/fluid/operators/mean_op_mlu.cc @@ -95,7 +95,8 @@ class MeanMLUGradKernel : public framework::OpKernel { MLUCnnlTensorDesc mean_var_desc(mean_var, CNNL_LAYOUT_ARRAY, ToCnnlDataType(mean_var.dtype())); auto value = static_cast(1.0 / static_cast(input_grad->numel())); - MLUCnnl::Fill(context, value, mean_var_desc.get(), GetBasePtr(&mean_var)); + MLUCnnl::Fill(context, CNNL_POINTER_MODE_HOST, &value, mean_var_desc.get(), + GetBasePtr(&mean_var)); // means mul output_grad MLUCnnlTensorDesc in_desc(*output_grad, CNNL_LAYOUT_ARRAY, diff --git a/paddle/fluid/operators/metrics/accuracy_op_mlu.cc b/paddle/fluid/operators/metrics/accuracy_op_mlu.cc index 1ce02ff4525c9692f88ed42b79ff336cc0113c41..26c31d82e36eb38d97a4d3c94c418c653cb174b0 100644 --- a/paddle/fluid/operators/metrics/accuracy_op_mlu.cc +++ b/paddle/fluid/operators/metrics/accuracy_op_mlu.cc @@ -136,15 +136,17 @@ class AccuracyMLUKernel : public framework::OpKernel { // [total] total->mutable_data(ctx.GetPlace()); MLUCnnlTensorDesc total_desc(*total); - MLUCnnl::Fill(ctx, num_samples, total_desc.get(), GetBasePtr(total)); + MLUCnnl::Fill(ctx, CNNL_POINTER_MODE_HOST, &num_samples, total_desc.get(), + GetBasePtr(total)); // use `total` of type `float32` for calculating accuracy Tensor total_fp32(framework::TransToPhiDataType(VT::FP32)); total_fp32.Resize(total->dims()); total_fp32.mutable_data(ctx.GetPlace()); MLUCnnlTensorDesc total_fp32_desc(total_fp32); - MLUCnnl::Fill(ctx, static_cast(num_samples), total_fp32_desc.get(), - GetBasePtr(&total_fp32)); + float num_samples_fp32 = static_cast(num_samples); + MLUCnnl::Fill(ctx, CNNL_POINTER_MODE_HOST, &num_samples_fp32, + total_fp32_desc.get(), GetBasePtr(&total_fp32)); // [accuracy] accuracy->mutable_data(ctx.GetPlace()); diff --git a/paddle/fluid/operators/mlu/mlu_baseop.cc b/paddle/fluid/operators/mlu/mlu_baseop.cc index 1fdaa153e3c27ed1a83696bf03d68dbfd2b93ae9..df091a7dc7535745a0fbe33c77f15478265d2217 100644 --- a/paddle/fluid/operators/mlu/mlu_baseop.cc +++ b/paddle/fluid/operators/mlu/mlu_baseop.cc @@ -208,8 +208,20 @@ MLUCnnlTensorDesc::~MLUCnnlTensorDesc() { MLUCnnlActivationDesc::MLUCnnlActivationDesc( const cnnlActivationMode_t act_mode, const float ceof) { PADDLE_ENFORCE_MLU_SUCCESS(cnnlCreateActivationDescriptor(&active_desc_)); - PADDLE_ENFORCE_MLU_SUCCESS(cnnlSetActivationDescriptor( - active_desc_, act_mode, CNNL_NOT_PROPAGATE_NAN, ceof)); + PADDLE_ENFORCE_MLU_SUCCESS(cnnlSetActivationDescriptor_v4( + active_desc_, act_mode, CNNL_ACTIVATION_HIGH_PRECISION, + CNNL_NOT_PROPAGATE_NAN, ceof, 1.0f /*sliced_dim*/, + 1.67326319217681884765625 /*selu_alpha*/, + 1.05070102214813232421875 /*selu_lambda*/)); +} + +MLUCnnlActivationDesc::MLUCnnlActivationDesc( + const cnnlActivationMode_t act_mode, const float ceof, + const float sliced_dim, const float selu_alpha, const float selu_lambda) { + PADDLE_ENFORCE_MLU_SUCCESS(cnnlCreateActivationDescriptor(&active_desc_)); + PADDLE_ENFORCE_MLU_SUCCESS(cnnlSetActivationDescriptor_v4( + active_desc_, act_mode, CNNL_ACTIVATION_HIGH_PRECISION, + CNNL_NOT_PROPAGATE_NAN, ceof, sliced_dim, selu_alpha, selu_lambda)); } const cnnlActivationDescriptor_t MLUCnnlActivationDesc::get() const { @@ -541,12 +553,15 @@ MLUCnnlTrigonDesc::~MLUCnnlTrigonDesc() { output_desc, output)); } -/* static */ void MLUCnnl::Fill(const ExecutionContext& ctx, float value, +/* static */ void MLUCnnl::Fill(const ExecutionContext& ctx, + const cnnlPointerMode_t pointer_mode, + const void* value_ptr, const cnnlTensorDescriptor_t output_desc, void* output) { cnnlHandle_t handle = GetHandleFromCTX(ctx); - PADDLE_ENFORCE_MLU_SUCCESS(cnnlFill(handle, value, output_desc, output)); + PADDLE_ENFORCE_MLU_SUCCESS( + cnnlFill_v3(handle, pointer_mode, value_ptr, output_desc, output)); } /* static */ void MLUCnnl::QuantifyOffline( diff --git a/paddle/fluid/operators/mlu/mlu_baseop.h b/paddle/fluid/operators/mlu/mlu_baseop.h index b55b10686e92e2b1b5b3a7390289f8329ac04a04..64a99b2a6d27365d624ba42d43907fa770d25648 100644 --- a/paddle/fluid/operators/mlu/mlu_baseop.h +++ b/paddle/fluid/operators/mlu/mlu_baseop.h @@ -218,6 +218,9 @@ class MLUCnnlActivationDesc { MLUCnnlActivationDesc(const MLUCnnlActivationDesc& desc) = delete; MLUCnnlActivationDesc& operator=(const MLUCnnlActivationDesc& desc) = delete; MLUCnnlActivationDesc(const cnnlActivationMode_t act_mode, const float ceof); + MLUCnnlActivationDesc(const cnnlActivationMode_t act_mode, const float ceof, + const float sliced_dim, const float selu_alpha, + const float selu_lambda); const cnnlActivationDescriptor_t get() const; ~MLUCnnlActivationDesc(); @@ -418,7 +421,8 @@ class MLUCnnl { const cnnlTensorDescriptor_t in1_desc, const void* in1, const cnnlTensorDescriptor_t output_desc, void* output); - static void Fill(const ExecutionContext& ctx, float value, + static void Fill(const ExecutionContext& ctx, + const cnnlPointerMode_t pointer_mode, const void* value_ptr, const cnnlTensorDescriptor_t output_desc, void* output); static void LRN(const ExecutionContext& ctx, const int local_size, diff --git a/paddle/fluid/operators/optimizers/merged_momentum_op_mlu.cc b/paddle/fluid/operators/optimizers/merged_momentum_op_mlu.cc index e5399ee36ba7ff4a983448d607c108db8870138c..b84a2bc579d3e7b9a1c9d594f9316c2ff38aff72 100644 --- a/paddle/fluid/operators/optimizers/merged_momentum_op_mlu.cc +++ b/paddle/fluid/operators/optimizers/merged_momentum_op_mlu.cc @@ -69,7 +69,7 @@ class MLUMergedMomentumOpKernel : public framework::OpKernel { "the same Tensors.")); } - auto mu = ctx.Attr("mu"); + auto mu = static_cast(ctx.Attr("mu")); auto lrs = ctx.MultiInput("LearningRate"); if (lrs.size() != 1) { PADDLE_ENFORCE_EQ( @@ -114,7 +114,8 @@ class MLUMergedMomentumOpKernel : public framework::OpKernel { Tensor mu_tensor = ctx.AllocateTmpTensor({1}, dev_ctx); MLUCnnlTensorDesc mu_tensor_desc(mu_tensor); - MLUCnnl::Fill(ctx, mu, mu_tensor_desc.get(), GetBasePtr(&mu_tensor)); + MLUCnnl::Fill(ctx, CNNL_POINTER_MODE_HOST, &mu, mu_tensor_desc.get(), + GetBasePtr(&mu_tensor)); for (size_t idx = 0; idx < n; ++idx) { RegularizationType regularization_flag = diff --git a/paddle/fluid/operators/optimizers/momentum_op_mlu.cc b/paddle/fluid/operators/optimizers/momentum_op_mlu.cc index 91e8aa643b98160badd79f7669b4223fcf3afccb..71af14fd91c8c5f75f469840581052bdc068b2bd 100644 --- a/paddle/fluid/operators/optimizers/momentum_op_mlu.cc +++ b/paddle/fluid/operators/optimizers/momentum_op_mlu.cc @@ -52,7 +52,8 @@ class MLUMomentumOpKernel : public framework::OpKernel { Tensor mu_tensor = ctx.AllocateTmpTensor({1}, dev_ctx); MLUCnnlTensorDesc mu_tensor_desc(mu_tensor); - MLUCnnl::Fill(ctx, mu, mu_tensor_desc.get(), GetBasePtr(&mu_tensor)); + MLUCnnl::Fill(ctx, CNNL_POINTER_MODE_HOST, &mu, mu_tensor_desc.get(), + GetBasePtr(&mu_tensor)); Tensor regularized_grad; MLUCnnlTensorDesc param_desc(*param); diff --git a/paddle/fluid/operators/reduce_ops/reduce_mean_op_mlu.cc b/paddle/fluid/operators/reduce_ops/reduce_mean_op_mlu.cc index 45f4e43378f4444b959ecd17567d2b70ee9f417f..89e578dbdb6b7cae33ea4911f6497da810bc29ff 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_mean_op_mlu.cc +++ b/paddle/fluid/operators/reduce_ops/reduce_mean_op_mlu.cc @@ -103,8 +103,8 @@ class ReduceMeanGradMLUKernel : public framework::OpKernel { ToCnnlDataType(input_grad->dtype())); auto value = static_cast(1.0 / static_cast(reduce_numel)); - MLUCnnl::Fill(context, value, input_grad_desc.get(), - GetBasePtr(input_grad)); + MLUCnnl::Fill(context, CNNL_POINTER_MODE_HOST, &value, + input_grad_desc.get(), GetBasePtr(input_grad)); MLUCnnlOpTensorDesc op_tensor_desc(CNNL_OP_TENSOR_MUL, ToCnnlDataType(), CNNL_NOT_PROPAGATE_NAN); diff --git a/paddle/fluid/operators/scale_op_mlu.cc b/paddle/fluid/operators/scale_op_mlu.cc index 5237e70e319ad3c99efa670f3f8329eacd8d6220..f9e313e64b1e14d92e9ed6d030eacb93e9b0b5bc 100644 --- a/paddle/fluid/operators/scale_op_mlu.cc +++ b/paddle/fluid/operators/scale_op_mlu.cc @@ -27,7 +27,7 @@ class ScaleMLUKernel : public framework::OpKernel { auto* in = framework::GetLoDTensorOrSelectedRowsValueFromVar(*in_var); // cnnl require input, scale, bias with same type. And all in device side. - auto& scale = ctx.Attr("scale"); + auto scale = static_cast(ctx.Attr("scale")); framework::Tensor scale_tensor; if (ctx.HasInput("ScaleTensor")) { framework::Tensor float_scale_tensor = @@ -49,14 +49,16 @@ class ScaleMLUKernel : public framework::OpKernel { } else { scale_tensor = ctx.AllocateTmpTensor({1}, dev_ctx); MLUCnnlTensorDesc scale_desc(scale_tensor); - MLUCnnl::Fill(ctx, scale, scale_desc.get(), GetBasePtr(&scale_tensor)); + MLUCnnl::Fill(ctx, CNNL_POINTER_MODE_HOST, &scale, scale_desc.get(), + GetBasePtr(&scale_tensor)); } - auto& bias = ctx.Attr("bias"); + auto bias = static_cast(ctx.Attr("bias")); framework::Tensor bias_tensor = ctx.AllocateTmpTensor({1}, dev_ctx); MLUCnnlTensorDesc bias_desc(bias_tensor); - MLUCnnl::Fill(ctx, bias, bias_desc.get(), GetBasePtr(&bias_tensor)); + MLUCnnl::Fill(ctx, CNNL_POINTER_MODE_HOST, &bias, bias_desc.get(), + GetBasePtr(&bias_tensor)); auto* out_var = ctx.OutputVar("Out"); if (in_var->IsType() && in_var != out_var) { diff --git a/python/paddle/fluid/tests/unittests/mlu/test_gelu_op_mlu.py b/python/paddle/fluid/tests/unittests/mlu/test_gelu_op_mlu.py new file mode 100644 index 0000000000000000000000000000000000000000..c62d30d43c08984937cafcae4e613528229e8103 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/mlu/test_gelu_op_mlu.py @@ -0,0 +1,151 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import numpy as np +from scipy import special +import unittest +import sys +sys.path.append("..") +from op_test import OpTest +import paddle +import paddle.fluid as fluid + +paddle.enable_static() +SEED = 2021 + + +def np_gelu(x): + y = 0.5 * x * (1 + special.erf(x / np.sqrt(2))) + return y + + +class TestGelu(OpTest): + def setUp(self): + self.set_mlu() + self.op_type = "gelu" + self.place = paddle.MLUPlace(0) + + self.init_dtype() + np.random.seed(SEED) + x = np.random.uniform(1, 2, [11, 17]).astype(self.dtype) + out = np_gelu(x) + + self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} + self.attrs = {} + self.outputs = {'Out': out} + + def set_mlu(self): + self.__class__.use_mlu = True + + def init_dtype(self): + self.dtype = np.float32 + + def test_check_output(self): + self.check_output_with_place(self.place, atol=1e-3) + + def test_check_grad(self): + self.check_grad_with_place( + self.place, ['X'], 'Out', max_relative_error=0.007) + + +class TestGeluFp16(OpTest): + def setUp(self): + self.set_mlu() + self.op_type = "gelu" + self.place = paddle.MLUPlace(0) + + self.init_dtype() + np.random.seed(SEED) + x = np.random.uniform(1, 2, [3, 4]).astype(self.dtype) + out = np_gelu(x) + + self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} + self.attrs = {} + self.outputs = {'Out': out} + + def set_mlu(self): + self.__class__.use_mlu = True + self.__class__.no_need_check_grad = True + + def init_dtype(self): + self.dtype = np.float16 + + def test_check_output(self): + self.check_output_with_place(self.place, atol=1e-3) + + +class TestGeluNet(unittest.TestCase): + def _test(self, run_mlu=True): + main_prog = paddle.static.Program() + startup_prog = paddle.static.Program() + main_prog.random_seed = SEED + startup_prog.random_seed = SEED + np.random.seed(SEED) + + a_np = np.random.random(size=(32, 32)).astype('float32') + b_np = np.random.random(size=(32, 32)).astype('float32') + label_np = np.random.randint(2, size=(32, 1)).astype('int64') + + with paddle.static.program_guard(main_prog, startup_prog): + a = paddle.static.data(name="a", shape=[32, 32], dtype='float32') + b = paddle.static.data(name="b", shape=[32, 32], dtype='float32') + label = paddle.static.data( + name="label", shape=[32, 1], dtype='int64') + + c = paddle.multiply(a, b) + + fc_1 = fluid.layers.fc(input=c, size=128) + fc_1_gelu = fluid.layers.gelu(fc_1) + prediction = fluid.layers.fc(input=fc_1_gelu, size=2, act='softmax') + + cost = fluid.layers.cross_entropy(input=prediction, label=label) + loss = fluid.layers.reduce_mean(cost) + sgd = fluid.optimizer.SGD(learning_rate=0.01) + sgd.minimize(loss) + + if run_mlu: + place = paddle.MLUPlace(0) + else: + place = paddle.CPUPlace() + + exe = paddle.static.Executor(place) + exe.run(startup_prog) + + print("Start run on {}".format(place)) + for epoch in range(100): + + pred_res, loss_res = exe.run( + main_prog, + feed={"a": a_np, + "b": b_np, + "label": label_np}, + fetch_list=[prediction, loss]) + if epoch % 10 == 0: + print("Epoch {} | Prediction[0]: {}, Loss: {}".format( + epoch, pred_res[0], loss_res)) + + return pred_res, loss_res + + def test_mlu(self): + cpu_pred, cpu_loss = self._test(False) + mlu_pred, mlu_loss = self._test(True) + + self.assertTrue(np.allclose(mlu_pred, cpu_pred, atol=1e-3)) + self.assertTrue(np.allclose(mlu_loss, cpu_loss, atol=1e-3)) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/mlu/test_leaky_relu_op_mlu.py b/python/paddle/fluid/tests/unittests/mlu/test_leaky_relu_op_mlu.py new file mode 100644 index 0000000000000000000000000000000000000000..ec2150fceb133084135b6969fbaf859aa8b01579 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/mlu/test_leaky_relu_op_mlu.py @@ -0,0 +1,143 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import numpy as np +import unittest +import sys +sys.path.append("..") +from op_test import OpTest +from test_activation_op import ref_leaky_relu +import paddle +import paddle.fluid as fluid + +paddle.enable_static() +SEED = 2021 + + +class TestLeadyRelu(OpTest): + def setUp(self): + self.set_mlu() + self.op_type = "leaky_relu" + self.place = paddle.MLUPlace(0) + + self.init_dtype() + np.random.seed(SEED) + + self.set_inputs() + self.set_attrs() + self.set_outputs() + + def set_inputs(self): + x = np.random.uniform(-1, 1, [11, 17]).astype(self.dtype) + self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} + + def set_attrs(self): + self.attrs = {} + + def set_outputs(self): + alpha = 0.02 if 'alpha' not in self.attrs else self.attrs['alpha'] + out = ref_leaky_relu(self.inputs['X'], alpha) + self.outputs = {'Out': out} + + def set_mlu(self): + self.__class__.use_mlu = True + + def init_dtype(self): + self.dtype = np.float32 + + def test_check_output(self): + self.check_output_with_place(self.place) + + def test_check_grad(self): + if self.dtype == np.float16: + self.check_grad_with_place( + self.place, ['X'], 'Out', max_relative_error=0.006) + else: + self.check_grad_with_place(self.place, ['X'], 'Out') + + +class TestLeadyReluFP16(TestLeadyRelu): + def init_dtype(self): + self.dtype = np.float16 + + +class TestLeadyRelu2(TestLeadyRelu): + def set_attrs(self): + self.attrs = {'alpha': 0.5} + + +class TestLeadyRelu3(TestLeadyRelu): + def set_attrs(self): + self.attrs = {'alpha': -0.5} + + +class TestLeakyReluNet(unittest.TestCase): + def _test(self, run_mlu=True): + main_prog = paddle.static.Program() + startup_prog = paddle.static.Program() + main_prog.random_seed = SEED + startup_prog.random_seed = SEED + np.random.seed(SEED) + + x_np = np.random.random(size=(32, 32)).astype('float32') + label_np = np.random.randint(2, size=(32, 1)).astype('int64') + + with paddle.static.program_guard(main_prog, startup_prog): + x = paddle.static.data(name="x", shape=[32, 32], dtype='float32') + label = paddle.static.data( + name="label", shape=[32, 1], dtype='int64') + + y = paddle.nn.functional.leaky_relu(x) + + fc_1 = fluid.layers.fc(input=y, size=128) + prediction = fluid.layers.fc(input=fc_1, size=2, act='softmax') + + cost = fluid.layers.cross_entropy(input=prediction, label=label) + loss = fluid.layers.reduce_mean(cost) + sgd = fluid.optimizer.SGD(learning_rate=0.01) + sgd.minimize(loss) + + if run_mlu: + place = paddle.MLUPlace(0) + else: + place = paddle.CPUPlace() + + exe = paddle.static.Executor(place) + exe.run(startup_prog) + + print("Start run on {}".format(place)) + for epoch in range(100): + + pred_res, loss_res = exe.run(main_prog, + feed={"x": x_np, + "label": label_np}, + fetch_list=[prediction, loss]) + if epoch % 10 == 0: + print("Epoch {} | Prediction[0]: {}, Loss: {}".format( + epoch, pred_res[0], loss_res)) + + return pred_res, loss_res + + def test_mlu(self): + cpu_pred, cpu_loss = self._test(False) + mlu_pred, mlu_loss = self._test(True) + + self.assertTrue(np.allclose(mlu_pred, cpu_pred)) + self.assertTrue(np.allclose(mlu_loss, cpu_loss)) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/mlu/test_relu6_op_mlu.py b/python/paddle/fluid/tests/unittests/mlu/test_relu6_op_mlu.py new file mode 100644 index 0000000000000000000000000000000000000000..54b1afd03633175501f0e670b97a065e18e984bb --- /dev/null +++ b/python/paddle/fluid/tests/unittests/mlu/test_relu6_op_mlu.py @@ -0,0 +1,164 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +import paddle.fluid as fluid +import paddle +from op_test import OpTest + +import numpy as np +import unittest +import sys +sys.path.append("..") + +paddle.enable_static() +SEED = 2021 + + +def ref_relu6(x, threshold=6.0): + out = np.copy(x) + out[np.abs(x - threshold) < 0.005] = threshold + 0.02 + out = np.minimum(np.maximum(x, 0), threshold) + return out + + +class TestRelu6(OpTest): + def setUp(self): + self.set_mlu() + self.op_type = "relu6" + self.place = paddle.MLUPlace(0) + + self.init_dtype() + np.random.seed(SEED) + x = np.random.uniform(-1, 10, [10, 12]).astype(self.dtype) + x[np.abs(x) < 0.005] = 0.02 + out = ref_relu6(x) + + self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} + self.attrs = {'threshold': 6.0} + self.outputs = {'Out': out} + + def set_mlu(self): + self.__class__.use_mlu = True + + def test_check_output(self): + self.check_output_with_place(self.place) + + def test_check_grad(self): + self.check_grad_with_place(self.place, ['X'], 'Out') + + def init_dtype(self): + self.dtype = np.float32 + + +class TestRelu6Float16(TestRelu6): + def set_mlu(self): + self.__class__.use_mlu = True + self.__class__.no_need_check_grad = True + + def set_attrs(self): + self.dtype = np.float16 + + def test_check_output(self): + self.check_output_with_place(self.place) + + +class TestReluNeg(TestRelu6): + def setUp(self): + self.set_mlu() + self.op_type = "relu6" + self.place = paddle.MLUPlace(0) + + self.init_dtype() + np.random.seed(SEED) + x = np.random.uniform(-10, -1, [10, 12]).astype(self.dtype) + x[np.abs(x) < 0.005] = 0.02 + out = ref_relu6(x) + + self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} + self.attrs = {'threshold': 6.0} + self.outputs = {'Out': out} + + def set_mlu(self): + self.__class__.use_mlu = True + + def init_dtype(self): + self.dtype = np.float32 + + def test_check_output(self): + self.check_output_with_place(self.place) + + +class TestRelu6Net(unittest.TestCase): + def _test(self, run_mlu=True): + main_prog = paddle.static.Program() + startup_prog = paddle.static.Program() + main_prog.random_seed = SEED + startup_prog.random_seed = SEED + np.random.seed(SEED) + + a_np = np.random.random(size=(32, 32)).astype('float32') + b_np = np.random.random(size=(32, 32)).astype('float32') + label_np = np.random.randint(2, size=(32, 1)).astype('int64') + + with paddle.static.program_guard(main_prog, startup_prog): + a = paddle.static.data(name="a", shape=[32, 32], dtype='float32') + b = paddle.static.data(name="b", shape=[32, 32], dtype='float32') + label = paddle.static.data( + name="label", shape=[32, 1], dtype='int64') + + sum = paddle.add(a, b) + z = paddle.nn.functional.relu6(sum) + + fc_1 = fluid.layers.fc(input=z, size=128) + prediction = fluid.layers.fc(input=fc_1, size=2, act='softmax') + + cost = fluid.layers.cross_entropy(input=prediction, label=label) + loss = fluid.layers.reduce_mean(cost) + sgd = fluid.optimizer.SGD(learning_rate=0.01) + sgd.minimize(loss) + + if run_mlu: + place = paddle.MLUPlace(0) + else: + place = paddle.CPUPlace() + + exe = paddle.static.Executor(place) + exe.run(startup_prog) + + print("Start run on {}".format(place)) + for epoch in range(100): + + pred_res, loss_res = exe.run( + main_prog, + feed={"a": a_np, + "b": b_np, + "label": label_np}, + fetch_list=[prediction, loss]) + if epoch % 10 == 0: + print("Epoch {} | Prediction[0]: {}, Loss: {}".format( + epoch, pred_res[0], loss_res)) + + return pred_res, loss_res + + def test_mlu(self): + cpu_pred, cpu_loss = self._test(False) + mlu_pred, mlu_loss = self._test(True) + + self.assertTrue(np.allclose(mlu_pred, cpu_pred)) + self.assertTrue(np.allclose(mlu_loss, cpu_loss)) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/mlu/test_sigmoid_op_mlu.py b/python/paddle/fluid/tests/unittests/mlu/test_sigmoid_op_mlu.py new file mode 100644 index 0000000000000000000000000000000000000000..f4c5612377e1c82025debf97f24a9cc529486440 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/mlu/test_sigmoid_op_mlu.py @@ -0,0 +1,65 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import numpy as np +import unittest +import sys +from paddle.fluid.tests.unittests.op_test import OpTest +import paddle +import paddle.fluid as fluid + +paddle.enable_static() +SEED = 2021 + + +class TestMLUSigmoid(OpTest): + def setUp(self): + self.op_type = "sigmoid" + self.set_mlu() + self.init_dtype() + + np.random.seed(SEED) + x = np.random.uniform(-1, 1, [11, 17]).astype(self.dtype) + out = 1 / (1 + np.exp(-x)) + + self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} + self.outputs = {'Out': out} + + def test_check_output(self): + self.check_output_with_place(self.place) + + def test_check_grad(self): + self.check_grad_with_place( + self.place, ['X'], 'Out', max_relative_error=0.01) + + def set_mlu(self): + self.__class__.use_mlu = True + self.place = paddle.MLUPlace(0) + + def init_dtype(self): + self.dtype = np.float32 + + +class TestMLUSigmoidFp16(TestMLUSigmoid): + def test_check_output(self): + self.check_output_with_place(self.place, atol=1e-3) + + def init_dtype(self): + self.dtype = np.float16 + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/mlu/test_tanh_op_mlu.py b/python/paddle/fluid/tests/unittests/mlu/test_tanh_op_mlu.py new file mode 100644 index 0000000000000000000000000000000000000000..a5aeeac0ffb9e2ffd425e27decf827b229e18d3f --- /dev/null +++ b/python/paddle/fluid/tests/unittests/mlu/test_tanh_op_mlu.py @@ -0,0 +1,147 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import numpy as np +import unittest +import sys +sys.path.append("..") +from op_test import OpTest +import paddle +import paddle.fluid as fluid + +paddle.enable_static() +SEED = 2021 + + +class TestTanh(OpTest): + def setUp(self): + self.set_mlu() + self.op_type = "tanh" + self.place = paddle.MLUPlace(0) + + self.init_dtype() + np.random.seed(SEED) + x = np.random.uniform(1, 2, [11, 17]).astype(self.dtype) + out = np.tanh(x) + + self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} + self.attrs = {} + self.outputs = {'Out': out} + + def set_mlu(self): + self.__class__.use_mlu = True + + def init_dtype(self): + self.dtype = np.float32 + + def test_check_output(self): + self.check_output_with_place(self.place) + + def test_check_grad(self): + if self.dtype == np.float16: + self.check_grad(['X'], 'Out', max_relative_error=0.009) + else: + self.check_grad(['X'], 'Out', max_relative_error=0.009) + + +class TestTanhFp16(OpTest): + def setUp(self): + self.set_mlu() + self.op_type = "tanh" + self.place = paddle.MLUPlace(0) + + self.init_dtype() + np.random.seed(SEED) + x = np.random.uniform(1, 2, [3, 4]).astype(self.dtype) + out = np.tanh(x) + + self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} + self.attrs = {} + self.outputs = {'Out': out} + + def set_mlu(self): + self.__class__.use_mlu = True + self.__class__.no_need_check_grad = True + + def init_dtype(self): + self.dtype = np.float16 + + def test_check_output(self): + self.check_output_with_place(self.place, atol=1e-3) + + +class TestTanhNet(unittest.TestCase): + def _test(self, run_mlu=True): + main_prog = paddle.static.Program() + startup_prog = paddle.static.Program() + main_prog.random_seed = SEED + startup_prog.random_seed = SEED + np.random.seed(SEED) + + a_np = np.random.random(size=(32, 32)).astype('float32') + b_np = np.random.random(size=(32, 32)).astype('float32') + label_np = np.random.randint(2, size=(32, 1)).astype('int64') + + with paddle.static.program_guard(main_prog, startup_prog): + a = paddle.static.data(name="a", shape=[32, 32], dtype='float32') + b = paddle.static.data(name="b", shape=[32, 32], dtype='float32') + label = paddle.static.data( + name="label", shape=[32, 1], dtype='int64') + + c = paddle.multiply(a, b) + d = paddle.tanh(c) + + fc_1 = fluid.layers.fc(input=d, size=128) + prediction = fluid.layers.fc(input=fc_1, size=2, act='softmax') + + cost = fluid.layers.cross_entropy(input=prediction, label=label) + loss = fluid.layers.reduce_mean(cost) + sgd = fluid.optimizer.SGD(learning_rate=0.01) + sgd.minimize(loss) + + if run_mlu: + place = paddle.MLUPlace(0) + else: + place = paddle.CPUPlace() + + exe = paddle.static.Executor(place) + exe.run(startup_prog) + + print("Start run on {}".format(place)) + for epoch in range(100): + + pred_res, loss_res = exe.run( + main_prog, + feed={"a": a_np, + "b": b_np, + "label": label_np}, + fetch_list=[prediction, loss]) + if epoch % 10 == 0: + print("Epoch {} | Prediction[0]: {}, Loss: {}".format( + epoch, pred_res[0], loss_res)) + + return pred_res, loss_res + + def test_mlu(self): + cpu_pred, cpu_loss = self._test(False) + mlu_pred, mlu_loss = self._test(True) + + self.assertTrue(np.allclose(mlu_pred, cpu_pred)) + self.assertTrue(np.allclose(mlu_loss, cpu_loss)) + + +if __name__ == '__main__': + unittest.main()