diff --git a/paddle/fluid/operators/dropout_op_mlu.cc b/paddle/fluid/operators/dropout_op_mlu.cc index b88974a51ceff679a60cfe69547857c50e7ba608..f4dbbae05532e086fd4f3be889b1cb097a19b784 100644 --- a/paddle/fluid/operators/dropout_op_mlu.cc +++ b/paddle/fluid/operators/dropout_op_mlu.cc @@ -82,7 +82,7 @@ class DropoutMLUKernel : public framework::OpKernel { *x, ctx.GetPlace(), ctx.template device_context(), out); } else { - float scale = static_cast(1.0f - dropout_prob); + auto scale = static_cast(1.0f - dropout_prob); Tensor scale_tensor(x->dtype()); scale_tensor.mutable_data({1}, ctx.GetPlace()); MLUCnnlTensorDesc scale_desc(scale_tensor); diff --git a/paddle/fluid/operators/mlu/mlu_baseop.cc b/paddle/fluid/operators/mlu/mlu_baseop.cc index 56c9dd855734d1a5bb172a9ac28274fb80392ddb..8c907ab0e8dec9915e3c5418d7cc605ca319db0e 100644 --- a/paddle/fluid/operators/mlu/mlu_baseop.cc +++ b/paddle/fluid/operators/mlu/mlu_baseop.cc @@ -805,17 +805,17 @@ MLUCnnlTrigonDesc::~MLUCnnlTrigonDesc() { } /* static */ void MLUCnnl::ApplyAdam( - const ExecutionContext& ctx, const cnnlTensorDescriptor_t grad_desc, - const void* grad, const void* lr, const void* beta1, const void* beta2, - const void* beta1_power, const void* beta2_power, const void* epsilon, - const bool use_nesterov, const cnnlTensorDescriptor_t var_desc, void* var, - const cnnlTensorDescriptor_t m_desc, void* m, - const cnnlTensorDescriptor_t v_desc, void* v) { + const ExecutionContext& ctx, const cnnlTensorDescriptor_t var_desc, + void* var, const cnnlTensorDescriptor_t m_desc, void* m, + const cnnlTensorDescriptor_t v_desc, void* v, + const cnnlTensorDescriptor_t grad_desc, const void* grad, const void* lr, + const void* beta1, const void* beta2, const void* beta1_power, + const void* beta2_power, const void* epsilon, const bool use_nesterov) { cnnlHandle_t handle = GetHandleFromCTX(ctx); PADDLE_ENFORCE_MLU_SUCCESS(cnnlApplyAdam( - handle, grad_desc, var, grad_desc, m, grad_desc, v, grad_desc, grad, lr, - beta1, beta2, beta1_power, beta2_power, epsilon, use_nesterov)); + handle, var_desc, var, m_desc, m, v_desc, v, grad_desc, grad, lr, beta1, + beta2, beta1_power, beta2_power, epsilon, use_nesterov)); } /* static */ void MLUCnnl::ApplyAdaMax( diff --git a/paddle/fluid/operators/mlu/mlu_baseop.h b/paddle/fluid/operators/mlu/mlu_baseop.h index 71ea27d690f11d57c83ca2a63e77aa81a9bc4545..24db6c760d78abb4c317b42715daf20575994aee 100644 --- a/paddle/fluid/operators/mlu/mlu_baseop.h +++ b/paddle/fluid/operators/mlu/mlu_baseop.h @@ -503,14 +503,14 @@ class MLUCnnl { const cnnlTensorDescriptor_t mom_desc, void* mom); static void ApplyAdam(const ExecutionContext& ctx, + const cnnlTensorDescriptor_t var_desc, void* var, + const cnnlTensorDescriptor_t m_desc, void* m, + const cnnlTensorDescriptor_t v_desc, void* v, const cnnlTensorDescriptor_t grad_desc, const void* grad, const void* lr, const void* beta1, const void* beta2, const void* beta1_power, const void* beta2_power, const void* epsilon, - const bool use_nesterov, - const cnnlTensorDescriptor_t var_desc, void* var, - const cnnlTensorDescriptor_t m_desc, void* m, - const cnnlTensorDescriptor_t v_desc, void* v); + const bool use_nesterov); static void ApplyAdaMax(const ExecutionContext& ctx, const cnnlTensorDescriptor_t grad_desc, diff --git a/paddle/fluid/operators/optimizers/adam_op_mlu.cc b/paddle/fluid/operators/optimizers/adam_op_mlu.cc new file mode 100644 index 0000000000000000000000000000000000000000..9d335021234ebb60279975beeb1389380300eb6b --- /dev/null +++ b/paddle/fluid/operators/optimizers/adam_op_mlu.cc @@ -0,0 +1,285 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/tensor_util.h" +#include "paddle/fluid/operators/reduce_ops/reduce_op_mlu.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +using LoDTensor = framework::LoDTensor; + +template +class AdamMLUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + const auto* param_var = ctx.InputVar("Param"); + PADDLE_ENFORCE_EQ(param_var->IsType(), true, + platform::errors::InvalidArgument( + "The Var(%s)'s type should be LoDTensor, " + "but the received is %s", + ctx.InputNames("Param").front(), + framework::ToTypeName(param_var->Type()))); + auto* param = ctx.Input("Param"); + auto* grad_var = ctx.InputVar("Grad"); + PADDLE_ENFORCE_EQ(grad_var->IsType(), true, + platform::errors::InvalidArgument( + "The Grad(%s)'s type should be LoDTensor, " + "but the received is %s", + ctx.InputNames("Grad").front(), + framework::ToTypeName(param_var->Type()))); + auto* grad = ctx.Input("Grad"); + auto* mom1 = ctx.Input("Moment1"); + auto* mom2 = ctx.Input("Moment2"); + auto* lr = ctx.Input("LearningRate"); + + auto* beta1_pow = ctx.Input("Beta1Pow"); + auto* beta2_pow = ctx.Input("Beta2Pow"); + + auto* param_out = ctx.Output("ParamOut"); + auto* mom1_out = ctx.Output("Moment1Out"); + auto* mom2_out = ctx.Output("Moment2Out"); + auto* beta1_pow_out = ctx.Output("Beta1PowOut"); + auto* beta2_pow_out = ctx.Output("Beta2PowOut"); + + bool skip_update = false; + if (ctx.HasInput("SkipUpdate")) { + auto* skip_update_tensor = ctx.Input("SkipUpdate"); + PADDLE_ENFORCE_EQ(skip_update_tensor->numel(), 1, + platform::errors::InvalidArgument( + "Input(SkipUpdate) size must be 1, but get %d", + skip_update_tensor->numel())); + std::vector skip_update_vec; + paddle::framework::TensorToVector(*skip_update_tensor, + ctx.device_context(), &skip_update_vec); + skip_update = skip_update_vec[0]; + } + // skip_update=true, just copy input to output, and TensorCopy will call + // mutable_data + if (skip_update) { + VLOG(4) << "Adam skip update"; + framework::TensorCopy( + *param, ctx.GetPlace(), + ctx.template device_context(), param_out); + framework::TensorCopy( + *mom1, ctx.GetPlace(), + ctx.template device_context(), mom1_out); + framework::TensorCopy( + *mom2, ctx.GetPlace(), + ctx.template device_context(), mom2_out); + framework::TensorCopy( + *beta1_pow, beta1_pow->place(), + ctx.template device_context(), + beta1_pow_out); + framework::TensorCopy( + *beta2_pow, beta2_pow->place(), + ctx.template device_context(), + beta2_pow_out); + return; + } + + bool use_global_beta_pow = ctx.Attr("use_global_beta_pow"); + VLOG(4) << "use_global_beta_pow:" << use_global_beta_pow; + + param_out->ShareDataWith(*param); + mom1_out->ShareDataWith(*mom1); + mom2_out->ShareDataWith(*mom2); + + LoDTensor beta1_pow_tmp; + LoDTensor beta2_pow_tmp; + if (beta1_pow->place() == platform::CPUPlace()) { + T beta1 = *beta1_pow->data(); + beta1_pow_tmp.mutable_data({1}, ctx.GetPlace()); + MLUCnnlTensorDesc beta1_pow_tmp_desc(beta1_pow_tmp); + MLUCnnl::Fill(ctx, CNNL_POINTER_MODE_HOST, &beta1, + beta1_pow_tmp_desc.get(), GetBasePtr(&beta1_pow_tmp)); + beta1_pow = &beta1_pow_tmp; + } + if (beta2_pow->place() == platform::CPUPlace()) { + T beta2 = *beta2_pow->data(); + beta2_pow_tmp.mutable_data({1}, ctx.GetPlace()); + MLUCnnlTensorDesc beta2_pow_tmp_desc(beta2_pow_tmp); + MLUCnnl::Fill(ctx, CNNL_POINTER_MODE_HOST, &beta2, + beta2_pow_tmp_desc.get(), GetBasePtr(&beta2_pow_tmp)); + beta2_pow = &beta2_pow_tmp; + } + + VLOG(3) << "beta1_pow.numel() : " << beta1_pow->numel() + << "beta2_pow.numel() : " << beta2_pow->numel(); + VLOG(3) << "param.numel(): " << param->numel(); + + PADDLE_ENFORCE_EQ(beta1_pow_out->numel(), 1, + platform::errors::InvalidArgument( + "beta1 pow output size should be 1, but received " + "value is:%d.", + beta1_pow_out->numel())); + + PADDLE_ENFORCE_EQ(beta2_pow_out->numel(), 1, + platform::errors::InvalidArgument( + "beta2 pow output size should be 1, but received " + "value is:%d.", + beta2_pow_out->numel())); + + const Tensor* beta1_tensor = nullptr; + const Tensor* beta2_tensor = nullptr; + const Tensor* epsilon_tensor = nullptr; + + Tensor beta1_tmp(experimental::DataType::FLOAT32); + Tensor beta2_tmp(experimental::DataType::FLOAT32); + Tensor epsilon_tmp(experimental::DataType::FLOAT32); + + if (ctx.HasInput("Beta1Tensor")) { + beta1_tensor = ctx.Input("Beta1Tensor"); + PADDLE_ENFORCE_EQ(beta1_tensor->numel(), 1, + platform::errors::InvalidArgument( + "Input(Beta1Tensor) size must be 1, but get %d", + beta1_tensor->numel())); + } else { + T beta1 = static_cast(ctx.Attr("beta1")); + beta1_tmp.mutable_data({1}, ctx.GetPlace()); + MLUCnnlTensorDesc beta1_tmp_desc(beta1_tmp); + MLUCnnl::Fill(ctx, CNNL_POINTER_MODE_HOST, &beta1, beta1_tmp_desc.get(), + GetBasePtr(&beta1_tmp)); + beta1_tensor = &beta1_tmp; + } + + if (ctx.HasInput("Beta2Tensor")) { + beta2_tensor = ctx.Input("Beta2Tensor"); + PADDLE_ENFORCE_EQ(beta2_tensor->numel(), 1, + platform::errors::InvalidArgument( + "Input(Beta2Tensor) size must be 1, but get %d", + beta2_tensor->numel())); + } else { + T beta2 = static_cast(ctx.Attr("beta2")); + beta2_tmp.mutable_data({1}, ctx.GetPlace()); + MLUCnnlTensorDesc beta2_tmp_desc(beta2_tmp); + MLUCnnl::Fill(ctx, CNNL_POINTER_MODE_HOST, &beta2, beta2_tmp_desc.get(), + GetBasePtr(&beta2_tmp)); + beta2_tensor = &beta2_tmp; + } + + if (ctx.HasInput("EpsilonTensor")) { + epsilon_tensor = ctx.Input("EpsilonTensor"); + PADDLE_ENFORCE_EQ(epsilon_tensor->numel(), 1, + platform::errors::InvalidArgument( + "Input(EpsilonTensor) size must be 1, but get %d", + epsilon_tensor->numel())); + } else { + T epsilon = static_cast(ctx.Attr("epsilon")); + epsilon_tmp.mutable_data({1}, ctx.GetPlace()); + MLUCnnlTensorDesc epsilon_tmp_desc(epsilon_tmp); + MLUCnnl::Fill(ctx, CNNL_POINTER_MODE_HOST, &epsilon, + epsilon_tmp_desc.get(), GetBasePtr(&epsilon_tmp)); + epsilon_tensor = &epsilon_tmp; + } + + MLUCnnlTensorDesc param_desc(*param); + MLUCnnlTensorDesc mom1_desc(*mom1); + MLUCnnlTensorDesc mom2_desc(*mom2); + MLUCnnlTensorDesc grad_desc(*grad); + MLUCnnl::ApplyAdam(ctx, param_desc.get(), GetBasePtr(param_out), + mom1_desc.get(), GetBasePtr(mom1_out), mom2_desc.get(), + GetBasePtr(mom2_out), grad_desc.get(), GetBasePtr(grad), + GetBasePtr(lr), GetBasePtr(beta1_tensor), + GetBasePtr(beta2_tensor), GetBasePtr(beta1_pow), + GetBasePtr(beta2_pow), GetBasePtr(epsilon_tensor), + /*use_nesterov*/ false); + + if (!use_global_beta_pow) { + beta1_pow_out->mutable_data(ctx.GetPlace()); + beta2_pow_out->mutable_data(ctx.GetPlace()); + + MLUCnnlTensorDesc beta1_desc(*beta1_tensor); + MLUCnnlOpTensorDesc mul_op_desc(CNNL_OP_TENSOR_MUL, ToCnnlDataType(), + CNNL_NOT_PROPAGATE_NAN); + + MLUCnnl::OpTensor(ctx, mul_op_desc.get(), beta1_desc.get(), + GetBasePtr(beta1_pow), beta1_desc.get(), + GetBasePtr(beta1_tensor), beta1_desc.get(), + GetBasePtr(beta1_pow_out), ToCnnlDataType()); + + MLUCnnl::OpTensor(ctx, mul_op_desc.get(), beta1_desc.get(), + GetBasePtr(beta2_pow), beta1_desc.get(), + GetBasePtr(beta2_tensor), beta1_desc.get(), + GetBasePtr(beta2_pow_out), ToCnnlDataType()); + } + } +}; + +template +class AdamWMLUKernel : public AdamMLUKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + VLOG(3) << "MLU AdamW Kernel"; + bool skip_update = false; + if (ctx.HasInput("SkipUpdate")) { + VLOG(3) << "Has SkipUpdate"; + auto* skip_update_tensor = ctx.Input("SkipUpdate"); + PADDLE_ENFORCE_EQ(skip_update_tensor->numel(), 1, + platform::errors::InvalidArgument( + "Input(SkipUpdate) size must be 1, but get %d", + skip_update_tensor->numel())); + std::vector skip_update_vec; + paddle::framework::TensorToVector(*skip_update_tensor, + ctx.device_context(), &skip_update_vec); + skip_update = skip_update_vec[0]; + } + VLOG(3) << "Skip update" << skip_update; + bool with_decay = ctx.Attr("with_decay"); + if (!skip_update && with_decay) { + if (ctx.HasInput("MasterParam")) { + PADDLE_THROW(platform::errors::Unimplemented( + "Master Param is not supported on MLU")); + } else { + const auto* param_var = ctx.InputVar("Param"); + PADDLE_ENFORCE_EQ(param_var->IsType(), true, + platform::errors::InvalidArgument( + "The Var(%s)'s type should be LoDTensor, " + "but the received is %s", + ctx.InputNames("Param").front(), + framework::ToTypeName(param_var->Type()))); + auto* param = ctx.Input("Param"); + auto* lr = ctx.Input("LearningRate"); + float coeff = ctx.Attr("coeff"); + + // update param with decay coeff: mul(-1 * lr, coeff * param) + param + MLUCnnlTensorDesc lr_desc(*lr); + MLUCnnlTensorDesc param_desc(*param); + MLUCnnlOpTensorDesc mul_op_desc(CNNL_OP_TENSOR_MUL, ToCnnlDataType(), + CNNL_NOT_PROPAGATE_NAN); + + MLUCnnl::OpTensor(ctx, mul_op_desc.get(), lr_desc.get(), GetBasePtr(lr), + param_desc.get(), GetBasePtr(param), param_desc.get(), + const_cast(GetBasePtr(param)), + ToCnnlDataType(), + /*alpha1*/ -1.f, /*alpha2*/ coeff, /*beta*/ 1.f); + } + } + AdamMLUKernel::Compute(ctx); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_MLU_KERNEL(adam, ops::AdamMLUKernel, + ops::AdamMLUKernel); + +REGISTER_OP_MLU_KERNEL(adamw, ops::AdamWMLUKernel, + ops::AdamWMLUKernel); diff --git a/python/paddle/fluid/tests/unittests/mlu/test_adam_op_mlu.py b/python/paddle/fluid/tests/unittests/mlu/test_adam_op_mlu.py new file mode 100644 index 0000000000000000000000000000000000000000..f30a391f65385414e1452c8e4648228abf209a92 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/mlu/test_adam_op_mlu.py @@ -0,0 +1,303 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import unittest +import sys +sys.path.append("..") +from op_test import OpTest +import paddle +import paddle.fluid as fluid +import paddle.fluid.core as core +from test_adam_op import adam_step + +paddle.enable_static() +SEED = 2022 + + +class TestAdam(OpTest): + def setUp(self): + self.set_mlu() + self.op_type = "adam" + param = np.random.uniform(-1, 1, (102, 105)).astype("float32") + grad = np.random.uniform(-1, 1, (102, 105)).astype("float32") + moment1 = np.random.uniform(-1, 1, (102, 105)).astype("float32") + # The second moment is positive + moment2 = np.random.random((102, 105)).astype("float32") + + learning_rate = 0.004 + beta1 = 0.78 + beta2 = 0.836 + epsilon = 1e-4 + beta1_pow = beta1**10 + beta2_pow = beta2**10 + + self.inputs = { + 'Param': param, + 'Grad': grad, + 'Moment1': moment1, + 'Moment2': moment2, + 'LearningRate': np.array([learning_rate]).astype("float32"), + 'Beta1Pow': np.array([beta1_pow]).astype("float32"), + 'Beta2Pow': np.array([beta2_pow]).astype("float32") + } + + self.attrs = {'epsilon': epsilon, 'beta1': beta1, 'beta2': beta2} + + param_out, moment1_out, \ + moment2_out = adam_step(self.inputs, self.attrs) + + self.outputs = { + 'Moment1Out': moment1_out, + 'Moment2Out': moment2_out, + 'ParamOut': param_out, + 'Beta1PowOut': np.array([beta1_pow]).astype("float32") * beta1, + 'Beta2PowOut': np.array([beta2_pow]).astype("float32") * beta2 + } + + def set_mlu(self): + self.__class__.use_mlu = True + self.place = paddle.device.MLUPlace(0) + + def init_dtype(self): + self.dtype = np.float32 + + def test_check_output(self): + self.check_output_with_place(self.place, atol=1e-5) + + +class TestAdamWithEpsilonTensor(OpTest): + def setUp(self): + self.set_mlu() + self.op_type = "adam" + param = np.random.uniform(-1, 1, (102, 105)).astype("float32") + grad = np.random.uniform(-1, 1, (102, 105)).astype("float32") + moment1 = np.random.uniform(-1, 1, (102, 105)).astype("float32") + # The second moment is positive + moment2 = np.random.random((102, 105)).astype("float32") + + learning_rate = 0.004 + beta1 = 0.78 + beta2 = 0.836 + epsilon = 1e-4 + beta1_pow = beta1**10 + beta2_pow = beta2**10 + + self.inputs = { + 'Param': param, + 'Grad': grad, + 'Moment1': moment1, + 'Moment2': moment2, + 'LearningRate': np.array([learning_rate]).astype("float32"), + 'Beta1Pow': np.array([beta1_pow]).astype("float32"), + 'Beta2Pow': np.array([beta2_pow]).astype("float32"), + 'Beta1Tensor': np.array([beta1]).astype("float32"), + 'Beta2Tensor': np.array([beta2]).astype("float32"), + 'EpsilonTensor': np.array([epsilon]).astype("float32"), + } + + self.attrs = {'epsilon': epsilon} + + param_out, moment1_out, \ + moment2_out = adam_step(self.inputs, self.attrs) + + self.outputs = { + 'Moment1Out': moment1_out, + 'Moment2Out': moment2_out, + 'ParamOut': param_out, + 'Beta1PowOut': np.array([beta1_pow]).astype("float32") * beta1, + 'Beta2PowOut': np.array([beta2_pow]).astype("float32") * beta2 + } + + def set_mlu(self): + self.__class__.use_mlu = True + self.place = paddle.device.MLUPlace(0) + + def init_dtype(self): + self.dtype = np.float32 + + def test_check_output(self): + self.check_output_with_place(self.place, atol=1e-5) + + +class TestAdamOpWithSkipUpdate(OpTest): + def setUp(self): + self.set_mlu() + self.op_type = "adam" + param = np.random.uniform(-1, 1, (102, 105)).astype("float32") + grad = np.random.uniform(-1, 1, (102, 105)).astype("float32") + moment1 = np.random.uniform(-1, 1, (102, 105)).astype("float32") + # The second moment is positive + moment2 = np.random.random((102, 105)).astype("float32") + + learning_rate = 0.004 + beta1 = 0.78 + beta2 = 0.836 + epsilon = 1e-4 + beta1_pow = beta1**10 + beta2_pow = beta2**10 + + self.inputs = { + 'Param': param, + 'Grad': grad, + 'Moment1': moment1, + 'Moment2': moment2, + 'LearningRate': np.array([learning_rate]).astype("float32"), + 'Beta1Pow': np.array([beta1_pow]).astype("float32"), + 'Beta2Pow': np.array([beta2_pow]).astype("float32"), + 'Beta1Tensor': np.array([beta1]).astype("float32"), + 'Beta2Tensor': np.array([beta2]).astype("float32"), + 'EpsilonTensor': np.array([epsilon]).astype("float32"), + "SkipUpdate": np.array([True]).astype("bool"), + } + + self.attrs = {'epsilon': epsilon} + + self.outputs = { + 'Moment1Out': moment1, + 'Moment2Out': moment2, + 'ParamOut': param, + 'Beta1PowOut': self.inputs['Beta1Pow'], + 'Beta2PowOut': self.inputs['Beta2Pow'], + } + + def set_mlu(self): + self.__class__.use_mlu = True + self.place = paddle.device.MLUPlace(0) + + def init_dtype(self): + self.dtype = np.float32 + + def test_check_output(self): + self.check_output_with_place(self.place, atol=1e-5) + + +class TestAdamOpWithGlobalBetaPow(OpTest): + def setUp(self): + self.set_mlu() + self.op_type = "adam" + param = np.random.uniform(-1, 1, (102, 105)).astype("float32") + grad = np.random.uniform(-1, 1, (102, 105)).astype("float32") + moment1 = np.random.uniform(-1, 1, (102, 105)).astype("float32") + # The second moment is positive + moment2 = np.random.random((102, 105)).astype("float32") + + learning_rate = 0.004 + beta1 = 0.78 + beta2 = 0.836 + epsilon = 1e-4 + beta1_pow = beta1**10 + beta2_pow = beta2**10 + + self.inputs = { + 'Param': param, + 'Grad': grad, + 'Moment1': moment1, + 'Moment2': moment2, + 'LearningRate': np.array([learning_rate]).astype("float32"), + 'Beta1Pow': np.array([beta1_pow]).astype("float32"), + 'Beta2Pow': np.array([beta2_pow]).astype("float32"), + 'Beta1Tensor': np.array([beta1]).astype("float32"), + 'Beta2Tensor': np.array([beta2]).astype("float32"), + 'EpsilonTensor': np.array([epsilon]).astype("float32"), + } + + attributes = {'epsilon': epsilon} + + param_out, moment1_out, \ + moment2_out = adam_step(self.inputs, attributes) + + self.attrs = {'use_global_beta_pow': True} + + # use_global_beta_pow=True, Beta1PowOut and Beta2PowOut are empty. + self.outputs = { + 'Moment1Out': moment1_out, + 'Moment2Out': moment2_out, + 'ParamOut': param_out, + 'Beta1PowOut': np.array([]), + 'Beta2PowOut': np.array([]) + } + + def set_mlu(self): + self.__class__.use_mlu = True + self.place = paddle.device.MLUPlace(0) + + def init_dtype(self): + self.dtype = np.float32 + + def test_check_output(self): + self.check_output_with_place(self.place, atol=1e-5) + + +class TestNet(unittest.TestCase): + def _test(self, run_mlu=True): + main_prog = paddle.static.Program() + startup_prog = paddle.static.Program() + main_prog.random_seed = SEED + startup_prog.random_seed = SEED + np.random.seed(SEED) + + a_np = np.random.random(size=(32, 32)).astype('float32') + b_np = np.random.random(size=(32, 32)).astype('float32') + label_np = np.random.randint(2, size=(32, 1)).astype('int64') + + with paddle.static.program_guard(main_prog, startup_prog): + a = paddle.static.data(name="a", shape=[32, 32], dtype='float32') + b = paddle.static.data(name="b", shape=[32, 32], dtype='float32') + label = paddle.static.data( + name="label", shape=[32, 1], dtype='int64') + + sum = paddle.add(a, b) + z = paddle.pow(sum, 2.0) + + fc_1 = fluid.layers.fc(input=z, size=128) + prediction = fluid.layers.fc(input=fc_1, size=2, act='softmax') + + cost = fluid.layers.cross_entropy(input=prediction, label=label) + loss = fluid.layers.reduce_mean(cost) + adam = fluid.optimizer.Adam(learning_rate=0.01) + adam.minimize(loss) + + if run_mlu: + place = paddle.device.MLUPlace(0) + else: + place = paddle.CPUPlace() + + exe = paddle.static.Executor(place) + exe.run(startup_prog) + + print("Start run on {}".format(place)) + for epoch in range(100): + + pred_res, loss_res = exe.run( + main_prog, + feed={"a": a_np, + "b": b_np, + "label": label_np}, + fetch_list=[prediction, loss]) + if epoch % 10 == 0: + print("Epoch {} | Prediction[0]: {}, Loss: {}".format( + epoch, pred_res[0], loss_res)) + + return pred_res, loss_res + + def test_mlu(self): + mlu_pred, mlu_loss = self._test(True) + cpu_pred, cpu_loss = self._test(False) + self.assertTrue(np.allclose(mlu_pred, cpu_pred, rtol=1e-3)) + self.assertTrue(np.allclose(mlu_loss, cpu_loss, rtol=1e-3)) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/mlu/test_adamw_op_mlu.py b/python/paddle/fluid/tests/unittests/mlu/test_adamw_op_mlu.py new file mode 100644 index 0000000000000000000000000000000000000000..d2827725a205815dfb85250197753b11561a3026 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/mlu/test_adamw_op_mlu.py @@ -0,0 +1,250 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import unittest +import sys +sys.path.append("..") +from op_test import OpTest +import paddle +import paddle.fluid as fluid +import paddle.fluid.core as core +from test_adam_op import adamw_step + +paddle.enable_static() +SEED = 2022 + + +class TestAdamW(OpTest): + def setUp(self): + self.set_mlu() + self.op_type = "adamw" + param = np.random.uniform(-1, 1, (105, 102)).astype("float32") + grad = np.random.uniform(-1, 1, (105, 102)).astype("float32") + moment1 = np.random.uniform(-1, 1, (105, 102)).astype("float32") + # The second moment is positive + moment2 = np.random.random((105, 102)).astype("float32") + + learning_rate = 0.5 + beta1 = 0.78 + beta2 = 0.836 + epsilon = 1e-4 + beta1_pow = beta1**10 + beta2_pow = beta2**10 + + self.inputs = { + 'Param': param, + 'Grad': grad, + 'Moment1': moment1, + 'Moment2': moment2, + 'LearningRate': np.array([learning_rate]).astype("float32"), + 'Beta1Pow': np.array([beta1_pow]).astype("float32"), + 'Beta2Pow': np.array([beta2_pow]).astype("float32") + } + + self.attrs = { + 'epsilon': epsilon, + 'beta1': beta1, + 'beta2': beta2, + "coeff": 0.9, + "with_decay": True + } + + param_out, moment1_out, \ + moment2_out = adamw_step(self.inputs, self.attrs) + + self.outputs = { + 'Moment1Out': moment1_out, + 'Moment2Out': moment2_out, + 'ParamOut': param_out, + 'Beta1PowOut': np.array([beta1_pow]).astype("float32") * beta1, + 'Beta2PowOut': np.array([beta2_pow]).astype("float32") * beta2 + } + + def set_mlu(self): + self.__class__.use_mlu = True + self.place = paddle.device.MLUPlace(0) + + def init_dtype(self): + self.dtype = np.float32 + + def test_check_output(self): + self.check_output_with_place(self.place, atol=1e-5) + + +class TestAdamOpWithSkipUpdate(OpTest): + def setUp(self): + self.set_mlu() + self.op_type = "adamw" + param = np.random.uniform(-1, 1, (102, 105)).astype("float32") + grad = np.random.uniform(-1, 1, (102, 105)).astype("float32") + moment1 = np.random.uniform(-1, 1, (102, 105)).astype("float32") + # The second moment is positive + moment2 = np.random.random((102, 105)).astype("float32") + + learning_rate = 0.004 + beta1 = 0.78 + beta2 = 0.836 + epsilon = 1e-4 + beta1_pow = beta1**10 + beta2_pow = beta2**10 + + self.inputs = { + 'Param': param, + 'Grad': grad, + 'Moment1': moment1, + 'Moment2': moment2, + 'LearningRate': np.array([learning_rate]).astype("float32"), + 'Beta1Pow': np.array([beta1_pow]).astype("float32"), + 'Beta2Pow': np.array([beta2_pow]).astype("float32"), + 'Beta1Tensor': np.array([beta1]).astype("float32"), + 'Beta2Tensor': np.array([beta2]).astype("float32"), + 'EpsilonTensor': np.array([epsilon]).astype("float32"), + "SkipUpdate": np.array([True]).astype("bool"), + } + + self.attrs = {'epsilon': epsilon, "coeff": 0.02, "with_decay": True} + + self.outputs = { + 'Moment1Out': moment1, + 'Moment2Out': moment2, + 'ParamOut': param, + 'Beta1PowOut': self.inputs['Beta1Pow'], + 'Beta2PowOut': self.inputs['Beta2Pow'], + } + + def set_mlu(self): + self.__class__.use_mlu = True + self.place = paddle.device.MLUPlace(0) + + def init_dtype(self): + self.dtype = np.float32 + + def test_check_output(self): + self.check_output_with_place(self.place, atol=1e-5) + + +class TestAdamOpWithoutDecay(OpTest): + def setUp(self): + self.set_mlu() + self.op_type = "adamw" + param = np.random.uniform(-1, 1, (102, 105)).astype("float32") + grad = np.random.uniform(-1, 1, (102, 105)).astype("float32") + moment1 = np.random.uniform(-1, 1, (102, 105)).astype("float32") + # The second moment is positive + moment2 = np.random.random((102, 105)).astype("float32") + + learning_rate = 0.004 + beta1 = 0.78 + beta2 = 0.836 + epsilon = 1e-4 + beta1_pow = beta1**10 + beta2_pow = beta2**10 + + self.inputs = { + 'Param': param, + 'Grad': grad, + 'Moment1': moment1, + 'Moment2': moment2, + 'LearningRate': np.array([learning_rate]).astype("float32"), + 'Beta1Pow': np.array([beta1_pow]).astype("float32"), + 'Beta2Pow': np.array([beta2_pow]).astype("float32"), + 'Beta1Tensor': np.array([beta1]).astype("float32"), + 'Beta2Tensor': np.array([beta2]).astype("float32"), + 'EpsilonTensor': np.array([epsilon]).astype("float32"), + "SkipUpdate": np.array([True]).astype("bool"), + } + + self.attrs = {'epsilon': epsilon, "coeff": 0.02, "with_decay": False} + + self.outputs = { + 'Moment1Out': moment1, + 'Moment2Out': moment2, + 'ParamOut': param, + 'Beta1PowOut': self.inputs['Beta1Pow'], + 'Beta2PowOut': self.inputs['Beta2Pow'], + } + + def set_mlu(self): + self.__class__.use_mlu = True + self.place = paddle.device.MLUPlace(0) + + def init_dtype(self): + self.dtype = np.float32 + + def test_check_output(self): + self.check_output_with_place(self.place, atol=1e-5) + + +class TestNet(unittest.TestCase): + def _test(self, run_mlu=True): + main_prog = paddle.static.Program() + startup_prog = paddle.static.Program() + main_prog.random_seed = SEED + startup_prog.random_seed = SEED + np.random.seed(SEED) + + a_np = np.random.random(size=(32, 32)).astype('float32') + b_np = np.random.random(size=(32, 32)).astype('float32') + label_np = np.random.randint(2, size=(32, 1)).astype('int64') + + with paddle.static.program_guard(main_prog, startup_prog): + a = paddle.static.data(name="a", shape=[32, 32], dtype='float32') + b = paddle.static.data(name="b", shape=[32, 32], dtype='float32') + label = paddle.static.data( + name="label", shape=[32, 1], dtype='int64') + + sum = paddle.add(a, b) + z = paddle.pow(sum, 2.0) + + fc_1 = fluid.layers.fc(input=z, size=128) + prediction = fluid.layers.fc(input=fc_1, size=2, act='softmax') + + cost = fluid.layers.cross_entropy(input=prediction, label=label) + loss = fluid.layers.reduce_mean(cost) + adam = paddle.optimizer.AdamW(learning_rate=0.01, weight_decay=0.02) + adam.minimize(loss) + + if run_mlu: + place = paddle.device.MLUPlace(0) + else: + place = paddle.CPUPlace() + + exe = paddle.static.Executor(place) + exe.run(startup_prog) + + print("Start run on {}".format(place)) + for epoch in range(100): + + pred_res, loss_res = exe.run( + main_prog, + feed={"a": a_np, + "b": b_np, + "label": label_np}, + fetch_list=[prediction, loss]) + if epoch % 10 == 0: + print("Epoch {} | Prediction[0]: {}, Loss: {}".format( + epoch, pred_res[0], loss_res)) + + return pred_res, loss_res + + def test_mlu(self): + mlu_pred, mlu_loss = self._test(True) + cpu_pred, cpu_loss = self._test(False) + self.assertTrue(np.allclose(mlu_pred, cpu_pred, rtol=1e-3)) + self.assertTrue(np.allclose(mlu_loss, cpu_loss, rtol=1e-3)) + + +if __name__ == '__main__': + unittest.main()