strategy_config.py 24.6 KB
Newer Older
C
ceci3 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
#   Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from collections import namedtuple

__all__ = [
W
whs 已提交
18
    "BaseStrategy",
19
    "QuantAware",
W
whs 已提交
20 21 22 23 24 25 26 27 28 29 30 31
    "Distillation",
    "MultiTeacherDistillation",
    "HyperParameterOptimization",
    "ChannelPrune",
    "UnstructurePrune",
    "TransformerPrune",
    "ASPPrune",
    "merge_config",
    "ProgramInfo",
    "TrainConfig",
    "SUPPORTED_CONFIG",
    "TRAIN_CONFIG_NAME",
Z
zhouzj 已提交
32
    "QuantPost",
C
ceci3 已提交
33 34
]

W
whs 已提交
35
SUPPORTED_CONFIG = [
36
    "QuantAware",
W
whs 已提交
37 38 39 40 41 42 43
    "Distillation",
    "MultiTeacherDistillation",
    "HyperParameterOptimization",
    "ChannelPrune",
    "UnstructurePrune",
    "TransformerPrune",
    "ASPPrune",
Z
zhouzj 已提交
44
    "QuantPost",
W
whs 已提交
45 46 47 48 49 50 51 52 53
]

TRAIN_CONFIG_NAME = "TrainConfig"


class BaseStrategy:
    def __init__(self, name):
        self.name = name

C
ceci3 已提交
54

55
class QuantAware(BaseStrategy):
C
ceci3 已提交
56 57
    def __init__(self,
                 quantize_op_types=[
58 59
                     'conv2d', 'depthwise_conv2d', 'conv2d_transpose', 'mul',
                     'matmul', 'matmul_v2'
C
ceci3 已提交
60 61 62 63 64 65 66 67 68 69 70
                 ],
                 weight_bits=8,
                 activation_bits=8,
                 not_quant_pattern=['skip_quant'],
                 use_pact=False,
                 activation_quantize_type='moving_average_abs_max',
                 weight_quantize_type='channel_wise_abs_max',
                 dtype='int8',
                 window_size=10000,
                 moving_rate=0.9,
                 for_tensorrt=False,
G
Guanghua Yu 已提交
71
                 onnx_format=True,
C
ceci3 已提交
72 73 74 75 76 77 78 79 80 81 82 83 84 85 86
                 is_full_quantize=False):
        """
        Quantization Config.
        Args:
            quantize_op_types(list(str)): Ops of type in quantize_op_types, will be quantized. Default: ['conv2d', 'depthwise_conv2d', 'mul', 'matmul', 'matmul_v2'].
            weight_bits(int): Weight quantize bit num. Default: 8.
            activation_bits(int): Activation quantize bit num. Default 8.
            not_quant_pattern(list(str)): Ops of name_scope in not_quant_pattern list, will not be quantized. Default: 'skip_quant'.
            use_pact(bool): Whether to use pact in quantization training. Default: False.
            activation_quantize_type(str): Activation quantize type. Default is 'moving_average_abs_max'.
            weight_quantize_type(str): Weight quantize type. Default 'channel_wise_abs_max'.
            dtype(str): Data type after quantization, such as 'uint8', 'int8', etc. default is 'int8'.
            window_size(int): Window size for 'range_abs_max' quantization. Default: 10000.
            moving_rate(float): The decay coefficient of moving average. Default: 0.9.
            for_tensorrt(bool): If True, 'quantize_op_types' will be TENSORRT_OP_TYPES. Default: False.
87
            onnx_format(bool): Whether to export the quantized model with format of ONNX. Default is False.
C
ceci3 已提交
88 89
            is_full_quantize(bool): If True, 'quantoze_op_types' will be TRANSFORM_PASS_OP_TYPES + QUANT_DEQUANT_PASS_OP_TYPES. Default: False.
        """
90
        super(QuantAware, self).__init__("QuantAware")
C
ceci3 已提交
91 92 93 94 95 96 97 98 99 100 101 102
        self.quantize_op_types = quantize_op_types
        self.weight_bits = weight_bits
        self.activation_bits = activation_bits
        self.not_quant_pattern = not_quant_pattern
        self.use_pact = use_pact
        self.is_full_quantize = is_full_quantize
        self.activation_quantize_type = activation_quantize_type
        self.weight_quantize_type = weight_quantize_type
        self.dtype = dtype
        self.window_size = window_size
        self.moving_rate = moving_rate
        self.for_tensorrt = for_tensorrt
103
        self.onnx_format = onnx_format
C
ceci3 已提交
104 105 106
        self.is_full_quantize = is_full_quantize


W
whs 已提交
107
class Distillation(BaseStrategy):
C
ceci3 已提交
108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124
    def __init__(self,
                 loss='l2',
                 node=[],
                 alpha=1.0,
                 teacher_model_dir=None,
                 teacher_model_filename=None,
                 teacher_params_filename=None):
        """
        Distillation Config.
        Args:
            loss(str|list(str)): Distillation loss, the type of loss can be set reference `<https://paddleslim.readthedocs.io/zh_CN/latest/api_cn/static/dist/single_distiller_api.html>`_. If set list of loss, means the difference node can be set difference distill loss, the length of loss must equal to length of node. Default: 'l2'.
            node(list(str)|list(list(str))): Distillation node, users can set node from the model before compress. If set list of list, every inner list used same distill loss, the length of list must equal to length of loss.  Default: [].
            alpha(float|list(float)): The lambda of distillation loss. If set list of alpha, the length of alpha must equal to length of loss. Default: 1.0. 
            teacher_model_dir(str, optional): The path of teacher inference model, and the model and params that saved by ``paddle.static.io.save_inference_model`` are under the path. If set to None, the teacher model will be set to the model before compress. Default: None.
            teacher_model_filename(str, optional): The name of teacher model file. If parameters are saved in separate files, set it as 'None'. Default: 'None'.
            teacher_params_filename(str, optional): The name of teacher params file. When all parameters are saved in a single file, set it as filename. If parameters are saved in separate files, set it as 'None'. Default : 'None'.
        """
W
whs 已提交
125
        super(Distillation, self).__init__("Distillation")
C
ceci3 已提交
126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159
        self.loss = loss
        self.node = node
        self.alpha = alpha
        self.teacher_model_dir = teacher_model_dir
        self.teacher_model_filename = teacher_model_filename
        self.teacher_params_filename = teacher_params_filename


class MultiTeacherDistillation:
    def __init__(self,
                 loss=[],
                 node=[],
                 alpha=[],
                 teacher_model_dir=[],
                 teacher_model_filename=[],
                 teacher_params_filename=[]):
        """
        Multi-Teacher Distillation Config.
        Args:
            loss(list(str)): The list of distillation loss, the type of loss can be set reference `<https://paddleslim.readthedocs.io/zh_CN/latest/api_cn/static/dist/single_distiller_api.html>`_. One-to-one correspondence between loss and teacher model. Default: [].
            node(list(list(str))): Distillation node, users can set node from the model before compress. If set list of list, every inner list used same distill loss, the length of list must equal to length of loss.  Default: [].
            alpha(list(float)): The list of lambda of distillation loss. One-to-one correspondence between alpha and loss. Default: []. 
            teacher_model_dir(list): The list of path of teacher inference model, and the model and params that saved by ``paddle.static.io.save_inference_model`` are under the path. If set to None, the teacher model will be set to the model before compress. Default: None.
            teacher_model_filename(list): The list of name of teacher model file. If parameters are saved in separate files, set it as 'None'. Default: 'None'.
            teacher_params_filename(list): The list of name of teacher params fie. When all parameters are saved in a single file, set it as filename. If parameters are saved in separate files, set it as 'None'. Default : 'None'.
        """
        self.loss = loss
        self.node = node
        self.alpha = alpha
        self.teacher_model_dir = teacher_model_dir
        self.teacher_model_filename = teacher_model_filename
        self.teacher_params_filename = teacher_params_filename


W
whs 已提交
160
class HyperParameterOptimization(BaseStrategy):
C
ceci3 已提交
161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177
    def __init__(self,
                 ptq_algo=["KL", "hist", "avg", "mse"],
                 bias_correct=[True, False],
                 weight_quantize_type=['channel_wise_abs_max'],
                 hist_percent=[0.98, 0.999],
                 batch_num=[10, 30],
                 max_quant_count=20):
        """
        HyperParameterOptimization Config.
        Args:
            ptq_algo(list(str)): Post-Training Quantization algorithm, can be set reference the algo from `<https://paddleslim.readthedocs.io/zh_CN/latest/api_cn/static/quant/quantization_api.html#quant-post-static>`_.
            bias_correct(list(bool)): Whether to use bias_correct.
            weight_quantize_type(list(str)): Quantization type for weight, can be set from 'channel_abs_max' or 'abs_max'.
            hist_percent(list(float)): The upper and lower bounds of threshold of algo 'hist' for activations, the real percent is uniform sampling in this bounds.
            batch_num(list(int)): The upper and lower bounds of batch number, the real batch number is uniform sampling in this bounds.
            max_quant_count(int): Max number of model quantization. Default: 20.
        """
W
whs 已提交
178
        super(HyperParameterOptimization, self).__init__("HPO_PTQ")
C
ceci3 已提交
179 180 181 182 183 184 185 186
        self.ptq_algo = ptq_algo
        self.bias_correct = bias_correct
        self.weight_quantize_type = weight_quantize_type
        self.hist_percent = hist_percent
        self.batch_num = batch_num
        self.max_quant_count = max_quant_count


Z
zhouzj 已提交
187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203
class QuantPost(BaseStrategy):
    def __init__(self,
                 batch_size=32,
                 batch_nums=None,
                 epochs=20,
                 lr=0.1,
                 algo='hist',
                 hist_percent=0.999,
                 regions=None,
                 region_weights_names=None,
                 recon_level=None,
                 is_full_quantize=False,
                 bias_correction=False,
                 weight_quantize_type='channel_wise_abs_max',
                 activation_quantize_type='range_abs_max',
                 simulate_activation_quant=False,
                 skip_tensor_list=None,
G
Guanghua Yu 已提交
204
                 onnx_format=True,
Z
zhouzj 已提交
205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253
                 quantize_op_types=[
                     "conv2d", "depthwise_conv2d", "mul", "matmul", "matmul_v2"
                 ],
                 weight_bits=8,
                 activation_bits=8):
        """
        QuantPost Config.
        Args:
            batch_size(int, optional): The batch size of DataLoader. Default: 1.
            batch_nums(int, optional): If batch_nums is not None, the number of calibrate data is 'batch_size*batch_nums'. If batch_nums is None, use all data generated by sample_generator as calibrate data. Default: None.
            lr(float, optional): The learning rate of Reconstruction Quanter. Default: 0.1.
            algo(str, optional): Post-Training Quantization algorithm, can be set reference the algo from `<https://paddleslim.readthedocs.io/zh_CN/latest/api_cn/static/quant/quantization_api.html#quant-post-static>`. Default: 'hist'.
            hist_percent(float, optional): The percentile of histogram for algo hist. Default: 0.999.
            regions(list[list], optional): The list of some regions, each region is a subgraph of fp32 program and it will have exact 1 input operation and 1 output operation. When the recon-level is region, the reconstruction loss of each region is minimized. Default: None.
            region_weights_names(list[list], optional): The weight names inside every region. Default: None.
            recon_level(str, optional): The type of reconstruction granularity. Currently support ['layer-wise', 'region-wise'] types. Only when recon_level isn't None can Reconstruction Quanter be used. Default: None. 
            is_full_quantize(bool): If True, 'quantoze_op_types' will be TRANSFORM_PASS_OP_TYPES + QUANT_DEQUANT_PASS_OP_TYPES. Default: False.
            bias_correct(list(bool)): Whether to use bias correction method of https://arxiv.org/abs/1810.05723. Default: False.
            weight_quantize_type(str): Weight quantize type. Default: 'channel_wise_abs_max'.
            activation_quantize_type(str): Activation quantize type. Default: 'moving_average_abs_max'.
            simulate_activation_quant(bool, optional): Whether we need the noise caused by activation quantization during the reconstruction process. Default: False.
            skip_tensor_list(list): List of skip quant tensor name. Default: None.
            onnx_format(bool): Whether to export the quantized model with format of ONNX. Default: False.
            quantize_op_types(list(str)): Ops of type in quantize_op_types, will be quantized. Default: ['conv2d', 'depthwise_conv2d', 'mul', 'matmul', 'matmul_v2'].
            weight_bits(int): Weight quantize bit num. Default: 8.
            activation_bits(int): Activation quantize bit num. Default: 8.
        """
        super(QuantPost, self).__init__("PTQ")
        self.batch_size = batch_size
        self.batch_nums = batch_nums
        self.epochs = epochs
        self.lr = lr
        self.algo = algo
        self.hist_percent = hist_percent
        self.regions = regions
        self.region_weights_names = region_weights_names
        self.recon_level = recon_level
        self.is_full_quantize = is_full_quantize
        self.bias_correction = bias_correction
        self.weight_quantize_type = weight_quantize_type
        self.activation_quantize_type = activation_quantize_type
        self.simulate_activation_quant = simulate_activation_quant
        self.skip_tensor_list = skip_tensor_list
        self.onnx_format = onnx_format
        self.quantize_op_types = quantize_op_types
        self.weight_bits = weight_bits
        self.activation_bits = activation_bits


C
ceci3 已提交
254
class ChannelPrune:
C
ceci3 已提交
255 256 257 258
    def __init__(self,
                 pruned_ratio,
                 prune_params_name=None,
                 criterion='l1_norm'):
C
ceci3 已提交
259 260 261
        """
        ChannelPrune Config.
        Args:
C
ceci3 已提交
262
            pruned_ratio(float|list[float]): The ratios to be pruned.
C
ceci3 已提交
263 264 265 266 267 268 269 270 271
            prune_params_name(list(str)): A list of parameter names to be pruned.
            criterion(str|function): the criterion used to sort channels for pruning, can be choose from ['l1_norm', 'bn_scale', 'geometry_median']. Default: 'l1_norm'.
        """
        self.pruned_ratio = pruned_ratio
        self.prune_params_name = prune_params_name
        self.criterion = criterion


class ASPPrune:
C
ceci3 已提交
272
    def __init__(self, prune_params_name=None):
C
ceci3 已提交
273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297
        """
        ASPPrune Config.
        Args:
            prune_params_name(list(str)): A list of parameter names to be pruned.
        """
        self.prune_params_name = prune_params_name


class TransformerPrune:
    def __init__(self, pruned_ratio):
        """
        TransformerPrune Config.
        Args:
            pruned_ratio(float): The ratios to be pruned each fully-connected layer.
        """
        self.pruned_ratio = pruned_ratio


class UnstructurePrune:
    def __init__(self,
                 prune_strategy=None,
                 prune_mode='ratio',
                 threshold=0.01,
                 ratio=0.55,
                 gmp_config=None,
C
ceci3 已提交
298
                 prune_params_type='conv1x1_only',
C
ceci3 已提交
299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334
                 local_sparsity=False):
        """
        UnstructurePrune Config.
        Args:
            prune_strategy(str, optional): The pruning strategy, currently we support base and gmp, ``None`` means use base pruning strategy. Default: ``None``.
            prune_mode(str): The pruning mode: whether by ratio or by threshold. Default: 'ratio'.
            threshold(float): The threshold to set zeros, the abs(weights) lower than which will be zeros. Default: 0.01.
            ratio(float): The ratio to set zeros, the smaller portion will be zeros. Default: 0.55.
            gmp_config(dict): The dictionary contains all the configs for GMP pruner. Default: None. The detailed description is as below:
              .. code-block:: python
                     
                     {'stable_iterations': int} # the duration of stable phase in terms of global iterations
                     {'pruning_iterations': int} # the duration of pruning phase in terms of global iterations
                     {'tunning_iterations': int} # the duration of tunning phase in terms of global iterations
                     {'resume_iteration': int} # the start timestamp you want to train from, in terms if global iteration
                     {'pruning_steps': int} # the total times you want to increase the ratio
                     {'initial_ratio': float} # the initial ratio value
              
              ..
            prune_params_type(str): Which kind of params should be pruned, we only support None (all but norms) and conv1x1_only for now. Default: None.
            local_sparsity(bool): Whether to prune all the parameter matrix at the same ratio or not. Default: False.
        """
        self.prune_strategy = prune_strategy
        self.prune_mode = prune_mode
        self.threshold = threshold
        self.ratio = ratio
        self.gmp_config = gmp_config
        self.prune_params_type = prune_params_type
        self.local_sparsity = local_sparsity


class TrainConfig:
    def __init__(self,
                 epochs=None,
                 train_iter=None,
                 learning_rate=0.02,
W
whs 已提交
335 336 337
                 optimizer_builder={'optimizer': {
                     'type': 'SGD'
                 }},
C
ceci3 已提交
338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373
                 eval_iter=1000,
                 logging_iter=10,
                 origin_metric=None,
                 target_metric=None,
                 amp_config=None,
                 recompute_config=None,
                 sharding_config=None,
                 sparse_model=False):
        """
        Train Config.
        Args:
            epochs(int): The number of total epochs. Default: None.
            train_iter(int):  Training total iteration, `epochs` or `train_iter` only need to set one. Default: None.
            learning_rate(float|dict): learning rate in the training. If set dict, the detailed description of learning_rate is as blow: 
              .. code-block:: python
                     
                  'type'(str) # the class name of learning rate decay, can reference in paddle.optimizer.lr.
              ..
              other keys in the learning_rate depend on the parameters in the class of learning rate decay. 
              Such as, if you want to use ``PiecewiseDecay``, need to set learning_rate like: 
              {'type': PiecewiseDecay, 'boundaries': [4500], 'values': [0.005, 0.0005]}.
            optimizer_builder(str|dict): optimizer in th training. If set dict, the detailed description of optimizer_builder is as blow:
              .. code-block:: python
                     
                  'optimizer'(dict) # the 'type' in the optimizer need to be the class name in the paddle.optimizer,  
                                      other key of optimzer depend on the parameters in the class.
                  'weight_decay(float, optional)' # weight decay in the training.
                  'regularizer(dict)': # the 'type' in the regularizer need to be the class name in the paddle.regularizer, 
                                         other key of optimzer depend on the parameters in the class.
                  'grad_clip(dict)': # the 'type' in the grad_clip need to be the class name in the paddle.nn, such as: 'ClipGradByGlobalNorm',
                                     other key of grad_clip depend on the parameters in the class.
              ..
            eval_iter(int): Test period in batches. Default: 1000.
            logging_iter(int): Log period in batches. Default: 10.
            origin_metric(float, optional): The Metric of model before compress, used to check whether the dataloader is correct if is not None. Default: None.
            target_metric(float, optional): The Metric of model after compress, if set target metric, the metric of compressed model satisfy the requirements, will be stop training. If not set, will train epochs as users set. Default: None.
Z
zhouzj 已提交
374
            amp_config(dict, optional): The dictionary contains all the configs of amp. Default: None. The detailed description is as below when turning on distributed training: 
C
ceci3 已提交
375 376 377 378 379 380 381 382 383 384 385 386 387 388 389
              .. code-block:: python
                 AMP-O1 `<https://www.paddlepaddle.org.cn/documentation/docs/zh/guides/performance_improving/amp_cn.html#id2>`_ : 
                     {'custom_white_list', set} # The custom white_list. It's the set of ops that support
                         fp16 calculation and are considered numerically-safe and performance-critical. These ops 
                         will be converted to fp16.
                     {'custom_black_list': set} # The custom black_list. The set of ops that support fp16
                         calculation and are considered numerically-dangerous and whose effects may also be 
                         observed in downstream ops. These ops will not be converted to fp16.
                     {'custom_black_varnames': set} # Users' custom black varibles' names.

                 AMP-O2 `<https://www.paddlepaddle.org.cn/documentation/docs/zh/guides/performance_improving/amp_cn.html#id3>`_ : 
                     {'use_pure_fp16': bool} # Whether to use the pure fp16 training.
                     {'use_fp16_guard': bool} # Whether to use `fp16_guard` when constructing the program.
              ..
              If you want to use AMP-O2, you need to set use_pure_fp16 is True and use_fp16_guard is False.
Z
zhouzj 已提交
390
              when turning on distributed training, the key of amp_config can reference `<https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/distributed/fleet/DistributedStrategy_cn.html#amp-configs>`_.
C
ceci3 已提交
391

Z
zhouzj 已提交
392 393
            recompute_config(dict, optional): The dictionary contains all the configs of recompute. Default: None. The recompute config only can be set when turning on distributed training, the key of recompute_config can reference `<https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/distributed/fleet/DistributedStrategy_cn.html#recompute-configs>`_. 
            sharding_config(dict, optional): The dictionary contains all the configs of sharding. Default: None. The sharding config only can be set when turning on distributed training, the key of sharding_config can reference `<https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/distributed/fleet/DistributedStrategy_cn.html#sharding-configs>`_.
C
ceci3 已提交
394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413
            sparse_model(bool, optional): Set sparse_model to ``True`` to remove mask tensor when the compress strategy is unstructure prune. Default: False.
        """
        self.epochs = epochs
        self.train_iter = train_iter
        self.learning_rate = learning_rate
        self.optimizer_builder = optimizer_builder
        self.eval_iter = eval_iter
        self.logging_iter = logging_iter
        self.origin_metric = origin_metric
        self.target_metric = target_metric
        self.amp_config = amp_config
        self.recompute_config = recompute_config
        self.sharding_config = sharding_config
        self.sparse_model = sparse_model


class MergeConfig:
    def __init__(self, **kwargs):
        for name, value in kwargs.items():
            setattr(self, name, value)
C
ceci3 已提交
414 415 416


def merge_config(*args):
C
ceci3 已提交
417
    fields = set()
C
ceci3 已提交
418 419
    cfg = dict()
    for arg in args:
C
ceci3 已提交
420
        cfg.update(arg.__dict__)
C
ceci3 已提交
421 422 423 424 425 426 427 428 429
    return MergeConfig(**cfg)


class ProgramInfo:
    def __init__(self,
                 startup_program,
                 program,
                 feed_target_names,
                 fetch_targets,
430
                 optimizer=None,
Z
zhouzj 已提交
431 432
                 learning_rate=None,
                 loss_dict=None):
C
ceci3 已提交
433 434 435
        """
        ProgramInfo Config.
        Args:
W
whs 已提交
436 437
            startup_program(paddle.static.Program): Startup program, the means of startup program can reference `<https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/static/default_startup_program_cn.html#default-startup-program>`_.
            program(paddle.static.Program): main program, the means of main program can reference `<https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/static/default_main_program_cn.html#default-main-program>`_.
C
ceci3 已提交
438 439 440 441
            feed_target_names(list(str)): The name of feed tensor in the program.
            fetch_targets(list(Variable)): The fetch variable in the program.
            optimizer(Optimizer, optional): Optimizer in training. Default: None.
            learning_rate(float|paddle.optimizer.lr, optional): learning_rate in training. Default: None.
Z
zhouzj 已提交
442
            loss_dict(dict): The components of losses.
C
ceci3 已提交
443
        """
C
ceci3 已提交
444 445 446 447 448
        self.startup_program = startup_program
        self.program = program
        self.feed_target_names = feed_target_names
        self.fetch_targets = fetch_targets
        self.optimizer = optimizer
449
        self.learning_rate = learning_rate
Z
zhouzj 已提交
450
        self.loss_dict = loss_dict