ascend_optimizer.py 10.1 KB
Newer Older
H
hutuxian 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# 
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# 
#     http://www.apache.org/licenses/LICENSE-2.0
# 
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
import os
H
hutuxian 已提交
16 17 18 19
import paddle.fluid.framework as framework
from paddle.fluid.optimizer import Optimizer
import paddle.fluid.core as core
import numpy as np
G
gongweibao 已提交
20
from . import ascend_parser
21 22 23
from paddle.distributed import fleet
import hccl.manage.api as hccl
from collections import namedtuple
H
hutuxian 已提交
24

25
HcomGroupConfig = namedtuple('HcomGroupConfig', ['name', 'nranks', 'rank_ids'])
H
hutuxian 已提交
26

G
gongweibao 已提交
27

H
hutuxian 已提交
28
class AscendIRParser(object):
29
    def __init__(self, auto_dp=False, world_rank_size=1):
H
hutuxian 已提交
30
        self.graph_idx = 0
31 32
        self.hcom_endpoints = {}
        self.groups_to_create = []
33 34
        self._auto_dp = auto_dp
        self._world_rank_size = world_rank_size
H
hutuxian 已提交
35 36 37 38 39

    def _construct_input_map(self, input_varlist):
        ret_map = {}
        ge_in_operator = []
        for id, var in enumerate(input_varlist):
G
gongweibao 已提交
40 41 42
            if var.is_data:  # input data
                ge_input = core.GEOperatorFactory.create_operator(
                    var.name, "Data").set_attr_int32("index", id)
H
hutuxian 已提交
43 44
                ret_map[var.name] = ge_input
                ge_in_operator.append(ge_input)
G
gongweibao 已提交
45 46 47 48 49 50 51 52
            else:  # param, learning ...
                ge_input = core.GEOperatorFactory.create_operator(var.name,
                                                                  "Variable")
                ge_input.update_output_desc("y",
                                            core.GETensorDesc(
                                                core.GEShape(var.shape),
                                                core.GEFormat.FORMAT_ND,
                                                core.GEDataType.DT_FLOAT))
H
hutuxian 已提交
53 54 55
                ret_map[var.name] = ge_input
        return ge_in_operator, ret_map

56 57
    def _endpoint_to_world_rank_id(self, endpoint):
        world_endpoints = fleet.worker_endpoints()
G
gongweibao 已提交
58 59
        assert endpoint in world_endpoints, "endpoint (%s) not in worker_endpoints (%s) " % (
            endpoint, fleet.world_device_ids())
60 61
        return world_endpoints.index(endpoint)

H
hutuxian 已提交
62
    def parse_op(self, op):
63 64 65 66 67 68 69 70 71 72 73 74
        if op.type == 'c_gen_nccl_id':
            endpoint = op.attr("endpoint")
            other_endpoints = op.attr("other_endpoints")
            rank = op.attr("rank")

            nccl_id = op.output_arg_names[0]

            # c_gen_nccl_id operator splits endpoints into local endpoint and other_endpoints
            # we should combine these together to produce world_rank_ids 
            self.hcom_endpoints[nccl_id] = other_endpoints[:]
            self.hcom_endpoints[nccl_id].insert(rank, endpoint)

G
gongweibao 已提交
75 76
            print("nccl_id (%s) registered endpoints %s" %
                  (nccl_id, self.hcom_endpoints[nccl_id]))
77 78 79
        elif op.type == 'c_comm_init':
            nccl_id = op.input_arg_names[0]
            nranks = op.attr("nranks")
G
gongweibao 已提交
80 81
            assert nranks == len(self.hcom_endpoints[
                nccl_id]), "nranks doesn't match endpoint count"
82 83 84 85
            rank = op.attr("rank")
            ring_id = op.attr("ring_id")

            group_name = "hcom_group_" + str(ring_id)
G
gongweibao 已提交
86 87 88 89 90 91 92 93 94
            global_rank_ids = [
                self._endpoint_to_world_rank_id(endpoint)
                for endpoint in self.hcom_endpoints[nccl_id]
            ]
            self.groups_to_create.append(
                HcomGroupConfig(
                    name=group_name, nranks=nranks, rank_ids=global_rank_ids))
            print("append to create group: %s, with rank_ids: %s" %
                  (group_name, global_rank_ids))
95
        elif op.type in ascend_parser.registerd_op:
G
gongweibao 已提交
96 97
            op_parser = self.parser_factory.create_parse(
                ascend_parser.registerd_op[op.type])
H
hutuxian 已提交
98 99
            op_parser.apply(op)
        else:
100 101
            assert False, "Op[%s] has not been registered, so we have to skip it" % (
                op.type)
G
gongweibao 已提交
102 103 104 105 106 107

    def _parse_program(self,
                       graph_name,
                       program,
                       input_varlist=[],
                       fetch_list=[]):
H
hutuxian 已提交
108 109 110 111 112 113 114 115 116 117 118 119 120 121
        begin_graph_idx = self.graph_idx
        ge_in_operator = []
        ge_out_operator = []
        self.var2geop = {}

        block = program.global_block()
        if len(block.ops) == 0:
            print("There is no ops in program %s" % (graph_name))
            return []

        graph = core.GEGraph(graph_name)

        ge_in_operator, self.var2geop = self._construct_input_map(input_varlist)

G
gongweibao 已提交
122 123
        self.parser_factory = ascend_parser.AscendParserFactory(graph,
                                                                self.var2geop)
H
hutuxian 已提交
124 125 126 127 128 129 130 131 132 133
        for i, curop in list(enumerate(block.ops)):
            self.parse_op(curop)

        # Set fetch_var for GE
        for e in fetch_list:
            name = e
            if not isinstance(e, str):
                name = e.name
            ge_out_operator.append(self.var2geop[name])

134
        # (Debug) If you want to print back prop vars, append/assign the varname in ge_out_operator here, such as:
H
hutuxian 已提交
135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159
        # if graph_name == "main":
        #     ge_out_operator.append(self.var2geop["reduce_sum_0.tmp_0@GRAD"])

        # Add ops that may be input of a graph, such as const.
        for varname, geop in self.var2geop.items():
            if varname.startswith("geinput"):
                ge_in_operator.append(geop)

        graph.set_inputs(ge_in_operator).set_outputs(ge_out_operator)

        # Remove ops of origin program
        op_num = len(block.ops)
        for i in range(op_num - 1, -1, -1):
            block._remove_op(i)

        input_varlist = [var for var in input_varlist if var.is_data]

        block.append_op(
            type="ascend_trigger",
            inputs={"FeedList": input_varlist},
            outputs={"FetchList": fetch_list},
            attrs={'graph_idx': self.graph_idx})
        self.graph_idx += 1
        return graph

G
gongweibao 已提交
160 161
    def parse_program(self, startup_program, main_program, input_varlist,
                      fetch_list):
H
hutuxian 已提交
162
        startup_graph = self._parse_program("startup", startup_program)
G
gongweibao 已提交
163 164
        main_graph = self._parse_program("main", main_program, input_varlist,
                                         fetch_list)
165 166 167 168 169 170 171 172 173 174 175
        if self._auto_dp and self._world_rank_size > 1:
            assert len(self.groups_to_create
                       ) == 0, "can't parse program under auto_dp mode"

            from paddle.distributed import fleet
            self.groups_to_create.append(
                HcomGroupConfig(
                    name="hcom_group_0",
                    nranks=fleet.world_size(),
                    rank_ids=[x for x in range(fleet.world_size())]))

H
hutuxian 已提交
176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198
        return startup_graph, main_graph


# AscendOptimizer is a wrapper for basic optimizer now
# We will make it part of fleet meta_optimizer in the future
class AscendOptimizer(Optimizer):
    def __init__(self, optimizer, fetch_list=[]):
        self.inner_opt = optimizer
        self.fetch_list = fetch_list

    def __del__(self):
        core.ge_finalize()

    def _can_apply(self):
        if not self.user_defined_strategy.ascend:
            return False
        # TODO(hutuxian): other check here
        return True

    def _disable_strategy(self, dist_strategy):
        dist_strategy.ascend = False
        dist_strategy.ascend_configs = {}

G
gongweibao 已提交
199
    def _get_input_varlist(self, program):
H
hutuxian 已提交
200 201 202 203 204 205 206 207 208 209
        ret_list = []
        for var in program.list_vars():
            if var.is_data or var.persistable:
                ret_list.append(var)
        return ret_list

    def minimize(self,
                 loss,
                 startup_program=None,
                 parameter_list=None,
G
gongweibao 已提交
210
                 no_grad_set=None,
211 212
                 auto_dp=False,
                 rank_table_file=None):
213 214
        minimized = None
        if self.inner_opt:
G
gongweibao 已提交
215 216
            minimized = self.inner_opt.minimize(
                loss, startup_program=startup_program)
H
hutuxian 已提交
217 218 219

        self.ascend_instance = core.AscendInstance()

G
gongweibao 已提交
220
        from paddle.distributed import fleet
221
        if auto_dp and fleet.world_size() > 1:
G
gongweibao 已提交
222
            from paddle.fluid.transpiler import ascend_transpiler
G
gongweibao 已提交
223 224
            t = ascend_transpiler.AscendTranspiler(startup_program,
                                                   loss.block.program)
G
gongweibao 已提交
225
            t.transpile()
226
            #print(loss.block.program)
G
gongweibao 已提交
227

H
hutuxian 已提交
228 229
        # Config about Graph Engine can be found in https://support.huaweicloud.com/
        config = {
230
            "ge.exec.deviceId": str(fleet.local_device_ids()),
H
hutuxian 已提交
231
            "ge.graphRunMode": "1",
232
            "ge.exec.precision_mode": "must_keep_origin_dtype",
H
hutuxian 已提交
233
        }
234 235 236 237 238 239
        # if multi trainers
        if rank_table_file and fleet.world_size() > 1:
            config["ge.exec.rankTableFile"] = rank_table_file
            config["ge.exec.rankId"] = str(fleet.worker_index())
            config["ge.exec.isUseHcom"] = "1"
            config["ge.exec.deployMode"] = "0"
G
gongweibao 已提交
240
        print("ge_initialize config:", config)
H
hutuxian 已提交
241 242 243 244 245 246
        core.ge_initialize(config)

        # Init Session
        self.ascend_instance.init_global_resources()

        main_block = loss.block
247 248
        self.parser = AscendIRParser(
            auto_dp=auto_dp, world_rank_size=fleet.world_size())
H
hutuxian 已提交
249

G
gongweibao 已提交
250
        input_varlist = self._get_input_varlist(main_block.program)
251

H
hutuxian 已提交
252 253 254
        startup_graph, main_graph = self.parser.parse_program(
            startup_program, main_block.program, input_varlist, self.fetch_list)

255
        for cfg in self.parser.groups_to_create:
G
gongweibao 已提交
256 257
            print("create group (%s), nranks: %d, rank_ids: %s" %
                  (cfg.name, cfg.nranks, cfg.rank_ids))
258
            hccl.create_group(cfg.name, cfg.nranks, cfg.rank_ids)
259

H
hutuxian 已提交
260 261 262 263
        self.ascend_instance.add_ascend_subgraph(0, startup_graph)
        self.ascend_instance.add_ascend_subgraph(1, main_graph)

        return minimized