test_program_translator.py 10.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15 16
import inspect
import textwrap
17 18
import unittest

19 20
import astor
import numpy as np
21 22 23 24
from ifelse_simple_func import (
    dyfunc_with_if_else_early_return1,
    dyfunc_with_if_else_early_return2,
)
25

26 27
import paddle
import paddle.jit.dy2static as _jst
28
from paddle import fluid
H
hjyp 已提交
29
from paddle.jit.api import to_static
30
from paddle.jit.dy2static.utils import func_to_source_code
31 32
from paddle.utils import gast

33 34 35
np.random.seed(0)


36 37 38 39
# TODO(Aurelius): Currently, `declarative` don't support decorate the function
# that contains layers with initialized operation, like `fc = linear(10, 3)`.
# Because initialized ops will be added into program and be executed many times.
# The parameters are assumed to initialized outside of the function.
40 41
def simple_func(x, weight_numpy):
    x = fluid.dygraph.to_variable(x)
42
    w = fluid.dygraph.to_variable(weight_numpy)
K
kangguangli 已提交
43
    y = paddle.matmul(x, w)
44
    z = paddle.mean(y)
45 46 47
    return z


H
hjyp 已提交
48
@to_static
49 50
def decorated_simple_func(x, weight_numpy):
    x = fluid.dygraph.to_variable(x)
51
    w = fluid.dygraph.to_variable(weight_numpy)
K
kangguangli 已提交
52
    y = paddle.matmul(x, w)
53
    z = paddle.mean(y)
54 55
    return z

56

57 58 59 60 61 62 63 64
def get_source_code(func):
    raw_code = inspect.getsource(func)
    code = textwrap.dedent(raw_code)
    root = gast.parse(code)
    source_code = astor.to_source(gast.gast_to_ast(root))
    return source_code


65
class StaticCode1:
66
    def dyfunc_with_if_else(x_v, label=None):
67 68 69
        loss = _jst.UndefinedVar('loss')
        __return_1 = _jst.UndefinedVar('__return_1')
        __return_0 = _jst.UndefinedVar('__return_0')
70
        __return_value_0 = None
71

72 73
        def get_args_0():
            nonlocal x_v
74
            return (x_v,)
75 76 77

        def set_args_0(__args):
            nonlocal x_v
78
            (x_v,) = __args
79 80 81

        def true_fn_0():
            nonlocal x_v
82
            x_v = x_v - 1
83
            return
84

85 86
        def false_fn_0():
            nonlocal x_v
87
            x_v = x_v + 1
88
            return
89

90 91 92 93 94 95 96 97 98
        _jst.IfElse(
            paddle.mean(x_v)[0] > 5,
            true_fn_0,
            false_fn_0,
            get_args_0,
            set_args_0,
            ('x_v',),
            push_pop_names=None,
        )
99 100

        def get_args_1():
101 102
            nonlocal __return_0, __return_1, __return_value_0, loss
            return __return_0, __return_1, __return_value_0, loss
103 104

        def set_args_1(__args):
105 106
            nonlocal __return_0, __return_1, __return_value_0, loss
            __return_0, __return_1, __return_value_0, loss = __args
107

108
        def true_fn_1():
109
            nonlocal __return_0, __return_1, __return_value_0, loss
110 111 112
            loss = paddle.nn.functional.cross_entropy(
                x_v, label, reduction='none', use_softmax=False
            )
113
            __return_0 = _jst.create_bool_as_type(label is not None, True)
114
            __return_value_0 = loss
115
            return
116

117
        def false_fn_1():
118
            nonlocal __return_0, __return_1, __return_value_0, loss
119 120
            __return_1 = _jst.create_bool_as_type(label is not None, True)
            __return_value_0 = x_v
121
            return
122

123 124 125 126 127 128 129 130 131
        _jst.IfElse(
            label is not None,
            true_fn_1,
            false_fn_1,
            get_args_1,
            set_args_1,
            ('__return_0', '__return_1', '__return_value_0', 'loss'),
            push_pop_names=None,
        )
132
        return __return_value_0
133 134


135
class StaticCode2:
136
    # TODO: Transform return statement
137
    def dyfunc_with_if_else(x_v, label=None):
138 139 140
        loss = _jst.UndefinedVar('loss')
        __return_3 = _jst.UndefinedVar('__return_3')
        __return_2 = _jst.UndefinedVar('__return_2')
141
        __return_value_1 = None
142

143 144
        def get_args_2():
            nonlocal x_v
145
            return (x_v,)
146 147 148

        def set_args_2(__args):
            nonlocal x_v
149
            (x_v,) = __args
150 151 152

        def true_fn_2():
            nonlocal x_v
153
            x_v = x_v - 1
154
            return
155

156 157
        def false_fn_2():
            nonlocal x_v
158
            x_v = x_v + 1
159
            return
160

161 162 163 164 165 166 167 168 169
        _jst.IfElse(
            paddle.mean(x_v)[0] > 5,
            true_fn_2,
            false_fn_2,
            get_args_2,
            set_args_2,
            ('x_v',),
            push_pop_names=None,
        )
170 171

        def get_args_3():
172 173
            nonlocal __return_2, __return_3, __return_value_1, loss
            return __return_2, __return_3, __return_value_1, loss
174 175

        def set_args_3(__args):
176 177
            nonlocal __return_2, __return_3, __return_value_1, loss
            __return_2, __return_3, __return_value_1, loss = __args
178

179
        def true_fn_3():
180
            nonlocal __return_2, __return_3, __return_value_1, loss
181 182 183
            loss = paddle.nn.functional.cross_entropy(
                x_v, label, reduction='none', use_softmax=False
            )
184
            __return_2 = _jst.create_bool_as_type(label is not None, True)
185
            __return_value_1 = loss
186
            return
187

188
        def false_fn_3():
189
            nonlocal __return_2, __return_3, __return_value_1, loss
190 191
            __return_3 = _jst.create_bool_as_type(label is not None, True)
            __return_value_1 = x_v
192
            return
193

194 195 196 197 198 199 200 201 202
        _jst.IfElse(
            label is not None,
            true_fn_3,
            false_fn_3,
            get_args_3,
            set_args_3,
            ('__return_2', '__return_3', '__return_value_1', 'loss'),
            push_pop_names=None,
        )
203
        return __return_value_1
204 205


206
class NetWithError(paddle.nn.Layer):
H
hjyp 已提交
207
    @to_static
208
    def forward(self, x):
209
        linear = paddle.nn.Linear(32, 64)
210 211 212 213
        y = linear(x)
        return y


214
class TestEnableDeclarative(unittest.TestCase):
215 216 217 218 219 220
    def setUp(self):
        self.x = np.random.randn(30, 10, 32).astype('float32')
        self.weight = np.random.randn(32, 64).astype('float32')

    def test_raise_error(self):
        with fluid.dygraph.guard():
R
Ryan 已提交
221
            paddle.jit.enable_to_static(True)
222 223 224
            net = NetWithError()
            with self.assertRaises(ValueError):
                net(fluid.dygraph.to_variable(self.x))
225 226 227

    def test_enable_disable_declarative(self):

R
Ryan 已提交
228
        paddle.jit.enable_to_static(True)
229
        with fluid.dygraph.guard():
230
            static_output = decorated_simple_func(self.x, self.weight)
231

R
Ryan 已提交
232
        paddle.jit.enable_to_static(False)
233
        with fluid.dygraph.guard():
234
            dygraph_output = decorated_simple_func(self.x, self.weight)
235 236 237 238 239 240
            np.testing.assert_allclose(
                static_output.numpy(),
                dygraph_output.numpy(),
                rtol=1e-05,
                atol=1e-4,
            )
241 242


243
class Net(paddle.nn.Layer):
244
    def __init__(self):
245
        super().__init__()
246 247 248 249 250

    def forward(self, x):
        return x + 1


251 252
class SwitchModeNet(paddle.nn.Layer):
    def __init__(self):
253
        super().__init__()
254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297

    @paddle.jit.to_static
    def forward(self, x):
        return x + 1

    @paddle.jit.to_static
    def foo(self):
        return True


@paddle.jit.to_static
def switch_mode_funciton():
    return True


class TestFunctionTrainEvalMode(unittest.TestCase):
    def test_switch_mode(self):
        paddle.disable_static()
        switch_mode_funciton.eval()
        switch_mode_funciton()
        self.assertEqual(switch_mode_funciton._training, False)
        _, partial_layer = switch_mode_funciton.program_cache.last()[-1]
        self.assertEqual(partial_layer.training, False)

        switch_mode_funciton.train()
        switch_mode_funciton()
        self.assertEqual(switch_mode_funciton._training, True)
        _, partial_layer = switch_mode_funciton.program_cache.last()[-1]
        self.assertEqual(partial_layer.training, True)

    def test_raise_error(self):
        paddle.disable_static()
        net = SwitchModeNet()

        self.assertEqual(net.training, True)
        with self.assertRaises(RuntimeError):
            net.forward.eval()

        net.eval()
        self.assertEqual(net.training, False)
        with self.assertRaises(RuntimeError):
            net.foo.train()


298 299 300 301 302
class TestIfElseEarlyReturn(unittest.TestCase):
    def test_ifelse_early_return1(self):
        answer = np.zeros([2, 2]) + 1
        static_func = paddle.jit.to_static(dyfunc_with_if_else_early_return1)
        out = static_func()
303
        np.testing.assert_allclose(answer, out[0].numpy(), rtol=1e-05)
304 305 306 307 308

    def test_ifelse_early_return2(self):
        answer = np.zeros([2, 2]) + 3
        static_func = paddle.jit.to_static(dyfunc_with_if_else_early_return2)
        out = static_func()
309
        np.testing.assert_allclose(answer, out[0].numpy(), rtol=1e-05)
310 311


312 313 314 315 316 317 318 319 320 321 322 323 324
class TestRemoveCommentInDy2St(unittest.TestCase):
    def func_with_comment(self):
        # Comment1
        x = paddle.to_tensor([1, 2, 3])
        # Comment2
        # Comment3
        y = paddle.to_tensor([4, 5, 6])

    def test_remove_comment(self):
        code_string = func_to_source_code(self.func_with_comment)
        self.assertEqual('#' not in code_string, True)


325 326 327 328 329 330 331 332 333 334 335 336 337
class Obj:
    def __init__(self):
        pass

    def func(self, x):
        return x + 1


obj = Obj()


class Net2:
    def __init__(self):
338
        super().__init__()
339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360
        self.layer1 = paddle.nn.Linear(10, 10)

    def forward(self, data):
        @paddle.jit.to_static
        def func(ins, x, loss_fn):
            x = ins.layer1(x)
            return loss_fn(x)

        def func1(x):
            return func(self, x, obj.func)

        return func1(data)


class TestParameterRecorder(unittest.TestCase):
    def test_recorder(self):
        """function calls nn.Layer case."""
        net = Net()
        x = paddle.randn([5, 10])
        out = net.forward(x)


361
if __name__ == '__main__':
362
    unittest.main()