未验证 提交 c9cd47d9 编写于 作者: J JZ-LIANG 提交者: GitHub

[Auto Parallel] Adapt Partitioner & DistOp for ERNIE3.0 Inference and cache (#39895)

* adapot dist op

* add dist_fill_constant_batch_size_like

* remvoe print

* update compitable

* add unitest
上级 6af2729e
......@@ -27,3 +27,4 @@ from . import dist_eltwise
from . import dist_check_finite_and_unscale
from . import dist_update_loss_scaling
from . import dist_split
from . import dist_fill_constant_batch_size_like
......@@ -155,7 +155,7 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
kwargs['Out'])
Ids_var = main_block.var(kwargs['Ids'][0])
Weight_var = main_block.var(kwargs['W'][0])
Weight_var = main_block._var_recursive(kwargs['W'][0])
Out_var = main_block.var(kwargs['Out'][0])
# got dist attribute info
......@@ -277,7 +277,8 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
# param initialization sync
if Weight_var.is_parameter and not op_dist_attr.is_recompute:
assert Weight_var.name not in dist_op_context.already_init_sync_vars
if Weight_var.name in dist_op_context.already_init_sync_vars:
return
dist_op_context.already_init_sync_vars.add(Weight_var.name)
param = startup_block.var(Weight_var.name)
param_dist_attr = ctx.get_tensor_dist_attr_for_program(param)
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
from .common import DistributedOperatorImplContainer
from .common import DistributedOperatorImpl
from .common import register_distributed_operator_impl_container
from .common import register_distributed_operator_impl
from ..utils import is_dim_shard
from ..utils import is_dim_replicate
from ..utils import is_valid_list_index
from ..utils import compute_compatible_dim_mapping
from ..utils import compute_compatible_dims_mapping
from ..utils import compute_compatible_and_update_dim_mapping
from ..utils import set_dist_op_desc_original_id
from paddle.fluid import core, unique_name
from paddle.fluid.framework import in_dygraph_mode
from paddle.fluid.framework import Program, Parameter, Variable, program_guard
from paddle.fluid.data_feeder import check_variable_and_dtype, check_dtype
from .dist_default import DistributedDefaultImpl0
class DistributedFillConstantBatchSizeLike(DistributedOperatorImplContainer):
def __init__(self, op_type):
super(DistributedFillConstantBatchSizeLike, self).__init__(op_type)
register_distributed_operator_impl_container(
DistributedFillConstantBatchSizeLike("fill_constant_batch_size_like"))
class DistributedFillConstantBatchSizeLikeImpl0(DistributedOperatorImpl):
def __init__(self, name):
super(DistributedFillConstantBatchSizeLikeImpl0, self).__init__(name)
self._forward_implemented = True
self._backward_implemented = True
def is_input_compatible(self, dist_op):
return True
def is_output_compatible(self, dist_op):
op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
out_name = op_desc.output('Out')[0]
out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
shape_list = op_desc.attr("shape")
if len(shape_list) != len(out_dims_mapping):
return False
return True
def is_auto_compatible(self, dist_op):
if (not self.is_input_compatible(dist_op)) or \
(not self.is_output_compatible(dist_op)):
return False
out_name = op_desc.output('Out')[0]
out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
in_name = op_desc.input('Input')[0]
in_dims_mapping = op_dist_attr.get_input_dims_mapping(in_name)
# the dim_mapping of batch dimension should be the same
return out_dims_mapping[0] == in_dims_mapping[0]
def update_dims_mapping(self, dist_op):
changed = False
op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
x_name = op_desc.input('X')[0]
out_name = op_desc.output('Out')[0]
x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
# only the batch size dimemsion of input and output are relative.
dim_changed = compute_compatible_and_update_dim_mapping(
[x_dims_mapping, out_dims_mapping], [0, 0])
if dim_changed:
changed = True
return changed
@staticmethod
def forward(ctx, *args, **kwargs):
"""
kwargs: inputname_mapping & outputname_mapping
"""
DistributedDefaultImpl0.forward(ctx, *args, **kwargs)
dist_op_context = ctx.dist_op_context
src_op = dist_op_context.cur_src_op
op_dist_attr = ctx.get_op_dist_attr_for_program(src_op)
main_block = dist_op_context.work_block
op = main_block.ops[-1]
assert op.type == "fill_constant_batch_size_like"
# modify shape attr according to how output are partitioned
out_name = op.output('Out')[0]
dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
process_mesh_shape = op_dist_attr.process_mesh.topology
shape_list = op.attr("shape")
# modify target shape
for idx, axis in enumerate(dims_mapping):
if axis >= 0:
shape_list[idx] = shape_list[idx] // process_mesh_shape[axis]
op._set_attr("shape", shape_list)
main_block._sync_with_cpp()
@staticmethod
def backward(ctx, *args, **kwargs):
DistributedDefaultImpl0.backward(ctx, *args, **kwargs)
register_distributed_operator_impl(
"fill_constant_batch_size_like",
DistributedFillConstantBatchSizeLikeImpl0("fill_by_shape"))
......@@ -433,8 +433,8 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs):
def _init_param_sync(Weight_var, dist_op_context, startup_block, ctx, rank_id):
assert Weight_var.name not in dist_op_context.already_init_sync_vars, "{} is in {}.".format(
Weight_var.name, dist_op_context.already_init_sync_vars)
if Weight_var.name in dist_op_context.already_init_sync_vars:
return
assert startup_block.has_var(Weight_var.name)
dist_op_context.already_init_sync_vars.add(Weight_var.name)
param = startup_block.var(Weight_var.name)
......@@ -819,6 +819,8 @@ class DistributedMatmulImpl1(DistributedOperatorImpl):
out_var_dist_attr)
intermediate_var_0 = main_block.create_var(
name=unique_name.generate_with_ignorable_key(".".join(
["c_allreduce_sum", 'tmp'])),
shape=Out_var.shape,
dtype=Out_var.dtype,
type=Out_var.type,
......@@ -1323,6 +1325,8 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl):
out_var_dist_attr)
intermediate_var_0 = main_block.create_var(
name=unique_name.generate_with_ignorable_key(".".join(
["c_allreduce_sum", 'tmp'])),
shape=Out_var.shape,
dtype=Out_var.dtype,
type=Out_var.type,
......
......@@ -285,6 +285,9 @@ def _get_dist_shape(var, dist_attr):
var_shape = var.shape
mapping = dist_attr.dims_mapping
mesh = dist_attr.process_mesh.topology
if mapping == []:
return var_shape
assert len(var_shape) == len(
mapping
), "variable shape [{}] and dim_mapping [{}] is NOT match !".format(
......
......@@ -174,6 +174,7 @@ def get_program():
dtype='float32')
label = static.data(
name="label", shape=[batch_size, sequence_len, 1], dtype='float32')
data_holder = [input, label]
# dataloader
dataloader = paddle.io.DataLoader.from_generator(
......@@ -194,6 +195,17 @@ def get_program():
"dims_mapping": [-1, -1, -1]
})
# fill constant bsz like
tmp = paddle.fluid.layers.fill_constant_batch_size_like(
input=input, shape=[-1, 16, 0, 48], dtype='float32', value=0)
auto.shard_tensor(
tmp,
dist_attr={
"process_mesh": _g_process_mesh,
"dims_mapping": [-1, 0, -1, -1]
})
# model
mlp_start = MLPLayer(
hidden_size=hidden_size,
intermediate_size=4 * hidden_size,
......@@ -395,6 +407,9 @@ def completion(train_program, start_program, dist_context):
op_dist_attr.impl_idx = 0
else:
op_dist_attr.impl_idx = 1
elif op.type == "fill_constant_batch_size_like":
op_dist_attr.impl_type = "fill_constant_batch_size_like"
op_dist_attr.impl_idx = 0
else:
op_dist_attr.impl_type = "default"
op_dist_attr.impl_idx = 0
......@@ -428,6 +443,12 @@ class TestMLP(unittest.TestCase):
dist_main_prog, dist_startup_prog = partition(
train_program, start_program, dist_context)
global_block_ops = dist_main_prog.blocks[0].ops
fill_op = None
for op in global_block_ops:
if op.type == "fill_constant_batch_size_like":
fill_op = op
global_block_ops = [op.type for op in global_block_ops]
sub_block_ops = dist_main_prog.blocks[1].ops
sub_block_ops = [op.type for op in sub_block_ops]
......@@ -435,6 +456,13 @@ class TestMLP(unittest.TestCase):
self.assertTrue("c_allreduce_sum" in global_block_ops)
self.assertTrue("c_allreduce_sum" in sub_block_ops)
# test fill_constant_batch_size_like
self.assertTrue(fill_op is not None)
ref_shape = [-1, 8, 0, 48]
shape = fill_op.attr("shape")
self.assertTrue(ref_shape == shape)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册