未验证 提交 d2404da7 编写于 作者: V Void Main 提交者: GitHub

Build praser for Hcom* operators (#30627)

Build praser for Hcom* operators
上级 f9c97dd7
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os
import paddle.fluid.framework as framework import paddle.fluid.framework as framework
from paddle.fluid.optimizer import Optimizer from paddle.fluid.optimizer import Optimizer
import paddle.fluid.core as core import paddle.fluid.core as core
...@@ -151,7 +152,12 @@ class AscendOptimizer(Optimizer): ...@@ -151,7 +152,12 @@ class AscendOptimizer(Optimizer):
config = { config = {
"ge.exec.deviceId": str(fleet.rank_in_node()), "ge.exec.deviceId": str(fleet.rank_in_node()),
"ge.graphRunMode": "1", "ge.graphRunMode": "1",
"ge.exec.precision_mode": "must_keep_origin_dtype" "ge.exec.precision_mode": "must_keep_origin_dtype",
# if multi mode
"ge.exec.rankTableFile": os.getenv("RANK_TABLE_FILE"),
"ge.exec.rankId": str(fleet.worker_index()),
"ge.exec.isUseHcom": "1",
"ge.exec.deployMode": "0",
} }
print("ge_initialize config:", config) print("ge_initialize config:", config)
core.ge_initialize(config) core.ge_initialize(config)
......
...@@ -34,7 +34,14 @@ registerd_op = { ...@@ -34,7 +34,14 @@ registerd_op = {
"relu_grad": "ReluGradParser", "relu_grad": "ReluGradParser",
"softmax_with_cross_entropy_grad": "SoftmaxWithCrossEntropyGradParser", "softmax_with_cross_entropy_grad": "SoftmaxWithCrossEntropyGradParser",
"truncated_gaussian_random": "TruncatedNormalParser", "truncated_gaussian_random": "TruncatedNormalParser",
"sgd": "SGDParser" "sgd": "SGDParser",
"c_allgather": "AllGatherParser",
"c_allreduce_sum": "AllReduceSumParser",
"c_allreduce_max": "AllReduceMaxParser",
"c_broadcast": "BroadcastParser",
"c_reduce_scatter": "ReduceScatterParser",
"c_send": "SendParser",
"c_receive": "ReceiveParser"
} }
global_cnt = -1 global_cnt = -1
global_input_cnt = -1 global_input_cnt = -1
...@@ -522,6 +529,135 @@ class TruncatedNormalParser(AscendParserBase): ...@@ -522,6 +529,135 @@ class TruncatedNormalParser(AscendParserBase):
) )
return [truncated_normal], [[0]] #[assign] return [truncated_normal], [[0]] #[assign]
class AllGatherParser(AscendParserBase):
def __init__(self, graph, var2geop):
super(AllGatherParser, self).__init__(graph, var2geop)
self.parser_name = "c_allgather"
def _apply(self):
x = self._get_ge_input(self.op.input_arg_names[0])
rank_size = self.op.attr("rank_size")
group = self.op.attr("group")
allgather = core.GEOperatorFactory.create_operator(
"allgather" + self._accumulated_op_id(), "HcomAllGather").set_input(
"x", x).set_attr_int32(
"rank_size", rank_size).set_attr_string("group", group)
return [allgather], [[0]]
class AllReduceParser(AscendParserBase):
def __init__(self, graph, var2geop, reduction):
super(AllReduceParser, self).__init__(graph, var2geop)
self.parser_name = "c_allreduce_" + reduction
self.reduction = reduction
def _apply(self):
x = self._get_ge_input(self.op.input_arg_names[0])
reduction = self.reduction
group = "hccl_world_group" #self.op.attr("group")
fusion = None #self.op.attr("fusion")
fusion_id = None #self.op.attr("fusion_id")
allreduce = core.GEOperatorFactory.create_operator(
"allreduce" + self._accumulated_op_id(), "HcomAllReduce").set_input(
"x", x).set_attr_string(
"reduction", reduction).set_attr_string("group", group)
if fusion is not None:
allreduce.set_attr_int32("fusion", fusion)
if fusion_id is not None:
allreduce.set_attr_int32("fusion_id", fusion_id)
return [allreduce], [[0]]
class AllReduceSumParser(AllReduceParser):
def __init__(self, graph, var2geop):
super(AllReduceSumParser, self).__init__(graph, var2geop, 'sum')
class AllReduceMaxParser(AllReduceParser):
def __init__(self, graph, var2geop):
super(AllReduceMaxParser, self).__init__(graph, var2geop, 'max')
class BroadcastParser(AscendParserBase):
def __init__(self, graph, var2geop):
super(BroadcastParser, self).__init__(graph, var2geop)
self.parser_name = "c_broadcast"
def _apply(self):
x = self._get_ge_input(self.op.input_arg_names[0])
root_rank = self.op.attr("root_rank")
group = self.op.attr("group")
broadcast = core.GEOperatorFactory.create_operator(
"broadcast" + self._accumulated_op_id(), "HcomBroadcast").set_input(
"x", x).set_attr_int32(
"root_rank", root_rank).set_attr_string("group", group)
return [broadcast], [[0]]
class ReduceScatterParser(AscendParserBase):
def __init__(self, graph, var2geop):
super(ReduceScatterParser, self).__init__(graph, var2geop)
self.parser_name = "c_reduce_scatter"
def _apply(self):
x = self._get_ge_input(self.op.input_arg_names[0])
reduction = self.op.attr("reduction")
group = self.op.attr("group")
rank_size = self.op.attr("rank_size")
reduce_scatter = core.GEOperatorFactory.create_operator(
"reducescatter" + self._accumulated_op_id(), "HcomReduceScatter").set_input(
"x", x).set_attr_string(
"reduction", reduction).set_attr_string(
"group", group).set_attr_int32("rank_size", rank_size)
return [reduce_scatter], [[0]]
class SendParser(AscendParserBase):
def __init__(self, graph, var2geop):
super(SendParser, self).__init__(graph, var2geop)
self.parser_name = "c_send"
def _apply(self):
x = self._get_ge_input(self.op.input_arg_names[0])
sr_tag = self.op.attr("sr_tag")
dest_rank = self.op.attr("dest_rank")
group = self.op.attr("group")
send = core.GEOperatorFactory.create_operator(
"send" + self._accumulated_op_id(), "HcomSend").set_input(
"x", x).set_attr_int32(
"sr_tag", sr_tag).set_attr_int32(
"dest_rank", dest_rank).set_attr_string("group", group)
return [send], [[0]]
class ReceiveParser(AscendParserBase):
def __init__(self, graph, var2geop):
super(ReceiveParser, self).__init__(graph, var2geop)
self.parser_name = "c_receive"
def _apply(self):
x = self._get_ge_input(self.op.input_arg_names[0])
sr_tag = self.op.attr("sr_tag")
src_rank = self.op.attr("src_rank")
group = self.op.attr("group")
shape = self.op.attr("shape")
dtype = self.op.attr("dtype")
receive = core.GEOperatorFactory.create_operator(
"receive" + self._accumulated_op_id(), "HcomReceive").set_input(
"x", x).set_attr_int32(
"sr_tag", sr_tag).set_attr_int32(
"src_rank", src_rank).set_attr_string(
"group", group).set_attr_vec_int32(
"shape", shape).set_attr_int32("dtype", dtype)
return [receive], [[0]]
class ScaleParser(AscendParserBase): class ScaleParser(AscendParserBase):
def __init__(self, graph, var2geop): def __init__(self, graph, var2geop):
super(ScaleParser, self).__init__(graph, var2geop) super(ScaleParser, self).__init__(graph, var2geop)
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
from . import collective from . import collective
from .. import core from .. import core
OpRole = core.op_proto_and_checker_maker.OpRole OpRole = core.op_proto_and_checker_maker.OpRole
from paddle.distributed import fleet
class AscendTranspiler(collective.Collective): class AscendTranspiler(collective.Collective):
def __init__(self, startup_program, main_program): def __init__(self, startup_program, main_program):
...@@ -49,13 +50,22 @@ class AscendTranspiler(collective.Collective): ...@@ -49,13 +50,22 @@ class AscendTranspiler(collective.Collective):
ring_id = (ring_id + 1) % self.nrings ring_id = (ring_id + 1) % self.nrings
block._insert_op( block._insert_op(
offset + 1, offset + 1,
type='allreduce', type='c_allreduce_sum',
inputs={'X': grad}, inputs={'X': grad},
outputs={'Out': grad}, outputs={'Out': grad},
attrs={ attrs={
'ring_id': ring_id, 'ring_id': ring_id,
self.op_role_key: OpRole.Backward self.op_role_key: OpRole.Backward
}) })
block._insert_op(
offset + 2,
type='scale',
inputs={'X': grad},
outputs={'Out': grad},
attrs={
'scale': 1.0 / fleet.worker_num(),
self.op_role_key: OpRole.Backward
})
if grad is None: if grad is None:
return return
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册