未验证 提交 31a437b1 编写于 作者: Z zhaoyingli 提交者: GitHub

[AutoParallel] adapt for gpt-gen (#46771)

* for gpt-gen

* fix reshard

* adapt assign and shape op

* add dist_assign & unittest

* add conditional block unittest

* rename unittest
上级 eee6b3a7
......@@ -17,7 +17,7 @@ import time
from paddle.fluid import core
from .utils import is_gradient_clip_op
from .utils import is_gradient_clip_op, __not_shape_var_type__
from .operators import find_compatible_distributed_operator_impls
from .dist_context import _node_id
from .dist_attribute import TensorDistributedAttribute
......@@ -491,14 +491,14 @@ class Completer:
for tensor_node in node.inputs:
if tensor_node.is_var() and tensor_node.var(
) is not None:
if tensor_node.var().type() == core.VarDesc.VarType.READER \
if tensor_node.var().type() in __not_shape_var_type__ \
or len(tensor_node.var().shape()) != 1:
flag = False
break
for tensor_node in node.outputs:
if tensor_node.is_var() and tensor_node.var(
) is not None:
if tensor_node.var().type() == core.VarDesc.VarType.READER \
if tensor_node.var().type() in __not_shape_var_type__ \
or len(tensor_node.var().shape()) != 1:
flag = False
break
......
......@@ -1139,8 +1139,10 @@ class Engine:
self.to_mode(mode)
if inputs or labels:
self._skip_build = True
self._inputs_spec = inputs_spec
self._labels_spec = labels_spec
self._inputs, self._labels = self._prepare_data_tensor(
inputs_spec, labels_spec, inputs, labels)
self._inputs_spec, self._labels_spec, inputs, labels)
self._orig_main_prog = main_program
if self._orig_main_prog is None:
self._orig_main_prog = static.default_main_program()
......@@ -1152,9 +1154,11 @@ class Engine:
else:
self._switch_mode(self._mode)
elif inputs_spec or labels_spec:
self._inputs_spec = inputs_spec
self._labels_spec = labels_spec
self._outside_dataloader = True
self._inputs, self._labels = self._prepare_data_tensor(
inputs_spec, labels_spec)
self._inputs_spec, self._labels_spec)
self._orig_main_prog = main_program
if self._orig_main_prog is None:
self._orig_main_prog = static.default_main_program()
......
......@@ -33,3 +33,5 @@ from . import dist_slice
from . import dist_fused_feedforward
from . import dist_fused_attention
from . import dist_reduce_sum_p
from . import dist_shape
from . import dist_assign
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .common import DistributedOperatorImplContainer
from .common import DistributedOperatorImpl
from .common import register_distributed_operator_impl_container
from .common import register_distributed_operator_impl
from .dist_default import DistributedDefaultImpl0
from ..utils import compute_compatible_and_update_dim_mapping
class DistributedAssign(DistributedOperatorImplContainer):
def __init__(self, op_type):
super(DistributedAssign, self).__init__(op_type)
register_distributed_operator_impl_container(DistributedAssign("assign"))
class DistributedAssignImpl(DistributedOperatorImpl):
def __init__(self, name):
super(DistributedAssignImpl, self).__init__(name)
self._forward_implemented = True
self._backward_implemented = True
def is_input_compatible(self, dist_op):
return True
def is_output_compatible(self, dist_op):
return True
def is_auto_compatible(self, dist_op):
if (not self.is_input_compatible(dist_op)) or \
(not self.is_output_compatible(dist_op)):
return False
op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
x_name = op_desc.input('X')[0]
out_name = op_desc.output('Out')[0]
x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
if x_dims_mapping != out_dims_mapping:
return False
return True
def update_dims_mapping(self, dist_op):
changed = False
op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
x_name = op_desc.input('X')[0]
out_name = op_desc.output('Out')[0]
x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
for i in range(len(x_dims_mapping)):
dim_changed = compute_compatible_and_update_dim_mapping(
[x_dims_mapping, out_dims_mapping], [i, i])
if dim_changed:
changed = True
return changed
@staticmethod
def forward(ctx, *args, **kwargs):
DistributedDefaultImpl0.forward(ctx, *args, **kwargs)
@staticmethod
def backward(ctx, *args, **kwargs):
DistributedDefaultImpl0.backward(ctx, *args, **kwargs)
register_distributed_operator_impl("assign", DistributedAssignImpl("assign"))
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .common import DistributedOperatorImplContainer
from .common import DistributedOperatorImpl
from .common import register_distributed_operator_impl_container
from .common import register_distributed_operator_impl
from .dist_default import DistributedDefaultImpl0
from ..utils import is_dim_shard
class DistributedShape(DistributedOperatorImplContainer):
def __init__(self, op_type):
super(DistributedShape, self).__init__(op_type)
register_distributed_operator_impl_container(DistributedShape("shape"))
class DistributedShapeImpl(DistributedOperatorImpl):
def __init__(self, name):
super(DistributedShapeImpl, self).__init__(name)
self._forward_implemented = True
self._backward_implemented = True
def is_input_compatible(self, dist_op):
return True
def is_output_compatible(self, dist_op):
op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
out_name = op_desc.output('Out')[0]
out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
assert len(out_dims_mapping) == 1
if is_dim_shard(out_dims_mapping[0]):
return False
return True
def is_auto_compatible(self, dist_op):
if (not self.is_input_compatible(dist_op)) or \
(not self.is_output_compatible(dist_op)):
return False
return True
def update_dims_mapping(self, dist_op):
return False
@staticmethod
def forward(ctx, *args, **kwargs):
DistributedDefaultImpl0.forward(ctx, *args, **kwargs)
@staticmethod
def backward(ctx, *args, **kwargs):
DistributedDefaultImpl0.backward(ctx, *args, **kwargs)
register_distributed_operator_impl("shape", DistributedShapeImpl("shape"))
......@@ -34,6 +34,7 @@ _g_special_ops = ['check_finite_and_unscale', 'update_loss_scaling']
_g_gradient_clip_ops = [
"sum", "sqrt", "fill_constant", "elementwise_max", "elementwise_div"
]
_g_subblock_ops = ["while", "conditional_block"]
def get_var_with_recursion(var_name, block, program):
......@@ -42,11 +43,11 @@ def get_var_with_recursion(var_name, block, program):
if var_name in block.vars:
var = block.vars[var_name]
else:
parent_block = program.blocks[block.parent_idx]
if var_name in parent_block.vars:
var = parent_block.vars[var_name]
assert var is not None, \
"{} is not found".format(var.name)
var = block._var_recursive(var_name)
# parent_block = program.blocks[block.parent_idx]
# if var_name in parent_block.vars:
# var = parent_block.vars[var_name]
assert var is not None, "{} is not found".format(var.name)
return var
......@@ -1075,7 +1076,9 @@ class Resharder:
new_Out = []
for var_name in while_op.output("Out"):
for output_name in sub_block_op_outputs[::-1]:
if output_name.find(var_name) != -1:
if output_name.find(var_name) != -1 and (
len(var_name) == len(output_name)
or "@RESHARD" in output_name):
if output_name not in new_Out:
new_Out.append(output_name)
assert new_Out
......@@ -1104,13 +1107,15 @@ class Resharder:
return False
def is_condition_replicative(self, op):
assert op.type == "while"
sub_block = self.auto_parallel_main_prog.blocks[op.attr("sub_block").id]
dist_op = self.dist_context.get_dist_op_for_program(op)
op_dist_attr = dist_op.dist_attr
if op.type == "while":
input_cond = op.input("Condition")
elif op.type == "conditional_block":
input_cond = op.input("Cond")
# the dims mapping of condition tensor should be replicative
for var_name in op.input("Condition"):
for var_name in input_cond:
var = get_var_with_recursion(var_name, sub_block,
self.auto_parallel_main_prog)
dist_tensor = self.dist_context.get_dist_tensor_for_program(var)
......@@ -1660,9 +1665,9 @@ class Resharder:
op.desc.set_input(proto.inputs[0].name,
op.input("X") + while_op_X_append)
def _get_while_op_input_attrs(self, op, var_name):
def _get_subblock_input_attrs(self, op, var_name):
# NOTE: Multi while loop is not supported
assert op.type == "while"
assert op.type in _g_subblock_ops
sub_block = self.auto_parallel_main_prog.blocks[op.attr("sub_block").id]
ops = sub_block.ops
input_attrs = []
......@@ -1713,8 +1718,8 @@ class Resharder:
def get_op_input_attrs(self, op, var_name):
op_input_attrs = []
if op.type == "while":
op_input_attrs = self._get_while_op_input_attrs(op, var_name)
if op.type in _g_subblock_ops:
op_input_attrs = self._get_subblock_input_attrs(op, var_name)
else:
op_input_attrs = self._get_common_op_input_attrs(op, var_name)
......@@ -1818,7 +1823,7 @@ class Resharder:
if dist_op is not None:
op_input_dist_attrs = [
] # [(op_process_mesh, op_input_dims_mapping), (op_process_mesh, op_input_dims_mapping)]
if op.type == "while":
if op.type in _g_subblock_ops:
if not self.is_condition_replicative(op):
raise ValueError(
"Please check the condition due to the dims mapping is not replicative."
......@@ -1832,6 +1837,8 @@ class Resharder:
if op.type == "while":
# condition var process mesh is the same with op and dims_mapping is replicative, so it do not need reshard
input_var_names = op.input("X")
elif op.type == "conditional_block":
input_var_names = op.input("Input")
else:
input_var_names = op.input_arg_names
# to avoid while op X order different
......@@ -1984,11 +1991,12 @@ class Resharder:
idx = 0
# skip reader and ops whose process mesh is union
skip_ops = [
"create_py_reader", "create_double_buffer_reader", "read", "while",
"create_py_reader", "create_double_buffer_reader", "read",
"write_to_array", "read_from_array"
]
global _g_special_ops
skip_ops += _g_special_ops
skip_ops += _g_subblock_ops
while idx < len(block.ops):
pre_op_count = len(block.ops)
op = block.ops[idx]
......
......@@ -27,6 +27,10 @@ from paddle.distributed.auto_parallel.process_group import get_all_process_group
from paddle.fluid.io import is_parameter, is_belong_to_optimizer
from paddle.distributed.auto_parallel.dist_attribute import TensorDistributedAttribute, OperatorDistributedAttribute
__not_shape_var_type__ = [
core.VarDesc.VarType.READER, core.VarDesc.VarType.STEP_SCOPES
]
def get_logger(log_level, name="auto_parallel"):
logger = logging.getLogger(name)
......
......@@ -96,5 +96,9 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
py_test_modules(test_interface MODULES test_interface)
py_test_modules(test_strategy MODULES test_strategy)
py_test_modules(test_pass_quantization MODULES test_pass_quantization)
py_test_modules(test_dist_shape MODULES test_dist_shape)
py_test_modules(test_dist_assign MODULES test_dist_assign)
py_test_modules(test_conditional_block_reshard MODULES
test_conditional_block_reshard)
endif()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddle.static import InputSpec
from paddle.distributed.fleet import auto
class MLPLayer(nn.Layer):
def __init__(self,
hidden_size=64,
intermediate_size=4 * 64,
initializer_range=0.02):
super(MLPLayer, self).__init__()
self.norm = nn.LayerNorm(hidden_size, epsilon=1e-5)
self.linear0 = nn.Linear(
hidden_size,
intermediate_size,
paddle.ParamAttr(initializer=nn.initializer.Normal(
mean=0.0, std=initializer_range)),
bias_attr=None)
self.linear1 = nn.Linear(
intermediate_size,
hidden_size,
paddle.ParamAttr(initializer=nn.initializer.Normal(
mean=0.0, std=initializer_range)),
bias_attr=None)
def forward(self, input):
out = self.norm(input)
auto.shard_tensor(self.linear0.weight, auto.ProcessMesh([0, 1], "x"),
[None, "x"])
out = self.linear0(out)
out = F.gelu(out, approximate=True)
auto.shard_tensor(self.linear1.weight, auto.ProcessMesh([0, 1], "x"),
["x", None])
out = self.linear1(out)
if paddle.mean(out) < 2:
out = self.norm(out)
out = self.linear0(out)
out = F.gelu(out, approximate=True)
out = self.linear1(out)
else:
out = self.norm(out)
out = self.linear0(out)
out = self.linear1(out)
return out
def loss_fn(predict, label):
error_cost = paddle.nn.functional.square_error_cost(predict, label)
loss = paddle.mean(error_cost)
return loss
class TestSubblock(unittest.TestCase):
def test_subblock(self):
mlp = MLPLayer()
strategy = auto.Strategy()
strategy.auto_mode = "semi"
engine = auto.Engine(model=mlp, loss=loss_fn, strategy=strategy)
input_sepc = InputSpec([4, 64], 'float32', 'input')
label_spec = InputSpec([4, 1], 'float32', 'label')
engine.prepare(inputs_spec=[input_sepc],
labels_spec=[label_spec],
mode="predict")
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import paddle
from paddle.distributed.fleet import auto
paddle.enable_static()
def make_program():
main_program = paddle.fluid.Program()
start_program = paddle.fluid.Program()
with paddle.static.program_guard(main_program, start_program):
x = paddle.static.data(name='x', shape=[4, 4, 8], dtype='float32')
y = paddle.static.data(name='y', shape=[4, 4, 8], dtype='float32')
auto.shard_tensor(x, auto.ProcessMesh([0, 1], dim_names=["d"]),
[None, "d", None])
z = paddle.add(x, y)
paddle.assign(x, output=z)
return main_program, start_program
def parallelizer(program_func, rank):
from paddle.distributed.auto_parallel.completion import Completer
from paddle.distributed.auto_parallel.partitioner import Partitioner
from paddle.distributed.auto_parallel.dist_context import DistributedContext
main_program, start_program = program_func()
dist_context = DistributedContext()
completer = Completer(dist_context)
completer.complete_forward_annotation(main_program)
dist_context.block_state.parse_forward_blocks(main_program)
partitioner = Partitioner(dist_context, rank)
dist_main_prog, _, _ = partitioner.partition(main_program, start_program,
[])
return dist_main_prog, dist_context
class TestDistAssign(unittest.TestCase):
def test_dist_assign(self):
dist_main_prog, dist_context = parallelizer(make_program, 0)
ops = dist_main_prog.global_block().ops
for op in ops:
if op.type == "assign":
dist_op = dist_context.get_dist_op_for_program(op)
dist_op.dist_attr.impl_type == "assign"
dist_op.dist_attr.impl_idx == 0
x_name = op.input_arg_names[0]
out_name = op.output_arg_names[0]
out_var = dist_main_prog.global_block().vars[out_name]
dist_out = dist_context.get_dist_tensor_for_program(out_var)
x_dims_mapping = dist_op.dist_attr.get_input_dims_mapping(
x_name)
out_dims_mapping = dist_op.dist_attr.get_output_dims_mapping(
out_name)
assert x_dims_mapping == out_dims_mapping
assert out_dims_mapping == dist_out.dist_attr.dims_mapping
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import paddle
from paddle.distributed.fleet import auto
paddle.enable_static()
def make_program():
main_program = paddle.fluid.Program()
start_program = paddle.fluid.Program()
with paddle.static.program_guard(main_program, start_program):
x = paddle.static.data(name='x', shape=[4, 4, 8], dtype='float32')
x.stop_gradient = False
auto.shard_tensor(x, auto.ProcessMesh([0, 1], dim_names=["x"]),
["x", None, None])
shape = paddle.shape(x)
return main_program, start_program
def parallelizer(program_func, rank):
from paddle.distributed.auto_parallel.completion import Completer
from paddle.distributed.auto_parallel.partitioner import Partitioner
from paddle.distributed.auto_parallel.dist_context import DistributedContext
main_program, start_program = program_func()
dist_context = DistributedContext()
completer = Completer(dist_context)
completer.complete_forward_annotation(main_program)
dist_context.block_state.parse_forward_blocks(main_program)
partitioner = Partitioner(dist_context, rank)
dist_main_prog, _, _ = partitioner.partition(main_program, start_program,
[])
return dist_main_prog, dist_context
class TestDistShape(unittest.TestCase):
def test_dist_shape(self):
dist_main_prog, dist_context = parallelizer(make_program, 0)
ops = dist_main_prog.global_block().ops
shape_op = ops[0]
dist_op = dist_context.get_dist_op_for_program(shape_op)
dist_op.dist_attr.impl_type == "shape"
dist_op.dist_attr.impl_idx == 0
in_name = shape_op.input_arg_names[0]
out_name = shape_op.output_arg_names[0]
in_dims_mapping = dist_op.dist_attr.get_input_dims_mapping(in_name)
out_dims_mapping = dist_op.dist_attr.get_output_dims_mapping(out_name)
assert in_dims_mapping == [0, -1, -1]
assert out_dims_mapping == [-1]
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册