diff --git a/paddle/fluid/operators/optimizers/adam_op.h b/paddle/fluid/operators/optimizers/adam_op.h index 285362fcd5e0ffec16b718d578e3cf90813449b7..c8b28aed24e8cf2886d527fb45560d06c48b4fad 100644 --- a/paddle/fluid/operators/optimizers/adam_op.h +++ b/paddle/fluid/operators/optimizers/adam_op.h @@ -36,6 +36,10 @@ static inline float GetAttrFromTensor(const framework::Tensor* tensor) { TensorCopySync(*tensor, platform::CPUPlace(), &cpu_tensor); tensor_data = cpu_tensor.data(); } + if (platform::is_xpu_place(tensor->place())) { + TensorCopySync(*tensor, platform::CPUPlace(), &cpu_tensor); + tensor_data = cpu_tensor.data(); + } return tensor_data[0]; } diff --git a/paddle/fluid/operators/optimizers/adam_op_xpu.cc b/paddle/fluid/operators/optimizers/adam_op_xpu.cc new file mode 100644 index 0000000000000000000000000000000000000000..05b4544c02a1231f7f6f275f13a978e66705819b --- /dev/null +++ b/paddle/fluid/operators/optimizers/adam_op_xpu.cc @@ -0,0 +1,136 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/optimizers/adam_op.h" +#include + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +#ifdef PADDLE_WITH_XPU +template +class AdamOpXPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + const auto* param_var = ctx.InputVar("Param"); + PADDLE_ENFORCE_EQ(param_var->IsType(), true, + platform::errors::InvalidArgument( + "Tensor holds the wrong type,Expected Var(%s)'s " + "type is LoDTensor, " + "but the received is %s", + ctx.InputNames("Param").front(), + framework::ToTypeName(param_var->Type()))); + using paddle::framework::LoDTensor; + + T epsilon = static_cast(ctx.Attr("epsilon")); + + auto& param = GET_DATA_SAFELY(ctx.Input("Param"), "Input", + "Param", "Adam"); + // auto& grad = Ref(ctx.Input("Grad"), "Must set Grad"); + auto* grad_var = ctx.InputVar("Grad"); + auto& mom1 = GET_DATA_SAFELY(ctx.Input("Moment1"), "Input", + "Moment1", "Adam"); + auto& mom2 = GET_DATA_SAFELY(ctx.Input("Moment2"), "Input", + "Moment2", "Adam"); + auto& lr = GET_DATA_SAFELY(ctx.Input("LearningRate"), "Input", + "LearningRate", "Adam"); + auto& beta1_pow = GET_DATA_SAFELY(ctx.Input("Beta1Pow"), "Input", + "Beta1Pow", "Adam"); + auto& beta2_pow = GET_DATA_SAFELY(ctx.Input("Beta2Pow"), "Input", + "Beta2Pow", "Adam"); + + auto& param_out = GET_DATA_SAFELY(ctx.Output("ParamOut"), + "Output", "ParamOut", "Adam"); + auto& mom1_out = GET_DATA_SAFELY(ctx.Output("Moment1Out"), + "Output", "Moment1Out", "Adam"); + auto& mom2_out = GET_DATA_SAFELY(ctx.Output("Moment2Out"), + "Output", "Moment2Out", "Adam"); + + auto* beta1_pow_out = ctx.Output("Beta1PowOut"); + auto* beta2_pow_out = ctx.Output("Beta2PowOut"); + PADDLE_ENFORCE_EQ(beta1_pow_out->numel(), 1, + platform::errors::InvalidArgument( + "Tensor holds the wrong size, Expected beta1 pow " + "output size is 1, but received " + "value is:%d.", + beta1_pow_out->numel())); + + PADDLE_ENFORCE_EQ(beta2_pow_out->numel(), 1, + platform::errors::InvalidArgument( + "Tensor holds the wrong size, Expected beta2 pow " + "output size is 1, but received " + "value is:%d.", + beta2_pow_out->numel())); + + T beta1 = static_cast(ctx.Attr("beta1")); + if (ctx.HasInput("Beta1Tensor")) { + auto* beta1_tensor = ctx.Input("Beta1Tensor"); + beta1 = static_cast(GetAttrFromTensor(beta1_tensor)); + } + T beta2 = static_cast(ctx.Attr("beta2")); + if (ctx.HasInput("Beta2Tensor")) { + auto* beta2_tensor = ctx.Input("Beta2Tensor"); + beta2 = static_cast(GetAttrFromTensor(beta2_tensor)); + } + if (grad_var->IsType()) { + auto& grad = GET_DATA_SAFELY(ctx.Input("Grad"), "Input", + "Grad", "Adam"); + + auto& dev_ctx = ctx.template device_context(); + int r = xpu::adam( + dev_ctx.x_context(), grad.template data(), mom1.template data(), + mom2.template data(), param.template data(), + beta1_pow.template data(), beta2_pow.template data(), beta1, + beta2, epsilon, lr.template data(), + mom1_out.template mutable_data(ctx.GetPlace()), + mom2_out.template mutable_data(ctx.GetPlace()), + param_out.template mutable_data(ctx.GetPlace()), param.numel()); + + const float* ptr0 = beta1_pow.template data(); + float* ptr1 = beta1_pow_out->mutable_data(ctx.GetPlace()); + float cpudata; + xpu_memcpy(&cpudata, ptr0, sizeof(float), XPU_DEVICE_TO_HOST); + cpudata = cpudata * beta1; + xpu_memcpy(ptr1, &cpudata, sizeof(float), XPU_HOST_TO_DEVICE); + + const float* ptr2 = beta2_pow.template data(); + float* ptr3 = beta2_pow_out->mutable_data(ctx.GetPlace()); + float cpudata1; + xpu_memcpy(&cpudata1, ptr2, sizeof(float), XPU_DEVICE_TO_HOST); + cpudata1 = cpudata1 * beta2; + xpu_memcpy(ptr3, &cpudata1, sizeof(float), XPU_HOST_TO_DEVICE); + + PADDLE_ENFORCE_EQ(r == xpu::Error_t::SUCCESS, true, + platform::errors::External( + "XPU API return wrong value[%d], please check " + "where Baidu Kunlun Card is properly installed.", + r)); + } else { + PADDLE_ENFORCE_EQ(1, 2, platform::errors::InvalidArgument( + "Variable type not supported by adam_op")); + } + } +}; +#endif + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +#ifdef PADDLE_WITH_XPU +REGISTER_OP_XPU_KERNEL( + adam, ops::AdamOpXPUKernel); +#endif diff --git a/python/paddle/fluid/tests/unittests/xpu/test_adam_op_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_adam_op_xpu.py new file mode 100644 index 0000000000000000000000000000000000000000..147824f341be43df33699b3a918880979c24485d --- /dev/null +++ b/python/paddle/fluid/tests/unittests/xpu/test_adam_op_xpu.py @@ -0,0 +1,268 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +import sys +sys.path.append("..") +import unittest +import numpy as np +from op_test import OpTest +from paddle.fluid import core +from paddle.fluid.op import Operator +import paddle.fluid as fluid +import paddle + + +class TestAdamOp1(OpTest): + def setUp(self): + '''Test Adam Op with supplied attributes + ''' + self.op_type = "adam" + param = np.random.uniform(-1, 1, (102, 105)).astype("float32") + grad = np.random.uniform(-1, 1, (102, 105)).astype("float32") + moment1 = np.random.uniform(-1, 1, (102, 105)).astype("float32") + # The second moment is positive + moment2 = np.random.random((102, 105)).astype("float32") + + learning_rate = 0.004 + beta1 = 0.78 + beta2 = 0.836 + epsilon = 1e-4 + beta1_pow = beta1**10 + beta2_pow = beta2**10 + + self.inputs = { + 'Param': param, + 'Grad': grad, + 'Moment1': moment1, + 'Moment2': moment2, + 'LearningRate': np.array([learning_rate]).astype("float32"), + 'Beta1Pow': np.array([beta1_pow]).astype("float32"), + 'Beta2Pow': np.array([beta2_pow]).astype("float32") + } + + self.attrs = {'epsilon': epsilon, 'beta1': beta1, 'beta2': beta2} + + param_out, moment1_out, \ + moment2_out = adam_step(self.inputs, self.attrs) + + self.outputs = { + 'Moment1Out': moment1_out, + 'Moment2Out': moment2_out, + 'ParamOut': param_out, + 'Beta1PowOut': np.array([beta1_pow]).astype("float32") * beta1, + 'Beta2PowOut': np.array([beta2_pow]).astype("float32") * beta2 + } + + def test_check_output(self): + self.check_output_with_place(place=paddle.XPUPlace(0), atol=1e-2) + + +class TestAdamOp2(OpTest): + def setUp(self): + '''Test Adam Op with supplied attributes + ''' + self.op_type = "adam" + param = np.random.uniform(-1, 1, (102, 105)).astype("float32") + grad = np.random.uniform(-1, 1, (102, 105)).astype("float32") + moment1 = np.random.uniform(-1, 1, (102, 105)).astype("float32") + # The second moment is positive + moment2 = np.random.random((102, 105)).astype("float32") + + learning_rate = 0.001 + beta1 = 0.9 + beta2 = 0.999 + epsilon = 1e-8 + beta1_pow = beta1**10 + beta2_pow = beta2**10 + + self.inputs = { + 'Param': param, + 'Grad': grad, + 'Moment1': moment1, + 'Moment2': moment2, + 'LearningRate': np.array([learning_rate]).astype("float32"), + 'Beta1Pow': np.array([beta1_pow]).astype("float32"), + 'Beta2Pow': np.array([beta2_pow]).astype("float32") + } + + attributes = {'epsilon': epsilon, 'beta1': beta1, 'beta2': beta2} + + param_out, moment1_out, \ + moment2_out = adam_step(self.inputs, attributes) + + self.outputs = { + 'Moment1Out': moment1_out, + 'Moment2Out': moment2_out, + 'ParamOut': param_out, + 'Beta1PowOut': np.array([beta1_pow]).astype("float32") * beta1, + 'Beta2PowOut': np.array([beta2_pow]).astype("float32") * beta2 + } + + def test_check_output(self): + self.check_output_with_place(place=paddle.XPUPlace(0), atol=1e-2) + + +class TestAdamOpMultipleSteps(OpTest): + def setUp(self): + '''Test Adam Operator with supplied attributes + ''' + self.op_type = "adam" + self.num_steps = 10 + + param = np.random.uniform(-1, 1, (102, 105)).astype("float32") + grad = np.random.uniform(-1, 1, (102, 105)).astype("float32") + moment1 = np.random.uniform(-1, 1, (102, 105)).astype("float32") + # The second moment is positive + moment2 = np.random.random((102, 105)).astype("float32") + + learning_rate = 0.001 + self.beta1 = 0.9 + self.beta2 = 0.999 + epsilon = 1e-8 + self.beta1_pow = self.beta1**10 + self.beta2_pow = self.beta2**10 + + self.inputs = { + 'Param': param, + 'Grad': grad, + 'Moment1': moment1, + 'Moment2': moment2, + 'LearningRate': np.array([learning_rate]).astype("float32"), + 'Beta1Pow': np.array([self.beta1_pow]).astype("float32"), + 'Beta2Pow': np.array([self.beta2_pow]).astype("float32") + } + + self.attrs = { + 'epsilon': epsilon, + 'beta1': self.beta1, + 'beta2': self.beta2 + } + + def test_check_output(self): + for _ in range(self.num_steps): + param_out, moment1_out, \ + moment2_out = adam_step(self.inputs, self.attrs) + + beta1_pow_out = self.inputs['Beta1Pow'] * self.beta1 + beta2_pow_out = self.inputs['Beta2Pow'] * self.beta2 + self.outputs = { + 'Moment1Out': moment1_out, + 'Moment2Out': moment2_out, + 'ParamOut': param_out, + 'Beta1PowOut': beta1_pow_out, + 'Beta2PowOut': beta2_pow_out + } + + # Verify output for this step + self.check_output_with_place(place=paddle.XPUPlace(0), atol=1e-2) + + # Output of this step becomes input for next step + self.inputs['Param'] = param_out + self.inputs['Moment1'] = moment1_out + self.inputs['Moment2'] = moment2_out + + # Update powers of Beta1 and Beta2 for next time step + self.inputs['Beta1Pow'] = beta1_pow_out + self.inputs['Beta2Pow'] = beta2_pow_out + + # Randomize gradient for next step + self.inputs['Grad'] = np.random.uniform( + -1, 1, (102, 105)).astype("float32") + + +def adam_step(inputs, attributes): + ''' + Simulate one step of the adam optimizer + :param inputs: dict of inputs + :param attributes: dict of attributes + :return tuple: tuple of output param, moment1, moment2, + beta1 power accumulator and beta2 power accumulator + ''' + param = inputs['Param'] + grad = inputs['Grad'] + moment1 = inputs['Moment1'] + moment2 = inputs['Moment2'] + lr = inputs['LearningRate'] + beta1_pow = inputs['Beta1Pow'] + beta2_pow = inputs['Beta2Pow'] + + epsilon = attributes['epsilon'] + + if 'beta1' in attributes: + beta1 = attributes['beta1'] + else: + beta1 = inputs['Beta1Tensor'][0] + if 'beta2' in attributes: + beta2 = attributes['beta2'] + else: + beta2 = inputs['Beta2Tensor'][0] + + moment1_out = beta1 * moment1 + (1 - beta1) * grad + moment2_out = beta2 * moment2 + (1 - beta2) * np.square(grad) + lr_t = lr * np.sqrt(1 - beta2_pow) / (1 - beta1_pow) + param_out = param - lr_t * (moment1_out / (np.sqrt(moment2_out) + epsilon)) + return param_out, moment1_out, moment2_out + + +class TestAdamOpBetaVariable(OpTest): + def setUp(self): + '''Test Adam Op with beta as Variable + ''' + self.op_type = "adam" + param = np.random.uniform(-1, 1, (102, 105)).astype("float32") + grad = np.random.uniform(-1, 1, (102, 105)).astype("float32") + moment1 = np.random.uniform(-1, 1, (102, 105)).astype("float32") + # The second moment is positive + moment2 = np.random.random((102, 105)).astype("float32") + beta1 = 0.85 + beta2 = 0.95 + + learning_rate = 0.001 + epsilon = 1e-8 + beta1_pow = beta1**10 + beta2_pow = beta2**10 + + self.inputs = { + 'Param': param, + 'Grad': grad, + 'Moment1': moment1, + 'Moment2': moment2, + 'LearningRate': np.array([learning_rate]).astype("float32"), + 'Beta1Pow': np.array([beta1_pow]).astype("float32"), + 'Beta2Pow': np.array([beta2_pow]).astype("float32"), + "Beta1Tensor": np.array([beta1]).astype("float32"), + "Beta2Tensor": np.array([beta2]).astype("float32"), + } + + attributes = {'epsilon': epsilon, 'beta1': beta1, 'beta2': beta2} + + param_out, moment1_out, \ + moment2_out = adam_step(self.inputs, attributes) + + self.outputs = { + 'Moment1Out': moment1_out, + 'Moment2Out': moment2_out, + 'ParamOut': param_out, + 'Beta1PowOut': np.array([beta1_pow]).astype("float32") * beta1, + 'Beta2PowOut': np.array([beta2_pow]).astype("float32") * beta2 + } + + def test_check_output(self): + self.check_output_with_place(place=paddle.XPUPlace(0), atol=1e-2) + + +if __name__ == "__main__": + paddle.enable_static() + unittest.main()