未验证 提交 6f0c3d1f 编写于 作者: Y yinhaofeng 提交者: GitHub

xpu adam op (#28031)

* lookup_table_xpu op report errors;test=kunlun

* add adam xpu op;test=kunlun

* reset lookup

* change adam wrong;test=kunlun
上级 a5c95cd5
......@@ -36,6 +36,10 @@ static inline float GetAttrFromTensor(const framework::Tensor* tensor) {
TensorCopySync(*tensor, platform::CPUPlace(), &cpu_tensor);
tensor_data = cpu_tensor.data<float>();
}
if (platform::is_xpu_place(tensor->place())) {
TensorCopySync(*tensor, platform::CPUPlace(), &cpu_tensor);
tensor_data = cpu_tensor.data<float>();
}
return tensor_data[0];
}
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/optimizers/adam_op.h"
#include <gflags/gflags.h>
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
#ifdef PADDLE_WITH_XPU
template <typename DeviceContext, typename T>
class AdamOpXPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
const auto* param_var = ctx.InputVar("Param");
PADDLE_ENFORCE_EQ(param_var->IsType<framework::LoDTensor>(), true,
platform::errors::InvalidArgument(
"Tensor holds the wrong type,Expected Var(%s)'s "
"type is LoDTensor, "
"but the received is %s",
ctx.InputNames("Param").front(),
framework::ToTypeName(param_var->Type())));
using paddle::framework::LoDTensor;
T epsilon = static_cast<T>(ctx.Attr<float>("epsilon"));
auto& param = GET_DATA_SAFELY(ctx.Input<LoDTensor>("Param"), "Input",
"Param", "Adam");
// auto& grad = Ref(ctx.Input<LoDTensor>("Grad"), "Must set Grad");
auto* grad_var = ctx.InputVar("Grad");
auto& mom1 = GET_DATA_SAFELY(ctx.Input<LoDTensor>("Moment1"), "Input",
"Moment1", "Adam");
auto& mom2 = GET_DATA_SAFELY(ctx.Input<LoDTensor>("Moment2"), "Input",
"Moment2", "Adam");
auto& lr = GET_DATA_SAFELY(ctx.Input<LoDTensor>("LearningRate"), "Input",
"LearningRate", "Adam");
auto& beta1_pow = GET_DATA_SAFELY(ctx.Input<LoDTensor>("Beta1Pow"), "Input",
"Beta1Pow", "Adam");
auto& beta2_pow = GET_DATA_SAFELY(ctx.Input<LoDTensor>("Beta2Pow"), "Input",
"Beta2Pow", "Adam");
auto& param_out = GET_DATA_SAFELY(ctx.Output<LoDTensor>("ParamOut"),
"Output", "ParamOut", "Adam");
auto& mom1_out = GET_DATA_SAFELY(ctx.Output<LoDTensor>("Moment1Out"),
"Output", "Moment1Out", "Adam");
auto& mom2_out = GET_DATA_SAFELY(ctx.Output<LoDTensor>("Moment2Out"),
"Output", "Moment2Out", "Adam");
auto* beta1_pow_out = ctx.Output<LoDTensor>("Beta1PowOut");
auto* beta2_pow_out = ctx.Output<LoDTensor>("Beta2PowOut");
PADDLE_ENFORCE_EQ(beta1_pow_out->numel(), 1,
platform::errors::InvalidArgument(
"Tensor holds the wrong size, Expected beta1 pow "
"output size is 1, but received "
"value is:%d.",
beta1_pow_out->numel()));
PADDLE_ENFORCE_EQ(beta2_pow_out->numel(), 1,
platform::errors::InvalidArgument(
"Tensor holds the wrong size, Expected beta2 pow "
"output size is 1, but received "
"value is:%d.",
beta2_pow_out->numel()));
T beta1 = static_cast<T>(ctx.Attr<float>("beta1"));
if (ctx.HasInput("Beta1Tensor")) {
auto* beta1_tensor = ctx.Input<framework::Tensor>("Beta1Tensor");
beta1 = static_cast<T>(GetAttrFromTensor(beta1_tensor));
}
T beta2 = static_cast<T>(ctx.Attr<float>("beta2"));
if (ctx.HasInput("Beta2Tensor")) {
auto* beta2_tensor = ctx.Input<framework::Tensor>("Beta2Tensor");
beta2 = static_cast<T>(GetAttrFromTensor(beta2_tensor));
}
if (grad_var->IsType<framework::LoDTensor>()) {
auto& grad = GET_DATA_SAFELY(ctx.Input<LoDTensor>("Grad"), "Input",
"Grad", "Adam");
auto& dev_ctx = ctx.template device_context<DeviceContext>();
int r = xpu::adam(
dev_ctx.x_context(), grad.template data<T>(), mom1.template data<T>(),
mom2.template data<T>(), param.template data<T>(),
beta1_pow.template data<T>(), beta2_pow.template data<T>(), beta1,
beta2, epsilon, lr.template data<T>(),
mom1_out.template mutable_data<T>(ctx.GetPlace()),
mom2_out.template mutable_data<T>(ctx.GetPlace()),
param_out.template mutable_data<T>(ctx.GetPlace()), param.numel());
const float* ptr0 = beta1_pow.template data<T>();
float* ptr1 = beta1_pow_out->mutable_data<T>(ctx.GetPlace());
float cpudata;
xpu_memcpy(&cpudata, ptr0, sizeof(float), XPU_DEVICE_TO_HOST);
cpudata = cpudata * beta1;
xpu_memcpy(ptr1, &cpudata, sizeof(float), XPU_HOST_TO_DEVICE);
const float* ptr2 = beta2_pow.template data<T>();
float* ptr3 = beta2_pow_out->mutable_data<T>(ctx.GetPlace());
float cpudata1;
xpu_memcpy(&cpudata1, ptr2, sizeof(float), XPU_DEVICE_TO_HOST);
cpudata1 = cpudata1 * beta2;
xpu_memcpy(ptr3, &cpudata1, sizeof(float), XPU_HOST_TO_DEVICE);
PADDLE_ENFORCE_EQ(r == xpu::Error_t::SUCCESS, true,
platform::errors::External(
"XPU API return wrong value[%d], please check "
"where Baidu Kunlun Card is properly installed.",
r));
} else {
PADDLE_ENFORCE_EQ(1, 2, platform::errors::InvalidArgument(
"Variable type not supported by adam_op"));
}
}
};
#endif
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
#ifdef PADDLE_WITH_XPU
REGISTER_OP_XPU_KERNEL(
adam, ops::AdamOpXPUKernel<paddle::platform::XPUDeviceContext, float>);
#endif
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import sys
sys.path.append("..")
import unittest
import numpy as np
from op_test import OpTest
from paddle.fluid import core
from paddle.fluid.op import Operator
import paddle.fluid as fluid
import paddle
class TestAdamOp1(OpTest):
def setUp(self):
'''Test Adam Op with supplied attributes
'''
self.op_type = "adam"
param = np.random.uniform(-1, 1, (102, 105)).astype("float32")
grad = np.random.uniform(-1, 1, (102, 105)).astype("float32")
moment1 = np.random.uniform(-1, 1, (102, 105)).astype("float32")
# The second moment is positive
moment2 = np.random.random((102, 105)).astype("float32")
learning_rate = 0.004
beta1 = 0.78
beta2 = 0.836
epsilon = 1e-4
beta1_pow = beta1**10
beta2_pow = beta2**10
self.inputs = {
'Param': param,
'Grad': grad,
'Moment1': moment1,
'Moment2': moment2,
'LearningRate': np.array([learning_rate]).astype("float32"),
'Beta1Pow': np.array([beta1_pow]).astype("float32"),
'Beta2Pow': np.array([beta2_pow]).astype("float32")
}
self.attrs = {'epsilon': epsilon, 'beta1': beta1, 'beta2': beta2}
param_out, moment1_out, \
moment2_out = adam_step(self.inputs, self.attrs)
self.outputs = {
'Moment1Out': moment1_out,
'Moment2Out': moment2_out,
'ParamOut': param_out,
'Beta1PowOut': np.array([beta1_pow]).astype("float32") * beta1,
'Beta2PowOut': np.array([beta2_pow]).astype("float32") * beta2
}
def test_check_output(self):
self.check_output_with_place(place=paddle.XPUPlace(0), atol=1e-2)
class TestAdamOp2(OpTest):
def setUp(self):
'''Test Adam Op with supplied attributes
'''
self.op_type = "adam"
param = np.random.uniform(-1, 1, (102, 105)).astype("float32")
grad = np.random.uniform(-1, 1, (102, 105)).astype("float32")
moment1 = np.random.uniform(-1, 1, (102, 105)).astype("float32")
# The second moment is positive
moment2 = np.random.random((102, 105)).astype("float32")
learning_rate = 0.001
beta1 = 0.9
beta2 = 0.999
epsilon = 1e-8
beta1_pow = beta1**10
beta2_pow = beta2**10
self.inputs = {
'Param': param,
'Grad': grad,
'Moment1': moment1,
'Moment2': moment2,
'LearningRate': np.array([learning_rate]).astype("float32"),
'Beta1Pow': np.array([beta1_pow]).astype("float32"),
'Beta2Pow': np.array([beta2_pow]).astype("float32")
}
attributes = {'epsilon': epsilon, 'beta1': beta1, 'beta2': beta2}
param_out, moment1_out, \
moment2_out = adam_step(self.inputs, attributes)
self.outputs = {
'Moment1Out': moment1_out,
'Moment2Out': moment2_out,
'ParamOut': param_out,
'Beta1PowOut': np.array([beta1_pow]).astype("float32") * beta1,
'Beta2PowOut': np.array([beta2_pow]).astype("float32") * beta2
}
def test_check_output(self):
self.check_output_with_place(place=paddle.XPUPlace(0), atol=1e-2)
class TestAdamOpMultipleSteps(OpTest):
def setUp(self):
'''Test Adam Operator with supplied attributes
'''
self.op_type = "adam"
self.num_steps = 10
param = np.random.uniform(-1, 1, (102, 105)).astype("float32")
grad = np.random.uniform(-1, 1, (102, 105)).astype("float32")
moment1 = np.random.uniform(-1, 1, (102, 105)).astype("float32")
# The second moment is positive
moment2 = np.random.random((102, 105)).astype("float32")
learning_rate = 0.001
self.beta1 = 0.9
self.beta2 = 0.999
epsilon = 1e-8
self.beta1_pow = self.beta1**10
self.beta2_pow = self.beta2**10
self.inputs = {
'Param': param,
'Grad': grad,
'Moment1': moment1,
'Moment2': moment2,
'LearningRate': np.array([learning_rate]).astype("float32"),
'Beta1Pow': np.array([self.beta1_pow]).astype("float32"),
'Beta2Pow': np.array([self.beta2_pow]).astype("float32")
}
self.attrs = {
'epsilon': epsilon,
'beta1': self.beta1,
'beta2': self.beta2
}
def test_check_output(self):
for _ in range(self.num_steps):
param_out, moment1_out, \
moment2_out = adam_step(self.inputs, self.attrs)
beta1_pow_out = self.inputs['Beta1Pow'] * self.beta1
beta2_pow_out = self.inputs['Beta2Pow'] * self.beta2
self.outputs = {
'Moment1Out': moment1_out,
'Moment2Out': moment2_out,
'ParamOut': param_out,
'Beta1PowOut': beta1_pow_out,
'Beta2PowOut': beta2_pow_out
}
# Verify output for this step
self.check_output_with_place(place=paddle.XPUPlace(0), atol=1e-2)
# Output of this step becomes input for next step
self.inputs['Param'] = param_out
self.inputs['Moment1'] = moment1_out
self.inputs['Moment2'] = moment2_out
# Update powers of Beta1 and Beta2 for next time step
self.inputs['Beta1Pow'] = beta1_pow_out
self.inputs['Beta2Pow'] = beta2_pow_out
# Randomize gradient for next step
self.inputs['Grad'] = np.random.uniform(
-1, 1, (102, 105)).astype("float32")
def adam_step(inputs, attributes):
'''
Simulate one step of the adam optimizer
:param inputs: dict of inputs
:param attributes: dict of attributes
:return tuple: tuple of output param, moment1, moment2,
beta1 power accumulator and beta2 power accumulator
'''
param = inputs['Param']
grad = inputs['Grad']
moment1 = inputs['Moment1']
moment2 = inputs['Moment2']
lr = inputs['LearningRate']
beta1_pow = inputs['Beta1Pow']
beta2_pow = inputs['Beta2Pow']
epsilon = attributes['epsilon']
if 'beta1' in attributes:
beta1 = attributes['beta1']
else:
beta1 = inputs['Beta1Tensor'][0]
if 'beta2' in attributes:
beta2 = attributes['beta2']
else:
beta2 = inputs['Beta2Tensor'][0]
moment1_out = beta1 * moment1 + (1 - beta1) * grad
moment2_out = beta2 * moment2 + (1 - beta2) * np.square(grad)
lr_t = lr * np.sqrt(1 - beta2_pow) / (1 - beta1_pow)
param_out = param - lr_t * (moment1_out / (np.sqrt(moment2_out) + epsilon))
return param_out, moment1_out, moment2_out
class TestAdamOpBetaVariable(OpTest):
def setUp(self):
'''Test Adam Op with beta as Variable
'''
self.op_type = "adam"
param = np.random.uniform(-1, 1, (102, 105)).astype("float32")
grad = np.random.uniform(-1, 1, (102, 105)).astype("float32")
moment1 = np.random.uniform(-1, 1, (102, 105)).astype("float32")
# The second moment is positive
moment2 = np.random.random((102, 105)).astype("float32")
beta1 = 0.85
beta2 = 0.95
learning_rate = 0.001
epsilon = 1e-8
beta1_pow = beta1**10
beta2_pow = beta2**10
self.inputs = {
'Param': param,
'Grad': grad,
'Moment1': moment1,
'Moment2': moment2,
'LearningRate': np.array([learning_rate]).astype("float32"),
'Beta1Pow': np.array([beta1_pow]).astype("float32"),
'Beta2Pow': np.array([beta2_pow]).astype("float32"),
"Beta1Tensor": np.array([beta1]).astype("float32"),
"Beta2Tensor": np.array([beta2]).astype("float32"),
}
attributes = {'epsilon': epsilon, 'beta1': beta1, 'beta2': beta2}
param_out, moment1_out, \
moment2_out = adam_step(self.inputs, attributes)
self.outputs = {
'Moment1Out': moment1_out,
'Moment2Out': moment2_out,
'ParamOut': param_out,
'Beta1PowOut': np.array([beta1_pow]).astype("float32") * beta1,
'Beta2PowOut': np.array([beta2_pow]).astype("float32") * beta2
}
def test_check_output(self):
self.check_output_with_place(place=paddle.XPUPlace(0), atol=1e-2)
if __name__ == "__main__":
paddle.enable_static()
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册