test_prim_custom_vjp.py 3.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import numpy as np

import paddle
from paddle import _ir_ops, nn
from paddle.autograd.ir_backward import grad
from paddle.decomposition import decompose
from paddle.framework import core

paddle.enable_static()


class SimpNet(nn.Layer):
    def __init__(self):
        super().__init__()

    def forward(self, x, linear1_weight, linear2_weight):
        x2 = _ir_ops.matmul(x, linear1_weight, False, False)
        x3 = _ir_ops.gelu(x2, False)
        res = _ir_ops.matmul(x3, linear2_weight, False, False)
        return res


class TestPrimMode(unittest.TestCase):
    def setUp(self):
        np.random.seed(2023)
        self.shape_x = [2, 1024, 1024]
        self.shape_y = [2, 1024, 1024]
        self.shape_l1_w = [2, 1024, 4096]
        self.shape_l2_w = [2, 4096, 1024]
        self.x = np.random.random(self.shape_x).astype("float32")
        self.y = np.random.random(self.shape_y).astype("float32")
        self.l1_w = np.random.random(self.shape_l1_w).astype("float32")
        self.l2_w = np.random.random(self.shape_l2_w).astype("float32")

    def base_net(self, flag=None):
        if flag == "backward":
            core._set_prim_backward_enabled(True)
        main_program = paddle.static.Program()
        with paddle.static.program_guard(main_program):
            net = SimpNet()
            x = paddle.static.data('x', self.shape_x, dtype='float32')
            y = paddle.static.data('y', self.shape_y, dtype='float32')
            x.stop_gradient = False
            y.stop_gradient = False
            l1_w = paddle.static.data('l1_w', self.shape_l1_w, dtype='float32')
            l2_w = paddle.static.data('l2_w', self.shape_l2_w, dtype='float32')
            divide_out = paddle.divide(x, y)
            res = net(divide_out, l1_w, l2_w)
            [res2] = decompose(
                main_program,
                [res],
            )
            gradients = grad(res2, (x, y))

            if flag == "backward":
                whole_ops_before = [
                    op.name() for op in main_program.block().ops
                ]
                assert (
                    "pd.gelu" in whole_ops_before
                    and "pd.gelu_grad" not in whole_ops_before
                )
                core._set_prim_forward_enabled(True)
                [res2] = decompose(main_program, [res2], whitelist={"pd.gelu"})
                whole_ops_after = [op.name() for op in main_program.block().ops]
                assert "pd.gelu" not in whole_ops_after
                core._set_prim_forward_enabled(False)

            exe = paddle.static.Executor()
            outs = exe.run(
                feed={
                    'x': self.x,
                    'y': self.y,
                    'l1_w': self.l1_w,
                    'l2_w': self.l2_w,
                },
                fetch_list=[res2, gradients[0], gradients[1]],
            )

        if flag == "backward":
            core._set_prim_backward_enabled(False)
        return outs

    def test_prim_custom_vjp(self):
        res_ref = self.base_net()
        res = self.base_net("backward")
        for ref, actual in zip(res_ref, res):
            np.testing.assert_allclose(ref, actual, rtol=1e-6)


if __name__ == "__main__":
    unittest.main()