未验证 提交 979af475 编写于 作者: Z zhaoyingli 提交者: GitHub

[AutoParallel] fix fp16 for subblock (#47189)

* [AutoParallel] fix fp16 for subblock

* fix engine

* fix comment
上级 68e27f35
......@@ -240,8 +240,8 @@ class Engine:
else:
specs.append(spec.batch(batch_size))
elif isinstance(item, (Variable, core.VarBase, core.eager.Tensor)):
_adjust_item_spec(num_shards, spec)
spec = InputSpec.from_tensor(item, name)
_adjust_item_spec(num_shards, spec)
if batch_size is None:
specs.append(spec)
else:
......@@ -1508,10 +1508,10 @@ class Engine:
strict (bool, optional): Whether to skip the loading of mismatch
parameter or raise an error when mismatch happens (not found
the parameter in file storing model states of or receives a
mismatch shape). Default: False.
mismatch shape). Default: True.
load_optimizer (bool, optional): If True, the stored optimizer
states is restored. Otherwise, the optimizer states is initialized
from scratch. Default: False.
from scratch. Default: True.
Returns:
None
......
......@@ -181,7 +181,8 @@ class FP16State(object):
try:
var = block.var(var_name)
except ValueError as e:
var = self.program.global_block().var(var_name)
var = block._var_recursive(var_name)
# var = self.program.global_block().var(var_name)
# NOTE(JZ-LIANG) "array_" is a hack to adopt for ernie3.0 inference, since there is
# a trick which make the LOD_TENSOR_ARRAY to the float32 in while block to reset the LOD_TENSOR_ARRAY
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册