ascend_optimizer.py 9.4 KB
Newer Older
H
hutuxian 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# 
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# 
#     http://www.apache.org/licenses/LICENSE-2.0
# 
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
import os
H
hutuxian 已提交
16 17 18 19
import paddle.fluid.framework as framework
from paddle.fluid.optimizer import Optimizer
import paddle.fluid.core as core
import numpy as np
G
gongweibao 已提交
20
from . import ascend_parser
21 22 23
from paddle.distributed import fleet
import hccl.manage.api as hccl
from collections import namedtuple
H
hutuxian 已提交
24

25
HcomGroupConfig = namedtuple('HcomGroupConfig', ['name', 'nranks', 'rank_ids'])
H
hutuxian 已提交
26

G
gongweibao 已提交
27

H
hutuxian 已提交
28 29 30
class AscendIRParser(object):
    def __init__(self):
        self.graph_idx = 0
31 32
        self.hcom_endpoints = {}
        self.groups_to_create = []
H
hutuxian 已提交
33 34 35 36 37

    def _construct_input_map(self, input_varlist):
        ret_map = {}
        ge_in_operator = []
        for id, var in enumerate(input_varlist):
G
gongweibao 已提交
38 39 40
            if var.is_data:  # input data
                ge_input = core.GEOperatorFactory.create_operator(
                    var.name, "Data").set_attr_int32("index", id)
H
hutuxian 已提交
41 42
                ret_map[var.name] = ge_input
                ge_in_operator.append(ge_input)
G
gongweibao 已提交
43 44 45 46 47 48 49 50
            else:  # param, learning ...
                ge_input = core.GEOperatorFactory.create_operator(var.name,
                                                                  "Variable")
                ge_input.update_output_desc("y",
                                            core.GETensorDesc(
                                                core.GEShape(var.shape),
                                                core.GEFormat.FORMAT_ND,
                                                core.GEDataType.DT_FLOAT))
H
hutuxian 已提交
51 52 53
                ret_map[var.name] = ge_input
        return ge_in_operator, ret_map

54 55
    def _endpoint_to_world_rank_id(self, endpoint):
        world_endpoints = fleet.worker_endpoints()
G
gongweibao 已提交
56 57
        assert endpoint in world_endpoints, "endpoint (%s) not in worker_endpoints (%s) " % (
            endpoint, fleet.world_device_ids())
58 59
        return world_endpoints.index(endpoint)

H
hutuxian 已提交
60
    def parse_op(self, op):
61 62 63 64 65 66 67 68 69 70 71 72
        if op.type == 'c_gen_nccl_id':
            endpoint = op.attr("endpoint")
            other_endpoints = op.attr("other_endpoints")
            rank = op.attr("rank")

            nccl_id = op.output_arg_names[0]

            # c_gen_nccl_id operator splits endpoints into local endpoint and other_endpoints
            # we should combine these together to produce world_rank_ids 
            self.hcom_endpoints[nccl_id] = other_endpoints[:]
            self.hcom_endpoints[nccl_id].insert(rank, endpoint)

G
gongweibao 已提交
73 74
            print("nccl_id (%s) registered endpoints %s" %
                  (nccl_id, self.hcom_endpoints[nccl_id]))
75 76 77
        elif op.type == 'c_comm_init':
            nccl_id = op.input_arg_names[0]
            nranks = op.attr("nranks")
G
gongweibao 已提交
78 79
            assert nranks == len(self.hcom_endpoints[
                nccl_id]), "nranks doesn't match endpoint count"
80 81 82 83
            rank = op.attr("rank")
            ring_id = op.attr("ring_id")

            group_name = "hcom_group_" + str(ring_id)
G
gongweibao 已提交
84 85 86 87 88 89 90 91 92
            global_rank_ids = [
                self._endpoint_to_world_rank_id(endpoint)
                for endpoint in self.hcom_endpoints[nccl_id]
            ]
            self.groups_to_create.append(
                HcomGroupConfig(
                    name=group_name, nranks=nranks, rank_ids=global_rank_ids))
            print("append to create group: %s, with rank_ids: %s" %
                  (group_name, global_rank_ids))
93
        elif op.type in ascend_parser.registerd_op:
H
hutuxian 已提交
94
            print("Op[%s] has been registered, begin to parse it" % (op.type))
G
gongweibao 已提交
95 96
            op_parser = self.parser_factory.create_parse(
                ascend_parser.registerd_op[op.type])
H
hutuxian 已提交
97 98
            op_parser.apply(op)
        else:
G
gongweibao 已提交
99 100 101 102 103 104 105 106
            print("Op[%s] has not been registered, so we have to skip it" %
                  (op.type))

    def _parse_program(self,
                       graph_name,
                       program,
                       input_varlist=[],
                       fetch_list=[]):
H
hutuxian 已提交
107 108 109 110 111 112 113 114 115 116 117 118 119 120
        begin_graph_idx = self.graph_idx
        ge_in_operator = []
        ge_out_operator = []
        self.var2geop = {}

        block = program.global_block()
        if len(block.ops) == 0:
            print("There is no ops in program %s" % (graph_name))
            return []

        graph = core.GEGraph(graph_name)

        ge_in_operator, self.var2geop = self._construct_input_map(input_varlist)

G
gongweibao 已提交
121 122
        self.parser_factory = ascend_parser.AscendParserFactory(graph,
                                                                self.var2geop)
H
hutuxian 已提交
123 124 125 126 127 128 129 130 131 132
        for i, curop in list(enumerate(block.ops)):
            self.parse_op(curop)

        # Set fetch_var for GE
        for e in fetch_list:
            name = e
            if not isinstance(e, str):
                name = e.name
            ge_out_operator.append(self.var2geop[name])

133
        # (Debug) If you want to print back prop vars, append/assign the varname in ge_out_operator here, such as:
H
hutuxian 已提交
134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158
        # if graph_name == "main":
        #     ge_out_operator.append(self.var2geop["reduce_sum_0.tmp_0@GRAD"])

        # Add ops that may be input of a graph, such as const.
        for varname, geop in self.var2geop.items():
            if varname.startswith("geinput"):
                ge_in_operator.append(geop)

        graph.set_inputs(ge_in_operator).set_outputs(ge_out_operator)

        # Remove ops of origin program
        op_num = len(block.ops)
        for i in range(op_num - 1, -1, -1):
            block._remove_op(i)

        input_varlist = [var for var in input_varlist if var.is_data]

        block.append_op(
            type="ascend_trigger",
            inputs={"FeedList": input_varlist},
            outputs={"FetchList": fetch_list},
            attrs={'graph_idx': self.graph_idx})
        self.graph_idx += 1
        return graph

G
gongweibao 已提交
159 160
    def parse_program(self, startup_program, main_program, input_varlist,
                      fetch_list):
H
hutuxian 已提交
161
        startup_graph = self._parse_program("startup", startup_program)
G
gongweibao 已提交
162 163
        main_graph = self._parse_program("main", main_program, input_varlist,
                                         fetch_list)
H
hutuxian 已提交
164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186
        return startup_graph, main_graph


# AscendOptimizer is a wrapper for basic optimizer now
# We will make it part of fleet meta_optimizer in the future
class AscendOptimizer(Optimizer):
    def __init__(self, optimizer, fetch_list=[]):
        self.inner_opt = optimizer
        self.fetch_list = fetch_list

    def __del__(self):
        core.ge_finalize()

    def _can_apply(self):
        if not self.user_defined_strategy.ascend:
            return False
        # TODO(hutuxian): other check here
        return True

    def _disable_strategy(self, dist_strategy):
        dist_strategy.ascend = False
        dist_strategy.ascend_configs = {}

G
gongweibao 已提交
187
    def _get_input_varlist(self, program):
H
hutuxian 已提交
188 189 190 191 192 193 194 195 196 197
        ret_list = []
        for var in program.list_vars():
            if var.is_data or var.persistable:
                ret_list.append(var)
        return ret_list

    def minimize(self,
                 loss,
                 startup_program=None,
                 parameter_list=None,
G
gongweibao 已提交
198 199
                 no_grad_set=None,
                 auto_dp=False):
200 201
        minimized = None
        if self.inner_opt:
G
gongweibao 已提交
202 203
            minimized = self.inner_opt.minimize(
                loss, startup_program=startup_program)
H
hutuxian 已提交
204 205 206

        self.ascend_instance = core.AscendInstance()

G
gongweibao 已提交
207 208 209
        from paddle.distributed import fleet
        if auto_dp and fleet.worker_num() > 1:
            from paddle.fluid.transpiler import ascend_transpiler
G
gongweibao 已提交
210 211
            t = ascend_transpiler.AscendTranspiler(startup_program,
                                                   loss.block.program)
G
gongweibao 已提交
212 213 214
            t.transpile()
            print(loss.block.program)

H
hutuxian 已提交
215 216
        # Config about Graph Engine can be found in https://support.huaweicloud.com/
        config = {
217
            "ge.exec.deviceId": str(fleet.local_device_ids()),
H
hutuxian 已提交
218
            "ge.graphRunMode": "1",
219 220 221 222 223 224
            "ge.exec.precision_mode": "must_keep_origin_dtype",
            # if multi mode
            "ge.exec.rankTableFile": os.getenv("RANK_TABLE_FILE"),
            "ge.exec.rankId": str(fleet.worker_index()),
            "ge.exec.isUseHcom": "1",
            "ge.exec.deployMode": "0",
H
hutuxian 已提交
225
        }
G
gongweibao 已提交
226
        print("ge_initialize config:", config)
H
hutuxian 已提交
227 228 229 230 231 232 233 234
        core.ge_initialize(config)

        # Init Session
        self.ascend_instance.init_global_resources()

        main_block = loss.block
        self.parser = AscendIRParser()

G
gongweibao 已提交
235
        input_varlist = self._get_input_varlist(main_block.program)
236

H
hutuxian 已提交
237 238 239
        startup_graph, main_graph = self.parser.parse_program(
            startup_program, main_block.program, input_varlist, self.fetch_list)

240 241
        for cfg in self.parser.groups_to_create:
            hccl.create_group(cfg.name, cfg.nranks, cfg.rank_ids)
G
gongweibao 已提交
242 243
            print("create group (%s), nranks: %d, rank_ids: %s" %
                  (cfg.name, cfg.nranks, cfg.rank_ids))
244

H
hutuxian 已提交
245 246 247 248
        self.ascend_instance.add_ascend_subgraph(0, startup_graph)
        self.ascend_instance.add_ascend_subgraph(1, main_graph)

        return minimized