未验证 提交 3108ba11 编写于 作者: C caozhou 提交者: GitHub

[Auto Parallel]Add parallel tuner (#46189)

* add parallel tuner

* add unittest

* fix unittest

* set timeout of unittest

* set unittest timeout

* fix auto_mode setting

* update unittest

* sync from develop and update unittest

* remove unused import

* update unittest

* update cmakelist

* add unittests
上级 9cdf30dc
...@@ -1305,6 +1305,8 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl): ...@@ -1305,6 +1305,8 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl):
process_mesh = dist_attr.process_mesh process_mesh = dist_attr.process_mesh
processes = process_mesh.processes processes = process_mesh.processes
# col parallel: matmul + allreduce # col parallel: matmul + allreduce
if backward_op.attr("trans_y"):
Y_var_dim_mapping.reverse()
assert Y_var_dim_mapping[0] < 0 assert Y_var_dim_mapping[0] < 0
parallel_axis = Y_var_dim_mapping[1] parallel_axis = Y_var_dim_mapping[1]
......
...@@ -14,8 +14,7 @@ ...@@ -14,8 +14,7 @@
from .completion import Completer from .completion import Completer
from .dist_context import get_default_distributed_context from .dist_context import get_default_distributed_context
from .tuner.parallel_tuner import ParallelTuner
# from .tuner.parallel_tuner import ParallelTuner
class Planner: class Planner:
...@@ -38,20 +37,20 @@ class Planner: ...@@ -38,20 +37,20 @@ class Planner:
self._completer = Completer(self._dist_context) self._completer = Completer(self._dist_context)
self._strategy = dist_context.strategy self._strategy = dist_context.strategy
# if self._strategy.auto_search: # set parallel tuner for auto search
# self._parallel_tuner = ParallelTuner( if self._strategy.auto_mode == "full":
# self._dist_context, mode=self._mode) self._parallel_tuner = ParallelTuner(self._dist_context,
mode=self._mode)
@property @property
def completer(self): def completer(self):
return self._completer return self._completer
def plan(self): def plan(self):
if self._strategy.auto_mode == "full":
self._parallel_tuner.tune()
else:
self._completer.complete_forward_annotation() self._completer.complete_forward_annotation()
# if self._strategy.auto_search:
# self._parallel_tuner.tune()
# else:
# self._completer.complete_forward_annotation()
# parse forward sub block # parse forward sub block
self._dist_context.block_state.parse_forward_blocks( self._dist_context.block_state.parse_forward_blocks(
self._dist_context.serial_main_program) self._dist_context.serial_main_program)
...@@ -37,10 +37,18 @@ class TunableSpace(object): ...@@ -37,10 +37,18 @@ class TunableSpace(object):
def variables(self): def variables(self):
return self._variables return self._variables
@variables.setter
def variables(self, variables):
self._variables = variables
@property @property
def values(self): def values(self):
return self._values return self._values
@values.setter
def values(self, values):
self._values = values
def get_value(self, name): def get_value(self, name):
if name in self.values: if name in self.values:
return self.values[name] return self.values[name]
......
...@@ -90,6 +90,7 @@ class Choice(TunableVariable): ...@@ -90,6 +90,7 @@ class Choice(TunableVariable):
raise TypeError( raise TypeError(
"Choice can contain only one type of value, but found values: {} with types: {}." "Choice can contain only one type of value, but found values: {} with types: {}."
.format(str(values), str(types))) .format(str(values), str(types)))
self._is_unknown_type = False
if isinstance(values[0], str): if isinstance(values[0], str):
values = [str(v) for v in values] values = [str(v) for v in values]
...@@ -108,9 +109,8 @@ class Choice(TunableVariable): ...@@ -108,9 +109,8 @@ class Choice(TunableVariable):
if default is not None: if default is not None:
default = bool(default) default = bool(default)
else: else:
raise TypeError( self._is_unknown_type = True
"Choice can only contain str, int, float, or boll, but found: {} " self._indices = [i for i in range(len(values))]
.format(str(values)))
self.values = values self.values = values
if default is not None and default not in values: if default is not None and default not in values:
...@@ -129,6 +129,10 @@ class Choice(TunableVariable): ...@@ -129,6 +129,10 @@ class Choice(TunableVariable):
def random(self, seed=None): def random(self, seed=None):
rng = np.random.default_rng(seed) rng = np.random.default_rng(seed)
if self._is_unknown_type:
indice = rng.choice(self._indices)
return self.values[indice]
else:
return rng.choice(self.values) return rng.choice(self.values)
def get_state(self): def get_state(self):
......
...@@ -99,8 +99,20 @@ if(WITH_DISTRIBUTE AND WITH_GPU) ...@@ -99,8 +99,20 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
py_test_modules(test_interface MODULES test_interface) py_test_modules(test_interface MODULES test_interface)
py_test_modules(test_strategy MODULES test_strategy) py_test_modules(test_strategy MODULES test_strategy)
py_test_modules(test_pass_quantization MODULES test_pass_quantization) py_test_modules(test_pass_quantization MODULES test_pass_quantization)
py_test_modules(test_dist_shape MODULES test_dist_shape) py_test_modules(test_dist_shape MODULES test_dist_shape)
py_test_modules(test_dist_assign MODULES test_dist_assign) py_test_modules(test_dist_assign MODULES test_dist_assign)
py_test_modules(test_conditional_block_reshard MODULES py_test_modules(test_conditional_block_reshard MODULES
test_conditional_block_reshard) test_conditional_block_reshard)
py_test_modules(test_parallel_tuner MODULES test_parallel_tuner ENVS
${dist_ENVS})
set_tests_properties(test_parallel_tuner PROPERTIES TIMEOUT 120)
py_test_modules(test_parallel_tuner_full MODULES test_parallel_tuner_full
ENVS ${dist_ENVS})
set_tests_properties(test_parallel_tuner_full PROPERTIES TIMEOUT 120)
py_test_modules(test_parallel_tuner_predict MODULES
test_parallel_tuner_predict ENVS ${dist_ENVS})
set_tests_properties(test_parallel_tuner_predict PROPERTIES TIMEOUT 120)
endif() endif()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import paddle
import paddle.static as static
from paddle.distributed import fleet
from paddle.distributed.auto_parallel.cluster import Cluster
from paddle.distributed.auto_parallel.dist_context import DistributedContext, set_default_distributed_context
from paddle.distributed.auto_parallel.tuner.parallel_tuner import ParallelTuner
from paddle.distributed.auto_parallel.process_mesh import ProcessMesh
import sys
sys.path.append("..")
import auto_parallel_gpt_model as modeling
from auto_parallel_gpt_model import GPTModel, GPTForPretraining, GPTPretrainingCriterion
paddle.enable_static()
batch_size = 4
epoch_num = 10
hidden_size = 1024
sequence_len = 512
_g_process_mesh = [
ProcessMesh([0, 1], dim_names=["x"]),
ProcessMesh([2, 3], dim_names=["x"])
]
def get_program_v3():
dist_strategy = fleet.DistributedStrategy()
dist_strategy.semi_auto = True
# fleet.init(is_collective=True, strategy=dist_strategy)
place = paddle.set_device("gpu")
gpus = [0, 1]
batch_size = 8
sequence_len = 512
vocab_size = 1000
train_program = static.Program()
start_program = static.Program()
modeling.init_global()
modeling._global_parallel_strategy = None
# modeling.DPMPPP_MESH_LIST = [
# ProcessMesh([[0, 1], [2, 3]], dim_names=["x", "y"]),
# ProcessMesh([[4, 5], [6, 7]], dim_names=["x", "y"])
# ]
with static.program_guard(train_program, start_program):
tokens = paddle.static.data(name="tokens",
shape=[batch_size, sequence_len],
dtype='int64')
position_ids = paddle.static.data(name="position_ids",
shape=[batch_size, sequence_len],
dtype='int64')
attention_mask = paddle.static.data(
name="attention_mask",
shape=[batch_size, 1, sequence_len, sequence_len],
dtype='float32')
labels = paddle.static.data(name="labels",
shape=[batch_size, sequence_len],
dtype='int64')
loss_mask = paddle.static.data(name="loss_mask",
shape=[batch_size, sequence_len],
dtype='float32')
data_holder = [tokens, position_ids, attention_mask, labels, loss_mask]
gpt = GPTModel(vocab_size=1000,
hidden_size=1024,
num_hidden_layers=2,
num_attention_heads=16,
intermediate_size=4 * 1024,
hidden_act="gelu",
hidden_dropout_prob=0.0,
attention_probs_dropout_prob=0.0,
max_position_embeddings=1024,
type_vocab_size=1,
initializer_range=0.02,
pad_token_id=0,
eos_token_id=7,
bos_token_id=0,
eol_token_id=3,
pp_degree=1)
model = GPTForPretraining(gpt,
vocab_size=1000,
hidden_size=64,
initializer_range=0.02)
preds = model(tokens, position_ids, attention_mask)
criterion = GPTPretrainingCriterion()
loss = criterion(preds, labels, loss_mask)
optimizer = paddle.fluid.optimizer.AdamOptimizer(learning_rate=0.00001,
beta1=0.9,
beta2=0.999,
epsilon=1e-08,
grad_clip=None)
feed_vars = {
"inputs": [tokens, position_ids, attention_mask, loss_mask],
"labels": [labels]
}
fetch_vars = {"loss": [loss]}
return train_program, start_program, None, loss, optimizer, feed_vars, fetch_vars
class TestParallelTunerTrain(unittest.TestCase):
def test_tune_with_train(self):
flag = False
set_default_distributed_context(DistributedContext())
train_program, start_program, dataloader, loss, optimizer, feed_vars, fetch_vars = get_program_v3(
)
cluster = Cluster()
cluster.gen_default_config_cluster(node_count=1, device_count=8)
dist_context = DistributedContext(train_program, start_program,
optimizer, loss, feed_vars,
fetch_vars, cluster)
dist_context.initialize()
parallel_tuner = ParallelTuner(dist_context, max_trials=3, mode="train")
parallel_tuner.tune()
parallel_tuner._store_best_parallel_strategy()
flag = True
self.assertTrue(flag)
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import paddle
import paddle.static as static
from paddle.distributed import fleet
from paddle.distributed.auto_parallel.cluster import Cluster
from paddle.distributed.auto_parallel.dist_context import DistributedContext, set_default_distributed_context
from paddle.distributed.auto_parallel.tuner.parallel_tuner import ParallelTuner
from paddle.distributed.auto_parallel.process_mesh import ProcessMesh
from paddle.distributed.auto_parallel.planner_v2 import Planner
from paddle.distributed.auto_parallel.strategy import Strategy
import sys
sys.path.append("..")
import auto_parallel_gpt_model as modeling
from auto_parallel_gpt_model import GPTModel, GPTForPretraining, GPTPretrainingCriterion
paddle.enable_static()
batch_size = 4
epoch_num = 10
hidden_size = 1024
sequence_len = 512
_g_process_mesh = [
ProcessMesh([0, 1], dim_names=["x"]),
ProcessMesh([2, 3], dim_names=["x"])
]
def get_program_v3():
dist_strategy = fleet.DistributedStrategy()
dist_strategy.semi_auto = True
# fleet.init(is_collective=True, strategy=dist_strategy)
place = paddle.set_device("gpu")
gpus = [0, 1]
batch_size = 8
sequence_len = 512
vocab_size = 1000
train_program = static.Program()
start_program = static.Program()
modeling.init_global()
modeling._global_parallel_strategy = "dp_mp_pp"
modeling.DPMPPP_MESH_LIST = [
ProcessMesh([[0, 1], [2, 3]], dim_names=["x", "y"]),
ProcessMesh([[4, 5], [6, 7]], dim_names=["x", "y"])
]
with static.program_guard(train_program, start_program):
tokens = paddle.static.data(name="tokens",
shape=[batch_size, sequence_len],
dtype='int64')
position_ids = paddle.static.data(name="position_ids",
shape=[batch_size, sequence_len],
dtype='int64')
attention_mask = paddle.static.data(
name="attention_mask",
shape=[batch_size, 1, sequence_len, sequence_len],
dtype='float32')
labels = paddle.static.data(name="labels",
shape=[batch_size, sequence_len],
dtype='int64')
loss_mask = paddle.static.data(name="loss_mask",
shape=[batch_size, sequence_len],
dtype='float32')
data_holder = [tokens, position_ids, attention_mask, labels, loss_mask]
gpt = GPTModel(vocab_size=1000,
hidden_size=1024,
num_hidden_layers=2,
num_attention_heads=16,
intermediate_size=4 * 1024,
hidden_act="gelu",
hidden_dropout_prob=0.0,
attention_probs_dropout_prob=0.0,
max_position_embeddings=1024,
type_vocab_size=1,
initializer_range=0.02,
pad_token_id=0,
eos_token_id=7,
bos_token_id=0,
eol_token_id=3,
pp_degree=len(modeling.DPMPPP_MESH_LIST))
model = GPTForPretraining(gpt,
vocab_size=1000,
hidden_size=64,
initializer_range=0.02)
preds = model(tokens, position_ids, attention_mask)
criterion = GPTPretrainingCriterion()
loss = criterion(preds, labels, loss_mask)
optimizer = paddle.fluid.optimizer.AdamOptimizer(learning_rate=0.00001,
beta1=0.9,
beta2=0.999,
epsilon=1e-08,
grad_clip=None)
feed_vars = {
"inputs": [tokens, position_ids, attention_mask, loss_mask],
"labels": [labels]
}
fetch_vars = {"loss": [loss]}
return train_program, start_program, None, loss, optimizer, feed_vars, fetch_vars
class TestParallelTunerFull(unittest.TestCase):
def test_tune_with_planner(self):
flag = False
set_default_distributed_context(DistributedContext())
train_program, start_program, dataloader, loss, optimizer, feed_vars, fetch_vars = get_program_v3(
)
cluster = Cluster()
cluster.gen_default_config_cluster(node_count=1, device_count=8)
strategy = Strategy()
strategy.auto_mode = "full"
dist_context = DistributedContext(train_program, start_program,
optimizer, loss, feed_vars,
fetch_vars, cluster, strategy)
dist_context.initialize()
planner = Planner("train", dist_context)
planner._parallel_tuner = ParallelTuner(planner._dist_context,
mode=planner._mode,
max_trials=3)
planner.plan()
flag = True
self.assertTrue(flag)
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import paddle
import paddle.static as static
from paddle.distributed import fleet
from paddle.distributed.auto_parallel.cluster import Cluster
from paddle.distributed.auto_parallel.dist_context import DistributedContext, set_default_distributed_context
from paddle.distributed.auto_parallel.tuner.parallel_tuner import ParallelTuner
from paddle.distributed.auto_parallel.process_mesh import ProcessMesh
import sys
sys.path.append("..")
import auto_parallel_gpt_model as modeling
from auto_parallel_gpt_model import GPTModel, GPTForPretraining, GPTPretrainingCriterion
paddle.enable_static()
batch_size = 4
epoch_num = 10
hidden_size = 1024
sequence_len = 512
_g_process_mesh = [
ProcessMesh([0, 1], dim_names=["x"]),
ProcessMesh([2, 3], dim_names=["x"])
]
def get_program_v3():
dist_strategy = fleet.DistributedStrategy()
dist_strategy.semi_auto = True
# fleet.init(is_collective=True, strategy=dist_strategy)
place = paddle.set_device("gpu")
gpus = [0, 1]
batch_size = 8
sequence_len = 512
vocab_size = 1000
train_program = static.Program()
start_program = static.Program()
modeling.init_global()
modeling._global_parallel_strategy = "dp_mp_pp"
modeling.DPMPPP_MESH_LIST = [
ProcessMesh([[0, 1], [2, 3]], dim_names=["x", "y"]),
ProcessMesh([[4, 5], [6, 7]], dim_names=["x", "y"])
]
with static.program_guard(train_program, start_program):
tokens = paddle.static.data(name="tokens",
shape=[batch_size, sequence_len],
dtype='int64')
position_ids = paddle.static.data(name="position_ids",
shape=[batch_size, sequence_len],
dtype='int64')
attention_mask = paddle.static.data(
name="attention_mask",
shape=[batch_size, 1, sequence_len, sequence_len],
dtype='float32')
labels = paddle.static.data(name="labels",
shape=[batch_size, sequence_len],
dtype='int64')
loss_mask = paddle.static.data(name="loss_mask",
shape=[batch_size, sequence_len],
dtype='float32')
data_holder = [tokens, position_ids, attention_mask, labels, loss_mask]
gpt = GPTModel(vocab_size=1000,
hidden_size=1024,
num_hidden_layers=2,
num_attention_heads=16,
intermediate_size=4 * 1024,
hidden_act="gelu",
hidden_dropout_prob=0.0,
attention_probs_dropout_prob=0.0,
max_position_embeddings=1024,
type_vocab_size=1,
initializer_range=0.02,
pad_token_id=0,
eos_token_id=7,
bos_token_id=0,
eol_token_id=3,
pp_degree=len(modeling.DPMPPP_MESH_LIST))
model = GPTForPretraining(gpt,
vocab_size=1000,
hidden_size=64,
initializer_range=0.02)
preds = model(tokens, position_ids, attention_mask)
criterion = GPTPretrainingCriterion()
loss = criterion(preds, labels, loss_mask)
optimizer = paddle.fluid.optimizer.AdamOptimizer(learning_rate=0.00001,
beta1=0.9,
beta2=0.999,
epsilon=1e-08,
grad_clip=None)
feed_vars = {
"inputs": [tokens, position_ids, attention_mask, loss_mask],
"labels": [labels]
}
fetch_vars = {"loss": [loss]}
return train_program, start_program, None, loss, optimizer, feed_vars, fetch_vars
class TestParallelTunerPredict(unittest.TestCase):
def test_tune_predict(self):
flag = False
set_default_distributed_context(DistributedContext())
train_program, start_program, dataloader, loss, optimizer, feed_vars, fetch_vars = get_program_v3(
)
cluster = Cluster()
cluster.gen_default_config_cluster(node_count=1, device_count=8)
dist_context = DistributedContext(train_program, start_program,
optimizer, loss, feed_vars,
fetch_vars, cluster)
dist_context.initialize()
parallel_tuner = ParallelTuner(dist_context,
max_trials=3,
mode="predict")
parallel_tuner.tune()
flag = True
self.assertTrue(flag)
if __name__ == "__main__":
unittest.main()
...@@ -136,6 +136,16 @@ class TestTunableSpace(unittest.TestCase): ...@@ -136,6 +136,16 @@ class TestTunableSpace(unittest.TestCase):
self.assertEqual(new_space.variables["int_range"].step, 1) self.assertEqual(new_space.variables["int_range"].step, 1)
self.assertEqual(new_space.variables["int_range"].endpoint, False) self.assertEqual(new_space.variables["int_range"].endpoint, False)
def test_expection(self):
space = ts.TunableSpace()
flag = True
try:
val = space.get_value("test")
flag = False
except:
pass
self.assertTrue(flag)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -298,14 +298,14 @@ class TransformerDecoder(nn.Layer): ...@@ -298,14 +298,14 @@ class TransformerDecoder(nn.Layer):
auto.shard_tensor(output, PP_MESH_LIST[0], auto.shard_tensor(output, PP_MESH_LIST[0],
[None for i in range(len(output.shape))]) [None for i in range(len(output.shape))])
if _global_parallel_strategy == "dp_pp": if _global_parallel_strategy == "dp_pp":
auto.shard_tensor(output, DPPP_MESH_LIST[0], ["x"].extends( auto.shard_tensor(output, DPPP_MESH_LIST[0], ["x"] +
[None for i in range(len(output.shape) - 1)])) [None for i in range(len(output.shape) - 1)])
if _global_parallel_strategy == "mp_pp": if _global_parallel_strategy == "mp_pp":
auto.shard_tensor(output, MPPP_MESH_LIST[0], auto.shard_tensor(output, MPPP_MESH_LIST[0],
[None for i in range(len(output.shape))]) [None for i in range(len(output.shape))])
if _global_parallel_strategy == "dp_mp_pp": if _global_parallel_strategy == "dp_mp_pp":
auto.shard_tensor(output, DPMPPP_MESH_LIST[0], ["x"].extends( auto.shard_tensor(output, DPMPPP_MESH_LIST[0], ["x"] +
[None for i in range(len(output.shape) - 1)])) [None for i in range(len(output.shape) - 1)])
for i, mod in enumerate(self.layers): for i, mod in enumerate(self.layers):
if cache is None: if cache is None:
if use_cache: if use_cache:
...@@ -323,8 +323,8 @@ class TransformerDecoder(nn.Layer): ...@@ -323,8 +323,8 @@ class TransformerDecoder(nn.Layer):
tgt_mask, tgt_mask,
use_cache, cache) use_cache, cache)
auto.shard_tensor( auto.shard_tensor(
output, DPPP_MESH_LIST[mod.mesh_idx], ["x"].extends( output, DPPP_MESH_LIST[mod.mesh_idx], ["x"] +
[None for i in range(len(output.shape) - 1)])) [None for i in range(len(output.shape) - 1)])
elif _global_parallel_strategy == "mp_pp": elif _global_parallel_strategy == "mp_pp":
output, new_cache = auto.shard_op( output, new_cache = auto.shard_op(
mod, MPPP_MESH_LIST[mod.mesh_idx])(output, memory, mod, MPPP_MESH_LIST[mod.mesh_idx])(output, memory,
...@@ -362,8 +362,8 @@ class TransformerDecoder(nn.Layer): ...@@ -362,8 +362,8 @@ class TransformerDecoder(nn.Layer):
tgt_mask, tgt_mask,
use_cache, cache) use_cache, cache)
auto.shard_tensor( auto.shard_tensor(
output, DPPP_MESH_LIST[mod.mesh_idx], ["x"].extends( output, DPPP_MESH_LIST[mod.mesh_idx], ["x"] +
[None for i in range(len(output.shape) - 1)])) [None for i in range(len(output.shape) - 1)])
elif _global_parallel_strategy == "mp_pp": elif _global_parallel_strategy == "mp_pp":
output = auto.shard_op( output = auto.shard_op(
mod, MPPP_MESH_LIST[mod.mesh_idx])(output, memory, mod, MPPP_MESH_LIST[mod.mesh_idx])(output, memory,
...@@ -378,9 +378,8 @@ class TransformerDecoder(nn.Layer): ...@@ -378,9 +378,8 @@ class TransformerDecoder(nn.Layer):
output, memory, tgt_mask, output, memory, tgt_mask,
use_cache, cache) use_cache, cache)
auto.shard_tensor( auto.shard_tensor(
output, DPMPPP_MESH_LIST[mod.mesh_idx], output, DPMPPP_MESH_LIST[mod.mesh_idx], ["x"] +
["x"].extends( [None for i in range(len(output.shape) - 1)])
[None for i in range(len(output.shape) - 1)]))
else: else:
output = mod(output, output = mod(output,
memory, memory,
...@@ -400,9 +399,9 @@ class TransformerDecoder(nn.Layer): ...@@ -400,9 +399,9 @@ class TransformerDecoder(nn.Layer):
mod, mod,
DPPP_MESH_LIST[mod.mesh_idx])(output, memory, tgt_mask, DPPP_MESH_LIST[mod.mesh_idx])(output, memory, tgt_mask,
use_cache, cache) use_cache, cache)
auto.shard_tensor(output, DPPP_MESH_LIST[mod.mesh_idx], [ auto.shard_tensor(
"x" output, DPPP_MESH_LIST[mod.mesh_idx],
].extends([None for i in range(len(output.shape) - 1)])) ["x"] + [None for i in range(len(output.shape) - 1)])
elif _global_parallel_strategy == "mp_pp": elif _global_parallel_strategy == "mp_pp":
output, new_cache = auto.shard_op( output, new_cache = auto.shard_op(
mod, mod,
...@@ -415,9 +414,9 @@ class TransformerDecoder(nn.Layer): ...@@ -415,9 +414,9 @@ class TransformerDecoder(nn.Layer):
mod, DPMPPP_MESH_LIST[mod.mesh_idx])(output, memory, mod, DPMPPP_MESH_LIST[mod.mesh_idx])(output, memory,
tgt_mask, tgt_mask,
use_cache, cache) use_cache, cache)
auto.shard_tensor(output, DPMPPP_MESH_LIST[mod.mesh_idx], [ auto.shard_tensor(
"x" output, DPMPPP_MESH_LIST[mod.mesh_idx],
].extends([None for i in range(len(output.shape) - 1)])) ["x"] + [None for i in range(len(output.shape) - 1)])
else: else:
output, new_cache = mod(output, output, new_cache = mod(output,
memory, memory,
...@@ -682,11 +681,11 @@ class GPTModel(nn.Layer): ...@@ -682,11 +681,11 @@ class GPTModel(nn.Layer):
auto.shard_tensor(input_ids, PP_MESH_LIST[0], auto.shard_tensor(input_ids, PP_MESH_LIST[0],
[None for i in range(len(input_ids.shape))]) [None for i in range(len(input_ids.shape))])
if _global_parallel_strategy == "dp_pp": if _global_parallel_strategy == "dp_pp":
auto.shard_tensor(input_ids, DPPP_MESH_LIST[0], ["x"].extends( auto.shard_tensor(input_ids, DPPP_MESH_LIST[0], ["x"] +
[None for i in range(len(input_ids.shape) - 1)])) [None for i in range(len(input_ids.shape) - 1)])
if _global_parallel_strategy == "dp_mp_pp": if _global_parallel_strategy == "dp_mp_pp":
auto.shard_tensor(input_ids, DPMPPP_MESH_LIST[0], ["x"].extends( auto.shard_tensor(input_ids, DPMPPP_MESH_LIST[0], ["x"] +
[None for i in range(len(input_ids.shape) - 1)])) [None for i in range(len(input_ids.shape) - 1)])
encoder_outputs = self.decoder(embedding_output, encoder_outputs = self.decoder(embedding_output,
memory=None, memory=None,
tgt_mask=attention_mask, tgt_mask=attention_mask,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册