test_rmsprop_op.py 2.4 KB
Newer Older
1 2 3 4 5
import unittest
import numpy as np
from op_test import OpTest


6 7 8 9
class TestRmspropOp1(OpTest):
    ''' Test RMSProp with explicit inputs
    '''

10 11 12 13
    def setUp(self):
        self.op_type = "rmsprop"

        param = np.random.random((123, 321)).astype("float32")
14 15
        mean_square = np.random.random((123, 321)).astype("float32")
        learning_rate = np.array([0.01]).astype("float32")
16 17 18 19
        grad = np.random.random((123, 321)).astype("float32")
        moment = np.zeros((123, 321)).astype("float32")

        epsilon = 1e-6
20 21
        decay = 0.9
        momentum = 0.0
22

K
Kavya Srinet 已提交
23 24
        self.inputs = {
            'Param': param,
25 26
            'MeanSquare': mean_square,
            'LearningRate': learning_rate,
K
Kavya Srinet 已提交
27 28
            'Grad': grad,
            'Moment': moment,
29 30
        }

31
        self.attrs = {'epsilon': epsilon, 'decay': decay, 'momentum': momentum}
K
Kavya Srinet 已提交
32

33 34 35 36
        ms_out = decay * mean_square + (1 - decay) * grad * grad
        moment_out = momentum * moment + \
            learning_rate * grad / np.sqrt(ms_out + epsilon)
        param_out = param - moment_out
37

38 39 40 41 42
        self.outputs = {
            'ParamOut': param_out,
            'MomentOut': moment_out,
            'MeanSquareOut': ms_out
        }
43 44 45 46 47

    def test_check_output(self):
        self.check_output()


48
class TestRmspropOp2(OpTest):
K
kavyasrinet 已提交
49
    '''Test RMSProp with default values for attributes
50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87
    '''

    def setUp(self):
        self.op_type = "rmsprop"

        param = np.random.random((123, 321)).astype("float32")
        mean_square = np.random.random((123, 321)).astype("float32")
        learning_rate = np.array([0.01]).astype("float32")
        grad = np.random.random((123, 321)).astype("float32")
        moment = np.zeros((123, 321)).astype("float32")

        epsilon = 1.0e-10
        decay = 0.9
        momentum = 0.0

        self.inputs = {
            'Param': param,
            'MeanSquare': mean_square,
            'LearningRate': learning_rate,
            'Grad': grad,
            'Moment': moment,
        }

        ms_out = decay * mean_square + (1 - decay) * grad * grad
        moment_out = momentum * moment + \
            learning_rate * grad / np.sqrt(ms_out + epsilon)
        param_out = param - moment_out

        self.outputs = {
            'ParamOut': param_out,
            'MomentOut': moment_out,
            'MeanSquareOut': ms_out
        }

    def test_check_output(self):
        self.check_output()


88 89
if __name__ == "__main__":
    unittest.main()