test_sgd_op.py 3.4 KB
Newer Older
1
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
D
dzhwinter 已提交
2
#
D
dzhwinter 已提交
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
D
dzhwinter 已提交
6
#
D
dzhwinter 已提交
7
#     http://www.apache.org/licenses/LICENSE-2.0
D
dzhwinter 已提交
8
#
D
dzhwinter 已提交
9 10 11 12 13 14
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Q
Qiao Longfei 已提交
15
import unittest
Q
qijun 已提交
16
import numpy as np
17 18
import paddle.fluid.core as core
from paddle.fluid.op import Operator
Q
qijun 已提交
19
from op_test import OpTest
Q
Qiao Longfei 已提交
20 21


22
class TestSGDOp(OpTest):
Q
Qiao Longfei 已提交
23
    def setUp(self):
Q
qijun 已提交
24 25 26
        self.op_type = "sgd"
        w = np.random.random((102, 105)).astype("float32")
        g = np.random.random((102, 105)).astype("float32")
27
        lr = np.array([0.1]).astype("float32")
D
dangqingqing 已提交
28

29 30
        self.inputs = {'Param': w, 'Grad': g, 'LearningRate': lr}
        self.outputs = {'ParamOut': w - lr * g}
Q
Qiao Longfei 已提交
31

Q
qijun 已提交
32 33 34
    def test_check_output(self):
        self.check_output()

Q
Qiao Longfei 已提交
35

Q
qijun 已提交
36
class TestSparseSGDOp(unittest.TestCase):
Q
qijun 已提交
37
    def check_with_place(self, place):
Q
qijun 已提交
38 39 40 41 42 43 44 45 46 47 48 49 50
        scope = core.Scope()

        # create and initialize Grad Variable   
        height = 10
        rows = [0, 4, 7]
        row_numel = 12

        grad_selected_rows = scope.var('Grad').get_selected_rows()
        grad_selected_rows.set_height(height)
        grad_selected_rows.set_rows(rows)
        np_array = np.ones((len(rows), row_numel)).astype("float32")
        np_array[0, 0] = 2.0
        np_array[2, 8] = 4.0
Q
qijun 已提交
51

Q
qijun 已提交
52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71
        grad_tensor = grad_selected_rows.get_tensor()
        grad_tensor.set(np_array, place)

        # create and initialize Param Variable
        param = scope.var('Param').get_tensor()
        param_array = np.full((height, row_numel), 5.0).astype("float32")
        param.set(param_array, place)

        # create and initialize LeraningRate Variable
        lr = scope.var('LearningRate').get_tensor()
        lr_array = np.full((1), 2.0).astype("float32")
        lr.set(lr_array, place)

        # create and run sgd operator
        sgd_op = Operator(
            "sgd",
            Param='Param',
            Grad='Grad',
            ParamOut='Param',
            LearningRate='LearningRate')
D
dzhwinter 已提交
72
        sgd_op.run(scope, place)
Q
qijun 已提交
73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91

        # get and compare result
        result_array = np.array(param)

        # rows[0] = 0, 5.0 - 2.0 * 2.0
        self.assertAlmostEqual(1.0, result_array[rows[0], 0])
        # rows[0] = 0, 5.0 - 2.0 * 1.0
        self.assertAlmostEqual(3.0, result_array[rows[0], 2])
        # 5.0 - 2.0 * 0.0
        self.assertAlmostEqual(5.0, result_array[1, 0])
        # rows[1] = 4, 5.0 - 2.0 * 1.0
        self.assertAlmostEqual(3.0, result_array[rows[1], 10])
        # 5.0 - 2.0 * 0.0
        self.assertAlmostEqual(5.0, result_array[5, 8])
        # rows[2] = 7, 5.0 - 2.0 * 1.0
        self.assertAlmostEqual(3.0, result_array[rows[2], 1])
        # rows[2] = 7, 5.0 - 2.0 * 4.0
        self.assertAlmostEqual(-3.0, result_array[rows[2], 8])

Q
qijun 已提交
92 93
    def test_sparse_sgd(self):
        places = [core.CPUPlace()]
94
        if core.is_compiled_with_cuda():
D
dzhwinter 已提交
95
            places.append(core.CUDAPlace(0))
Q
qijun 已提交
96 97 98
        for place in places:
            self.check_with_place(place)

Q
qijun 已提交
99

Q
Qiao Longfei 已提交
100 101
if __name__ == "__main__":
    unittest.main()