test_sgd_op.py 472 字节
Newer Older
Q
Qiao Longfei 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
import unittest
import numpy
from op_test_util import OpTestMeta


class TestSGD(unittest.TestCase):
    __metaclass__ = OpTestMeta

    def setUp(self):
        self.type = "sgd"
        self.param = numpy.random.random((342, 345)).astype("float32")
        self.grad = numpy.random.random((342, 345)).astype("float32")
        self.learning_rate = 0.1
        self.param_out = self.param - self.learning_rate * self.grad


if __name__ == "__main__":
    unittest.main()