Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Oneflow-Inc
oneflow
提交
a0b87f1b
O
oneflow
项目概览
Oneflow-Inc
/
oneflow
上一次同步 接近 3 年
通知
13
Star
2733
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
O
oneflow
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
a0b87f1b
编写于
12月 31, 2019
作者:
J
Juncheng
提交者:
Li Xinqi
12月 31, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Dev total loss instance num time shape (#2531)
* loss instance num time shape * refine
上级
a77d1ebe
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
42 addition
and
7 deletion
+42
-7
oneflow/core/job_completer/autograd.cpp
oneflow/core/job_completer/autograd.cpp
+42
-7
未找到文件。
oneflow/core/job_completer/autograd.cpp
浏览文件 @
a0b87f1b
...
...
@@ -493,18 +493,53 @@ void AddTotalLossInstanceNumOpConf(
const
auto
&
lbi
=
GenLogicalBlobId
(
loss_lbn
);
CHECK
(
loss_lbi2op_node
.
emplace
(
lbi
,
LossOpNode4OpName
(
lbi
.
op_name
())).
second
);
}
const
BlobDesc
*
blob_desc
=
nullptr
;
const
Shape
src_time_shape
(
{
GlobalJobDesc
().
TotalBatchNum
(),
GlobalJobDesc
().
NumOfPiecesInBatch
()});
const
int64_t
source_time_shape_elem_cnt
=
src_time_shape
.
elem_cnt
();
bool
all_loss_time_shape_eq_src
=
true
;
for
(
const
auto
&
pair
:
loss_lbi2op_node
)
{
const
BlobDesc
*
cur_blob_desc
=
&
pair
.
second
->
LogicalBlobDesc4Lbi
(
pair
.
first
);
if
(
blob_desc
!=
nullptr
)
{
CHECK
(
*
blob_desc
==
*
cur_blob_desc
);
}
blob_desc
=
cur_blob_desc
;
const
Shape
*
time_shape
=
pair
.
second
->
out_blob_time_shape
();
const
int64_t
time_shape_elem_cnt
=
time_shape
->
elem_cnt
();
if
(
time_shape_elem_cnt
!=
source_time_shape_elem_cnt
)
{
CHECK_EQ
(
time_shape_elem_cnt
%
source_time_shape_elem_cnt
,
0
);
all_loss_time_shape_eq_src
=
false
;
}
}
HashMap
<
ParallelDesc
,
int32_t
>
parallel_desc2optimizer_node_cnt
;
CalcParallelDesc2OptimizerNodeCnt
(
op_graph
,
lbi2diff_lbi
,
&
parallel_desc2optimizer_node_cnt
);
if
(
blob_desc
->
is_dynamic
())
{
AddTotalLossInstanceNumOpConfForDynamicDim0
(
parallel_desc2optimizer_node_cnt
,
loss_lbi2op_node
,
job_builder
,
LossInstanceNum4ParallelDesc
);
if
(
all_loss_time_shape_eq_src
)
{
const
BlobDesc
*
blob_desc
=
nullptr
;
for
(
const
auto
&
pair
:
loss_lbi2op_node
)
{
const
BlobDesc
*
cur_blob_desc
=
&
pair
.
second
->
LogicalBlobDesc4Lbi
(
pair
.
first
);
if
(
blob_desc
!=
nullptr
)
{
CHECK
(
*
blob_desc
==
*
cur_blob_desc
);
}
blob_desc
=
cur_blob_desc
;
}
if
(
blob_desc
->
is_dynamic
())
{
AddTotalLossInstanceNumOpConfForDynamicDim0
(
parallel_desc2optimizer_node_cnt
,
loss_lbi2op_node
,
job_builder
,
LossInstanceNum4ParallelDesc
);
}
else
{
BuildConstantOpAsTotalLossInstanceNum
(
parallel_desc2optimizer_node_cnt
,
*
blob_desc
,
job_builder
,
LossInstanceNum4ParallelDesc
);
}
}
else
{
std
::
unique_ptr
<
BlobDesc
>
blob_desc
;
for
(
const
auto
&
pair
:
loss_lbi2op_node
)
{
const
BlobDesc
*
cur_blob_desc
=
&
pair
.
second
->
LogicalBlobDesc4Lbi
(
pair
.
first
);
// TODO: support dynamic
CHECK
(
!
cur_blob_desc
->
is_dynamic
());
const
DataType
loss_data_type
=
cur_blob_desc
->
data_type
();
const
int64_t
time_shape_elem_cnt
=
pair
.
second
->
out_blob_time_shape
()
->
elem_cnt
();
// TODO: consider batch_axis or sbp
const
int64_t
loss_elem_cnt
=
cur_blob_desc
->
shape
().
elem_cnt
()
*
time_shape_elem_cnt
/
source_time_shape_elem_cnt
;
if
(
blob_desc
)
{
CHECK_EQ
(
blob_desc
->
data_type
(),
loss_data_type
);
CHECK_EQ
(
blob_desc
->
shape
().
elem_cnt
(),
loss_elem_cnt
);
}
else
{
blob_desc
.
reset
(
new
BlobDesc
(
Shape
({
loss_elem_cnt
}),
loss_data_type
));
}
}
BuildConstantOpAsTotalLossInstanceNum
(
parallel_desc2optimizer_node_cnt
,
*
blob_desc
,
job_builder
,
LossInstanceNum4ParallelDesc
);
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录