test_dy2static_ipu.py 8.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
#  Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function

A
Allen Guo 已提交
17
import tempfile
18
import unittest
A
Allen Guo 已提交
19 20

import numpy as np
21
import paddle
A
Allen Guo 已提交
22 23
from paddle.fluid.dygraph.dygraph_to_static.program_translator import ProgramCache
from paddle.fluid.tests.unittests.ipu.op_test_ipu import IPUOpTest
24 25
from paddle.jit import to_static
from paddle.optimizer.lr import LRScheduler
A
Allen Guo 已提交
26
from functools import partial
27 28 29 30 31 32

SEED = 2022


class SimpleLayer(paddle.nn.Layer):

A
Allen Guo 已提交
33 34 35 36 37
    def __init__(self,
                 loss_op=None,
                 use_softmax=True,
                 use_reduction=True,
                 use_identity_loss=True):
38
        super(SimpleLayer, self).__init__()
A
Allen Guo 已提交
39
        self.loss_op = loss_op
40 41 42 43
        self.conv = paddle.nn.Conv2D(in_channels=3,
                                     out_channels=1,
                                     kernel_size=2,
                                     stride=1)
A
Allen Guo 已提交
44 45 46
        self.use_softmax = use_softmax
        self.use_reduction = use_reduction
        self.use_identity_loss = use_identity_loss
47 48 49 50 51 52

    @to_static()
    def forward(self, x, target=None):
        x = self.conv(x)
        x = paddle.fluid.layers.flatten(x, axis=1)
        if target is not None:
A
Allen Guo 已提交
53 54 55 56
            if self.use_softmax:
                x = paddle.fluid.layers.softmax(x)
            if self.loss_op:
                loss = self.loss_op(x, target)
57
            else:
A
Allen Guo 已提交
58 59
                loss = paddle.fluid.layers.cross_entropy(x, target)
            if self.use_reduction:
60
                loss = paddle.mean(loss)
A
Allen Guo 已提交
61 62
            if self.use_identity_loss:
                loss = paddle.incubate.identity_loss(loss, 1)
63 64 65 66
            return x, loss
        return x


A
Allen Guo 已提交
67 68 69
@unittest.skipIf(not paddle.is_compiled_with_ipu(),
                 "core is not compiled with IPU")
class TestBase(IPUOpTest):
70

A
Allen Guo 已提交
71
    def setUp(self):
72
        paddle.disable_static()
A
Allen Guo 已提交
73 74 75 76 77 78 79 80 81 82 83 84 85 86 87
        self.set_op_attrs()
        self.set_data_feed()

    def set_op_attrs(self):
        self.loss_op = paddle.fluid.layers.cross_entropy

    def set_data_feed(self):
        self.data = paddle.uniform((32, 3, 10, 10), dtype='float32')
        self.label = paddle.randint(0, 10, shape=[32], dtype='int64')

    def create_model(self, use_ipu=False):
        return SimpleLayer(loss_op=self.loss_op,
                           use_softmax=True,
                           use_reduction=not use_ipu,
                           use_identity_loss=use_ipu)
88 89 90 91

    def _test(self, use_ipu=False):
        paddle.seed(SEED)
        np.random.seed(SEED)
A
Allen Guo 已提交
92
        model = self.create_model(use_ipu)
93 94 95 96 97 98 99 100 101 102 103 104 105 106 107
        optim = paddle.optimizer.Adam(learning_rate=0.01,
                                      parameters=model.parameters())

        if use_ipu:
            device = paddle.set_device('ipu')
            ipu_strategy = paddle.static.IpuStrategy()
            ipu_strategy.set_graph_config(num_ipus=1,
                                          is_training=True,
                                          micro_batch_size=1,
                                          enable_manual_shard=False)
            ipu_strategy.set_optimizer(optim)

        result = []
        for epoch in range(100):
            # ipu only needs call model() to do forward/backward/grad_update
A
Allen Guo 已提交
108
            pred, loss = model(self.data, self.label)
109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140
            if not use_ipu:
                loss.backward()
                optim.step()
                optim.clear_grad()

            result.append(loss)

        if use_ipu:
            ipu_strategy.release_patch()

        return np.array(result)

    def test_training(self):
        ipu_loss = self._test(True).flatten()
        cpu_loss = self._test(False).flatten()

        self.assertTrue(np.allclose(ipu_loss, cpu_loss, atol=1e-4))


class TestSaveLoad(TestBase):

    @classmethod
    def setUpClass(cls):
        cls.save_path = tempfile.TemporaryDirectory()

    @classmethod
    def tearDownClass(cls):
        cls.save_path.cleanup()

    def _test(self, use_ipu=False):
        paddle.seed(SEED)
        np.random.seed(SEED)
A
Allen Guo 已提交
141
        model = self.create_model(use_ipu)
142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160
        optim = paddle.optimizer.Adam(learning_rate=0.01,
                                      parameters=model.parameters())
        model_path = '{}/model_state_dict_{}.pdparams'.format(
            self.save_path, 'ipu' if use_ipu else 'cpu')
        optim_path = '{}/optim_state_dict_{}.pdopt'.format(
            self.save_path, 'ipu' if use_ipu else 'cpu')

        if use_ipu:
            device = paddle.set_device('ipu')
            ipu_strategy = paddle.static.IpuStrategy()
            ipu_strategy.set_graph_config(num_ipus=1,
                                          is_training=True,
                                          micro_batch_size=1,
                                          enable_manual_shard=False)
            ipu_strategy.set_optimizer(optim)

        result = []
        for epoch in range(100):
            # ipu only needs call model() to do forward/backward/grad_update
A
Allen Guo 已提交
161
            pred, loss = model(self.data, self.label)
162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179
            if not use_ipu:
                loss.backward()
                optim.step()
                optim.clear_grad()

            result.append(loss)

        if use_ipu:
            paddle.fluid.core.IpuBackend.get_instance().weights_to_host()

        paddle.save(model.state_dict(), model_path)
        paddle.save(optim.state_dict(), optim_path)

        model.set_state_dict(paddle.load(model_path))
        optim.set_state_dict(paddle.load(optim_path))

        for epoch in range(100):
            # ipu only needs call model() to do forward/backward/grad_update
A
Allen Guo 已提交
180
            pred, loss = model(self.data, self.label)
181 182 183 184 185 186 187 188 189 190 191 192 193
            if not use_ipu:
                loss.backward()
                optim.step()
                optim.clear_grad()

            result.append(loss)

        if use_ipu:
            ipu_strategy.release_patch()

        return np.array(result)


A
Allen Guo 已提交
194 195 196
@unittest.skipIf(not paddle.is_compiled_with_ipu(),
                 "core is not compiled with IPU")
class TestPatch(IPUOpTest):
197

A
Allen Guo 已提交
198
    def setUp(cls):
199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214
        paddle.disable_static()

    def test(self, use_ipu=False):
        old_getter = ProgramCache.__getitem__
        old_step = LRScheduler.step

        ipu_strategy = paddle.static.IpuStrategy()
        ipu_strategy.release_patch()

        reset_getter = ProgramCache.__getitem__
        reset_step = LRScheduler.step

        self.assertTrue(reset_getter is old_getter)
        self.assertTrue(reset_step is old_step)


A
Allen Guo 已提交
215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255
class TestWithoutIdentityLoss1(TestBase):

    def create_model(self, use_ipu=False):
        return SimpleLayer(loss_op=self.loss_op,
                           use_softmax=True,
                           use_reduction=True,
                           use_identity_loss=False)


class TestWithoutIdentityLoss2(TestBase):

    def set_op_attrs(self):
        self.loss_op = paddle.fluid.layers.softmax_with_cross_entropy

    def set_data_feed(self):
        self.data = paddle.uniform((32, 3, 10, 10), dtype='float32')
        self.label = paddle.randint(0, 10, shape=[32, 1], dtype='int64')

    def create_model(self, use_ipu=False):
        return SimpleLayer(loss_op=self.loss_op,
                           use_softmax=False,
                           use_reduction=True,
                           use_identity_loss=False)


class TestWithoutIdentityLoss3(TestBase):

    def set_op_attrs(self):
        self.loss_op = partial(paddle.fluid.layers.kldiv_loss, reduction="none")

    def set_data_feed(self):
        self.data = paddle.uniform((32, 3, 10, 10), dtype='float32')
        self.label = paddle.rand(shape=[32, 81], dtype='float32')

    def create_model(self, use_ipu=False):
        return SimpleLayer(loss_op=self.loss_op,
                           use_softmax=True,
                           use_reduction=True,
                           use_identity_loss=False)


256 257
if __name__ == "__main__":
    unittest.main()