未验证 提交 d2404da7 编写于 作者: V Void Main 提交者: GitHub

Build praser for Hcom* operators (#30627)

Build praser for Hcom* operators
上级 f9c97dd7
......@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import paddle.fluid.framework as framework
from paddle.fluid.optimizer import Optimizer
import paddle.fluid.core as core
......@@ -151,7 +152,12 @@ class AscendOptimizer(Optimizer):
config = {
"ge.exec.deviceId": str(fleet.rank_in_node()),
"ge.graphRunMode": "1",
"ge.exec.precision_mode": "must_keep_origin_dtype"
"ge.exec.precision_mode": "must_keep_origin_dtype",
# if multi mode
"ge.exec.rankTableFile": os.getenv("RANK_TABLE_FILE"),
"ge.exec.rankId": str(fleet.worker_index()),
"ge.exec.isUseHcom": "1",
"ge.exec.deployMode": "0",
}
print("ge_initialize config:", config)
core.ge_initialize(config)
......
......@@ -34,7 +34,14 @@ registerd_op = {
"relu_grad": "ReluGradParser",
"softmax_with_cross_entropy_grad": "SoftmaxWithCrossEntropyGradParser",
"truncated_gaussian_random": "TruncatedNormalParser",
"sgd": "SGDParser"
"sgd": "SGDParser",
"c_allgather": "AllGatherParser",
"c_allreduce_sum": "AllReduceSumParser",
"c_allreduce_max": "AllReduceMaxParser",
"c_broadcast": "BroadcastParser",
"c_reduce_scatter": "ReduceScatterParser",
"c_send": "SendParser",
"c_receive": "ReceiveParser"
}
global_cnt = -1
global_input_cnt = -1
......@@ -522,6 +529,135 @@ class TruncatedNormalParser(AscendParserBase):
)
return [truncated_normal], [[0]] #[assign]
class AllGatherParser(AscendParserBase):
def __init__(self, graph, var2geop):
super(AllGatherParser, self).__init__(graph, var2geop)
self.parser_name = "c_allgather"
def _apply(self):
x = self._get_ge_input(self.op.input_arg_names[0])
rank_size = self.op.attr("rank_size")
group = self.op.attr("group")
allgather = core.GEOperatorFactory.create_operator(
"allgather" + self._accumulated_op_id(), "HcomAllGather").set_input(
"x", x).set_attr_int32(
"rank_size", rank_size).set_attr_string("group", group)
return [allgather], [[0]]
class AllReduceParser(AscendParserBase):
def __init__(self, graph, var2geop, reduction):
super(AllReduceParser, self).__init__(graph, var2geop)
self.parser_name = "c_allreduce_" + reduction
self.reduction = reduction
def _apply(self):
x = self._get_ge_input(self.op.input_arg_names[0])
reduction = self.reduction
group = "hccl_world_group" #self.op.attr("group")
fusion = None #self.op.attr("fusion")
fusion_id = None #self.op.attr("fusion_id")
allreduce = core.GEOperatorFactory.create_operator(
"allreduce" + self._accumulated_op_id(), "HcomAllReduce").set_input(
"x", x).set_attr_string(
"reduction", reduction).set_attr_string("group", group)
if fusion is not None:
allreduce.set_attr_int32("fusion", fusion)
if fusion_id is not None:
allreduce.set_attr_int32("fusion_id", fusion_id)
return [allreduce], [[0]]
class AllReduceSumParser(AllReduceParser):
def __init__(self, graph, var2geop):
super(AllReduceSumParser, self).__init__(graph, var2geop, 'sum')
class AllReduceMaxParser(AllReduceParser):
def __init__(self, graph, var2geop):
super(AllReduceMaxParser, self).__init__(graph, var2geop, 'max')
class BroadcastParser(AscendParserBase):
def __init__(self, graph, var2geop):
super(BroadcastParser, self).__init__(graph, var2geop)
self.parser_name = "c_broadcast"
def _apply(self):
x = self._get_ge_input(self.op.input_arg_names[0])
root_rank = self.op.attr("root_rank")
group = self.op.attr("group")
broadcast = core.GEOperatorFactory.create_operator(
"broadcast" + self._accumulated_op_id(), "HcomBroadcast").set_input(
"x", x).set_attr_int32(
"root_rank", root_rank).set_attr_string("group", group)
return [broadcast], [[0]]
class ReduceScatterParser(AscendParserBase):
def __init__(self, graph, var2geop):
super(ReduceScatterParser, self).__init__(graph, var2geop)
self.parser_name = "c_reduce_scatter"
def _apply(self):
x = self._get_ge_input(self.op.input_arg_names[0])
reduction = self.op.attr("reduction")
group = self.op.attr("group")
rank_size = self.op.attr("rank_size")
reduce_scatter = core.GEOperatorFactory.create_operator(
"reducescatter" + self._accumulated_op_id(), "HcomReduceScatter").set_input(
"x", x).set_attr_string(
"reduction", reduction).set_attr_string(
"group", group).set_attr_int32("rank_size", rank_size)
return [reduce_scatter], [[0]]
class SendParser(AscendParserBase):
def __init__(self, graph, var2geop):
super(SendParser, self).__init__(graph, var2geop)
self.parser_name = "c_send"
def _apply(self):
x = self._get_ge_input(self.op.input_arg_names[0])
sr_tag = self.op.attr("sr_tag")
dest_rank = self.op.attr("dest_rank")
group = self.op.attr("group")
send = core.GEOperatorFactory.create_operator(
"send" + self._accumulated_op_id(), "HcomSend").set_input(
"x", x).set_attr_int32(
"sr_tag", sr_tag).set_attr_int32(
"dest_rank", dest_rank).set_attr_string("group", group)
return [send], [[0]]
class ReceiveParser(AscendParserBase):
def __init__(self, graph, var2geop):
super(ReceiveParser, self).__init__(graph, var2geop)
self.parser_name = "c_receive"
def _apply(self):
x = self._get_ge_input(self.op.input_arg_names[0])
sr_tag = self.op.attr("sr_tag")
src_rank = self.op.attr("src_rank")
group = self.op.attr("group")
shape = self.op.attr("shape")
dtype = self.op.attr("dtype")
receive = core.GEOperatorFactory.create_operator(
"receive" + self._accumulated_op_id(), "HcomReceive").set_input(
"x", x).set_attr_int32(
"sr_tag", sr_tag).set_attr_int32(
"src_rank", src_rank).set_attr_string(
"group", group).set_attr_vec_int32(
"shape", shape).set_attr_int32("dtype", dtype)
return [receive], [[0]]
class ScaleParser(AscendParserBase):
def __init__(self, graph, var2geop):
super(ScaleParser, self).__init__(graph, var2geop)
......
......@@ -15,6 +15,7 @@
from . import collective
from .. import core
OpRole = core.op_proto_and_checker_maker.OpRole
from paddle.distributed import fleet
class AscendTranspiler(collective.Collective):
def __init__(self, startup_program, main_program):
......@@ -49,13 +50,22 @@ class AscendTranspiler(collective.Collective):
ring_id = (ring_id + 1) % self.nrings
block._insert_op(
offset + 1,
type='allreduce',
type='c_allreduce_sum',
inputs={'X': grad},
outputs={'Out': grad},
attrs={
'ring_id': ring_id,
self.op_role_key: OpRole.Backward
})
block._insert_op(
offset + 2,
type='scale',
inputs={'X': grad},
outputs={'Out': grad},
attrs={
'scale': 1.0 / fleet.worker_num(),
self.op_role_key: OpRole.Backward
})
if grad is None:
return
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册