未验证 提交 fc5fb2a1 编写于 作者: Z zhaoyingli 提交者: GitHub

add dist_attr for dist op and var (#35585)

* add dist_attr for dist op

* add unitest

* update inputname

* update function name

* add unitest

* update CMakeLists.txt for CI

* fix dis_matmul

* fix compile error

* update matmul to matmul_v2
上级 09eaa7d7
......@@ -30,13 +30,37 @@ using Tensor = framework::Tensor;
template <typename T1, typename T2, typename OutType>
class GpuAndCpuSearchSortedCompute {
public:
static HOSTDEVICE bool IsNan(float x) { return ::isnan(x); }
static HOSTDEVICE bool IsNan(double x) { return ::isnan(x); }
static HOSTDEVICE bool IsNan(float x) {
#ifdef __NVCC__
return ::isnan(x);
#else
return std::isnan(x);
#endif
}
static HOSTDEVICE bool IsNan(double x) {
#ifdef __NVCC__
return ::isnan(x);
#else
return std::isnan(x);
#endif
}
static HOSTDEVICE bool IsNan(int x) { return false; }
static HOSTDEVICE bool IsNan(int64_t x) { return false; }
static HOSTDEVICE bool IsInf(float x) { return ::isinf(x); }
static HOSTDEVICE bool IsInf(double x) { return ::isinf(x); }
static HOSTDEVICE bool IsInf(float x) {
#ifdef __NVCC__
return ::isinf(x);
#else
return std::isinf(x);
#endif
}
static HOSTDEVICE bool IsInf(double x) {
#ifdef __NVCC__
return ::isinf(x);
#else
return std::isinf(x);
#endif
}
static HOSTDEVICE bool IsInf(int x) { return false; }
static HOSTDEVICE bool IsInf(int64_t x) { return false; }
......
......@@ -114,3 +114,50 @@ def find_best_compatible_distributed_operator_impl(name, op_dist_attr,
best_compatible_impl, idx = None, -1
return best_compatible_impl, idx
def copy_distributed_attr_for_var(src_op_dist_attr, var, src_var):
"""
copy src var's dist_attr to dst var
"""
import copy
auto_paralle_context = src_op_dist_attr.get_owner_context()
dist_attr = copy.deepcopy(
auto_paralle_context.get_tensor_distributed_attr_for_program(src_var))
dist_attr._owner_tensor = var
dist_attr._owner_context = auto_paralle_context.get_tensor_distributed_attr_for_program(
src_var)._owner_context
auto_paralle_context.set_tensor_distributed_attr_for_program(var, dist_attr)
def copy_distributed_attr_for_dist_op(dist_op, dst_block, src_op_dist_attr):
"""
copy src op's dist_attr to dst dist op
"""
from ..attribute import OperatorDistributedAttribute
auto_paralle_context = src_op_dist_attr.get_owner_context()
op_dist_attr = OperatorDistributedAttribute(dist_op, auto_paralle_context)
auto_paralle_context._copy_distributed_attr_from_op_desc(dist_op.desc,
op_dist_attr)
auto_paralle_context.set_op_distributed_attr_for_program(dist_op,
op_dist_attr)
op_dist_attr.set_process_mesh(src_op_dist_attr.get_process_mesh())
op_dist_attr.set_impl_idx(src_op_dist_attr.get_impl_idx())
for input_varname in dist_op.desc.input_arg_names():
input_var = dst_block.var(input_varname)
tensor_dist_attr = auto_paralle_context.get_tensor_distributed_attr_for_program(
input_var)
tensor_dims_mapping = tensor_dist_attr.get_dims_mapping()
op_dist_attr.set_input_dims_mapping(input_varname, tensor_dims_mapping)
for output_varname in dist_op.desc.output_arg_names():
output_var = dst_block.var(output_varname)
tensor_dist_attr = auto_paralle_context.get_tensor_distributed_attr_for_program(
output_var)
tensor_dims_mapping = tensor_dist_attr.get_dims_mapping()
op_dist_attr.set_output_dims_mapping(output_varname,
tensor_dims_mapping)
......@@ -16,6 +16,8 @@ from .common import DistributedOperator
from .common import DistributedOperatorImpl
from .common import register_distributed_operator
from .common import register_distributed_operator_impl
from .common import copy_distributed_attr_for_var
from .common import copy_distributed_attr_for_dist_op
from ..utils import is_dim_shard
from ..utils import is_dim_replicate
from ..utils import is_valid_list_index
......@@ -173,13 +175,16 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
type=core.VarDesc.VarType.LOD_TENSOR,
persistable=False,
stop_gradient=Out_var.stop_gradient)
# copy Out_var's dist_attr to intermediate_var_0's dist_attr
copy_distributed_attr_for_var(op_dist_attr, intermediate_var_0,
Out_var)
check_variable_and_dtype(
Out_var, 'tensor',
['float16', 'float32', 'float64', 'int32', 'int64'],
'c_allreduce_sum')
dst_block.append_op(
c_embedding_op = dst_block.append_op(
type='c_embedding',
inputs={'Ids': [Ids_var],
'W': [Weight_var]},
......@@ -187,7 +192,7 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
attrs={"start_index": relative_idx})
# use_model_parallel
dst_block.append_op(
c_allreduce_sum_op = dst_block.append_op(
type='c_allreduce_sum',
inputs={'X': [intermediate_var_0]},
outputs={'Out': [Out_var]},
......@@ -197,6 +202,12 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
'use_model_parallel': True,
})
# copy serial op's dist_attr to dist op's dist_attr
copy_distributed_attr_for_dist_op(c_embedding_op, dst_block,
op_dist_attr)
copy_distributed_attr_for_dist_op(c_allreduce_sum_op, dst_block,
op_dist_attr)
if in_dygraph_mode():
raise NotImplementedError(
"Dist op for [{}] with idx [{}] is NOT implemented yet.".format(
......
......@@ -16,6 +16,8 @@ from .common import DistributedOperator
from .common import DistributedOperatorImpl
from .common import register_distributed_operator
from .common import register_distributed_operator_impl
from .common import copy_distributed_attr_for_var
from .common import copy_distributed_attr_for_dist_op
from ..utils import is_dim_shard
from ..utils import is_dim_replicate
from ..utils import is_valid_list_index
......@@ -223,13 +225,16 @@ class DistributedMatmulImpl0(DistributedOperatorImpl):
type=core.VarDesc.VarType.LOD_TENSOR,
persistable=False,
stop_gradient=X_var.stop_gradient)
# copy X_var's dist_attr to intermediate_var_0's dist_attr
copy_distributed_attr_for_var(op_dist_attr, intermediate_var_0,
X_var)
check_variable_and_dtype(
X_var, 'tensor',
['float16', 'float32', 'float64', 'int32', 'int64'],
'_c_identity')
dst_block.append_op(
c_identity_op = dst_block.append_op(
type='c_identity',
inputs={'X': [X_var]},
outputs={'Out': intermediate_var_0},
......@@ -250,12 +255,18 @@ class DistributedMatmulImpl0(DistributedOperatorImpl):
'alpha': 1,
}
inputs = {'X': [intermediate_var_0], 'Y': [Weight_var]}
dst_block.append_op(
matmul_op = dst_block.append_op(
type='matmul',
inputs=inputs,
outputs={'Out': Out_var},
attrs=attrs)
# copy serial op's dist_attr to dist op's dist_attr
copy_distributed_attr_for_dist_op(c_identity_op, dst_block,
op_dist_attr)
copy_distributed_attr_for_dist_op(matmul_op, dst_block,
op_dist_attr)
if in_dygraph_mode():
raise NotImplementedError(
"Dist op for [{}] with idx [{}] is NOT implemented yet.".format(
......@@ -369,13 +380,17 @@ class DistributedMatmulImpl1(DistributedOperatorImpl):
persistable=False,
is_data=False,
need_check_feed=Out_var.desc.need_check_feed())
dst_block.append_op(
# copy Out_var's dist_attr to intermediate_var_0's dist_attr
copy_distributed_attr_for_var(op_dist_attr, intermediate_var_0,
Out_var)
matmul_op = dst_block.append_op(
type='matmul',
inputs=inputs,
outputs={'Out': intermediate_var_0},
attrs=attrs)
dst_block.append_op(
c_allreduce_sum_op = dst_block.append_op(
type='c_allreduce_sum',
inputs={'X': intermediate_var_0},
outputs={'Out': Out_var},
......@@ -385,6 +400,12 @@ class DistributedMatmulImpl1(DistributedOperatorImpl):
'use_model_parallel': True
})
# copy serial op's dist_attr to dist op's dist_attr
copy_distributed_attr_for_dist_op(matmul_op, dst_block,
op_dist_attr)
copy_distributed_attr_for_dist_op(c_allreduce_sum_op, dst_block,
op_dist_attr)
if in_dygraph_mode():
raise NotImplementedError(
"Dist op for [{}] with idx [{}] is NOT implemented yet.".format(
......@@ -540,15 +561,12 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl):
Out_var = dst_block.var(output_name_mapping['Out'][0])
# TODO infer logic comm presentation
from ..process import new_process_group
from ..transpiler import _get_comm_group
model_parallel_axis, process_mesh = op_dist_attr.get_owner_context(
)._get_model_parallel_info()
group_ranks = _get_comm_group(process_mesh.topology,
model_parallel_axis,
process_mesh.process_group, rank_id)
group_ranks = _get_comm_group(process_mesh.process_group,
process_mesh.topology,
model_parallel_axis, rank_id)
group = new_process_group(group_ranks)
# print("@@@@@@@@@@@@@@@@@@@@@ 5", group)
intermediate_var_0 = dst_block.create_var(
name=unique_name.generate_with_ignorable_key(".".join(
......@@ -558,13 +576,16 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl):
type=core.VarDesc.VarType.LOD_TENSOR,
persistable=False,
stop_gradient=X_var.stop_gradient)
# copy X_var's dist_attr to intermediate_var_0's dist_attr
copy_distributed_attr_for_var(op_dist_attr, intermediate_var_0,
X_var)
check_variable_and_dtype(
X_var, 'tensor',
['float16', 'float32', 'float64', 'int32', 'int64'],
'_c_identity')
dst_block.append_op(
c_identity_op = dst_block.append_op(
type='c_identity',
inputs={'X': [X_var]},
outputs={'Out': intermediate_var_0},
......@@ -581,12 +602,18 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl):
['float16', 'float32', 'float64'], 'linear')
attrs = {'trans_x': False, 'trans_y': False}
inputs = {'X': [intermediate_var_0], 'Y': [Weight_var]}
dst_block.append_op(
matmul_v2_op = dst_block.append_op(
type='matmul_v2',
inputs=inputs,
outputs={'Out': Out_var},
attrs=attrs)
# copy serial op's dist_attr to dist op's dist_attr
copy_distributed_attr_for_dist_op(c_identity_op, dst_block,
op_dist_attr)
copy_distributed_attr_for_dist_op(matmul_v2_op, dst_block,
op_dist_attr)
if in_dygraph_mode():
raise NotImplementedError(
"Dist op for [{}] with idx [{}] is NOT implemented yet.".format(
......@@ -675,15 +702,12 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl):
Out_var = dst_block.var(output_name_mapping['Out'][0])
# TODO infer logic comm presentation
from ..process import new_process_group
from ..transpiler import _get_comm_group
model_parallel_axis, process_mesh = op_dist_attr.get_owner_context(
)._get_model_parallel_info()
group_ranks = _get_comm_group(process_mesh.topology,
model_parallel_axis,
process_mesh.process_group, rank_id)
group_ranks = _get_comm_group(process_mesh.process_group,
process_mesh.topology,
model_parallel_axis, rank_id)
group = new_process_group(group_ranks)
# print("@@@@@@@@@@@@@@@@@@@@@ 4", group)
check_variable_and_dtype(
X_var, 'x', ['float16', 'float32', 'float64'], 'linear')
......@@ -699,13 +723,17 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl):
persistable=False,
is_data=False,
need_check_feed=Out_var.desc.need_check_feed())
dst_block.append_op(
# copy Out_var's dist_attr to intermediate_var_0's dist_attr
copy_distributed_attr_for_var(op_dist_attr, intermediate_var_0,
Out_var)
matmul_v2_op = dst_block.append_op(
type='matmul_v2',
inputs=inputs,
outputs={'Out': intermediate_var_0},
attrs=attrs)
dst_block.append_op(
c_allreduce_sum_op = dst_block.append_op(
type='c_allreduce_sum',
inputs={'X': intermediate_var_0},
outputs={'Out': Out_var},
......@@ -715,6 +743,12 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl):
'use_model_parallel': True
})
# copy serial op's dist_attr to dist op's dist_attr
copy_distributed_attr_for_dist_op(matmul_v2_op, dst_block,
op_dist_attr)
copy_distributed_attr_for_dist_op(c_allreduce_sum_op, dst_block,
op_dist_attr)
if in_dygraph_mode():
raise NotImplementedError(
"Dist op for [{}] with idx [{}] is NOT implemented yet.".format(
......
......@@ -582,6 +582,8 @@ if(WITH_DISTRIBUTE)
py_test_modules(test_fleet_localsgd_meta_optimizer MODULES test_fleet_localsgd_meta_optimizer ENVS ${dist_ENVS})
py_test_modules(test_fleet_lars_meta_optimizer MODULES test_fleet_lars_meta_optimizer ENVS ${dist_ENVS})
py_test_modules(test_fleet_lamb_meta_optimizer MODULES test_fleet_lamb_meta_optimizer ENVS ${dist_ENVS})
py_test_modules(test_auto_parallel_partitioner MODULES test_auto_parallel_partitioner ENVS ${dist_ENVS})
py_test_modules(test_auto_parallel_partitioner_gpt MODULES test_auto_parallel_partitioner_gpt ENVS ${dist_ENVS})
endif(NOT WIN32)
endif(NOT APPLE)
if(WITH_DGC)
......
......@@ -135,6 +135,127 @@ def initialization_check(mode, dist_context, dist_startup_prog,
return True
def get_input_var_dist_attr(op, main_program, dist_context):
varname = op.desc.input_arg_names()
var = main_program.global_block().var(varname[0])
dist_attr = dist_context.get_tensor_distributed_attr_for_program(var)
return dist_attr
def get_output_var_dist_attr(op, main_program, dist_context):
varname = op.desc.output_arg_names()
var = main_program.global_block().var(varname[0])
dist_attr = dist_context.get_tensor_distributed_attr_for_program(var)
return dist_attr
def check_equal_var_dist_attr(serial_dist_attr, dist_attr):
equal = True
if serial_dist_attr.get_process_mesh() != dist_attr.get_process_mesh() or \
serial_dist_attr.is_parameter() != dist_attr.is_parameter() or \
serial_dist_attr.get_dims_mapping() != dist_attr.get_dims_mapping():
equal = False
return equal
def check_equal_dist_op_attr(dist_context, dist_main_prog, serial_op, dist_ops,
dist_op_idx):
equal = True
# get serial op's process_mesh and impl_idx
serial_op_dist_attr = dist_context.get_op_distributed_attr_for_program(
serial_op)
serial_process_mesh = serial_op_dist_attr.get_process_mesh()
serial_impl_idx = serial_op_dist_attr.get_impl_idx()
# check dist_attr between serial op and dist op
for i in dist_op_idx:
op_dist_attr = dist_context.get_op_distributed_attr_for_program(
dist_ops[i])
for in_varname in dist_ops[i].desc.input_arg_names():
in_var = dist_main_prog.global_block().var(in_varname)
tensor_dist_attr = dist_context.get_tensor_distributed_attr_for_program(
in_var)
tensor_dims_mapping = tensor_dist_attr.get_dims_mapping()
in_var_dims_mapping = op_dist_attr.get_input_dims_mapping(
in_varname)
if tensor_dims_mapping != in_var_dims_mapping:
equal = False
for out_varname in dist_ops[i].desc.output_arg_names():
out_var = dist_main_prog.global_block().var(out_varname)
tensor_dist_attr = dist_context.get_tensor_distributed_attr_for_program(
out_var)
tensor_dims_mapping = tensor_dist_attr.get_dims_mapping()
out_var_dims_mapping = op_dist_attr.get_output_dims_mapping(
out_varname)
if tensor_dims_mapping != out_var_dims_mapping:
equal = False
dist_op_process_mesh = op_dist_attr.get_process_mesh()
dist_op_impl_idx = op_dist_attr.get_impl_idx()
if serial_op.desc.id() == dist_ops[i].desc.id() or \
serial_process_mesh != dist_op_process_mesh or \
serial_impl_idx != dist_op_impl_idx:
equal = False
return equal
def distributed_attr_check_for_dist_op(serial_main_prog, dist_main_prog,
dist_context, serial_op_idx,
dist_op_idx):
equal = True
serial_ops = serial_main_prog.global_block().ops
dist_ops = dist_main_prog.global_block().ops
for i in range(len(serial_op_idx)):
serial_op = serial_ops[serial_op_idx[i]]
dist_op_0 = dist_ops[dist_op_idx[i][0]]
if dist_op_0.type == "c_identity":
# serial op input's dist_attr
serial_in_dist_attr = get_input_var_dist_attr(
serial_op, serial_main_prog, dist_context)
# c_identity output's(new var) dist_attr
identity_out_dist_attr = get_output_var_dist_attr(
dist_op_0, dist_main_prog, dist_context)
# check var dist_attr
equal = check_equal_var_dist_attr(serial_in_dist_attr,
identity_out_dist_attr)
else:
# serial op output's dist_attr
serial_out_dist_attr = get_output_var_dist_attr(
serial_op, serial_main_prog, dist_context)
# dist op output's(new var) dist_attr
out_dist_attr = get_output_var_dist_attr(dist_op_0, dist_main_prog,
dist_context)
# check var dist_attr
equal = check_equal_var_dist_attr(serial_out_dist_attr,
out_dist_attr)
# check op's dist_attr
equal = check_equal_dist_op_attr(dist_context, dist_main_prog,
serial_op, dist_ops, dist_op_idx[i])
return equal
def distributed_attr_check_for_program(dist_main_prog, dist_context):
have_dist_attr = True
for block in dist_main_prog.blocks:
for tensor in block.vars.values():
var_dist_attr = dist_context.get_tensor_distributed_attr_for_program(
tensor)
if var_dist_attr is None:
have_dist_attr = False
for op in block.ops:
op_dist_attr = dist_context.get_op_distributed_attr_for_program(op)
if op_dist_attr is None:
have_dist_attr = False
return have_dist_attr
class MLPLayer(nn.Layer):
def __init__(self,
hidden_size=1024,
......@@ -276,8 +397,8 @@ class TestMLPAutoPartitioner(unittest.TestCase):
dist_ops = dist_main_prog.global_block().ops
dist_ops = [op.type for op in dist_ops]
ref_ops = [
'layer_norm', 'c_identity', 'matmul', 'elementwise_add', 'gelu',
'matmul', 'c_allreduce_sum', 'elementwise_add', 'dropout'
'layer_norm', 'c_identity', 'matmul_v2', 'elementwise_add', 'gelu',
'matmul_v2', 'c_allreduce_sum', 'elementwise_add', 'dropout'
]
self.assertTrue(dist_ops == ref_ops)
......@@ -289,6 +410,17 @@ class TestMLPAutoPartitioner(unittest.TestCase):
dist_startup_prog, serial_startup_prog,
var_need_broadcast))
# check var and op all have dist_attr in dist_main_program
self.assertTrue(
distributed_attr_check_for_program(dist_main_prog, dist_context))
# check distribured attr for dist op
serial_op_idx = [1, 4]
dist_op_idx = [[1, 2], [5, 6]]
self.assertTrue(
distributed_attr_check_for_dist_op(serial_main_prog, dist_main_prog,
dist_context, serial_op_idx,
dist_op_idx))
def test_mlp_dp_mp(self):
global _global_parallel_strategy
_global_parallel_strategy = "dp_mp"
......@@ -323,8 +455,8 @@ class TestMLPAutoPartitioner(unittest.TestCase):
dist_ops = dist_main_prog.global_block().ops
dist_ops = [op.type for op in dist_ops]
ref_ops = [
'layer_norm', 'c_identity', 'matmul', 'elementwise_add', 'gelu',
'matmul', 'c_allreduce_sum', 'elementwise_add', 'dropout'
'layer_norm', 'c_identity', 'matmul_v2', 'elementwise_add', 'gelu',
'matmul_v2', 'c_allreduce_sum', 'elementwise_add', 'dropout'
]
self.assertTrue(dist_ops == ref_ops)
......@@ -336,6 +468,17 @@ class TestMLPAutoPartitioner(unittest.TestCase):
dist_startup_prog, serial_startup_prog,
var_need_broadcast))
# check var and op all have dist_attr in dist_main_program
self.assertTrue(
distributed_attr_check_for_program(dist_main_prog, dist_context))
# check distribured attr for dist op
serial_op_idx = [1, 4]
dist_op_idx = [[1, 2], [5, 6]]
self.assertTrue(
distributed_attr_check_for_dist_op(serial_main_prog, dist_main_prog,
dist_context, serial_op_idx,
dist_op_idx))
class AttentionLayer(nn.Layer):
def __init__(self,
......@@ -531,12 +674,12 @@ class TestAttentionAutoPartitioner(unittest.TestCase):
dist_ops = dist_main_prog.global_block().ops
dist_ops = [op.type for op in dist_ops]
ref_ops = [
'c_identity', 'matmul', 'elementwise_add', 'reshape2', 'transpose2',
'c_identity', 'matmul', 'elementwise_add', 'c_identity', 'matmul',
'elementwise_add', 'reshape2', 'transpose2', 'reshape2',
'transpose2', 'matmul', 'softmax', 'dropout', 'matmul_v2',
'transpose2', 'reshape2', 'matmul', 'c_allreduce_sum',
'elementwise_add'
'c_identity', 'matmul_v2', 'elementwise_add', 'reshape2',
'transpose2', 'c_identity', 'matmul_v2', 'elementwise_add',
'c_identity', 'matmul_v2', 'elementwise_add', 'reshape2',
'transpose2', 'reshape2', 'transpose2', 'matmul', 'softmax',
'dropout', 'matmul_v2', 'transpose2', 'reshape2', 'matmul_v2',
'c_allreduce_sum', 'elementwise_add'
]
self.assertTrue(dist_ops == ref_ops)
......@@ -547,6 +690,17 @@ class TestAttentionAutoPartitioner(unittest.TestCase):
dist_startup_prog, serial_startup_prog,
var_need_broadcast))
# check var and op all have dist_attr in dist_main_program
self.assertTrue(
distributed_attr_check_for_program(dist_main_prog, dist_context))
# check distribured attr for dist op
serial_op_idx = [0, 4, 6, 18]
dist_op_idx = [[0, 1], [5, 6], [8, 9], [21, 22]]
self.assertTrue(
distributed_attr_check_for_dist_op(serial_main_prog, dist_main_prog,
dist_context, serial_op_idx,
dist_op_idx))
def test_attn_dp_mp(self):
global _global_parallel_strategy
_global_parallel_strategy = "dp_mp"
......@@ -582,12 +736,12 @@ class TestAttentionAutoPartitioner(unittest.TestCase):
dist_ops = dist_main_prog.global_block().ops
dist_ops = [op.type for op in dist_ops]
ref_ops = [
'c_identity', 'matmul', 'elementwise_add', 'reshape2', 'transpose2',
'c_identity', 'matmul', 'elementwise_add', 'c_identity', 'matmul',
'elementwise_add', 'reshape2', 'transpose2', 'reshape2',
'transpose2', 'matmul', 'softmax', 'dropout', 'matmul_v2',
'transpose2', 'reshape2', 'matmul', 'c_allreduce_sum',
'elementwise_add'
'c_identity', 'matmul_v2', 'elementwise_add', 'reshape2',
'transpose2', 'c_identity', 'matmul_v2', 'elementwise_add',
'c_identity', 'matmul_v2', 'elementwise_add', 'reshape2',
'transpose2', 'reshape2', 'transpose2', 'matmul', 'softmax',
'dropout', 'matmul_v2', 'transpose2', 'reshape2', 'matmul_v2',
'c_allreduce_sum', 'elementwise_add'
]
self.assertTrue(dist_ops == ref_ops)
......@@ -598,6 +752,17 @@ class TestAttentionAutoPartitioner(unittest.TestCase):
dist_startup_prog, serial_startup_prog,
var_need_broadcast))
# check var and op all have dist_attr in dist_main_program
self.assertTrue(
distributed_attr_check_for_program(dist_main_prog, dist_context))
# check distribured attr for dist op
serial_op_idx = [0, 4, 6, 18]
dist_op_idx = [[0, 1], [5, 6], [8, 9], [21, 22]]
self.assertTrue(
distributed_attr_check_for_dist_op(serial_main_prog, dist_main_prog,
dist_context, serial_op_idx,
dist_op_idx))
class DecoderLayer(nn.Layer):
def __init__(self,
......@@ -859,15 +1024,16 @@ class TestDecoderLayerPartitioner(unittest.TestCase):
dist_ops = [op.type for op in dist_ops]
ref_ops = [
'c_embedding', 'c_allreduce_sum', 'lookup_table_v2',
'elementwise_add', 'dropout', 'layer_norm', 'c_identity', 'matmul',
'elementwise_add', 'reshape2', 'transpose2', 'c_identity', 'matmul',
'elementwise_add', 'c_identity', 'matmul', 'elementwise_add',
'reshape2', 'transpose2', 'reshape2', 'transpose2', 'matmul',
'softmax', 'dropout', 'matmul_v2', 'transpose2', 'reshape2',
'matmul', 'c_allreduce_sum', 'elementwise_add', 'dropout',
'elementwise_add', 'layer_norm', 'c_identity', 'matmul',
'elementwise_add', 'gelu', 'matmul', 'c_allreduce_sum',
'elementwise_add', 'dropout', 'elementwise_add'
'elementwise_add', 'dropout', 'layer_norm', 'c_identity',
'matmul_v2', 'elementwise_add', 'reshape2', 'transpose2',
'c_identity', 'matmul_v2', 'elementwise_add', 'c_identity',
'matmul_v2', 'elementwise_add', 'reshape2', 'transpose2',
'reshape2', 'transpose2', 'matmul', 'softmax', 'dropout',
'matmul_v2', 'transpose2', 'reshape2', 'matmul_v2',
'c_allreduce_sum', 'elementwise_add', 'dropout', 'elementwise_add',
'layer_norm', 'c_identity', 'matmul_v2', 'elementwise_add', 'gelu',
'matmul_v2', 'c_allreduce_sum', 'elementwise_add', 'dropout',
'elementwise_add'
]
self.assertTrue(dist_ops == ref_ops)
......@@ -881,6 +1047,18 @@ class TestDecoderLayerPartitioner(unittest.TestCase):
dist_startup_prog, serial_startup_prog,
var_need_broadcast))
# check var and op all have dist_attr in dist_main_program
self.assertTrue(
distributed_attr_check_for_program(dist_main_prog, dist_context))
# check distribured attr
serial_op_idx = [0, 5, 9, 11, 23, 28, 31]
dist_op_idx = [[0, 1], [6, 7], [11, 12], [14, 15], [27, 28], [33, 34],
[37, 38]]
self.assertTrue(
distributed_attr_check_for_dist_op(serial_main_prog, dist_main_prog,
dist_context, serial_op_idx,
dist_op_idx))
def test_decoder_noparallel(self):
global _global_parallel_strategy
_global_parallel_strategy = "None"
......@@ -923,13 +1101,13 @@ class TestDecoderLayerPartitioner(unittest.TestCase):
dist_ops = [op.type for op in dist_ops]
ref_ops = [
'lookup_table_v2', 'lookup_table_v2', 'elementwise_add', 'dropout',
'layer_norm', 'matmul', 'elementwise_add', 'reshape2', 'transpose2',
'matmul', 'elementwise_add', 'matmul', 'elementwise_add',
'reshape2', 'transpose2', 'reshape2', 'transpose2', 'matmul',
'softmax', 'dropout', 'matmul_v2', 'transpose2', 'reshape2',
'matmul', 'elementwise_add', 'dropout', 'elementwise_add',
'layer_norm', 'matmul', 'elementwise_add', 'gelu', 'matmul',
'elementwise_add', 'dropout', 'elementwise_add'
'layer_norm', 'matmul_v2', 'elementwise_add', 'reshape2',
'transpose2', 'matmul_v2', 'elementwise_add', 'matmul_v2',
'elementwise_add', 'reshape2', 'transpose2', 'reshape2',
'transpose2', 'matmul', 'softmax', 'dropout', 'matmul_v2',
'transpose2', 'reshape2', 'matmul_v2', 'elementwise_add', 'dropout',
'elementwise_add', 'layer_norm', 'matmul_v2', 'elementwise_add',
'gelu', 'matmul_v2', 'elementwise_add', 'dropout', 'elementwise_add'
]
self.assertTrue(dist_ops == ref_ops)
dist_ops = dist_startup_prog.global_block().ops
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册