未验证 提交 5676f5ec 编写于 作者: J Jeff Rasley 提交者: GitHub

[inference] check for unsupported model generate args (#2627)

上级 df985fac
......@@ -49,6 +49,10 @@ class InferenceEngine(Module):
self._get_model_config_generate(config) # keep for weird backward compatibility
# patch model generate with ours if model uses it
if hasattr(self.module, "generate"):
self.generate = self._generate
if hasattr(self.module, "config"):
DSPolicy.hf_model_config = self.module.config
......@@ -148,8 +152,6 @@ class InferenceEngine(Module):
self.config = getattr(self.module,
'config',
None) if config.config is None else config.config
# todo: clarify with Reza if this gets used anywhere
self.generate = getattr(self.module, 'generate', None)
def remove_mask_prepare_for_bloom(self):
if hasattr(self.module, 'transformer'):
......@@ -518,3 +520,19 @@ class InferenceEngine(Module):
self._model_times.append(duration)
return outputs
def _generate(self, *inputs, **kwargs):
num_beams = 1
if "generation_config" in kwargs:
gen_config = kwargs["generation_config"]
num_beams = getattr(gen_config, "num_beams", 1)
if "num_beams" in kwargs:
num_beams = kwargs["num_beams"]
if num_beams > 1:
raise NotImplementedError(
"DeepSpeed does not support `num_beams` > 1, if this is important to you please "
"add your request to: https://github.com/microsoft/DeepSpeed/issues/2506"
)
return self.module.generate(*inputs, **kwargs)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册