gan_compression.py 4.4 KB
Newer Older
C
update  
ceci3 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
C
ceci3 已提交
14 15 16 17 18 19 20 21 22 23 24 25
import os
import time
import logging

import paddle.fluid as fluid
from dataset.data_loader import create_data
from utils.get_args import configs


class gan_compression:
    def __init__(self, cfgs, **kwargs):
        self.cfgs = cfgs
C
update  
ceci3 已提交
26 27 28 29 30 31 32 33 34 35 36 37 38
        use_gpu, use_parallel = self._get_device()

        if not use_gpu:
            place = fluid.CPUPlace()
        else:
            if not use_parallel:
                place = fluid.CUDAPlace(0)
            else:
                place = fluid.CUDAPlace(fluid.dygraph.parallel.Env().dev_id)

        setattr(self.cfgs, 'use_gpu', use_gpu)
        setattr(self.cfgs, 'use_parallel', use_parallel)
        setattr(self.cfgs, 'place', place)
C
ceci3 已提交
39 40 41
        for k, v in kwargs.items():
            setattr(self, k, v)

C
update  
ceci3 已提交
42 43 44 45
    def _get_device(self):
        num = self.cfgs.gpu_num

        use_gpu, use_parallel = False, False
C
fix  
ceci3 已提交
46
        if num == 0:
C
update  
ceci3 已提交
47 48 49 50 51 52 53
            use_gpu = False
        else:
            use_gpu = True
            if num > 1:
                use_parallel = True
        return use_gpu, use_parallel

C
ceci3 已提交
54 55
    def start_train(self):
        steps = self.cfgs.task.split('+')
C
ceci3 已提交
56 57
        model_weight = {}
        for idx, step in enumerate(steps):
C
ceci3 已提交
58 59 60
            if step == 'mobile':
                from models import create_model
            elif step == 'distiller':
C
update  
ceci3 已提交
61
                from distillers import create_distiller as create_model
C
ceci3 已提交
62
            elif step == 'supernet':
C
update  
ceci3 已提交
63
                from supernets import create_supernet as create_model
C
ceci3 已提交
64 65 66 67 68 69
            else:
                raise NotImplementedError

            print(
                "============================= start train {} ==============================".
                format(step))
C
ceci3 已提交
70 71 72 73 74 75
            fluid.enable_imperative(place=self.cfgs.place)

            if self.cfgs.use_parallel and idx == 0:
                strategy = fluid.dygraph.parallel.prepare_context()
                setattr(self.cfgs, 'strategy', strategy)

C
ceci3 已提交
76
            model = create_model(self.cfgs)
C
ceci3 已提交
77 78 79
            model.setup(model_weight)
            ### clear model_weight every step
            model_weight = {}
C
ceci3 已提交
80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100

            _train_dataloader, _ = create_data(self.cfgs)

            epochs = getattr(self.cfgs, '{}_epoch'.format(step))

            for epoch_id in range(epochs):
                for batch_id, data in enumerate(_train_dataloader()):
                    start_time = time.time()
                    model.set_input(data)
                    model.optimize_parameter()
                    batch_time = time.time() - start_time
                    if batch_id % self.cfgs.print_freq == 0:
                        message = 'epoch: %d, batch: %d batch_time: %fs' % (
                            epoch_id, batch_id, batch_time)
                        for k, v in model.get_current_lr().items():
                            message += '%s: %f ' % (k, v)
                        message += '\n'
                        for k, v in model.get_current_loss().items():
                            message += '%s: %.3f ' % (k, v)
                        logging.info(message)

C
ceci3 已提交
101 102 103 104 105
                if epoch_id == (epochs - 1):
                    for name in model.model_names:
                        model_weight[name] = model._sub_layers[
                            name].state_dict()

C
update  
ceci3 已提交
106 107 108
                save_model = (not self.cfgs.use_parallel) or (
                    self.cfgs.use_parallel and
                    fluid.dygraph.parallel.Env().local_rank == 0)
C
ceci3 已提交
109
                if epoch_id % self.cfgs.save_freq == 0 or epoch_id == (
C
update  
ceci3 已提交
110
                        epochs - 1) and save_model:
C
ceci3 已提交
111 112 113 114 115 116 117 118 119 120 121
                    model.evaluate_model(epoch_id)
                    model.save_network(epoch_id)
            print("=" * 80)


if __name__ == '__main__':
    cfg_instance = configs()
    cfgs = cfg_instance.get_all_config()
    cfg_instance.print_configs(cfgs)
    compression = gan_compression(cfgs)
    compression.start_train()