未验证 提交 315ef265 编写于 作者: Z zhaoyingli 提交者: GitHub

[AutoParallel] fix engine _build and cost method (#47263)

* fix engine build method

* fix import

* update engine cost

* update raise error

* update cmakelist

* revert optimizer

* revert optimizer

* fix unittest

* fix unittest
Co-authored-by: Ncaozhou <caozhou@radi.ac.cn>
上级 26c419ca
......@@ -34,6 +34,25 @@ class AdamOpCost(CompOpCost):
return 0
@register_op_cost
class ArgsortOpCost(CompOpCost):
OP_TYPE = "argsort"
def __init__(self, op=None, op_desc=None, cluster=None):
super(ArgsortOpCost, self).__init__(
op=op, op_desc=op_desc, cluster=cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0
def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0
@register_op_cost
class AssignOpCost(CompOpCost):
OP_TYPE = "assign"
......@@ -338,6 +357,25 @@ class ElementwiseSubGradOpCost(CompOpCost):
return 0
@register_op_cost
class EqualOpCost(CompOpCost):
OP_TYPE = "equal"
def __init__(self, op=None, op_desc=None, cluster=None):
super(EqualOpCost, self).__init__(
op=op, op_desc=op_desc, cluster=cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0
def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0
@register_op_cost
class EmbeddingOpCost(CompOpCost):
OP_TYPE = "c_embedding"
......
......@@ -545,11 +545,12 @@ class CostEstimator:
def get_cost_from_engine(engine, mode):
from ..utils import to_list
import copy
# Construct cost estimator by original main program
serial_main_prog = (
engine._serial_main_progs[mode].clone()
if mode in engine._serial_main_progs
engine._fwd_main_progs[mode].clone()
if mode in engine._fwd_main_progs
else engine._orig_main_prog.clone()
)
......@@ -566,29 +567,29 @@ def get_cost_from_engine(engine, mode):
)
else engine._losses
)
if mode in engine._dist_contexts:
dist_context = engine._dist_contexts[mode]
completer = engine._planners[mode].completer
serial_optimizer = copy.deepcopy(engine._orig_optimizer)
if mode in engine._fwd_dist_contexts:
dist_context = copy.deepcopy(engine._fwd_dist_contexts[mode])
else:
from ..completion import Completer
from ..dist_context import DistributedContext
dist_context = DistributedContext(
serial_main_prog,
serial_startup_prog,
engine._optimizer,
serial_optimizer,
losses,
{},
{"loss": losses},
engine._cluster,
engine._strategy,
)
completer = Completer(dist_context)
completer.complete_forward_annotation()
dist_context.block_state.parse_forward_blocks(
dist_context.serial_main_program
)
from ..completion import Completer
completer = Completer(dist_context)
completer.complete_forward_annotation()
dist_context.block_state.parse_forward_blocks(
dist_context.serial_main_program
)
if mode == "eval" or mode == "predict":
cost_estimator = CostEstimator(serial_main_prog, engine._cluster)
......@@ -596,7 +597,6 @@ def get_cost_from_engine(engine, mode):
from ..parallelizer_v2 import Parallelizer
# Get serial main program with backward
serial_optimizer = engine._optimizer
parallelizer = Parallelizer(mode, completer, dist_context)
# Generate backward
loss_name = dist_context.serial_loss.name
......
......@@ -1876,3 +1876,34 @@ def initialize_pg_in_full_mode(all_process_groups, cur_rank):
break
process_group.instantiate()
server_socket.close()
def get_input_split_info(cur_rank, var, dist_context):
# deduce how the input data is split among the cluster
tensor_dist_attr = dist_context.get_tensor_dist_attr_for_program(var)
process_mesh = tensor_dist_attr.process_mesh
dims_mapping = tensor_dist_attr.dims_mapping
if cur_rank not in process_mesh.processes:
rank_id = _get_corresponding_rank(dist_context, process_mesh, cur_rank)
else:
rank_id = cur_rank
batch_size_axis = dims_mapping[0]
if batch_size_axis > -1 and process_mesh.topology[batch_size_axis] > 1:
group_ranks = _get_comm_group(
process_mesh.processes,
process_mesh.topology,
batch_size_axis,
rank_id,
)
return len(group_ranks), group_ranks.index(rank_id)
return 1, 0
def validate_opt(optimizer):
if optimizer is not None:
optimizer._parameter_list = None
optimizer._param_groups = None
return optimizer
......@@ -63,6 +63,15 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
py_test_modules(test_engine_callbacks MODULES test_engine_callbacks)
set_tests_properties(test_engine_callbacks
PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 50)
py_test_modules(test_parallel_tuner MODULES test_parallel_tuner ENVS
${dist_ENVS})
set_tests_properties(test_parallel_tuner PROPERTIES TIMEOUT 120)
py_test_modules(test_parallel_tuner_full MODULES test_parallel_tuner_full
ENVS ${dist_ENVS})
set_tests_properties(test_parallel_tuner_full PROPERTIES TIMEOUT 120)
py_test_modules(test_parallel_tuner_predict MODULES
test_parallel_tuner_predict ENVS ${dist_ENVS})
set_tests_properties(test_parallel_tuner_predict PROPERTIES TIMEOUT 120)
py_test_modules(test_while_op_completion MODULES test_while_op_completion
ENVS ${dist_ENVS})
......@@ -90,6 +99,7 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
py_test_modules(test_prim_dist_op MODULES test_prim_dist_op ENVS ${dist_ENVS})
py_test_modules(test_to_static MODULES test_to_static ENVS ${dist_ENVS})
py_test_modules(test_dist_op_cost MODULES test_dist_op_cost ENVS ${dist_ENVS})
py_test_modules(test_cluster_v2 MODULES test_cluster_v2)
py_test_modules(test_process_mesh_v2 MODULES test_process_mesh_v2)
py_test_modules(test_dist_attr_v2 MODULES test_dist_attr_v2)
......@@ -99,20 +109,10 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
py_test_modules(test_interface MODULES test_interface)
py_test_modules(test_strategy MODULES test_strategy)
py_test_modules(test_pass_quantization MODULES test_pass_quantization)
py_test_modules(test_dist_shape MODULES test_dist_shape)
py_test_modules(test_dist_assign MODULES test_dist_assign)
py_test_modules(test_conditional_block_reshard MODULES
test_conditional_block_reshard)
py_test_modules(test_parallel_tuner MODULES test_parallel_tuner ENVS
${dist_ENVS})
set_tests_properties(test_parallel_tuner PROPERTIES TIMEOUT 120)
py_test_modules(test_parallel_tuner_full MODULES test_parallel_tuner_full
ENVS ${dist_ENVS})
set_tests_properties(test_parallel_tuner_full PROPERTIES TIMEOUT 120)
py_test_modules(test_parallel_tuner_predict MODULES
test_parallel_tuner_predict ENVS ${dist_ENVS})
set_tests_properties(test_parallel_tuner_predict PROPERTIES TIMEOUT 120)
py_test_modules(test_engine_api_error MODULES test_engine_api_error)
endif()
......@@ -374,6 +374,57 @@ def train_non_builtin_data_vars():
def get_cost():
main_program = static.Program()
startup_program = static.Program()
with static.program_guard(
main_program, startup_program
), utils.unique_name.guard():
input = static.data(
name="input", shape=[batch_size, image_size], dtype='float32'
)
label = static.data(name="label", shape=[batch_size, 1], dtype='int64')
loader = paddle.io.DataLoader.from_generator(
feed_list=[input, label], capacity=4 * batch_size, iterable=False
)
places = static.cuda_places()
loader.set_batch_generator(batch_generator_creator(), places=places)
mlp = MLPLayer(
hidden_size=hidden_size,
intermediate_size=4 * hidden_size,
dropout_ratio=0.1,
initializer_range=0.02,
)
loss = paddle.nn.CrossEntropyLoss()
optimizer = paddle.optimizer.Adam(
learning_rate=0.00001,
beta1=0.9,
beta2=0.999,
epsilon=1e-08,
grad_clip=None,
)
metric = paddle.metric.Accuracy()
predict = mlp(input)
loss_var = loss(predict, label)
strategy = auto.Strategy()
strategy.auto_mode = "semi"
engine = auto.Engine(
loss=loss_var, optimizer=optimizer, metrics=metric, strategy=strategy
)
engine.prepare(
main_program=main_program,
startup_program=startup_program,
inputs=[input],
labels=[label],
mode="train",
)
engine.cost()
def get_cost_by_default_program():
main_program = static.default_main_program()
startup_program = static.default_startup_program()
with static.program_guard(
......@@ -414,7 +465,7 @@ def get_cost():
engine = auto.Engine(
loss=loss_var, optimizer=optimizer, metrics=metric, strategy=strategy
)
engine.cost()
engine.cost(mode="train")
def get_cost_by_spec():
......@@ -451,4 +502,5 @@ if __name__ == "__main__":
train_builtin_data_vars()
train_non_builtin_data_vars()
get_cost()
get_cost_by_default_program()
get_cost_by_spec()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle
import paddle.static as static
import paddle.nn as nn
import paddle.nn.functional as F
from paddle.io import Dataset
from paddle.distributed.fleet import auto
paddle.enable_static()
epoch_num = 1
batch_size = 2
batch_num = 10
hidden_size = 1024
sequence_len = 512
image_size = hidden_size
class_num = 10
is_fetch = True
is_feed = True
my_feed_vars = []
class TrainDataset(Dataset):
def __init__(self, num_samples):
super(TrainDataset, self).__init__()
self.num_samples = num_samples
def __getitem__(self, index):
input = np.random.uniform(size=image_size).astype("float32")
label = np.random.randint(0, class_num - 1, dtype="int64")
return input, label
def __len__(self):
return self.num_samples
class TestDataset(Dataset):
def __init__(self, num_samples):
super(TestDataset, self).__init__()
self.num_samples = num_samples
def __getitem__(self, index):
input = np.random.uniform(size=image_size).astype("float32")
return input
def __len__(self):
return self.num_samples
class MLPLayer(nn.Layer):
def __init__(
self,
hidden_size=1024,
intermediate_size=4 * 1024,
dropout_ratio=0.1,
initializer_range=0.02,
):
super(MLPLayer, self).__init__()
d_model = hidden_size
dim_feedforward = intermediate_size
weight_attr = paddle.ParamAttr(
initializer=nn.initializer.Normal(mean=0.0, std=initializer_range)
)
bias_attr = None
self.linear0 = nn.Linear(
d_model, dim_feedforward, weight_attr, bias_attr=bias_attr
)
self.linear1 = nn.Linear(
dim_feedforward, d_model, weight_attr, bias_attr=bias_attr
)
self.linear2 = nn.Linear(d_model, 1, weight_attr, bias_attr=bias_attr)
self.norm = nn.LayerNorm(d_model, epsilon=1e-5)
self.dropout = nn.Dropout(dropout_ratio, mode="upscale_in_train")
def forward(self, input):
out = self.norm(input)
out = self.linear0(out)
if is_feed:
my_feed_vars.append((out, out.shape))
out = F.gelu(out, approximate=True)
out = self.linear1(out)
out = self.dropout(out)
out = self.linear2(out)
if is_feed:
my_feed_vars.append((out, out.shape))
if is_fetch:
auto.fetch(out, "my_fetch", logging=True)
return out
class TestEngineErrorRaise(unittest.TestCase):
def setUp(self):
class NoSupportData1:
def __getitem__(self, index):
input = np.random.uniform(size=image_size).astype("float32")
label = np.random.randint(0, class_num - 1, dtype="int64")
return input, label
class NoSupportData2(TrainDataset):
def __getitem__(self, index):
input = [
list(np.random.uniform(size=image_size).astype("float32"))
]
label = [np.random.randint(0, class_num - 1, dtype="int64")]
return input, label
class NoSupportData3:
def __getitem__(self, index):
input = np.random.uniform(size=image_size).astype("float32")
return input
class NoSupportData4(TestDataset):
def __getitem__(self, index):
input = [
list(np.random.uniform(size=image_size).astype("float32"))
]
return input
self.no_support_data_1 = NoSupportData1()
self.no_support_data_2 = NoSupportData2(10)
self.no_support_data_3 = NoSupportData3()
self.no_support_data_4 = NoSupportData4(10)
def test_Engine(self):
with self.assertRaises(TypeError):
auto.Engine(model=paddle.static.Program())
with self.assertRaises(TypeError):
auto.Engine(loss="CrossEntropyLoss")
with self.assertRaises(TypeError):
auto.Engine(optimizer="adam")
with self.assertRaises(TypeError):
auto.Engine(metrics=["acc"])
with self.assertRaises(TypeError):
auto.Engine(cluster="cluster")
with self.assertRaises(TypeError):
auto.Engine(strategy="strategy")
def test_fit(self):
with self.assertRaises(TypeError):
engine = auto.Engine(
model=MLPLayer(),
loss=paddle.nn.CrossEntropyLoss(),
optimizer=paddle.optimizer.AdamW(0.00001),
)
engine.fit(train_data=self.no_support_data_1)
with self.assertRaises(TypeError):
engine = auto.Engine(
model=MLPLayer(),
loss=paddle.nn.CrossEntropyLoss(),
optimizer=paddle.optimizer.AdamW(0.00001),
)
engine.fit(train_data=self.no_support_data_2)
def test_evaluate(self):
with self.assertRaises(TypeError):
engine = auto.Engine(
model=MLPLayer(),
loss=paddle.nn.CrossEntropyLoss(),
metrics=paddle.metric.Accuracy(),
)
engine.evaluate(valid_data=self.no_support_data_3)
with self.assertRaises(TypeError):
engine = auto.Engine(
model=MLPLayer(),
loss=paddle.nn.CrossEntropyLoss(),
metrics=paddle.metric.Accuracy(),
)
engine.evaluate(
valid_data=self.no_support_data_4, valid_sample_split=1
)
def test_predict(self):
with self.assertRaises(TypeError):
engine = auto.Engine(model=MLPLayer())
engine.predict(
test_data=self.no_support_data_3, test_sample_split=1
)
with self.assertRaises(TypeError):
engine = auto.Engine(model=MLPLayer())
engine.predict(
test_data=self.no_support_data_4, test_sample_split=1
)
def build_program(self):
main_prog = static.Program()
startup_prog = static.Program()
with static.program_guard(main_prog, startup_prog):
input = static.data(
name="input",
shape=[batch_size // 2, image_size],
dtype='float32',
)
label = static.data(
name="label", shape=[batch_size // 2, 1], dtype='int64'
)
mlp = MLPLayer()
loss = paddle.nn.CrossEntropyLoss()
predict = mlp(input)
loss_var = loss(predict, label)
return main_prog, startup_prog, input, label, loss_var
def test_prepare(self):
with self.assertRaises(ValueError):
engine = auto.Engine(model=MLPLayer())
engine.prepare()
with self.assertRaises(AssertionError):
engine = auto.Engine(model=MLPLayer())
engine.prepare(mode="train")
with self.assertRaises(TypeError):
input = static.data(
name="input",
shape=[batch_size / 2, image_size],
dtype='float32',
)
label = static.data(
name="label", shape=[batch_size / 2, 1], dtype='int64'
)
engine = auto.Engine(model=MLPLayer())
engine.prepare(inputs_spec=input, labels_spec=label, mode="eval")
input_spec = static.InputSpec(
shape=[batch_size, image_size], dtype="float32", name="input"
)
label_spec = static.InputSpec(
shape=[batch_size, image_size], dtype="float32", name="input"
)
(
main_prog,
startup_prog,
input_var,
label_var,
loss_var,
) = self.build_program()
with self.assertRaises(TypeError):
engine = auto.Engine(loss=loss_var)
engine.prepare(
inputs=input_spec,
labels=label_spec,
main_program=main_prog,
startup_program=startup_prog,
mode="eval",
)
with self.assertRaises(AssertionError):
engine = auto.Engine(loss=loss_var)
engine.prepare(
inputs_spec=[input_spec, input_spec],
labels_spec=[label_spec, label_spec],
inputs=input_var,
labels=label_var,
main_program=main_prog,
startup_program=startup_prog,
mode="predict",
)
def test_cost(self):
with self.assertRaises(ValueError):
engine = auto.Engine(model=MLPLayer())
engine.cost(mode="predict")
class TestEngineDynamicErrorRaise(unittest.TestCase):
def setUp(self):
paddle.disable_static()
def tearDown(self):
paddle.enable_static()
def test_cost(self):
with self.assertRaises(ValueError):
engine = auto.Engine(model=MLPLayer())
engine.cost(mode="predict")
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册