未验证 提交 6e03f790 编写于 作者: Q Qiao Longfei 提交者: GitHub

Add centered mode rmsprop (#13161)

* rmsprop optimizer support v1 mode

* typo

* optimize code

* refine code

* optimize unit test

* update test_rmsprop_op.py

* update formula of rmsprop

* optimize document

* update API.spec for RMSPropOptimizer

* add default value to check_output_with_place equal_nan
上级 9df2d8b5
...@@ -376,7 +376,7 @@ paddle.fluid.optimizer.DecayedAdagradOptimizer.__init__ ArgSpec(args=['self', 'l ...@@ -376,7 +376,7 @@ paddle.fluid.optimizer.DecayedAdagradOptimizer.__init__ ArgSpec(args=['self', 'l
paddle.fluid.optimizer.DecayedAdagradOptimizer.minimize ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set'], varargs=None, keywords=None, defaults=(None, None, None)) paddle.fluid.optimizer.DecayedAdagradOptimizer.minimize ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set'], varargs=None, keywords=None, defaults=(None, None, None))
paddle.fluid.optimizer.FtrlOptimizer.__init__ ArgSpec(args=['self', 'learning_rate', 'l1', 'l2', 'lr_power'], varargs=None, keywords='kwargs', defaults=(0.0, 0.0, -0.5)) paddle.fluid.optimizer.FtrlOptimizer.__init__ ArgSpec(args=['self', 'learning_rate', 'l1', 'l2', 'lr_power'], varargs=None, keywords='kwargs', defaults=(0.0, 0.0, -0.5))
paddle.fluid.optimizer.FtrlOptimizer.minimize ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set'], varargs=None, keywords=None, defaults=(None, None, None)) paddle.fluid.optimizer.FtrlOptimizer.minimize ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set'], varargs=None, keywords=None, defaults=(None, None, None))
paddle.fluid.optimizer.RMSPropOptimizer.__init__ ArgSpec(args=['self', 'learning_rate', 'rho', 'epsilon', 'momentum'], varargs=None, keywords='kwargs', defaults=(0.95, 1e-06, 0.0)) paddle.fluid.optimizer.RMSPropOptimizer.__init__ ArgSpec(args=['self', 'learning_rate', 'rho', 'epsilon', 'momentum', 'centered'], varargs=None, keywords='kwargs', defaults=(0.95, 1e-06, 0.0, False))
paddle.fluid.optimizer.RMSPropOptimizer.minimize ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set'], varargs=None, keywords=None, defaults=(None, None, None)) paddle.fluid.optimizer.RMSPropOptimizer.minimize ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set'], varargs=None, keywords=None, defaults=(None, None, None))
paddle.fluid.optimizer.AdadeltaOptimizer.__init__ ArgSpec(args=['self', 'learning_rate', 'epsilon', 'rho'], varargs=None, keywords='kwargs', defaults=(1e-06, 0.95)) paddle.fluid.optimizer.AdadeltaOptimizer.__init__ ArgSpec(args=['self', 'learning_rate', 'epsilon', 'rho'], varargs=None, keywords='kwargs', defaults=(1e-06, 0.95))
paddle.fluid.optimizer.AdadeltaOptimizer.minimize ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set'], varargs=None, keywords=None, defaults=(None, None, None)) paddle.fluid.optimizer.AdadeltaOptimizer.minimize ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set'], varargs=None, keywords=None, defaults=(None, None, None))
......
...@@ -36,9 +36,13 @@ class RmspropOp : public framework::OperatorWithKernel { ...@@ -36,9 +36,13 @@ class RmspropOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE(ctx->HasOutput("ParamOut"), PADDLE_ENFORCE(ctx->HasOutput("ParamOut"),
"Output(param_out) of RmspropOp should not be null."); "Output(param_out) of RmspropOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("MomentOut"), PADDLE_ENFORCE(ctx->HasOutput("MomentOut"),
"Output(Momentum_out) of RmspropOp should not be null."); "Output(MomentOut) of RmspropOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("MeanSquareOut"), PADDLE_ENFORCE(ctx->HasOutput("MeanSquareOut"),
"Output(MeanSquareOut) of RmspropOp should not be null."); "Output(MeanSquareOut) of RmspropOp should not be null.");
if (ctx->Attrs().Get<bool>("centered")) {
PADDLE_ENFORCE(ctx->HasOutput("MeanGradOut"),
"Output(MeanGradOut) of RmspropOp should not be null.");
}
auto param_dim = ctx->GetInputDim("Param"); auto param_dim = ctx->GetInputDim("Param");
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
...@@ -58,6 +62,9 @@ class RmspropOp : public framework::OperatorWithKernel { ...@@ -58,6 +62,9 @@ class RmspropOp : public framework::OperatorWithKernel {
ctx->SetOutputDim("ParamOut", param_dim); ctx->SetOutputDim("ParamOut", param_dim);
ctx->SetOutputDim("MomentOut", param_dim); ctx->SetOutputDim("MomentOut", param_dim);
ctx->SetOutputDim("MeanSquareOut", param_dim); ctx->SetOutputDim("MeanSquareOut", param_dim);
if (ctx->Attrs().Get<bool>("centered")) {
ctx->SetOutputDim("MeanGradOut", param_dim);
}
} }
}; };
...@@ -70,6 +77,10 @@ class RmspropOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -70,6 +77,10 @@ class RmspropOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput("MeanSquare", AddInput("MeanSquare",
"(Tensor, default Tensor<float>)" "(Tensor, default Tensor<float>)"
" The mean square value that gets updated."); " The mean square value that gets updated.");
AddInput("MeanGrad",
"(Tensor, default Tensor<float>)"
" The moving average of gradient")
.AsDispensable();
AddInput("LearningRate", AddInput("LearningRate",
"(Tensor, default Tensor<float>) " "(Tensor, default Tensor<float>) "
"The learning rate should be a tensor of size 1."); "The learning rate should be a tensor of size 1.");
...@@ -82,6 +93,8 @@ class RmspropOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -82,6 +93,8 @@ class RmspropOpMaker : public framework::OpProtoAndCheckerMaker {
AddOutput("ParamOut", "(Tensor) Output updated parameter value."); AddOutput("ParamOut", "(Tensor) Output updated parameter value.");
AddOutput("MomentOut", "(Tensor) Output updated moment."); AddOutput("MomentOut", "(Tensor) Output updated moment.");
AddOutput("MeanSquareOut", "(Tensor) Output Mean squared updated value."); AddOutput("MeanSquareOut", "(Tensor) Output Mean squared updated value.");
AddOutput("MeanGradOut",
"(Tensor) Output moving average of gradient updated value.");
AddAttr<float>("epsilon", AddAttr<float>("epsilon",
"(float, default 1e-10) Constant " "(float, default 1e-10) Constant "
...@@ -93,6 +106,8 @@ class RmspropOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -93,6 +106,8 @@ class RmspropOpMaker : public framework::OpProtoAndCheckerMaker {
.SetDefault(0.9f); .SetDefault(0.9f);
AddAttr<float>("momentum", "(float, default 0.0) Constant value.") AddAttr<float>("momentum", "(float, default 0.0) Constant value.")
.SetDefault(0.0f); .SetDefault(0.0f);
AddAttr<bool>("centered", "(bool, default false) use centered rmsprop.")
.SetDefault(false);
AddComment(R"DOC( AddComment(R"DOC(
Rmsprop Optimizer. Rmsprop Optimizer.
...@@ -103,6 +118,14 @@ MomentOut = momentum * Moment + ...@@ -103,6 +118,14 @@ MomentOut = momentum * Moment +
ParamOut = Param - MomentOut ParamOut = Param - MomentOut
$$ $$
if centered is true:
mean_grad = decay * mean_square{t-1} + (1-decay) * gradient
mean_square = decay * mean_square{t-1} + (1-decay) * gradient ** 2
mom = momentum * mom{t-1} + learning_rate * g_t /
sqrt(mean_square - mean_grad**2 + epsilon)
param -= mom
The original slides that proposed Rmsprop: Slide 29 of The original slides that proposed Rmsprop: Slide 29 of
http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf) http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf)
......
...@@ -41,6 +41,7 @@ class RmspropOpKernel : public framework::OpKernel<T> { ...@@ -41,6 +41,7 @@ class RmspropOpKernel : public framework::OpKernel<T> {
float epsilon = ctx.Attr<float>("epsilon"); float epsilon = ctx.Attr<float>("epsilon");
float rho = ctx.Attr<float>("decay"); float rho = ctx.Attr<float>("decay");
float momentum = ctx.Attr<float>("momentum"); float momentum = ctx.Attr<float>("momentum");
bool centered = ctx.Attr<bool>("centered");
auto p = EigenVector<T>::Flatten(*ctx.Input<Tensor>("Param")); auto p = EigenVector<T>::Flatten(*ctx.Input<Tensor>("Param"));
auto ms = EigenVector<T>::Flatten(*ctx.Input<Tensor>("MeanSquare")); auto ms = EigenVector<T>::Flatten(*ctx.Input<Tensor>("MeanSquare"));
...@@ -53,12 +54,24 @@ class RmspropOpKernel : public framework::OpKernel<T> { ...@@ -53,12 +54,24 @@ class RmspropOpKernel : public framework::OpKernel<T> {
auto ms_out = EigenVector<T>::Flatten(*mean_square_out); auto ms_out = EigenVector<T>::Flatten(*mean_square_out);
auto& place = *ctx.template device_context<DeviceContext>().eigen_device(); auto& place = *ctx.template device_context<DeviceContext>().eigen_device();
Eigen::DSizes<int, 1> grad_dsize(grad->numel()); Eigen::DSizes<int, 1> grad_dsize(static_cast<int>(grad->numel()));
ms_out.device(place) = rho * ms + (1 - rho) * g * g; ms_out.device(place) = rho * ms + (1 - rho) * g * g;
if (centered) {
auto mg = EigenVector<T>::Flatten(*ctx.Input<Tensor>("MeanGrad"));
auto* mean_grad_out = ctx.Output<Tensor>("MeanGradOut");
mean_grad_out->mutable_data<T>(ctx.GetPlace());
auto mg_out = EigenVector<T>::Flatten(*mean_grad_out);
mg_out.device(place) = rho * mg + (1 - rho) * g;
mom_out.device(place) = momentum * mom +
lr.broadcast(grad_dsize) * g /
(ms_out - mg_out.square() + epsilon).sqrt();
} else {
mom_out.device(place) = mom_out.device(place) =
momentum * mom + momentum * mom +
lr.broadcast(grad_dsize) * g / (ms_out + epsilon).sqrt(); lr.broadcast(grad_dsize) * g / (ms_out + epsilon).sqrt();
}
p_out.device(place) = p - mom_out; p_out.device(place) = p - mom_out;
} }
}; };
......
...@@ -897,7 +897,20 @@ class RMSPropOptimizer(Optimizer): ...@@ -897,7 +897,20 @@ class RMSPropOptimizer(Optimizer):
r(w, t) & = \\rho r(w, t-1) + (1 - \\rho)(\\nabla Q_{i}(w))^2 r(w, t) & = \\rho r(w, t-1) + (1 - \\rho)(\\nabla Q_{i}(w))^2
v(w, t) & = \\beta v(w, t-1) + \\frac{\\eta} {\\sqrt{v(w,t) + v(w, t) & = \\beta v(w, t-1) + \\frac{\\eta} {\\sqrt{r(w,t) +
\\epsilon}} \\nabla Q_{i}(w)
w & = w - v(w, t)
if centered is True:
.. math::
r(w, t) & = \\rho r(w, t-1) + (1 - \\rho)(\\nabla Q_{i}(w))^2
g(w, t) & = \\rho g(w, t-1) + (1 - \\rho)\\nabla Q_{i}(w)
v(w, t) & = \\beta v(w, t-1) + \\frac{\\eta} {\\sqrt{r(w,t) - (g(w, t))^2 +
\\epsilon}} \\nabla Q_{i}(w) \\epsilon}} \\nabla Q_{i}(w)
w & = w - v(w, t) w & = w - v(w, t)
...@@ -915,6 +928,10 @@ class RMSPropOptimizer(Optimizer): ...@@ -915,6 +928,10 @@ class RMSPropOptimizer(Optimizer):
avoid division by zero, set 1e-6 by default. avoid division by zero, set 1e-6 by default.
momentum(float): :math:`\\beta` in equation is the momentum term, momentum(float): :math:`\\beta` in equation is the momentum term,
set 0.0 by default. set 0.0 by default.
centered(bool): If True, gradients are normalized by the estimated variance of
the gradient; if False, by the uncentered second moment. Setting this to
True may help with training, but is slightly more expensive in terms of
computation and memory. Defaults to False.
Raises: Raises:
ValueError: If learning_rate, rho, epsilon, momentum are None. ValueError: If learning_rate, rho, epsilon, momentum are None.
...@@ -928,12 +945,14 @@ class RMSPropOptimizer(Optimizer): ...@@ -928,12 +945,14 @@ class RMSPropOptimizer(Optimizer):
_momentum_acc_str = "momentum" _momentum_acc_str = "momentum"
_mean_square_acc_str = "mean_square" _mean_square_acc_str = "mean_square"
_mean_grad_acc_str = "mean_grad"
def __init__(self, def __init__(self,
learning_rate, learning_rate,
rho=0.95, rho=0.95,
epsilon=1.0e-6, epsilon=1.0e-6,
momentum=0.0, momentum=0.0,
centered=False,
**kwargs): **kwargs):
super(RMSPropOptimizer, self).__init__( super(RMSPropOptimizer, self).__init__(
learning_rate=learning_rate, **kwargs) learning_rate=learning_rate, **kwargs)
...@@ -950,6 +969,7 @@ class RMSPropOptimizer(Optimizer): ...@@ -950,6 +969,7 @@ class RMSPropOptimizer(Optimizer):
self._rho = rho self._rho = rho
self._epsilon = epsilon self._epsilon = epsilon
self._momentum = momentum self._momentum = momentum
self._centered = centered
def _create_accumulators(self, block, parameters): def _create_accumulators(self, block, parameters):
if not isinstance(block, framework.Block): if not isinstance(block, framework.Block):
...@@ -958,6 +978,7 @@ class RMSPropOptimizer(Optimizer): ...@@ -958,6 +978,7 @@ class RMSPropOptimizer(Optimizer):
for p in parameters: for p in parameters:
self._add_accumulator(self._momentum_acc_str, p) self._add_accumulator(self._momentum_acc_str, p)
self._add_accumulator(self._mean_square_acc_str, p) self._add_accumulator(self._mean_square_acc_str, p)
self._add_accumulator(self._mean_grad_acc_str, p)
def _append_optimize_op(self, block, param_and_grad): def _append_optimize_op(self, block, param_and_grad):
if not isinstance(block, framework.Block): if not isinstance(block, framework.Block):
...@@ -967,6 +988,8 @@ class RMSPropOptimizer(Optimizer): ...@@ -967,6 +988,8 @@ class RMSPropOptimizer(Optimizer):
param_and_grad[0]) param_and_grad[0])
mean_square_acc = self._get_accumulator(self._mean_square_acc_str, mean_square_acc = self._get_accumulator(self._mean_square_acc_str,
param_and_grad[0]) param_and_grad[0])
mean_grad_acc = self._get_accumulator(self._mean_grad_acc_str,
param_and_grad[0])
rmsprop_op = block.append_op( rmsprop_op = block.append_op(
type=self.type, type=self.type,
inputs={ inputs={
...@@ -974,17 +997,20 @@ class RMSPropOptimizer(Optimizer): ...@@ -974,17 +997,20 @@ class RMSPropOptimizer(Optimizer):
"Grad": param_and_grad[1], "Grad": param_and_grad[1],
"Moment": momentum_acc, "Moment": momentum_acc,
"MeanSquare": mean_square_acc, "MeanSquare": mean_square_acc,
"MeanGrad": mean_grad_acc,
"LearningRate": self._create_param_lr(param_and_grad), "LearningRate": self._create_param_lr(param_and_grad),
}, },
outputs={ outputs={
"ParamOut": param_and_grad[0], "ParamOut": param_and_grad[0],
"MomentOut": momentum_acc, "MomentOut": momentum_acc,
"MeanSquareOut": mean_square_acc "MeanSquareOut": mean_square_acc,
"MeanGradOut": mean_grad_acc
}, },
attrs={ attrs={
"epsilon": self._epsilon, "epsilon": self._epsilon,
"decay": self._rho, "decay": self._rho,
"momentum": self._momentum "momentum": self._momentum,
"centered": self._centered
}) })
return rmsprop_op return rmsprop_op
......
...@@ -291,7 +291,7 @@ class OpTest(unittest.TestCase): ...@@ -291,7 +291,7 @@ class OpTest(unittest.TestCase):
return_numpy=False) return_numpy=False)
return outs, fetch_list return outs, fetch_list
def check_output_with_place(self, place, atol): def check_output_with_place(self, place, atol, equal_nan=False):
outs, fetch_list = self._calc_output(place) outs, fetch_list = self._calc_output(place)
for out_name, out_dup in Operator.get_op_outputs(self.op_type): for out_name, out_dup in Operator.get_op_outputs(self.op_type):
if out_name not in self.outputs: if out_name not in self.outputs:
...@@ -321,7 +321,7 @@ class OpTest(unittest.TestCase): ...@@ -321,7 +321,7 @@ class OpTest(unittest.TestCase):
if isinstance(expect, tuple) else expect if isinstance(expect, tuple) else expect
self.assertTrue( self.assertTrue(
np.allclose( np.allclose(
actual_t, expect_t, atol=atol), actual_t, expect_t, atol=atol, equal_nan=equal_nan),
"Output (" + sub_out_name + ") has diff at " + "Output (" + sub_out_name + ") has diff at " +
str(place)) str(place))
if isinstance(expect, tuple): if isinstance(expect, tuple):
...@@ -337,7 +337,7 @@ class OpTest(unittest.TestCase): ...@@ -337,7 +337,7 @@ class OpTest(unittest.TestCase):
expect_t = expect[0] if isinstance(expect, tuple) else expect expect_t = expect[0] if isinstance(expect, tuple) else expect
self.assertTrue( self.assertTrue(
np.allclose( np.allclose(
actual_t, expect_t, atol=atol), actual_t, expect_t, atol=atol, equal_nan=equal_nan),
"Output (" + out_name + ") has diff at " + str(place) + "Output (" + out_name + ") has diff at " + str(place) +
"\nExpect " + str(expect_t) + "\n" + "But Got" + "\nExpect " + str(expect_t) + "\n" + "But Got" +
str(actual_t)) str(actual_t))
...@@ -360,10 +360,10 @@ class OpTest(unittest.TestCase): ...@@ -360,10 +360,10 @@ class OpTest(unittest.TestCase):
places.append(core.CUDAPlace(0)) places.append(core.CUDAPlace(0))
return places return places
def check_output(self, atol=1e-5): def check_output(self, atol=1e-5, equal_nan=False):
places = self._get_places() places = self._get_places()
for place in places: for place in places:
self.check_output_with_place(place, atol) self.check_output_with_place(place, atol, equal_nan)
def check_output_customized(self, checker): def check_output_customized(self, checker):
places = self._get_places() places = self._get_places()
......
...@@ -15,90 +15,164 @@ ...@@ -15,90 +15,164 @@
from __future__ import print_function from __future__ import print_function
import unittest import unittest
import numpy as np import numpy as np
from op_test import OpTest import paddle.fluid.core as core
from paddle.fluid.op import Operator
class TestRmspropOp1(OpTest):
''' Test RMSProp with explicit inputs class TestBase(unittest.TestCase):
''' def setup(self, centered, epsilon=1e-6):
np.random.seed(5) # fix seed
def setUp(self):
self.op_type = "rmsprop" self.param_name = "param"
self.param = np.random.random((123, 321)).astype("float32")
param = np.random.random((123, 321)).astype("float32")
mean_square = np.random.random((123, 321)).astype("float32") self.mean_square_name = "mean_square"
learning_rate = np.array([0.01]).astype("float32") self.mean_square = np.random.random((123, 321)).astype("float32")
grad = np.random.random((123, 321)).astype("float32")
moment = np.zeros((123, 321)).astype("float32") self.mean_grad_name = "mean_grad"
self.mean_grad = np.random.random((123, 321)).astype("float32")
epsilon = 1e-6
decay = 0.9 self.lr_name = "lr"
momentum = 0.0 self.learning_rate = np.array([0.01]).astype("float32")
self.inputs = { self.grad_name = "grad"
'Param': param, self.grad = np.random.random((123, 321)).astype("float32")
'MeanSquare': mean_square,
'LearningRate': learning_rate, self.moment_name = "moment"
'Grad': grad, self.moment = np.zeros((123, 321)).astype("float32")
'Moment': moment,
} self.epsilon = epsilon
self.decay = 0.9
self.attrs = {'epsilon': epsilon, 'decay': decay, 'momentum': momentum} self.momentum = 0.0
self.centered = centered
ms_out = decay * mean_square + (1 - decay) * grad * grad
moment_out = momentum * moment + \ self.ms_out = self.decay * self.mean_square + (1 - self.decay
learning_rate * grad / np.sqrt(ms_out + epsilon) ) * self.grad * self.grad
param_out = param - moment_out if centered:
self.mg_out = self.decay * self.mean_grad + (1 - self.decay
self.outputs = { ) * self.grad
'ParamOut': param_out, self.moment_out = self.momentum * self.moment + \
'MomentOut': moment_out, self.learning_rate * self.grad / np.sqrt(self.ms_out - np.square(self.mg_out) + self.epsilon)
'MeanSquareOut': ms_out else:
} self.moment_out = self.momentum * self.moment + \
self.learning_rate * self.grad / np.sqrt(self.ms_out + self.epsilon)
def test_check_output(self):
self.check_output() self.param_out = self.param - self.moment_out
def check(self,
class TestRmspropOp2(OpTest): actual_t,
'''Test RMSProp with default values for attributes expect_t,
''' place,
out_name,
def setUp(self): atol=1e-5,
self.op_type = "rmsprop" equal_nan=False):
self.assertTrue(
param = np.random.random((123, 321)).astype("float32") np.allclose(
mean_square = np.random.random((123, 321)).astype("float32") actual_t, expect_t, atol=atol, equal_nan=equal_nan),
learning_rate = np.array([0.01]).astype("float32") "Output (" + out_name + ") has diff at " + str(place) + "\nExpect "
grad = np.random.random((123, 321)).astype("float32") + str(expect_t) + "\n" + "But Got" + str(actual_t))
moment = np.zeros((123, 321)).astype("float32")
epsilon = 1.0e-10 class TestRmspropOp(TestBase):
decay = 0.9 def check_with_place(self, place, centered, epsilon):
momentum = 0.0 self.setup(centered, epsilon)
scope = core.Scope()
self.inputs = {
'Param': param, # create and initialize Param Variable
'MeanSquare': mean_square, param = scope.var(self.param_name).get_tensor()
'LearningRate': learning_rate, param.set(self.param, place)
'Grad': grad,
'Moment': moment, mean_square = scope.var(self.mean_square_name).get_tensor()
} mean_square.set(self.mean_square, place)
ms_out = decay * mean_square + (1 - decay) * grad * grad lr = scope.var(self.lr_name).get_tensor()
moment_out = momentum * moment + \ lr.set(self.learning_rate, place)
learning_rate * grad / np.sqrt(ms_out + epsilon)
param_out = param - moment_out grad = scope.var(self.grad_name).get_tensor()
grad.set(self.grad, place)
self.outputs = {
'ParamOut': param_out, moment = scope.var(self.moment_name).get_tensor()
'MomentOut': moment_out, moment.set(self.moment, place)
'MeanSquareOut': ms_out
} # create and run sgd operator
def test_check_output(self): if self.centered:
self.check_output() mean_grad = scope.var(self.mean_grad_name).get_tensor()
mean_grad.set(self.mean_grad, place)
rmsprop_op = Operator(
"rmsprop",
Param=self.param_name,
Grad=self.grad_name,
MeanSquare=self.mean_square_name,
MeanGrad=self.mean_grad_name,
Moment=self.moment_name,
LearningRate=self.lr_name,
ParamOut=self.param_name,
MeanSquareOut=self.mean_square_name,
MomentOut=self.moment_name,
MeanGradOut=self.mean_grad_name,
epsilon=self.epsilon,
decay=self.decay,
momentum=self.momentum,
centered=True)
else:
rmsprop_op = Operator(
"rmsprop",
Param=self.param_name,
Grad=self.grad_name,
MeanSquare=self.mean_square_name,
Moment=self.moment_name,
LearningRate=self.lr_name,
ParamOut=self.param_name,
MeanSquareOut=self.mean_square_name,
MomentOut=self.moment_name,
epsilon=self.epsilon,
decay=self.decay,
momentum=self.momentum,
centered=False)
rmsprop_op.run(scope, place)
atol = 1e-5
equal_nan = False
if self.centered:
atol = 1e-3
equal_nan = True
self.check(
np.array(mean_square), self.ms_out, place, self.mean_square_name)
self.check(
np.array(moment),
self.moment_out,
place,
self.moment_name,
atol=atol,
equal_nan=equal_nan)
self.check(
np.array(param),
self.param_out,
place,
self.param_name,
atol=atol,
equal_nan=equal_nan)
if self.centered:
self.check(
np.array(mean_grad), self.mg_out, place, self.mean_grad_name)
def test_rmsprop(self):
places = [core.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(core.CUDAPlace(0))
for place in places:
self.check_with_place(place, False, 1e-6)
self.check_with_place(place, False, 1e-10)
self.check_with_place(place, True, 1e-6)
self.check_with_place(place, True, 1e-10)
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册