提交 fa12e516 编写于 作者: K Kavya Srinet

Adding the default attribute test case

上级 94855f4a
......@@ -89,13 +89,13 @@ class RmspropOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<float>("epsilon",
"(float, default 1e-10) Constant "
"for numerical stability.")
.SetDefault(1e-10);
.SetDefault(1.0e-10f);
AddAttr<float>("decay",
"(float, default 0.9) "
"Discounting factor for coming gradient.")
.SetDefault(0.9);
.SetDefault(0.9f);
AddAttr<float>("momentum", "(float, default 0.0) Constant value")
.SetDefault(0.0);
.SetDefault(0.0f);
AddComment(R"DOC(
RMSprop
......
......@@ -28,9 +28,9 @@ template <typename Place, typename T>
class RmspropOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto param_out = ctx.Output<Tensor>("ParamOut");
auto moment_out = ctx.Output<Tensor>("MomentOut");
auto mean_square_out = ctx.Output<Tensor>("MeanSquareOut");
auto* param_out = ctx.Output<Tensor>("ParamOut");
auto* moment_out = ctx.Output<Tensor>("MomentOut");
auto* mean_square_out = ctx.Output<Tensor>("MeanSquareOut");
param_out->mutable_data<T>(ctx.GetPlace());
moment_out->mutable_data<T>(ctx.GetPlace());
......
......@@ -3,7 +3,10 @@ import numpy as np
from op_test import OpTest
class TestRmspropOp(OpTest):
class TestRmspropOp1(OpTest):
''' Test RMSProp with explicit inputs
'''
def setUp(self):
self.op_type = "rmsprop"
......@@ -42,5 +45,45 @@ class TestRmspropOp(OpTest):
self.check_output()
class TestRmspropOp2(OpTest):
'''Test RMSProp with defaukt values for attributes
'''
def setUp(self):
self.op_type = "rmsprop"
param = np.random.random((123, 321)).astype("float32")
mean_square = np.random.random((123, 321)).astype("float32")
learning_rate = np.array([0.01]).astype("float32")
grad = np.random.random((123, 321)).astype("float32")
moment = np.zeros((123, 321)).astype("float32")
epsilon = 1.0e-10
decay = 0.9
momentum = 0.0
self.inputs = {
'Param': param,
'MeanSquare': mean_square,
'LearningRate': learning_rate,
'Grad': grad,
'Moment': moment,
}
ms_out = decay * mean_square + (1 - decay) * grad * grad
moment_out = momentum * moment + \
learning_rate * grad / np.sqrt(ms_out + epsilon)
param_out = param - moment_out
self.outputs = {
'ParamOut': param_out,
'MomentOut': moment_out,
'MeanSquareOut': ms_out
}
def test_check_output(self):
self.check_output()
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册