test_autograd_functional_prim.py 10.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import config
18
import numpy as np
19 20
import utils

21 22
import paddle

23 24

@utils.place(config.DEVICES)
25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43
@utils.parameterize(
    (utils.TEST_CASE_NAME, 'fun', 'args', 'dtype'),
    (
        ('unary_float32', paddle.tanh, (np.random.rand(2, 3),), 'float32'),
        (
            'binary_float32',
            paddle.matmul,
            (np.random.rand(2, 3), np.random.rand(3, 2)),
            'float32',
        ),
        ('unary_float64', paddle.tanh, (np.random.rand(2, 3),), 'float64'),
        (
            'binary_float64',
            paddle.matmul,
            (np.random.rand(2, 3), np.random.rand(3, 2)),
            'float64',
        ),
    ),
)
44 45 46 47
class TestJacobianPrim(unittest.TestCase):
    @classmethod
    def setUpClass(cls):
        cls.args = [arg.astype(cls.dtype) for arg in cls.args]
48 49 50 51 52 53
        cls._rtol = (
            config.TOLERANCE.get(cls.dtype).get('first_order_grad').get('rtol')
        )
        cls._atol = (
            config.TOLERANCE.get(cls.dtype).get('first_order_grad').get('atol')
        )
54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78

    def setUp(self):
        paddle.enable_static()
        paddle.incubate.autograd.enable_prim()

    def tearDown(self):
        paddle.incubate.autograd.disable_prim()
        paddle.disable_static()

    def test_jacobian_prim(self):
        def wrapper(fun, args):
            mp = paddle.static.Program()
            sp = paddle.static.Program()
            with paddle.static.program_guard(mp, sp):
                static_args = [
                    paddle.static.data(f'arg{i}', arg.shape, self.dtype)
                    for i, arg in enumerate(args)
                ]
                for arg in static_args:
                    arg.stop_gradient = False
                jac = paddle.incubate.autograd.Jacobian(fun, static_args)[:]
                if paddle.incubate.autograd.prim_enabled():
                    paddle.incubate.autograd.prim2orig()
            exe = paddle.static.Executor()
            exe.run(sp)
79 80 81 82 83
            [jac] = exe.run(
                mp,
                feed={f'arg{i}': arg for i, arg in enumerate(args)},
                fetch_list=[jac],
            )
84 85 86 87 88 89 90
            return jac

        paddle.incubate.autograd.enable_prim()
        prim_jac = wrapper(self.fun, self.args)
        paddle.incubate.autograd.disable_prim()
        orig_jac = wrapper(self.fun, self.args)

91 92 93
        np.testing.assert_allclose(
            orig_jac, prim_jac, rtol=self._rtol, atol=self._atol
        )
94 95 96


@utils.place(config.DEVICES)
97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115
@utils.parameterize(
    (utils.TEST_CASE_NAME, 'fun', 'args', 'dtype'),
    (
        ('unary_float32', paddle.tanh, (np.random.rand(1),), 'float32'),
        (
            'binary_float32',
            paddle.multiply,
            (np.random.rand(1), np.random.rand(1)),
            'float32',
        ),
        ('unary_float64', paddle.tanh, (np.random.rand(1),), 'float64'),
        (
            'binary_float64',
            paddle.multiply,
            (np.random.rand(1), np.random.rand(1)),
            'float64',
        ),
    ),
)
116 117 118 119
class TestHessianPrim(unittest.TestCase):
    @classmethod
    def setUpClass(cls):
        cls.args = [arg.astype(cls.dtype) for arg in cls.args]
120 121 122 123 124 125
        cls._rtol = (
            config.TOLERANCE.get(cls.dtype).get('second_order_grad').get('rtol')
        )
        cls._atol = (
            config.TOLERANCE.get(cls.dtype).get('second_order_grad').get('atol')
        )
126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150

    def setUp(self):
        paddle.enable_static()
        paddle.incubate.autograd.enable_prim()

    def tearDown(self):
        paddle.incubate.autograd.disable_prim()
        paddle.disable_static()

    def test_jacobian_prim(self):
        def wrapper(fun, args):
            mp = paddle.static.Program()
            sp = paddle.static.Program()
            with paddle.static.program_guard(mp, sp):
                static_args = [
                    paddle.static.data(f'arg{i}', arg.shape, self.dtype)
                    for i, arg in enumerate(args)
                ]
                for arg in static_args:
                    arg.stop_gradient = False
                hessian = paddle.incubate.autograd.Hessian(fun, static_args)[:]
                if paddle.incubate.autograd.prim_enabled():
                    paddle.incubate.autograd.prim2orig()
            exe = paddle.static.Executor()
            exe.run(sp)
151 152 153 154 155
            [hessian] = exe.run(
                mp,
                feed={f'arg{i}': arg for i, arg in enumerate(args)},
                fetch_list=[hessian],
            )
156 157 158 159 160 161 162
            return hessian

        paddle.incubate.autograd.enable_prim()
        prim_jac = wrapper(self.fun, self.args)
        paddle.incubate.autograd.disable_prim()
        orig_jac = wrapper(self.fun, self.args)

163 164 165
        np.testing.assert_allclose(
            orig_jac, prim_jac, rtol=self._rtol, atol=self._atol
        )
166 167


168
@utils.place(config.DEVICES)
169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187
@utils.parameterize(
    (utils.TEST_CASE_NAME, 'fun', 'args', 'dtype'),
    (
        ('unary_float32', paddle.tanh, (np.random.rand(2, 3),), 'float32'),
        (
            'binary_float32',
            paddle.matmul,
            (np.random.rand(2, 3), np.random.rand(3, 2)),
            'float32',
        ),
        ('unary_float64', paddle.tanh, (np.random.rand(2, 3),), 'float64'),
        (
            'binary_float64',
            paddle.matmul,
            (np.random.rand(2, 3), np.random.rand(3, 2)),
            'float64',
        ),
    ),
)
188 189 190 191
class TestJvpPrim(unittest.TestCase):
    @classmethod
    def setUpClass(cls):
        cls.args = [arg.astype(cls.dtype) for arg in cls.args]
192 193 194 195 196 197
        cls._rtol = (
            config.TOLERANCE.get(cls.dtype).get('first_order_grad').get('rtol')
        )
        cls._atol = (
            config.TOLERANCE.get(cls.dtype).get('first_order_grad').get('atol')
        )
198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224

    def setUp(self):
        paddle.enable_static()
        paddle.incubate.autograd.enable_prim()

    def tearDown(self):
        paddle.incubate.autograd.disable_prim()
        paddle.disable_static()

    def test_jacobian_prim(self):
        def wrapper(fun, args):
            mp = paddle.static.Program()
            sp = paddle.static.Program()
            with paddle.static.program_guard(mp, sp):
                static_args = [
                    paddle.static.data(f'arg{i}', arg.shape, self.dtype)
                    for i, arg in enumerate(args)
                ]
                for arg in static_args:
                    arg.stop_gradient = False
                _, jvp_res = paddle.incubate.autograd.jvp(fun, static_args)
                if paddle.incubate.autograd.prim_enabled():
                    paddle.incubate.autograd.prim2orig()
            exe = paddle.static.Executor()
            exe.run(sp)
            jvp_res = exe.run(
                mp,
225 226 227
                feed={f'arg{i}': arg for i, arg in enumerate(args)},
                fetch_list=[jvp_res],
            )
228 229 230 231 232 233 234
            return jvp_res

        paddle.incubate.autograd.enable_prim()
        prim_jvp = wrapper(self.fun, self.args)
        paddle.incubate.autograd.disable_prim()
        orig_jvp = wrapper(self.fun, self.args)

235 236 237
        np.testing.assert_allclose(
            orig_jvp, prim_jvp, rtol=self._rtol, atol=self._atol
        )
238 239 240


@utils.place(config.DEVICES)
241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259
@utils.parameterize(
    (utils.TEST_CASE_NAME, 'fun', 'args', 'dtype'),
    (
        ('unary_float32', paddle.tanh, (np.random.rand(2, 3),), 'float32'),
        (
            'binary_float32',
            paddle.matmul,
            (np.random.rand(2, 3), np.random.rand(3, 2)),
            'float32',
        ),
        ('unary_float64', paddle.tanh, (np.random.rand(2, 3),), 'float64'),
        (
            'binary_float64',
            paddle.matmul,
            (np.random.rand(2, 3), np.random.rand(3, 2)),
            'float64',
        ),
    ),
)
260 261 262 263
class TestVjpPrim(unittest.TestCase):
    @classmethod
    def setUpClass(cls):
        cls.args = [arg.astype(cls.dtype) for arg in cls.args]
264 265 266 267 268 269
        cls._rtol = (
            config.TOLERANCE.get(cls.dtype).get('first_order_grad').get('rtol')
        )
        cls._atol = (
            config.TOLERANCE.get(cls.dtype).get('first_order_grad').get('atol')
        )
270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296

    def setUp(self):
        paddle.enable_static()
        paddle.incubate.autograd.enable_prim()

    def tearDown(self):
        paddle.incubate.autograd.disable_prim()
        paddle.disable_static()

    def test_jacobian_prim(self):
        def wrapper(fun, args):
            mp = paddle.static.Program()
            sp = paddle.static.Program()
            with paddle.static.program_guard(mp, sp):
                static_args = [
                    paddle.static.data(f'arg{i}', arg.shape, self.dtype)
                    for i, arg in enumerate(args)
                ]
                for arg in static_args:
                    arg.stop_gradient = False
                _, vjp_res = paddle.incubate.autograd.vjp(fun, static_args)
                if paddle.incubate.autograd.prim_enabled():
                    paddle.incubate.autograd.prim2orig()
            exe = paddle.static.Executor()
            exe.run(sp)
            vjp_res = exe.run(
                mp,
297 298 299
                feed={f'arg{i}': arg for i, arg in enumerate(args)},
                fetch_list=[vjp_res],
            )
300 301 302 303 304 305 306 307
            return vjp_res

        paddle.incubate.autograd.enable_prim()
        prim_vjp = wrapper(self.fun, self.args)
        paddle.incubate.autograd.disable_prim()
        orig_vjp = wrapper(self.fun, self.args)

        for orig, prim in zip(orig_vjp, prim_vjp):
308 309 310
            np.testing.assert_allclose(
                orig, prim, rtol=self._rtol, atol=self._atol
            )
311 312


313 314
if __name__ == "__main__":
    unittest.main()