Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
979af475
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
979af475
编写于
10月 20, 2022
作者:
Z
zhaoyingli
提交者:
GitHub
10月 20, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[AutoParallel] fix fp16 for subblock (#47189)
* [AutoParallel] fix fp16 for subblock * fix engine * fix comment
上级
68e27f35
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
5 addition
and
4 deletion
+5
-4
python/paddle/distributed/auto_parallel/engine.py
python/paddle/distributed/auto_parallel/engine.py
+3
-3
python/paddle/distributed/passes/auto_parallel_fp16.py
python/paddle/distributed/passes/auto_parallel_fp16.py
+2
-1
未找到文件。
python/paddle/distributed/auto_parallel/engine.py
浏览文件 @
979af475
...
...
@@ -240,8 +240,8 @@ class Engine:
else
:
specs
.
append
(
spec
.
batch
(
batch_size
))
elif
isinstance
(
item
,
(
Variable
,
core
.
VarBase
,
core
.
eager
.
Tensor
)):
_adjust_item_spec
(
num_shards
,
spec
)
spec
=
InputSpec
.
from_tensor
(
item
,
name
)
_adjust_item_spec
(
num_shards
,
spec
)
if
batch_size
is
None
:
specs
.
append
(
spec
)
else
:
...
...
@@ -1508,10 +1508,10 @@ class Engine:
strict (bool, optional): Whether to skip the loading of mismatch
parameter or raise an error when mismatch happens (not found
the parameter in file storing model states of or receives a
mismatch shape). Default:
Fals
e.
mismatch shape). Default:
Tru
e.
load_optimizer (bool, optional): If True, the stored optimizer
states is restored. Otherwise, the optimizer states is initialized
from scratch. Default:
Fals
e.
from scratch. Default:
Tru
e.
Returns:
None
...
...
python/paddle/distributed/passes/auto_parallel_fp16.py
浏览文件 @
979af475
...
...
@@ -181,7 +181,8 @@ class FP16State(object):
try
:
var
=
block
.
var
(
var_name
)
except
ValueError
as
e
:
var
=
self
.
program
.
global_block
().
var
(
var_name
)
var
=
block
.
_var_recursive
(
var_name
)
# var = self.program.global_block().var(var_name)
# NOTE(JZ-LIANG) "array_" is a hack to adopt for ernie3.0 inference, since there is
# a trick which make the LOD_TENSOR_ARRAY to the float32 in while block to reset the LOD_TENSOR_ARRAY
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录