提交 61e6d069 编写于 作者: L Logan Adams

Merge branch 'master' into loadams/low-cpu-mem-ut

......@@ -26,7 +26,7 @@ from ..module_inject.policy import TransformerPolicy
from ..module_inject.auto_tp import AutoTP
from ..module_inject.replace_policy import generic_policies
from ..module_inject.auto_tp_model_utils import build_bloom_alibi_tensor, build_mpt_atten_bias_tensor
from ..module_inject.auto_tp_model_utils import build_bloom_alibi_tensor, build_mpt_atten_bias_tensor, build_mpt_alibi_tensor
DS_INFERENCE_ENABLED = False
from torch import nn
......@@ -187,6 +187,9 @@ class InferenceEngine(Module):
if hasattr(self.module, 'transformer'):
if hasattr(self.module.transformer, 'build_alibi_tensor'):
self.module.transformer.build_alibi_tensor = build_bloom_alibi_tensor
if hasattr(self.module.transformer, 'build_mpt_alibi_tensor'):
self.module.transformer.build_mpt_alibi_tensor_orig = self.module.transformer.build_mpt_alibi_tensor
self.module.transformer.__class__.build_mpt_alibi_tensor = build_mpt_alibi_tensor
def build_attn_bias(self):
if hasattr(self.module, 'transformer'):
......
......@@ -296,21 +296,6 @@ class AutoTP():
if getattr(child, "replaced", False) == True:
return
weight_shape = child.weight.shape
if name == 'attn.Wqkv' and self.module._get_name() == 'MPTBlock':
# MPT block qkv weight's allocation is different from other models, it's [3,num_head,head_dim,hidden_size]
# instead of [num_head,3,head_dim,hidden_size]
new_weight = torch.empty((
weight_shape[0] // self.mp_size,
weight_shape[1],
),
device=child.weight.device,
dtype=child.weight.dtype)
reversed_dim = True
mp_replace = ReplaceWithTensorSlicing(mp_group=self.mp_group, out_dim=0)
# todo: can we remove new tensor allocation if we use strided copy?
mp_replace.strided_copy(new_weight, child.weight.data, num_splits=3, int8=reversed_dim)
setattr(child, "replaced", True)
return LinearLayer(weight=new_weight.to(get_accelerator().current_device_name()), bias=None)
mp_replace = ReplaceWithTensorSlicing(mp_group=self.mp_group)
if name in self.all_reduce_linears:
# if conv_linear_layer [weight_shape[1], weight_shape[0] // mp_size]
......
......@@ -76,3 +76,18 @@ def build_mpt_atten_bias_tensor(self,
offset = dist.get_rank() * num_heads_per_rank
attn_bias = attn_bias[:, offset:num_heads_per_rank + offset, :, :]
return attn_bias, attention_mask
def build_mpt_alibi_tensor(self, num_heads, sequence_length, alibi_bias_max=8, device=None) -> torch.Tensor:
r"""
Link to paper: https://arxiv.org/abs/2108.12409 - Alibi tensor is not causal as the original paper mentions, it
relies on a translation invariance of softmax for quick implementation. This implementation has been copied from
the alibi implementation of MPT source code that led to slightly different results than the Bloom alibi:
https://huggingface.co/mosaicml/mpt-7b/blob/main/attention.py#L292
"""
alibi = self.build_mpt_alibi_tensor_orig(num_heads, sequence_length, alibi_bias_max, device)
if dist.is_initialized():
num_heads_per_rank = int(num_heads / dist.get_world_size())
offset = dist.get_rank() * num_heads_per_rank
alibi = alibi[offset:num_heads_per_rank + offset, :, :]
return alibi
......@@ -16,7 +16,7 @@ def split_by_qkvlist_and_refuse(qkv_list, split_size, split_dim=0, cat_dim=0):
def require_tp_fused_qkvw(name, mp_size):
fused_qkvw_name_list = ['qkv_proj', 'query_key_value']
fused_qkvw_name_list = ['qkv_proj', 'query_key_value', 'attn.Wqkv']
if mp_size == 1:
return False
......@@ -33,6 +33,8 @@ def prepare_tp_fused_qkvw(module_str, src, mp_size, gpu_index):
'CodeGenBlock': 'codegentype',
'BloomBlock': 'bloomtype',
'GLMBlock': 'glmtype',
"MPTBlock": 'glmtype',
"MptBlock": 'glmtype',
}
def _codegen_type_transpose(input, mp_size, codegen_mp_num=4):
......
......@@ -24,6 +24,7 @@ import subprocess
from setuptools import setup, find_packages
from setuptools.command import egg_info
import time
import typing
torch_available = True
try:
......@@ -56,6 +57,22 @@ def fetch_requirements(path):
return [r.strip() for r in fd.readlines()]
def is_env_set(key):
"""
Checks if an environment variable is set and not "".
"""
return bool(os.environ.get(key, None))
def get_env_if_set(key, default: typing.Any = ""):
"""
Returns an environment variable if it is set and not "",
otherwise returns a default value. In contrast, the fallback
parameter of os.environ.get() is skipped if the variable is set to "".
"""
return os.environ.get(key, None) or default
install_requires = fetch_requirements('requirements/requirements.txt')
extras_require = {
'1bit': [], # add cupy based on cuda/rocm version
......@@ -116,14 +133,14 @@ if torch_available and not torch.cuda.is_available():
print("[WARNING] Torch did not find cuda available, if cross-compiling or running with cpu only "
"you can ignore this message. Adding compute capability for Pascal, Volta, and Turing "
"(compute capabilities 6.0, 6.1, 6.2)")
if os.environ.get("TORCH_CUDA_ARCH_LIST", None) is None:
if not is_env_set("TORCH_CUDA_ARCH_LIST"):
os.environ["TORCH_CUDA_ARCH_LIST"] = get_default_compute_capabilities()
ext_modules = []
# Default to pre-install kernels to false so we rely on JIT on Linux, opposite on Windows.
BUILD_OP_PLATFORM = 1 if sys.platform == "win32" else 0
BUILD_OP_DEFAULT = int(os.environ.get('DS_BUILD_OPS', BUILD_OP_PLATFORM))
BUILD_OP_DEFAULT = int(get_env_if_set('DS_BUILD_OPS', BUILD_OP_PLATFORM))
print(f"DS_BUILD_OPS={BUILD_OP_DEFAULT}")
if BUILD_OP_DEFAULT:
......@@ -147,7 +164,7 @@ def op_envvar(op_name):
def op_enabled(op_name):
env_var = op_envvar(op_name)
return int(os.environ.get(env_var, BUILD_OP_DEFAULT))
return int(get_env_if_set(env_var, BUILD_OP_DEFAULT))
compatible_ops = dict.fromkeys(ALL_OPS.keys(), False)
......@@ -160,7 +177,7 @@ for op_name, builder in ALL_OPS.items():
# If op is requested but not available, throw an error.
if op_enabled(op_name) and not op_compatible:
env_var = op_envvar(op_name)
if env_var not in os.environ:
if not is_env_set(env_var):
builder.warning(f"One can disable {op_name} with {env_var}=0")
abort(f"Unable to pre-compile {op_name}")
......@@ -179,7 +196,7 @@ print(f'Install Ops={install_ops}')
# Write out version/git info.
git_hash_cmd = "git rev-parse --short HEAD"
git_branch_cmd = "git rev-parse --abbrev-ref HEAD"
if command_exists('git') and 'DS_BUILD_STRING' not in os.environ:
if command_exists('git') and not is_env_set('DS_BUILD_STRING'):
try:
result = subprocess.check_output(git_hash_cmd, shell=True)
git_hash = result.decode('utf-8').strip()
......@@ -216,11 +233,11 @@ version_str = open('version.txt', 'r').read().strip()
# Example: DS_BUILD_STRING=".dev20201022" python setup.py sdist bdist_wheel.
# Building wheel for distribution, update version file.
if 'DS_BUILD_STRING' in os.environ:
if is_env_set('DS_BUILD_STRING'):
# Build string env specified, probably building for distribution.
with open('build.txt', 'w') as fd:
fd.write(os.environ.get('DS_BUILD_STRING'))
version_str += os.environ.get('DS_BUILD_STRING')
fd.write(os.environ['DS_BUILD_STRING'])
version_str += os.environ['DS_BUILD_STRING']
elif os.path.isfile('build.txt'):
# build.txt exists, probably installing from distribution.
with open('build.txt', 'r') as fd:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册