未验证 提交 54b756e2 编写于 作者: Y ykkk2333 提交者: GitHub

add xpu centered rmsprop (#48658)

* add stat tool

* add roll and roll_grad kernels and strided_slice and strided_slice_grad kernels, test=kunlun

* add xpu rmsprop centered, test=kunlun
上级 3ba1237e
...@@ -10,7 +10,7 @@ set(XPU_RT_LIB_NAME "libxpurt.so") ...@@ -10,7 +10,7 @@ set(XPU_RT_LIB_NAME "libxpurt.so")
if(NOT DEFINED XPU_BASE_URL) if(NOT DEFINED XPU_BASE_URL)
set(XPU_BASE_URL_WITHOUT_DATE set(XPU_BASE_URL_WITHOUT_DATE
"https://baidu-kunlun-product.su.bcebos.com/KL-SDK/klsdk-dev") "https://baidu-kunlun-product.su.bcebos.com/KL-SDK/klsdk-dev")
set(XPU_BASE_URL "${XPU_BASE_URL_WITHOUT_DATE}/20221124") set(XPU_BASE_URL "${XPU_BASE_URL_WITHOUT_DATE}/20221201")
else() else()
set(XPU_BASE_URL "${XPU_BASE_URL}") set(XPU_BASE_URL "${XPU_BASE_URL}")
endif() endif()
......
...@@ -117,7 +117,8 @@ XPUOpMap& get_kl2_ops() { ...@@ -117,7 +117,8 @@ XPUOpMap& get_kl2_ops() {
{"clip_by_norm", {"clip_by_norm",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"coalesce_tensor", {"coalesce_tensor",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
{"concat_grad", {"concat_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})}, pOpKernelType(vartype::FP16, XPUPlace())})},
......
...@@ -37,12 +37,6 @@ void RmspropDenseKernel(const Context& dev_ctx, ...@@ -37,12 +37,6 @@ void RmspropDenseKernel(const Context& dev_ctx,
DenseTensor* moment_out, DenseTensor* moment_out,
DenseTensor* mean_square_out, DenseTensor* mean_square_out,
DenseTensor* mean_grad_out) { DenseTensor* mean_grad_out) {
// check input
PADDLE_ENFORCE_EQ(centered,
false,
errors::Unimplemented(
"centered=True is not supported in the xpu kernel of "
"rmsprop. use XPU_BLACK_LIST to disable this op."));
// copy learning_rate to cpu // copy learning_rate to cpu
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
learning_rate.dims().size(), learning_rate.dims().size(),
...@@ -62,23 +56,56 @@ void RmspropDenseKernel(const Context& dev_ctx, ...@@ -62,23 +56,56 @@ void RmspropDenseKernel(const Context& dev_ctx,
dev_ctx.template Alloc<T>(moment_out); dev_ctx.template Alloc<T>(moment_out);
dev_ctx.template Alloc<T>(mean_square_out); dev_ctx.template Alloc<T>(mean_square_out);
// int rmsprop(Context* ctx, const T* g, const T* p, const float* ms, const if (centered) {
// float* mom, T* p_out, float* ms_out, float* mom_out, float epsilon, float dev_ctx.template Alloc<T>(mean_grad_out);
// rho, float momentum, float lr, int n); auto mg_tensor = mean_grad.get_ptr();
int r = xpu::rmsprop(dev_ctx.x_context(), if (mg_tensor) {
grad.data<T>(), PADDLE_ENFORCE_EQ(
param.data<T>(), mg_tensor->Holder(),
mean_square.data<T>(), mean_grad_out->Holder(),
moment.data<T>(), phi::errors::InvalidArgument(
param_out->data<T>(), "MeanGrad and MeanGradOut must be the same Tensor"));
mean_square_out->data<T>(), } else {
moment_out->data<T>(), PADDLE_ENFORCE_EQ(
epsilon, mg_tensor,
decay, mean_grad_out,
momentum, phi::errors::InvalidArgument(
learning_rate_cpu, "MeanGrad and MeanGradOut must be the same Tensor"));
param.numel()); }
PADDLE_ENFORCE_XDNN_SUCCESS(r, "rmsprop"); int r = xpu::rmsprop(dev_ctx.x_context(),
grad.data<T>(),
param.data<T>(),
mean_square.data<T>(),
moment.data<T>(),
param_out->data<T>(),
mean_square_out->data<T>(),
moment_out->data<T>(),
epsilon,
decay,
momentum,
learning_rate_cpu,
param.numel(),
centered,
mg_tensor->data<T>(),
mean_grad_out->data<T>());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "centered rmsprop");
} else {
int r = xpu::rmsprop(dev_ctx.x_context(),
grad.data<T>(),
param.data<T>(),
mean_square.data<T>(),
moment.data<T>(),
param_out->data<T>(),
mean_square_out->data<T>(),
moment_out->data<T>(),
epsilon,
decay,
momentum,
learning_rate_cpu,
param.numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "uncentered rmsprop");
}
} }
} // namespace phi } // namespace phi
......
...@@ -27,7 +27,9 @@ from xpu.get_test_cover_info import ( ...@@ -27,7 +27,9 @@ from xpu.get_test_cover_info import (
) )
import paddle import paddle
import paddle.fluid as fluid
import paddle.fluid.core as core import paddle.fluid.core as core
from paddle.fluid.op import Operator
paddle.enable_static() paddle.enable_static()
...@@ -161,6 +163,185 @@ class XPUTestRMSPropOP(XPUOpTestWrapper): ...@@ -161,6 +163,185 @@ class XPUTestRMSPropOP(XPUOpTestWrapper):
self.momentum = 0.002 self.momentum = 0.002
class TestBase(unittest.TestCase):
def setup(
self, place, is_sparse, centered, size, row_num=None, epsilon=1e-6
):
np.random.seed(5) # fix seed
self.scope = fluid.global_scope()
self.place = place
self.param_name = "param"
self.param = np.random.random(size).astype("float32")
self.mean_square_name = "mean_square"
self.mean_square = np.random.uniform(low=1, high=2, size=size).astype(
"float32"
)
self.mean_grad_name = "mean_grad"
self.mean_grad = np.random.random(size).astype("float32")
self.lr_name = "lr"
self.learning_rate = np.array([0.01]).astype("float32")
self.grad_name = "grad"
self.is_sparse = is_sparse
self.grad = np.random.random(size).astype("float32")
grad_tensor = self.scope.var(self.grad_name).get_tensor()
grad_tensor.set(self.grad, place)
self.moment_name = "moment"
self.moment = np.random.uniform(low=0, high=1, size=size).astype(
"float32"
)
self.epsilon = epsilon
self.decay = 0.9
self.momentum = 0.1
self.centered = centered
self.ms_out = (
self.decay * self.mean_square
+ (1 - self.decay) * self.grad * self.grad
)
if centered:
self.mg_out = (
self.decay * self.mean_grad + (1 - self.decay) * self.grad
)
self.moment_out = (
self.momentum * self.moment
+ self.learning_rate
* self.grad
/ np.sqrt(self.ms_out - np.square(self.mg_out) + self.epsilon)
)
else:
self.moment_out = (
self.momentum * self.moment
+ self.learning_rate
* self.grad
/ np.sqrt(self.ms_out + self.epsilon)
)
self.param_out = self.param - self.moment_out
# create and initialize Param Variable
self.param_tensor = self.scope.var(self.param_name).get_tensor()
self.param_tensor.set(self.param, place)
self.mean_square_tensor = self.scope.var(
self.mean_square_name
).get_tensor()
self.mean_square_tensor.set(self.mean_square, place)
lr = self.scope.var(self.lr_name).get_tensor()
lr.set(self.learning_rate, place)
self.moment_tensor = self.scope.var(self.moment_name).get_tensor()
self.moment_tensor.set(self.moment, place)
if self.centered:
self.mean_grad_tensor = self.scope.var(
self.mean_grad_name
).get_tensor()
self.mean_grad_tensor.set(self.mean_grad, place)
def check(self, actual_t, expect_t, place, out_name, atol=1e-5):
np.testing.assert_allclose(
actual_t,
expect_t,
rtol=1e-05,
atol=atol,
err_msg='Output ('
+ out_name
+ ') has diff at '
+ str(place)
+ '\nExpect '
+ str(expect_t)
+ '\n'
+ 'But Got'
+ str(actual_t),
)
class TestRmspropOp(TestBase):
def check_with_place(
self, place, is_sparse, centered, size, row_num=None, epsilon=1e-6
):
self.setup(place, is_sparse, centered, size, row_num, epsilon)
self.run_and_check()
def run_and_check(self):
grad_name = self.grad_name
kwargs = {
'Param': self.param_name,
'Grad': grad_name,
'MeanSquare': self.mean_square_name,
'Moment': self.moment_name,
'LearningRate': self.lr_name,
'ParamOut': self.param_name,
'MeanSquareOut': self.mean_square_name,
'MomentOut': self.moment_name,
'epsilon': self.epsilon,
'decay': self.decay,
'momentum': self.momentum,
'centered': self.centered,
}
if self.centered:
kwargs['MeanGrad'] = self.mean_grad_name
kwargs['MeanGradOut'] = self.mean_grad_name
rmsprop_op = Operator('rmsprop', **kwargs)
atol = 1e-6
rmsprop_op.run(self.scope, self.place)
self.check(
np.array(self.mean_square_tensor),
self.ms_out,
self.place,
self.mean_square_name,
atol=atol,
)
self.check(
np.array(self.moment_tensor),
self.moment_out,
self.place,
self.moment_name,
atol=atol,
)
self.check(
np.array(self.param_tensor),
self.param_out,
self.place,
self.param_name,
atol=atol,
)
if self.centered:
self.check(
np.array(self.mean_grad_tensor),
self.mg_out,
self.place,
self.mean_grad_name,
)
def test_rmsprop(self):
places = [core.XPUPlace(0)]
size = (128, 320)
for place in places:
for centered in [False, True]:
with fluid.scope_guard(core.Scope()):
self.check_with_place(
place, is_sparse=False, centered=centered, size=size
)
support_types = get_xpu_op_support_types('rmsprop') support_types = get_xpu_op_support_types('rmsprop')
for stype in support_types: for stype in support_types:
create_test_class(globals(), XPUTestRMSPropOP, stype) create_test_class(globals(), XPUTestRMSPropOP, stype)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册