test_parallel_executor_run_cinn.py 4.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function

17
import logging
18
import numpy as np
19
import os
20
import paddle
21 22
import shutil
import tempfile
23 24 25 26
import unittest

paddle.enable_static()

27 28
logging.basicConfig(
    format='%(asctime)s - %(levelname)s - %(message)s', level=logging.INFO)
29
logger = logging.getLogger("paddle_with_cinn")
30 31 32


def set_cinn_flag(val):
33
    cinn_compiled = False
34 35
    try:
        paddle.set_flags({'FLAGS_use_cinn': val})
36
        cinn_compiled = True
37 38
    except ValueError:
        logger.warning("The used paddle is not compiled with CINN.")
39
    return cinn_compiled
40

41

42
def reader(limit):
43 44 45
    for _ in range(limit):
        yield np.random.random([1, 28]).astype('float32'), \
            np.random.randint(0, 2, size=[1]).astype('int64')
46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64


def rand_data(img, label, loop_num=10):
    feed = []
    data = reader(loop_num)
    for _ in range(loop_num):
        d, l = next(data)
        feed.append({img: d, label: l})
    return feed


def build_program(main_program, startup_program):
    with paddle.static.program_guard(main_program, startup_program):
        img = paddle.static.data(name='img', shape=[1, 28], dtype='float32')
        param = paddle.create_parameter(
            name="bias",
            shape=[1, 28],
            dtype="float32",
            attr=paddle.ParamAttr(initializer=paddle.nn.initializer.Assign(
65
                np.random.rand(1, 28).astype(np.float32))))
66 67 68 69 70 71 72 73 74 75 76 77
        label = paddle.static.data(name="label", shape=[1], dtype='int64')

        hidden = paddle.add(img, param)
        prediction = paddle.nn.functional.relu(hidden)

        loss = paddle.nn.functional.cross_entropy(input=prediction, label=label)
        avg_loss = paddle.mean(loss)
        adam = paddle.optimizer.Adam(learning_rate=0.001)
        adam.minimize(avg_loss)
    return img, label, avg_loss


78 79 80 81 82 83
def train(dot_save_dir, prefix, seed=1234):
    np.random.seed(seed)
    paddle.seed(seed)
    if paddle.is_compiled_with_cuda():
        paddle.set_flags({'FLAGS_cudnn_deterministic': 1})

84 85 86 87 88 89 90 91 92 93
    startup_program = paddle.static.Program()
    main_program = paddle.static.Program()
    img, label, loss = build_program(main_program, startup_program)

    place = paddle.CUDAPlace(0) if paddle.is_compiled_with_cuda(
    ) else paddle.CPUPlace()
    exe = paddle.static.Executor(place)
    exe.run(startup_program)

    build_strategy = paddle.static.BuildStrategy()
94
    build_strategy.debug_graphviz_path = os.path.join(dot_save_dir, prefix)
95 96 97
    compiled_program = paddle.static.CompiledProgram(
        main_program, build_strategy).with_data_parallel(loss_name=loss.name)

98
    iters = 100
99
    feed = rand_data(img.name, label.name, iters)
100
    loss_values = []
101 102 103 104 105
    for step in range(iters):
        loss_v = exe.run(compiled_program,
                         feed=feed[step],
                         fetch_list=[loss],
                         return_merged=False)
106 107
        loss_values.append(loss_v[0][0][0])
    return loss_values
108 109


110
@unittest.skipIf(not set_cinn_flag(True), "Paddle is not compiled with CINN.")
111
class TestParallelExecutorRunCinn(unittest.TestCase):
112 113
    def setUp(self):
        self.tmpdir = tempfile.mkdtemp(prefix="dots_")
114

115 116 117 118
    def tearDown(self):
        shutil.rmtree(self.tmpdir)

    def test_run_with_cinn(self):
119 120 121 122
        cinn_losses = train(self.tmpdir, "paddle")
        set_cinn_flag(False)
        pd_losses = train(self.tmpdir, "cinn")
        self.assertTrue(np.allclose(cinn_losses, pd_losses, atol=1e-5))
123 124 125 126


if __name__ == '__main__':
    unittest.main()