test_eager_deletion_dynamic_rnn_base.py 2.6 KB
Newer Older
S
sneaxiy 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
16

S
sneaxiy 已提交
17 18 19 20
os.environ['CPU_NUM'] = '2'

import unittest

21 22
from fake_reader import fake_imdb_reader

S
sneaxiy 已提交
23
import paddle
24 25
from paddle import fluid
from paddle.fluid import core
S
sneaxiy 已提交
26 27


28
def train(network, use_cuda, batch_size=32, pass_num=2):
S
sneaxiy 已提交
29 30 31 32
    if use_cuda and not core.is_compiled_with_cuda():
        print('Skip use_cuda=True because Paddle is not compiled with cuda')
        return

33 34 35
    word_dict_size = 5147
    reader = fake_imdb_reader(word_dict_size, batch_size * 40)
    train_reader = paddle.batch(reader, batch_size=batch_size)
S
sneaxiy 已提交
36

G
GGBond8488 已提交
37 38
    data = paddle.static.data(
        name="words", shape=[-1, 1], dtype="int64", lod_level=1
39
    )
S
sneaxiy 已提交
40

G
GGBond8488 已提交
41
    label = paddle.static.data(name="label", shape=[-1, 1], dtype="int64")
S
sneaxiy 已提交
42

43
    cost = network(data, label, word_dict_size)
S
sneaxiy 已提交
44
    cost.persistable = True
S
sneaxiy 已提交
45 46 47 48 49
    optimizer = fluid.optimizer.Adagrad(learning_rate=0.2)
    optimizer.minimize(cost)

    place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
    feeder = fluid.DataFeeder(feed_list=[data, label], place=place)
50
    reader = feeder.feed(train_reader())
S
sneaxiy 已提交
51 52

    exe = fluid.Executor(place)
D
dzhwinter 已提交
53 54
    fluid.default_startup_program().random_seed = 1
    fluid.default_main_program().random_seed = 1
S
sneaxiy 已提交
55 56
    exe.run(fluid.default_startup_program())

C
chengduo 已提交
57
    train_cp = fluid.default_main_program()
58
    fetch_list = [cost]
S
sneaxiy 已提交
59

60
    for pass_id in range(pass_num):
S
sneaxiy 已提交
61 62
        batch_id = 0
        for data in reader():
63 64 65 66 67
            exe.run(
                train_cp,
                feed=data,
                fetch_list=fetch_list if batch_id % 4 == 0 else [],
            )
S
sneaxiy 已提交
68 69 70 71 72 73 74 75 76 77 78 79 80 81
            batch_id += 1
            if batch_id > 16:
                break


class TestBase(unittest.TestCase):
    def setUp(self):
        self.net = None

    def test_network(self):
        if self.net is None:
            return

        for use_cuda in [True, False]:
82
            print(f'network: {self.net.__name__}, use_cuda: {use_cuda}')
83 84 85
            with fluid.program_guard(fluid.Program(), fluid.Program()):
                with fluid.scope_guard(core.Scope()):
                    train(self.net, use_cuda)