test_sgd_op.py 472 字节
Newer Older
Q
Qiao Longfei 已提交
1 2 3 4 5 6 7 8 9 10
import unittest
import numpy
from op_test_util import OpTestMeta


class TestSGD(unittest.TestCase):
    __metaclass__ = OpTestMeta

    def setUp(self):
        self.type = "sgd"
Q
qijun 已提交
11 12
        self.param = numpy.random.random((102, 105)).astype("float32")
        self.grad = numpy.random.random((102, 105)).astype("float32")
Q
Qiao Longfei 已提交
13 14 15 16 17 18
        self.learning_rate = 0.1
        self.param_out = self.param - self.learning_rate * self.grad


if __name__ == "__main__":
    unittest.main()