未验证 提交 94c17a0f 编写于 作者: C caozhou 提交者: GitHub

[Auto Parallel] Add mul dist op cost (#44973)

* add mul dist op cost

* add mul unittest
上级 2c77b575
...@@ -1246,6 +1246,108 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl): ...@@ -1246,6 +1246,108 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl):
self._forward_implemented = True self._forward_implemented = True
self._backward_implemented = True self._backward_implemented = True
def calc_cost(self, op_role, dist_op, ctx, cluster):
cost = None
if int(op_role) == int(OpRole.Forward):
cost = self.calc_fwd_cost(dist_op, ctx, cluster)
elif int(op_role) == int(OpRole.Backward):
cost = self.calc_bwd_cost(dist_op, ctx, cluster)
assert cost is not None
return cost
def calc_bwd_cost(self, dist_op, ctx, cluster):
# by now the backward function only insert the gradient allreduce for dist op itself
res = []
backward_op = dist_op.serial_op
dist_attr = dist_op.dist_attr
main_block = backward_op.block
vars = main_block.vars
Y_var_dim_mapping = dist_attr.get_input_dims_mapping(
backward_op.input("Y")[0])
process_mesh = dist_attr.process_mesh
processes = process_mesh.processes
# col parallel: matmul + allreduce
assert Y_var_dim_mapping[0] < 0
parallel_axis = Y_var_dim_mapping[1]
has_x_grad = len(backward_op.output("X@GRAD")) > 0
if has_x_grad:
assert len(backward_op.output("X@GRAD")) == 1
# calc comp op cost
desc_mapping = build_comp_desc_from_dist_op(dist_op=dist_op,
dist_context=ctx)
cost_mapping = build_comp_costs_from_descs(MatmulV2GradOpCost, ctx,
processes, desc_mapping,
cluster)
res.append(cost_mapping)
# calc comm op cost
if has_x_grad:
attrs = {"use_calc_stream": True, "use_model_parallel": True}
var_names = backward_op.output("X@GRAD")
c_allreduce_sum_desc_mapping = build_comm_desc_from_dist_op(
"c_allreduce_sum",
dist_op,
ctx,
var_names,
attrs=attrs,
parallel_axis=parallel_axis)
comm_op_cost_list = build_comm_costs_from_descs(
AllreduceSumOpCost, ctx, processes,
c_allreduce_sum_desc_mapping, cluster)
res.append(comm_op_cost_list)
# need gradient allreduce
process_mesh = dist_attr.process_mesh
var_dim_mapping = dist_attr.get_input_dims_mapping(
backward_op.input("X")[0])
mesh_shape = process_mesh.topology
batch_size_axis = var_dim_mapping[0]
if batch_size_axis > -1 and mesh_shape[
batch_size_axis] > 1 and is_parameter_related(
backward_op.input("Y")[0], main_block):
parallel_axis = batch_size_axis
attrs = {"use_calc_stream": True}
var_names = [backward_op.output('Y@GRAD')[0]]
build_dp_costs(res, dist_op, ctx, var_names, attrs, parallel_axis,
cluster)
return res
def calc_fwd_cost(self, dist_op, ctx, cluster):
# calc comp op cost
# TODO: trans shape if trans_x or trans_y is True
comp_desc_mapping = build_comp_desc_from_dist_op(dist_op=dist_op,
dist_context=ctx)
processes = dist_op.dist_attr.process_mesh.processes
comp_cost_mapping = build_comp_costs_from_descs(MatmulV2OpCost, ctx,
processes,
comp_desc_mapping,
cluster)
# calc comm op cost
serial_op = dist_op.serial_op
vars = serial_op.block.vars
parallel_axis = dist_op.dist_attr.get_input_dims_mapping(
serial_op.input("Y")[0])[-1]
attrs = {"use_calc_stream": True, "use_model_parallel": True}
var_names = serial_op.input("X")
c_identity_desc_mapping = build_comm_desc_from_dist_op(
"c_identity",
dist_op,
ctx,
var_names,
attrs=attrs,
parallel_axis=parallel_axis)
comm_op_cost_list = build_comm_costs_from_descs(
IdentityOpCost, ctx, processes, c_identity_desc_mapping, cluster)
res_cost = [comm_op_cost_list, comp_cost_mapping]
return res_cost
def is_input_compatible(self, dist_op): def is_input_compatible(self, dist_op):
op_desc = dist_op.serial_op.desc op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr op_dist_attr = dist_op.dist_attr
...@@ -1468,6 +1570,100 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl): ...@@ -1468,6 +1570,100 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl):
self._forward_implemented = True self._forward_implemented = True
self._backward_implemented = True self._backward_implemented = True
def calc_cost(self, op_role, dist_op, ctx, cluster):
cost = None
if int(op_role) == int(OpRole.Forward):
cost = self.calc_fwd_cost(dist_op, ctx, cluster)
elif int(op_role) == int(OpRole.Backward):
cost = self.calc_bwd_cost(dist_op, ctx, cluster)
assert cost is not None
return cost
def calc_bwd_cost(self, dist_op, ctx, cluster):
# by now the backward function only insert the gradient allreduce for dist op itself
res = []
backward_op = dist_op.serial_op
dist_attr = dist_op.dist_attr
main_block = backward_op.block
vars = main_block.vars
Y_var_dim_mapping = dist_attr.get_input_dims_mapping(
backward_op.input("Y")[0])
assert Y_var_dim_mapping[1] < 0
parallel_axis = Y_var_dim_mapping[0]
process_mesh = dist_attr.process_mesh
processes = process_mesh.processes
# calc comm op cost
var_names = [backward_op.input("Out@GRAD")[0]]
attrs = {"use_calc_stream": True, "use_model_parallel": True}
c_identity_desc_mapping = build_comm_desc_from_dist_op(
"c_identity",
dist_op,
ctx,
var_names,
attrs=attrs,
parallel_axis=parallel_axis)
comm_op_cost_list = build_comm_costs_from_descs(
IdentityOpCost, ctx, processes, c_identity_desc_mapping, cluster)
res.append(comm_op_cost_list)
# calc comp op cost
desc_mapping = build_comp_desc_from_dist_op(dist_op=dist_op,
dist_context=ctx)
cost_mapping = build_comp_costs_from_descs(MatmulV2GradOpCost, ctx,
processes, desc_mapping,
cluster)
res.append(cost_mapping)
# need gradient allreduce
process_mesh = dist_attr.process_mesh
var_dim_mapping = dist_attr.get_input_dims_mapping(
backward_op.input("X")[0])
mesh_shape = process_mesh.topology
batch_size_axis = var_dim_mapping[0]
if batch_size_axis > -1 and mesh_shape[
batch_size_axis] > 1 and is_parameter_related(
backward_op.input("Y")[0], main_block):
parallel_axis = batch_size_axis
attrs = {"use_calc_stream": True}
var_names = [backward_op.output('Y@GRAD')[0]]
build_dp_costs(res, dist_op, ctx, var_names, attrs, parallel_axis,
cluster)
return res
def calc_fwd_cost(self, dist_op, ctx, cluster):
# calc comp op cost
desc_mapping = build_comp_desc_from_dist_op(dist_op=dist_op,
dist_context=ctx)
processes = dist_op.dist_attr.process_mesh.processes
cost_mapping = build_comp_costs_from_descs(MatmulV2OpCost, ctx,
processes, desc_mapping,
cluster)
# calc comm op cost
serial_op = dist_op.serial_op
vars = serial_op.block.vars
parallel_axis = dist_op.dist_attr.get_input_dims_mapping(
serial_op.input("Y")[0])[-2]
attrs = {"use_calc_stream": True, "use_model_parallel": True}
var_names = serial_op.output("Out")
c_allreduce_sum_desc_mapping = build_comm_desc_from_dist_op(
"c_allreduce_sum",
dist_op,
ctx,
var_names,
attrs=attrs,
parallel_axis=parallel_axis)
comm_op_cost_list = build_comm_costs_from_descs(
AllreduceSumOpCost, ctx, processes, c_allreduce_sum_desc_mapping,
cluster)
res_cost = [cost_mapping, comm_op_cost_list]
return res_cost
def is_input_compatible(self, dist_op): def is_input_compatible(self, dist_op):
op_desc = dist_op.serial_op.desc op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr op_dist_attr = dist_op.dist_attr
...@@ -1677,6 +1873,61 @@ class DistributedMatmulV2Impl2(DistributedOperatorImpl): ...@@ -1677,6 +1873,61 @@ class DistributedMatmulV2Impl2(DistributedOperatorImpl):
def __init__(self, name): def __init__(self, name):
super(DistributedMatmulV2Impl2, self).__init__(name) super(DistributedMatmulV2Impl2, self).__init__(name)
def calc_cost(self, op_role, dist_op, ctx, cluster):
cost = None
if int(op_role) == int(OpRole.Forward):
cost = self.calc_fwd_cost(dist_op, ctx, cluster)
elif int(op_role) == int(OpRole.Backward):
cost = self.calc_bwd_cost(dist_op, ctx, cluster)
assert cost is not None
return cost
def calc_bwd_cost(self, dist_op, ctx, cluster):
res = []
backward_op = dist_op.serial_op
dist_attr = dist_op.dist_attr
main_block = backward_op.block
vars = main_block.vars
process_mesh = dist_attr.process_mesh
# calc comp op cost
desc_mapping = build_comp_desc_from_dist_op(dist_op=dist_op,
dist_context=ctx)
processes = process_mesh.processes
cost_mapping = build_comp_costs_from_descs(MatmulV2GradOpCost, ctx,
processes, desc_mapping,
cluster)
res.append(cost_mapping)
# need gradient allreduce
var_dim_mapping = dist_attr.get_input_dims_mapping(
backward_op.input("X")[0])
mesh_shape = process_mesh.topology
batch_size_axis = var_dim_mapping[0]
if batch_size_axis > -1 and mesh_shape[
batch_size_axis] > 1 and is_parameter_related(
backward_op.input("Y")[0], main_block):
parallel_axis = batch_size_axis
attrs = {"use_calc_stream": True}
var_names = [backward_op.output('Y@GRAD')[0]]
build_dp_costs(res, dist_op, ctx, var_names, attrs, parallel_axis,
cluster)
return res
def calc_fwd_cost(self, dist_op, ctx, cluster):
# calc comp op cost
desc_mapping = build_comp_desc_from_dist_op(dist_op=dist_op,
dist_context=ctx)
processes = dist_op.dist_attr.process_mesh.processes
cost_mapping = build_comp_costs_from_descs(MatmulV2OpCost, ctx,
processes, desc_mapping,
cluster)
res_cost = [cost_mapping]
return res_cost
def is_input_compatible(self, dist_op): def is_input_compatible(self, dist_op):
op_desc = dist_op.serial_op.desc op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr op_dist_attr = dist_op.dist_attr
...@@ -1765,6 +2016,102 @@ class DistributedMulImpl0(DistributedOperatorImpl): ...@@ -1765,6 +2016,102 @@ class DistributedMulImpl0(DistributedOperatorImpl):
self._forward_implemented = True self._forward_implemented = True
self._backward_implemented = True self._backward_implemented = True
def calc_cost(self, op_role, dist_op, ctx, cluster):
cost = None
if int(op_role) == int(OpRole.Forward):
cost = self.calc_fwd_cost(dist_op, ctx, cluster)
elif int(op_role) == int(OpRole.Backward):
cost = self.calc_bwd_cost(dist_op, ctx, cluster)
assert cost is not None
return cost
def calc_bwd_cost(self, dist_op, ctx, cluster):
# by now the backward function only insert the gradient allreduce for dist op itself
res = []
backward_op = dist_op.serial_op
dist_attr = dist_op.dist_attr
main_block = backward_op.block
vars = main_block.vars
Y_var_dim_mapping = dist_attr.get_input_dims_mapping(
backward_op.input("Y")[0])
# col parallel: matmul + allreduce
assert Y_var_dim_mapping[0] < 0
parallel_axis = Y_var_dim_mapping[1]
has_x_grad = len(backward_op.output("X@GRAD")) > 0
if has_x_grad:
assert len(backward_op.output("X@GRAD")) == 1
# calc comp op cost
desc_mapping = build_comp_desc_from_dist_op(dist_op=dist_op,
dist_context=ctx)
process_mesh = dist_attr.process_mesh
processes = process_mesh.processes
cost_mapping = build_comp_costs_from_descs(MulGradOpCost, ctx,
processes, desc_mapping,
cluster)
res.append(cost_mapping)
# calc comm op cost
if has_x_grad:
attrs = {"use_calc_stream": True, "use_model_parallel": True}
var_names = backward_op.output("X@GRAD")
c_allreduce_sum_desc_mapping = build_comm_desc_from_dist_op(
"c_allreduce_sum",
dist_op,
ctx,
var_names,
attrs=attrs,
parallel_axis=parallel_axis)
comm_op_cost_list = build_comm_costs_from_descs(
AllreduceSumOpCost, ctx, processes,
c_allreduce_sum_desc_mapping, cluster)
res.append(comm_op_cost_list)
# need gradient allreduce
var_dim_mapping = dist_attr.get_input_dims_mapping(
backward_op.input("X")[0])
mesh_shape = process_mesh.topology
batch_size_axis = var_dim_mapping[0]
if batch_size_axis > -1 and mesh_shape[
batch_size_axis] > 1 and is_parameter_related(
backward_op.input("Y")[0], main_block):
parallel_axis = batch_size_axis
attrs = {"use_calc_stream": True}
var_names = [backward_op.output('Y@GRAD')[0]]
build_dp_costs(res, dist_op, ctx, var_names, attrs, parallel_axis,
cluster)
return res
def calc_fwd_cost(self, dist_op, ctx, cluster):
# calc comp op cost
desc_mapping = build_comp_desc_from_dist_op(dist_op=dist_op,
dist_context=ctx)
processes = dist_op.dist_attr.process_mesh.processes
cost_mapping = build_comp_costs_from_descs(MulOpCost, ctx, processes,
desc_mapping, cluster)
# calc comm op cost
serial_op = dist_op.serial_op
vars = serial_op.block.vars
parallel_axis = dist_op.dist_attr.get_input_dims_mapping(
serial_op.input("Y")[0])[-1]
attrs = {"use_calc_stream": True, "use_model_parallel": True}
var_names = serial_op.input("X")
c_identity_desc_mapping = build_comm_desc_from_dist_op(
"c_identity",
dist_op,
ctx,
var_names,
attrs=attrs,
parallel_axis=parallel_axis)
comm_op_cost_list = build_comm_costs_from_descs(
IdentityOpCost, ctx, processes, c_identity_desc_mapping, cluster)
res_cost = [comm_op_cost_list, cost_mapping]
return res_cost
def is_input_compatible(self, dist_op): def is_input_compatible(self, dist_op):
op_desc = dist_op.serial_op.desc op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr op_dist_attr = dist_op.dist_attr
...@@ -1916,7 +2263,24 @@ class DistributedMulImpl0(DistributedOperatorImpl): ...@@ -1916,7 +2263,24 @@ class DistributedMulImpl0(DistributedOperatorImpl):
"y_num_col_dims": src_op.desc.attr("y_num_col_dims"), "y_num_col_dims": src_op.desc.attr("y_num_col_dims"),
OP_ROLE_KEY: src_op.attr('op_role') OP_ROLE_KEY: src_op.attr('op_role')
} }
inputs = {'X': [intermediate_var_0], 'Y': [Weight_var]} inputs = {'X': intermediate_var_0, 'Y': Weight_var}
inputs_ref_shape = {}
inputs_original_shape = {}
for var_name in inputs:
if var_name == "X":
var = X_var
else:
var = inputs[var_name]
inputs_original_shape[var_name] = var.shape
input_tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(var)
input_var_dist_attr = op_dist_attr.get_input_dist_attr(var.name)
input_ref_shape = infer_shape(main_block, var,
input_tensor_dist_attr,
input_var_dist_attr)
inputs_ref_shape[var_name] = input_ref_shape
var.desc.set_shape(input_ref_shape)
mul_op = main_block.append_op(type='mul', mul_op = main_block.append_op(type='mul',
inputs=inputs, inputs=inputs,
outputs={'Out': Out_var}, outputs={'Out': Out_var},
...@@ -1924,6 +2288,11 @@ class DistributedMulImpl0(DistributedOperatorImpl): ...@@ -1924,6 +2288,11 @@ class DistributedMulImpl0(DistributedOperatorImpl):
if Out_var.shape != ref_shape_out: if Out_var.shape != ref_shape_out:
Out_var.desc.set_shape(ref_shape_out) Out_var.desc.set_shape(ref_shape_out)
for var_name in inputs:
var = inputs[var_name]
original_shape = inputs_original_shape[var_name]
var.desc.set_shape(original_shape)
# set dist op's dist_attr with serial op's dist_attr # set dist op's dist_attr with serial op's dist_attr
# c_identity # c_identity
identity_op_dist_attr = OperatorDistributedAttribute() identity_op_dist_attr = OperatorDistributedAttribute()
...@@ -1988,6 +2357,100 @@ class DistributedMulImpl1(DistributedOperatorImpl): ...@@ -1988,6 +2357,100 @@ class DistributedMulImpl1(DistributedOperatorImpl):
self._forward_implemented = True self._forward_implemented = True
self._backward_implemented = True self._backward_implemented = True
def calc_cost(self, op_role, dist_op, ctx, cluster):
cost = None
if int(op_role) == int(OpRole.Forward):
cost = self.calc_fwd_cost(dist_op, ctx, cluster)
elif int(op_role) == int(OpRole.Backward):
cost = self.calc_bwd_cost(dist_op, ctx, cluster)
assert cost is not None
return cost
def calc_bwd_cost(self, dist_op, ctx, cluster):
# by now the backward function only insert the gradient allreduce for dist op itself
res = []
backward_op = dist_op.serial_op
dist_attr = dist_op.dist_attr
process_mesh = dist_attr.process_mesh
main_block = backward_op.block
vars = main_block.vars
Y_var_dim_mapping = dist_attr.get_input_dims_mapping(
backward_op.input("Y")[0])
assert Y_var_dim_mapping[1] < 0
parallel_axis = Y_var_dim_mapping[0]
# calc comm op cost
var_names = [backward_op.input("Out@GRAD")[0]]
attrs = {"use_calc_stream": True, "use_model_parallel": True}
c_identity_desc_mapping = build_comm_desc_from_dist_op(
"c_identity",
dist_op,
ctx,
var_names,
attrs=attrs,
parallel_axis=parallel_axis)
processes = process_mesh.processes
comm_op_cost_list = build_comm_costs_from_descs(
IdentityOpCost, ctx, processes, c_identity_desc_mapping, cluster)
res.append(comm_op_cost_list)
# calc comp op cost
desc_mapping = build_comp_desc_from_dist_op(dist_op=dist_op,
dist_context=ctx)
cost_mapping = build_comp_costs_from_descs(MulGradOpCost, ctx,
processes, desc_mapping,
cluster)
res.append(cost_mapping)
# need gradient allreduce
var_dim_mapping = dist_attr.get_input_dims_mapping(
backward_op.input("X")[0])
mesh_shape = process_mesh.topology
batch_size_axis = var_dim_mapping[0]
if batch_size_axis > -1 and mesh_shape[
batch_size_axis] > 1 and is_parameter_related(
backward_op.input("Y")[0], main_block):
parallel_axis = batch_size_axis
attrs = {"use_calc_stream": True}
var_names = [backward_op.output('Y@GRAD')[0]]
build_dp_costs(res, dist_op, ctx, var_names, attrs, parallel_axis,
cluster)
return res
def calc_fwd_cost(self, dist_op, ctx, cluster):
# calc comp op cost
desc_mapping = build_comp_desc_from_dist_op(dist_op=dist_op,
dist_context=ctx)
processes = dist_op.dist_attr.process_mesh.processes
cost_mapping = build_comp_costs_from_descs(MulOpCost, ctx, processes,
desc_mapping, cluster)
# calc comm op cost
serial_op = dist_op.serial_op
vars = serial_op.block.vars
parallel_axis = dist_op.dist_attr.get_input_dims_mapping(
serial_op.input("Y")[0])[-2]
attrs = {"use_calc_stream": True, "use_model_parallel": True}
var_names = serial_op.output("Out")
c_allreduce_sum_desc_mapping = build_comm_desc_from_dist_op(
"c_allreduce_sum",
dist_op,
ctx,
var_names,
attrs=attrs,
parallel_axis=parallel_axis)
# print("dist_matmul.py dist_op: ", dist_op)
comm_op_cost_list = build_comm_costs_from_descs(
AllreduceSumOpCost, ctx, processes, c_allreduce_sum_desc_mapping,
cluster)
res_cost = [cost_mapping, comm_op_cost_list]
return res_cost
def is_input_compatible(self, dist_op): def is_input_compatible(self, dist_op):
op_desc = dist_op.serial_op.desc op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr op_dist_attr = dist_op.dist_attr
...@@ -2122,13 +2585,32 @@ class DistributedMulImpl1(DistributedOperatorImpl): ...@@ -2122,13 +2585,32 @@ class DistributedMulImpl1(DistributedOperatorImpl):
ctx.set_tensor_dist_attr_for_program(intermediate_var_0, ctx.set_tensor_dist_attr_for_program(intermediate_var_0,
out_var_dist_attr) out_var_dist_attr)
inputs_ref_shape = {}
inputs_original_shape = {}
for var_name in inputs:
var = inputs[var_name]
inputs_original_shape[var_name] = var.shape
input_tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(var)
input_var_dist_attr = op_dist_attr.get_input_dist_attr(var.name)
input_ref_shape = infer_shape(main_block, var,
input_tensor_dist_attr,
input_var_dist_attr)
inputs_ref_shape[var_name] = input_ref_shape
var.desc.set_shape(input_ref_shape)
mul_op = main_block.append_op(type='mul', mul_op = main_block.append_op(type='mul',
inputs=inputs, inputs=inputs,
outputs={'Out': intermediate_var_0}, outputs={'Out': intermediate_var_0},
attrs=attrs) attrs=attrs)
if intermediate_var_0.shape != ref_shape: if intermediate_var_0.shape != ref_shape:
intermediate_var_0.desc.set_shape(ref_shape) intermediate_var_0.desc.set_shape(ref_shape)
for var_name in inputs:
var = inputs[var_name]
original_shape = inputs_original_shape[var_name]
var.desc.set_shape(original_shape)
c_allreduce_sum_op = main_block.append_op( c_allreduce_sum_op = main_block.append_op(
type='c_allreduce_sum', type='c_allreduce_sum',
inputs={'X': intermediate_var_0}, inputs={'X': intermediate_var_0},
...@@ -2139,6 +2621,7 @@ class DistributedMulImpl1(DistributedOperatorImpl): ...@@ -2139,6 +2621,7 @@ class DistributedMulImpl1(DistributedOperatorImpl):
'use_model_parallel': True, 'use_model_parallel': True,
OP_ROLE_KEY: src_op.attr('op_role') OP_ROLE_KEY: src_op.attr('op_role')
}) })
if Out_var.shape != ref_shape: if Out_var.shape != ref_shape:
Out_var.desc.set_shape(ref_shape) Out_var.desc.set_shape(ref_shape)
...@@ -2198,6 +2681,59 @@ class DistributedMulImpl2(DistributedOperatorImpl): ...@@ -2198,6 +2681,59 @@ class DistributedMulImpl2(DistributedOperatorImpl):
def __init__(self, name): def __init__(self, name):
super(DistributedMulImpl2, self).__init__(name) super(DistributedMulImpl2, self).__init__(name)
def calc_cost(self, op_role, dist_op, ctx, cluster):
cost = None
if int(op_role) == int(OpRole.Forward):
cost = self.calc_fwd_cost(dist_op, ctx, cluster)
elif int(op_role) == int(OpRole.Backward):
cost = self.calc_bwd_cost(dist_op, ctx, cluster)
assert cost is not None
return cost
def calc_bwd_cost(self, dist_op, ctx, cluster):
res = []
backward_op = dist_op.serial_op
dist_attr = dist_op.dist_attr
main_block = backward_op.block
vars = main_block.vars
# calc comp op cost
desc_mapping = build_comp_desc_from_dist_op(dist_op=dist_op,
dist_context=ctx)
process_mesh = dist_attr.process_mesh
processes = process_mesh.processes
cost_mapping = build_comp_costs_from_descs(MulGradOpCost, ctx,
processes, desc_mapping,
cluster)
res.append(cost_mapping)
# need gradient allreduce
var_dim_mapping = dist_attr.get_input_dims_mapping(
backward_op.input("X")[0])
mesh_shape = process_mesh.topology
batch_size_axis = var_dim_mapping[0]
if batch_size_axis > -1 and mesh_shape[
batch_size_axis] > 1 and is_parameter_related(
backward_op.input("Y")[0], main_block):
parallel_axis = batch_size_axis
attrs = {"use_calc_stream": True}
var_names = [backward_op.output('Y@GRAD')[0]]
build_dp_costs(res, dist_op, ctx, var_names, attrs, parallel_axis,
cluster)
return res
def calc_fwd_cost(self, dist_op, ctx, cluster):
# calc comp op cost
desc_mapping = build_comp_desc_from_dist_op(dist_op=dist_op,
dist_context=ctx)
processes = dist_op.dist_attr.process_mesh.processes
cost_mapping = build_comp_costs_from_descs(MulOpCost, ctx, processes,
desc_mapping, cluster)
res_cost = [cost_mapping]
return res_cost
def is_input_compatible(self, dist_op): def is_input_compatible(self, dist_op):
op_desc = dist_op.serial_op.desc op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr op_dist_attr = dist_op.dist_attr
......
...@@ -215,6 +215,215 @@ class TestDistOpCost(unittest.TestCase): ...@@ -215,6 +215,215 @@ class TestDistOpCost(unittest.TestCase):
dist_context, cluster) dist_context, cluster)
self.assertTrue(dist_op_cost) self.assertTrue(dist_op_cost)
def test_dist_op_cost_part3(self):
def make_program():
main_program = paddle.static.Program()
start_program = paddle.static.Program()
with paddle.static.program_guard(main_program, start_program):
x = paddle.static.data(name='x', shape=[4], dtype='float32')
x.stop_gradient = True
label = paddle.static.data(name="label",
shape=[8, 1],
dtype='float32')
label.stop_gradient = True
auto.shard_tensor(x,
dist_attr={
"process_mesh": auto.ProcessMesh([0, 1]),
"dims_mapping": [0]
})
auto.shard_tensor(label,
dist_attr={
"process_mesh": auto.ProcessMesh([0, 1]),
"dims_mapping": [0, -1]
})
# embedding
tmp = paddle.fluid.layers.fill_constant_batch_size_like(
input=x, shape=[4], value=1, dtype='int32')
embedding = paddle.nn.Embedding(10, 8)
out = embedding(tmp)
# row parallel embedding
for op in main_program.global_block().ops:
if op.type == "lookup_table_v2":
W = main_program.global_block().vars[op.input("W")[0]]
auto.shard_tensor(W,
dist_attr={
"process_mesh":
auto.ProcessMesh([0, 1]),
"dims_mapping": [0, -1]
})
out = paddle.fluid.layers.transpose(out,
[1, 0]) # [8, 2] [-1, 0]
# matmul_v2
param1 = paddle.fluid.layers.create_parameter(
[4, 8], paddle.float32) # [2, 8] [0, -1]
auto.shard_tensor(param1,
dist_attr={
"process_mesh": auto.ProcessMesh([0, 1]),
"dims_mapping": [0, -1]
})
param2 = paddle.fluid.layers.create_parameter(
[8, 8], paddle.float32) # [8, 4] [-1, 0]
auto.shard_tensor(param2,
dist_attr={
"process_mesh": auto.ProcessMesh([0, 1]),
"dims_mapping": [-1, 0]
})
out1 = paddle.matmul(out, param1) # [8, 8] [-1, -1]
tmp_param = paddle.fluid.layers.create_parameter(
[8, 8], paddle.float32) # [8, 8] [-1, -1]
auto.shard_tensor(param2,
dist_attr={
"process_mesh": auto.ProcessMesh([0, 1]),
"dims_mapping": [-1, -1]
})
tmp_out = paddle.matmul(out1, tmp_param)
out2 = paddle.matmul(tmp_out, param2) # [8, 4] [-1, 0]
out8 = paddle.fluid.layers.transpose(out2,
[1, 0]) # [4, 8] [0, -1]
# reshape
out9 = paddle.reshape(out8, [8, 2, 4]) # [4, 2, 4] [0, -1, -1]
tmp_reshape_out = paddle.reshape(out9, [8, 4, 2])
out10 = paddle.reshape(tmp_reshape_out,
[8, 8]) # [4, 8] [0, -1]
# softmax
softmax = paddle.nn.Softmax()
out11 = softmax(out10)
error_cost = paddle.nn.functional.square_error_cost(
out11, label)
loss = paddle.mean(error_cost)
return main_program, start_program, loss
main_program, dist_context = parallelizer(make_program, 0)
ops = main_program.global_block().ops
cluster = Cluster()
cluster.gen_default_config_cluster(device_count=2)
for idx, op in enumerate(ops):
dist_op = dist_context.get_dist_op_for_program(op)
op_dist_attr = dist_op.dist_attr
processes = op_dist_attr.process_mesh.processes
if is_elementwise_op(op.type):
container = get_distributed_operator_impl_container(
"elementwise")
else:
container = get_distributed_operator_impl_container(
op_dist_attr.impl_type)
dist_impl = container.impls[op_dist_attr.impl_idx]
dist_op_cost = dist_impl.calc_cost(op.attr('op_role'), dist_op,
dist_context, cluster)
self.assertTrue(dist_op_cost)
def test_dist_op_cost_part4(self):
def make_program():
main_program = paddle.static.Program()
start_program = paddle.static.Program()
with paddle.static.program_guard(main_program, start_program):
x = paddle.static.data(name='x', shape=[4], dtype='float32')
x.stop_gradient = True
label = paddle.static.data(name="label",
shape=[8, 1],
dtype='float32')
label.stop_gradient = True
auto.shard_tensor(x,
dist_attr={
"process_mesh": auto.ProcessMesh([0, 1]),
"dims_mapping": [0]
})
auto.shard_tensor(label,
dist_attr={
"process_mesh": auto.ProcessMesh([0, 1]),
"dims_mapping": [0, -1]
})
# embedding
tmp = paddle.fluid.layers.fill_constant_batch_size_like(
input=x, shape=[4], value=1, dtype='int32')
embedding = paddle.nn.Embedding(10, 8)
out = embedding(tmp)
# row parallel embedding
for op in main_program.global_block().ops:
if op.type == "lookup_table_v2":
W = main_program.global_block().vars[op.input("W")[0]]
auto.shard_tensor(W,
dist_attr={
"process_mesh":
auto.ProcessMesh([0, 1]),
"dims_mapping": [0, -1]
})
out = paddle.fluid.layers.transpose(out,
[1, 0]) # [8, 2] [-1, 0]
# mul
param1 = paddle.fluid.layers.create_parameter(
[4, 8], paddle.float32) # [2, 8] [0, -1]
auto.shard_tensor(param1,
dist_attr={
"process_mesh": auto.ProcessMesh([0, 1]),
"dims_mapping": [0, -1]
})
param2 = paddle.fluid.layers.create_parameter(
[8, 8], paddle.float32) # [8, 4] [-1, 0]
auto.shard_tensor(param2,
dist_attr={
"process_mesh": auto.ProcessMesh([0, 1]),
"dims_mapping": [-1, 0]
})
out1 = paddle.fluid.layers.mul(out, param1) # [8, 8] [-1, -1]
tmp_param = paddle.fluid.layers.create_parameter(
[8, 8], paddle.float32) # [8, 8] [-1, -1]
auto.shard_tensor(param2,
dist_attr={
"process_mesh": auto.ProcessMesh([0, 1]),
"dims_mapping": [-1, -1]
})
tmp_out = paddle.fluid.layers.mul(out1, tmp_param)
out2 = paddle.fluid.layers.mul(tmp_out,
param2) # [8, 4] [-1, 0]
out8 = paddle.fluid.layers.transpose(out2,
[1, 0]) # [4, 8] [0, -1]
# reshape
out9 = paddle.reshape(out8, [8, 2, 4]) # [4, 2, 4] [0, -1, -1]
tmp_reshape_out = paddle.reshape(out9, [8, 4, 2])
out10 = paddle.reshape(tmp_reshape_out,
[8, 8]) # [4, 8] [0, -1]
# softmax
softmax = paddle.nn.Softmax()
out11 = softmax(out10)
error_cost = paddle.nn.functional.square_error_cost(
out11, label)
loss = paddle.mean(error_cost)
return main_program, start_program, loss
main_program, dist_context = parallelizer(make_program, 0)
ops = main_program.global_block().ops
cluster = Cluster()
cluster.gen_default_config_cluster(device_count=2)
for idx, op in enumerate(ops):
dist_op = dist_context.get_dist_op_for_program(op)
op_dist_attr = dist_op.dist_attr
processes = op_dist_attr.process_mesh.processes
if is_elementwise_op(op.type):
container = get_distributed_operator_impl_container(
"elementwise")
else:
container = get_distributed_operator_impl_container(
op_dist_attr.impl_type)
dist_impl = container.impls[op_dist_attr.impl_idx]
dist_op_cost = dist_impl.calc_cost(op.attr('op_role'), dist_op,
dist_context, cluster)
self.assertTrue(dist_op_cost)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -76,6 +76,14 @@ class MLPLayer(nn.Layer): ...@@ -76,6 +76,14 @@ class MLPLayer(nn.Layer):
out = self.linear0(out) out = self.linear0(out)
out = F.gelu(out, approximate=True) out = F.gelu(out, approximate=True)
out = self.linear1(out) out = self.linear1(out)
param = paddle.fluid.layers.create_parameter([1024, 4096],
paddle.float32)
auto.shard_tensor(param,
dist_attr={
"process_mesh": PP_MESH_1,
"dims_mapping": [-1, 1]
})
out = paddle.fluid.layers.mul(out, param)
return out return out
......
...@@ -93,6 +93,14 @@ class MLPLayer(nn.Layer): ...@@ -93,6 +93,14 @@ class MLPLayer(nn.Layer):
}) })
w_out = self.word_embeddings(input) w_out = self.word_embeddings(input)
out = self.linear0(w_out) out = self.linear0(w_out)
param = paddle.fluid.layers.create_parameter([4096, 4096],
paddle.float32)
auto.shard_tensor(param,
dist_attr={
"process_mesh": PP_MESH_0,
"dims_mapping": [0, -1]
})
out = paddle.fluid.layers.mul(out, param)
gelu_out = F.gelu(out, approximate=True) gelu_out = F.gelu(out, approximate=True)
out = self.linear1(gelu_out) out = self.linear1(gelu_out)
out1 = self.linear2(gelu_out) out1 = self.linear2(gelu_out)
...@@ -228,7 +236,7 @@ class TestMLPReshard(unittest.TestCase): ...@@ -228,7 +236,7 @@ class TestMLPReshard(unittest.TestCase):
resharder = Resharder(dist_main_prog, dist_startup_prog, rank_id, resharder = Resharder(dist_main_prog, dist_startup_prog, rank_id,
dist_context, dist_params_grads) dist_context, dist_params_grads)
resharder.reshard() resharder.reshard()
print_program_with_dist_attr(dist_main_prog, dist_context)
# check send and recv result # check send and recv result
self.assertTrue(check_send_recv_result(dist_main_prog, rank_id)) self.assertTrue(check_send_recv_result(dist_main_prog, rank_id))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册