test_sgd_op.py 509 字节
Newer Older
Q
Qiao Longfei 已提交
1 2 3 4 5 6 7 8 9 10
import unittest
import numpy
from op_test_util import OpTestMeta


class TestSGD(unittest.TestCase):
    __metaclass__ = OpTestMeta

    def setUp(self):
        self.type = "sgd"
D
dangqingqing 已提交
11 12 13 14 15 16 17
        w = numpy.random.random((102, 105)).astype("float32")
        g = numpy.random.random((102, 105)).astype("float32")
        lr = 0.1

        self.inputs = {'param': w, 'grad': g}
        self.attrs = {'learning_rate': lr}
        self.outputs = {'param_out': w - lr * g}
Q
Qiao Longfei 已提交
18 19 20 21


if __name__ == "__main__":
    unittest.main()