test_sgd_op.py 3.3 KB
Newer Older
D
dzhwinter 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
#  Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
Q
Qiao Longfei 已提交
14
import unittest
Q
qijun 已提交
15
import numpy as np
Q
Qiao Longfei 已提交
16 17
import paddle.v2.fluid.core as core
from paddle.v2.fluid.op import Operator
Q
qijun 已提交
18
from op_test import OpTest
Q
Qiao Longfei 已提交
19 20


21
class TestSGDOp(OpTest):
Q
Qiao Longfei 已提交
22
    def setUp(self):
Q
qijun 已提交
23 24 25
        self.op_type = "sgd"
        w = np.random.random((102, 105)).astype("float32")
        g = np.random.random((102, 105)).astype("float32")
26
        lr = np.array([0.1]).astype("float32")
D
dangqingqing 已提交
27

28 29
        self.inputs = {'Param': w, 'Grad': g, 'LearningRate': lr}
        self.outputs = {'ParamOut': w - lr * g}
Q
Qiao Longfei 已提交
30

Q
qijun 已提交
31 32 33
    def test_check_output(self):
        self.check_output()

Q
Qiao Longfei 已提交
34

Q
qijun 已提交
35
class TestSparseSGDOp(unittest.TestCase):
Q
qijun 已提交
36
    def check_with_place(self, place):
Q
qijun 已提交
37 38 39 40 41 42 43 44 45 46 47 48 49
        scope = core.Scope()

        # create and initialize Grad Variable   
        height = 10
        rows = [0, 4, 7]
        row_numel = 12

        grad_selected_rows = scope.var('Grad').get_selected_rows()
        grad_selected_rows.set_height(height)
        grad_selected_rows.set_rows(rows)
        np_array = np.ones((len(rows), row_numel)).astype("float32")
        np_array[0, 0] = 2.0
        np_array[2, 8] = 4.0
Q
qijun 已提交
50

Q
qijun 已提交
51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70
        grad_tensor = grad_selected_rows.get_tensor()
        grad_tensor.set(np_array, place)

        # create and initialize Param Variable
        param = scope.var('Param').get_tensor()
        param_array = np.full((height, row_numel), 5.0).astype("float32")
        param.set(param_array, place)

        # create and initialize LeraningRate Variable
        lr = scope.var('LearningRate').get_tensor()
        lr_array = np.full((1), 2.0).astype("float32")
        lr.set(lr_array, place)

        # create and run sgd operator
        sgd_op = Operator(
            "sgd",
            Param='Param',
            Grad='Grad',
            ParamOut='Param',
            LearningRate='LearningRate')
D
dzhwinter 已提交
71
        sgd_op.run(scope, place)
Q
qijun 已提交
72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90

        # get and compare result
        result_array = np.array(param)

        # rows[0] = 0, 5.0 - 2.0 * 2.0
        self.assertAlmostEqual(1.0, result_array[rows[0], 0])
        # rows[0] = 0, 5.0 - 2.0 * 1.0
        self.assertAlmostEqual(3.0, result_array[rows[0], 2])
        # 5.0 - 2.0 * 0.0
        self.assertAlmostEqual(5.0, result_array[1, 0])
        # rows[1] = 4, 5.0 - 2.0 * 1.0
        self.assertAlmostEqual(3.0, result_array[rows[1], 10])
        # 5.0 - 2.0 * 0.0
        self.assertAlmostEqual(5.0, result_array[5, 8])
        # rows[2] = 7, 5.0 - 2.0 * 1.0
        self.assertAlmostEqual(3.0, result_array[rows[2], 1])
        # rows[2] = 7, 5.0 - 2.0 * 4.0
        self.assertAlmostEqual(-3.0, result_array[rows[2], 8])

Q
qijun 已提交
91 92 93
    def test_sparse_sgd(self):
        places = [core.CPUPlace()]
        if core.is_compile_gpu():
D
dzhwinter 已提交
94
            places.append(core.CUDAPlace(0))
Q
qijun 已提交
95 96 97
        for place in places:
            self.check_with_place(place)

Q
qijun 已提交
98

Q
Qiao Longfei 已提交
99 100
if __name__ == "__main__":
    unittest.main()