fp16_helper.py 7.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from paddle.distributed.fleet.meta_optimizers.common import is_optimizer_op, OP_ROLE_KEY, OpRole
from paddle.distributed.fleet.meta_optimizers.sharding.utils import *

from paddle.fluid import core


class FP16Utils(object):
    def __init__(self):
        pass

    @staticmethod
    def is_fp16_cast_op(block, op, params):
        if op.type != "cast":
            return False
        if is_optimizer_op(op):
            return False
        assert (len(op.desc.input_arg_names()) == 1)
        assert (len(op.desc.output_arg_names()) == 1)
        input_name, output_name = op.desc.input_arg_names()[
            0], op.desc.output_arg_names()[0]
        if input_name not in params:
            return False
        input_var = block.var(input_name)
        output_var = block.var(output_name)
        if input_var.dtype != core.VarDesc.VarType.FP32 or \
            output_var.dtype != core.VarDesc.VarType.FP16:
            return False
        return True

    @staticmethod
    def is_fp32_cast_op(block, op):
        if op.type != "cast":
            return False
        if not is_optimizer_op(op):
            return False
        assert (len(op.desc.input_arg_names()) == 1)
        assert (len(op.desc.output_arg_names()) == 1)
        input_name, output_name = op.desc.input_arg_names()[
            0], op.desc.output_arg_names()[0]
        input_var = block.var(input_name)
        output_var = block.var(output_name)
        if input_var.dtype != core.VarDesc.VarType.FP16 or \
            output_var.dtype != core.VarDesc.VarType.FP32:
            return False
        return True

    @staticmethod
    def remove_cast_op(block, params, segment, offset):
        inserted_op_num = 0
        for op_idx in reversed(
                range(offset + segment._start_idx, offset + segment._end_idx)):
            op = block.ops[op_idx]
            if FP16Utils.is_fp16_cast_op(block, op, params):
                block._remove_op(op_idx, sync=False)
                inserted_op_num -= 1
        block._sync_with_cpp()
        return inserted_op_num

    @staticmethod
74 75
    def prune_fp16(block, shard, reduced_grads_to_param, ring_id):
        """
L
lilong12 已提交
76
        1. prune all cast_fp16_to_fp32 ops if the param not belongs to this shard
77 78
        2. revise amp inifine grad checking for sharding
        """
79 80 81 82 83
        # remove cast
        for idx, op in reversed(list(enumerate(block.ops))):
            if not FP16Utils.is_fp32_cast_op(block, op):
                continue
            output_name = op.desc.output_arg_names()[0]
L
lilong12 已提交
84 85 86
            param_name = output_name.strip(
                "@GRAD@MERGED"
            ) if "@MERGED" in output_name else output_name.strip("@GRAD")
87
            if param_name not in shard.global_params:
88 89 90
                raise ValueError("Output 'X' of cast_op must be a grad of"
                                 "model param, but {} is not a grad".format(
                                     output_name))
91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107
            if output_name in reduced_grads_to_param:
                continue
            if shard.has_param(param_name):
                continue
            block._remove_op(idx, sync=False)
            block._remove_var(output_name, sync=False)

        block._sync_with_cpp()
        update_loss_scaling_op_idx = -1
        inf_var_name = ''
        for idx, op in reversed(list(enumerate(block.ops))):
            if op.type == "update_loss_scaling":
                update_loss_scaling_op_idx = idx
                inf_var_name = op.desc.input('FoundInfinite')[0]
                op._rename_input(inf_var_name, inf_var_name + "@sharding")
            if op.type in ["check_finite_and_unscale", "update_loss_scaling"]:
                reversed_x = []
L
lilong12 已提交
108
                reversed_x_paramname = []
109
                for input_name in op.desc.input('X'):
L
lilong12 已提交
110
                    param_name = input_name.strip("@GRAD@MERGED")
111 112 113 114 115 116
                    if param_name not in shard.global_params:
                        raise ValueError(
                            "Input 'X' of check_finite_and_unscale must"
                            "be grads, but {} is not a grad".format(input_name))
                    if shard.has_param(param_name):
                        reversed_x.append(input_name)
L
lilong12 已提交
117
                        reversed_x_paramname.append(param_name)
118 119
                op.desc.set_input('X', reversed_x)
                op.desc.set_output('Out', reversed_x)
L
lilong12 已提交
120 121 122 123 124 125 126 127 128 129 130 131 132

                # the grad checking should take the all and only param in the current shard
                to_check_param = set(reversed_x_paramname)
                should_check_param = set(shard.global_params).intersection(
                    set([
                        param
                        for param, worker_idx in shard.global_param2device.
                        items() if worker_idx == shard.worker_idx
                    ]))
                #assert to_check_param == should_check_param, "amp check_finite_and_unscale checking miss [{}] and got unexpected [{}]".format(
                #    should_check_param - to_check_param,
                #    to_check_param - should_check_param)

133 134 135
        if update_loss_scaling_op_idx == -1:
            return
        inf_var = block.var(inf_var_name)
L
lilong12 已提交
136
        inf_var_int32 = block.create_var(
137 138 139 140 141 142 143 144 145 146 147
            name=inf_var_name + "@cast_int32",
            shape=inf_var.shape,
            dtype=core.VarDesc.VarType.INT32)
        inf_var_sharding = block.create_var(
            name=inf_var_name + "@sharding",
            shape=inf_var.shape,
            dtype=inf_var.dtype)
        block._insert_op_without_sync(
            update_loss_scaling_op_idx,
            type='cast',
            inputs={'X': inf_var},
L
lilong12 已提交
148
            outputs={'Out': inf_var_int32},
149 150
            attrs={
                "in_dtype": inf_var.dtype,
L
lilong12 已提交
151
                "out_dtype": inf_var_int32.dtype,
152 153
                OP_ROLE_KEY: OpRole.Optimize
            })
L
lilong12 已提交
154 155 156
        # this allreduce communication should not overlap with calc
        # insert_sync_calc_op(block, update_loss_scaling_op_idx + 1,
        #                     [inf_var_int32])
157
        block._insert_op_without_sync(
L
lilong12 已提交
158
            update_loss_scaling_op_idx + 1,
159
            type='c_allreduce_max',
L
lilong12 已提交
160 161 162 163 164 165 166
            inputs={'X': inf_var_int32},
            outputs={'Out': inf_var_int32},
            attrs={
                'ring_id': ring_id,
                'use_calc_stream': True,
                OP_ROLE_KEY: OpRole.Optimize
            })
167

L
lilong12 已提交
168 169
        # comm_op_num = insert_sync_comm_op(block, update_loss_scaling_op_idx + 3,
        #                                   ring_id, [inf_var_int32])
170

171
        block._insert_op_without_sync(
L
lilong12 已提交
172
            update_loss_scaling_op_idx + 2,
173
            type='cast',
L
lilong12 已提交
174
            inputs={'X': inf_var_int32},
175 176
            outputs={'Out': inf_var_sharding},
            attrs={
L
lilong12 已提交
177
                "in_dtype": inf_var_int32.dtype,
178 179 180 181
                "out_dtype": inf_var_sharding.dtype,
                OP_ROLE_KEY: OpRole.Optimize
            })
        block._sync_with_cpp()