未验证 提交 3ca713ee 编写于 作者: H houj04 提交者: GitHub

rmsprop for xpu. test=kunlun (#44175)

* rmsprop for xpu. test=kunlun

* minor fix (follow comments). test=kunlun
上级 9a3054c6
......@@ -10,7 +10,7 @@ set(XPU_RT_LIB_NAME "libxpurt.so")
if(NOT DEFINED XPU_BASE_URL)
set(XPU_BASE_URL_WITHOUT_DATE
"https://baidu-kunlun-product.cdn.bcebos.com/KL-SDK/klsdk-dev")
set(XPU_BASE_URL "${XPU_BASE_URL_WITHOUT_DATE}/20220706")
set(XPU_BASE_URL "${XPU_BASE_URL_WITHOUT_DATE}/20220707")
else()
set(XPU_BASE_URL "${XPU_BASE_URL}")
endif()
......@@ -19,7 +19,7 @@ endif()
if(NOT DEFINED XPU_XDNN_BASE_URL)
set(XPU_XDNN_BASE_URL_WITHOUT_DATE
"https://klx-sdk-release-public.su.bcebos.com/xdnn/dev")
set(XPU_XDNN_BASE_URL "${XPU_XDNN_BASE_URL_WITHOUT_DATE}/20220706")
set(XPU_XDNN_BASE_URL "${XPU_XDNN_BASE_URL_WITHOUT_DATE}/20220707")
else()
set(XPU_XDNN_BASE_URL "${XPU_XDNN_BASE_URL}")
endif()
......
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef PADDLE_WITH_XPU
#include <gflags/gflags.h>
#include <iostream>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/device/device_wrapper.h"
namespace paddle {
namespace operators {
static inline float GetAttrFromTensor(const framework::Tensor* tensor) {
const float* tensor_data = tensor->data<float>();
framework::Tensor cpu_tensor;
if (platform::is_gpu_place(tensor->place()) ||
platform::is_xpu_place(tensor->place())) {
paddle::framework::TensorCopySync(
*tensor, platform::CPUPlace(), &cpu_tensor);
tensor_data = cpu_tensor.data<float>();
}
return tensor_data[0];
}
using framework::OpKernelType;
using framework::Tensor;
template <typename DeviceContext, typename T>
class RmspropOpXPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
using paddle::framework::LoDTensor;
// check Param & Grad tensor type
const auto* param_var = ctx.InputVar("Param");
PADDLE_ENFORCE_EQ(param_var->IsType<LoDTensor>(),
true,
platform::errors::InvalidArgument(
"Tensor holds the wrong type,Expected Var(%s)'s "
"type is LoDTensor, "
"but the received is %s",
ctx.InputNames("Param").front(),
framework::ToTypeName(param_var->Type())));
const auto* grad_var = ctx.InputVar("Grad");
PADDLE_ENFORCE_EQ(grad_var->IsType<LoDTensor>(),
true,
platform::errors::InvalidArgument(
"Tensor holds the wrong type,Expected Var(%s)'s "
"type is LoDTensor, "
"but the received is %s",
ctx.InputNames("Grad").front(),
framework::ToTypeName(grad_var->Type())));
// inputs
auto& param = GET_DATA_SAFELY(
ctx.Input<LoDTensor>("Param"), "Input", "Param", "Rmsprop");
auto& meanSquare = GET_DATA_SAFELY(
ctx.Input<LoDTensor>("MeanSquare"), "Input", "MeanSquare", "Rmsprop");
auto& grad = GET_DATA_SAFELY(
ctx.Input<LoDTensor>("Grad"), "Input", "Grad", "Rmsprop");
auto& mom = GET_DATA_SAFELY(
ctx.Input<LoDTensor>("Moment"), "Input", "Moment", "Rmsprop");
auto* learning_rate = ctx.Input<Tensor>("LearningRate");
PADDLE_ENFORCE_EQ(learning_rate->dims().size(),
1,
platform::errors::InvalidArgument(
"learining rate should have dimension = 1."
" But received learning rate dim [%s] ",
learning_rate->dims().size()));
T lr = static_cast<T>(GetAttrFromTensor(learning_rate));
// constants
T epsilon = static_cast<T>(ctx.Attr<float>("epsilon"));
T decay = static_cast<T>(ctx.Attr<float>("decay"));
T momentum = static_cast<T>(ctx.Attr<float>("momentum"));
// outputs
auto& param_out = GET_DATA_SAFELY(
ctx.Output<LoDTensor>("ParamOut"), "Output", "ParamOut", "Rmsprop");
auto& mom_out = GET_DATA_SAFELY(
ctx.Output<LoDTensor>("MomentOut"), "Output", "MomentOut", "Rmsprop");
auto& mom_sqrt_out = GET_DATA_SAFELY(ctx.Output<LoDTensor>("MeanSquareOut"),
"Output",
"MeanSquareOut",
"Rmsprop");
auto& dev_ctx = ctx.template device_context<DeviceContext>();
///// rmsprop优化算法
///
/// ms_out[i] = rho * ms[i] + (1 - rho) * (g[i] * g[i]);
///
/// mom_out[i] = momentum * mom[i] + lr *
/// (g[i] / ((float)sqrt(ms_out[i] + epsilon)));
///
/// p_out[i] = p[i] - mom_out[i];
/// DLL_EXPORT int rmsprop(Context* ctx, const float* p,
/// const float* ms, const float* g, const float* mom,
/// float epsilon, float rho, float momentum, float lr,
/// float *ms_out, float *mom_out, float *p_out, int n)
int r = xpu::rmsprop(dev_ctx.x_context(),
grad.template data<T>(),
param.template data<T>(),
meanSquare.template data<T>(),
mom.template data<T>(),
param_out.template mutable_data<T>(ctx.GetPlace()),
mom_sqrt_out.template mutable_data<T>(ctx.GetPlace()),
mom_out.template mutable_data<T>(ctx.GetPlace()),
epsilon,
decay,
momentum,
lr,
param.numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "rmsprop");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_XPU_KERNEL(
rmsprop,
ops::RmspropOpXPUKernel<paddle::platform::XPUDeviceContext, float>);
#endif
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef PADDLE_WITH_XPU
#include <gflags/gflags.h>
#include <iostream>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/device/device_wrapper.h"
namespace paddle {
namespace operators {
static inline float GetAttrFromTensor(const framework::Tensor* tensor) {
const float* tensor_data = tensor->data<float>();
framework::Tensor cpu_tensor;
if (platform::is_gpu_place(tensor->place()) ||
platform::is_xpu_place(tensor->place())) {
paddle::framework::TensorCopySync(
*tensor, platform::CPUPlace(), &cpu_tensor);
tensor_data = cpu_tensor.data<float>();
}
return tensor_data[0];
}
using framework::OpKernelType;
using framework::Tensor;
template <typename DeviceContext, typename T>
class RmspropOpXPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
using paddle::framework::LoDTensor;
// check Param & Grad tensor type
const auto* param_var = ctx.InputVar("Param");
PADDLE_ENFORCE_EQ(param_var->IsType<LoDTensor>(),
true,
platform::errors::InvalidArgument(
"Tensor holds the wrong type,Expected Var(%s)'s "
"type is LoDTensor, "
"but the received is %s",
ctx.InputNames("Param").front(),
framework::ToTypeName(param_var->Type())));
const auto* grad_var = ctx.InputVar("Grad");
PADDLE_ENFORCE_EQ(grad_var->IsType<LoDTensor>(),
true,
platform::errors::InvalidArgument(
"Tensor holds the wrong type,Expected Var(%s)'s "
"type is LoDTensor, "
"but the received is %s",
ctx.InputNames("Grad").front(),
framework::ToTypeName(grad_var->Type())));
// inputs
auto& param = GET_DATA_SAFELY(
ctx.Input<LoDTensor>("Param"), "Input", "Param", "Rmsprop");
auto& meanSquare = GET_DATA_SAFELY(
ctx.Input<LoDTensor>("MeanSquare"), "Input", "MeanSquare", "Rmsprop");
auto& grad = GET_DATA_SAFELY(
ctx.Input<LoDTensor>("Grad"), "Input", "Grad", "Rmsprop");
auto& mom = GET_DATA_SAFELY(
ctx.Input<LoDTensor>("Moment"), "Input", "Moment", "Rmsprop");
auto* learning_rate = ctx.Input<Tensor>("LearningRate");
PADDLE_ENFORCE_EQ(learning_rate->dims().size(),
1,
platform::errors::InvalidArgument(
"learining rate should have dimension = 1."
" But received learning rate dim [%s] ",
learning_rate->dims().size()));
T lr = static_cast<T>(GetAttrFromTensor(learning_rate));
// constants
T epsilon = static_cast<T>(ctx.Attr<float>("epsilon"));
T decay = static_cast<T>(ctx.Attr<float>("decay"));
T momentum = static_cast<T>(ctx.Attr<float>("momentum"));
bool centered = ctx.Attr<bool>("centered");
PADDLE_ENFORCE_EQ(centered,
false,
platform::errors::Unimplemented(
"centered=True is not supported in the xpu kernel of "
"rmsprop. use XPU_BLACK_LIST to disable this op."));
/*
TODO(houj04): when XDNN api supports 'center', add input of
mean_grad_input and output of mean_grad_output. auto *mean_grad_input =
ctx.Input<Tensor>("MeanGrad"); auto *mean_grad_output =
ctx.Output<Tensor>("MeanGradOut");
*/
// outputs
auto& param_out = GET_DATA_SAFELY(
ctx.Output<LoDTensor>("ParamOut"), "Output", "ParamOut", "Rmsprop");
auto& mom_out = GET_DATA_SAFELY(
ctx.Output<LoDTensor>("MomentOut"), "Output", "MomentOut", "Rmsprop");
auto& mom_sqrt_out = GET_DATA_SAFELY(ctx.Output<LoDTensor>("MeanSquareOut"),
"Output",
"MeanSquareOut",
"Rmsprop");
auto& dev_ctx = ctx.template device_context<DeviceContext>();
// int rmsprop(Context* ctx, const T* g, const T* p, const float* ms, const
// float* mom, T* p_out, float* ms_out, float* mom_out, float epsilon, float
// rho, float momentum, float lr, int n);
int r = xpu::rmsprop(dev_ctx.x_context(),
grad.template data<T>(),
param.template data<T>(),
meanSquare.template data<T>(),
mom.template data<T>(),
param_out.template mutable_data<T>(ctx.GetPlace()),
mom_sqrt_out.template mutable_data<T>(ctx.GetPlace()),
mom_out.template mutable_data<T>(ctx.GetPlace()),
epsilon,
decay,
momentum,
lr,
param.numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "rmsprop");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_XPU_KERNEL(
rmsprop,
ops::RmspropOpXPUKernel<paddle::platform::XPUDeviceContext, float>);
#endif
......@@ -363,6 +363,7 @@ XPUOpMap& get_kl2_ops() {
pOpKernelType(vartype::INT32, XPUPlace()),
pOpKernelType(vartype::BOOL, XPUPlace()),
pOpKernelType(vartype::FP32, XPUPlace())})},
{"rmsprop", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"rnn", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"rnn_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"roi_align", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
......
......@@ -36,4 +36,5 @@ no_check_set_white_list = [
'eigvalsh',
'class_center_sample',
'einsum',
'rmsprop',
]
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import sys
sys.path.append("..")
import unittest
import numpy as np
import paddle.fluid.core as core
from paddle.fluid.op import Operator
from op_test_xpu import XPUOpTest
import paddle.fluid as fluid
import paddle
'''
def create_selected_rows_and_tensor(scope, place, height, row_num,
embedding_size):
sr = scope.var("@selected_rows@").get_selected_rows()
tensor = scope.var("grad").get_tensor()
rows = np.random.random_integers(
low=0, high=height - 1, size=[row_num, ]).astype('int64')
sr_val = np.random.random(size=[row_num, embedding_size]).astype('float32')
sr.set_height(height)
sr.set_rows(rows)
sr.get_tensor().set(sr_val, place)
tensor_val = np.zeros(shape=[height, embedding_size], dtype='float32')
for i in range(row_num):
row = rows[i]
tensor_val[row, :] = tensor_val[row, :] + sr_val[i, :]
tensor.set(tensor_val, place)
return tensor_val, sr_val
'''
"""
class TestBase(XPUOpTest):
op_type = 'rmsprop'
def setup(self,
place,
is_sparse,
centered,
size,
row_num=None,
epsilon=1e-6):
np.random.seed(5) # fix seed
self.scope = fluid.global_scope()
self.place = place
self.param_name = 'param'
self.param = np.random.random(size).astype('float32')
self.mean_square_name = 'mean_square'
self.mean_square = np.random.uniform(
low=1, high=2, size=size).astype('float32')
self.mean_grad_name = 'mean_grad'
self.mean_grad = np.random.random(size).astype('float32')
self.lr_name = 'lr'
self.learning_rate = np.array([0.01]).astype('float32')
self.grad_name = 'grad'
self.is_sparse = is_sparse
self.grad = np.random.random(size).astype('float32')
grad_tensor = self.scope.var(self.grad_name).get_tensor()
grad_tensor.set(self.grad, place)
self.moment_name = 'moment'
self.moment = np.random.uniform(
low=0, high=1, size=size).astype('float32')
self.epsilon = epsilon
self.decay = 0.9
self.momentum = 0.1
self.centered = centered
self.ms_out = self.decay * self.mean_square + (1 - self.decay
) * self.grad * self.grad
if centered:
self.mg_out = self.decay * self.mean_grad + (1 - self.decay
) * self.grad
self.moment_out = self.momentum * self.moment + \
self.learning_rate * self.grad / np.sqrt(self.ms_out - np.square(self.mg_out) + self.epsilon)
else:
self.moment_out = self.momentum * self.moment + \
self.learning_rate * self.grad / np.sqrt(self.ms_out + self.epsilon)
self.param_out = self.param - self.moment_out
# create and initialize Param Variable
self.param_tensor = self.scope.var(self.param_name).get_tensor()
self.param_tensor.set(self.param, place)
self.mean_square_tensor = self.scope.var(
self.mean_square_name).get_tensor()
self.mean_square_tensor.set(self.mean_square, place)
lr = self.scope.var(self.lr_name).get_tensor()
lr.set(self.learning_rate, place)
self.moment_tensor = self.scope.var(self.moment_name).get_tensor()
self.moment_tensor.set(self.moment, place)
if self.centered:
self.mean_grad_tensor = self.scope.var(
self.mean_grad_name).get_tensor()
self.mean_grad_tensor.set(self.mean_grad, place)
def check(self, actual_t, expect_t, place, out_name, atol=1e-5):
self.assertTrue(
np.allclose(
actual_t, expect_t, atol=atol),
'Output (' + out_name + ') has diff at ' + str(place) + '\nExpect '
+ str(expect_t) + '\n' + 'But Got' + str(actual_t))
class TestRmspropOp(TestBase):
def check_with_place(self,
place,
is_sparse,
centered,
size,
row_num=None,
epsilon=1e-6):
self.setup(place, is_sparse, centered, size, row_num, epsilon)
self.run_and_check()
def run_and_check(self):
#grad_name = self.grad_sr_name if self.is_sparse else self.grad_name
grad_name = self.grad_name
kwargs = {
'Param': self.param_name,
'Grad': grad_name,
'MeanSquare': self.mean_square_name,
'Moment': self.moment_name,
'LearningRate': self.lr_name,
'ParamOut': self.param_name,
'MeanSquareOut': self.mean_square_name,
'MomentOut': self.moment_name,
'epsilon': self.epsilon,
'decay': self.decay,
'momentum': self.momentum,
'centered': self.centered
}
if self.centered:
kwargs['MeanGrad'] = self.mean_grad_name
kwargs['MeanGradOut'] = self.mean_grad_name
rmsprop_op = Operator('rmsprop', **kwargs)
atol = 1e-6
rmsprop_op.run(self.scope, self.place)
self.check(
np.array(self.mean_square_tensor),
self.ms_out,
self.place,
self.mean_square_name,
atol=atol)
self.check(
np.array(self.moment_tensor),
self.moment_out,
self.place,
self.moment_name,
atol=atol)
self.check(
np.array(self.param_tensor),
self.param_out,
self.place,
self.param_name,
atol=atol)
if self.centered:
self.check(
np.array(self.mean_grad_tensor), self.mg_out, self.place,
self.mean_grad_name)
def test_rmsprop(self):
places = [paddle.XPUPlace(0)]
size = (128, 320)
for place in places:
for centered in [False]:
with fluid.scope_guard(core.Scope()):
self.check_with_place(
place, is_sparse=False, centered=centered, size=size)
with fluid.scope_guard(core.Scope()):
self.check_with_place(
place,
is_sparse=True,
centered=centered,
row_num=512,
size=size)
with fluid.scope_guard(core.Scope()):
self.check_with_place(
place,
is_sparse=True,
centered=centered,
row_num=60,
size=size, )
class TestRMSPropV2(XPUOpTest):
op_type = 'rmsprop'
def test_rmsprop_dygraph(self):
paddle.disable_static()
value = np.arange(26).reshape(2, 13).astype('float32')
a = paddle.to_tensor(value)
linear = paddle.nn.Linear(13, 5)
# This can be any optimizer supported by dygraph.
adam = paddle.optimizer.RMSProp(
learning_rate=0.01,
parameters=linear.parameters(),
weight_decay=0.01)
out = linear(a)
out.backward()
adam.step()
adam.clear_gradients()
def test_rmsprop(self):
place = paddle.XPUPlace(0)
paddle.enable_static()
main = fluid.Program()
with fluid.program_guard(main):
x = fluid.layers.data(name='x', shape=[13], dtype='float32')
y = fluid.layers.data(name='y', shape=[1], dtype='float32')
y_predict = fluid.layers.fc(input=x, size=1, act=None)
cost = fluid.layers.square_error_cost(input=y_predict, label=y)
avg_cost = paddle.mean(cost)
print(avg_cost.shape)
linear = paddle.nn.Linear(13, 5)
rms_optimizer = paddle.optimizer.RMSProp(
learning_rate=0.1, parameters=linear.parameters())
rms_optimizer.minimize(avg_cost)
fetch_list = [avg_cost]
train_reader = paddle.batch(
paddle.dataset.uci_housing.train(), batch_size=1)
feeder = fluid.DataFeeder(place=place, feed_list=[x, y])
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
for data in train_reader():
exe.run(main, feed=feeder.feed(data), fetch_list=fetch_list)
def test_raise_error(self):
self.assertRaises(ValueError, paddle.optimizer.RMSProp, None)
self.assertRaises(
ValueError, paddle.optimizer.RMSProp, learning_rate=0.1, rho=None)
self.assertRaises(
ValueError,
paddle.optimizer.RMSProp,
learning_rate=0.1,
epsilon=None)
self.assertRaises(
ValueError,
paddle.optimizer.RMSProp,
learning_rate=0.1,
momentum=None)
def test_rmsprop_op_invalid_input(self):
paddle.disable_static()
linear = paddle.nn.Linear(10, 10)
with self.assertRaises(ValueError):
adam = paddle.optimizer.RMSProp(
0.1, epsilon=-1, parameters=linear.parameters())
with self.assertRaises(ValueError):
adam = paddle.optimizer.RMSProp(
0.1, momentum=-1, parameters=linear.parameters())
with self.assertRaises(ValueError):
adam = paddle.optimizer.RMSProp(
0.1, rho=-1, parameters=linear.parameters())
"""
if __name__ == "__main__":
paddle.enable_static()
unittest.main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import sys
sys.path.append("..")
import paddle
import paddle.fluid.core as core
from op_test import OpTest
from op_test_xpu import XPUOpTest
from xpu.get_test_cover_info import create_test_class, get_xpu_op_support_types, XPUOpTestWrapper
paddle.enable_static()
def calculate_rmsprop_by_numpy(param, grad, mean_square, moment, learning_rate,
epsilon, decay, momentum):
mean_square_out = decay * mean_square + (1 - decay) * grad * grad
moment_out = momentum * moment + learning_rate * grad / np.sqrt(
mean_square_out + epsilon)
param_out = param - moment_out
return param_out, mean_square_out, moment_out
class XPUTestRMSPropOP(XPUOpTestWrapper):
def __init__(self):
self.op_name = 'rmsprop'
self.use_dynamic_create_class = False
class TestRMSPropOPBase(XPUOpTest):
def setUp(self):
self.place = paddle.XPUPlace(0)
self.xpu_version = core.get_xpu_device_version(0)
self.init_dtype()
self.set_case()
def set_case(self):
self.op_type = 'rmsprop'
self.dtype = self.in_type
self.init_config()
self.param = np.random.uniform(-1, 1,
self.input_shape).astype(self.dtype)
self.grad = np.random.uniform(-1, 1,
self.input_shape).astype(self.dtype)
self.mean_square = np.random.uniform(0, 1, self.input_shape).astype(
self.dtype)
self.moment = np.random.uniform(-1, 1,
self.input_shape).astype(self.dtype)
self.mean_grad = np.random.uniform(-1, 1, self.input_shape).astype(
self.dtype)
self.mean_grad_out = np.random.uniform(
-1, 1, self.input_shape).astype(self.dtype)
param_out, mean_square_out, moment_out = calculate_rmsprop_by_numpy(
param=self.param,
grad=self.grad,
mean_square=self.mean_square,
moment=self.moment,
learning_rate=self.learning_rate,
epsilon=self.epsilon,
decay=self.decay,
momentum=self.momentum)
self.inputs = {
'Param': self.param,
'Grad': self.grad,
'MeanSquare': self.mean_square,
'Moment': self.moment,
'LearningRate': self.learning_rate,
'MeanGrad': self.mean_grad,
'MeanGradOut': self.mean_grad_out,
}
self.attrs = {
'use_xpu': True,
'epsilon': self.epsilon,
'decay': self.decay,
'momentum': self.momentum,
'centered':
False, # TODO(houj04): when XDNN api supports 'center = True', add more test cases
}
self.outputs = {
'ParamOut': param_out,
'MomentOut': moment_out,
'MeanSquareOut': mean_square_out,
'MeanGradOut': self.mean_grad_out
}
def init_dtype(self):
self.dtype = np.float32
def test_check_output(self):
self.check_output_with_place(self.place,
no_check_set=['MeanGradOut'])
def init_config(self):
self.input_shape = [864]
self.learning_rate = np.array([0.001]).astype(self.dtype)
self.epsilon = 1e-4
self.decay = 0.9
self.momentum = 0.1
class XPUTestRMSProp1(TestRMSPropOPBase):
def init_config(self):
self.input_shape = [2, 768]
self.learning_rate = np.array([0.002]).astype(self.dtype)
self.epsilon = 1e-4
self.decay = 0.9
self.momentum = 0.1
class XPUTestRMSProp2(TestRMSPropOPBase):
def init_config(self):
self.input_shape = [3, 8, 4096]
self.learning_rate = np.array([0.005]).astype(self.dtype)
self.epsilon = 1e-6
self.decay = 0.95
self.momentum = 0
class XPUTestRMSProp3(TestRMSPropOPBase):
def init_config(self):
self.input_shape = [1024]
self.learning_rate = np.array([0.01]).astype(self.dtype)
self.epsilon = 1e-5
self.decay = 0.99
self.momentum = 0.02
class XPUTestRMSProp4(TestRMSPropOPBase):
def init_config(self):
self.input_shape = [2, 2, 255]
self.learning_rate = np.array([0.0005]).astype(self.dtype)
self.epsilon = 1e-3
self.decay = 0.8
self.momentum = 0.002
support_types = get_xpu_op_support_types('rmsprop')
for stype in support_types:
create_test_class(globals(), XPUTestRMSPropOP, stype)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册