ascend_optimizer.py 10.3 KB
Newer Older
1
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2
#
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9 10 11 12 13 14
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15 16
from paddle.optimizer import Optimizer
import paddle.framework.core as core
17 18 19 20 21 22
from . import ascend_parser
from paddle.distributed import fleet
import hccl.manage.api as hccl
from collections import namedtuple

HcomGroupConfig = namedtuple('HcomGroupConfig', ['name', 'nranks', 'rank_ids'])
23

24 25
__all__ = []

26

27
class AscendIRParser:
28
    def __init__(self, auto_dp=False, world_rank_size=1):
29
        self.graph_idx = 0
30 31 32 33
        self.hcom_endpoints = {}
        self.groups_to_create = []
        self._auto_dp = auto_dp
        self._world_rank_size = world_rank_size
34 35 36 37 38 39 40

    def _construct_input_map(self, input_varlist):
        ret_map = {}
        ge_in_operator = []
        for id, var in enumerate(input_varlist):
            if var.is_data:  # input data
                ge_input = core.GEOperatorFactory.create_operator(
41 42
                    var.name, "Data"
                ).set_attr_int32("index", id)
43 44 45
                ret_map[var.name] = ge_input
                ge_in_operator.append(ge_input)
            else:  # param, learning ...
46
                ge_input = core.GEOperatorFactory.create_operator(
47 48
                    var.name, "Variable"
                )
49 50
                ge_input.update_output_desc(
                    "y",
51 52 53 54 55 56
                    core.GETensorDesc(
                        core.GEShape(var.shape),
                        core.GEFormat.FORMAT_ND,
                        core.GEDataType.DT_FLOAT,
                    ),
                )
57 58 59
                ret_map[var.name] = ge_input
        return ge_in_operator, ret_map

60 61
    def _endpoint_to_world_rank_id(self, endpoint):
        world_endpoints = fleet.worker_endpoints()
62 63 64 65 66 67
        assert (
            endpoint in world_endpoints
        ), "endpoint (%s) not in worker_endpoints (%s) " % (
            endpoint,
            fleet.world_device_ids(),
        )
68 69
        return world_endpoints.index(endpoint)

70
    def parse_op(self, op):
71 72 73 74 75 76 77 78
        if op.type == 'c_gen_nccl_id':
            endpoint = op.attr("endpoint")
            other_endpoints = op.attr("other_endpoints")
            rank = op.attr("rank")

            nccl_id = op.output_arg_names[0]

            # c_gen_nccl_id operator splits endpoints into local endpoint and other_endpoints
79
            # we should combine these together to produce world_rank_ids
80 81 82
            self.hcom_endpoints[nccl_id] = other_endpoints[:]
            self.hcom_endpoints[nccl_id].insert(rank, endpoint)

83 84 85 86
            print(
                "nccl_id (%s) registered endpoints %s"
                % (nccl_id, self.hcom_endpoints[nccl_id])
            )
87 88 89
        elif op.type == 'c_comm_init':
            nccl_id = op.input_arg_names[0]
            nranks = op.attr("nranks")
90 91 92
            assert nranks == len(
                self.hcom_endpoints[nccl_id]
            ), "nranks doesn't match endpoint count"
93 94 95 96 97 98 99 100 101
            rank = op.attr("rank")
            ring_id = op.attr("ring_id")

            group_name = "hcom_group_" + str(ring_id)
            global_rank_ids = [
                self._endpoint_to_world_rank_id(endpoint)
                for endpoint in self.hcom_endpoints[nccl_id]
            ]
            self.groups_to_create.append(
102 103 104 105 106 107 108 109
                HcomGroupConfig(
                    name=group_name, nranks=nranks, rank_ids=global_rank_ids
                )
            )
            print(
                "append to create group: %s, with rank_ids: %s"
                % (group_name, global_rank_ids)
            )
110
        elif op.type in ascend_parser.registerd_op:
111
            op_parser = self.parser_factory.create_parse(
112 113
                ascend_parser.registerd_op[op.type]
            )
114 115
            op_parser.apply(op)
        else:
116 117 118 119 120 121 122 123 124
            assert (
                False
            ), "Op[%s] has not been registered, so we have to skip it" % (
                op.type
            )

    def _parse_program(
        self, graph_name, program, input_varlist=[], fetch_list=[]
    ):
125 126 127 128 129 130 131 132 133 134 135 136 137 138
        begin_graph_idx = self.graph_idx
        ge_in_operator = []
        ge_out_operator = []
        self.var2geop = {}

        block = program.global_block()
        if len(block.ops) == 0:
            print("There is no ops in program %s" % (graph_name))
            return []

        graph = core.GEGraph(graph_name)

        ge_in_operator, self.var2geop = self._construct_input_map(input_varlist)

139
        self.parser_factory = ascend_parser.AscendParserFactory(
140 141
            graph, self.var2geop
        )
142 143 144 145 146 147 148 149 150 151
        for i, curop in list(enumerate(block.ops)):
            self.parse_op(curop)

        # Set fetch_var for GE
        for e in fetch_list:
            name = e
            if not isinstance(e, str):
                name = e.name
            ge_out_operator.append(self.var2geop[name])

152
        # (Debug) If you want to print back prop vars, append/assign the varname in ge_out_operator here, such as:
153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169
        # if graph_name == "main":
        #     ge_out_operator.append(self.var2geop["reduce_sum_0.tmp_0@GRAD"])

        # Add ops that may be input of a graph, such as const.
        for varname, geop in self.var2geop.items():
            if varname.startswith("geinput"):
                ge_in_operator.append(geop)

        graph.set_inputs(ge_in_operator).set_outputs(ge_out_operator)

        # Remove ops of origin program
        op_num = len(block.ops)
        for i in range(op_num - 1, -1, -1):
            block._remove_op(i)

        input_varlist = [var for var in input_varlist if var.is_data]

170 171 172 173 174 175
        block.append_op(
            type="ascend_trigger",
            inputs={"FeedList": input_varlist},
            outputs={"FetchList": fetch_list},
            attrs={'graph_idx': self.graph_idx},
        )
176 177 178
        self.graph_idx += 1
        return graph

179 180 181
    def parse_program(
        self, startup_program, main_program, input_varlist, fetch_list
    ):
182
        startup_graph = self._parse_program("startup", startup_program)
183 184 185
        main_graph = self._parse_program(
            "main", main_program, input_varlist, fetch_list
        )
186
        if self._auto_dp and self._world_rank_size > 1:
187 188 189
            assert (
                len(self.groups_to_create) == 0
            ), "can't parse program under auto_dp mode"
190 191

            from paddle.distributed import fleet
192

193
            self.groups_to_create.append(
194 195 196 197 198 199
                HcomGroupConfig(
                    name="hcom_group_0",
                    nranks=fleet.world_size(),
                    rank_ids=[x for x in range(fleet.world_size())],
                )
            )
200

201 202 203 204 205 206 207 208 209
        return startup_graph, main_graph


# AscendOptimizer is a wrapper for basic optimizer now
# We will make it part of fleet meta_optimizer in the future
class AscendOptimizer(Optimizer):
    def __init__(self, optimizer, fetch_list=[]):
        self.inner_opt = optimizer
        self.fetch_list = fetch_list
210
        self.ascend_instance = None
211 212

    def __del__(self):
213 214 215
        print("begin AscendOptimizer del")
        if self.ascend_instance is not None:
            self.ascend_instance.destroy_global_resources()
216
        core.ge_finalize()
217
        print("end AscendOptimizer del")
218 219 220 221 222 223 224 225 226 227 228

    def _can_apply(self):
        if not self.user_defined_strategy.ascend:
            return False
        # TODO(hutuxian): other check here
        return True

    def _disable_strategy(self, dist_strategy):
        dist_strategy.ascend = False
        dist_strategy.ascend_configs = {}

229
    def _get_input_varlist(self, program):
230 231 232 233 234 235
        ret_list = []
        for var in program.list_vars():
            if var.is_data or var.persistable:
                ret_list.append(var)
        return ret_list

236 237 238 239 240 241 242 243 244 245
    def minimize(
        self,
        loss,
        startup_program=None,
        parameter_list=None,
        no_grad_set=None,
        auto_dp=False,
        rank_table_file=None,
        precision_mode="must_keep_origin_dtype",
    ):
246 247
        minimized = None
        if self.inner_opt:
248 249 250
            minimized = self.inner_opt.minimize(
                loss, startup_program=startup_program
            )
251 252 253

        self.ascend_instance = core.AscendInstance()

254
        from paddle.distributed import fleet
255

256 257
        if auto_dp and fleet.world_size() > 1:
            from paddle.fluid.transpiler import ascend_transpiler
258 259 260 261

            t = ascend_transpiler.AscendTranspiler(
                startup_program, loss.block.program
            )
262
            t.transpile()
263
            # print(loss.block.program)
264

265 266
        # Config about Graph Engine can be found in https://support.huaweicloud.com/
        config = {
267
            "ge.exec.deviceId": str(fleet.local_device_ids()),
268
            "ge.graphRunMode": "1",
269
            "ge.exec.precision_mode": precision_mode,
270
        }
271 272 273 274 275 276 277
        # if multi trainers
        if rank_table_file and fleet.world_size() > 1:
            config["ge.exec.rankTableFile"] = rank_table_file
            config["ge.exec.rankId"] = str(fleet.worker_index())
            config["ge.exec.isUseHcom"] = "1"
            config["ge.exec.deployMode"] = "0"
        print("ge_initialize config:", config)
278 279 280 281 282 283
        core.ge_initialize(config)

        # Init Session
        self.ascend_instance.init_global_resources()

        main_block = loss.block
284 285 286
        self.parser = AscendIRParser(
            auto_dp=auto_dp, world_rank_size=fleet.world_size()
        )
287 288

        input_varlist = self._get_input_varlist(main_block.program)
289 290

        startup_graph, main_graph = self.parser.parse_program(
291 292
            startup_program, main_block.program, input_varlist, self.fetch_list
        )
293

294
        for cfg in self.parser.groups_to_create:
295 296 297 298
            print(
                "create group (%s), nranks: %d, rank_ids: %s"
                % (cfg.name, cfg.nranks, cfg.rank_ids)
            )
299 300
            hccl.create_group(cfg.name, cfg.nranks, cfg.rank_ids)

301 302 303 304
        self.ascend_instance.add_ascend_subgraph(0, startup_graph)
        self.ascend_instance.add_ascend_subgraph(1, main_graph)

        return minimized