test_primapi.py 9.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import typing
import unittest

import numpy as np
import paddle
from paddle.incubate.autograd import primapi

import config
import utils


@utils.place(config.DEVICES)
@utils.parameterize(
    (utils.TEST_CASE_NAME, 'fun', 'xs', 'v', 'dtype'),
    (('matmul', paddle.matmul,
      (np.random.rand(2, 3), np.random.rand(3, 2)), None, 'float32'),
     ('multiply', paddle.multiply,
      (np.random.rand(2, 3), np.random.rand(2, 3)), None, 'float64'),
     ('add', paddle.add,
      (np.random.rand(2, 3), np.random.rand(2, 3)), None, 'float32'),
     ('input_not_sequence', paddle.tanh,
      (np.random.rand(5, 5), ), None, 'float64'),
     ('input_gradients_not_none', paddle.matmul,
      (np.random.rand(3, 3), np.random.rand(3, 3)),
      (np.random.rand(3, 3), np.random.rand(3, 3)), 'float64')))
40
class TestForwardGrad(unittest.TestCase):
41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57

    @classmethod
    def setUpClass(cls):
        cls.xs = tuple(x.astype(cls.dtype) for x in cls.xs)
        cls._rtol = config.TOLERANCE.get(str(
            cls.dtype)).get("first_order_grad").get("rtol")
        cls._atol = config.TOLERANCE.get(str(
            cls.dtype)).get("first_order_grad").get("atol")

    def setUp(self):
        paddle.enable_static()
        paddle.incubate.autograd.enable_prim()

    def tearDown(self):
        paddle.incubate.autograd.disable_prim()
        paddle.disable_static()

58
    def test_forward_grad(self):
59 60 61 62 63 64 65 66

        def expected():
            paddle.incubate.autograd.disable_prim()
            sp = paddle.static.Program()
            mp = paddle.static.Program()
            with paddle.static.program_guard(mp, sp):
                feed, static_xs, static_v = utils.gen_static_data_and_feed(
                    self.xs, self.v, stop_gradient=False)
67 68
                _, ys_grad = paddle.incubate.autograd.jvp(
                    self.fun, static_xs, static_v)
69 70 71 72 73 74 75 76 77 78 79 80 81 82 83
            exe = paddle.static.Executor()
            exe.run(sp)
            out = exe.run(mp, feed=feed, fetch_list=ys_grad)
            paddle.incubate.autograd.enable_prim()
            return out

        def actual():
            paddle.incubate.autograd.enable_prim()
            sp = paddle.static.Program()
            mp = paddle.static.Program()
            with paddle.static.program_guard(mp, sp):
                feed, static_xs, static_v = utils.gen_static_data_and_feed(
                    self.xs, self.v, stop_gradient=False)
                ys = self.fun(*static_xs) if isinstance(
                    static_xs, typing.Sequence) else self.fun(static_xs)
84 85
                ys_grad = paddle.incubate.autograd.forward_grad(
                    ys, static_xs, static_v)
86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110
                paddle.incubate.autograd.prim2orig(mp.block(0))
            exe = paddle.static.Executor()
            exe.run(sp)
            out = exe.run(mp, feed=feed, fetch_list=ys_grad)
            paddle.incubate.autograd.disable_prim()
            return out

        actual = actual()
        expected = expected()
        self.assertEqual(type(actual), type(expected))
        np.testing.assert_allclose(np.concatenate(actual),
                                   np.concatenate(expected),
                                   rtol=self._rtol,
                                   atol=self._atol)

    def test_prim_disabled(self):
        paddle.incubate.autograd.disable_prim()
        sp = paddle.static.Program()
        mp = paddle.static.Program()
        with self.assertRaises(RuntimeError):
            with paddle.static.program_guard(mp, sp):
                feed, static_xs, static_v = utils.gen_static_data_and_feed(
                    self.xs, self.v, stop_gradient=False)
                ys = self.fun(*static_xs) if isinstance(
                    static_xs, typing.Sequence) else self.fun(static_xs)
111
                ys_grad = primapi.forward_grad(ys, static_xs, static_v)
112 113 114 115 116 117 118 119 120
                paddle.incubate.autograd.prim2orig(mp.block(0))
            exe = paddle.static.Executor()
            exe.run(sp)
            exe.run(mp, feed=feed, fetch_list=ys_grad)
        paddle.incubate.autograd.enable_prim()

    def test_illegal_param(self):
        paddle.incubate.autograd.enable_prim()
        with self.assertRaises(TypeError):
121
            primapi.forward_grad(1, paddle.static.data('inputs', shape=[1]))
122 123

        with self.assertRaises(TypeError):
124
            primapi.forward_grad(paddle.static.data('targets', shape=[1]), 1)
125 126 127
        paddle.incubate.autograd.disable_prim()


128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240
class TestGrad(unittest.TestCase):

    def setUp(self):
        paddle.enable_static()
        paddle.incubate.autograd.enable_prim()

    def tearDown(self):
        paddle.incubate.autograd.disable_prim()
        paddle.disable_static()

    def test_third_order(self):
        paddle.incubate.autograd.enable_prim()
        main = paddle.static.Program()
        startup = paddle.static.Program()
        with paddle.static.program_guard(main, startup):
            x = paddle.static.data(name='x', shape=[1], dtype='float32')
            x2 = paddle.multiply(x, x)
            x3 = paddle.multiply(x2, x)
            x4 = paddle.multiply(x3, x)

            grad1, = paddle.incubate.autograd.grad([x4], [x])
            grad2, = paddle.incubate.autograd.grad([grad1], [x])
            grad3, = paddle.incubate.autograd.grad([grad2], [x])

            paddle.incubate.autograd.prim2orig(main.block(0))

        feed = {x.name: np.array([2.]).astype('float32')}
        fetch_list = [grad3.name]
        result = [np.array([48.])]

        place = paddle.CPUPlace()
        if paddle.device.is_compiled_with_cuda():
            place = paddle.CUDAPlace(0)
        exe = paddle.static.Executor(place)
        exe.run(startup)
        outs = exe.run(main, feed=feed, fetch_list=fetch_list)
        np.allclose(outs, result)
        paddle.incubate.autograd.disable_prim()

    def test_fourth_order(self):
        paddle.incubate.autograd.enable_prim()
        main = paddle.static.Program()
        startup = paddle.static.Program()
        with paddle.static.program_guard(main, startup):
            x = paddle.static.data(name='x', shape=[1], dtype='float32')
            x2 = paddle.multiply(x, x)
            x3 = paddle.multiply(x2, x)
            x4 = paddle.multiply(x3, x)
            x5 = paddle.multiply(x4, x)
            out = paddle.sqrt(x5 + x4)

            grad1, = paddle.incubate.autograd.grad([out], [x])
            grad2, = paddle.incubate.autograd.grad([grad1], [x])
            grad3, = paddle.incubate.autograd.grad([grad2], [x])
            grad4, = paddle.incubate.autograd.grad([grad3], [x])

            paddle.incubate.autograd.prim2orig(main.block(0))

        feed = {
            x.name: np.array([2.]).astype('float32'),
        }
        fetch_list = [grad4.name]
        # (3*(-5*x^2-16*x-16))/(16*(x+1)^3.5)
        result = [np.array([-0.27263762711])]

        place = paddle.CPUPlace()
        if paddle.device.is_compiled_with_cuda():
            place = paddle.CUDAPlace(0)
        exe = paddle.static.Executor(place)
        exe.run(startup)
        outs = exe.run(main, feed=feed, fetch_list=fetch_list)
        np.allclose(outs, result)
        paddle.incubate.autograd.disable_prim()

    def test_disable_prim(self):

        def actual(x: np.array):
            paddle.incubate.autograd.disable_prim()
            main = paddle.static.Program()
            startup = paddle.static.Program()
            with paddle.static.program_guard(main, startup):
                var_x = paddle.static.data('x', shape=x.shape, dtype=x.dtype)
                var_x.stop_gradient = False
                y = paddle.tanh(var_x)
                y_grad = paddle.incubate.autograd.grad(y, var_x)
                y_second_grad = paddle.incubate.autograd.grad(y_grad, var_x)
            exe = paddle.static.Executor()
            exe.run(startup)
            return exe.run(main,
                           feed={'x': x},
                           fetch_list=[y_grad, y_second_grad])

        def expect(x: np.array):
            paddle.incubate.autograd.disable_prim()
            main = paddle.static.Program()
            startup = paddle.static.Program()
            with paddle.static.program_guard(main, startup):
                var_x = paddle.static.data('x', shape=x.shape, dtype=x.dtype)
                var_x.stop_gradient = False
                y = paddle.tanh(var_x)
                y_grad = paddle.static.gradients(y, var_x)
                y_second_grad = paddle.static.gradients(y_grad, var_x)
            exe = paddle.static.Executor()
            exe.run(startup)
            return exe.run(main,
                           feed={'x': x},
                           fetch_list=[y_grad, y_second_grad])

        x = np.random.randn(100, 200)
        for i, j in zip(actual(x), expect(x)):
            np.testing.assert_allclose(i, j)


241 242
if __name__ == '__main__':
    unittest.main()