未验证 提交 d7f7963f 编写于 作者: Z zhaoyingli 提交者: GitHub

[AutoParallel] selective recompute (#48111)

* [AutoParallel] selective recompute

* add cmakelist
上级 aafa9820
......@@ -55,6 +55,7 @@ set_field_default_config(BASE, "reinit", False) # Only for debug
RECOMPUTE = "recompute"
set_field_default_config(RECOMPUTE, "enable", False)
set_field_default_config(RECOMPUTE, "checkpoints", None)
set_field_default_config(RECOMPUTE, "no_recompute_segments", [])
set_field_default_config(RECOMPUTE, "enable_tuning", False)
#########################################
......
......@@ -134,7 +134,7 @@ class DistributedDataLoaderFromGenerator(DistributedDataLoaderBase):
raise StopIteration
def _infer_steps(self):
if isinstance(self.steps_per_epoch, int) and self.steps_per_epoch > 1:
if isinstance(self.steps_per_epoch, int) and self.steps_per_epoch > 0:
return self.steps_per_epoch
try:
if isinstance(self.dataset, IterableDataset):
......
......@@ -610,7 +610,7 @@ class Engine:
if mode != "train":
serial_main_prog = serial_main_prog.clone(for_test=True)
self._set_recompute_ckpts()
auto_utils.set_recompute_ckpts(self._model, self._strategy)
self._dist_contexts[mode] = DistributedContext(
serial_main_prog,
serial_startup_prog,
......@@ -1518,35 +1518,6 @@ class Engine:
var_name = _to_name_str(var)
return var_name in self.main_program.global_block().vars
def _set_recompute_ckpts(self):
# NOTE hack to enable recompute in engine api for GPT-3
# TODO support more PaddleNLP/CV models here
recompute = self._strategy.recompute
# extract ckpts by specific model
if isinstance(self._model, paddle.nn.Layer):
if hasattr(
self._model, "gpt"
) and self._model.__class__.__name__ in [
'GPTForPretraining',
'GPTForPretrainingAuto',
]:
exact_ckpts = self._model.gpt.checkpoints
else:
exact_ckpts = recompute.checkpoints
else:
exact_ckpts = recompute.checkpoints
# modify strategy
if recompute.enable:
recompute.checkpoints = exact_ckpts[:]
logs = {
'Model Class': self._model.__class__.__name__,
'Applied Recompute ckpts': exact_ckpts,
}
self._logger.info(logs)
def _reset_metrics(self):
for metric in self._metrics:
metric.reset()
......
......@@ -195,7 +195,13 @@ def shard_op(op, process_mesh=None, in_shard_specs=None, out_shard_specs=None):
return op
_g_recompute_idx = -1
def recompute(op):
global _g_recompute_idx
_g_recompute_idx += 1
class RecomputeOperator:
def __init__(self, op):
self._op = op
......@@ -209,7 +215,9 @@ def recompute(op):
for idx in range(op_size, new_op_size):
op = cur_block.ops[idx]
op._set_attr("is_recompute@auto_parallel", True)
op._set_attr(
'op_namescope', "/auto_parallel/rc_" + str(_g_recompute_idx)
)
return output
......
......@@ -33,6 +33,9 @@ from paddle.distributed.auto_parallel.dist_attribute import (
OperatorDistributedAttribute,
)
OP_ROLE_KEY = core.op_proto_and_checker_maker.kOpRoleAttrName()
OpRole = core.op_proto_and_checker_maker.OpRole
__no_shape_var_type__ = [
core.VarDesc.VarType.READER,
core.VarDesc.VarType.STEP_SCOPES,
......@@ -1181,7 +1184,6 @@ def _get_split_indices(
def set_grad_var_shape(program, dist_context):
from .operators.common import infer_shape
from paddle.distributed.fleet.meta_optimizers.common import OpRole
block = program.global_block()
vars = block.vars
......@@ -1315,10 +1317,6 @@ def set_grad_var_shape(program, dist_context):
grad_var.desc.set_shape(ref_shape)
OP_ROLE_KEY = core.op_proto_and_checker_maker.kOpRoleAttrName()
OpRole = core.op_proto_and_checker_maker.OpRole
def is_forward_op(op):
op_role = int(op.attr('op_role'))
return OP_ROLE_KEY in op.attr_names and (
......@@ -1896,6 +1894,39 @@ def initialize_pg_in_full_mode(all_process_groups, cur_rank):
server_socket.close()
def set_recompute_ckpts(model, strategy):
from .interface import _g_recompute_idx
if _g_recompute_idx > -1:
return
recompute = strategy.recompute
if not recompute.enable:
return
# NOTE: hack to enable recompute in engine api for GPT-3
# TODO support more PaddleNLP/CV models here
# extract ckpts by specific model
if isinstance(model, paddle.nn.Layer):
if hasattr(model, "gpt") and model.__class__.__name__ in [
'GPTForPretraining',
'GPTForPretrainingAuto',
]:
exact_ckpts = model.gpt.checkpoints
else:
exact_ckpts = recompute.checkpoints
else:
exact_ckpts = recompute.checkpoints
# modify strategy
recompute.checkpoints = exact_ckpts[:]
logs = {
'Model Class': model.__class__.__name__,
'Applied Recompute ckpts': exact_ckpts,
}
logging.info(logs)
def get_input_split_info(cur_rank, var, dist_context):
# deduce how the input data is split among the cluster
tensor_dist_attr = dist_context.get_tensor_dist_attr_for_program(var)
......
......@@ -17,7 +17,6 @@ import logging
from .pass_base import PassBase, register_pass
from paddle.fluid import core, unique_name
from paddle.fluid import framework as framework
from paddle.fluid.framework import Variable
from paddle.fluid.backward import _append_grad_suffix_, _get_no_grad_set_name
from paddle.fluid.backward import ProgramStats, _rename_arg_, _find_op_path_
from paddle.distributed.auto_parallel.dist_attribute import (
......@@ -33,12 +32,21 @@ from paddle.distributed.auto_parallel.utils import (
)
def _to_be_recomputed(op):
return op.has_attr('op_namescope') and "/auto_parallel/rc_" in op.attr(
'op_namescope'
)
class RecomputeState(ProgramStats):
def __init__(self, block, ops):
super().__init__(block=block, ops=ops)
self._block = block
self._ops = ops
# {varname: {as_input_ops: op_idx, as_output_ops: op_idx}}
self.var_op_deps = {}
# {segment_name: op_idx}
self.seg_op_deps = {}
def build_stats(self):
for i, op in enumerate(self._ops):
......@@ -58,36 +66,72 @@ class RecomputeState(ProgramStats):
self.var_op_deps[name]["var_as_input_ops"] = []
self.var_op_deps[name]["var_as_output_ops"] = [i]
def get_recompute_segments(self, checkpoints):
"""get recompute segments from checkpoints"""
if not _to_be_recomputed(op):
continue
seg_name = op.attr('op_namescope')
if seg_name not in self.seg_op_deps:
self.seg_op_deps[seg_name] = [i]
else:
assert (
self.seg_op_deps[seg_name][-1] + 1 == i
), "The recompute segment's ops should be continuous"
self.seg_op_deps[seg_name].extend([i])
def get_recompute_segments(
self, checkpoints_list=None, no_recompute_segments=[]
):
"""get recompute segments and checkpoints"""
segments = []
start_idx = -1
pre_segment_end_idx = -1
while start_idx + 1 < len(checkpoints):
if start_idx == -1:
ckpt_name = checkpoints[start_idx + 1]
if ckpt_name not in self.var_op_deps:
start_idx += 1
checkpoints = checkpoints_list or []
if len(checkpoints) == 0:
# the segments is marked by `auto.recompute()` api
for segment_idx in self.seg_op_deps.values():
if len(segment_idx) == 1:
continue
op_idx_list = self.var_op_deps[ckpt_name]["var_as_output_ops"]
if op_idx_list:
segments.append([0, max(op_idx_list) + 1])
else:
flag, min_idx, max_idx = self.is_subgraph(
[checkpoints[start_idx]], [checkpoints[start_idx + 1]]
)
if flag:
min_idx = self._update_segment_start(
min_idx, pre_segment_end_idx
)
segments.append([min_idx, max_idx + 1])
segments.append([segment_idx[0], segment_idx[-1] + 1])
checkpoints.extend(self._ops[segment_idx[-1]].output_arg_names)
else:
# the segments is marked by `strategy.checkpoints` api
start_idx = -1
pre_segment_end_idx = -1
while start_idx + 1 < len(checkpoints):
if start_idx == -1:
ckpt_name = checkpoints[start_idx + 1]
if ckpt_name not in self.var_op_deps:
start_idx += 1
continue
op_idx_list = self.var_op_deps[ckpt_name][
"var_as_output_ops"
]
if op_idx_list:
segments.append([0, max(op_idx_list) + 1])
else:
logging.info(
"Could not recompute op range [{}] - [{}] ".format(
min_idx, max_idx + 1
)
flag, min_idx, max_idx = self.is_subgraph(
[checkpoints[start_idx]], [checkpoints[start_idx + 1]]
)
start_idx += 1
if flag:
min_idx = self._update_segment_start(
min_idx, pre_segment_end_idx
)
segments.append([min_idx, max_idx + 1])
else:
logging.info(
"Could not recompute op range [{}] - [{}] ".format(
min_idx, max_idx + 1
)
)
start_idx += 1
if no_recompute_segments:
for i in reversed(sorted(no_recompute_segments)):
assert i < len(
segments
), "the no_recompute_segments idx [{}] should be lower the number of segment [{}]".format(
i, len(segments)
)
segments.pop(i)
for i, (idx1, idx2) in enumerate(segments):
logging.info("recompute segment[{}]".format(i))
......@@ -106,7 +150,10 @@ class RecomputeState(ProgramStats):
)
)
return segments
return segments, checkpoints
def is_recompute(self):
return any([_to_be_recomputed(op) for op in self._ops])
def modify_forward_desc_for_recompute(self, dist_context):
"""
......@@ -162,6 +209,7 @@ class RecomputeState(ProgramStats):
outputs={"Out": seed_var},
attrs={"seed": seed, "force_cpu": True},
)
seed_op._set_attr('op_namescope', cur_op.attr('op_namescope'))
# set new seed op's dist_attr
naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
seed_op, ref_process_mesh, ref_dims_mapping, dist_context
......@@ -196,7 +244,6 @@ def _get_stop_gradients(program, no_grad_set):
no_grad_set_name = set()
for var in program.list_vars():
assert isinstance(var, Variable)
if "@GRAD" in var.name:
break
if var.stop_gradient:
......@@ -244,14 +291,13 @@ class RecomputePass(PassBase):
self.set_attr("loss", None)
self.set_attr("dist_context", None)
self.set_attr("no_grad_set", None)
self.set_attr("no_recompute_segments", [])
def _check_self(self):
if self.get_attr("dist_context") is None:
return False
if self.get_attr("loss") is None:
return False
if self.get_attr("checkpoints") is None:
return False
return True
def _check_conflict(self, other_pass):
......@@ -259,25 +305,32 @@ class RecomputePass(PassBase):
def _apply_single_impl(self, main_program, startup_program, context):
checkpoints = self.get_attr("checkpoints")
no_recompute_segments = self.get_attr("no_recompute_segments")
loss = self.get_attr("loss")
no_grad_set = self.get_attr("no_grad_set")
self._dist_context = self.get_attr("dist_context")
# 0. get op_path which is related to loss
main_block = main_program.global_block()
no_grad_set_name = _get_stop_gradients(main_program, no_grad_set)
# get op_path which is related to loss
op_path = _find_op_path_(main_block, [loss], [], no_grad_set_name)
# step 1: build recompute state
# 1. build recompute state
rc_state = RecomputeState(main_block, op_path)
if not rc_state.is_recompute() and not checkpoints:
return
# 2. get the segments to be recomputed
rc_state.modify_forward_desc_for_recompute(self._dist_context)
rc_state.build_stats()
checkpoints = rc_state.sort_checkpoints(checkpoints)
segments = rc_state.get_recompute_segments(checkpoints)
if segments == []:
checkpoints = rc_state.sort_checkpoints(checkpoints or [])
segments, checkpoints = rc_state.get_recompute_segments(
checkpoints, no_recompute_segments
)
if segments == [] or checkpoints == []:
return
# step 2: get vars_should_be_hold
# 3. get vars that should be hold in memory
vars_should_be_hold = []
for segment in segments:
vars_should_be_hold.extend(
......@@ -295,9 +348,9 @@ class RecomputePass(PassBase):
vars_should_be_hold = list(set(vars_should_be_hold))
vars_in_memory = vars_should_be_hold + checkpoints
# step 3: get recomputed fwd ops desc
var_name_dict = {}
ckpt_ops_dict = {}
# 4. get the fwd ops desc to be recomputed.
var_name_dict = {} # varname --> varname.subprog_XXX
ckpt_ops_dict = {} # ckpt_op_id --> segment_descs
buffer_block = main_block.program._create_block()
for i, segment in enumerate(segments[::-1]):
fwd_ops = op_path[segment[0] : segment[1]]
......@@ -362,7 +415,7 @@ class RecomputePass(PassBase):
ckpt_op = op_path[segment[1] - 1]
ckpt_ops_dict[ckpt_op.desc.original_id()] = [True, segment_descs]
# step 4: insert recomputed fwd ops
# 5. insert recomputed fwd ops into backward parse
ops = main_block.ops
loss_op = get_loss_op(main_block)
loss_op_idx = _find_op_index(main_block, loss_op)
......
......@@ -72,6 +72,8 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
py_test_modules(test_parallel_tuner_predict MODULES
test_parallel_tuner_predict ENVS ${dist_ENVS})
set_tests_properties(test_parallel_tuner_predict PROPERTIES TIMEOUT 120)
py_test_modules(test_selective_recompute MODULES test_selective_recompute)
set_tests_properties(test_selective_recompute PROPERTIES TIMEOUT 50)
py_test_modules(test_while_op_completion MODULES test_while_op_completion
ENVS ${dist_ENVS})
......
......@@ -22,13 +22,14 @@ from paddle.fluid.dygraph.parallel import ParallelEnv
from get_gpt_model import FakeDataset, generate_model
def apply_pass(use_recompute=False):
def apply_pass(use_recompute=False, no_recompute_segments=[]):
strategy = auto.Strategy()
strategy.auto_mode = "semi"
strategy.reinit = True
if use_recompute:
recompute = strategy.recompute
recompute.enable = True
recompute.no_recompute_segments = no_recompute_segments
return strategy
......@@ -53,10 +54,10 @@ class TestRecomputePass(unittest.TestCase):
place = paddle.fluid.CUDAPlace(ParallelEnv().dev_id)
engine._executor = paddle.static.Executor(place)
def get_engine(self, use_recompute=False):
def get_engine(self, use_recompute=False, no_recompute_segments=[]):
reset_prog()
strategy = apply_pass(use_recompute)
strategy = apply_pass(use_recompute, no_recompute_segments)
clip = paddle.nn.ClipGradByGlobalNorm(self.clip_norm)
opt = paddle.optimizer.AdamW(learning_rate=0.00001, grad_clip=clip)
model, loss = generate_model("mp")
......@@ -88,6 +89,18 @@ class TestRecomputePass(unittest.TestCase):
rc_losses = np.array(history.history["loss"])
self.check_results(mp_losses, rc_losses)
# mp2 selective recompute training
rc1_engine = self.get_engine(True, [0])
history = rc1_engine.fit(self.dataset, 3, batch_size=self.batch_size)
rc1_losses = np.array(history.history["loss"])
self.check_results(mp_losses, rc1_losses)
def test_recompute_pass_error(self):
with self.assertRaises(AssertionError):
rc_engine = self.get_engine(True, [2])
history = rc_engine.fit(self.dataset, 3, batch_size=self.batch_size)
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import unittest
import random
import numpy as np
import paddle
from paddle.distributed.fleet import auto
from paddle.fluid.dygraph.parallel import ParallelEnv
from get_gpt_model import FakeDataset
sys.path.append("..")
import auto_parallel_gpt_model as modeling
from auto_parallel_gpt_model import (
GPTModel,
GPTForPretraining,
GPTPretrainingCriterion,
)
def generate_model(use_new_recompute, recompute_granularity):
modeling.init_global()
modeling._global_parallel_strategy = "serial"
modeling._global_process_mesh = auto.ProcessMesh(mesh=[0], dim_names=["x"])
gpt = GPTModel(
vocab_size=1000,
hidden_size=64,
num_hidden_layers=2,
num_attention_heads=8,
intermediate_size=256,
hidden_act="gelu",
hidden_dropout_prob=0.0,
attention_probs_dropout_prob=0.0,
max_position_embeddings=1024,
type_vocab_size=1,
initializer_range=0.02,
pad_token_id=0,
eos_token_id=7,
bos_token_id=0,
eol_token_id=3,
use_new_recompute=use_new_recompute,
recompute_granularity=recompute_granularity,
)
model = GPTForPretraining(
gpt, vocab_size=1000, hidden_size=64, initializer_range=0.02
)
criterion = GPTPretrainingCriterion()
return model, criterion
def apply_pass(use_recompute=False, no_recompute_segments=[]):
strategy = auto.Strategy()
strategy.auto_mode = "semi"
strategy.reinit = True
if use_recompute:
recompute = strategy.recompute
recompute.enable = True
recompute.no_recompute_segments = no_recompute_segments
return strategy
def reset_prog():
paddle.fluid.framework.switch_main_program(paddle.static.Program())
paddle.fluid.framework.switch_startup_program(paddle.static.Program())
class TestRecomputePassWithRecomputeAPI(unittest.TestCase):
def setUp(self):
self.rtol = 1e-6
self.atol = 1e-8
self.batch_size = 1
self.batch_num = 2
self.clip_norm = 0.2
self.dataset = FakeDataset(self.batch_size * self.batch_num)
def init(self, engine):
paddle.seed(2022)
np.random.seed(2022)
random.seed(2022)
place = paddle.fluid.CUDAPlace(ParallelEnv().dev_id)
engine._executor = paddle.static.Executor(place)
def get_engine(
self,
use_recompute=False,
use_new_recompute=False,
recompute_granularity="full",
no_recompute_segments=[],
):
reset_prog()
strategy = apply_pass(use_recompute, no_recompute_segments)
clip = paddle.nn.ClipGradByGlobalNorm(self.clip_norm)
opt = paddle.optimizer.AdamW(learning_rate=0.00001, grad_clip=clip)
model, loss = generate_model(use_new_recompute, recompute_granularity)
engine = auto.Engine(model, loss, opt, strategy=strategy)
self.init(engine)
return engine
def check_results(self, ref_losses, check_losses):
np.testing.assert_allclose(
ref_losses,
check_losses,
rtol=self.rtol,
atol=self.atol,
err_msg='pass {} has wrong results!, \nu={}\nv={}\ndiff={}'.format(
__class__, ref_losses, check_losses, ref_losses - check_losses
),
)
def recompute_vars(self, program):
return list(filter(lambda a: "subprog" in a.name, program.list_vars()))
def test_recompute_pass(self):
# mp2 training
mp_engine = self.get_engine()
history = mp_engine.fit(self.dataset, 3, batch_size=self.batch_size)
mp_losses = np.array(history.history["loss"])
# mp2 recompute with old api
rc4_engine = self.get_engine(True, False)
history = rc4_engine.fit(self.dataset, 3, batch_size=self.batch_size)
rc4_losses = np.array(history.history["loss"])
self.check_results(mp_losses, rc4_losses)
# mp2 recompute core_attn
rc1_engine = self.get_engine(True, True, "core_attn", [0])
history = rc1_engine.fit(self.dataset, 3, batch_size=self.batch_size)
rc1_losses = np.array(history.history["loss"])
self.check_results(mp_losses, rc1_losses)
# mp2 recompute full_attn
rc2_engine = self.get_engine(True, True, "full_attn")
history = rc2_engine.fit(self.dataset, 3, batch_size=self.batch_size)
rc2_losses = np.array(history.history["loss"])
self.check_results(mp_losses, rc2_losses)
# mp2 recompute full
rc3_engine = self.get_engine(True, True, "full")
history = rc3_engine.fit(self.dataset, 3, batch_size=self.batch_size)
rc3_losses = np.array(history.history["loss"])
self.check_results(mp_losses, rc3_losses)
rc0_vars = self.recompute_vars(mp_engine.main_program)
rc1_vars = self.recompute_vars(rc1_engine.main_program)
rc2_vars = self.recompute_vars(rc2_engine.main_program)
rc3_vars = self.recompute_vars(rc3_engine.main_program)
assert rc0_vars == []
assert len(rc1_vars) < len(rc2_vars) and len(rc2_vars) < len(rc3_vars)
def test_recompute_pass_error(self):
with self.assertRaises(AssertionError):
rc_engine = self.get_engine(True, True, "full", [2])
history = rc_engine.fit(self.dataset, 3, batch_size=self.batch_size)
if __name__ == "__main__":
unittest.main()
......@@ -57,6 +57,8 @@ class MultiHeadAttention(nn.Layer):
bias_attr=None,
fuse=False,
mesh_idx=None,
use_new_recompute=False,
recompute_granularity="full",
):
super().__init__()
self.embed_dim = embed_dim
......@@ -67,6 +69,9 @@ class MultiHeadAttention(nn.Layer):
self.need_weights = need_weights
self.fuse = fuse
self.mesh_idx = mesh_idx
self.use_new_recompute = use_new_recompute
self.recompute_granularity = recompute_granularity
self.head_dim = embed_dim // num_heads
assert (
self.head_dim * num_heads == self.embed_dim
......@@ -225,6 +230,27 @@ class MultiHeadAttention(nn.Layer):
# incremental_state with initial value, mainly for usage like UniLM
return self.Cache(key, value)
def core_attn(self, q, k, v, attn_mask):
product = layers.matmul(
x=q, y=k, transpose_y=True, alpha=self.head_dim**-0.5
)
if attn_mask is not None:
product = product + attn_mask
weights = F.softmax(product)
if self.dropout:
weights = F.dropout(
weights,
self.dropout,
training=self.training,
mode="upscale_in_train",
)
out = tensor.matmul(weights, v)
# combine heads
out = tensor.transpose(out, perm=[0, 2, 1, 3])
out = tensor.reshape(x=out, shape=[0, 0, out.shape[2] * out.shape[3]])
return out, weights
def forward(
self, query, key, value, attn_mask=None, use_cache=False, cache=None
):
......@@ -244,23 +270,12 @@ class MultiHeadAttention(nn.Layer):
q, k, v, cache = self._prepare_qkv(
query, key, value, use_cache, cache
)
product = layers.matmul(
x=q, y=k, transpose_y=True, alpha=self.head_dim**-0.5
)
if attn_mask is not None:
product = product + attn_mask
weights = F.softmax(product)
if self.dropout:
weights = F.dropout(
weights,
self.dropout,
training=self.training,
mode="upscale_in_train",
)
out = tensor.matmul(weights, v)
# combine heads
out = tensor.transpose(out, perm=[0, 2, 1, 3])
out = tensor.reshape(x=out, shape=[0, 0, out.shape[2] * out.shape[3]])
if self.use_new_recompute and self.recompute_granularity == "core_attn":
out, weights = auto.recompute(self.core_attn)(q, k, v, attn_mask)
else:
out, weights = self.core_attn(q, k, v, attn_mask)
# project to output
out = self.out_proj(out)
if _global_parallel_strategy == "mp":
......@@ -295,12 +310,22 @@ class TransformerDecoder(nn.Layer):
TransformerDecoder is a stack of N decoder layers.
"""
def __init__(self, decoder_layers, num_layers, norm=None, hidden_size=None):
def __init__(
self,
decoder_layers,
num_layers,
norm=None,
hidden_size=None,
use_new_recompute=False,
recompute_granularity="full",
):
super().__init__()
self.num_layers = num_layers
self.layers = decoder_layers
self.norm = norm
self.use_new_recompute = use_new_recompute
self.recompute_granularity = recompute_granularity
if norm == "LayerNorm":
self.norm = nn.LayerNorm(hidden_size)
elif norm is not None:
......@@ -348,149 +373,36 @@ class TransformerDecoder(nn.Layer):
DPMPPP_MESH_LIST[0],
["x"] + [None for i in range(len(output.shape) - 1)],
)
for i, mod in enumerate(self.layers):
if self.use_new_recompute and self.recompute_granularity == "full":
mod = auto.recompute(mod)
if cache is None:
if use_cache:
if _global_parallel_strategy == "pp":
output, new_cache = auto.shard_op(
mod, PP_MESH_LIST[mod.mesh_idx]
)(output, memory, tgt_mask, use_cache, cache)
auto.shard_tensor(
output,
PP_MESH_LIST[mod.mesh_idx],
[None for i in range(len(output.shape))],
)
elif _global_parallel_strategy == "dp_pp":
output, new_cache = auto.shard_op(
mod, DPPP_MESH_LIST[mod.mesh_idx]
)(output, memory, tgt_mask, use_cache, cache)
auto.shard_tensor(
output,
DPPP_MESH_LIST[mod.mesh_idx],
["x"]
+ [None for i in range(len(output.shape) - 1)],
)
elif _global_parallel_strategy == "mp_pp":
output, new_cache = auto.shard_op(
mod, MPPP_MESH_LIST[mod.mesh_idx]
)(output, memory, tgt_mask, use_cache, cache)
auto.shard_tensor(
output,
MPPP_MESH_LIST[mod.mesh_idx],
[None for i in range(len(output.shape))],
)
elif _global_parallel_strategy == "dp_mp_pp":
output, new_cache = auto.shard_op(
mod, DPMPPP_MESH_LIST[mod.mesh_idx]
)(output, memory, tgt_mask, use_cache, cache)
auto.shard_tensor(
output,
DPMPPP_MESH_LIST[mod.mesh_idx],
[None for i in range(len(output.shape))],
)
else:
output, new_cache = mod(
output,
memory,
tgt_mask=tgt_mask,
use_cache=use_cache,
cache=cache,
)
new_caches.append(new_cache)
else:
if _global_parallel_strategy == "pp":
output = auto.shard_op(mod, PP_MESH_LIST[mod.mesh_idx])(
output, memory, tgt_mask, use_cache, cache
)
auto.shard_tensor(
output,
PP_MESH_LIST[mod.mesh_idx],
[None for i in range(len(output.shape))],
)
elif _global_parallel_strategy == "dp_pp":
output = auto.shard_op(
mod, DPPP_MESH_LIST[mod.mesh_idx]
)(output, memory, tgt_mask, use_cache, cache)
auto.shard_tensor(
output,
DPPP_MESH_LIST[mod.mesh_idx],
["x"]
+ [None for i in range(len(output.shape) - 1)],
)
elif _global_parallel_strategy == "mp_pp":
output = auto.shard_op(
mod, MPPP_MESH_LIST[mod.mesh_idx]
)(output, memory, tgt_mask, use_cache, cache)
auto.shard_tensor(
output,
MPPP_MESH_LIST[mod.mesh_idx],
[None for i in range(len(output.shape))],
)
elif _global_parallel_strategy == "dp_mp_pp":
output = auto.shard_op(
mod, DPMPPP_MESH_LIST[mod.mesh_idx]
)(output, memory, tgt_mask, use_cache, cache)
auto.shard_tensor(
output,
DPMPPP_MESH_LIST[mod.mesh_idx],
["x"]
+ [None for i in range(len(output.shape) - 1)],
)
else:
output = mod(
output,
memory,
tgt_mask=tgt_mask,
use_cache=use_cache,
cache=cache,
)
else:
if _global_parallel_strategy == "pp":
output, new_cache = auto.shard_op(
mod, PP_MESH_LIST[mod.mesh_idx]
)(output, memory, tgt_mask, use_cache, cache)
auto.shard_tensor(
output,
PP_MESH_LIST[mod.mesh_idx],
[None for i in range(len(output.shape))],
)
elif _global_parallel_strategy == "dp_pp":
output, new_cache = auto.shard_op(
mod, DPPP_MESH_LIST[mod.mesh_idx]
)(output, memory, tgt_mask, use_cache, cache)
auto.shard_tensor(
output,
DPPP_MESH_LIST[mod.mesh_idx],
["x"] + [None for i in range(len(output.shape) - 1)],
)
elif _global_parallel_strategy == "mp_pp":
output, new_cache = auto.shard_op(
mod, MPPP_MESH_LIST[mod.mesh_idx]
)(output, memory, tgt_mask, use_cache, cache)
auto.shard_tensor(
output,
MPPP_MESH_LIST[mod.mesh_idx],
[None for i in range(len(output.shape))],
)
elif _global_parallel_strategy == "dp_mp_pp":
output, new_cache = auto.shard_op(
mod, DPMPPP_MESH_LIST[mod.mesh_idx]
)(output, memory, tgt_mask, use_cache, cache)
auto.shard_tensor(
output,
DPMPPP_MESH_LIST[mod.mesh_idx],
["x"] + [None for i in range(len(output.shape) - 1)],
)
else:
output, new_cache = mod(
output,
memory,
tgt_mask=tgt_mask,
use_cache=use_cache,
cache=cache[i],
cache=cache,
)
new_caches.append(new_cache)
else:
output = mod(output, memory, tgt_mask, use_cache, cache)
else:
output, new_cache = mod(
output,
memory,
tgt_mask=tgt_mask,
use_cache=use_cache,
cache=cache[i],
)
new_caches.append(new_cache)
self.checkpoints.append(output.name)
if not self.use_new_recompute:
self.checkpoints.append(output.name)
if self.norm is not None:
output = self.norm(output)
return output if use_cache is False else (output, new_caches)
......@@ -528,6 +440,8 @@ class TransformerDecoderLayer(nn.Layer):
weight_attr=None,
bias_attr=None,
mesh_idx=None,
use_new_recompute=False,
recompute_granularity="full",
):
self._config = locals()
self._config.pop("self")
......@@ -537,8 +451,12 @@ class TransformerDecoderLayer(nn.Layer):
attn_dropout = dropout if attn_dropout is None else attn_dropout
act_dropout = dropout if act_dropout is None else act_dropout
self.normalize_before = normalize_before
self.use_new_recompute = use_new_recompute
self.recompute_granularity = recompute_granularity
weight_attrs = _convert_param_attr_to_list(weight_attr, 3)
bias_attrs = _convert_param_attr_to_list(bias_attr, 3)
self.self_attn = MultiHeadAttention(
d_model,
nhead,
......@@ -546,6 +464,8 @@ class TransformerDecoderLayer(nn.Layer):
weight_attr=weight_attrs[0],
bias_attr=bias_attrs[0],
mesh_idx=self.mesh_idx,
use_new_recompute=self.use_new_recompute,
recompute_granularity=self.recompute_granularity,
)
self.linear1 = nn.Linear(
d_model, dim_feedforward, weight_attrs[2], bias_attr=bias_attrs[2]
......@@ -563,12 +483,19 @@ class TransformerDecoderLayer(nn.Layer):
residual = tgt
if self.normalize_before:
tgt = self.norm1(tgt)
if self.use_new_recompute and self.recompute_granularity == "full_attn":
self_attn = auto.recompute(self.self_attn)
else:
self_attn = self.self_attn
if use_cache is False:
tgt = self.self_attn(tgt, tgt, tgt, tgt_mask, use_cache, cache)
tgt = self_attn(tgt, tgt, tgt, tgt_mask, use_cache, cache)
else:
tgt, incremental_cache = self.self_attn(
tgt, incremental_cache = self_attn(
tgt, tgt, tgt, tgt_mask, use_cache, cache
)
tgt = residual + self.dropout1(tgt)
if not self.normalize_before:
tgt = self.norm1(tgt)
......@@ -716,12 +643,17 @@ class GPTModel(nn.Layer):
bos_token_id=0,
eol_token_id=3,
pp_degree=None,
use_new_recompute=False,
recompute_granularity="full",
):
super().__init__()
self.pad_token_id = pad_token_id
self.initializer_range = initializer_range
self.hidden_size = hidden_size
self.vocab_size = vocab_size
self.use_new_recompute = use_new_recompute
self.recompute_granularity = recompute_granularity
self.layer_per_stage = None
self.pipline_mode = pp_degree is not None and pp_degree > 1
if self.pipline_mode:
......@@ -734,6 +666,7 @@ class GPTModel(nn.Layer):
type_vocab_size,
self.initializer_range,
)
decoder_layers = nn.LayerList()
for i in range(num_hidden_layers):
mesh_index = None
......@@ -756,14 +689,19 @@ class GPTModel(nn.Layer):
),
bias_attr=None,
mesh_idx=mesh_index,
use_new_recompute=self.use_new_recompute,
recompute_granularity=self.recompute_granularity,
)
)
Decoder = TransformerDecoder
self.decoder = Decoder(
decoder_layers,
num_hidden_layers,
norm="LayerNorm",
hidden_size=hidden_size,
use_new_recompute=self.use_new_recompute,
recompute_granularity=self.recompute_granularity,
)
self.checkpoints = []
......@@ -817,7 +755,8 @@ class GPTModel(nn.Layer):
use_cache=use_cache,
cache=cache,
)
self.checkpoints.extend(self.decoder.checkpoints)
if not self.use_new_recompute:
self.checkpoints.extend(self.decoder.checkpoints)
return encoder_outputs
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册