diff --git a/fluid/PaddleCV/image_classification/train.py b/fluid/PaddleCV/image_classification/train.py index d92ce5532bb015e626dd0524f7b3e9d4c10dade6..58eb5cbc800512c23a15b3cb0c38aa24dcfd515e 100644 --- a/fluid/PaddleCV/image_classification/train.py +++ b/fluid/PaddleCV/image_classification/train.py @@ -17,6 +17,7 @@ import functools import subprocess import utils from utils.learning_rate import cosine_decay +from utils.fp16_utils import create_master_params_grads, master_param_to_train_param from utility import add_arguments, print_arguments import models import models_name @@ -160,62 +161,6 @@ def net_config(image, label, model, args): return avg_cost, acc_top1, acc_top5 -def cast_fp16_to_fp32(i, o, prog): - prog.global_block().append_op( - type="cast", - inputs={"X": i}, - outputs={"Out": o}, - attrs={ - "in_dtype": fluid.core.VarDesc.VarType.FP16, - "out_dtype": fluid.core.VarDesc.VarType.FP32 - } - ) - -def cast_fp32_to_fp16(i, o, prog): - prog.global_block().append_op( - type="cast", - inputs={"X": i}, - outputs={"Out": o}, - attrs={ - "in_dtype": fluid.core.VarDesc.VarType.FP32, - "out_dtype": fluid.core.VarDesc.VarType.FP16 - } - ) - -def copy_to_master_param(p, block): - v = block.vars.get(p.name, None) - if v is None: - raise ValueError("no param name %s found!" % p.name) - new_p = fluid.framework.Parameter( - block=block, - shape=v.shape, - dtype=fluid.core.VarDesc.VarType.FP32, - type=v.type, - lod_level=v.lod_level, - stop_gradient=p.stop_gradient, - trainable=p.trainable, - optimize_attr=p.optimize_attr, - regularizer=p.regularizer, - gradient_clip_attr=p.gradient_clip_attr, - error_clip=p.error_clip, - name=v.name + ".master") - return new_p - -def update_op_role_var(params_grads, master_params_grads, main_prog): - orig_grad_name_set = set() - for _, g in params_grads: - orig_grad_name_set.add(g.name) - master_g2p_dict = dict() - for idx, master in enumerate(master_params_grads): - orig = params_grads[idx] - master_g2p_dict[orig[1].name] = [master[0].name, master[1].name] - for op in main_prog.global_block().ops: - for oname in op.output_arg_names: - if oname in orig_grad_name_set: - # rename - print("setting to ", master_g2p_dict[oname]) - op._set_attr("op_role_var", master_g2p_dict[oname]) - def build_program(is_train, main_prog, startup_prog, args): image_shape = [int(m) for m in args.image_shape.split(",")] model_name = args.model @@ -249,38 +194,11 @@ def build_program(is_train, main_prog, startup_prog, args): optimizer = optimizer_setting(params) if args.fp16: - master_params_grads = [] params_grads = optimizer.backward(avg_cost) - tmp_role = main_prog._current_role - OpRole = fluid.core.op_proto_and_checker_maker.OpRole - main_prog._current_role = OpRole.Backward - for p, g in params_grads: - master_param = copy_to_master_param(p, main_prog.global_block()) - startup_master_param = startup_prog.global_block()._clone_variable(master_param) - startup_p = startup_prog.global_block().var(p.name) - cast_fp16_to_fp32(startup_p, startup_master_param, startup_prog) - - if g.name.startswith("batch_norm"): - if args.scale_loss > 1: - scaled_g = g / float(args.scale_loss) - else: - scaled_g = g - master_params_grads.append([p, scaled_g]) - continue - master_grad = fluid.layers.cast(g, "float32") - if args.scale_loss > 1: - master_grad = master_grad / float(args.scale_loss) - master_params_grads.append([master_param, master_grad]) - main_prog._current_role = tmp_role - + master_params_grads = create_master_params_grads( + params_grads, main_prog, startup_prog, args.scale_loss) optimizer.apply_gradients(master_params_grads) - - for idx, m_p_g in enumerate(master_params_grads): - train_p, train_g = params_grads[idx] - if train_p.name.startswith("batch_norm"): - continue - with main_prog._optimized_guard([m_p_g[0], m_p_g[1]]): - cast_fp32_to_fp16(m_p_g[0], train_p, main_prog) + master_param_to_train_param(master_params_grads, params_grads, main_prog) else: optimizer.minimize(avg_cost) diff --git a/fluid/PaddleCV/image_classification/utils/__init__.py b/fluid/PaddleCV/image_classification/utils/__init__.py index f59e4baf93aa095f393441d2cd766ff8d3b28801..4751caceeb14f0dddc937d90b4c953a870ffc3f8 100644 --- a/fluid/PaddleCV/image_classification/utils/__init__.py +++ b/fluid/PaddleCV/image_classification/utils/__init__.py @@ -1 +1,2 @@ from .learning_rate import cosine_decay, lr_warmup +from .fp16_utils import create_master_params_grads, master_param_to_train_param diff --git a/fluid/PaddleCV/image_classification/utils/fp16_utils.py b/fluid/PaddleCV/image_classification/utils/fp16_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..cf6e081e7f9b433e4097f190b012c533064f5cca --- /dev/null +++ b/fluid/PaddleCV/image_classification/utils/fp16_utils.py @@ -0,0 +1,78 @@ +from __future__ import print_function +import paddle +import paddle.fluid as fluid + +def cast_fp16_to_fp32(i, o, prog): + prog.global_block().append_op( + type="cast", + inputs={"X": i}, + outputs={"Out": o}, + attrs={ + "in_dtype": fluid.core.VarDesc.VarType.FP16, + "out_dtype": fluid.core.VarDesc.VarType.FP32 + } + ) + +def cast_fp32_to_fp16(i, o, prog): + prog.global_block().append_op( + type="cast", + inputs={"X": i}, + outputs={"Out": o}, + attrs={ + "in_dtype": fluid.core.VarDesc.VarType.FP32, + "out_dtype": fluid.core.VarDesc.VarType.FP16 + } + ) + +def copy_to_master_param(p, block): + v = block.vars.get(p.name, None) + if v is None: + raise ValueError("no param name %s found!" % p.name) + new_p = fluid.framework.Parameter( + block=block, + shape=v.shape, + dtype=fluid.core.VarDesc.VarType.FP32, + type=v.type, + lod_level=v.lod_level, + stop_gradient=p.stop_gradient, + trainable=p.trainable, + optimize_attr=p.optimize_attr, + regularizer=p.regularizer, + gradient_clip_attr=p.gradient_clip_attr, + error_clip=p.error_clip, + name=v.name + ".master") + return new_p + +def create_master_params_grads(params_grads, main_prog, startup_prog, scale_loss): + master_params_grads = [] + tmp_role = main_prog._current_role + OpRole = fluid.core.op_proto_and_checker_maker.OpRole + main_prog._current_role = OpRole.Backward + for p, g in params_grads: + # create master parameters + master_param = copy_to_master_param(p, main_prog.global_block()) + startup_master_param = startup_prog.global_block()._clone_variable(master_param) + startup_p = startup_prog.global_block().var(p.name) + cast_fp16_to_fp32(startup_p, startup_master_param, startup_prog) + # cast fp16 gradients to fp32 before apply gradients + if g.name.startswith("batch_norm"): + if scale_loss > 1: + scaled_g = g / float(scale_loss) + else: + scaled_g = g + master_params_grads.append([p, scaled_g]) + continue + master_grad = fluid.layers.cast(g, "float32") + if scale_loss > 1: + master_grad = master_grad / float(scale_loss) + master_params_grads.append([master_param, master_grad]) + main_prog._current_role = tmp_role + return master_params_grads + +def master_param_to_train_param(master_params_grads, params_grads, main_prog): + for idx, m_p_g in enumerate(master_params_grads): + train_p, _ = params_grads[idx] + if train_p.name.startswith("batch_norm"): + continue + with main_prog._optimized_guard([m_p_g[0], m_p_g[1]]): + cast_fp32_to_fp16(m_p_g[0], train_p, main_prog)