未验证 提交 3ca713ee 编写于 作者: H houj04 提交者: GitHub

rmsprop for xpu. test=kunlun (#44175)

* rmsprop for xpu. test=kunlun

* minor fix (follow comments). test=kunlun
上级 9a3054c6
...@@ -10,7 +10,7 @@ set(XPU_RT_LIB_NAME "libxpurt.so") ...@@ -10,7 +10,7 @@ set(XPU_RT_LIB_NAME "libxpurt.so")
if(NOT DEFINED XPU_BASE_URL) if(NOT DEFINED XPU_BASE_URL)
set(XPU_BASE_URL_WITHOUT_DATE set(XPU_BASE_URL_WITHOUT_DATE
"https://baidu-kunlun-product.cdn.bcebos.com/KL-SDK/klsdk-dev") "https://baidu-kunlun-product.cdn.bcebos.com/KL-SDK/klsdk-dev")
set(XPU_BASE_URL "${XPU_BASE_URL_WITHOUT_DATE}/20220706") set(XPU_BASE_URL "${XPU_BASE_URL_WITHOUT_DATE}/20220707")
else() else()
set(XPU_BASE_URL "${XPU_BASE_URL}") set(XPU_BASE_URL "${XPU_BASE_URL}")
endif() endif()
...@@ -19,7 +19,7 @@ endif() ...@@ -19,7 +19,7 @@ endif()
if(NOT DEFINED XPU_XDNN_BASE_URL) if(NOT DEFINED XPU_XDNN_BASE_URL)
set(XPU_XDNN_BASE_URL_WITHOUT_DATE set(XPU_XDNN_BASE_URL_WITHOUT_DATE
"https://klx-sdk-release-public.su.bcebos.com/xdnn/dev") "https://klx-sdk-release-public.su.bcebos.com/xdnn/dev")
set(XPU_XDNN_BASE_URL "${XPU_XDNN_BASE_URL_WITHOUT_DATE}/20220706") set(XPU_XDNN_BASE_URL "${XPU_XDNN_BASE_URL_WITHOUT_DATE}/20220707")
else() else()
set(XPU_XDNN_BASE_URL "${XPU_XDNN_BASE_URL}") set(XPU_XDNN_BASE_URL "${XPU_XDNN_BASE_URL}")
endif() endif()
......
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#ifdef PADDLE_WITH_XPU #ifdef PADDLE_WITH_XPU
#include <gflags/gflags.h> #include <gflags/gflags.h>
#include <iostream> #include <iostream>
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/device/device_wrapper.h" #include "paddle/fluid/platform/device/device_wrapper.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
static inline float GetAttrFromTensor(const framework::Tensor* tensor) { static inline float GetAttrFromTensor(const framework::Tensor* tensor) {
const float* tensor_data = tensor->data<float>(); const float* tensor_data = tensor->data<float>();
framework::Tensor cpu_tensor; framework::Tensor cpu_tensor;
if (platform::is_gpu_place(tensor->place()) || if (platform::is_gpu_place(tensor->place()) ||
platform::is_xpu_place(tensor->place())) { platform::is_xpu_place(tensor->place())) {
paddle::framework::TensorCopySync( paddle::framework::TensorCopySync(
*tensor, platform::CPUPlace(), &cpu_tensor); *tensor, platform::CPUPlace(), &cpu_tensor);
tensor_data = cpu_tensor.data<float>(); tensor_data = cpu_tensor.data<float>();
} }
return tensor_data[0]; return tensor_data[0];
} }
using framework::OpKernelType; using framework::OpKernelType;
using framework::Tensor; using framework::Tensor;
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class RmspropOpXPUKernel : public framework::OpKernel<T> { class RmspropOpXPUKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
using paddle::framework::LoDTensor; using paddle::framework::LoDTensor;
// check Param & Grad tensor type // check Param & Grad tensor type
const auto* param_var = ctx.InputVar("Param"); const auto* param_var = ctx.InputVar("Param");
PADDLE_ENFORCE_EQ(param_var->IsType<LoDTensor>(), PADDLE_ENFORCE_EQ(param_var->IsType<LoDTensor>(),
true, true,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Tensor holds the wrong type,Expected Var(%s)'s " "Tensor holds the wrong type,Expected Var(%s)'s "
"type is LoDTensor, " "type is LoDTensor, "
"but the received is %s", "but the received is %s",
ctx.InputNames("Param").front(), ctx.InputNames("Param").front(),
framework::ToTypeName(param_var->Type()))); framework::ToTypeName(param_var->Type())));
const auto* grad_var = ctx.InputVar("Grad"); const auto* grad_var = ctx.InputVar("Grad");
PADDLE_ENFORCE_EQ(grad_var->IsType<LoDTensor>(), PADDLE_ENFORCE_EQ(grad_var->IsType<LoDTensor>(),
true, true,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Tensor holds the wrong type,Expected Var(%s)'s " "Tensor holds the wrong type,Expected Var(%s)'s "
"type is LoDTensor, " "type is LoDTensor, "
"but the received is %s", "but the received is %s",
ctx.InputNames("Grad").front(), ctx.InputNames("Grad").front(),
framework::ToTypeName(grad_var->Type()))); framework::ToTypeName(grad_var->Type())));
// inputs // inputs
auto& param = GET_DATA_SAFELY( auto& param = GET_DATA_SAFELY(
ctx.Input<LoDTensor>("Param"), "Input", "Param", "Rmsprop"); ctx.Input<LoDTensor>("Param"), "Input", "Param", "Rmsprop");
auto& meanSquare = GET_DATA_SAFELY( auto& meanSquare = GET_DATA_SAFELY(
ctx.Input<LoDTensor>("MeanSquare"), "Input", "MeanSquare", "Rmsprop"); ctx.Input<LoDTensor>("MeanSquare"), "Input", "MeanSquare", "Rmsprop");
auto& grad = GET_DATA_SAFELY( auto& grad = GET_DATA_SAFELY(
ctx.Input<LoDTensor>("Grad"), "Input", "Grad", "Rmsprop"); ctx.Input<LoDTensor>("Grad"), "Input", "Grad", "Rmsprop");
auto& mom = GET_DATA_SAFELY( auto& mom = GET_DATA_SAFELY(
ctx.Input<LoDTensor>("Moment"), "Input", "Moment", "Rmsprop"); ctx.Input<LoDTensor>("Moment"), "Input", "Moment", "Rmsprop");
auto* learning_rate = ctx.Input<Tensor>("LearningRate"); auto* learning_rate = ctx.Input<Tensor>("LearningRate");
PADDLE_ENFORCE_EQ(learning_rate->dims().size(), PADDLE_ENFORCE_EQ(learning_rate->dims().size(),
1, 1,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"learining rate should have dimension = 1." "learining rate should have dimension = 1."
" But received learning rate dim [%s] ", " But received learning rate dim [%s] ",
learning_rate->dims().size())); learning_rate->dims().size()));
T lr = static_cast<T>(GetAttrFromTensor(learning_rate)); T lr = static_cast<T>(GetAttrFromTensor(learning_rate));
// constants // constants
T epsilon = static_cast<T>(ctx.Attr<float>("epsilon")); T epsilon = static_cast<T>(ctx.Attr<float>("epsilon"));
T decay = static_cast<T>(ctx.Attr<float>("decay")); T decay = static_cast<T>(ctx.Attr<float>("decay"));
T momentum = static_cast<T>(ctx.Attr<float>("momentum")); T momentum = static_cast<T>(ctx.Attr<float>("momentum"));
// outputs bool centered = ctx.Attr<bool>("centered");
auto& param_out = GET_DATA_SAFELY( PADDLE_ENFORCE_EQ(centered,
ctx.Output<LoDTensor>("ParamOut"), "Output", "ParamOut", "Rmsprop"); false,
auto& mom_out = GET_DATA_SAFELY( platform::errors::Unimplemented(
ctx.Output<LoDTensor>("MomentOut"), "Output", "MomentOut", "Rmsprop"); "centered=True is not supported in the xpu kernel of "
auto& mom_sqrt_out = GET_DATA_SAFELY(ctx.Output<LoDTensor>("MeanSquareOut"), "rmsprop. use XPU_BLACK_LIST to disable this op."));
"Output", /*
"MeanSquareOut", TODO(houj04): when XDNN api supports 'center', add input of
"Rmsprop"); mean_grad_input and output of mean_grad_output. auto *mean_grad_input =
auto& dev_ctx = ctx.template device_context<DeviceContext>(); ctx.Input<Tensor>("MeanGrad"); auto *mean_grad_output =
ctx.Output<Tensor>("MeanGradOut");
///// rmsprop优化算法 */
///
/// ms_out[i] = rho * ms[i] + (1 - rho) * (g[i] * g[i]); // outputs
/// auto& param_out = GET_DATA_SAFELY(
/// mom_out[i] = momentum * mom[i] + lr * ctx.Output<LoDTensor>("ParamOut"), "Output", "ParamOut", "Rmsprop");
/// (g[i] / ((float)sqrt(ms_out[i] + epsilon))); auto& mom_out = GET_DATA_SAFELY(
/// ctx.Output<LoDTensor>("MomentOut"), "Output", "MomentOut", "Rmsprop");
/// p_out[i] = p[i] - mom_out[i]; auto& mom_sqrt_out = GET_DATA_SAFELY(ctx.Output<LoDTensor>("MeanSquareOut"),
/// DLL_EXPORT int rmsprop(Context* ctx, const float* p, "Output",
/// const float* ms, const float* g, const float* mom, "MeanSquareOut",
/// float epsilon, float rho, float momentum, float lr, "Rmsprop");
/// float *ms_out, float *mom_out, float *p_out, int n) auto& dev_ctx = ctx.template device_context<DeviceContext>();
int r = xpu::rmsprop(dev_ctx.x_context(),
grad.template data<T>(), // int rmsprop(Context* ctx, const T* g, const T* p, const float* ms, const
param.template data<T>(), // float* mom, T* p_out, float* ms_out, float* mom_out, float epsilon, float
meanSquare.template data<T>(), // rho, float momentum, float lr, int n);
mom.template data<T>(), int r = xpu::rmsprop(dev_ctx.x_context(),
param_out.template mutable_data<T>(ctx.GetPlace()), grad.template data<T>(),
mom_sqrt_out.template mutable_data<T>(ctx.GetPlace()), param.template data<T>(),
mom_out.template mutable_data<T>(ctx.GetPlace()), meanSquare.template data<T>(),
epsilon, mom.template data<T>(),
decay, param_out.template mutable_data<T>(ctx.GetPlace()),
momentum, mom_sqrt_out.template mutable_data<T>(ctx.GetPlace()),
lr, mom_out.template mutable_data<T>(ctx.GetPlace()),
param.numel()); epsilon,
decay,
PADDLE_ENFORCE_XDNN_SUCCESS(r, "rmsprop"); momentum,
} lr,
}; param.numel());
} // namespace operators PADDLE_ENFORCE_XDNN_SUCCESS(r, "rmsprop");
} // namespace paddle }
};
namespace ops = paddle::operators;
REGISTER_OP_XPU_KERNEL( } // namespace operators
rmsprop, } // namespace paddle
ops::RmspropOpXPUKernel<paddle::platform::XPUDeviceContext, float>);
#endif namespace ops = paddle::operators;
REGISTER_OP_XPU_KERNEL(
rmsprop,
ops::RmspropOpXPUKernel<paddle::platform::XPUDeviceContext, float>);
#endif
...@@ -363,6 +363,7 @@ XPUOpMap& get_kl2_ops() { ...@@ -363,6 +363,7 @@ XPUOpMap& get_kl2_ops() {
pOpKernelType(vartype::INT32, XPUPlace()), pOpKernelType(vartype::INT32, XPUPlace()),
pOpKernelType(vartype::BOOL, XPUPlace()), pOpKernelType(vartype::BOOL, XPUPlace()),
pOpKernelType(vartype::FP32, XPUPlace())})}, pOpKernelType(vartype::FP32, XPUPlace())})},
{"rmsprop", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"rnn", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"rnn", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"rnn_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"rnn_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"roi_align", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"roi_align", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
......
...@@ -36,4 +36,5 @@ no_check_set_white_list = [ ...@@ -36,4 +36,5 @@ no_check_set_white_list = [
'eigvalsh', 'eigvalsh',
'class_center_sample', 'class_center_sample',
'einsum', 'einsum',
'rmsprop',
] ]
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from __future__ import print_function from __future__ import print_function
import sys
import unittest
sys.path.append("..") import numpy as np
import sys
import unittest
import numpy as np sys.path.append("..")
import paddle.fluid.core as core
from paddle.fluid.op import Operator import paddle
from op_test_xpu import XPUOpTest import paddle.fluid.core as core
import paddle.fluid as fluid
import paddle from op_test import OpTest
''' from op_test_xpu import XPUOpTest
def create_selected_rows_and_tensor(scope, place, height, row_num, from xpu.get_test_cover_info import create_test_class, get_xpu_op_support_types, XPUOpTestWrapper
embedding_size):
sr = scope.var("@selected_rows@").get_selected_rows() paddle.enable_static()
tensor = scope.var("grad").get_tensor()
rows = np.random.random_integers( def calculate_rmsprop_by_numpy(param, grad, mean_square, moment, learning_rate,
low=0, high=height - 1, size=[row_num, ]).astype('int64') epsilon, decay, momentum):
sr_val = np.random.random(size=[row_num, embedding_size]).astype('float32') mean_square_out = decay * mean_square + (1 - decay) * grad * grad
moment_out = momentum * moment + learning_rate * grad / np.sqrt(
sr.set_height(height) mean_square_out + epsilon)
sr.set_rows(rows) param_out = param - moment_out
sr.get_tensor().set(sr_val, place) return param_out, mean_square_out, moment_out
tensor_val = np.zeros(shape=[height, embedding_size], dtype='float32')
for i in range(row_num): class XPUTestRMSPropOP(XPUOpTestWrapper):
row = rows[i]
tensor_val[row, :] = tensor_val[row, :] + sr_val[i, :] def __init__(self):
self.op_name = 'rmsprop'
tensor.set(tensor_val, place) self.use_dynamic_create_class = False
return tensor_val, sr_val
''' class TestRMSPropOPBase(XPUOpTest):
"""
class TestBase(XPUOpTest): def setUp(self):
op_type = 'rmsprop' self.place = paddle.XPUPlace(0)
self.xpu_version = core.get_xpu_device_version(0)
def setup(self, self.init_dtype()
place, self.set_case()
is_sparse,
centered, def set_case(self):
size, self.op_type = 'rmsprop'
row_num=None, self.dtype = self.in_type
epsilon=1e-6): self.init_config()
np.random.seed(5) # fix seed self.param = np.random.uniform(-1, 1,
self.input_shape).astype(self.dtype)
self.scope = fluid.global_scope() self.grad = np.random.uniform(-1, 1,
self.place = place self.input_shape).astype(self.dtype)
self.mean_square = np.random.uniform(0, 1, self.input_shape).astype(
self.param_name = 'param' self.dtype)
self.param = np.random.random(size).astype('float32') self.moment = np.random.uniform(-1, 1,
self.input_shape).astype(self.dtype)
self.mean_square_name = 'mean_square'
self.mean_square = np.random.uniform( self.mean_grad = np.random.uniform(-1, 1, self.input_shape).astype(
low=1, high=2, size=size).astype('float32') self.dtype)
self.mean_grad_out = np.random.uniform(
self.mean_grad_name = 'mean_grad' -1, 1, self.input_shape).astype(self.dtype)
self.mean_grad = np.random.random(size).astype('float32')
param_out, mean_square_out, moment_out = calculate_rmsprop_by_numpy(
self.lr_name = 'lr' param=self.param,
self.learning_rate = np.array([0.01]).astype('float32') grad=self.grad,
mean_square=self.mean_square,
self.grad_name = 'grad' moment=self.moment,
self.is_sparse = is_sparse learning_rate=self.learning_rate,
epsilon=self.epsilon,
self.grad = np.random.random(size).astype('float32') decay=self.decay,
grad_tensor = self.scope.var(self.grad_name).get_tensor() momentum=self.momentum)
grad_tensor.set(self.grad, place) self.inputs = {
'Param': self.param,
self.moment_name = 'moment' 'Grad': self.grad,
self.moment = np.random.uniform( 'MeanSquare': self.mean_square,
low=0, high=1, size=size).astype('float32') 'Moment': self.moment,
'LearningRate': self.learning_rate,
self.epsilon = epsilon 'MeanGrad': self.mean_grad,
self.decay = 0.9 'MeanGradOut': self.mean_grad_out,
self.momentum = 0.1 }
self.centered = centered self.attrs = {
'use_xpu': True,
self.ms_out = self.decay * self.mean_square + (1 - self.decay 'epsilon': self.epsilon,
) * self.grad * self.grad 'decay': self.decay,
if centered: 'momentum': self.momentum,
self.mg_out = self.decay * self.mean_grad + (1 - self.decay 'centered':
) * self.grad False, # TODO(houj04): when XDNN api supports 'center = True', add more test cases
self.moment_out = self.momentum * self.moment + \ }
self.learning_rate * self.grad / np.sqrt(self.ms_out - np.square(self.mg_out) + self.epsilon) self.outputs = {
else: 'ParamOut': param_out,
self.moment_out = self.momentum * self.moment + \ 'MomentOut': moment_out,
self.learning_rate * self.grad / np.sqrt(self.ms_out + self.epsilon) 'MeanSquareOut': mean_square_out,
'MeanGradOut': self.mean_grad_out
self.param_out = self.param - self.moment_out }
# create and initialize Param Variable def init_dtype(self):
self.param_tensor = self.scope.var(self.param_name).get_tensor() self.dtype = np.float32
self.param_tensor.set(self.param, place)
def test_check_output(self):
self.mean_square_tensor = self.scope.var( self.check_output_with_place(self.place,
self.mean_square_name).get_tensor() no_check_set=['MeanGradOut'])
self.mean_square_tensor.set(self.mean_square, place)
def init_config(self):
lr = self.scope.var(self.lr_name).get_tensor() self.input_shape = [864]
lr.set(self.learning_rate, place) self.learning_rate = np.array([0.001]).astype(self.dtype)
self.epsilon = 1e-4
self.moment_tensor = self.scope.var(self.moment_name).get_tensor() self.decay = 0.9
self.moment_tensor.set(self.moment, place) self.momentum = 0.1
if self.centered: class XPUTestRMSProp1(TestRMSPropOPBase):
self.mean_grad_tensor = self.scope.var(
self.mean_grad_name).get_tensor() def init_config(self):
self.mean_grad_tensor.set(self.mean_grad, place) self.input_shape = [2, 768]
self.learning_rate = np.array([0.002]).astype(self.dtype)
def check(self, actual_t, expect_t, place, out_name, atol=1e-5): self.epsilon = 1e-4
self.assertTrue( self.decay = 0.9
np.allclose( self.momentum = 0.1
actual_t, expect_t, atol=atol),
'Output (' + out_name + ') has diff at ' + str(place) + '\nExpect ' class XPUTestRMSProp2(TestRMSPropOPBase):
+ str(expect_t) + '\n' + 'But Got' + str(actual_t))
def init_config(self):
self.input_shape = [3, 8, 4096]
class TestRmspropOp(TestBase): self.learning_rate = np.array([0.005]).astype(self.dtype)
def check_with_place(self, self.epsilon = 1e-6
place, self.decay = 0.95
is_sparse, self.momentum = 0
centered,
size, class XPUTestRMSProp3(TestRMSPropOPBase):
row_num=None,
epsilon=1e-6): def init_config(self):
self.setup(place, is_sparse, centered, size, row_num, epsilon) self.input_shape = [1024]
self.run_and_check() self.learning_rate = np.array([0.01]).astype(self.dtype)
self.epsilon = 1e-5
def run_and_check(self): self.decay = 0.99
#grad_name = self.grad_sr_name if self.is_sparse else self.grad_name self.momentum = 0.02
grad_name = self.grad_name
class XPUTestRMSProp4(TestRMSPropOPBase):
kwargs = {
'Param': self.param_name, def init_config(self):
'Grad': grad_name, self.input_shape = [2, 2, 255]
'MeanSquare': self.mean_square_name, self.learning_rate = np.array([0.0005]).astype(self.dtype)
'Moment': self.moment_name, self.epsilon = 1e-3
'LearningRate': self.lr_name, self.decay = 0.8
'ParamOut': self.param_name, self.momentum = 0.002
'MeanSquareOut': self.mean_square_name,
'MomentOut': self.moment_name,
'epsilon': self.epsilon, support_types = get_xpu_op_support_types('rmsprop')
'decay': self.decay, for stype in support_types:
'momentum': self.momentum, create_test_class(globals(), XPUTestRMSPropOP, stype)
'centered': self.centered
} if __name__ == "__main__":
unittest.main()
if self.centered:
kwargs['MeanGrad'] = self.mean_grad_name
kwargs['MeanGradOut'] = self.mean_grad_name
rmsprop_op = Operator('rmsprop', **kwargs)
atol = 1e-6
rmsprop_op.run(self.scope, self.place)
self.check(
np.array(self.mean_square_tensor),
self.ms_out,
self.place,
self.mean_square_name,
atol=atol)
self.check(
np.array(self.moment_tensor),
self.moment_out,
self.place,
self.moment_name,
atol=atol)
self.check(
np.array(self.param_tensor),
self.param_out,
self.place,
self.param_name,
atol=atol)
if self.centered:
self.check(
np.array(self.mean_grad_tensor), self.mg_out, self.place,
self.mean_grad_name)
def test_rmsprop(self):
places = [paddle.XPUPlace(0)]
size = (128, 320)
for place in places:
for centered in [False]:
with fluid.scope_guard(core.Scope()):
self.check_with_place(
place, is_sparse=False, centered=centered, size=size)
with fluid.scope_guard(core.Scope()):
self.check_with_place(
place,
is_sparse=True,
centered=centered,
row_num=512,
size=size)
with fluid.scope_guard(core.Scope()):
self.check_with_place(
place,
is_sparse=True,
centered=centered,
row_num=60,
size=size, )
class TestRMSPropV2(XPUOpTest):
op_type = 'rmsprop'
def test_rmsprop_dygraph(self):
paddle.disable_static()
value = np.arange(26).reshape(2, 13).astype('float32')
a = paddle.to_tensor(value)
linear = paddle.nn.Linear(13, 5)
# This can be any optimizer supported by dygraph.
adam = paddle.optimizer.RMSProp(
learning_rate=0.01,
parameters=linear.parameters(),
weight_decay=0.01)
out = linear(a)
out.backward()
adam.step()
adam.clear_gradients()
def test_rmsprop(self):
place = paddle.XPUPlace(0)
paddle.enable_static()
main = fluid.Program()
with fluid.program_guard(main):
x = fluid.layers.data(name='x', shape=[13], dtype='float32')
y = fluid.layers.data(name='y', shape=[1], dtype='float32')
y_predict = fluid.layers.fc(input=x, size=1, act=None)
cost = fluid.layers.square_error_cost(input=y_predict, label=y)
avg_cost = paddle.mean(cost)
print(avg_cost.shape)
linear = paddle.nn.Linear(13, 5)
rms_optimizer = paddle.optimizer.RMSProp(
learning_rate=0.1, parameters=linear.parameters())
rms_optimizer.minimize(avg_cost)
fetch_list = [avg_cost]
train_reader = paddle.batch(
paddle.dataset.uci_housing.train(), batch_size=1)
feeder = fluid.DataFeeder(place=place, feed_list=[x, y])
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
for data in train_reader():
exe.run(main, feed=feeder.feed(data), fetch_list=fetch_list)
def test_raise_error(self):
self.assertRaises(ValueError, paddle.optimizer.RMSProp, None)
self.assertRaises(
ValueError, paddle.optimizer.RMSProp, learning_rate=0.1, rho=None)
self.assertRaises(
ValueError,
paddle.optimizer.RMSProp,
learning_rate=0.1,
epsilon=None)
self.assertRaises(
ValueError,
paddle.optimizer.RMSProp,
learning_rate=0.1,
momentum=None)
def test_rmsprop_op_invalid_input(self):
paddle.disable_static()
linear = paddle.nn.Linear(10, 10)
with self.assertRaises(ValueError):
adam = paddle.optimizer.RMSProp(
0.1, epsilon=-1, parameters=linear.parameters())
with self.assertRaises(ValueError):
adam = paddle.optimizer.RMSProp(
0.1, momentum=-1, parameters=linear.parameters())
with self.assertRaises(ValueError):
adam = paddle.optimizer.RMSProp(
0.1, rho=-1, parameters=linear.parameters())
"""
if __name__ == "__main__":
paddle.enable_static()
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册