distillation_strategy.py 4.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from ..core.strategy import Strategy
16
from ....framework import Program, Variable, program_guard
17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40
from .... import Executor
import logging

__all__ = ['DistillationStrategy']

logging.basicConfig(format='%(asctime)s-%(levelname)s: %(message)s')
_logger = logging.getLogger(__name__)
_logger.setLevel(logging.INFO)


class DistillationStrategy(Strategy):
    def __init__(self, distillers=None, start_epoch=0, end_epoch=0):
        """
        Args:
            distillers(list): A list of distiller used to combine student graph and teacher graph
                              by adding some loss.
            start_epoch(int): The epoch when to merge student graph and teacher graph for
                              distillation training. default: 0
            end_epoch(int): The epoch when to finish distillation training. default: 0
            
        """
        super(DistillationStrategy, self).__init__(start_epoch, end_epoch)
        self.distillers = distillers

41
    def restore_from_checkpoint(self, context):
42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76
        # load from checkpoint
        if context.epoch_id > 0:
            if context.epoch_id > self.start_epoch and context.epoch_id < self.end_epoch:
                _logger.info('Restore DistillationStrategy')
                self._create_distillation_graph(context)
                _logger.info('Restore DistillationStrategy finish.')

    def on_epoch_begin(self, context):
        if self.start_epoch == context.epoch_id:
            _logger.info('DistillationStrategy::on_epoch_begin.')
            self._create_distillation_graph(context)
            _logger.info('DistillationStrategy set optimize_graph.')

    def _create_distillation_graph(self, context):
        """
        step 1: Merge student graph and teacher graph into distillation graph.
        step 2: Add loss into distillation graph by distillers.
        step 3: Append backward ops and optimize ops into distillation graph for training.
        """
        # step 1
        teacher = context.teacher_graphs[0]
        for var in teacher.program.list_vars():
            var.stop_gradient = True
        graph = context.train_graph.clone()
        graph.merge(teacher)
        graph.out_nodes['student_loss'] = graph.out_nodes['loss']

        # step 2
        for distiller in self.distillers:
            graph = distiller.distiller_loss(graph)

        # step 3
        startup_program = Program()
        with program_guard(graph.program, startup_program):
            context.distiller_optimizer._name = 'distillation_optimizer'
77 78 79 80 81 82 83 84 85 86 87

            # The learning rate variable may be created in other program.
            # Update information in optimizer to make
            # learning rate variable being accessible in current program.
            optimizer = context.distiller_optimizer
            if isinstance(optimizer._learning_rate, Variable):
                optimizer._learning_rate_map[
                    graph.program] = optimizer._learning_rate

            optimizer.minimize(graph.var(graph.out_nodes['loss'])._var)

88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103
        exe = Executor(context.place)
        exe.run(startup_program, scope=context.scope)

        # backup graph for fine-tune after distillation
        context.put('distillation_backup_optimize_graph',
                    context.optimize_graph)
        context.optimize_graph = graph

    def on_epoch_end(self, context):
        if context.epoch_id == (self.end_epoch - 1):
            _logger.info('DistillationStrategy::on_epoch_end.')
            # restore optimize_graph for fine-tune or other strategy in next stage.
            context.optimize_graph = context.get(
                'distillation_backup_optimize_graph')
            _logger.info(
                'DistillationStrategy set context.optimize_graph to None.')