dist_fleet_ctr_ps_gpu.py 5.3 KB
Newer Older
C
Chengmo 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Distribute CTR model for test fleet api
"""

from __future__ import print_function

import shutil
import tempfile
import time

import paddle
import paddle.fluid as fluid
import os
import numpy as np

import ctr_dataset_reader
from test_dist_fleet_base import runtime_main, FleetDistRunnerBase
from dist_fleet_ctr import TestDistCTR2x2, fake_ctr_reader
from paddle.distributed.fleet.base.util_factory import fleet_util

# Fix seed for test
fluid.default_startup_program().random_seed = 1
fluid.default_main_program().random_seed = 1


class TestDistGpuPsCTR2x2(TestDistCTR2x2):
    """
    For test CTR model, using Fleet api & PS-GPU
    """

    def check_model_right(self, dirname):
        model_filename = os.path.join(dirname, "__model__")

        with open(model_filename, "rb") as f:
            program_desc_str = f.read()

        program = fluid.Program.parse_from_string(program_desc_str)
        with open(os.path.join(dirname, "__model__.proto"), "w") as wn:
            wn.write(str(program))

    def do_pyreader_training(self, fleet):
        """
        do training using dataset, using fetch handler to catch variable
        Args:
            fleet(Fleet api): the fleet object of Parameter Server, define distribute training role
        """
        device_id = int(os.getenv("FLAGS_selected_gpus", "0"))
        place = fluid.CUDAPlace(device_id)
        exe = fluid.Executor(place)
        fleet.init_worker()
        exe.run(fleet.startup_program)

        batch_size = 4
        train_reader = paddle.batch(fake_ctr_reader(), batch_size=batch_size)
        self.reader.decorate_sample_list_generator(train_reader)

        for epoch_id in range(1):
            self.reader.start()
            try:
                pass_start = time.time()
                while True:
                    loss_val = exe.run(program=fleet.main_program,
                                       fetch_list=[self.avg_cost.name])
                    loss_val = np.mean(loss_val)
                    reduce_output = fleet_util.all_reduce(
                        np.array(loss_val), mode="sum")
                    loss_all_trainer = fleet_util.all_gather(float(loss_val))
                    loss_val = float(reduce_output) / len(loss_all_trainer)
                    message = "TRAIN ---> pass: {} loss: {}\n".format(epoch_id,
                                                                      loss_val)
                    fleet_util.print_on_rank(message, 0)

                pass_time = time.time() - pass_start
            except fluid.core.EOFException:
                self.reader.reset()

        model_dir = tempfile.mkdtemp()
        fleet.save_inference_model(
            exe, model_dir, [feed.name for feed in self.feeds], self.avg_cost)
        self.check_model_right(model_dir)
        if fleet.is_first_worker():
            fleet.save_persistables(executor=exe, dirname=model_dir)
        shutil.rmtree(model_dir)
        fleet.stop_worker()

    def do_dataset_training(self, fleet):
        dnn_input_dim, lr_input_dim, train_file_path = ctr_dataset_reader.prepare_data(
        )

        device_id = int(os.getenv("FLAGS_selected_gpus", "0"))
        place = fluid.CUDAPlace(device_id)
        exe = fluid.Executor(place)

        fleet.init_worker()
        exe.run(fleet.startup_program)

        thread_num = 2
        batch_size = 128
        filelist = []
        for _ in range(thread_num):
            filelist.append(train_file_path)

        # config dataset
117 118 119
        dataset = paddle.distributed.QueueDataset()
        dataset._set_batch_size(batch_size)
        dataset._set_use_var(self.feeds)
C
Chengmo 已提交
120
        pipe_command = 'python ctr_dataset_reader.py'
121
        dataset._set_pipe_command(pipe_command)
C
Chengmo 已提交
122 123

        dataset.set_filelist(filelist)
124
        dataset._set_thread(thread_num)
C
Chengmo 已提交
125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152

        for epoch_id in range(1):
            pass_start = time.time()
            dataset.set_filelist(filelist)
            exe.train_from_dataset(
                program=fleet.main_program,
                dataset=dataset,
                fetch_list=[self.avg_cost],
                fetch_info=["cost"],
                print_period=2,
                debug=int(os.getenv("Debug", "0")))
            pass_time = time.time() - pass_start

        if os.getenv("SAVE_MODEL") == "1":
            model_dir = tempfile.mkdtemp()
            fleet.save_inference_model(exe, model_dir,
                                       [feed.name for feed in self.feeds],
                                       self.avg_cost)
            self.check_model_right(model_dir)
            if fleet.is_first_worker():
                fleet.save_persistables(executor=exe, dirname=model_dir)
            shutil.rmtree(model_dir)

        fleet.stop_worker()


if __name__ == "__main__":
    runtime_main(TestDistGpuPsCTR2x2)