未验证 提交 ec1e0d5a 编写于 作者: C caozhou 提交者: GitHub

add dist op costs (#44701)

上级 fecbc958
...@@ -12,20 +12,40 @@ ...@@ -12,20 +12,40 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License # limitations under the License
from .base_cost import _g_op_cost_factory
from .base_cost import Cost from .base_cost import Cost
from .base_cost import CommContext from .base_cost import CommContext
from .base_cost import _g_op_cost_factory
from .base_cost import build_comm_desc from .base_cost import build_comm_desc
from .base_cost import build_comp_desc_from_op
from .base_cost import build_comp_desc_from_dist_op
from .base_cost import build_dp_costs from .base_cost import build_dp_costs
from .base_cost import build_comp_desc_str_for_predict
from .base_cost import build_comp_desc_from_dist_op
from .base_cost import build_comm_desc_from_dist_op
from .base_cost import build_comm_costs_from_descs
from .base_cost import build_comp_costs_from_descs from .base_cost import build_comp_costs_from_descs
from .tensor_cost import TensorCost
from .estimate_cost import CostEstimator
from .comp_op_cost import EmbeddingOpCost
from .comp_op_cost import EmbeddingGradOpCost
from .comp_op_cost import ConcatOpCost
from .comp_op_cost import MatmulOpCost
from .comp_op_cost import MatmulGradOpCost
from .comp_op_cost import MatmulV2OpCost from .comp_op_cost import MatmulV2OpCost
from .comp_op_cost import MatmulV2GradOpCost
from .comp_op_cost import MulOpCost
from .comp_op_cost import MulGradOpCost
from .comp_op_cost import Reshape2OpCost
from .comp_op_cost import Reshape2GradOpCost
from .comp_op_cost import SliceOpCost
from .comp_op_cost import SplitOpCost
from .comp_op_cost import SoftmaxOpCost
from .comp_op_cost import SoftmaxGradOpCost
from .comp_op_cost import Transpose2OpCost
from .comp_op_cost import Transpose2GradOpCost
from .comp_op_cost import FillConstantBatchSizeLikeOpCost from .comp_op_cost import FillConstantBatchSizeLikeOpCost
from .tensor_cost import TensorCost
from .estimate_cost import CostEstimator
from .comm_op_cost import SendOpCost from .comm_op_cost import SendOpCost
from .comm_op_cost import RecvOpCost from .comm_op_cost import RecvOpCost
from .comm_op_cost import IdentityOpCost from .comm_op_cost import IdentityOpCost
......
...@@ -15,6 +15,25 @@ ...@@ -15,6 +15,25 @@
from .base_cost import Cost, register_op_cost, CompOpCost, _g_op_cost_factory from .base_cost import Cost, register_op_cost, CompOpCost, _g_op_cost_factory
@register_op_cost
class AdamOpCost(CompOpCost):
OP_TYPE = "adam"
def __init__(self, op=None, op_desc=None, cluster=None):
super(AdamOpCost, self).__init__(op=op,
op_desc=op_desc,
cluster=cluster)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0
def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0
@register_op_cost @register_op_cost
class AssignOpCost(CompOpCost): class AssignOpCost(CompOpCost):
OP_TYPE = "assign" OP_TYPE = "assign"
......
...@@ -831,8 +831,10 @@ class DistributedContext: ...@@ -831,8 +831,10 @@ class DistributedContext:
if (dist_tensor if (dist_tensor
is not None) and (not dist_tensor.validate_dist_attr()): is not None) and (not dist_tensor.validate_dist_attr()):
assert False, "Tensor {} (id: {}, original_id: {}) has a wrong distributed attributes {}.".format( assert False, "Tensor {} (id: {}, original_id: {}) has a wrong distributed attributes {}.".format(
dist_tensor.serial_tensor.name, dist_tensor.desc.id(), dist_tensor.serial_tensor.name,
dist_tensor.desc.original_id(), dist_tensor.dist_attr) dist_tensor.serial_tensor.desc.id(),
dist_tensor.serial_tensor.desc.original_id(),
dist_tensor.dist_attr)
for op in block.ops: for op in block.ops:
dist_op = self.get_dist_op_for_program(op) dist_op = self.get_dist_op_for_program(op)
assert dist_op is not None, \ assert dist_op is not None, \
......
...@@ -31,6 +31,9 @@ from paddle.fluid.data_feeder import check_variable_and_dtype, check_dtype ...@@ -31,6 +31,9 @@ from paddle.fluid.data_feeder import check_variable_and_dtype, check_dtype
from paddle.distributed.fleet.meta_optimizers.common import OpRole, OP_ROLE_KEY, OP_ROLE_VAR_KEY from paddle.distributed.fleet.meta_optimizers.common import OpRole, OP_ROLE_KEY, OP_ROLE_VAR_KEY
from ..process_group import new_process_group from ..process_group import new_process_group
from ..utils import _get_comm_group, _get_idx_in_axis, _get_corresponding_rank from ..utils import _get_comm_group, _get_idx_in_axis, _get_corresponding_rank
from ..cost import build_comp_desc_from_dist_op, build_comm_desc_from_dist_op
from ..cost import build_comm_costs_from_descs, build_comp_costs_from_descs, build_dp_costs
from ..cost import EmbeddingOpCost, EmbeddingGradOpCost, AllreduceSumOpCost, IdentityOpCost
class DistributedEmbedding(DistributedOperatorImplContainer): class DistributedEmbedding(DistributedOperatorImplContainer):
...@@ -53,6 +56,95 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): ...@@ -53,6 +56,95 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
self._forward_implemented = True self._forward_implemented = True
self._backward_implemented = True self._backward_implemented = True
def calc_cost(self, op_role, dist_op, ctx, cluster):
"""Calculate the cost by the op role."""
cost = None
if int(op_role) == int(OpRole.Forward):
cost = self.calc_fwd_cost(dist_op, ctx, cluster)
elif int(op_role) == int(OpRole.Backward):
cost = self.calc_bwd_cost(dist_op, ctx, cluster)
assert cost is not None
return cost
def calc_fwd_cost(self, dist_op, ctx, cluster):
# calc comp op cost
desc_mapping = build_comp_desc_from_dist_op(dist_op=dist_op,
dist_context=ctx)
processes = dist_op.dist_attr.process_mesh.processes
# embedding need start_index
cost_mapping = build_comp_costs_from_descs(EmbeddingOpCost, ctx,
processes, desc_mapping,
cluster)
serial_op = dist_op.serial_op
parallel_axis = dist_op.dist_attr.get_input_dims_mapping(
serial_op.input("W")[0])[0]
attrs = {"use_calc_stream": True, "use_model_parallel": True}
var_names = serial_op.output("Out")
c_allreduce_sum_desc_mapping = build_comm_desc_from_dist_op(
"c_allreduce_sum",
dist_op,
ctx,
var_names,
attrs=attrs,
parallel_axis=parallel_axis)
comm_op_cost_list = build_comm_costs_from_descs(
AllreduceSumOpCost, ctx, processes, c_allreduce_sum_desc_mapping,
cluster)
res_cost = [cost_mapping, comm_op_cost_list]
return res_cost
def calc_bwd_cost(self, dist_op, ctx, cluster):
# by now the backward function only insert the gradient allreduce for dist op itself
res = []
backward_op = dist_op.serial_op
main_block = backward_op.block
dist_attr = dist_op.dist_attr
embedding_row_dim_mapping = dist_attr.get_input_dims_mapping(
backward_op.input("W")[0])[0]
parallel_axis = embedding_row_dim_mapping
attrs = {"use_calc_stream": True, "use_model_parallel": True}
var_names = [backward_op.input("Out@GRAD")[0]]
c_identity_desc_mapping = build_comm_desc_from_dist_op(
"c_identity",
dist_op,
ctx,
var_names,
attrs=attrs,
parallel_axis=parallel_axis)
process_mesh = dist_attr.process_mesh
processes = process_mesh.processes
comm_op_cost_list = build_comm_costs_from_descs(
IdentityOpCost, ctx, processes, c_identity_desc_mapping, cluster)
res.append(comm_op_cost_list)
# calc comp op cost
desc_mapping = build_comp_desc_from_dist_op(dist_op=dist_op,
dist_context=ctx)
cost_mapping = build_comp_costs_from_descs(EmbeddingGradOpCost, ctx,
processes, desc_mapping,
cluster)
res.append(cost_mapping)
# need gradient allreduce
var_dim_mapping = dist_attr.get_input_dims_mapping(
backward_op.input("Ids")[0])
mesh_shape = process_mesh.topology
batch_size_axis = var_dim_mapping[0]
if batch_size_axis > -1 and mesh_shape[batch_size_axis] > 1:
parallel_axis = batch_size_axis
attrs = {"use_calc_stream": True}
var_names = [backward_op.output('W@GRAD')[0]]
build_dp_costs(res, dist_op, ctx, var_names, attrs, parallel_axis,
cluster)
return res
def is_input_compatible(self, dist_op): def is_input_compatible(self, dist_op):
op_desc = dist_op.serial_op.desc op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr op_dist_attr = dist_op.dist_attr
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License # limitations under the License
import copy import copy
from .common import infer_shape from .common import infer_shape
from .common import DistributedOperatorImplContainer from .common import DistributedOperatorImplContainer
from .common import DistributedOperatorImpl from .common import DistributedOperatorImpl
...@@ -35,6 +36,10 @@ from paddle.distributed.fleet.meta_optimizers.common import OpRole, OP_ROLE_KEY, ...@@ -35,6 +36,10 @@ from paddle.distributed.fleet.meta_optimizers.common import OpRole, OP_ROLE_KEY,
from ..process_group import new_process_group from ..process_group import new_process_group
from ..utils import _get_comm_group, _get_corresponding_rank from ..utils import _get_comm_group, _get_corresponding_rank
from .dist_default import DistributedDefaultImpl0 from .dist_default import DistributedDefaultImpl0
from ..cost import build_comp_desc_from_dist_op, build_comm_desc_from_dist_op, build_dp_costs
from ..cost import build_comm_costs_from_descs, build_comp_costs_from_descs
from ..cost import MatmulV2OpCost, MatmulOpCost, MulOpCost, IdentityOpCost, AllreduceSumOpCost
from ..cost import MatmulV2GradOpCost, MatmulGradOpCost, MulGradOpCost
def copy_op_with_new_input_output(ctx, block, src_op, **kwargs): def copy_op_with_new_input_output(ctx, block, src_op, **kwargs):
...@@ -58,6 +63,14 @@ def _update_dims_mapping_for_matmul(dist_op): ...@@ -58,6 +63,14 @@ def _update_dims_mapping_for_matmul(dist_op):
x_name = op_desc.input('X')[0] x_name = op_desc.input('X')[0]
y_name = op_desc.input('Y')[0] y_name = op_desc.input('Y')[0]
out_name = op_desc.output('Out')[0] out_name = op_desc.output('Out')[0]
trans_x = None
trans_y = None
if op_desc.type() == "matmul_v2":
trans_x = op_desc.attr('trans_x')
trans_y = op_desc.attr('trans_y')
elif op_desc.type() == "matmul":
trans_x = op_desc.attr('transpose_X')
trans_y = op_desc.attr('transpose_Y')
x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name) x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
y_dims_mapping = op_dist_attr.get_input_dims_mapping(y_name) y_dims_mapping = op_dist_attr.get_input_dims_mapping(y_name)
out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name) out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
...@@ -67,27 +80,34 @@ def _update_dims_mapping_for_matmul(dist_op): ...@@ -67,27 +80,34 @@ def _update_dims_mapping_for_matmul(dist_op):
# Add dim mapping to Make sure the length dims_mapping be at least 2 # Add dim mapping to Make sure the length dims_mapping be at least 2
if x_dims_mapping_len == 1: if x_dims_mapping_len == 1:
assert trans_x is False
x_dims_mapping.insert(0, -1) x_dims_mapping.insert(0, -1)
out_dims_mapping.insert(out_dims_mapping_len - 1, 0)
if y_dims_mapping_len == 1: if y_dims_mapping_len == 1:
assert trans_y is False
y_dims_mapping.insert(1, -1) y_dims_mapping.insert(1, -1)
out_dims_mapping.insert(out_dims_mapping_len, 0)
new_x_dims_mapping_len = len(x_dims_mapping)
new_y_dims_mapping_len = len(y_dims_mapping)
new_out_dims_mapping_len = len(out_dims_mapping)
# Deal with dim > 2 and take care of broadcasting # Deal with dim > 2 and take care of broadcasting
if out_dims_mapping_len > 2: if new_out_dims_mapping_len > 2:
broadcast_x_dims_mapping = [] broadcast_x_dims_mapping = []
broadcast_y_dims_mapping = [] broadcast_y_dims_mapping = []
broadcast_out_dims_mapping = [] broadcast_out_dims_mapping = []
for i in range(out_dims_mapping_len - x_dims_mapping_len): for i in range(new_out_dims_mapping_len - new_x_dims_mapping_len):
broadcast_x_dims_mapping.append(out_dims_mapping[i]) broadcast_x_dims_mapping.append(out_dims_mapping[i])
for i in range(x_dims_mapping_len - 2): for i in range(new_x_dims_mapping_len - 2):
broadcast_x_dims_mapping.append(x_dims_mapping[i]) broadcast_x_dims_mapping.append(x_dims_mapping[i])
for i in range(out_dims_mapping_len - y_dims_mapping_len): for i in range(new_out_dims_mapping_len - new_y_dims_mapping_len):
broadcast_y_dims_mapping.append(out_dims_mapping[i]) broadcast_y_dims_mapping.append(out_dims_mapping[i])
for i in range(y_dims_mapping_len - 2): for i in range(new_y_dims_mapping_len - 2):
broadcast_y_dims_mapping.append(y_dims_mapping[i]) broadcast_y_dims_mapping.append(y_dims_mapping[i])
for i in range(out_dims_mapping_len - 2): for i in range(new_out_dims_mapping_len - 2):
broadcast_out_dims_mapping.append(out_dims_mapping[i]) broadcast_out_dims_mapping.append(out_dims_mapping[i])
compatible_dims_mapping = compute_compatible_dims_mapping([ compatible_dims_mapping = compute_compatible_dims_mapping([
...@@ -97,23 +117,30 @@ def _update_dims_mapping_for_matmul(dist_op): ...@@ -97,23 +117,30 @@ def _update_dims_mapping_for_matmul(dist_op):
if compatible_dims_mapping is None: if compatible_dims_mapping is None:
return False return False
for i in range(x_dims_mapping_len - 2): for i in range(new_x_dims_mapping_len - 2):
new_idx = i + (out_dims_mapping_len - x_dims_mapping_len) new_idx = i + (out_dims_mapping_len - new_x_dims_mapping_len)
if x_dims_mapping[i] != compatible_dims_mapping[new_idx]: if x_dims_mapping[i] != compatible_dims_mapping[new_idx]:
x_dims_mapping[i] = compatible_dims_mapping[new_idx] x_dims_mapping[i] = compatible_dims_mapping[new_idx]
changed = True changed = True
for i in range(y_dims_mapping_len - 2): for i in range(new_y_dims_mapping_len - 2):
new_idx = i + (out_dims_mapping_len - y_dims_mapping_len) new_idx = i + (out_dims_mapping_len - new_y_dims_mapping_len)
if y_dims_mapping[i] != compatible_dims_mapping[new_idx]: if y_dims_mapping[i] != compatible_dims_mapping[new_idx]:
y_dims_mapping[i] = compatible_dims_mapping[new_idx] y_dims_mapping[i] = compatible_dims_mapping[new_idx]
changed = True changed = True
for i in range(out_dims_mapping_len - 2): for i in range(new_out_dims_mapping_len - 2):
if out_dims_mapping[i] != compatible_dims_mapping[i]: if out_dims_mapping[i] != compatible_dims_mapping[i]:
out_dims_mapping[i] = compatible_dims_mapping[i] out_dims_mapping[i] = compatible_dims_mapping[i]
changed = True changed = True
if trans_x:
x_dims_mapping[-1], x_dims_mapping[-2] = x_dims_mapping[
-2], x_dims_mapping[-1]
if trans_y:
y_dims_mapping[-1], y_dims_mapping[-2] = y_dims_mapping[
-2], y_dims_mapping[-1]
# The following which uses negative index can be work # The following which uses negative index can be work
# when len(out_dims_mapping) > 2 and len(out_dims_mapping) <=2 # when len(out_dims_mapping) > 2 and len(out_dims_mapping) <=2
dim_changed = compute_compatible_and_update_dim_mapping( dim_changed = compute_compatible_and_update_dim_mapping(
...@@ -131,11 +158,20 @@ def _update_dims_mapping_for_matmul(dist_op): ...@@ -131,11 +158,20 @@ def _update_dims_mapping_for_matmul(dist_op):
if dim_changed: if dim_changed:
changed = True changed = True
if trans_x:
x_dims_mapping[-1], x_dims_mapping[-2] = x_dims_mapping[
-2], x_dims_mapping[-1]
if trans_y:
y_dims_mapping[-1], y_dims_mapping[-2] = y_dims_mapping[
-2], y_dims_mapping[-1]
# Remove unnecessary dim mapping to make sure the length of dims_mapping is same as its tensor # Remove unnecessary dim mapping to make sure the length of dims_mapping is same as its tensor
if x_dims_mapping_len == 1: if x_dims_mapping_len == 1:
x_dims_mapping.pop(0) x_dims_mapping.pop(0)
out_dims_mapping.pop(out_dims_mapping_len - 1)
if y_dims_mapping_len == 1: if y_dims_mapping_len == 1:
y_dims_mapping.pop(1) y_dims_mapping.pop(1)
out_dims_mapping.pop(out_dims_mapping_len)
assert len(x_dims_mapping) == x_dims_mapping_len assert len(x_dims_mapping) == x_dims_mapping_len
assert len(y_dims_mapping) == y_dims_mapping_len assert len(y_dims_mapping) == y_dims_mapping_len
...@@ -484,6 +520,102 @@ class DistributedMatmulImpl0(DistributedOperatorImpl): ...@@ -484,6 +520,102 @@ class DistributedMatmulImpl0(DistributedOperatorImpl):
self._forward_implemented = True self._forward_implemented = True
self._backward_implemented = True self._backward_implemented = True
def calc_cost(self, op_role, dist_op, ctx, cluster):
cost = None
if int(op_role) == int(OpRole.Forward):
cost = self.calc_fwd_cost(dist_op, ctx, cluster)
elif int(op_role) == int(OpRole.Backward):
cost = self.calc_bwd_cost(dist_op, ctx, cluster)
assert cost is not None
return cost
def calc_bwd_cost(self, dist_op, ctx, cluster):
# by now the backward function only insert the gradient allreduce for dist op itself
res = []
backward_op = dist_op.serial_op
dist_attr = dist_op.dist_attr
main_block = backward_op.block
vars = main_block.vars
Y_var_dim_mapping = dist_attr.get_input_dims_mapping(
backward_op.input("Y")[0])
# col parallel: matmul + allreduce
assert Y_var_dim_mapping[0] < 0
parallel_axis = Y_var_dim_mapping[1]
has_x_grad = len(backward_op.output("X@GRAD")) > 0
if has_x_grad:
assert len(backward_op.output("X@GRAD")) == 1
# calc comp op cost
desc_mapping = build_comp_desc_from_dist_op(dist_op=dist_op,
dist_context=ctx)
process_mesh = dist_attr.process_mesh
processes = process_mesh.processes
cost_mapping = build_comp_costs_from_descs(MatmulGradOpCost, ctx,
processes, desc_mapping,
cluster)
res.append(cost_mapping)
# calc comm op cost
if has_x_grad:
attrs = {"use_calc_stream": True, "use_model_parallel": True}
var_names = backward_op.output("X@GRAD")
c_allreduce_sum_desc_mapping = build_comm_desc_from_dist_op(
"c_allreduce_sum",
dist_op,
ctx,
var_names,
attrs=attrs,
parallel_axis=parallel_axis)
comm_op_cost_list = build_comm_costs_from_descs(
AllreduceSumOpCost, ctx, processes,
c_allreduce_sum_desc_mapping, cluster)
res.append(comm_op_cost_list)
# need gradient allreduce
var_dim_mapping = dist_attr.get_input_dims_mapping(
backward_op.input("X")[0])
mesh_shape = process_mesh.topology
batch_size_axis = var_dim_mapping[0]
if batch_size_axis > -1 and mesh_shape[
batch_size_axis] > 1 and is_parameter_related(
backward_op.input("Y")[0], main_block):
parallel_axis = batch_size_axis
attrs = {"use_calc_stream": True}
var_names = [backward_op.output('Y@GRAD')[0]]
build_dp_costs(res, dist_op, ctx, var_names, attrs, parallel_axis,
cluster)
return res
def calc_fwd_cost(self, dist_op, ctx, cluster):
# calc comp op cost
desc_mapping = build_comp_desc_from_dist_op(dist_op=dist_op,
dist_context=ctx)
processes = dist_op.dist_attr.process_mesh.processes
cost_mapping = build_comp_costs_from_descs(MatmulOpCost, ctx, processes,
desc_mapping, cluster)
# calc comm op cost
serial_op = dist_op.serial_op
vars = serial_op.block.vars
parallel_axis = dist_op.dist_attr.get_input_dims_mapping(
serial_op.input("Y")[0])[-1]
attrs = {"use_calc_stream": True, "use_model_parallel": True}
var_names = serial_op.input("X")
c_identity_desc_mapping = build_comm_desc_from_dist_op(
"c_identity",
dist_op,
ctx,
var_names,
attrs=attrs,
parallel_axis=parallel_axis)
comm_op_cost_list = build_comm_costs_from_descs(
IdentityOpCost, ctx, processes, c_identity_desc_mapping, cluster)
res_cost = [comm_op_cost_list, cost_mapping]
return res_cost
def is_input_compatible(self, dist_op): def is_input_compatible(self, dist_op):
op_desc = dist_op.serial_op.desc op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr op_dist_attr = dist_op.dist_attr
...@@ -710,6 +842,99 @@ class DistributedMatmulImpl1(DistributedOperatorImpl): ...@@ -710,6 +842,99 @@ class DistributedMatmulImpl1(DistributedOperatorImpl):
self._forward_implemented = True self._forward_implemented = True
self._backward_implemented = True self._backward_implemented = True
def calc_cost(self, op_role, dist_op, ctx, cluster):
cost = None
if int(op_role) == int(OpRole.Forward):
cost = self.calc_fwd_cost(dist_op, ctx, cluster)
elif int(op_role) == int(OpRole.Backward):
cost = self.calc_bwd_cost(dist_op, ctx, cluster)
assert cost is not None
return cost
def calc_bwd_cost(self, dist_op, ctx, cluster):
# by now the backward function only insert the gradient allreduce for dist op itself
res = []
backward_op = dist_op.serial_op
dist_attr = dist_op.dist_attr
main_block = backward_op.block
vars = main_block.vars
Y_var_dim_mapping = dist_attr.get_input_dims_mapping(
backward_op.input("Y")[0])
assert Y_var_dim_mapping[1] < 0
parallel_axis = Y_var_dim_mapping[0]
# calc comm op cost
var_names = [backward_op.input("Out@GRAD")[0]]
attrs = {"use_calc_stream": True, "use_model_parallel": True}
c_identity_desc_mapping = build_comm_desc_from_dist_op(
"c_identity",
dist_op,
ctx,
var_names,
attrs=attrs,
parallel_axis=parallel_axis)
process_mesh = dist_attr.process_mesh
processes = process_mesh.processes
comm_op_cost_list = build_comm_costs_from_descs(
IdentityOpCost, ctx, processes, c_identity_desc_mapping, cluster)
res.append(comm_op_cost_list)
# calc comp op cost
desc_mapping = build_comp_desc_from_dist_op(dist_op=dist_op,
dist_context=ctx)
cost_mapping = build_comp_costs_from_descs(MatmulGradOpCost, ctx,
processes, desc_mapping,
cluster)
res.append(cost_mapping)
# need gradient allreduce
var_dim_mapping = dist_attr.get_input_dims_mapping(
backward_op.input("X")[0])
mesh_shape = process_mesh.topology
batch_size_axis = var_dim_mapping[0]
if batch_size_axis > -1 and mesh_shape[
batch_size_axis] > 1 and is_parameter_related(
backward_op.input("Y")[0], main_block):
parallel_axis = batch_size_axis
attrs = {"use_calc_stream": True}
var_names = [backward_op.output('Y@GRAD')[0]]
build_dp_costs(res, dist_op, ctx, var_names, attrs, parallel_axis,
cluster)
return res
def calc_fwd_cost(self, dist_op, ctx, cluster):
# calc comp op cost
desc_mapping = build_comp_desc_from_dist_op(dist_op=dist_op,
dist_context=ctx)
processes = dist_op.dist_attr.process_mesh.processes
cost_mapping = build_comp_costs_from_descs(MatmulOpCost, ctx, processes,
desc_mapping, cluster)
# calc comm op cost
serial_op = dist_op.serial_op
vars = serial_op.block.vars
parallel_axis = dist_op.dist_attr.get_input_dims_mapping(
serial_op.input("Y")[0])[-2]
attrs = {"use_calc_stream": True, "use_model_parallel": True}
var_names = serial_op.output("Out")
c_allreduce_sum_desc_mapping = build_comm_desc_from_dist_op(
"c_allreduce_sum",
dist_op,
ctx,
var_names,
attrs=attrs,
parallel_axis=parallel_axis)
comm_op_cost_list = build_comm_costs_from_descs(
AllreduceSumOpCost, ctx, processes, c_allreduce_sum_desc_mapping,
cluster)
res_cost = [cost_mapping, comm_op_cost_list]
return res_cost
def is_input_compatible(self, dist_op): def is_input_compatible(self, dist_op):
op_desc = dist_op.serial_op.desc op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr op_dist_attr = dist_op.dist_attr
...@@ -920,6 +1145,59 @@ class DistributedMatmulImpl2(DistributedOperatorImpl): ...@@ -920,6 +1145,59 @@ class DistributedMatmulImpl2(DistributedOperatorImpl):
def __init__(self, name): def __init__(self, name):
super(DistributedMatmulImpl2, self).__init__(name) super(DistributedMatmulImpl2, self).__init__(name)
def calc_cost(self, op_role, dist_op, ctx, cluster):
cost = None
if int(op_role) == int(OpRole.Forward):
cost = self.calc_fwd_cost(dist_op, ctx, cluster)
elif int(op_role) == int(OpRole.Backward):
cost = self.calc_bwd_cost(dist_op, ctx, cluster)
assert cost is not None
return cost
def calc_bwd_cost(self, dist_op, ctx, cluster):
res = []
backward_op = dist_op.serial_op
dist_attr = dist_op.dist_attr
main_block = backward_op.block
vars = main_block.vars
# calc comp op cost
desc_mapping = build_comp_desc_from_dist_op(dist_op=dist_op,
dist_context=ctx)
process_mesh = dist_attr.process_mesh
processes = process_mesh.processes
cost_mapping = build_comp_costs_from_descs(MatmulGradOpCost, ctx,
processes, desc_mapping,
cluster)
res.append(cost_mapping)
# need gradient allreduce
var_dim_mapping = dist_attr.get_input_dims_mapping(
backward_op.input("X")[0])
mesh_shape = process_mesh.topology
batch_size_axis = var_dim_mapping[0]
if batch_size_axis > -1 and mesh_shape[
batch_size_axis] > 1 and is_parameter_related(
backward_op.input("Y")[0], main_block):
parallel_axis = batch_size_axis
attrs = {"use_calc_stream": True}
var_names = [backward_op.output('Y@GRAD')[0]]
build_dp_costs(res, dist_op, ctx, var_names, attrs, parallel_axis,
cluster)
return res
def calc_fwd_cost(self, dist_op, ctx, cluster):
# calc comp op cost
desc_mapping = build_comp_desc_from_dist_op(dist_op=dist_op,
dist_context=ctx)
processes = dist_op.dist_attr.process_mesh.processes
cost_mapping = build_comp_costs_from_descs(MatmulOpCost, ctx, processes,
desc_mapping, cluster)
res_cost = [cost_mapping]
return res_cost
def is_input_compatible(self, dist_op): def is_input_compatible(self, dist_op):
op_desc = dist_op.serial_op.desc op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr op_dist_attr = dist_op.dist_attr
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
from .common import DistributedOperatorImplContainer from .common import DistributedOperatorImplContainer
from .common import DistributedOperatorImpl from .common import DistributedOperatorImpl
from .common import register_distributed_operator_impl_container from .common import register_distributed_operator_impl_container
from .common import register_distributed_operator_impl from .common import register_distributed_operator_impl, is_parameter_related
from ..utils import is_dim_shard from ..utils import is_dim_shard
from ..utils import is_dim_replicate from ..utils import is_dim_replicate
from ..utils import is_valid_list_index from ..utils import is_valid_list_index
...@@ -28,6 +28,11 @@ from paddle.fluid.framework import _non_static_mode ...@@ -28,6 +28,11 @@ from paddle.fluid.framework import _non_static_mode
from paddle.fluid.framework import Program, Parameter, Variable, program_guard from paddle.fluid.framework import Program, Parameter, Variable, program_guard
from paddle.fluid.data_feeder import check_variable_and_dtype, check_dtype from paddle.fluid.data_feeder import check_variable_and_dtype, check_dtype
from .dist_default import DistributedDefaultImpl0 from .dist_default import DistributedDefaultImpl0
from ..cost import build_comp_desc_from_dist_op, build_comp_costs_from_descs
from ..cost import build_comm_costs_from_descs
from ..cost import Reshape2OpCost
from ..cost import Reshape2GradOpCost
from paddle.distributed.fleet.meta_optimizers.common import OpRole
class DistributedReshape2(DistributedOperatorImplContainer): class DistributedReshape2(DistributedOperatorImplContainer):
...@@ -46,6 +51,84 @@ class DistributedReshapeImpl0(DistributedOperatorImpl): ...@@ -46,6 +51,84 @@ class DistributedReshapeImpl0(DistributedOperatorImpl):
self._forward_implemented = True self._forward_implemented = True
self._backward_implemented = False self._backward_implemented = False
def calc_cost(self, op_role, dist_op, ctx, cluster):
cost = None
if int(op_role) == int(OpRole.Backward):
cost = self.calc_bwd_cost(dist_op, ctx, cluster)
else:
cost = self.calc_fwd_cost(dist_op, ctx, cluster)
assert cost is not None
return cost
def calc_fwd_cost(self, dist_op, ctx, cluster):
res = []
op = dist_op.serial_op
vars = op.block.vars
dist_attr = dist_op.dist_attr
shape_list = op.desc.attr("shape")
# got dist attribute info
dim_mapping = dist_attr.get_output_dims_mapping(op.output("Out")[0])
process_mesh_shape = dist_attr.process_mesh.topology
# modify target shape
for idx, axis in enumerate(dim_mapping):
if axis >= 0:
if len(shape_list) > idx:
shape_list[
idx] = shape_list[idx] // process_mesh_shape[axis]
# calc comp op cost
desc_mapping = build_comp_desc_from_dist_op(dist_op=dist_op,
dist_context=ctx)
processes = dist_attr.process_mesh.processes
for key in desc_mapping:
desc_mapping[key]["shape"] = shape_list
cost_mapping = build_comp_costs_from_descs(Reshape2OpCost, ctx,
processes, desc_mapping,
cluster)
res.append(cost_mapping)
return res
def calc_bwd_cost(self, dist_op, ctx, cluster):
# calc comp op cost
res = []
desc_mapping = build_comp_desc_from_dist_op(dist_op=dist_op,
dist_context=ctx)
dist_attr = dist_op.dist_attr
process_mesh = dist_attr.process_mesh
processes = process_mesh.processes
op_type = dist_op.serial_op.type
cost_mapping = build_comp_costs_from_descs(Reshape2GradOpCost, ctx,
processes, desc_mapping,
cluster)
res.append(cost_mapping)
backward_op = dist_op.serial_op
main_block = backward_op.block
need_gradient_allreduce = False
vars = main_block.vars
for input_name in backward_op.desc.input_names():
for varname in backward_op.desc.input(input_name):
if "@GRAD" not in varname and is_parameter_related(
varname, main_block):
# NOTE input var's dim_mapping of backward op should be the same with input var instead of corresponding varname of forward op
var_dim_mapping = dist_attr.get_input_dims_mapping(varname)
mesh_shape = process_mesh.topology
batch_size_axis = var_dim_mapping[0]
if batch_size_axis > -1 and mesh_shape[batch_size_axis] > 1:
parallel_axis = batch_size_axis
attrs = {"use_calc_stream": True}
var_names = [varname + "@GRAD"]
build_dp_costs(res, dist_op, ctx, var_names, attrs,
parallel_axis, cluster)
return res
def is_input_compatible(self, dist_op): def is_input_compatible(self, dist_op):
op_desc = dist_op.serial_op.desc op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr op_dist_attr = dist_op.dist_attr
...@@ -199,6 +282,84 @@ class DistributedReshapeImpl1(DistributedOperatorImpl): ...@@ -199,6 +282,84 @@ class DistributedReshapeImpl1(DistributedOperatorImpl):
self._forward_implemented = True self._forward_implemented = True
self._backward_implemented = False self._backward_implemented = False
def calc_cost(self, op_role, dist_op, ctx, cluster):
cost = None
if int(op_role) == int(OpRole.Backward):
cost = self.calc_bwd_cost(dist_op, ctx, cluster)
else:
cost = self.calc_fwd_cost(dist_op, ctx, cluster)
assert cost is not None
return cost
def calc_fwd_cost(self, dist_op, ctx, cluster):
res = []
op = dist_op.serial_op
vars = op.block.vars
dist_attr = dist_op.dist_attr
shape_list = op.desc.attr("shape")
# got dist attribute info
dim_mapping = dist_attr.get_output_dims_mapping(op.output("Out")[0])
process_mesh_shape = dist_attr.process_mesh.topology
# modify target shape
for idx, axis in enumerate(dim_mapping):
if axis >= 0:
if len(shape_list) > idx:
shape_list[
idx] = shape_list[idx] // process_mesh_shape[axis]
# calc comp op cost
desc_mapping = build_comp_desc_from_dist_op(dist_op=dist_op,
dist_context=ctx)
processes = dist_attr.process_mesh.processes
for key in desc_mapping:
desc_mapping[key]["shape"] = shape_list
cost_mapping = build_comp_costs_from_descs(Reshape2OpCost, ctx,
processes, desc_mapping,
cluster)
res.append(cost_mapping)
return res
def calc_bwd_cost(self, dist_op, ctx, cluster):
# calc comp op cost
res = []
desc_mapping = build_comp_desc_from_dist_op(dist_op=dist_op,
dist_context=ctx)
dist_attr = dist_op.dist_attr
process_mesh = dist_attr.process_mesh
processes = process_mesh.processes
op_type = dist_op.serial_op.type
cost_mapping = build_comp_costs_from_descs(Reshape2GradOpCost, ctx,
processes, desc_mapping,
cluster)
res.append(cost_mapping)
backward_op = dist_op.serial_op
main_block = backward_op.block
need_gradient_allreduce = False
vars = main_block.vars
for input_name in backward_op.desc.input_names():
for varname in backward_op.desc.input(input_name):
if "@GRAD" not in varname and not is_parameter_related(
varname, main_block):
# NOTE input var's dim_mapping of backward op should be the same with input var instead of corresponding varname of forward op
var_dim_mapping = dist_attr.get_input_dims_mapping(varname)
mesh_shape = process_mesh.topology
batch_size_axis = var_dim_mapping[0]
if batch_size_axis > -1 and mesh_shape[batch_size_axis] > 1:
parallel_axis = batch_size_axis
attrs = {"use_calc_stream": True}
var_names = [varname + "@GRAD"]
build_dp_costs(res, dist_op, ctx, var_names, attrs,
parallel_axis, cluster)
return res
def is_input_compatible(self, dist_op): def is_input_compatible(self, dist_op):
op_desc = dist_op.serial_op.desc op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr op_dist_attr = dist_op.dist_attr
...@@ -355,6 +516,84 @@ class DistributedReshapeImpl2(DistributedOperatorImpl): ...@@ -355,6 +516,84 @@ class DistributedReshapeImpl2(DistributedOperatorImpl):
self._forward_implemented = True self._forward_implemented = True
self._backward_implemented = False self._backward_implemented = False
def calc_cost(self, op_role, dist_op, ctx, cluster):
cost = None
if int(op_role) == int(OpRole.Backward):
cost = self.calc_bwd_cost(dist_op, ctx, cluster)
else:
cost = self.calc_fwd_cost(dist_op, ctx, cluster)
assert cost is not None
return cost
def calc_fwd_cost(self, dist_op, ctx, cluster):
res = []
op = dist_op.serial_op
vars = op.block.vars
dist_attr = dist_op.dist_attr
shape_list = op.desc.attr("shape")
# got dist attribute info
dim_mapping = dist_attr.get_output_dims_mapping(op.output("Out")[0])
process_mesh_shape = dist_attr.process_mesh.topology
# modify target shape
for idx, axis in enumerate(dim_mapping):
if axis >= 0:
if len(shape_list) > idx:
shape_list[
idx] = shape_list[idx] // process_mesh_shape[axis]
# calc comp op cost
desc_mapping = build_comp_desc_from_dist_op(dist_op=dist_op,
dist_context=ctx)
processes = dist_attr.process_mesh.processes
for key in desc_mapping:
desc_mapping[key]["shape"] = shape_list
cost_mapping = build_comp_costs_from_descs(Reshape2OpCost, ctx,
processes, desc_mapping,
cluster)
res.append(cost_mapping)
return res
def calc_bwd_cost(self, dist_op, ctx, cluster):
# calc comp op cost
res = []
desc_mapping = build_comp_desc_from_dist_op(dist_op=dist_op,
dist_context=ctx)
dist_attr = dist_op.dist_attr
process_mesh = dist_attr.process_mesh
processes = process_mesh.processes
op_type = dist_op.serial_op.type
cost_mapping = build_comp_costs_from_descs(Reshape2GradOpCost, ctx,
processes, desc_mapping,
cluster)
res.append(cost_mapping)
backward_op = dist_op.serial_op
main_block = backward_op.block
need_gradient_allreduce = False
vars = main_block.vars
for input_name in backward_op.desc.input_names():
for varname in backward_op.desc.input(input_name):
if "@GRAD" not in varname and not is_parameter_related(
varname, main_block):
# NOTE input var's dim_mapping of backward op should be the same with input var instead of corresponding varname of forward op
var_dim_mapping = dist_attr.get_input_dims_mapping(varname)
mesh_shape = process_mesh.topology
batch_size_axis = var_dim_mapping[0]
if batch_size_axis > -1 and mesh_shape[batch_size_axis] > 1:
parallel_axis = batch_size_axis
attrs = {"use_calc_stream": True}
var_names = [varname + "@GRAD"]
build_dp_costs(res, dist_op, ctx, var_names, attrs,
parallel_axis, cluster)
return res
def is_input_compatible(self, dist_op): def is_input_compatible(self, dist_op):
op_desc = dist_op.serial_op.desc op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr op_dist_attr = dist_op.dist_attr
......
...@@ -16,6 +16,7 @@ from .common import DistributedOperatorImplContainer ...@@ -16,6 +16,7 @@ from .common import DistributedOperatorImplContainer
from .common import DistributedOperatorImpl from .common import DistributedOperatorImpl
from .common import register_distributed_operator_impl_container from .common import register_distributed_operator_impl_container
from .common import register_distributed_operator_impl from .common import register_distributed_operator_impl
from .common import is_parameter_related
from ..utils import is_dim_shard from ..utils import is_dim_shard
from ..utils import is_dim_replicate from ..utils import is_dim_replicate
from ..utils import is_valid_list_index from ..utils import is_valid_list_index
...@@ -23,6 +24,11 @@ from ..utils import compute_compatible_dim_mapping ...@@ -23,6 +24,11 @@ from ..utils import compute_compatible_dim_mapping
from ..utils import compute_compatible_dims_mapping from ..utils import compute_compatible_dims_mapping
from ..utils import compute_compatible_and_update_dim_mapping from ..utils import compute_compatible_and_update_dim_mapping
from .dist_default import DistributedDefaultImpl0 from .dist_default import DistributedDefaultImpl0
from ..cost import AllreduceSumOpCost, _g_op_cost_factory
from ..cost import build_comp_desc_from_dist_op, build_dp_costs
from ..cost import build_comp_costs_from_descs
from ..cost import SoftmaxOpCost, SoftmaxGradOpCost
from paddle.distributed.fleet.meta_optimizers.common import OpRole
class DistributedSoftmax(DistributedOperatorImplContainer): class DistributedSoftmax(DistributedOperatorImplContainer):
...@@ -41,6 +47,62 @@ class DistributedSoftmaxImpl(DistributedOperatorImpl): ...@@ -41,6 +47,62 @@ class DistributedSoftmaxImpl(DistributedOperatorImpl):
self._forward_implemented = False self._forward_implemented = False
self._backward_implemented = False self._backward_implemented = False
def calc_cost(self, op_role, dist_op, ctx, cluster):
cost = None
if int(op_role) == int(OpRole.Backward):
cost = self.calc_bwd_cost(dist_op, ctx, cluster)
else:
cost = self.calc_fwd_cost(dist_op, ctx, cluster)
assert cost is not None
return cost
def calc_fwd_cost(self, dist_op, ctx, cluster):
# calc comp op cost
desc_mapping = build_comp_desc_from_dist_op(dist_op=dist_op,
dist_context=ctx)
processes = dist_op.dist_attr.process_mesh.processes
cost_mapping = build_comp_costs_from_descs(SoftmaxOpCost, ctx,
processes, desc_mapping,
cluster)
res_cost = [cost_mapping]
return res_cost
def calc_bwd_cost(self, dist_op, ctx, cluster):
# calc comp op cost
res = []
desc_mapping = build_comp_desc_from_dist_op(dist_op=dist_op,
dist_context=ctx)
dist_attr = dist_op.dist_attr
process_mesh = dist_attr.process_mesh
processes = process_mesh.processes
cost_mapping = build_comp_costs_from_descs(SoftmaxGradOpCost, ctx,
processes, desc_mapping,
cluster)
res.append(cost_mapping)
backward_op = dist_op.serial_op
main_block = backward_op.block
need_gradient_allreduce = False
vars = main_block.vars
for input_name in backward_op.desc.input_names():
for varname in backward_op.desc.input(input_name):
if "@GRAD" not in varname and is_parameter_related(
varname, main_block):
# NOTE input var's dim_mapping of backward op should be the same with input var instead of corresponding varname of forward op
var_dim_mapping = dist_attr.get_input_dims_mapping(varname)
mesh_shape = process_mesh.topology
batch_size_axis = var_dim_mapping[0]
if batch_size_axis > -1 and mesh_shape[batch_size_axis] > 1:
parallel_axis = batch_size_axis
attrs = {"use_calc_stream": True}
var_names = [varname + "@GRAD"]
build_dp_costs(res, dist_op, ctx, var_names, attrs,
parallel_axis, cluster)
return res
def is_input_compatible(self, dist_op): def is_input_compatible(self, dist_op):
op_desc = dist_op.serial_op.desc op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr op_dist_attr = dist_op.dist_attr
......
...@@ -16,6 +16,7 @@ from .common import DistributedOperatorImplContainer ...@@ -16,6 +16,7 @@ from .common import DistributedOperatorImplContainer
from .common import DistributedOperatorImpl from .common import DistributedOperatorImpl
from .common import register_distributed_operator_impl_container from .common import register_distributed_operator_impl_container
from .common import register_distributed_operator_impl from .common import register_distributed_operator_impl
from .common import is_parameter_related
from ..utils import is_dim_shard from ..utils import is_dim_shard
from ..utils import is_dim_replicate from ..utils import is_dim_replicate
from ..utils import is_valid_list_index from ..utils import is_valid_list_index
...@@ -23,6 +24,10 @@ from ..utils import compute_compatible_dim_mapping ...@@ -23,6 +24,10 @@ from ..utils import compute_compatible_dim_mapping
from ..utils import compute_compatible_dims_mapping from ..utils import compute_compatible_dims_mapping
from ..utils import compute_compatible_and_update_dim_mapping from ..utils import compute_compatible_and_update_dim_mapping
from .dist_default import DistributedDefaultImpl0 from .dist_default import DistributedDefaultImpl0
from ..cost import AllreduceSumOpCost, Transpose2OpCost, Transpose2GradOpCost
from ..cost import build_comp_desc_from_dist_op, build_comm_desc_from_dist_op, build_dp_costs
from ..cost import build_comp_costs_from_descs
from paddle.distributed.fleet.meta_optimizers.common import OpRole
class DistributedTranspose2(DistributedOperatorImplContainer): class DistributedTranspose2(DistributedOperatorImplContainer):
...@@ -116,6 +121,63 @@ class DistributedTranspose2Impl(DistributedOperatorImpl): ...@@ -116,6 +121,63 @@ class DistributedTranspose2Impl(DistributedOperatorImpl):
return changed return changed
def calc_cost(self, op_role, dist_op, ctx, cluster):
cost = None
if int(op_role) == int(OpRole.Backward):
cost = self.calc_bwd_cost(dist_op, ctx, cluster)
else:
cost = self.calc_fwd_cost(dist_op, ctx, cluster)
assert cost is not None
return cost
def calc_fwd_cost(self, dist_op, ctx, cluster):
# calc comp op cost
desc_mapping = build_comp_desc_from_dist_op(dist_op=dist_op,
dist_context=ctx)
processes = dist_op.dist_attr.process_mesh.processes
op_type = dist_op.serial_op.type
cost_mapping = build_comp_costs_from_descs(Transpose2OpCost, ctx,
processes, desc_mapping,
cluster)
res_cost = [cost_mapping]
return res_cost
def calc_bwd_cost(self, dist_op, ctx, cluster):
# calc comp op cost
res = []
desc_mapping = build_comp_desc_from_dist_op(dist_op=dist_op,
dist_context=ctx)
dist_attr = dist_op.dist_attr
process_mesh = dist_attr.process_mesh
processes = process_mesh.processes
op_type = dist_op.serial_op.type
cost_mapping = build_comp_costs_from_descs(Transpose2GradOpCost, ctx,
processes, desc_mapping,
cluster)
res.append(cost_mapping)
backward_op = dist_op.serial_op
main_block = backward_op.block
need_gradient_allreduce = False
vars = main_block.vars
for input_name in backward_op.desc.input_names():
for varname in backward_op.desc.input(input_name):
if "@GRAD" not in varname and is_parameter_related(
varname, main_block):
# NOTE input var's dim_mapping of backward op should be the same with input var instead of corresponding varname of forward op
var_dim_mapping = dist_attr.get_input_dims_mapping(varname)
mesh_shape = process_mesh.topology
batch_size_axis = var_dim_mapping[0]
if batch_size_axis > -1 and mesh_shape[batch_size_axis] > 1:
parallel_axis = batch_size_axis
attrs = {"use_calc_stream": True}
var_names = [varname + "@GRAD"]
build_dp_costs(res, dist_op, ctx, var_names, attrs,
parallel_axis, cluster)
return res
@staticmethod @staticmethod
def forward(ctx, *args, **kwargs): def forward(ctx, *args, **kwargs):
DistributedDefaultImpl0.forward(ctx, *args, **kwargs) DistributedDefaultImpl0.forward(ctx, *args, **kwargs)
......
...@@ -47,7 +47,7 @@ def parallelizer(program_func, rank): ...@@ -47,7 +47,7 @@ def parallelizer(program_func, rank):
completer.complete_backward_annotation(main_program) completer.complete_backward_annotation(main_program)
dist_context.block_state.parse_backward_blocks(main_program) dist_context.block_state.parse_backward_blocks(main_program)
optimizer = paddle.optimizer.SGD(learning_rate=0.001) optimizer = paddle.optimizer.Adam(learning_rate=0.001)
# generate opt and complete opt # generate opt and complete opt
with program_guard(main_program, startup_program): with program_guard(main_program, startup_program):
optimize_ops = copy.deepcopy(optimizer).apply_gradients(params_grads) optimize_ops = copy.deepcopy(optimizer).apply_gradients(params_grads)
...@@ -59,7 +59,7 @@ def parallelizer(program_func, rank): ...@@ -59,7 +59,7 @@ def parallelizer(program_func, rank):
class TestDistOpCost(unittest.TestCase): class TestDistOpCost(unittest.TestCase):
def test_dist_fill_constatnt_batch_size_like_op_cost(self): def test_dist_op_cost_part1(self):
def make_program(): def make_program():
main_program = paddle.static.Program() main_program = paddle.static.Program()
...@@ -79,7 +79,7 @@ class TestDistOpCost(unittest.TestCase): ...@@ -79,7 +79,7 @@ class TestDistOpCost(unittest.TestCase):
tmp = paddle.fluid.layers.fill_constant_batch_size_like( tmp = paddle.fluid.layers.fill_constant_batch_size_like(
input=x, shape=[2, 8], value=1, dtype='float32') input=x, shape=[2, 8], value=1, dtype='float32')
weight_attr = paddle.ParamAttr() weight_attr = paddle.ParamAttr()
linear = paddle.nn.Linear(8, 8, weight_attr=weight_attr) linear = paddle.nn.Linear(8, 1, weight_attr=weight_attr)
linear_out = linear(x) linear_out = linear(x)
gelu_out = paddle.nn.functional.gelu(linear_out) gelu_out = paddle.nn.functional.gelu(linear_out)
# default op with dp # default op with dp
...@@ -109,6 +109,112 @@ class TestDistOpCost(unittest.TestCase): ...@@ -109,6 +109,112 @@ class TestDistOpCost(unittest.TestCase):
dist_context, cluster) dist_context, cluster)
self.assertTrue(dist_op_cost) self.assertTrue(dist_op_cost)
def test_dist_op_cost_part2(self):
def make_program():
main_program = paddle.static.Program()
start_program = paddle.static.Program()
with paddle.static.program_guard(main_program, start_program):
x = paddle.static.data(name='x', shape=[4], dtype='float32')
x.stop_gradient = True
label = paddle.static.data(name="label",
shape=[8, 1],
dtype='float32')
label.stop_gradient = True
auto.shard_tensor(x,
dist_attr={
"process_mesh": auto.ProcessMesh([0, 1]),
"dims_mapping": [0]
})
auto.shard_tensor(label,
dist_attr={
"process_mesh": auto.ProcessMesh([0, 1]),
"dims_mapping": [0, -1]
})
# embedding
tmp = paddle.fluid.layers.fill_constant_batch_size_like(
input=x, shape=[4], value=1, dtype='int32')
embedding = paddle.nn.Embedding(10, 8)
out = embedding(tmp)
# row parallel embedding
for op in main_program.global_block().ops:
if op.type == "lookup_table_v2":
W = main_program.global_block().vars[op.input("W")[0]]
auto.shard_tensor(W,
dist_attr={
"process_mesh":
auto.ProcessMesh([0, 1]),
"dims_mapping": [0, -1]
})
out = paddle.fluid.layers.transpose(out,
[1, 0]) # [8, 2] [-1, 0]
# matmul
param1 = paddle.fluid.layers.create_parameter(
[4, 8], paddle.float32) # [2, 8] [0, -1]
auto.shard_tensor(param1,
dist_attr={
"process_mesh": auto.ProcessMesh([0, 1]),
"dims_mapping": [0, -1]
})
param2 = paddle.fluid.layers.create_parameter(
[8, 8], paddle.float32) # [8, 4] [-1, 0]
auto.shard_tensor(param2,
dist_attr={
"process_mesh": auto.ProcessMesh([0, 1]),
"dims_mapping": [-1, 0]
})
out1 = paddle.fluid.layers.matmul(out,
param1) # [8, 8] [-1, -1]
tmp_param = paddle.fluid.layers.create_parameter(
[8, 8], paddle.float32) # [8, 8] [-1, -1]
auto.shard_tensor(param2,
dist_attr={
"process_mesh": auto.ProcessMesh([0, 1]),
"dims_mapping": [-1, -1]
})
tmp_out = paddle.fluid.layers.matmul(out1, tmp_param)
out2 = paddle.fluid.layers.matmul(tmp_out,
param2) # [8, 4] [-1, 0]
out8 = paddle.fluid.layers.transpose(out2,
[1, 0]) # [4, 8] [0, -1]
# reshape
out9 = paddle.reshape(out8, [8, 2, 4]) # [4, 2, 4] [0, -1, -1]
tmp_reshape_out = paddle.reshape(out9, [8, 4, 2])
out10 = paddle.reshape(tmp_reshape_out,
[8, 8]) # [4, 8] [0, -1]
# softmax
softmax = paddle.nn.Softmax()
out11 = softmax(out10)
error_cost = paddle.nn.functional.square_error_cost(
out11, label)
loss = paddle.mean(error_cost)
return main_program, start_program, loss
main_program, dist_context = parallelizer(make_program, 0)
ops = main_program.global_block().ops
cluster = Cluster()
cluster.gen_default_config_cluster(device_count=2)
for idx, op in enumerate(ops):
dist_op = dist_context.get_dist_op_for_program(op)
op_dist_attr = dist_op.dist_attr
processes = op_dist_attr.process_mesh.processes
if is_elementwise_op(op.type):
container = get_distributed_operator_impl_container(
"elementwise")
else:
container = get_distributed_operator_impl_container(
op_dist_attr.impl_type)
dist_impl = container.impls[op_dist_attr.impl_idx]
dist_op_cost = dist_impl.calc_cost(op.attr('op_role'), dist_op,
dist_context, cluster)
self.assertTrue(dist_op_cost)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册