From d2404da76868935cc0696cc3d57856022f8615df Mon Sep 17 00:00:00 2001 From: Void Main Date: Thu, 21 Jan 2021 14:14:11 +0800 Subject: [PATCH] Build praser for Hcom* operators (#30627) Build praser for Hcom* operators --- .../ascend/ascend_optimizer.py | 8 +- .../meta_optimizers/ascend/ascend_parser.py | 138 +++++++++++++++++- .../fluid/transpiler/ascend_transpiler.py | 12 +- 3 files changed, 155 insertions(+), 3 deletions(-) diff --git a/python/paddle/distributed/fleet/meta_optimizers/ascend/ascend_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/ascend/ascend_optimizer.py index d99ee9c1e9b..d6ad1b2f2d0 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/ascend/ascend_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/ascend/ascend_optimizer.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os import paddle.fluid.framework as framework from paddle.fluid.optimizer import Optimizer import paddle.fluid.core as core @@ -151,7 +152,12 @@ class AscendOptimizer(Optimizer): config = { "ge.exec.deviceId": str(fleet.rank_in_node()), "ge.graphRunMode": "1", - "ge.exec.precision_mode": "must_keep_origin_dtype" + "ge.exec.precision_mode": "must_keep_origin_dtype", + # if multi mode + "ge.exec.rankTableFile": os.getenv("RANK_TABLE_FILE"), + "ge.exec.rankId": str(fleet.worker_index()), + "ge.exec.isUseHcom": "1", + "ge.exec.deployMode": "0", } print("ge_initialize config:", config) core.ge_initialize(config) diff --git a/python/paddle/distributed/fleet/meta_optimizers/ascend/ascend_parser.py b/python/paddle/distributed/fleet/meta_optimizers/ascend/ascend_parser.py index 7921b3d2216..36fa1575fcd 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/ascend/ascend_parser.py +++ b/python/paddle/distributed/fleet/meta_optimizers/ascend/ascend_parser.py @@ -34,7 +34,14 @@ registerd_op = { "relu_grad": "ReluGradParser", "softmax_with_cross_entropy_grad": "SoftmaxWithCrossEntropyGradParser", "truncated_gaussian_random": "TruncatedNormalParser", - "sgd": "SGDParser" + "sgd": "SGDParser", + "c_allgather": "AllGatherParser", + "c_allreduce_sum": "AllReduceSumParser", + "c_allreduce_max": "AllReduceMaxParser", + "c_broadcast": "BroadcastParser", + "c_reduce_scatter": "ReduceScatterParser", + "c_send": "SendParser", + "c_receive": "ReceiveParser" } global_cnt = -1 global_input_cnt = -1 @@ -522,6 +529,135 @@ class TruncatedNormalParser(AscendParserBase): ) return [truncated_normal], [[0]] #[assign] + +class AllGatherParser(AscendParserBase): + def __init__(self, graph, var2geop): + super(AllGatherParser, self).__init__(graph, var2geop) + self.parser_name = "c_allgather" + + def _apply(self): + x = self._get_ge_input(self.op.input_arg_names[0]) + rank_size = self.op.attr("rank_size") + group = self.op.attr("group") + + allgather = core.GEOperatorFactory.create_operator( + "allgather" + self._accumulated_op_id(), "HcomAllGather").set_input( + "x", x).set_attr_int32( + "rank_size", rank_size).set_attr_string("group", group) + return [allgather], [[0]] + +class AllReduceParser(AscendParserBase): + def __init__(self, graph, var2geop, reduction): + super(AllReduceParser, self).__init__(graph, var2geop) + self.parser_name = "c_allreduce_" + reduction + self.reduction = reduction + + def _apply(self): + x = self._get_ge_input(self.op.input_arg_names[0]) + reduction = self.reduction + group = "hccl_world_group" #self.op.attr("group") + fusion = None #self.op.attr("fusion") + fusion_id = None #self.op.attr("fusion_id") + + allreduce = core.GEOperatorFactory.create_operator( + "allreduce" + self._accumulated_op_id(), "HcomAllReduce").set_input( + "x", x).set_attr_string( + "reduction", reduction).set_attr_string("group", group) + if fusion is not None: + allreduce.set_attr_int32("fusion", fusion) + + if fusion_id is not None: + allreduce.set_attr_int32("fusion_id", fusion_id) + return [allreduce], [[0]] + + +class AllReduceSumParser(AllReduceParser): + def __init__(self, graph, var2geop): + super(AllReduceSumParser, self).__init__(graph, var2geop, 'sum') + + +class AllReduceMaxParser(AllReduceParser): + def __init__(self, graph, var2geop): + super(AllReduceMaxParser, self).__init__(graph, var2geop, 'max') + + +class BroadcastParser(AscendParserBase): + def __init__(self, graph, var2geop): + super(BroadcastParser, self).__init__(graph, var2geop) + self.parser_name = "c_broadcast" + + def _apply(self): + x = self._get_ge_input(self.op.input_arg_names[0]) + root_rank = self.op.attr("root_rank") + group = self.op.attr("group") + + broadcast = core.GEOperatorFactory.create_operator( + "broadcast" + self._accumulated_op_id(), "HcomBroadcast").set_input( + "x", x).set_attr_int32( + "root_rank", root_rank).set_attr_string("group", group) + return [broadcast], [[0]] + + +class ReduceScatterParser(AscendParserBase): + def __init__(self, graph, var2geop): + super(ReduceScatterParser, self).__init__(graph, var2geop) + self.parser_name = "c_reduce_scatter" + + def _apply(self): + x = self._get_ge_input(self.op.input_arg_names[0]) + reduction = self.op.attr("reduction") + group = self.op.attr("group") + rank_size = self.op.attr("rank_size") + + reduce_scatter = core.GEOperatorFactory.create_operator( + "reducescatter" + self._accumulated_op_id(), "HcomReduceScatter").set_input( + "x", x).set_attr_string( + "reduction", reduction).set_attr_string( + "group", group).set_attr_int32("rank_size", rank_size) + return [reduce_scatter], [[0]] + + +class SendParser(AscendParserBase): + def __init__(self, graph, var2geop): + super(SendParser, self).__init__(graph, var2geop) + self.parser_name = "c_send" + + def _apply(self): + x = self._get_ge_input(self.op.input_arg_names[0]) + sr_tag = self.op.attr("sr_tag") + dest_rank = self.op.attr("dest_rank") + group = self.op.attr("group") + + send = core.GEOperatorFactory.create_operator( + "send" + self._accumulated_op_id(), "HcomSend").set_input( + "x", x).set_attr_int32( + "sr_tag", sr_tag).set_attr_int32( + "dest_rank", dest_rank).set_attr_string("group", group) + return [send], [[0]] + + +class ReceiveParser(AscendParserBase): + def __init__(self, graph, var2geop): + super(ReceiveParser, self).__init__(graph, var2geop) + self.parser_name = "c_receive" + + def _apply(self): + x = self._get_ge_input(self.op.input_arg_names[0]) + sr_tag = self.op.attr("sr_tag") + src_rank = self.op.attr("src_rank") + group = self.op.attr("group") + shape = self.op.attr("shape") + dtype = self.op.attr("dtype") + + receive = core.GEOperatorFactory.create_operator( + "receive" + self._accumulated_op_id(), "HcomReceive").set_input( + "x", x).set_attr_int32( + "sr_tag", sr_tag).set_attr_int32( + "src_rank", src_rank).set_attr_string( + "group", group).set_attr_vec_int32( + "shape", shape).set_attr_int32("dtype", dtype) + return [receive], [[0]] + class ScaleParser(AscendParserBase): def __init__(self, graph, var2geop): super(ScaleParser, self).__init__(graph, var2geop) diff --git a/python/paddle/fluid/transpiler/ascend_transpiler.py b/python/paddle/fluid/transpiler/ascend_transpiler.py index ff814161050..61064e9d9a8 100644 --- a/python/paddle/fluid/transpiler/ascend_transpiler.py +++ b/python/paddle/fluid/transpiler/ascend_transpiler.py @@ -15,6 +15,7 @@ from . import collective from .. import core OpRole = core.op_proto_and_checker_maker.OpRole +from paddle.distributed import fleet class AscendTranspiler(collective.Collective): def __init__(self, startup_program, main_program): @@ -49,13 +50,22 @@ class AscendTranspiler(collective.Collective): ring_id = (ring_id + 1) % self.nrings block._insert_op( offset + 1, - type='allreduce', + type='c_allreduce_sum', inputs={'X': grad}, outputs={'Out': grad}, attrs={ 'ring_id': ring_id, self.op_role_key: OpRole.Backward }) + block._insert_op( + offset + 2, + type='scale', + inputs={'X': grad}, + outputs={'Out': grad}, + attrs={ + 'scale': 1.0 / fleet.worker_num(), + self.op_role_key: OpRole.Backward + }) if grad is None: return -- GitLab