未验证 提交 3ca713ee 编写于 作者: H houj04 提交者: GitHub

rmsprop for xpu. test=kunlun (#44175)

* rmsprop for xpu. test=kunlun

* minor fix (follow comments). test=kunlun
上级 9a3054c6
...@@ -10,7 +10,7 @@ set(XPU_RT_LIB_NAME "libxpurt.so") ...@@ -10,7 +10,7 @@ set(XPU_RT_LIB_NAME "libxpurt.so")
if(NOT DEFINED XPU_BASE_URL) if(NOT DEFINED XPU_BASE_URL)
set(XPU_BASE_URL_WITHOUT_DATE set(XPU_BASE_URL_WITHOUT_DATE
"https://baidu-kunlun-product.cdn.bcebos.com/KL-SDK/klsdk-dev") "https://baidu-kunlun-product.cdn.bcebos.com/KL-SDK/klsdk-dev")
set(XPU_BASE_URL "${XPU_BASE_URL_WITHOUT_DATE}/20220706") set(XPU_BASE_URL "${XPU_BASE_URL_WITHOUT_DATE}/20220707")
else() else()
set(XPU_BASE_URL "${XPU_BASE_URL}") set(XPU_BASE_URL "${XPU_BASE_URL}")
endif() endif()
...@@ -19,7 +19,7 @@ endif() ...@@ -19,7 +19,7 @@ endif()
if(NOT DEFINED XPU_XDNN_BASE_URL) if(NOT DEFINED XPU_XDNN_BASE_URL)
set(XPU_XDNN_BASE_URL_WITHOUT_DATE set(XPU_XDNN_BASE_URL_WITHOUT_DATE
"https://klx-sdk-release-public.su.bcebos.com/xdnn/dev") "https://klx-sdk-release-public.su.bcebos.com/xdnn/dev")
set(XPU_XDNN_BASE_URL "${XPU_XDNN_BASE_URL_WITHOUT_DATE}/20220706") set(XPU_XDNN_BASE_URL "${XPU_XDNN_BASE_URL_WITHOUT_DATE}/20220707")
else() else()
set(XPU_XDNN_BASE_URL "${XPU_XDNN_BASE_URL}") set(XPU_XDNN_BASE_URL "${XPU_XDNN_BASE_URL}")
endif() endif()
......
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
...@@ -90,6 +90,19 @@ class RmspropOpXPUKernel : public framework::OpKernel<T> { ...@@ -90,6 +90,19 @@ class RmspropOpXPUKernel : public framework::OpKernel<T> {
T decay = static_cast<T>(ctx.Attr<float>("decay")); T decay = static_cast<T>(ctx.Attr<float>("decay"));
T momentum = static_cast<T>(ctx.Attr<float>("momentum")); T momentum = static_cast<T>(ctx.Attr<float>("momentum"));
bool centered = ctx.Attr<bool>("centered");
PADDLE_ENFORCE_EQ(centered,
false,
platform::errors::Unimplemented(
"centered=True is not supported in the xpu kernel of "
"rmsprop. use XPU_BLACK_LIST to disable this op."));
/*
TODO(houj04): when XDNN api supports 'center', add input of
mean_grad_input and output of mean_grad_output. auto *mean_grad_input =
ctx.Input<Tensor>("MeanGrad"); auto *mean_grad_output =
ctx.Output<Tensor>("MeanGradOut");
*/
// outputs // outputs
auto& param_out = GET_DATA_SAFELY( auto& param_out = GET_DATA_SAFELY(
ctx.Output<LoDTensor>("ParamOut"), "Output", "ParamOut", "Rmsprop"); ctx.Output<LoDTensor>("ParamOut"), "Output", "ParamOut", "Rmsprop");
...@@ -101,18 +114,9 @@ class RmspropOpXPUKernel : public framework::OpKernel<T> { ...@@ -101,18 +114,9 @@ class RmspropOpXPUKernel : public framework::OpKernel<T> {
"Rmsprop"); "Rmsprop");
auto& dev_ctx = ctx.template device_context<DeviceContext>(); auto& dev_ctx = ctx.template device_context<DeviceContext>();
///// rmsprop优化算法 // int rmsprop(Context* ctx, const T* g, const T* p, const float* ms, const
/// // float* mom, T* p_out, float* ms_out, float* mom_out, float epsilon, float
/// ms_out[i] = rho * ms[i] + (1 - rho) * (g[i] * g[i]); // rho, float momentum, float lr, int n);
///
/// mom_out[i] = momentum * mom[i] + lr *
/// (g[i] / ((float)sqrt(ms_out[i] + epsilon)));
///
/// p_out[i] = p[i] - mom_out[i];
/// DLL_EXPORT int rmsprop(Context* ctx, const float* p,
/// const float* ms, const float* g, const float* mom,
/// float epsilon, float rho, float momentum, float lr,
/// float *ms_out, float *mom_out, float *p_out, int n)
int r = xpu::rmsprop(dev_ctx.x_context(), int r = xpu::rmsprop(dev_ctx.x_context(),
grad.template data<T>(), grad.template data<T>(),
param.template data<T>(), param.template data<T>(),
......
...@@ -363,6 +363,7 @@ XPUOpMap& get_kl2_ops() { ...@@ -363,6 +363,7 @@ XPUOpMap& get_kl2_ops() {
pOpKernelType(vartype::INT32, XPUPlace()), pOpKernelType(vartype::INT32, XPUPlace()),
pOpKernelType(vartype::BOOL, XPUPlace()), pOpKernelType(vartype::BOOL, XPUPlace()),
pOpKernelType(vartype::FP32, XPUPlace())})}, pOpKernelType(vartype::FP32, XPUPlace())})},
{"rmsprop", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"rnn", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"rnn", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"rnn_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"rnn_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"roi_align", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"roi_align", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
......
...@@ -36,4 +36,5 @@ no_check_set_white_list = [ ...@@ -36,4 +36,5 @@ no_check_set_white_list = [
'eigvalsh', 'eigvalsh',
'class_center_sample', 'class_center_sample',
'einsum', 'einsum',
'rmsprop',
] ]
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -13,288 +13,152 @@ ...@@ -13,288 +13,152 @@
# limitations under the License. # limitations under the License.
from __future__ import print_function from __future__ import print_function
import sys
sys.path.append("..")
import unittest import unittest
import numpy as np import numpy as np
import paddle.fluid.core as core import sys
from paddle.fluid.op import Operator
from op_test_xpu import XPUOpTest
import paddle.fluid as fluid
import paddle
'''
def create_selected_rows_and_tensor(scope, place, height, row_num,
embedding_size):
sr = scope.var("@selected_rows@").get_selected_rows()
tensor = scope.var("grad").get_tensor()
rows = np.random.random_integers(
low=0, high=height - 1, size=[row_num, ]).astype('int64')
sr_val = np.random.random(size=[row_num, embedding_size]).astype('float32')
sr.set_height(height)
sr.set_rows(rows)
sr.get_tensor().set(sr_val, place)
tensor_val = np.zeros(shape=[height, embedding_size], dtype='float32')
for i in range(row_num):
row = rows[i]
tensor_val[row, :] = tensor_val[row, :] + sr_val[i, :]
tensor.set(tensor_val, place)
return tensor_val, sr_val
'''
"""
class TestBase(XPUOpTest):
op_type = 'rmsprop'
def setup(self,
place,
is_sparse,
centered,
size,
row_num=None,
epsilon=1e-6):
np.random.seed(5) # fix seed
self.scope = fluid.global_scope()
self.place = place
self.param_name = 'param'
self.param = np.random.random(size).astype('float32')
self.mean_square_name = 'mean_square'
self.mean_square = np.random.uniform(
low=1, high=2, size=size).astype('float32')
self.mean_grad_name = 'mean_grad'
self.mean_grad = np.random.random(size).astype('float32')
self.lr_name = 'lr'
self.learning_rate = np.array([0.01]).astype('float32')
self.grad_name = 'grad'
self.is_sparse = is_sparse
self.grad = np.random.random(size).astype('float32')
grad_tensor = self.scope.var(self.grad_name).get_tensor()
grad_tensor.set(self.grad, place)
self.moment_name = 'moment'
self.moment = np.random.uniform(
low=0, high=1, size=size).astype('float32')
self.epsilon = epsilon
self.decay = 0.9
self.momentum = 0.1
self.centered = centered
self.ms_out = self.decay * self.mean_square + (1 - self.decay
) * self.grad * self.grad
if centered:
self.mg_out = self.decay * self.mean_grad + (1 - self.decay
) * self.grad
self.moment_out = self.momentum * self.moment + \
self.learning_rate * self.grad / np.sqrt(self.ms_out - np.square(self.mg_out) + self.epsilon)
else:
self.moment_out = self.momentum * self.moment + \
self.learning_rate * self.grad / np.sqrt(self.ms_out + self.epsilon)
self.param_out = self.param - self.moment_out
# create and initialize Param Variable
self.param_tensor = self.scope.var(self.param_name).get_tensor()
self.param_tensor.set(self.param, place)
self.mean_square_tensor = self.scope.var(
self.mean_square_name).get_tensor()
self.mean_square_tensor.set(self.mean_square, place)
lr = self.scope.var(self.lr_name).get_tensor()
lr.set(self.learning_rate, place)
self.moment_tensor = self.scope.var(self.moment_name).get_tensor()
self.moment_tensor.set(self.moment, place)
if self.centered:
self.mean_grad_tensor = self.scope.var(
self.mean_grad_name).get_tensor()
self.mean_grad_tensor.set(self.mean_grad, place)
def check(self, actual_t, expect_t, place, out_name, atol=1e-5):
self.assertTrue(
np.allclose(
actual_t, expect_t, atol=atol),
'Output (' + out_name + ') has diff at ' + str(place) + '\nExpect '
+ str(expect_t) + '\n' + 'But Got' + str(actual_t))
class TestRmspropOp(TestBase): sys.path.append("..")
def check_with_place(self,
place,
is_sparse,
centered,
size,
row_num=None,
epsilon=1e-6):
self.setup(place, is_sparse, centered, size, row_num, epsilon)
self.run_and_check()
def run_and_check(self): import paddle
#grad_name = self.grad_sr_name if self.is_sparse else self.grad_name import paddle.fluid.core as core
grad_name = self.grad_name
kwargs = { from op_test import OpTest
'Param': self.param_name, from op_test_xpu import XPUOpTest
'Grad': grad_name, from xpu.get_test_cover_info import create_test_class, get_xpu_op_support_types, XPUOpTestWrapper
'MeanSquare': self.mean_square_name,
'Moment': self.moment_name, paddle.enable_static()
'LearningRate': self.lr_name,
'ParamOut': self.param_name,
'MeanSquareOut': self.mean_square_name, def calculate_rmsprop_by_numpy(param, grad, mean_square, moment, learning_rate,
'MomentOut': self.moment_name, epsilon, decay, momentum):
mean_square_out = decay * mean_square + (1 - decay) * grad * grad
moment_out = momentum * moment + learning_rate * grad / np.sqrt(
mean_square_out + epsilon)
param_out = param - moment_out
return param_out, mean_square_out, moment_out
class XPUTestRMSPropOP(XPUOpTestWrapper):
def __init__(self):
self.op_name = 'rmsprop'
self.use_dynamic_create_class = False
class TestRMSPropOPBase(XPUOpTest):
def setUp(self):
self.place = paddle.XPUPlace(0)
self.xpu_version = core.get_xpu_device_version(0)
self.init_dtype()
self.set_case()
def set_case(self):
self.op_type = 'rmsprop'
self.dtype = self.in_type
self.init_config()
self.param = np.random.uniform(-1, 1,
self.input_shape).astype(self.dtype)
self.grad = np.random.uniform(-1, 1,
self.input_shape).astype(self.dtype)
self.mean_square = np.random.uniform(0, 1, self.input_shape).astype(
self.dtype)
self.moment = np.random.uniform(-1, 1,
self.input_shape).astype(self.dtype)
self.mean_grad = np.random.uniform(-1, 1, self.input_shape).astype(
self.dtype)
self.mean_grad_out = np.random.uniform(
-1, 1, self.input_shape).astype(self.dtype)
param_out, mean_square_out, moment_out = calculate_rmsprop_by_numpy(
param=self.param,
grad=self.grad,
mean_square=self.mean_square,
moment=self.moment,
learning_rate=self.learning_rate,
epsilon=self.epsilon,
decay=self.decay,
momentum=self.momentum)
self.inputs = {
'Param': self.param,
'Grad': self.grad,
'MeanSquare': self.mean_square,
'Moment': self.moment,
'LearningRate': self.learning_rate,
'MeanGrad': self.mean_grad,
'MeanGradOut': self.mean_grad_out,
}
self.attrs = {
'use_xpu': True,
'epsilon': self.epsilon, 'epsilon': self.epsilon,
'decay': self.decay, 'decay': self.decay,
'momentum': self.momentum, 'momentum': self.momentum,
'centered': self.centered 'centered':
False, # TODO(houj04): when XDNN api supports 'center = True', add more test cases
}
self.outputs = {
'ParamOut': param_out,
'MomentOut': moment_out,
'MeanSquareOut': mean_square_out,
'MeanGradOut': self.mean_grad_out
} }
if self.centered: def init_dtype(self):
kwargs['MeanGrad'] = self.mean_grad_name self.dtype = np.float32
kwargs['MeanGradOut'] = self.mean_grad_name
rmsprop_op = Operator('rmsprop', **kwargs)
atol = 1e-6
rmsprop_op.run(self.scope, self.place)
self.check(
np.array(self.mean_square_tensor),
self.ms_out,
self.place,
self.mean_square_name,
atol=atol)
self.check(
np.array(self.moment_tensor),
self.moment_out,
self.place,
self.moment_name,
atol=atol)
self.check(
np.array(self.param_tensor),
self.param_out,
self.place,
self.param_name,
atol=atol)
if self.centered:
self.check(
np.array(self.mean_grad_tensor), self.mg_out, self.place,
self.mean_grad_name)
def test_rmsprop(self): def test_check_output(self):
places = [paddle.XPUPlace(0)] self.check_output_with_place(self.place,
no_check_set=['MeanGradOut'])
size = (128, 320) def init_config(self):
for place in places: self.input_shape = [864]
for centered in [False]: self.learning_rate = np.array([0.001]).astype(self.dtype)
with fluid.scope_guard(core.Scope()): self.epsilon = 1e-4
self.check_with_place( self.decay = 0.9
place, is_sparse=False, centered=centered, size=size) self.momentum = 0.1
with fluid.scope_guard(core.Scope()): class XPUTestRMSProp1(TestRMSPropOPBase):
self.check_with_place(
place,
is_sparse=True,
centered=centered,
row_num=512,
size=size)
with fluid.scope_guard(core.Scope()): def init_config(self):
self.check_with_place( self.input_shape = [2, 768]
place, self.learning_rate = np.array([0.002]).astype(self.dtype)
is_sparse=True, self.epsilon = 1e-4
centered=centered, self.decay = 0.9
row_num=60, self.momentum = 0.1
size=size, )
class XPUTestRMSProp2(TestRMSPropOPBase):
class TestRMSPropV2(XPUOpTest): def init_config(self):
op_type = 'rmsprop' self.input_shape = [3, 8, 4096]
self.learning_rate = np.array([0.005]).astype(self.dtype)
self.epsilon = 1e-6
self.decay = 0.95
self.momentum = 0
def test_rmsprop_dygraph(self): class XPUTestRMSProp3(TestRMSPropOPBase):
paddle.disable_static()
value = np.arange(26).reshape(2, 13).astype('float32')
a = paddle.to_tensor(value)
linear = paddle.nn.Linear(13, 5)
# This can be any optimizer supported by dygraph.
adam = paddle.optimizer.RMSProp(
learning_rate=0.01,
parameters=linear.parameters(),
weight_decay=0.01)
out = linear(a)
out.backward()
adam.step()
adam.clear_gradients()
def test_rmsprop(self): def init_config(self):
place = paddle.XPUPlace(0) self.input_shape = [1024]
paddle.enable_static() self.learning_rate = np.array([0.01]).astype(self.dtype)
main = fluid.Program() self.epsilon = 1e-5
with fluid.program_guard(main): self.decay = 0.99
x = fluid.layers.data(name='x', shape=[13], dtype='float32') self.momentum = 0.02
y = fluid.layers.data(name='y', shape=[1], dtype='float32')
y_predict = fluid.layers.fc(input=x, size=1, act=None)
cost = fluid.layers.square_error_cost(input=y_predict, label=y)
avg_cost = paddle.mean(cost)
print(avg_cost.shape) class XPUTestRMSProp4(TestRMSPropOPBase):
linear = paddle.nn.Linear(13, 5)
rms_optimizer = paddle.optimizer.RMSProp(
learning_rate=0.1, parameters=linear.parameters())
rms_optimizer.minimize(avg_cost)
fetch_list = [avg_cost] def init_config(self):
train_reader = paddle.batch( self.input_shape = [2, 2, 255]
paddle.dataset.uci_housing.train(), batch_size=1) self.learning_rate = np.array([0.0005]).astype(self.dtype)
feeder = fluid.DataFeeder(place=place, feed_list=[x, y]) self.epsilon = 1e-3
exe = fluid.Executor(place) self.decay = 0.8
exe.run(fluid.default_startup_program()) self.momentum = 0.002
for data in train_reader():
exe.run(main, feed=feeder.feed(data), fetch_list=fetch_list)
def test_raise_error(self):
self.assertRaises(ValueError, paddle.optimizer.RMSProp, None)
self.assertRaises(
ValueError, paddle.optimizer.RMSProp, learning_rate=0.1, rho=None)
self.assertRaises(
ValueError,
paddle.optimizer.RMSProp,
learning_rate=0.1,
epsilon=None)
self.assertRaises(
ValueError,
paddle.optimizer.RMSProp,
learning_rate=0.1,
momentum=None)
def test_rmsprop_op_invalid_input(self): support_types = get_xpu_op_support_types('rmsprop')
paddle.disable_static() for stype in support_types:
linear = paddle.nn.Linear(10, 10) create_test_class(globals(), XPUTestRMSPropOP, stype)
with self.assertRaises(ValueError):
adam = paddle.optimizer.RMSProp(
0.1, epsilon=-1, parameters=linear.parameters())
with self.assertRaises(ValueError):
adam = paddle.optimizer.RMSProp(
0.1, momentum=-1, parameters=linear.parameters())
with self.assertRaises(ValueError):
adam = paddle.optimizer.RMSProp(
0.1, rho=-1, parameters=linear.parameters())
"""
if __name__ == "__main__": if __name__ == "__main__":
paddle.enable_static()
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册