diff --git a/tools/static/program.py b/tools/static/program.py index db9949c89fac78e589520e84804a090d7ef24c75..cea583ea87075a2e9b1ecc5e5829f6a86de79848 100644 --- a/tools/static/program.py +++ b/tools/static/program.py @@ -288,6 +288,60 @@ def create_optimizer(config): opt = OptimizerBuilder(config, **opt_config) return opt(lr), lr +def create_strategy(config): + """ + Create build strategy and exec strategy. + + Args: + config(dict): config + + Returns: + build_strategy: build strategy + exec_strategy: exec strategy + """ + build_strategy = paddle.static.BuildStrategy() + exec_strategy = paddle.static.ExecutionStrategy() + + exec_strategy.num_threads = 1 + exec_strategy.num_iteration_per_drop_scope = 10000 if config.get( + 'use_pure_fp16', False) else 10 + + fuse_op = config.get('use_amp', False) or config.get('use_pure_fp16', + False) + fuse_bn_act_ops = config.get('fuse_bn_act_ops', fuse_op) + fuse_elewise_add_act_ops = config.get('fuse_elewise_add_act_ops', fuse_op) + fuse_bn_add_act_ops = config.get('fuse_bn_add_act_ops', fuse_op) + enable_addto = config.get('enable_addto', fuse_op) + + try: + build_strategy.fuse_bn_act_ops = fuse_bn_act_ops + except Exception as e: + logger.info( + "PaddlePaddle version 1.7.0 or higher is " + "required when you want to fuse batch_norm and activation_op.") + + try: + build_strategy.fuse_elewise_add_act_ops = fuse_elewise_add_act_ops + except Exception as e: + logger.info( + "PaddlePaddle version 1.7.0 or higher is " + "required when you want to fuse elewise_add_act and activation_op.") + + try: + build_strategy.fuse_bn_add_act_ops = fuse_bn_add_act_ops + except Exception as e: + logger.info( + "PaddlePaddle 2.0-rc or higher is " + "required when you want to enable fuse_bn_add_act_ops strategy.") + + try: + build_strategy.enable_addto = enable_addto + except Exception as e: + logger.info("PaddlePaddle 2.0-rc or higher is " + "required when you want to enable addto strategy.") + return build_strategy, exec_strategy + + def dist_optimizer(config, optimizer): """ @@ -300,14 +354,15 @@ def dist_optimizer(config, optimizer): Returns: optimizer: a distributed optimizer """ - exec_strategy = paddle.static.ExecutionStrategy() - exec_strategy.num_threads = 3 - exec_strategy.num_iteration_per_drop_scope = 10 + build_strategy, exec_strategy = create_strategy(config) dist_strategy = DistributedStrategy() + dist_strategy.execution_strategy = exec_strategy + dist_strategy.build_strategy = build_strategy + dist_strategy.nccl_comm_num = 1 dist_strategy.fuse_all_reduce_ops = True - dist_strategy.execution_strategy = exec_strategy + dist_strategy.fuse_grad_size_in_MB = 16 optimizer = fleet.distributed_optimizer(optimizer, strategy=dist_strategy) return optimizer @@ -399,46 +454,7 @@ def compile(config, program, loss_name=None, share_prog=None): Returns: compiled_program(): a compiled program """ - build_strategy = paddle.static.BuildStrategy() - exec_strategy = paddle.static.ExecutionStrategy() - - exec_strategy.num_threads = 1 - exec_strategy.num_iteration_per_drop_scope = 10000 if config.get( - 'use_pure_fp16', False) else 10 - - fuse_op = config.get('use_amp', False) or config.get('use_pure_fp16', - False) - fuse_bn_act_ops = config.get('fuse_bn_act_ops', fuse_op) - fuse_elewise_add_act_ops = config.get('fuse_elewise_add_act_ops', fuse_op) - fuse_bn_add_act_ops = config.get('fuse_bn_add_act_ops', fuse_op) - enable_addto = config.get('enable_addto', fuse_op) - - try: - build_strategy.fuse_bn_act_ops = fuse_bn_act_ops - except Exception as e: - logger.info( - "PaddlePaddle version 1.7.0 or higher is " - "required when you want to fuse batch_norm and activation_op.") - - try: - build_strategy.fuse_elewise_add_act_ops = fuse_elewise_add_act_ops - except Exception as e: - logger.info( - "PaddlePaddle version 1.7.0 or higher is " - "required when you want to fuse elewise_add_act and activation_op.") - - try: - build_strategy.fuse_bn_add_act_ops = fuse_bn_add_act_ops - except Exception as e: - logger.info( - "PaddlePaddle 2.0-rc or higher is " - "required when you want to enable fuse_bn_add_act_ops strategy.") - - try: - build_strategy.enable_addto = enable_addto - except Exception as e: - logger.info("PaddlePaddle 2.0-rc or higher is " - "required when you want to enable addto strategy.") + build_strategy, exec_strategy = create_strategy(config) compiled_program = paddle.static.CompiledProgram( program).with_data_parallel(