未验证 提交 979af475 编写于 作者: Z zhaoyingli 提交者: GitHub

[AutoParallel] fix fp16 for subblock (#47189)

* [AutoParallel] fix fp16 for subblock

* fix engine

* fix comment
上级 68e27f35
...@@ -240,8 +240,8 @@ class Engine: ...@@ -240,8 +240,8 @@ class Engine:
else: else:
specs.append(spec.batch(batch_size)) specs.append(spec.batch(batch_size))
elif isinstance(item, (Variable, core.VarBase, core.eager.Tensor)): elif isinstance(item, (Variable, core.VarBase, core.eager.Tensor)):
_adjust_item_spec(num_shards, spec)
spec = InputSpec.from_tensor(item, name) spec = InputSpec.from_tensor(item, name)
_adjust_item_spec(num_shards, spec)
if batch_size is None: if batch_size is None:
specs.append(spec) specs.append(spec)
else: else:
...@@ -1508,10 +1508,10 @@ class Engine: ...@@ -1508,10 +1508,10 @@ class Engine:
strict (bool, optional): Whether to skip the loading of mismatch strict (bool, optional): Whether to skip the loading of mismatch
parameter or raise an error when mismatch happens (not found parameter or raise an error when mismatch happens (not found
the parameter in file storing model states of or receives a the parameter in file storing model states of or receives a
mismatch shape). Default: False. mismatch shape). Default: True.
load_optimizer (bool, optional): If True, the stored optimizer load_optimizer (bool, optional): If True, the stored optimizer
states is restored. Otherwise, the optimizer states is initialized states is restored. Otherwise, the optimizer states is initialized
from scratch. Default: False. from scratch. Default: True.
Returns: Returns:
None None
......
...@@ -181,7 +181,8 @@ class FP16State(object): ...@@ -181,7 +181,8 @@ class FP16State(object):
try: try:
var = block.var(var_name) var = block.var(var_name)
except ValueError as e: except ValueError as e:
var = self.program.global_block().var(var_name) var = block._var_recursive(var_name)
# var = self.program.global_block().var(var_name)
# NOTE(JZ-LIANG) "array_" is a hack to adopt for ernie3.0 inference, since there is # NOTE(JZ-LIANG) "array_" is a hack to adopt for ernie3.0 inference, since there is
# a trick which make the LOD_TENSOR_ARRAY to the float32 in while block to reset the LOD_TENSOR_ARRAY # a trick which make the LOD_TENSOR_ARRAY to the float32 in while block to reset the LOD_TENSOR_ARRAY
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册