未验证 提交 ec1e0d5a 编写于 作者: C caozhou 提交者: GitHub

add dist op costs (#44701)

上级 fecbc958
......@@ -12,20 +12,40 @@
# See the License for the specific language governing permissions and
# limitations under the License
from .base_cost import _g_op_cost_factory
from .base_cost import Cost
from .base_cost import CommContext
from .base_cost import _g_op_cost_factory
from .base_cost import build_comm_desc
from .base_cost import build_comp_desc_from_op
from .base_cost import build_comp_desc_from_dist_op
from .base_cost import build_dp_costs
from .base_cost import build_comp_desc_str_for_predict
from .base_cost import build_comp_desc_from_dist_op
from .base_cost import build_comm_desc_from_dist_op
from .base_cost import build_comm_costs_from_descs
from .base_cost import build_comp_costs_from_descs
from .tensor_cost import TensorCost
from .estimate_cost import CostEstimator
from .comp_op_cost import EmbeddingOpCost
from .comp_op_cost import EmbeddingGradOpCost
from .comp_op_cost import ConcatOpCost
from .comp_op_cost import MatmulOpCost
from .comp_op_cost import MatmulGradOpCost
from .comp_op_cost import MatmulV2OpCost
from .comp_op_cost import MatmulV2GradOpCost
from .comp_op_cost import MulOpCost
from .comp_op_cost import MulGradOpCost
from .comp_op_cost import Reshape2OpCost
from .comp_op_cost import Reshape2GradOpCost
from .comp_op_cost import SliceOpCost
from .comp_op_cost import SplitOpCost
from .comp_op_cost import SoftmaxOpCost
from .comp_op_cost import SoftmaxGradOpCost
from .comp_op_cost import Transpose2OpCost
from .comp_op_cost import Transpose2GradOpCost
from .comp_op_cost import FillConstantBatchSizeLikeOpCost
from .tensor_cost import TensorCost
from .estimate_cost import CostEstimator
from .comm_op_cost import SendOpCost
from .comm_op_cost import RecvOpCost
from .comm_op_cost import IdentityOpCost
......
......@@ -15,6 +15,25 @@
from .base_cost import Cost, register_op_cost, CompOpCost, _g_op_cost_factory
@register_op_cost
class AdamOpCost(CompOpCost):
OP_TYPE = "adam"
def __init__(self, op=None, op_desc=None, cluster=None):
super(AdamOpCost, self).__init__(op=op,
op_desc=op_desc,
cluster=cluster)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0
def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0
@register_op_cost
class AssignOpCost(CompOpCost):
OP_TYPE = "assign"
......
......@@ -831,8 +831,10 @@ class DistributedContext:
if (dist_tensor
is not None) and (not dist_tensor.validate_dist_attr()):
assert False, "Tensor {} (id: {}, original_id: {}) has a wrong distributed attributes {}.".format(
dist_tensor.serial_tensor.name, dist_tensor.desc.id(),
dist_tensor.desc.original_id(), dist_tensor.dist_attr)
dist_tensor.serial_tensor.name,
dist_tensor.serial_tensor.desc.id(),
dist_tensor.serial_tensor.desc.original_id(),
dist_tensor.dist_attr)
for op in block.ops:
dist_op = self.get_dist_op_for_program(op)
assert dist_op is not None, \
......
......@@ -31,6 +31,9 @@ from paddle.fluid.data_feeder import check_variable_and_dtype, check_dtype
from paddle.distributed.fleet.meta_optimizers.common import OpRole, OP_ROLE_KEY, OP_ROLE_VAR_KEY
from ..process_group import new_process_group
from ..utils import _get_comm_group, _get_idx_in_axis, _get_corresponding_rank
from ..cost import build_comp_desc_from_dist_op, build_comm_desc_from_dist_op
from ..cost import build_comm_costs_from_descs, build_comp_costs_from_descs, build_dp_costs
from ..cost import EmbeddingOpCost, EmbeddingGradOpCost, AllreduceSumOpCost, IdentityOpCost
class DistributedEmbedding(DistributedOperatorImplContainer):
......@@ -53,6 +56,95 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
self._forward_implemented = True
self._backward_implemented = True
def calc_cost(self, op_role, dist_op, ctx, cluster):
"""Calculate the cost by the op role."""
cost = None
if int(op_role) == int(OpRole.Forward):
cost = self.calc_fwd_cost(dist_op, ctx, cluster)
elif int(op_role) == int(OpRole.Backward):
cost = self.calc_bwd_cost(dist_op, ctx, cluster)
assert cost is not None
return cost
def calc_fwd_cost(self, dist_op, ctx, cluster):
# calc comp op cost
desc_mapping = build_comp_desc_from_dist_op(dist_op=dist_op,
dist_context=ctx)
processes = dist_op.dist_attr.process_mesh.processes
# embedding need start_index
cost_mapping = build_comp_costs_from_descs(EmbeddingOpCost, ctx,
processes, desc_mapping,
cluster)
serial_op = dist_op.serial_op
parallel_axis = dist_op.dist_attr.get_input_dims_mapping(
serial_op.input("W")[0])[0]
attrs = {"use_calc_stream": True, "use_model_parallel": True}
var_names = serial_op.output("Out")
c_allreduce_sum_desc_mapping = build_comm_desc_from_dist_op(
"c_allreduce_sum",
dist_op,
ctx,
var_names,
attrs=attrs,
parallel_axis=parallel_axis)
comm_op_cost_list = build_comm_costs_from_descs(
AllreduceSumOpCost, ctx, processes, c_allreduce_sum_desc_mapping,
cluster)
res_cost = [cost_mapping, comm_op_cost_list]
return res_cost
def calc_bwd_cost(self, dist_op, ctx, cluster):
# by now the backward function only insert the gradient allreduce for dist op itself
res = []
backward_op = dist_op.serial_op
main_block = backward_op.block
dist_attr = dist_op.dist_attr
embedding_row_dim_mapping = dist_attr.get_input_dims_mapping(
backward_op.input("W")[0])[0]
parallel_axis = embedding_row_dim_mapping
attrs = {"use_calc_stream": True, "use_model_parallel": True}
var_names = [backward_op.input("Out@GRAD")[0]]
c_identity_desc_mapping = build_comm_desc_from_dist_op(
"c_identity",
dist_op,
ctx,
var_names,
attrs=attrs,
parallel_axis=parallel_axis)
process_mesh = dist_attr.process_mesh
processes = process_mesh.processes
comm_op_cost_list = build_comm_costs_from_descs(
IdentityOpCost, ctx, processes, c_identity_desc_mapping, cluster)
res.append(comm_op_cost_list)
# calc comp op cost
desc_mapping = build_comp_desc_from_dist_op(dist_op=dist_op,
dist_context=ctx)
cost_mapping = build_comp_costs_from_descs(EmbeddingGradOpCost, ctx,
processes, desc_mapping,
cluster)
res.append(cost_mapping)
# need gradient allreduce
var_dim_mapping = dist_attr.get_input_dims_mapping(
backward_op.input("Ids")[0])
mesh_shape = process_mesh.topology
batch_size_axis = var_dim_mapping[0]
if batch_size_axis > -1 and mesh_shape[batch_size_axis] > 1:
parallel_axis = batch_size_axis
attrs = {"use_calc_stream": True}
var_names = [backward_op.output('W@GRAD')[0]]
build_dp_costs(res, dist_op, ctx, var_names, attrs, parallel_axis,
cluster)
return res
def is_input_compatible(self, dist_op):
op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
......
......@@ -13,6 +13,7 @@
# limitations under the License
import copy
from .common import infer_shape
from .common import DistributedOperatorImplContainer
from .common import DistributedOperatorImpl
......@@ -35,6 +36,10 @@ from paddle.distributed.fleet.meta_optimizers.common import OpRole, OP_ROLE_KEY,
from ..process_group import new_process_group
from ..utils import _get_comm_group, _get_corresponding_rank
from .dist_default import DistributedDefaultImpl0
from ..cost import build_comp_desc_from_dist_op, build_comm_desc_from_dist_op, build_dp_costs
from ..cost import build_comm_costs_from_descs, build_comp_costs_from_descs
from ..cost import MatmulV2OpCost, MatmulOpCost, MulOpCost, IdentityOpCost, AllreduceSumOpCost
from ..cost import MatmulV2GradOpCost, MatmulGradOpCost, MulGradOpCost
def copy_op_with_new_input_output(ctx, block, src_op, **kwargs):
......@@ -58,6 +63,14 @@ def _update_dims_mapping_for_matmul(dist_op):
x_name = op_desc.input('X')[0]
y_name = op_desc.input('Y')[0]
out_name = op_desc.output('Out')[0]
trans_x = None
trans_y = None
if op_desc.type() == "matmul_v2":
trans_x = op_desc.attr('trans_x')
trans_y = op_desc.attr('trans_y')
elif op_desc.type() == "matmul":
trans_x = op_desc.attr('transpose_X')
trans_y = op_desc.attr('transpose_Y')
x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
y_dims_mapping = op_dist_attr.get_input_dims_mapping(y_name)
out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
......@@ -67,27 +80,34 @@ def _update_dims_mapping_for_matmul(dist_op):
# Add dim mapping to Make sure the length dims_mapping be at least 2
if x_dims_mapping_len == 1:
assert trans_x is False
x_dims_mapping.insert(0, -1)
out_dims_mapping.insert(out_dims_mapping_len - 1, 0)
if y_dims_mapping_len == 1:
assert trans_y is False
y_dims_mapping.insert(1, -1)
out_dims_mapping.insert(out_dims_mapping_len, 0)
new_x_dims_mapping_len = len(x_dims_mapping)
new_y_dims_mapping_len = len(y_dims_mapping)
new_out_dims_mapping_len = len(out_dims_mapping)
# Deal with dim > 2 and take care of broadcasting
if out_dims_mapping_len > 2:
if new_out_dims_mapping_len > 2:
broadcast_x_dims_mapping = []
broadcast_y_dims_mapping = []
broadcast_out_dims_mapping = []
for i in range(out_dims_mapping_len - x_dims_mapping_len):
for i in range(new_out_dims_mapping_len - new_x_dims_mapping_len):
broadcast_x_dims_mapping.append(out_dims_mapping[i])
for i in range(x_dims_mapping_len - 2):
for i in range(new_x_dims_mapping_len - 2):
broadcast_x_dims_mapping.append(x_dims_mapping[i])
for i in range(out_dims_mapping_len - y_dims_mapping_len):
for i in range(new_out_dims_mapping_len - new_y_dims_mapping_len):
broadcast_y_dims_mapping.append(out_dims_mapping[i])
for i in range(y_dims_mapping_len - 2):
for i in range(new_y_dims_mapping_len - 2):
broadcast_y_dims_mapping.append(y_dims_mapping[i])
for i in range(out_dims_mapping_len - 2):
for i in range(new_out_dims_mapping_len - 2):
broadcast_out_dims_mapping.append(out_dims_mapping[i])
compatible_dims_mapping = compute_compatible_dims_mapping([
......@@ -97,23 +117,30 @@ def _update_dims_mapping_for_matmul(dist_op):
if compatible_dims_mapping is None:
return False
for i in range(x_dims_mapping_len - 2):
new_idx = i + (out_dims_mapping_len - x_dims_mapping_len)
for i in range(new_x_dims_mapping_len - 2):
new_idx = i + (out_dims_mapping_len - new_x_dims_mapping_len)
if x_dims_mapping[i] != compatible_dims_mapping[new_idx]:
x_dims_mapping[i] = compatible_dims_mapping[new_idx]
changed = True
for i in range(y_dims_mapping_len - 2):
new_idx = i + (out_dims_mapping_len - y_dims_mapping_len)
for i in range(new_y_dims_mapping_len - 2):
new_idx = i + (out_dims_mapping_len - new_y_dims_mapping_len)
if y_dims_mapping[i] != compatible_dims_mapping[new_idx]:
y_dims_mapping[i] = compatible_dims_mapping[new_idx]
changed = True
for i in range(out_dims_mapping_len - 2):
for i in range(new_out_dims_mapping_len - 2):
if out_dims_mapping[i] != compatible_dims_mapping[i]:
out_dims_mapping[i] = compatible_dims_mapping[i]
changed = True
if trans_x:
x_dims_mapping[-1], x_dims_mapping[-2] = x_dims_mapping[
-2], x_dims_mapping[-1]
if trans_y:
y_dims_mapping[-1], y_dims_mapping[-2] = y_dims_mapping[
-2], y_dims_mapping[-1]
# The following which uses negative index can be work
# when len(out_dims_mapping) > 2 and len(out_dims_mapping) <=2
dim_changed = compute_compatible_and_update_dim_mapping(
......@@ -131,11 +158,20 @@ def _update_dims_mapping_for_matmul(dist_op):
if dim_changed:
changed = True
if trans_x:
x_dims_mapping[-1], x_dims_mapping[-2] = x_dims_mapping[
-2], x_dims_mapping[-1]
if trans_y:
y_dims_mapping[-1], y_dims_mapping[-2] = y_dims_mapping[
-2], y_dims_mapping[-1]
# Remove unnecessary dim mapping to make sure the length of dims_mapping is same as its tensor
if x_dims_mapping_len == 1:
x_dims_mapping.pop(0)
out_dims_mapping.pop(out_dims_mapping_len - 1)
if y_dims_mapping_len == 1:
y_dims_mapping.pop(1)
out_dims_mapping.pop(out_dims_mapping_len)
assert len(x_dims_mapping) == x_dims_mapping_len
assert len(y_dims_mapping) == y_dims_mapping_len
......@@ -484,6 +520,102 @@ class DistributedMatmulImpl0(DistributedOperatorImpl):
self._forward_implemented = True
self._backward_implemented = True
def calc_cost(self, op_role, dist_op, ctx, cluster):
cost = None
if int(op_role) == int(OpRole.Forward):
cost = self.calc_fwd_cost(dist_op, ctx, cluster)
elif int(op_role) == int(OpRole.Backward):
cost = self.calc_bwd_cost(dist_op, ctx, cluster)
assert cost is not None
return cost
def calc_bwd_cost(self, dist_op, ctx, cluster):
# by now the backward function only insert the gradient allreduce for dist op itself
res = []
backward_op = dist_op.serial_op
dist_attr = dist_op.dist_attr
main_block = backward_op.block
vars = main_block.vars
Y_var_dim_mapping = dist_attr.get_input_dims_mapping(
backward_op.input("Y")[0])
# col parallel: matmul + allreduce
assert Y_var_dim_mapping[0] < 0
parallel_axis = Y_var_dim_mapping[1]
has_x_grad = len(backward_op.output("X@GRAD")) > 0
if has_x_grad:
assert len(backward_op.output("X@GRAD")) == 1
# calc comp op cost
desc_mapping = build_comp_desc_from_dist_op(dist_op=dist_op,
dist_context=ctx)
process_mesh = dist_attr.process_mesh
processes = process_mesh.processes
cost_mapping = build_comp_costs_from_descs(MatmulGradOpCost, ctx,
processes, desc_mapping,
cluster)
res.append(cost_mapping)
# calc comm op cost
if has_x_grad:
attrs = {"use_calc_stream": True, "use_model_parallel": True}
var_names = backward_op.output("X@GRAD")
c_allreduce_sum_desc_mapping = build_comm_desc_from_dist_op(
"c_allreduce_sum",
dist_op,
ctx,
var_names,
attrs=attrs,
parallel_axis=parallel_axis)
comm_op_cost_list = build_comm_costs_from_descs(
AllreduceSumOpCost, ctx, processes,
c_allreduce_sum_desc_mapping, cluster)
res.append(comm_op_cost_list)
# need gradient allreduce
var_dim_mapping = dist_attr.get_input_dims_mapping(
backward_op.input("X")[0])
mesh_shape = process_mesh.topology
batch_size_axis = var_dim_mapping[0]
if batch_size_axis > -1 and mesh_shape[
batch_size_axis] > 1 and is_parameter_related(
backward_op.input("Y")[0], main_block):
parallel_axis = batch_size_axis
attrs = {"use_calc_stream": True}
var_names = [backward_op.output('Y@GRAD')[0]]
build_dp_costs(res, dist_op, ctx, var_names, attrs, parallel_axis,
cluster)
return res
def calc_fwd_cost(self, dist_op, ctx, cluster):
# calc comp op cost
desc_mapping = build_comp_desc_from_dist_op(dist_op=dist_op,
dist_context=ctx)
processes = dist_op.dist_attr.process_mesh.processes
cost_mapping = build_comp_costs_from_descs(MatmulOpCost, ctx, processes,
desc_mapping, cluster)
# calc comm op cost
serial_op = dist_op.serial_op
vars = serial_op.block.vars
parallel_axis = dist_op.dist_attr.get_input_dims_mapping(
serial_op.input("Y")[0])[-1]
attrs = {"use_calc_stream": True, "use_model_parallel": True}
var_names = serial_op.input("X")
c_identity_desc_mapping = build_comm_desc_from_dist_op(
"c_identity",
dist_op,
ctx,
var_names,
attrs=attrs,
parallel_axis=parallel_axis)
comm_op_cost_list = build_comm_costs_from_descs(
IdentityOpCost, ctx, processes, c_identity_desc_mapping, cluster)
res_cost = [comm_op_cost_list, cost_mapping]
return res_cost
def is_input_compatible(self, dist_op):
op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
......@@ -710,6 +842,99 @@ class DistributedMatmulImpl1(DistributedOperatorImpl):
self._forward_implemented = True
self._backward_implemented = True
def calc_cost(self, op_role, dist_op, ctx, cluster):
cost = None
if int(op_role) == int(OpRole.Forward):
cost = self.calc_fwd_cost(dist_op, ctx, cluster)
elif int(op_role) == int(OpRole.Backward):
cost = self.calc_bwd_cost(dist_op, ctx, cluster)
assert cost is not None
return cost
def calc_bwd_cost(self, dist_op, ctx, cluster):
# by now the backward function only insert the gradient allreduce for dist op itself
res = []
backward_op = dist_op.serial_op
dist_attr = dist_op.dist_attr
main_block = backward_op.block
vars = main_block.vars
Y_var_dim_mapping = dist_attr.get_input_dims_mapping(
backward_op.input("Y")[0])
assert Y_var_dim_mapping[1] < 0
parallel_axis = Y_var_dim_mapping[0]
# calc comm op cost
var_names = [backward_op.input("Out@GRAD")[0]]
attrs = {"use_calc_stream": True, "use_model_parallel": True}
c_identity_desc_mapping = build_comm_desc_from_dist_op(
"c_identity",
dist_op,
ctx,
var_names,
attrs=attrs,
parallel_axis=parallel_axis)
process_mesh = dist_attr.process_mesh
processes = process_mesh.processes
comm_op_cost_list = build_comm_costs_from_descs(
IdentityOpCost, ctx, processes, c_identity_desc_mapping, cluster)
res.append(comm_op_cost_list)
# calc comp op cost
desc_mapping = build_comp_desc_from_dist_op(dist_op=dist_op,
dist_context=ctx)
cost_mapping = build_comp_costs_from_descs(MatmulGradOpCost, ctx,
processes, desc_mapping,
cluster)
res.append(cost_mapping)
# need gradient allreduce
var_dim_mapping = dist_attr.get_input_dims_mapping(
backward_op.input("X")[0])
mesh_shape = process_mesh.topology
batch_size_axis = var_dim_mapping[0]
if batch_size_axis > -1 and mesh_shape[
batch_size_axis] > 1 and is_parameter_related(
backward_op.input("Y")[0], main_block):
parallel_axis = batch_size_axis
attrs = {"use_calc_stream": True}
var_names = [backward_op.output('Y@GRAD')[0]]
build_dp_costs(res, dist_op, ctx, var_names, attrs, parallel_axis,
cluster)
return res
def calc_fwd_cost(self, dist_op, ctx, cluster):
# calc comp op cost
desc_mapping = build_comp_desc_from_dist_op(dist_op=dist_op,
dist_context=ctx)
processes = dist_op.dist_attr.process_mesh.processes
cost_mapping = build_comp_costs_from_descs(MatmulOpCost, ctx, processes,
desc_mapping, cluster)
# calc comm op cost
serial_op = dist_op.serial_op
vars = serial_op.block.vars
parallel_axis = dist_op.dist_attr.get_input_dims_mapping(
serial_op.input("Y")[0])[-2]
attrs = {"use_calc_stream": True, "use_model_parallel": True}
var_names = serial_op.output("Out")
c_allreduce_sum_desc_mapping = build_comm_desc_from_dist_op(
"c_allreduce_sum",
dist_op,
ctx,
var_names,
attrs=attrs,
parallel_axis=parallel_axis)
comm_op_cost_list = build_comm_costs_from_descs(
AllreduceSumOpCost, ctx, processes, c_allreduce_sum_desc_mapping,
cluster)
res_cost = [cost_mapping, comm_op_cost_list]
return res_cost
def is_input_compatible(self, dist_op):
op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
......@@ -920,6 +1145,59 @@ class DistributedMatmulImpl2(DistributedOperatorImpl):
def __init__(self, name):
super(DistributedMatmulImpl2, self).__init__(name)
def calc_cost(self, op_role, dist_op, ctx, cluster):
cost = None
if int(op_role) == int(OpRole.Forward):
cost = self.calc_fwd_cost(dist_op, ctx, cluster)
elif int(op_role) == int(OpRole.Backward):
cost = self.calc_bwd_cost(dist_op, ctx, cluster)
assert cost is not None
return cost
def calc_bwd_cost(self, dist_op, ctx, cluster):
res = []
backward_op = dist_op.serial_op
dist_attr = dist_op.dist_attr
main_block = backward_op.block
vars = main_block.vars
# calc comp op cost
desc_mapping = build_comp_desc_from_dist_op(dist_op=dist_op,
dist_context=ctx)
process_mesh = dist_attr.process_mesh
processes = process_mesh.processes
cost_mapping = build_comp_costs_from_descs(MatmulGradOpCost, ctx,
processes, desc_mapping,
cluster)
res.append(cost_mapping)
# need gradient allreduce
var_dim_mapping = dist_attr.get_input_dims_mapping(
backward_op.input("X")[0])
mesh_shape = process_mesh.topology
batch_size_axis = var_dim_mapping[0]
if batch_size_axis > -1 and mesh_shape[
batch_size_axis] > 1 and is_parameter_related(
backward_op.input("Y")[0], main_block):
parallel_axis = batch_size_axis
attrs = {"use_calc_stream": True}
var_names = [backward_op.output('Y@GRAD')[0]]
build_dp_costs(res, dist_op, ctx, var_names, attrs, parallel_axis,
cluster)
return res
def calc_fwd_cost(self, dist_op, ctx, cluster):
# calc comp op cost
desc_mapping = build_comp_desc_from_dist_op(dist_op=dist_op,
dist_context=ctx)
processes = dist_op.dist_attr.process_mesh.processes
cost_mapping = build_comp_costs_from_descs(MatmulOpCost, ctx, processes,
desc_mapping, cluster)
res_cost = [cost_mapping]
return res_cost
def is_input_compatible(self, dist_op):
op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
......
......@@ -15,7 +15,7 @@
from .common import DistributedOperatorImplContainer
from .common import DistributedOperatorImpl
from .common import register_distributed_operator_impl_container
from .common import register_distributed_operator_impl
from .common import register_distributed_operator_impl, is_parameter_related
from ..utils import is_dim_shard
from ..utils import is_dim_replicate
from ..utils import is_valid_list_index
......@@ -28,6 +28,11 @@ from paddle.fluid.framework import _non_static_mode
from paddle.fluid.framework import Program, Parameter, Variable, program_guard
from paddle.fluid.data_feeder import check_variable_and_dtype, check_dtype
from .dist_default import DistributedDefaultImpl0
from ..cost import build_comp_desc_from_dist_op, build_comp_costs_from_descs
from ..cost import build_comm_costs_from_descs
from ..cost import Reshape2OpCost
from ..cost import Reshape2GradOpCost
from paddle.distributed.fleet.meta_optimizers.common import OpRole
class DistributedReshape2(DistributedOperatorImplContainer):
......@@ -46,6 +51,84 @@ class DistributedReshapeImpl0(DistributedOperatorImpl):
self._forward_implemented = True
self._backward_implemented = False
def calc_cost(self, op_role, dist_op, ctx, cluster):
cost = None
if int(op_role) == int(OpRole.Backward):
cost = self.calc_bwd_cost(dist_op, ctx, cluster)
else:
cost = self.calc_fwd_cost(dist_op, ctx, cluster)
assert cost is not None
return cost
def calc_fwd_cost(self, dist_op, ctx, cluster):
res = []
op = dist_op.serial_op
vars = op.block.vars
dist_attr = dist_op.dist_attr
shape_list = op.desc.attr("shape")
# got dist attribute info
dim_mapping = dist_attr.get_output_dims_mapping(op.output("Out")[0])
process_mesh_shape = dist_attr.process_mesh.topology
# modify target shape
for idx, axis in enumerate(dim_mapping):
if axis >= 0:
if len(shape_list) > idx:
shape_list[
idx] = shape_list[idx] // process_mesh_shape[axis]
# calc comp op cost
desc_mapping = build_comp_desc_from_dist_op(dist_op=dist_op,
dist_context=ctx)
processes = dist_attr.process_mesh.processes
for key in desc_mapping:
desc_mapping[key]["shape"] = shape_list
cost_mapping = build_comp_costs_from_descs(Reshape2OpCost, ctx,
processes, desc_mapping,
cluster)
res.append(cost_mapping)
return res
def calc_bwd_cost(self, dist_op, ctx, cluster):
# calc comp op cost
res = []
desc_mapping = build_comp_desc_from_dist_op(dist_op=dist_op,
dist_context=ctx)
dist_attr = dist_op.dist_attr
process_mesh = dist_attr.process_mesh
processes = process_mesh.processes
op_type = dist_op.serial_op.type
cost_mapping = build_comp_costs_from_descs(Reshape2GradOpCost, ctx,
processes, desc_mapping,
cluster)
res.append(cost_mapping)
backward_op = dist_op.serial_op
main_block = backward_op.block
need_gradient_allreduce = False
vars = main_block.vars
for input_name in backward_op.desc.input_names():
for varname in backward_op.desc.input(input_name):
if "@GRAD" not in varname and is_parameter_related(
varname, main_block):
# NOTE input var's dim_mapping of backward op should be the same with input var instead of corresponding varname of forward op
var_dim_mapping = dist_attr.get_input_dims_mapping(varname)
mesh_shape = process_mesh.topology
batch_size_axis = var_dim_mapping[0]
if batch_size_axis > -1 and mesh_shape[batch_size_axis] > 1:
parallel_axis = batch_size_axis
attrs = {"use_calc_stream": True}
var_names = [varname + "@GRAD"]
build_dp_costs(res, dist_op, ctx, var_names, attrs,
parallel_axis, cluster)
return res
def is_input_compatible(self, dist_op):
op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
......@@ -199,6 +282,84 @@ class DistributedReshapeImpl1(DistributedOperatorImpl):
self._forward_implemented = True
self._backward_implemented = False
def calc_cost(self, op_role, dist_op, ctx, cluster):
cost = None
if int(op_role) == int(OpRole.Backward):
cost = self.calc_bwd_cost(dist_op, ctx, cluster)
else:
cost = self.calc_fwd_cost(dist_op, ctx, cluster)
assert cost is not None
return cost
def calc_fwd_cost(self, dist_op, ctx, cluster):
res = []
op = dist_op.serial_op
vars = op.block.vars
dist_attr = dist_op.dist_attr
shape_list = op.desc.attr("shape")
# got dist attribute info
dim_mapping = dist_attr.get_output_dims_mapping(op.output("Out")[0])
process_mesh_shape = dist_attr.process_mesh.topology
# modify target shape
for idx, axis in enumerate(dim_mapping):
if axis >= 0:
if len(shape_list) > idx:
shape_list[
idx] = shape_list[idx] // process_mesh_shape[axis]
# calc comp op cost
desc_mapping = build_comp_desc_from_dist_op(dist_op=dist_op,
dist_context=ctx)
processes = dist_attr.process_mesh.processes
for key in desc_mapping:
desc_mapping[key]["shape"] = shape_list
cost_mapping = build_comp_costs_from_descs(Reshape2OpCost, ctx,
processes, desc_mapping,
cluster)
res.append(cost_mapping)
return res
def calc_bwd_cost(self, dist_op, ctx, cluster):
# calc comp op cost
res = []
desc_mapping = build_comp_desc_from_dist_op(dist_op=dist_op,
dist_context=ctx)
dist_attr = dist_op.dist_attr
process_mesh = dist_attr.process_mesh
processes = process_mesh.processes
op_type = dist_op.serial_op.type
cost_mapping = build_comp_costs_from_descs(Reshape2GradOpCost, ctx,
processes, desc_mapping,
cluster)
res.append(cost_mapping)
backward_op = dist_op.serial_op
main_block = backward_op.block
need_gradient_allreduce = False
vars = main_block.vars
for input_name in backward_op.desc.input_names():
for varname in backward_op.desc.input(input_name):
if "@GRAD" not in varname and not is_parameter_related(
varname, main_block):
# NOTE input var's dim_mapping of backward op should be the same with input var instead of corresponding varname of forward op
var_dim_mapping = dist_attr.get_input_dims_mapping(varname)
mesh_shape = process_mesh.topology
batch_size_axis = var_dim_mapping[0]
if batch_size_axis > -1 and mesh_shape[batch_size_axis] > 1:
parallel_axis = batch_size_axis
attrs = {"use_calc_stream": True}
var_names = [varname + "@GRAD"]
build_dp_costs(res, dist_op, ctx, var_names, attrs,
parallel_axis, cluster)
return res
def is_input_compatible(self, dist_op):
op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
......@@ -355,6 +516,84 @@ class DistributedReshapeImpl2(DistributedOperatorImpl):
self._forward_implemented = True
self._backward_implemented = False
def calc_cost(self, op_role, dist_op, ctx, cluster):
cost = None
if int(op_role) == int(OpRole.Backward):
cost = self.calc_bwd_cost(dist_op, ctx, cluster)
else:
cost = self.calc_fwd_cost(dist_op, ctx, cluster)
assert cost is not None
return cost
def calc_fwd_cost(self, dist_op, ctx, cluster):
res = []
op = dist_op.serial_op
vars = op.block.vars
dist_attr = dist_op.dist_attr
shape_list = op.desc.attr("shape")
# got dist attribute info
dim_mapping = dist_attr.get_output_dims_mapping(op.output("Out")[0])
process_mesh_shape = dist_attr.process_mesh.topology
# modify target shape
for idx, axis in enumerate(dim_mapping):
if axis >= 0:
if len(shape_list) > idx:
shape_list[
idx] = shape_list[idx] // process_mesh_shape[axis]
# calc comp op cost
desc_mapping = build_comp_desc_from_dist_op(dist_op=dist_op,
dist_context=ctx)
processes = dist_attr.process_mesh.processes
for key in desc_mapping:
desc_mapping[key]["shape"] = shape_list
cost_mapping = build_comp_costs_from_descs(Reshape2OpCost, ctx,
processes, desc_mapping,
cluster)
res.append(cost_mapping)
return res
def calc_bwd_cost(self, dist_op, ctx, cluster):
# calc comp op cost
res = []
desc_mapping = build_comp_desc_from_dist_op(dist_op=dist_op,
dist_context=ctx)
dist_attr = dist_op.dist_attr
process_mesh = dist_attr.process_mesh
processes = process_mesh.processes
op_type = dist_op.serial_op.type
cost_mapping = build_comp_costs_from_descs(Reshape2GradOpCost, ctx,
processes, desc_mapping,
cluster)
res.append(cost_mapping)
backward_op = dist_op.serial_op
main_block = backward_op.block
need_gradient_allreduce = False
vars = main_block.vars
for input_name in backward_op.desc.input_names():
for varname in backward_op.desc.input(input_name):
if "@GRAD" not in varname and not is_parameter_related(
varname, main_block):
# NOTE input var's dim_mapping of backward op should be the same with input var instead of corresponding varname of forward op
var_dim_mapping = dist_attr.get_input_dims_mapping(varname)
mesh_shape = process_mesh.topology
batch_size_axis = var_dim_mapping[0]
if batch_size_axis > -1 and mesh_shape[batch_size_axis] > 1:
parallel_axis = batch_size_axis
attrs = {"use_calc_stream": True}
var_names = [varname + "@GRAD"]
build_dp_costs(res, dist_op, ctx, var_names, attrs,
parallel_axis, cluster)
return res
def is_input_compatible(self, dist_op):
op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
......
......@@ -16,6 +16,7 @@ from .common import DistributedOperatorImplContainer
from .common import DistributedOperatorImpl
from .common import register_distributed_operator_impl_container
from .common import register_distributed_operator_impl
from .common import is_parameter_related
from ..utils import is_dim_shard
from ..utils import is_dim_replicate
from ..utils import is_valid_list_index
......@@ -23,6 +24,11 @@ from ..utils import compute_compatible_dim_mapping
from ..utils import compute_compatible_dims_mapping
from ..utils import compute_compatible_and_update_dim_mapping
from .dist_default import DistributedDefaultImpl0
from ..cost import AllreduceSumOpCost, _g_op_cost_factory
from ..cost import build_comp_desc_from_dist_op, build_dp_costs
from ..cost import build_comp_costs_from_descs
from ..cost import SoftmaxOpCost, SoftmaxGradOpCost
from paddle.distributed.fleet.meta_optimizers.common import OpRole
class DistributedSoftmax(DistributedOperatorImplContainer):
......@@ -41,6 +47,62 @@ class DistributedSoftmaxImpl(DistributedOperatorImpl):
self._forward_implemented = False
self._backward_implemented = False
def calc_cost(self, op_role, dist_op, ctx, cluster):
cost = None
if int(op_role) == int(OpRole.Backward):
cost = self.calc_bwd_cost(dist_op, ctx, cluster)
else:
cost = self.calc_fwd_cost(dist_op, ctx, cluster)
assert cost is not None
return cost
def calc_fwd_cost(self, dist_op, ctx, cluster):
# calc comp op cost
desc_mapping = build_comp_desc_from_dist_op(dist_op=dist_op,
dist_context=ctx)
processes = dist_op.dist_attr.process_mesh.processes
cost_mapping = build_comp_costs_from_descs(SoftmaxOpCost, ctx,
processes, desc_mapping,
cluster)
res_cost = [cost_mapping]
return res_cost
def calc_bwd_cost(self, dist_op, ctx, cluster):
# calc comp op cost
res = []
desc_mapping = build_comp_desc_from_dist_op(dist_op=dist_op,
dist_context=ctx)
dist_attr = dist_op.dist_attr
process_mesh = dist_attr.process_mesh
processes = process_mesh.processes
cost_mapping = build_comp_costs_from_descs(SoftmaxGradOpCost, ctx,
processes, desc_mapping,
cluster)
res.append(cost_mapping)
backward_op = dist_op.serial_op
main_block = backward_op.block
need_gradient_allreduce = False
vars = main_block.vars
for input_name in backward_op.desc.input_names():
for varname in backward_op.desc.input(input_name):
if "@GRAD" not in varname and is_parameter_related(
varname, main_block):
# NOTE input var's dim_mapping of backward op should be the same with input var instead of corresponding varname of forward op
var_dim_mapping = dist_attr.get_input_dims_mapping(varname)
mesh_shape = process_mesh.topology
batch_size_axis = var_dim_mapping[0]
if batch_size_axis > -1 and mesh_shape[batch_size_axis] > 1:
parallel_axis = batch_size_axis
attrs = {"use_calc_stream": True}
var_names = [varname + "@GRAD"]
build_dp_costs(res, dist_op, ctx, var_names, attrs,
parallel_axis, cluster)
return res
def is_input_compatible(self, dist_op):
op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
......
......@@ -16,6 +16,7 @@ from .common import DistributedOperatorImplContainer
from .common import DistributedOperatorImpl
from .common import register_distributed_operator_impl_container
from .common import register_distributed_operator_impl
from .common import is_parameter_related
from ..utils import is_dim_shard
from ..utils import is_dim_replicate
from ..utils import is_valid_list_index
......@@ -23,6 +24,10 @@ from ..utils import compute_compatible_dim_mapping
from ..utils import compute_compatible_dims_mapping
from ..utils import compute_compatible_and_update_dim_mapping
from .dist_default import DistributedDefaultImpl0
from ..cost import AllreduceSumOpCost, Transpose2OpCost, Transpose2GradOpCost
from ..cost import build_comp_desc_from_dist_op, build_comm_desc_from_dist_op, build_dp_costs
from ..cost import build_comp_costs_from_descs
from paddle.distributed.fleet.meta_optimizers.common import OpRole
class DistributedTranspose2(DistributedOperatorImplContainer):
......@@ -116,6 +121,63 @@ class DistributedTranspose2Impl(DistributedOperatorImpl):
return changed
def calc_cost(self, op_role, dist_op, ctx, cluster):
cost = None
if int(op_role) == int(OpRole.Backward):
cost = self.calc_bwd_cost(dist_op, ctx, cluster)
else:
cost = self.calc_fwd_cost(dist_op, ctx, cluster)
assert cost is not None
return cost
def calc_fwd_cost(self, dist_op, ctx, cluster):
# calc comp op cost
desc_mapping = build_comp_desc_from_dist_op(dist_op=dist_op,
dist_context=ctx)
processes = dist_op.dist_attr.process_mesh.processes
op_type = dist_op.serial_op.type
cost_mapping = build_comp_costs_from_descs(Transpose2OpCost, ctx,
processes, desc_mapping,
cluster)
res_cost = [cost_mapping]
return res_cost
def calc_bwd_cost(self, dist_op, ctx, cluster):
# calc comp op cost
res = []
desc_mapping = build_comp_desc_from_dist_op(dist_op=dist_op,
dist_context=ctx)
dist_attr = dist_op.dist_attr
process_mesh = dist_attr.process_mesh
processes = process_mesh.processes
op_type = dist_op.serial_op.type
cost_mapping = build_comp_costs_from_descs(Transpose2GradOpCost, ctx,
processes, desc_mapping,
cluster)
res.append(cost_mapping)
backward_op = dist_op.serial_op
main_block = backward_op.block
need_gradient_allreduce = False
vars = main_block.vars
for input_name in backward_op.desc.input_names():
for varname in backward_op.desc.input(input_name):
if "@GRAD" not in varname and is_parameter_related(
varname, main_block):
# NOTE input var's dim_mapping of backward op should be the same with input var instead of corresponding varname of forward op
var_dim_mapping = dist_attr.get_input_dims_mapping(varname)
mesh_shape = process_mesh.topology
batch_size_axis = var_dim_mapping[0]
if batch_size_axis > -1 and mesh_shape[batch_size_axis] > 1:
parallel_axis = batch_size_axis
attrs = {"use_calc_stream": True}
var_names = [varname + "@GRAD"]
build_dp_costs(res, dist_op, ctx, var_names, attrs,
parallel_axis, cluster)
return res
@staticmethod
def forward(ctx, *args, **kwargs):
DistributedDefaultImpl0.forward(ctx, *args, **kwargs)
......
......@@ -47,7 +47,7 @@ def parallelizer(program_func, rank):
completer.complete_backward_annotation(main_program)
dist_context.block_state.parse_backward_blocks(main_program)
optimizer = paddle.optimizer.SGD(learning_rate=0.001)
optimizer = paddle.optimizer.Adam(learning_rate=0.001)
# generate opt and complete opt
with program_guard(main_program, startup_program):
optimize_ops = copy.deepcopy(optimizer).apply_gradients(params_grads)
......@@ -59,7 +59,7 @@ def parallelizer(program_func, rank):
class TestDistOpCost(unittest.TestCase):
def test_dist_fill_constatnt_batch_size_like_op_cost(self):
def test_dist_op_cost_part1(self):
def make_program():
main_program = paddle.static.Program()
......@@ -79,7 +79,7 @@ class TestDistOpCost(unittest.TestCase):
tmp = paddle.fluid.layers.fill_constant_batch_size_like(
input=x, shape=[2, 8], value=1, dtype='float32')
weight_attr = paddle.ParamAttr()
linear = paddle.nn.Linear(8, 8, weight_attr=weight_attr)
linear = paddle.nn.Linear(8, 1, weight_attr=weight_attr)
linear_out = linear(x)
gelu_out = paddle.nn.functional.gelu(linear_out)
# default op with dp
......@@ -109,6 +109,112 @@ class TestDistOpCost(unittest.TestCase):
dist_context, cluster)
self.assertTrue(dist_op_cost)
def test_dist_op_cost_part2(self):
def make_program():
main_program = paddle.static.Program()
start_program = paddle.static.Program()
with paddle.static.program_guard(main_program, start_program):
x = paddle.static.data(name='x', shape=[4], dtype='float32')
x.stop_gradient = True
label = paddle.static.data(name="label",
shape=[8, 1],
dtype='float32')
label.stop_gradient = True
auto.shard_tensor(x,
dist_attr={
"process_mesh": auto.ProcessMesh([0, 1]),
"dims_mapping": [0]
})
auto.shard_tensor(label,
dist_attr={
"process_mesh": auto.ProcessMesh([0, 1]),
"dims_mapping": [0, -1]
})
# embedding
tmp = paddle.fluid.layers.fill_constant_batch_size_like(
input=x, shape=[4], value=1, dtype='int32')
embedding = paddle.nn.Embedding(10, 8)
out = embedding(tmp)
# row parallel embedding
for op in main_program.global_block().ops:
if op.type == "lookup_table_v2":
W = main_program.global_block().vars[op.input("W")[0]]
auto.shard_tensor(W,
dist_attr={
"process_mesh":
auto.ProcessMesh([0, 1]),
"dims_mapping": [0, -1]
})
out = paddle.fluid.layers.transpose(out,
[1, 0]) # [8, 2] [-1, 0]
# matmul
param1 = paddle.fluid.layers.create_parameter(
[4, 8], paddle.float32) # [2, 8] [0, -1]
auto.shard_tensor(param1,
dist_attr={
"process_mesh": auto.ProcessMesh([0, 1]),
"dims_mapping": [0, -1]
})
param2 = paddle.fluid.layers.create_parameter(
[8, 8], paddle.float32) # [8, 4] [-1, 0]
auto.shard_tensor(param2,
dist_attr={
"process_mesh": auto.ProcessMesh([0, 1]),
"dims_mapping": [-1, 0]
})
out1 = paddle.fluid.layers.matmul(out,
param1) # [8, 8] [-1, -1]
tmp_param = paddle.fluid.layers.create_parameter(
[8, 8], paddle.float32) # [8, 8] [-1, -1]
auto.shard_tensor(param2,
dist_attr={
"process_mesh": auto.ProcessMesh([0, 1]),
"dims_mapping": [-1, -1]
})
tmp_out = paddle.fluid.layers.matmul(out1, tmp_param)
out2 = paddle.fluid.layers.matmul(tmp_out,
param2) # [8, 4] [-1, 0]
out8 = paddle.fluid.layers.transpose(out2,
[1, 0]) # [4, 8] [0, -1]
# reshape
out9 = paddle.reshape(out8, [8, 2, 4]) # [4, 2, 4] [0, -1, -1]
tmp_reshape_out = paddle.reshape(out9, [8, 4, 2])
out10 = paddle.reshape(tmp_reshape_out,
[8, 8]) # [4, 8] [0, -1]
# softmax
softmax = paddle.nn.Softmax()
out11 = softmax(out10)
error_cost = paddle.nn.functional.square_error_cost(
out11, label)
loss = paddle.mean(error_cost)
return main_program, start_program, loss
main_program, dist_context = parallelizer(make_program, 0)
ops = main_program.global_block().ops
cluster = Cluster()
cluster.gen_default_config_cluster(device_count=2)
for idx, op in enumerate(ops):
dist_op = dist_context.get_dist_op_for_program(op)
op_dist_attr = dist_op.dist_attr
processes = op_dist_attr.process_mesh.processes
if is_elementwise_op(op.type):
container = get_distributed_operator_impl_container(
"elementwise")
else:
container = get_distributed_operator_impl_container(
op_dist_attr.impl_type)
dist_impl = container.impls[op_dist_attr.impl_idx]
dist_op_cost = dist_impl.calc_cost(op.attr('op_role'), dist_op,
dist_context, cluster)
self.assertTrue(dist_op_cost)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册