fp16_helper.py 10.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15 16 17 18 19
from paddle.distributed.fleet.meta_optimizers.common import (
    is_optimizer_op,
    OP_ROLE_KEY,
    OpRole,
)
20

21
from paddle.framework import core
22

23 24
__all__ = []

25

26
class FP16Utils:
27 28 29 30 31 32 33 34 35
    def __init__(self):
        pass

    @staticmethod
    def is_fp16_cast_op(block, op, params):
        if op.type != "cast":
            return False
        if is_optimizer_op(op):
            return False
36 37 38 39 40 41
        assert len(op.desc.input_arg_names()) == 1
        assert len(op.desc.output_arg_names()) == 1
        input_name, output_name = (
            op.desc.input_arg_names()[0],
            op.desc.output_arg_names()[0],
        )
42 43 44 45
        if input_name not in params:
            return False
        input_var = block.var(input_name)
        output_var = block.var(output_name)
46 47 48 49
        if (
            input_var.dtype != core.VarDesc.VarType.FP32
            or output_var.dtype != core.VarDesc.VarType.FP16
        ):
50 51 52 53 54 55 56 57 58
            return False
        return True

    @staticmethod
    def is_fp32_cast_op(block, op):
        if op.type != "cast":
            return False
        if not is_optimizer_op(op):
            return False
59 60 61 62 63 64
        assert len(op.desc.input_arg_names()) == 1
        assert len(op.desc.output_arg_names()) == 1
        input_name, output_name = (
            op.desc.input_arg_names()[0],
            op.desc.output_arg_names()[0],
        )
65 66
        input_var = block.var(input_name)
        output_var = block.var(output_name)
67 68 69 70
        if (
            input_var.dtype != core.VarDesc.VarType.FP16
            or output_var.dtype != core.VarDesc.VarType.FP32
        ):
71 72 73 74 75 76 77
            return False
        return True

    @staticmethod
    def remove_cast_op(block, params, segment, offset):
        inserted_op_num = 0
        for op_idx in reversed(
78 79
            range(offset + segment._start_idx, offset + segment._end_idx)
        ):
80 81 82 83 84 85 86 87
            op = block.ops[op_idx]
            if FP16Utils.is_fp16_cast_op(block, op, params):
                block._remove_op(op_idx, sync=False)
                inserted_op_num -= 1
        block._sync_with_cpp()
        return inserted_op_num

    @staticmethod
88
    def prune_fp16(block, shard, reduced_grads_to_param, ring_ids):
89
        """
90
        1. prune all cast_fp16_to_fp32 ops if the param not belongs to this shard
91 92
        2. revise amp inifine grad checking for sharding
        """
93 94 95 96 97
        # remove cast
        for idx, op in reversed(list(enumerate(block.ops))):
            if not FP16Utils.is_fp32_cast_op(block, op):
                continue
            output_name = op.desc.output_arg_names()[0]
98
            # TODO (JZ-LIANG) revise this for uniform mixed parallelism
99 100 101 102 103
            param_name = (
                output_name.strip("@GRAD@MERGED")
                if "@MERGED" in output_name
                else output_name.strip("@GRAD")
            )
104
            if param_name not in shard.global_params:
105 106
                raise ValueError(
                    "Output 'X' of cast_op must be a grad of"
107 108
                    "model param, but {} is not a grad".format(output_name)
                )
109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124
            if output_name in reduced_grads_to_param:
                continue
            if shard.has_param(param_name):
                continue
            block._remove_op(idx, sync=False)
            block._remove_var(output_name, sync=False)

        block._sync_with_cpp()
        update_loss_scaling_op_idx = -1
        inf_var_name = ''
        for idx, op in reversed(list(enumerate(block.ops))):
            if op.type == "update_loss_scaling":
                update_loss_scaling_op_idx = idx
                inf_var_name = op.desc.input('FoundInfinite')[0]
            if op.type in ["check_finite_and_unscale", "update_loss_scaling"]:
                reversed_x = []
125
                reversed_x_paramname = []
126
                for input_name in op.desc.input('X'):
127 128 129 130 131
                    # TODO (JZ-LIANG) revise this for uniform mixed parallelism
                    if "@MERGED" in input_name:
                        param_name = input_name.strip("@GRAD@MERGED")
                    else:
                        param_name = input_name.strip("@GRAD")
132 133 134
                    if param_name not in shard.global_params:
                        raise ValueError(
                            "Input 'X' of check_finite_and_unscale must"
135 136
                            "be grads, but {} is not a grad".format(input_name)
                        )
137 138
                    if shard.has_param(param_name):
                        reversed_x.append(input_name)
139
                        reversed_x_paramname.append(param_name)
140 141
                op.desc.set_input('X', reversed_x)
                op.desc.set_output('Out', reversed_x)
142 143 144 145

                # the grad checking should take the all and only param in the current shard
                to_check_param = set(reversed_x_paramname)
                should_check_param = set(shard.global_params).intersection(
146 147 148 149 150 151 152 153 154 155 156
                    set(
                        [
                            param
                            for param, worker_idx in shard.global_param2device.items()
                            if worker_idx == shard.worker_idx
                        ]
                    )
                )
                assert (
                    to_check_param == should_check_param
                ), "amp \
157 158
                    check_finite_and_unscale checking miss [{}] and got unexpected [{}]".format(
                    should_check_param - to_check_param,
159 160
                    to_check_param - should_check_param,
                )
161

162 163 164
        if update_loss_scaling_op_idx == -1:
            return
        inf_var = block.var(inf_var_name)
165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181
        inf_var_int32 = block.create_var(
            name=inf_var_name + "@cast_int32",
            shape=inf_var.shape,
            dtype=core.VarDesc.VarType.INT32,
        )

        block._insert_op_without_sync(
            update_loss_scaling_op_idx,
            type='cast',
            inputs={'X': inf_var},
            outputs={'Out': inf_var_int32},
            attrs={
                "in_dtype": inf_var.dtype,
                "out_dtype": inf_var_int32.dtype,
                OP_ROLE_KEY: OpRole.Optimize,
            },
        )
182 183 184 185
        update_loss_scaling_op_idx += 1

        # allreduce(mp)->allreduce(sharding)->allreduce(pp)
        for ring_id in ring_ids:
186 187
            if ring_id == -1:
                continue
188
            # this allreduce communication should not overlap with calc
189 190 191 192 193 194 195 196 197 198 199
            block._insert_op_without_sync(
                update_loss_scaling_op_idx,
                type='c_allreduce_max',
                inputs={'X': inf_var_int32},
                outputs={'Out': inf_var_int32},
                attrs={
                    'ring_id': ring_id,
                    'use_calc_stream': True,
                    OP_ROLE_KEY: OpRole.Optimize,
                },
            )
200 201
            update_loss_scaling_op_idx += 1

202 203 204 205 206 207 208 209 210 211 212
        block._insert_op_without_sync(
            update_loss_scaling_op_idx,
            type='cast',
            inputs={'X': inf_var_int32},
            outputs={'Out': inf_var},
            attrs={
                "in_dtype": inf_var_int32.dtype,
                "out_dtype": inf_var.dtype,
                OP_ROLE_KEY: OpRole.Optimize,
            },
        )
213
        update_loss_scaling_op_idx += 1
214
        block._sync_with_cpp()
215 216 217

    # TODO (JZ-LIANG) revise this for uniform mixed parallelism
    @staticmethod
218
    def sync_amp_check_nan_inf(block, ring_ids):
219 220 221 222 223 224
        update_loss_scaling_op_idx = -1

        for idx, op in reversed(list(enumerate(block.ops))):
            if op.type == "update_loss_scaling":
                update_loss_scaling_op_idx = idx
                inf_var_name = op.desc.input('FoundInfinite')[0]
225
                break
226 227 228 229

        # not use amp
        if update_loss_scaling_op_idx == -1:
            return
230 231 232
        # 0. inf_var_int32 = cast(inf_var)
        # 1. inf_var_int32 = allreduce_max(inf_var_int32)
        # 3. inf_var = cast(inf_var_int32)
233
        inf_var = block.var(inf_var_name)
234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249
        inf_var_int32 = block.create_var(
            name=inf_var_name + "@cast_int32",
            shape=inf_var.shape,
            dtype=core.VarDesc.VarType.INT32,
        )
        block._insert_op_without_sync(
            update_loss_scaling_op_idx,
            type='cast',
            inputs={'X': inf_var},
            outputs={'Out': inf_var_int32},
            attrs={
                "in_dtype": inf_var.dtype,
                "out_dtype": inf_var_int32.dtype,
                OP_ROLE_KEY: OpRole.Optimize,
            },
        )
250 251 252 253
        update_loss_scaling_op_idx += 1

        # allreduce(mp)->allreduce(pp)
        for ring_id in ring_ids:
254 255 256 257 258 259 260 261 262 263 264 265 266
            if ring_id == -1:
                continue
            block._insert_op_without_sync(
                update_loss_scaling_op_idx,
                type='c_allreduce_max',
                inputs={'X': inf_var_int32},
                outputs={'Out': inf_var_int32},
                attrs={
                    'ring_id': ring_id,
                    'use_calc_stream': True,
                    OP_ROLE_KEY: OpRole.Optimize,
                },
            )
267 268
            update_loss_scaling_op_idx += 1

269 270 271 272 273 274 275 276 277 278 279
        block._insert_op_without_sync(
            update_loss_scaling_op_idx,
            type='cast',
            inputs={'X': inf_var_int32},
            outputs={'Out': inf_var},
            attrs={
                "in_dtype": inf_var_int32.dtype,
                "out_dtype": inf_var.dtype,
                OP_ROLE_KEY: OpRole.Optimize,
            },
        )
280
        update_loss_scaling_op_idx += 1
281
        block._sync_with_cpp()