test_layout_autotune.py 5.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
# 
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# 
#     http://www.apache.org/licenses/LICENSE-2.0
# 
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15 16
import os
import json
17
import numpy
18
import unittest
19 20
import tempfile
import warnings
21 22 23 24 25 26

import paddle
import paddle.nn.functional as F
from paddle.fluid.framework import _enable_legacy_dygraph

_enable_legacy_dygraph()
27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51


class SimpleNet(paddle.nn.Layer):
    def __init__(self, data_format="NCHW", class_num=2):
        super(SimpleNet, self).__init__()
        self.conv = paddle.nn.Conv2D(3, 8, (3, 3))
        self.bn = paddle.nn.BatchNorm(num_channels=8)
        self.relu = paddle.nn.ReLU()
        self.pool = paddle.nn.AvgPool2D(kernel_size=2, stride=2)
        self.flatten = paddle.nn.Flatten()
        self.fc = paddle.nn.Linear(392, class_num)

    def forward(self, image):
        conv_out = self.conv(image)
        bn_out = self.bn(conv_out)
        out = self.relu(bn_out)
        out = self.pool(out)
        out = self.flatten(out)
        out = self.fc(out)
        return conv_out, out


class LayoutAutoTune(unittest.TestCase):
    def use_autoune(self):
        if paddle.is_compiled_with_cuda():
52 53 54 55
            paddle.incubate.autotune.set_config(
                config={"layout": {
                    "enable": True
                }})
56 57
            return paddle.fluid.core.use_layout_autotune()
        else:
58 59 60 61 62 63
            config = {"layout": {"enable": False}}
            tfile = tempfile.NamedTemporaryFile(mode="w+", delete=False)
            json.dump(config, tfile)
            tfile.close()
            paddle.incubate.autotune.set_config(tfile.name)
            os.remove(tfile.name)
64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136
            return paddle.fluid.core.use_layout_autotune()

    def train(self, data_format):
        model = SimpleNet(data_format="NCHW", class_num=2)
        data = paddle.rand([1, 3, 16, 16])
        if (data_format == "NHWC"):
            data = paddle.rand([1, 16, 16, 3])
        label_data = paddle.randint(0, 1, shape=[1, 1], dtype="int64")
        optimizer = paddle.optimizer.SGD(learning_rate=0.0001,
                                         parameters=model.parameters())
        scaler = paddle.amp.GradScaler()
        for i in range(2):
            with paddle.amp.auto_cast(level="O2"):
                conv_out, predict = model(data)
                loss = F.cross_entropy(predict, label=label_data)
                loss = loss.mean()

            scaled = scaler.scale(loss)
            scaled.backward()
            scaler.minimize(optimizer, scaled)
        return conv_out, predict

    def test_enable_autotune(self):
        if self.use_autoune():
            conv_out, predict = self.train(data_format="NCHW")
            self.assertEqual(conv_out.shape, [1, 14, 14, 8])
            self.assertEqual(predict.shape, [1, 2])
        else:
            conv_out, predict = self.train(data_format="NCHW")
            self.assertEqual(conv_out.shape, [1, 8, 14, 14])
            self.assertEqual(predict.shape, [1, 2])

    def test_transpose_op_transposer(self):
        if not self.use_autoune():
            return
        conv = paddle.nn.Conv2D(3, 8, (3, 3))
        data = paddle.rand([1, 3, 16, 14])
        label_data = paddle.randint(0, 1, shape=[1, 1], dtype="int64")
        optimizer = paddle.optimizer.SGD(learning_rate=0.0001,
                                         parameters=conv.parameters())
        scaler = paddle.amp.GradScaler()
        with paddle.amp.auto_cast(level="O2"):
            conv_out = conv(data)
            # conv_out.shape = [1, 14, 12, 8] with NHWC
            # layout tuner will transpose conv_out to 
            # [1, 8, 14, 12] with NCHW before the following transpose op.
            out = paddle.transpose(conv_out, perm=[0, 3, 1, 2])
            loss = out.mean()
        scaled = scaler.scale(loss)
        scaled.backward()
        scaler.minimize(optimizer, scaled)

        self.assertEqual(conv_out.shape, [1, 14, 12, 8])
        self.assertEqual(out.shape, [1, 12, 8, 14])

    def test_flatten_op_transposer(self):
        if not self.use_autoune():
            return
        conv = paddle.nn.Conv2D(3, 8, (3, 3))
        flatten = paddle.nn.Flatten(start_axis=1, stop_axis=2)
        data = paddle.rand([1, 3, 16, 14])
        with paddle.amp.auto_cast(level="O2"):
            conv_out = conv(data)
            # conv_out.shape = [1, 14, 12, 8] with NHWC
            # layout tuner will transpose conv_out to
            # [1, 8, 14, 12] with NCHW before the following flatten op
            # because it flatten the C and H dimensions.
            out = flatten(conv_out)

        self.assertEqual(conv_out.shape, [1, 14, 12, 8])
        self.assertEqual(out.shape, [1, 112, 12])


137 138 139 140 141 142 143 144 145 146 147 148 149 150 151
class TestAutoTuneAPI(unittest.TestCase):
    def test_set_config_warnings(self):
        with warnings.catch_warnings(record=True) as w:
            config = {"layout": {"enable": 1}}
            # On linux, we can open the file again to read the content
            # without closing the file, but on windows system, there is
            # no permission to open it again without closing it.
            tfile = tempfile.NamedTemporaryFile(mode="w+", delete=False)
            json.dump(config, tfile)
            tfile.close()
            paddle.incubate.autotune.set_config(tfile.name)
            os.remove(tfile.name)
            self.assertTrue(len(w) == 1)


152 153
if __name__ == '__main__':
    unittest.main()