test_ir_vjp.py 5.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import paddle
from paddle import ir
from paddle.fluid.core import call_vjp, has_vjp

paddle.enable_static()


def get_ir_program():
    main_program, start_program = (
        paddle.static.Program(),
        paddle.static.Program(),
    )
    with paddle.static.program_guard(main_program, start_program):
        x = paddle.static.data('x', [4, 4], 'float32')
        x.stop_gradient = False
        paddle.tanh(x)
        paddle.tensor.fill_constant(shape=[4, 4], dtype='float32', value=2.0)
    newir_program = ir.translate_to_new_ir(main_program.desc)
    return newir_program


class TestTanhVjp(unittest.TestCase):
    def test_tanh_vjp1(self):
        newir_program = get_ir_program()
41 42
        tanh_op = newir_program.block().ops[-2]
        fill_constant_op = newir_program.block().ops[-1]
43
        out_grads = [[fill_constant_op.result(0)]]
44
        stop_gradients = [[False]]
45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67
        with paddle.ir.core.program_guard(newir_program):
            grad_outs = call_vjp(tanh_op, out_grads, stop_gradients)
        self.assertEqual(
            grad_outs[0][0].get_defining_op().name(), "pd.tanh_grad"
        )
        self.assertEqual(
            grad_outs[0][0]
            .get_defining_op()
            .operands()[0]
            .source()
            .get_defining_op()
            .name(),
            "pd.tanh",
        )
        self.assertEqual(
            grad_outs[0][0]
            .get_defining_op()
            .operands()[1]
            .source()
            .get_defining_op()
            .name(),
            "pd.full",
        )
68
        self.assertEqual(len(newir_program.block().ops), 4)
69 70 71

    def test_tanh_vjp2(self):
        newir_program = get_ir_program()
72 73
        tanh_op = newir_program.block().ops[-2]
        fill_constant_op = newir_program.block().ops[-1]
74
        out_grads = [[fill_constant_op.result(0)]]
75
        stop_gradients = [[True]]
76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92
        with paddle.ir.core.program_guard(newir_program):
            grad_outs = call_vjp(tanh_op, out_grads, stop_gradients)
        self.assertEqual(grad_outs[0][0], None)


class TestMeanVjp(unittest.TestCase):
    def test_mean_vjp1(self):
        main_program, start_program = (
            paddle.static.Program(),
            paddle.static.Program(),
        )
        with paddle.static.program_guard(main_program, start_program):
            x = paddle.static.data('x', [4, 4], 'float32')
            x.stop_gradient = False
            paddle.mean(x, axis=[0, 1])
            paddle.tensor.fill_constant(shape=[1], dtype='float32', value=2.0)
        newir_program = ir.translate_to_new_ir(main_program.desc)
93 94
        fill_constant_op = newir_program.block().ops[-1]
        mean_op = newir_program.block().ops[-2]
95
        out_grads = [[fill_constant_op.result(0)]]
96
        stop_gradients = [[False]]
97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119
        with paddle.ir.core.program_guard(newir_program):
            grad_outs = call_vjp(mean_op, out_grads, stop_gradients)
            self.assertEqual(
                grad_outs[0][0].get_defining_op().name(), "pd.mean_grad"
            )
            self.assertEqual(
                grad_outs[0][0]
                .get_defining_op()
                .operands()[0]
                .source()
                .get_defining_op()
                .name(),
                "builtin.get_parameter",
            )
            self.assertEqual(
                grad_outs[0][0]
                .get_defining_op()
                .operands()[1]
                .source()
                .get_defining_op()
                .name(),
                "pd.full",
            )
120
            self.assertEqual(len(newir_program.block().ops), 4)
121 122 123 124 125 126 127 128 129 130 131 132

    def test_mean_vjp2(self):
        main_program, start_program = (
            paddle.static.Program(),
            paddle.static.Program(),
        )
        with paddle.static.program_guard(main_program, start_program):
            x = paddle.static.data('x', [4, 4], 'float32')
            x.stop_gradient = False
            paddle.mean(x, axis=[0, 1])
            paddle.tensor.fill_constant(shape=[1], dtype='float32', value=2.0)
        newir_program = ir.translate_to_new_ir(main_program.desc)
133 134
        fill_constant_op = newir_program.block().ops[-1]
        mean_op = newir_program.block().ops[-2]
135
        out_grads = [[fill_constant_op.result(0)]]
136
        stop_gradients = [[True]]
137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153
        with paddle.ir.core.program_guard(newir_program):
            grad_outs = call_vjp(mean_op, out_grads, stop_gradients)
            self.assertEqual(grad_outs[0][0], None)


class TesthasVjp(unittest.TestCase):
    def test_has_vjp(self):
        main_program, start_program = (
            paddle.static.Program(),
            paddle.static.Program(),
        )
        with paddle.static.program_guard(main_program, start_program):
            x = paddle.static.data('x', [4, 4], 'float32')
            x.stop_gradient = False
            paddle.mean(x, axis=[0, 1])
            paddle.tensor.fill_constant(shape=[1], dtype='float32', value=2.0)
        newir_program = ir.translate_to_new_ir(main_program.desc)
154 155
        fill_constant_op = newir_program.block().ops[-1]
        mean_op = newir_program.block().ops[-2]
156 157 158 159 160 161
        self.assertEqual(has_vjp(fill_constant_op), False)
        self.assertEqual(has_vjp(mean_op), True)


if __name__ == "__main__":
    unittest.main()