test_cinn_prim.py 4.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import numpy as np

import paddle
import paddle.nn.functional as F
from paddle.fluid import core


def apply_to_static(net, use_cinn):
    build_strategy = paddle.static.BuildStrategy()
    build_strategy.build_cinn_pass = use_cinn
    return paddle.jit.to_static(net, build_strategy=build_strategy)


class PrimeNet(paddle.nn.Layer):
    def __init__(self):
        super(PrimeNet, self).__init__()
        self.fc = paddle.nn.Linear(4, 4)

    def forward(self, x):
        y = self.fc(x)
        out = F.softmax(y)
        return out


class TestPrimForward(unittest.TestCase):
    """
    This case only tests prim_forward + to_static + cinn. Thus we need to
    set this flag as False to avoid prim_backward.
    core.set_prim_backward(False)
    """

    def setUp(self):
        paddle.seed(2022)
        self.x = paddle.randn([2, 4])
        self.x.stop_gradient = False

    def train(self, use_prim):
        paddle.seed(2022)
        net = PrimeNet()
        sgd = paddle.optimizer.SGD(
            learning_rate=0.1, parameters=net.parameters()
        )
59
        core._set_prim_forward_enabled(use_prim)
60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79
        if use_prim:
            net = apply_to_static(net, use_prim)

        res = []
        for _ in range(10):
            out = net(self.x)
            loss = paddle.mean(out)
            loss.backward()
            sgd.step()
            sgd.clear_grad()

            res.append(out.numpy())

        self.check_prim(net, use_prim)

        return res

    def check_prim(self, net, use_prim):
        if not use_prim:
            return
X
xiongkun 已提交
80 81 82 83 84 85
        fwd_ops = [
            op.type
            for op in net.forward.get_concrete_program(self.x)[1]
            .train_program.block(0)
            .ops
        ]
86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114
        # Ensure that softmax is splitted into small ops
        self.assertTrue('softmax' not in fwd_ops)

    def test_cinn_prim_forward(self):
        dy_res = self.train(use_prim=False)
        cinn_res = self.train(use_prim=True)

        for i in range(len(dy_res)):
            np.testing.assert_allclose(
                cinn_res[i], dy_res[i], rtol=1e-7, atol=1e-7
            )


class TestPrimForwardAndBackward(unittest.TestCase):
    """
    Test PrimeNet with @to_static + prim forward + prim backward + cinn v.s Dygraph
    """

    def setUp(self):
        paddle.seed(2022)
        self.x = paddle.randn([2, 4])
        self.x.stop_gradient = False

    def train(self, use_prim):
        paddle.seed(2022)
        net = PrimeNet()
        sgd = paddle.optimizer.SGD(
            learning_rate=0.1, parameters=net.parameters()
        )
115
        core._set_prim_all_enabled(use_prim)
116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135
        if use_prim:
            net = apply_to_static(net, use_prim)

        res = []
        for _ in range(10):
            out = net(self.x)
            loss = paddle.mean(out)
            loss.backward()
            sgd.step()
            sgd.clear_grad()

            res.append(out.numpy())

        self.check_prim(net, use_prim)

        return res

    def check_prim(self, net, use_prim):
        if not use_prim:
            return
X
xiongkun 已提交
136 137 138 139 140 141
        fwd_ops = [
            op.type
            for op in net.forward.get_concrete_program(self.x)[1]
            .train_program.block(0)
            .ops
        ]
J
Jiabin Yang 已提交
142 143 144 145 146 147
        all_ops = [
            op.type
            for op in net.forward.program_cache.last()[-1][-1]
            .train_program.block(0)
            .ops
        ]
148 149
        # Ensure that softmax is splitted into small ops
        self.assertTrue('softmax' not in fwd_ops)
J
Jiabin Yang 已提交
150 151 152
        for op in all_ops:
            if op != "matmul_v2_grad":
                self.assertTrue("_grad" not in op)
153 154

    def test_cinn_prim(self):
155 156 157 158 159 160 161
        dy_res = self.train(use_prim=False)
        cinn_res = self.train(use_prim=True)

        for i in range(len(dy_res)):
            np.testing.assert_allclose(
                cinn_res[i], dy_res[i], rtol=1e-6, atol=1e-6
            )
162 163 164 165


if __name__ == '__main__':
    unittest.main()