Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Oneflow-Inc
oneflow
提交
32022d4d
O
oneflow
项目概览
Oneflow-Inc
/
oneflow
上一次同步 2 年多
通知
13
Star
2733
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
O
oneflow
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
32022d4d
编写于
8月 18, 2017
作者:
W
willzhang4a58
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
blob desc
上级
0e9c3917
变更
71
隐藏空白更改
内联
并排
Showing
71 changed file
with
296 addition
and
261 deletion
+296
-261
cmake/oneflow.cmake
cmake/oneflow.cmake
+5
-3
oneflow/core/graph/boxing_task_node.cpp
oneflow/core/graph/boxing_task_node.cpp
+7
-7
oneflow/core/graph/boxing_task_node.h
oneflow/core/graph/boxing_task_node.h
+3
-3
oneflow/core/graph/copy_task_node.cpp
oneflow/core/graph/copy_task_node.cpp
+2
-2
oneflow/core/graph/copy_task_node.h
oneflow/core/graph/copy_task_node.h
+1
-1
oneflow/core/graph/data_comp_task_node.cpp
oneflow/core/graph/data_comp_task_node.cpp
+9
-10
oneflow/core/graph/data_comp_task_node.h
oneflow/core/graph/data_comp_task_node.h
+3
-3
oneflow/core/graph/exec_graph.cpp
oneflow/core/graph/exec_graph.cpp
+3
-3
oneflow/core/graph/exec_graph.h
oneflow/core/graph/exec_graph.h
+1
-1
oneflow/core/graph/loss_accumulate_comp_task_node.cpp
oneflow/core/graph/loss_accumulate_comp_task_node.cpp
+2
-2
oneflow/core/graph/loss_accumulate_comp_task_node.h
oneflow/core/graph/loss_accumulate_comp_task_node.h
+1
-1
oneflow/core/graph/loss_record_comp_task_node.cpp
oneflow/core/graph/loss_record_comp_task_node.cpp
+1
-2
oneflow/core/graph/loss_record_comp_task_node.h
oneflow/core/graph/loss_record_comp_task_node.h
+1
-1
oneflow/core/graph/model_diff_accumulate_comp_task_node.cpp
oneflow/core/graph/model_diff_accumulate_comp_task_node.cpp
+2
-2
oneflow/core/graph/model_diff_accumulate_comp_task_node.h
oneflow/core/graph/model_diff_accumulate_comp_task_node.h
+1
-1
oneflow/core/graph/model_save_comp_task_node.cpp
oneflow/core/graph/model_save_comp_task_node.cpp
+1
-1
oneflow/core/graph/model_save_comp_task_node.h
oneflow/core/graph/model_save_comp_task_node.h
+1
-1
oneflow/core/graph/model_update_comp_task_node.cpp
oneflow/core/graph/model_update_comp_task_node.cpp
+6
-6
oneflow/core/graph/model_update_comp_task_node.h
oneflow/core/graph/model_update_comp_task_node.h
+1
-1
oneflow/core/graph/task_graph.cpp
oneflow/core/graph/task_graph.cpp
+3
-4
oneflow/core/graph/task_graph.h
oneflow/core/graph/task_graph.h
+1
-1
oneflow/core/graph/task_node.h
oneflow/core/graph/task_node.h
+1
-1
oneflow/core/job/compiler.cpp
oneflow/core/job/compiler.cpp
+1
-1
oneflow/core/operator/boxing_op.cpp
oneflow/core/operator/boxing_op.cpp
+10
-7
oneflow/core/operator/boxing_op.h
oneflow/core/operator/boxing_op.h
+2
-2
oneflow/core/operator/boxing_op_test.cpp
oneflow/core/operator/boxing_op_test.cpp
+3
-3
oneflow/core/operator/clone_op.cpp
oneflow/core/operator/clone_op.cpp
+4
-4
oneflow/core/operator/clone_op.h
oneflow/core/operator/clone_op.h
+2
-2
oneflow/core/operator/clone_op_test.cpp
oneflow/core/operator/clone_op_test.cpp
+1
-1
oneflow/core/operator/concat_op.cpp
oneflow/core/operator/concat_op.cpp
+13
-10
oneflow/core/operator/concat_op.h
oneflow/core/operator/concat_op.h
+2
-2
oneflow/core/operator/concat_op_test.cpp
oneflow/core/operator/concat_op_test.cpp
+1
-1
oneflow/core/operator/convolution_op.cpp
oneflow/core/operator/convolution_op.cpp
+19
-24
oneflow/core/operator/convolution_op.h
oneflow/core/operator/convolution_op.h
+2
-2
oneflow/core/operator/convolution_op_test.cpp
oneflow/core/operator/convolution_op_test.cpp
+2
-2
oneflow/core/operator/data_loader_op.cpp
oneflow/core/operator/data_loader_op.cpp
+4
-4
oneflow/core/operator/data_loader_op.h
oneflow/core/operator/data_loader_op.h
+2
-2
oneflow/core/operator/innerproduct_op.cpp
oneflow/core/operator/innerproduct_op.cpp
+9
-11
oneflow/core/operator/innerproduct_op.h
oneflow/core/operator/innerproduct_op.h
+2
-2
oneflow/core/operator/innerproduct_op_test.cpp
oneflow/core/operator/innerproduct_op_test.cpp
+2
-2
oneflow/core/operator/model_update_op.h
oneflow/core/operator/model_update_op.h
+2
-2
oneflow/core/operator/momentum_model_update_op.cpp
oneflow/core/operator/momentum_model_update_op.cpp
+3
-4
oneflow/core/operator/momentum_model_update_op.h
oneflow/core/operator/momentum_model_update_op.h
+2
-2
oneflow/core/operator/multinomial_logistic_loss_op.cpp
oneflow/core/operator/multinomial_logistic_loss_op.cpp
+4
-4
oneflow/core/operator/multinomial_logistic_loss_op.h
oneflow/core/operator/multinomial_logistic_loss_op.h
+2
-2
oneflow/core/operator/multinomial_logistic_loss_op_test.cpp
oneflow/core/operator/multinomial_logistic_loss_op_test.cpp
+1
-1
oneflow/core/operator/operator.h
oneflow/core/operator/operator.h
+5
-5
oneflow/core/operator/pooling_op.cpp
oneflow/core/operator/pooling_op.cpp
+12
-14
oneflow/core/operator/pooling_op.h
oneflow/core/operator/pooling_op.h
+2
-2
oneflow/core/operator/pooling_op_test.cpp
oneflow/core/operator/pooling_op_test.cpp
+1
-1
oneflow/core/operator/record_op.h
oneflow/core/operator/record_op.h
+2
-2
oneflow/core/operator/relu_op.cpp
oneflow/core/operator/relu_op.cpp
+3
-5
oneflow/core/operator/relu_op.h
oneflow/core/operator/relu_op.h
+2
-2
oneflow/core/operator/relu_op_test.cpp
oneflow/core/operator/relu_op_test.cpp
+1
-1
oneflow/core/operator/rmsprop_model_update_op.cpp
oneflow/core/operator/rmsprop_model_update_op.cpp
+3
-4
oneflow/core/operator/rmsprop_model_update_op.h
oneflow/core/operator/rmsprop_model_update_op.h
+2
-2
oneflow/core/operator/softmax_loss_op.cpp
oneflow/core/operator/softmax_loss_op.cpp
+7
-7
oneflow/core/operator/softmax_loss_op.h
oneflow/core/operator/softmax_loss_op.h
+2
-2
oneflow/core/operator/softmax_loss_op_test.cpp
oneflow/core/operator/softmax_loss_op_test.cpp
+1
-1
oneflow/core/operator/softmax_op.cpp
oneflow/core/operator/softmax_op.cpp
+6
-5
oneflow/core/operator/softmax_op.h
oneflow/core/operator/softmax_op.h
+2
-2
oneflow/core/operator/softmax_op_test.cpp
oneflow/core/operator/softmax_op_test.cpp
+1
-1
oneflow/core/register/blob.h
oneflow/core/register/blob.h
+1
-1
oneflow/core/register/blob_desc.h
oneflow/core/register/blob_desc.h
+30
-0
oneflow/core/register/blob_desc.proto
oneflow/core/register/blob_desc.proto
+8
-0
oneflow/core/register/register_desc.cpp
oneflow/core/register/register_desc.cpp
+29
-24
oneflow/core/register/register_desc.h
oneflow/core/register/register_desc.h
+9
-12
oneflow/core/register/register_desc.proto
oneflow/core/register/register_desc.proto
+2
-2
oneflow/core/register/register_manager.cpp
oneflow/core/register/register_manager.cpp
+7
-6
oneflow/core/register/runtime_register_desc.cpp
oneflow/core/register/runtime_register_desc.cpp
+3
-2
oneflow/core/register/runtime_register_desc.h
oneflow/core/register/runtime_register_desc.h
+4
-4
未找到文件。
cmake/oneflow.cmake
浏览文件 @
32022d4d
...
...
@@ -108,6 +108,8 @@ foreach(cc ${of_main_cc})
endforeach
()
# build test
cuda_add_executable
(
oneflow_testexe
${
of_all_test_cc
}
)
target_link_libraries
(
oneflow_testexe
${
of_libs
}
${
oneflow_third_party_libs
}
)
add_test
(
NAME oneflow_test COMMAND oneflow_testexe
)
if
(
BUILD_TESTING
)
cuda_add_executable
(
oneflow_testexe
${
of_all_test_cc
}
)
target_link_libraries
(
oneflow_testexe
${
of_libs
}
${
oneflow_third_party_libs
}
)
add_test
(
NAME oneflow_test COMMAND oneflow_testexe
)
endif
()
oneflow/core/graph/boxing_task_node.cpp
浏览文件 @
32022d4d
...
...
@@ -173,11 +173,11 @@ void BoxingTaskNode::FwBuildChainSortedEdgesPair(
}
}
void
BoxingTaskNode
::
FwInfer
ShapeOfBlobs
InProducedRegsts
(
TaskGraph
*
)
{
void
BoxingTaskNode
::
FwInfer
BlobDesc
InProducedRegsts
(
TaskGraph
*
)
{
exec_gph
().
ConstForEachNode
([
this
](
const
ExecNode
*
exec_node
)
{
exec_node
->
op
()
->
Infer
Shape4FwBlobs
(
exec_node
->
GetMutShapePtr4BnInOpFunc
(),
chain_node
()
->
parallel_desc
()
->
policy
(),
0
,
0
);
exec_node
->
op
()
->
Infer
BlobDesc4FwBlobs
(
exec_node
->
GetBlobDesc4BnInOpFunc
(),
chain_node
()
->
parallel_desc
()
->
policy
(),
0
,
0
);
});
}
...
...
@@ -232,16 +232,16 @@ void BoxingTaskNode::BpBuildExecAndEnrollLbn2Regsts(TaskGraph*) {
mut_exec_gph
().
UpdateSourceAndSink
();
}
void
BoxingTaskNode
::
BpInfer
ShapeOfBlobs
InProducedRegsts
(
TaskGraph
*
)
{
void
BoxingTaskNode
::
BpInfer
BlobDesc
InProducedRegsts
(
TaskGraph
*
)
{
for
(
TaskEdge
*
fw_in_edge
:
GetFwNode
()
->
in_edges
())
{
auto
in_regst
=
GetRelatedRegst
(
fw_in_edge
);
if
(
auto
in_diff_regst
=
GetBpRegstFromFwRegst
(
in_regst
))
{
in_diff_regst
->
Copy
Shape
From
(
in_regst
.
get
());
in_diff_regst
->
Copy
BlobDesc
From
(
in_regst
.
get
());
}
}
auto
fw_middle_regst
=
GetFwNode
()
->
GetProducedRegstDesc
(
"middle"
);
auto
bp_middle_regst
=
GetProducedRegstDesc
(
"middle"
);
bp_middle_regst
->
Copy
Shape
From
(
fw_middle_regst
.
get
());
bp_middle_regst
->
Copy
BlobDesc
From
(
fw_middle_regst
.
get
());
}
}
// namespace oneflow
oneflow/core/graph/boxing_task_node.h
浏览文件 @
32022d4d
...
...
@@ -41,12 +41,12 @@ class BoxingTaskNode : public TaskNode {
private:
OVERRIDE_IF_FW_BP_FOR_FUNC
(
BuildExecAndEnrollLbn2Regsts
);
OVERRIDE_IF_FW_BP_FOR_FUNC
(
Infer
ShapeOfBlobs
InProducedRegsts
);
OVERRIDE_IF_FW_BP_FOR_FUNC
(
Infer
BlobDesc
InProducedRegsts
);
void
FwBuildExecAndEnrollLbn2Regsts
(
TaskGraph
*
);
void
FwInfer
ShapeOfBlobs
InProducedRegsts
(
TaskGraph
*
);
void
FwInfer
BlobDesc
InProducedRegsts
(
TaskGraph
*
);
void
BpBuildExecAndEnrollLbn2Regsts
(
TaskGraph
*
);
void
BpInfer
ShapeOfBlobs
InProducedRegsts
(
TaskGraph
*
);
void
BpInfer
BlobDesc
InProducedRegsts
(
TaskGraph
*
);
void
EnrollAllRegstAndBindRelatedEdge
();
TaskType
task_type
()
const
override
{
return
kBoxingTask
;
}
...
...
oneflow/core/graph/copy_task_node.cpp
浏览文件 @
32022d4d
...
...
@@ -25,10 +25,10 @@ void CopyTaskNode::BuildExecAndEnrollLbn2Regsts(TaskGraph*) {
mut_exec_gph
().
UpdateSourceAndSink
();
}
void
CopyTaskNode
::
Infer
ShapeOfBlobs
InProducedRegsts
(
TaskGraph
*
)
{
void
CopyTaskNode
::
Infer
BlobDesc
InProducedRegsts
(
TaskGraph
*
)
{
std
::
shared_ptr
<
RegstDesc
>
in_regst
=
GetRelatedRegst
(
SoleInEdge
());
std
::
shared_ptr
<
RegstDesc
>
out_regst
=
GetRelatedRegst
(
SoleOutEdge
());
out_regst
->
Copy
Shape
From
(
in_regst
.
get
());
out_regst
->
Copy
BlobDesc
From
(
in_regst
.
get
());
}
void
CopyHDTaskNode
::
SetFwInCopy
()
{
...
...
oneflow/core/graph/copy_task_node.h
浏览文件 @
32022d4d
...
...
@@ -16,7 +16,7 @@ class CopyTaskNode : public TaskNode {
private:
void
BuildExecAndEnrollLbn2Regsts
(
TaskGraph
*
)
override
;
void
Infer
ShapeOfBlobs
InProducedRegsts
(
TaskGraph
*
)
override
;
void
Infer
BlobDesc
InProducedRegsts
(
TaskGraph
*
)
override
;
};
class
CopyHDTaskNode
final
:
public
CopyTaskNode
{
...
...
oneflow/core/graph/data_comp_task_node.cpp
浏览文件 @
32022d4d
...
...
@@ -24,17 +24,16 @@ void DataCompTaskNode::FwBuildExecAndEnrollLbn2Regsts(TaskGraph*) {
FwEnrollLbn2ModelAndTmpRegsts
();
// model model_tmp data_tmp
}
void
DataCompTaskNode
::
FwInfer
ShapeOfBlobs
InProducedRegsts
(
TaskGraph
*
)
{
void
DataCompTaskNode
::
FwInfer
BlobDesc
InProducedRegsts
(
TaskGraph
*
)
{
exec_gph
().
ConstTopoForEachNode
([
this
](
const
ExecNode
*
node
)
{
node
->
op
()
->
InferShape4FwBlobs
(
node
->
GetMutShapePtr4BnInOpFunc
(),
chain_node
()
->
parallel_desc
()
->
policy
(),
parallel_id
(),
chain_node
()
->
parallel_desc
()
->
parallel_num
());
node
->
op
()
->
InferBlobDesc4FwBlobs
(
node
->
GetBlobDesc4BnInOpFunc
(),
chain_node
()
->
parallel_desc
()
->
policy
(),
parallel_id
(),
chain_node
()
->
parallel_desc
()
->
parallel_num
());
});
if
(
IsLossNode
())
{
auto
out_regst
=
GetRelatedRegst
(
SoleOutEdge
());
auto
in_regst
=
GetRelatedRegst
(
SoleInEdge
());
out_regst
->
Copy
Shape
From
(
in_regst
.
get
());
out_regst
->
Copy
BlobDesc
From
(
in_regst
.
get
());
}
}
...
...
@@ -173,20 +172,20 @@ void DataCompTaskNode::BpBuildExecAndEnrollLbn2Regsts(TaskGraph*) {
BpEnrollLbn2ProducedRegst
();
}
void
DataCompTaskNode
::
BpInfer
ShapeOfBlobs
InProducedRegsts
(
TaskGraph
*
)
{
void
DataCompTaskNode
::
BpInfer
BlobDesc
InProducedRegsts
(
TaskGraph
*
)
{
// in_diff_regst
auto
in_diff_regst
=
GetProducedRegstDesc
(
"in_diff"
);
auto
in_regst
=
GetRelatedRegst
(
GetFwNode
()
->
SoleInEdge
());
in_diff_regst
->
Copy
Shape
From
(
in_regst
.
get
());
in_diff_regst
->
Copy
BlobDesc
From
(
in_regst
.
get
());
// model_diff_regst
if
(
auto
md_diff_regst
=
GetProducedRegstDesc
(
"model_diff"
))
{
md_diff_regst
->
Copy
Shape
From
(
md_diff_regst
->
Copy
BlobDesc
From
(
GetFwNode
()
->
GetConsumedRegstDesc
(
"model"
).
get
());
}
// activation_diff_regst
if
(
auto
acti_diff_regst
=
GetProducedRegstDesc
(
"activation_diff"
))
{
auto
acti_regst
=
GetFwNode
()
->
GetProducedRegstDesc
(
"activation"
);
acti_diff_regst
->
Copy
Shape
From
(
acti_regst
.
get
());
acti_diff_regst
->
Copy
BlobDesc
From
(
acti_regst
.
get
());
}
}
...
...
oneflow/core/graph/data_comp_task_node.h
浏览文件 @
32022d4d
...
...
@@ -41,12 +41,12 @@ class DataCompTaskNode final : public CompTaskNode {
private:
OVERRIDE_IF_FW_BP_FOR_FUNC
(
BuildExecAndEnrollLbn2Regsts
);
OVERRIDE_IF_FW_BP_FOR_FUNC
(
Infer
ShapeOfBlobs
InProducedRegsts
);
OVERRIDE_IF_FW_BP_FOR_FUNC
(
Infer
BlobDesc
InProducedRegsts
);
using
Lbn2NodeBnMap
=
HashMap
<
std
::
string
,
std
::
pair
<
ExecNode
*
,
std
::
string
>>
;
void
FwBuildExecAndEnrollLbn2Regsts
(
TaskGraph
*
gph
);
void
FwInfer
ShapeOfBlobs
InProducedRegsts
(
TaskGraph
*
gph
);
void
FwInfer
BlobDesc
InProducedRegsts
(
TaskGraph
*
gph
);
void
FwBuildFromUserOps
(
Lbn2NodeBnMap
*
lbn2producer
,
Lbn2NodeBnMap
*
extern_in_lbn2consumer
);
void
FwSetExecNodeFromInRegst
(
const
Lbn2NodeBnMap
&
extern_in_lbn2consumer
);
...
...
@@ -56,7 +56,7 @@ class DataCompTaskNode final : public CompTaskNode {
void
FwEnrollLbn2ActivationRegst
();
void
FwEnrollLbn2ModelAndTmpRegsts
();
void
BpBuildExecAndEnrollLbn2Regsts
(
TaskGraph
*
);
void
BpInfer
ShapeOfBlobs
InProducedRegsts
(
TaskGraph
*
);
void
BpInfer
BlobDesc
InProducedRegsts
(
TaskGraph
*
);
void
BpBuildExecGraph
();
void
BpEnrollLbn2ProducedRegst
();
void
BpEnrollLbn2ActivationDiffRegst
();
...
...
oneflow/core/graph/exec_graph.cpp
浏览文件 @
32022d4d
...
...
@@ -4,14 +4,14 @@ namespace oneflow {
void
ExecEdge
::
set_lbn
(
const
std
::
string
&
lbn
)
{
lbn_
=
lbn
;
}
std
::
function
<
Shape
*
(
const
std
::
string
&
)
>
ExecNode
::
GetMutShapePtr
4BnInOpFunc
()
std
::
function
<
BlobDesc
*
(
const
std
::
string
&
)
>
ExecNode
::
GetBlobDesc
4BnInOpFunc
()
const
{
return
[
this
](
const
std
::
string
&
bn_in_op
)
->
Shape
*
{
return
[
this
](
const
std
::
string
&
bn_in_op
)
->
BlobDesc
*
{
auto
it
=
this
->
bn_in_op2regst_
.
find
(
bn_in_op
);
if
(
it
==
this
->
bn_in_op2regst_
.
end
())
{
return
nullptr
;
}
std
::
shared_ptr
<
RegstDesc
>
regst
=
it
->
second
.
lock
();
const
std
::
string
&
lbn
=
this
->
op
()
->
Lbn4BnInOp
(
bn_in_op
);
return
regst
->
GetMut
ShapePtr
(
lbn
);
return
regst
->
GetMut
BlobDesc
(
lbn
);
};
}
...
...
oneflow/core/graph/exec_graph.h
浏览文件 @
32022d4d
...
...
@@ -55,7 +55,7 @@ class ExecNode final : public Node<ExecNode, ExecEdge> {
return
bn_in_op2regst_
;
}
std
::
function
<
Shape
*
(
const
std
::
string
&
)
>
GetMutShapePtr
4BnInOpFunc
()
const
;
std
::
function
<
BlobDesc
*
(
const
std
::
string
&
)
>
GetBlobDesc
4BnInOpFunc
()
const
;
std
::
string
VisualStr
()
const
{
return
op_
->
op_name
();
}
...
...
oneflow/core/graph/loss_accumulate_comp_task_node.cpp
浏览文件 @
32022d4d
...
...
@@ -22,11 +22,11 @@ void LossAccCompTaskNode::BuildExecAndEnrollLbn2Regsts(TaskGraph* gph) {
mut_exec_gph
().
UpdateSourceAndSink
();
}
void
LossAccCompTaskNode
::
Infer
ShapeOfBlobs
InProducedRegsts
(
TaskGraph
*
gph
)
{
void
LossAccCompTaskNode
::
Infer
BlobDesc
InProducedRegsts
(
TaskGraph
*
gph
)
{
if
(
!
chain_node
()
->
op_vec
().
empty
())
{
auto
loss_regst
=
GetConsumedRegstDesc
(
"loss"
);
auto
loss_acc_regst
=
GetProducedRegstDesc
(
"loss_acc"
);
loss_acc_regst
->
Copy
Shape
From
(
loss_regst
.
get
());
loss_acc_regst
->
Copy
BlobDesc
From
(
loss_regst
.
get
());
}
}
...
...
oneflow/core/graph/loss_accumulate_comp_task_node.h
浏览文件 @
32022d4d
...
...
@@ -13,7 +13,7 @@ class LossAccCompTaskNode final : public CompTaskNode {
private:
void
BuildExecAndEnrollLbn2Regsts
(
TaskGraph
*
gph
)
override
;
void
Infer
ShapeOfBlobs
InProducedRegsts
(
TaskGraph
*
gph
)
override
;
void
Infer
BlobDesc
InProducedRegsts
(
TaskGraph
*
gph
)
override
;
TaskType
task_type
()
const
override
{
return
kLossAccCompTask
;
}
std
::
unique_ptr
<
TaskNode
>
CreateSameTypeNode
()
const
override
{
return
of_make_unique
<
LossAccCompTaskNode
>
();
...
...
oneflow/core/graph/loss_record_comp_task_node.cpp
浏览文件 @
32022d4d
...
...
@@ -20,7 +20,6 @@ void LossRecordCompTaskNode::BuildExecAndEnrollLbn2Regsts(TaskGraph* gph) {
mut_exec_gph
().
UpdateSourceAndSink
();
}
void
LossRecordCompTaskNode
::
InferShapeOfBlobsInProducedRegsts
(
TaskGraph
*
gph
)
{
}
void
LossRecordCompTaskNode
::
InferBlobDescInProducedRegsts
(
TaskGraph
*
gph
)
{}
}
// namespace oneflow
oneflow/core/graph/loss_record_comp_task_node.h
浏览文件 @
32022d4d
...
...
@@ -13,7 +13,7 @@ class LossRecordCompTaskNode final : public CompTaskNode {
private:
void
BuildExecAndEnrollLbn2Regsts
(
TaskGraph
*
gph
)
override
;
void
Infer
ShapeOfBlobs
InProducedRegsts
(
TaskGraph
*
gph
)
override
;
void
Infer
BlobDesc
InProducedRegsts
(
TaskGraph
*
gph
)
override
;
bool
IsMeaningLess
()
const
override
{
return
!
GetConsumedRegstDesc
(
"loss_acc"
);
}
...
...
oneflow/core/graph/model_diff_accumulate_comp_task_node.cpp
浏览文件 @
32022d4d
...
...
@@ -35,13 +35,13 @@ void MdDiffAccCompTaskNode::BuildExecAndEnrollLbn2Regsts(TaskGraph* gph) {
mut_exec_gph
().
UpdateSourceAndSink
();
}
void
MdDiffAccCompTaskNode
::
Infer
ShapeOfBlobs
InProducedRegsts
(
TaskGraph
*
gph
)
{
void
MdDiffAccCompTaskNode
::
Infer
BlobDesc
InProducedRegsts
(
TaskGraph
*
gph
)
{
CHECK
(
IsFwNode
());
if
(
!
chain_node
()
->
op_vec
().
empty
())
{
std
::
shared_ptr
<
RegstDesc
>
in_regst
=
GetConsumedRegstDesc
(
"model_diff"
);
std
::
shared_ptr
<
RegstDesc
>
out_regst
=
GetProducedRegstDesc
(
"model_diff_acc"
);
out_regst
->
Copy
Shape
From
(
in_regst
.
get
());
out_regst
->
Copy
BlobDesc
From
(
in_regst
.
get
());
}
}
...
...
oneflow/core/graph/model_diff_accumulate_comp_task_node.h
浏览文件 @
32022d4d
...
...
@@ -19,7 +19,7 @@ class MdDiffAccCompTaskNode final : public CompTaskNode {
private:
void
BuildExecAndEnrollLbn2Regsts
(
TaskGraph
*
gph
)
override
;
void
Infer
ShapeOfBlobs
InProducedRegsts
(
TaskGraph
*
gph
)
override
;
void
Infer
BlobDesc
InProducedRegsts
(
TaskGraph
*
gph
)
override
;
TaskType
task_type
()
const
override
{
return
kMdDiffAccCompTask
;
}
std
::
unique_ptr
<
TaskNode
>
CreateSameTypeNode
()
const
override
{
return
of_make_unique
<
MdDiffAccCompTaskNode
>
();
...
...
oneflow/core/graph/model_save_comp_task_node.cpp
浏览文件 @
32022d4d
...
...
@@ -32,7 +32,7 @@ void MdSaveCompTaskNode::BuildExecAndEnrollLbn2Regsts(TaskGraph* gph) {
}
}
void
MdSaveCompTaskNode
::
Infer
ShapeOfBlobs
InProducedRegsts
(
TaskGraph
*
gph
)
{
void
MdSaveCompTaskNode
::
Infer
BlobDesc
InProducedRegsts
(
TaskGraph
*
gph
)
{
CHECK
(
IsFwNode
());
}
...
...
oneflow/core/graph/model_save_comp_task_node.h
浏览文件 @
32022d4d
...
...
@@ -22,7 +22,7 @@ class MdSaveCompTaskNode final : public CompTaskNode {
private:
void
BuildExecAndEnrollLbn2Regsts
(
TaskGraph
*
gph
)
override
;
void
Infer
ShapeOfBlobs
InProducedRegsts
(
TaskGraph
*
gph
)
override
;
void
Infer
BlobDesc
InProducedRegsts
(
TaskGraph
*
gph
)
override
;
bool
IsMeaningLess
()
const
override
{
return
!
GetConsumedRegstDesc
(
"model"
);
}
TaskType
task_type
()
const
override
{
return
kMdSaveCompTask
;
}
...
...
oneflow/core/graph/model_update_comp_task_node.cpp
浏览文件 @
32022d4d
...
...
@@ -34,17 +34,17 @@ void MdUpdtCompTaskNode::BuildExecAndEnrollLbn2Regsts(TaskGraph* gph) {
mut_exec_gph
().
UpdateSourceAndSink
();
}
void
MdUpdtCompTaskNode
::
Infer
ShapeOfBlobs
InProducedRegsts
(
TaskGraph
*
gph
)
{
void
MdUpdtCompTaskNode
::
Infer
BlobDesc
InProducedRegsts
(
TaskGraph
*
gph
)
{
CHECK
(
IsFwNode
());
ExecNode
*
exec_node
=
exec_gph
().
SoleNode
();
auto
model_diffs_regst
=
GetConsumedRegstDesc
(
"model_diffs"
);
Shape
packed_model_diffs_shape
({
model_diffs_regst
->
CompElemCntOfAllBlob
()}
);
exec_node
->
op
()
->
Infer
Shape
4FwBlobs
(
[
&
](
const
std
::
string
&
bn_in_op
)
->
Shape
*
{
BlobDesc
packed_blob_desc
=
model_diffs_regst
->
CompPackedBlobDesc
(
);
exec_node
->
op
()
->
Infer
BlobDesc
4FwBlobs
(
[
&
](
const
std
::
string
&
bn_in_op
)
->
BlobDesc
*
{
if
(
bn_in_op
==
"model_diffs"
)
{
return
&
packed_
model_diffs_shape
;
return
&
packed_
blob_desc
;
}
else
{
return
exec_node
->
Get
MutShapePtr
4BnInOpFunc
()(
bn_in_op
);
return
exec_node
->
Get
BlobDesc
4BnInOpFunc
()(
bn_in_op
);
}
},
kDataParallel
,
0
,
0
);
...
...
oneflow/core/graph/model_update_comp_task_node.h
浏览文件 @
32022d4d
...
...
@@ -35,7 +35,7 @@ class MdUpdtCompTaskNode final : public CompTaskNode {
private:
void
BuildExecAndEnrollLbn2Regsts
(
TaskGraph
*
gph
)
override
;
void
Infer
ShapeOfBlobs
InProducedRegsts
(
TaskGraph
*
gph
)
override
;
void
Infer
BlobDesc
InProducedRegsts
(
TaskGraph
*
gph
)
override
;
TaskType
task_type
()
const
override
{
return
kMdUpdtCompTask
;
}
std
::
unique_ptr
<
TaskNode
>
CreateSameTypeNode
()
const
override
{
return
of_make_unique
<
MdUpdtCompTaskNode
>
();
...
...
oneflow/core/graph/task_graph.cpp
浏览文件 @
32022d4d
...
...
@@ -33,10 +33,9 @@ void TaskGraph::BuildExecAndEnrollLbn2Regsts() {
[
this
](
TaskNode
*
node
)
{
node
->
BuildExecAndEnrollLbn2Regsts
(
this
);
});
}
void
TaskGraph
::
InferShapeOfBlobsInProducedRegsts
()
{
TopoForEachNode
([
this
](
TaskNode
*
node
)
{
node
->
InferShapeOfBlobsInProducedRegsts
(
this
);
});
void
TaskGraph
::
InferBlobDescInProducedRegsts
()
{
TopoForEachNode
(
[
this
](
TaskNode
*
node
)
{
node
->
InferBlobDescInProducedRegsts
(
this
);
});
}
std
::
vector
<
CompTaskNode
*>
TaskGraph
::
CompTasksInChain
(
const
ChainNode
*
chain
)
{
...
...
oneflow/core/graph/task_graph.h
浏览文件 @
32022d4d
...
...
@@ -22,7 +22,7 @@ class TaskGraph : public Graph<TaskNode, TaskEdge> {
const
ChainGraph
*
chain_gph
()
const
{
return
stage_gph_
->
chain_gph
();
}
std
::
vector
<
CompTaskNode
*>
CompTasksInChain
(
const
ChainNode
*
);
void
Infer
ShapeOfBlobs
InProducedRegsts
();
void
Infer
BlobDesc
InProducedRegsts
();
const
std
::
string
&
name
()
const
{
return
name_
;
}
...
...
oneflow/core/graph/task_node.h
浏览文件 @
32022d4d
...
...
@@ -42,7 +42,7 @@ class TaskNode : public Node<TaskNode, TaskEdge> {
//
virtual
void
BuildExecAndEnrollLbn2Regsts
(
TaskGraph
*
)
=
0
;
virtual
void
Infer
ShapeOfBlobs
InProducedRegsts
(
TaskGraph
*
)
=
0
;
virtual
void
Infer
BlobDesc
InProducedRegsts
(
TaskGraph
*
)
=
0
;
#define OVERRIDE_IF_FW_BP_FOR_FUNC(func_name) \
void func_name(TaskGraph* gph) override { \
...
...
oneflow/core/job/compiler.cpp
浏览文件 @
32022d4d
...
...
@@ -168,7 +168,7 @@ void Compiler::BuildLossGraph(
void
Compiler
::
InferShape4Regsts
()
{
for
(
auto
&
task_gph
:
ordered_task_gphs_
)
{
LOG
(
INFO
)
<<
"InferShape for "
<<
task_gph
->
name
();
task_gph
->
Infer
ShapeOfBlobs
InProducedRegsts
();
task_gph
->
Infer
BlobDesc
InProducedRegsts
();
}
}
...
...
oneflow/core/operator/boxing_op.cpp
浏览文件 @
32022d4d
...
...
@@ -31,13 +31,13 @@ std::string BoxingOp::obn2lbn(const std::string& output_bn) const {
return
GetStringFromSpecialConf
(
"lbn"
);
}
void
BoxingOp
::
Infer
Shape
4FwBlobs
(
std
::
function
<
Shape
*
(
const
std
::
string
&
)
>
GetShapePtr
4BnInOp
,
void
BoxingOp
::
Infer
BlobDesc
4FwBlobs
(
std
::
function
<
BlobDesc
*
(
const
std
::
string
)
>
GetBlobDesc
4BnInOp
,
ParallelPolicy
policy
,
int64_t
parallel_id
,
int64_t
parallel_num
)
const
{
auto
boxing_conf
=
op_conf
().
boxing_conf
();
auto
in_box_case
=
boxing_conf
.
in_box_case
();
std
::
vector
<
int64_t
>
data_tmp_blob_shape_vec
=
Get
ShapePtr4BnInOp
(
input_bns
().
at
(
0
))
->
dim_vec
();
Get
BlobDesc4BnInOp
(
input_bns
().
at
(
0
))
->
shape
().
dim_vec
();
// if it is a concat-box, accumulate the dimensions on concat-axis.
// otherwise only check all boxes are in the same shape.
...
...
@@ -47,7 +47,8 @@ void BoxingOp::InferShape4FwBlobs(
CHECK
(
concat_axis
==
0
||
concat_axis
==
1
);
}
for
(
size_t
ib_idx
=
1
;
ib_idx
<
input_bns
().
size
();
++
ib_idx
)
{
auto
ib_shape_vec
=
GetShapePtr4BnInOp
(
input_bns
().
at
(
ib_idx
))
->
dim_vec
();
auto
ib_shape_vec
=
GetBlobDesc4BnInOp
(
input_bns
().
at
(
ib_idx
))
->
shape
().
dim_vec
();
for
(
size_t
i
=
0
;
i
<
ib_shape_vec
.
size
();
++
i
)
{
if
(
in_box_case
==
BoxingOpConf
::
kConcatBox
&&
i
==
concat_axis
)
{
data_tmp_blob_shape_vec
[
i
]
+=
ib_shape_vec
[
i
];
...
...
@@ -61,7 +62,8 @@ void BoxingOp::InferShape4FwBlobs(
// it is stored back if and only if this is a concat-clone box
if
(
in_box_case
==
BoxingOpConf
::
kConcatBox
&&
out_box_case
==
BoxingOpConf
::
kCloneBox
)
{
*
GetShapePtr4BnInOp
(
SoleDtbn
())
=
Shape
(
data_tmp_blob_shape_vec
);
GetBlobDesc4BnInOp
(
SoleDtbn
())
->
mut_shape
()
=
Shape
(
data_tmp_blob_shape_vec
);
}
CHECK_NE
(
out_box_case
,
BoxingOpConf
::
OUT_BOX_NOT_SET
);
if
(
out_box_case
==
BoxingOpConf
::
kDataSplitBox
)
{
...
...
@@ -70,11 +72,12 @@ void BoxingOp::InferShape4FwBlobs(
auto
output_shape_vec
=
data_tmp_blob_shape_vec
;
for
(
size_t
i
=
0
;
i
<
out_num
;
++
i
)
{
output_shape_vec
[
0
]
=
splitter
.
At
(
i
).
size
();
*
GetShapePtr4BnInOp
(
output_bns
()[
i
])
=
Shape
(
output_shape_vec
);
GetBlobDesc4BnInOp
(
output_bns
()[
i
])
->
mut_shape
()
=
Shape
(
output_shape_vec
);
}
}
else
if
(
out_box_case
==
BoxingOpConf
::
kCloneBox
)
{
for
(
auto
obn
:
output_bns
())
{
*
GetShapePtr4BnInOp
(
obn
)
=
Shape
(
data_tmp_blob_shape_vec
);
GetBlobDesc4BnInOp
(
obn
)
->
mut_shape
(
)
=
Shape
(
data_tmp_blob_shape_vec
);
}
}
else
{
UNEXPECTED_RUN
();
...
...
oneflow/core/operator/boxing_op.h
浏览文件 @
32022d4d
...
...
@@ -14,8 +14,8 @@ class BoxingOp final : public SysOperator {
void
InitFromOpConf
(
const
OperatorConf
&
op_conf
)
override
;
const
PbMessage
&
GetSpecialConf
()
const
override
;
void
Infer
Shape
4FwBlobs
(
std
::
function
<
Shape
*
(
const
std
::
string
&
)
>
GetShapePtr
4BnInOp
,
void
Infer
BlobDesc
4FwBlobs
(
std
::
function
<
BlobDesc
*
(
const
std
::
string
)
>
GetBlobDesc
4BnInOp
,
ParallelPolicy
policy
,
int64_t
parallel_id
,
int64_t
parallel_num
)
const
override
;
...
...
oneflow/core/operator/boxing_op_test.cpp
浏览文件 @
32022d4d
...
...
@@ -36,7 +36,7 @@ TEST(BoxingOp, box_4_10x5x6x6) {
};
// do infer shape
boxing_op
->
Infer
Shape
4FwBlobs
(
fp
,
kModelParallel
,
0
,
1
);
boxing_op
->
Infer
BlobDesc
4FwBlobs
(
fp
,
kModelParallel
,
0
,
1
);
// test results
// output_shape should be:
...
...
@@ -58,7 +58,7 @@ TEST(BoxingOp, box_4_10x5x6x6) {
boxing_op
=
ConstructOp
(
op_conf
);
// do infer shape
boxing_op
->
Infer
Shape
4FwBlobs
(
fp
,
kModelParallel
,
0
,
1
);
boxing_op
->
Infer
BlobDesc
4FwBlobs
(
fp
,
kModelParallel
,
0
,
1
);
// test results
// output shape should be the same as input
...
...
@@ -75,7 +75,7 @@ TEST(BoxingOp, box_4_10x5x6x6) {
boxing_op
=
ConstructOp
(
op_conf
);
// do infer shape
boxing_op
->
Infer
Shape
4FwBlobs
(
fp
,
kModelParallel
,
0
,
1
);
boxing_op
->
Infer
BlobDesc
4FwBlobs
(
fp
,
kModelParallel
,
0
,
1
);
// data_tmp_shape is {10, 17, 6, 6}, and the 17 = 4 + 4 + 4 + 5
Shape
*
data_tmp_shape_ptr
=
bn2shape_ptr
.
at
(
boxing_op
->
SoleDtbn
());
...
...
oneflow/core/operator/clone_op.cpp
浏览文件 @
32022d4d
...
...
@@ -16,12 +16,12 @@ const PbMessage& CloneOp::GetSpecialConf() const {
return
op_conf
().
clone_conf
();
}
void
CloneOp
::
Infer
Shape
4FwBlobs
(
std
::
function
<
Shape
*
(
const
std
::
string
&
)
>
GetShapePtr
4BnInOp
,
void
CloneOp
::
Infer
BlobDesc
4FwBlobs
(
std
::
function
<
BlobDesc
*
(
const
std
::
string
)
>
GetBlobDesc
4BnInOp
,
ParallelPolicy
policy
,
int64_t
parallel_id
,
int64_t
parallel_num
)
const
{
Shape
*
input_shape_ptr
=
GetShapePtr
4BnInOp
(
SoleIbn
());
const
BlobDesc
*
input_blob_desc
=
GetBlobDesc
4BnInOp
(
SoleIbn
());
for
(
std
::
string
obn
:
output_bns
())
{
*
Get
ShapePtr4BnInOp
(
obn
)
=
*
input_shape_ptr
;
*
Get
BlobDesc4BnInOp
(
obn
)
=
*
input_blob_desc
;
}
}
...
...
oneflow/core/operator/clone_op.h
浏览文件 @
32022d4d
...
...
@@ -15,8 +15,8 @@ class CloneOp final : public SysOperator {
void
InitFromOpConf
(
const
OperatorConf
&
op_conf
)
override
;
const
PbMessage
&
GetSpecialConf
()
const
override
;
void
Infer
Shape
4FwBlobs
(
std
::
function
<
Shape
*
(
const
std
::
string
&
)
>
GetShapePtr
4BnInOp
,
void
Infer
BlobDesc
4FwBlobs
(
std
::
function
<
BlobDesc
*
(
const
std
::
string
)
>
GetBlobDesc
4BnInOp
,
ParallelPolicy
policy
,
int64_t
parallel_id
,
int64_t
parallel_num
)
const
override
;
...
...
oneflow/core/operator/clone_op_test.cpp
浏览文件 @
32022d4d
...
...
@@ -18,7 +18,7 @@ TEST(CloneOp, clone_4x3_3_times) {
return
bn2shape_ptr
.
at
(
bn
);
};
clone_op
->
Infer
Shape
4FwBlobs
(
fp
,
kDataParallel
,
3
,
10
);
clone_op
->
Infer
BlobDesc
4FwBlobs
(
fp
,
kDataParallel
,
3
,
10
);
Shape
*
input_shape_ptr
=
bn2shape_ptr
.
at
(
clone_op
->
SoleIbn
());
for
(
std
::
string
obn
:
clone_op
->
output_bns
())
{
...
...
oneflow/core/operator/concat_op.cpp
浏览文件 @
32022d4d
...
...
@@ -18,23 +18,26 @@ const PbMessage& ConcatOp::GetSpecialConf() const {
return
op_conf
().
concat_conf
();
}
void
ConcatOp
::
Infer
Shape
4FwBlobs
(
std
::
function
<
Shape
*
(
const
std
::
string
&
)
>
GetShapePtr
4BnInOp
,
void
ConcatOp
::
Infer
BlobDesc
4FwBlobs
(
std
::
function
<
BlobDesc
*
(
const
std
::
string
)
>
GetBlobDesc
4BnInOp
,
ParallelPolicy
policy
,
int64_t
parallel_id
,
int64_t
parallel_num
)
const
{
std
::
vector
<
int64_t
>
vec
=
GetShapePtr4BnInOp
(
input_bns
().
at
(
0
))
->
dim_vec
();
std
::
vector
<
int64_t
>
vec
=
GetBlobDesc4BnInOp
(
input_bns
().
at
(
0
))
->
shape
().
dim_vec
();
for
(
size_t
ibn_idx
=
1
;
ibn_idx
<
input_bns
().
size
();
++
ibn_idx
)
{
Shape
*
ib_shape
=
GetShapePtr4BnInOp
(
input_bns
().
at
(
ibn_idx
));
const
Shape
&
ib_shape
=
GetBlobDesc4BnInOp
(
input_bns
().
at
(
ibn_idx
))
->
shape
();
int32_t
concat_axis
=
op_conf
().
concat_conf
().
axis
();
for
(
int64_t
j
=
0
;
j
<
ib_shape
->
NumAxes
();
++
j
)
{
if
(
j
==
concat_axis
||
j
==
concat_axis
+
ib_shape
->
NumAxes
())
{
vec
[
j
]
+=
ib_shape
->
At
(
j
);
for
(
int64_t
j
=
0
;
j
<
ib_shape
.
NumAxes
();
++
j
)
{
if
(
j
==
concat_axis
||
j
==
concat_axis
+
ib_shape
.
NumAxes
())
{
vec
[
j
]
+=
ib_shape
.
At
(
j
);
}
else
{
CHECK_EQ
(
vec
[
j
],
ib_shape
->
At
(
j
));
CHECK_EQ
(
vec
[
j
],
ib_shape
.
At
(
j
));
}
}
}
CHECK_EQ
(
vec
.
size
(),
GetShapePtr4BnInOp
(
input_bns
().
at
(
0
))
->
NumAxes
());
*
GetShapePtr4BnInOp
(
SoleObn
())
=
Shape
(
vec
);
CHECK_EQ
(
vec
.
size
(),
GetBlobDesc4BnInOp
(
input_bns
().
at
(
0
))
->
shape
().
NumAxes
());
GetBlobDesc4BnInOp
(
SoleObn
())
->
mut_shape
()
=
Shape
(
vec
);
}
REGISTER_OP
(
OperatorConf
::
kConcatConf
,
ConcatOp
);
...
...
oneflow/core/operator/concat_op.h
浏览文件 @
32022d4d
...
...
@@ -15,8 +15,8 @@ class ConcatOp final : public UserOperator {
const
PbMessage
&
GetSpecialConf
()
const
override
;
void
Infer
Shape
4FwBlobs
(
std
::
function
<
Shape
*
(
const
std
::
string
&
)
>
GetShapePtr
4BnInOp
,
void
Infer
BlobDesc
4FwBlobs
(
std
::
function
<
BlobDesc
*
(
const
std
::
string
)
>
GetBlobDesc
4BnInOp
,
ParallelPolicy
policy
,
int64_t
parallel_id
,
int64_t
parallel_num
)
const
override
;
...
...
oneflow/core/operator/concat_op_test.cpp
浏览文件 @
32022d4d
...
...
@@ -21,7 +21,7 @@ TEST(ConcatOp, concat_two_3x3) {
return
bn2shape_ptr
.
at
(
bn
);
};
// infershape
concat_op
->
Infer
Shape
4FwBlobs
(
fp
,
kDataParallel
,
0
,
1
);
concat_op
->
Infer
BlobDesc
4FwBlobs
(
fp
,
kDataParallel
,
0
,
1
);
// test
Shape
*
output_shape_ptr
=
fp
(
concat_op
->
SoleObn
());
ASSERT_EQ
(
*
output_shape_ptr
,
Shape
({
3
,
6
}));
...
...
oneflow/core/operator/convolution_op.cpp
浏览文件 @
32022d4d
...
...
@@ -22,15 +22,13 @@ const PbMessage& ConvolutionOp::GetSpecialConf() const {
return
op_conf
().
convolution_conf
();
}
void
ConvolutionOp
::
Infer
Shape
4FwBlobs
(
std
::
function
<
Shape
*
(
const
std
::
string
&
)
>
GetShapePtr
4BnInOp
,
void
ConvolutionOp
::
Infer
BlobDesc
4FwBlobs
(
std
::
function
<
BlobDesc
*
(
const
std
::
string
)
>
GetBlobDesc
4BnInOp
,
ParallelPolicy
policy
,
int64_t
parallel_id
,
int64_t
parallel_num
)
const
{
Shape
*
input_shape_ptr
=
GetShapePtr4BnInOp
(
SoleIbn
());
Shape
*
output_shape_ptr
=
GetShapePtr4BnInOp
(
SoleObn
());
Shape
*
colbuf_shape_ptr
=
GetShapePtr4BnInOp
(
"col_buf"
);
const
Shape
&
input_shape
=
GetBlobDesc4BnInOp
(
SoleIbn
())
->
shape
();
auto
conv_conf
=
op_conf
().
convolution_conf
();
int64_t
batch_size
=
input_shape
_ptr
->
At
(
0
);
int64_t
c_i
=
input_shape
_ptr
->
At
(
1
);
int64_t
batch_size
=
input_shape
.
At
(
0
);
int64_t
c_i
=
input_shape
.
At
(
1
);
int32_t
out_num
=
GetInt32FromSpecialConf
(
"out_num"
);
if
(
policy
==
kModelParallel
)
{
...
...
@@ -43,32 +41,29 @@ void ConvolutionOp::InferShape4FwBlobs(
int64_t
output_size
=
1
;
std
::
vector
<
int64_t
>
output_shape_vec
=
{
batch_size
,
c_o
};
int64_t
h_len
=
(
input_shape_ptr
->
At
(
2
)
+
2
*
conv_conf
.
pad_h
()
-
conv_conf
.
kernel_size_h
())
/
conv_conf
.
stride_h
()
+
1
;
int64_t
h_len
=
(
input_shape
.
At
(
2
)
+
2
*
conv_conf
.
pad_h
()
-
conv_conf
.
kernel_size_h
())
/
conv_conf
.
stride_h
()
+
1
;
output_shape_vec
.
push_back
(
h_len
);
int64_t
w_len
=
(
input_shape_ptr
->
At
(
3
)
+
2
*
conv_conf
.
pad_w
()
-
conv_conf
.
kernel_size_w
())
/
conv_conf
.
stride_w
()
+
1
;
int64_t
w_len
=
(
input_shape
.
At
(
3
)
+
2
*
conv_conf
.
pad_w
()
-
conv_conf
.
kernel_size_w
())
/
conv_conf
.
stride_w
()
+
1
;
output_shape_vec
.
push_back
(
w_len
);
kernel_size
*=
conv_conf
.
kernel_size_h
();
kernel_size
*=
conv_conf
.
kernel_size_w
();
output_size
*=
h_len
;
output_size
*=
w_len
;
*
output_shape_ptr
=
Shape
(
output_shape_vec
);
CHECK_EQ
(
output_shape_ptr
->
NumAxes
(),
input_shape_ptr
->
NumAxes
());
*
colbuf_shape_ptr
=
Shape
({
batch_size
,
output_size
,
c_i
*
kernel_size
});
Shape
*
weight
=
GetShapePtr4BnInOp
(
"weight"
);
*
weight
=
Shape
({
c_o
,
c_i
*
kernel_size
});
GetBlobDesc4BnInOp
(
SoleObn
())
->
mut_shape
()
=
Shape
(
output_shape_vec
);
GetBlobDesc4BnInOp
(
"col_buf"
)
->
mut_shape
()
=
Shape
({
batch_size
,
output_size
,
c_i
*
kernel_size
});
GetBlobDesc4BnInOp
(
"weight"
)
->
mut_shape
()
=
Shape
({
c_o
,
c_i
*
kernel_size
});
if
(
GetBoolFromSpecialConf
(
"has_bias_term"
))
{
Shape
*
bias
=
GetShapePtr4BnInOp
(
"bias"
);
Shape
*
biasmult_shape_ptr
=
GetShapePtr4BnInOp
(
"bias_multiplier"
);
*
bias
=
Shape
({
c_o
});
*
biasmult_shape_ptr
=
Shape
({
output_size
});
GetBlobDesc4BnInOp
(
"bias"
)
->
mut_shape
()
=
Shape
({
c_o
});
GetBlobDesc4BnInOp
(
"bias_multiplier"
)
->
mut_shape
()
=
Shape
({
output_size
});
}
}
...
...
oneflow/core/operator/convolution_op.h
浏览文件 @
32022d4d
...
...
@@ -13,8 +13,8 @@ class ConvolutionOp final : public UserOperator {
void
InitFromOpConf
(
const
OperatorConf
&
op_conf
)
override
;
const
PbMessage
&
GetSpecialConf
()
const
override
;
void
Infer
Shape
4FwBlobs
(
std
::
function
<
Shape
*
(
const
std
::
string
&
)
>
GetShapePtr
4BnInOp
,
void
Infer
BlobDesc
4FwBlobs
(
std
::
function
<
BlobDesc
*
(
const
std
::
string
)
>
GetBlobDesc
4BnInOp
,
ParallelPolicy
policy
,
int64_t
parallel_id
,
int64_t
parallel_num
)
const
override
;
void
FixParallelDesc
(
ParallelDesc
*
pr_desc
)
const
override
{
...
...
oneflow/core/operator/convolution_op_test.cpp
浏览文件 @
32022d4d
...
...
@@ -37,7 +37,7 @@ void TestDataParallelConvolutionOp() {
};
// infershape
convolution_op
->
Infer
Shape
4FwBlobs
(
fp
,
kDataParallel
,
0
,
1
);
convolution_op
->
Infer
BlobDesc
4FwBlobs
(
fp
,
kDataParallel
,
0
,
1
);
// test
Shape
*
output_shape_ptr
=
fp
(
convolution_op
->
SoleObn
());
...
...
@@ -71,7 +71,7 @@ void TestModelParallelConvolutionOp() {
};
// infershape
convolution_op
->
Infer
Shape
4FwBlobs
(
fp
,
kModelParallel
,
3
,
8
);
convolution_op
->
Infer
BlobDesc
4FwBlobs
(
fp
,
kModelParallel
,
3
,
8
);
// test
Shape
*
output_shape_ptr
=
fp
(
convolution_op
->
SoleObn
());
...
...
oneflow/core/operator/data_loader_op.cpp
浏览文件 @
32022d4d
...
...
@@ -15,8 +15,8 @@ const PbMessage& DataLoaderOp::GetSpecialConf() const {
return
op_conf
().
data_loader_conf
();
}
void
DataLoaderOp
::
Infer
Shape
4FwBlobs
(
std
::
function
<
Shape
*
(
const
std
::
string
&
)
>
GetShapePtr
4BnInOp
,
void
DataLoaderOp
::
Infer
BlobDesc
4FwBlobs
(
std
::
function
<
BlobDesc
*
(
const
std
::
string
)
>
GetBlobDesc
4BnInOp
,
ParallelPolicy
policy
,
int64_t
parallel_id
,
int64_t
parallel_num
)
const
{
// useful vars
int32_t
piece_size
=
JobDesc
::
Singleton
()
->
piece_size
();
...
...
@@ -27,9 +27,9 @@ void DataLoaderOp::InferShape4FwBlobs(
feature_shape
.
insert
(
feature_shape
.
end
(),
feature_shape_of_one_ins
.
dim_vec
().
begin
(),
feature_shape_of_one_ins
.
dim_vec
().
end
());
*
GetShapePtr4BnInOp
(
"feature"
)
=
Shape
(
feature_shape
);
GetBlobDesc4BnInOp
(
"feature"
)
->
mut_shape
(
)
=
Shape
(
feature_shape
);
// label shape
*
GetShapePtr4BnInOp
(
"label"
)
=
Shape
({
piece_size
});
GetBlobDesc4BnInOp
(
"label"
)
->
mut_shape
(
)
=
Shape
({
piece_size
});
}
REGISTER_OP
(
OperatorConf
::
kDataLoaderConf
,
DataLoaderOp
);
...
...
oneflow/core/operator/data_loader_op.h
浏览文件 @
32022d4d
...
...
@@ -14,8 +14,8 @@ class DataLoaderOp final : public SysOperator {
void
InitFromOpConf
(
const
OperatorConf
&
op_conf
)
override
;
const
PbMessage
&
GetSpecialConf
()
const
override
;
void
Infer
Shape
4FwBlobs
(
std
::
function
<
Shape
*
(
const
std
::
string
&
)
>
GetShapePtr
4BnInOp
,
void
Infer
BlobDesc
4FwBlobs
(
std
::
function
<
BlobDesc
*
(
const
std
::
string
)
>
GetBlobDesc
4BnInOp
,
ParallelPolicy
policy
,
int64_t
parallel_id
,
int64_t
parallel_num
)
const
override
;
...
...
oneflow/core/operator/innerproduct_op.cpp
浏览文件 @
32022d4d
...
...
@@ -21,10 +21,10 @@ const PbMessage& InnerProductOp::GetSpecialConf() const {
return
op_conf
().
innerproduct_conf
();
}
void
InnerProductOp
::
Infer
Shape
4FwBlobs
(
std
::
function
<
Shape
*
(
const
std
::
string
&
)
>
GetShapePtr
4BnInOp
,
void
InnerProductOp
::
Infer
BlobDesc
4FwBlobs
(
std
::
function
<
BlobDesc
*
(
const
std
::
string
)
>
GetBlobDesc
4BnInOp
,
ParallelPolicy
policy
,
int64_t
parallel_id
,
int64_t
parallel_num
)
const
{
Shape
*
in_shape_ptr
=
GetShapePtr4BnInOp
(
SoleIbn
()
);
const
Shape
&
in_shape
=
GetBlobDesc4BnInOp
(
SoleIbn
())
->
shape
(
);
int32_t
out_num
=
GetInt32FromSpecialConf
(
"out_num"
);
if
(
policy
==
kModelParallel
)
{
BalancedSplitter
splitter
(
out_num
,
parallel_num
);
...
...
@@ -32,22 +32,20 @@ void InnerProductOp::InferShape4FwBlobs(
}
// output bn
Shape
*
out_shape_ptr
=
GetShapePtr4BnInOp
(
SoleObn
());
*
out_shape_ptr
=
Shape
({
in_shape_ptr
->
At
(
0
),
out_num
});
GetBlobDesc4BnInOp
(
SoleObn
())
->
mut_shape
()
=
Shape
({
in_shape
.
At
(
0
),
out_num
});
// model bn
Shape
*
weight_shape_ptr
=
GetShapePtr4BnInOp
(
"weight"
);
*
weight_shape_ptr
=
Shape
({
out_num
,
in_shape_ptr
->
Count
(
1
)});
GetBlobDesc4BnInOp
(
"weight"
)
->
mut_shape
()
=
Shape
({
out_num
,
in_shape
.
Count
(
1
)});
if
(
GetBoolFromSpecialConf
(
"has_bias_term"
))
{
// model bn
Shape
*
bias_shape_ptr
=
GetShapePtr4BnInOp
(
"bias"
);
*
bias_shape_ptr
=
Shape
({
1
,
out_num
});
GetBlobDesc4BnInOp
(
"bias"
)
->
mut_shape
()
=
Shape
({
1
,
out_num
});
// model tmp bn
CHECK_EQ
(
model_tmp_bns
().
size
(),
1
);
Shape
*
bias_multiplier_shape_ptr
=
GetShapePtr4BnInOp
(
"bias_multiplier"
);
*
bias_multiplier_shape_ptr
=
Shape
({
in_shape_ptr
->
At
(
0
),
1
});
GetBlobDesc4BnInOp
(
"bias_multiplier"
)
->
mut_shape
()
=
Shape
({
in_shape
.
At
(
0
),
1
});
}
}
...
...
oneflow/core/operator/innerproduct_op.h
浏览文件 @
32022d4d
...
...
@@ -13,8 +13,8 @@ class InnerProductOp final : public UserOperator {
void
InitFromOpConf
(
const
OperatorConf
&
op_conf
)
override
;
const
PbMessage
&
GetSpecialConf
()
const
override
;
void
Infer
Shape
4FwBlobs
(
std
::
function
<
Shape
*
(
const
std
::
string
&
)
>
GetShapePtr
4BnInOp
,
void
Infer
BlobDesc
4FwBlobs
(
std
::
function
<
BlobDesc
*
(
const
std
::
string
)
>
GetBlobDesc
4BnInOp
,
ParallelPolicy
policy
,
int64_t
parallel_id
,
int64_t
parallel_num
)
const
override
;
void
FixParallelDesc
(
ParallelDesc
*
pr_desc
)
const
override
{
...
...
oneflow/core/operator/innerproduct_op_test.cpp
浏览文件 @
32022d4d
...
...
@@ -29,7 +29,7 @@ void TestModelParallelInnerProductOp(bool has_bias_term) {
return
bn2shape_ptr
.
at
(
bn
);
};
ip_op
->
Infer
Shape
4FwBlobs
(
fp
,
kModelParallel
,
3
,
10
);
ip_op
->
Infer
BlobDesc
4FwBlobs
(
fp
,
kModelParallel
,
3
,
10
);
BalancedSplitter
splitter
(
40
,
10
);
int
out_num
=
splitter
.
At
(
3
).
size
();
...
...
@@ -70,7 +70,7 @@ void TestDataParallelInnerProductOp(bool has_bias_term) {
return
bn2shape_ptr
.
at
(
bn
);
};
ip_op
->
Infer
Shape
4FwBlobs
(
fp
,
kDataParallel
,
3
,
10
);
ip_op
->
Infer
BlobDesc
4FwBlobs
(
fp
,
kDataParallel
,
3
,
10
);
Shape
*
out_shape_ptr
=
bn2shape_ptr
.
at
(
ip_op
->
SoleObn
());
CHECK_EQ
(
*
out_shape_ptr
,
Shape
({
1000
,
40
}));
...
...
oneflow/core/operator/model_update_op.h
浏览文件 @
32022d4d
...
...
@@ -10,8 +10,8 @@ class ModelUpdtOp : public SysOperator {
OF_DISALLOW_COPY_AND_MOVE
(
ModelUpdtOp
);
virtual
~
ModelUpdtOp
()
=
default
;
virtual
void
Infer
Shape
4FwBlobs
(
std
::
function
<
Shape
*
(
const
std
::
string
&
)
>
GetShapePtr
4BnInOp
,
virtual
void
Infer
BlobDesc
4FwBlobs
(
std
::
function
<
BlobDesc
*
(
const
std
::
string
)
>
GetBlobDesc
4BnInOp
,
ParallelPolicy
policy
,
int64_t
parallel_id
,
int64_t
parallel_num
)
const
override
{
// do nothing
...
...
oneflow/core/operator/momentum_model_update_op.cpp
浏览文件 @
32022d4d
...
...
@@ -15,11 +15,10 @@ const PbMessage& MomentumModelUpdateOp::GetSpecialConf() const {
return
op_conf
().
momentum_mdupdt_conf
();
}
void
MomentumModelUpdateOp
::
Infer
Shape
4FwBlobs
(
std
::
function
<
Shape
*
(
const
std
::
string
&
)
>
GetShapePtr
4BnInOp
,
void
MomentumModelUpdateOp
::
Infer
BlobDesc
4FwBlobs
(
std
::
function
<
BlobDesc
*
(
const
std
::
string
)
>
GetBlobDesc
4BnInOp
,
ParallelPolicy
policy
,
int64_t
parallel_id
,
int64_t
parallel_num
)
const
{
Shape
*
input_shape_ptr
=
GetShapePtr4BnInOp
(
"model_diffs"
);
*
GetShapePtr4BnInOp
(
"momentum"
)
=
*
input_shape_ptr
;
TODO
();
}
REGISTER_OP
(
OperatorConf
::
kMomentumMdupdtConf
,
MomentumModelUpdateOp
);
...
...
oneflow/core/operator/momentum_model_update_op.h
浏览文件 @
32022d4d
...
...
@@ -13,8 +13,8 @@ class MomentumModelUpdateOp final : public ModelUpdtOp {
void
InitFromOpConf
(
const
OperatorConf
&
op_conf
)
override
;
const
PbMessage
&
GetSpecialConf
()
const
override
;
void
Infer
Shape
4FwBlobs
(
std
::
function
<
Shape
*
(
const
std
::
string
&
)
>
GetShapePtr
4BnInOp
,
void
Infer
BlobDesc
4FwBlobs
(
std
::
function
<
BlobDesc
*
(
const
std
::
string
)
>
GetBlobDesc
4BnInOp
,
ParallelPolicy
policy
,
int64_t
parallel_id
,
int64_t
parallel_num
)
const
override
;
...
...
oneflow/core/operator/multinomial_logistic_loss_op.cpp
浏览文件 @
32022d4d
...
...
@@ -16,11 +16,11 @@ const PbMessage& MultinomialLogisticLossOp::GetSpecialConf() const {
return
op_conf
().
multinomial_logistic_loss_conf
();
}
void
MultinomialLogisticLossOp
::
Infer
Shape
4FwBlobs
(
std
::
function
<
Shape
*
(
const
std
::
string
&
)
>
GetShapePtr
4BnInOp
,
void
MultinomialLogisticLossOp
::
Infer
BlobDesc
4FwBlobs
(
std
::
function
<
BlobDesc
*
(
const
std
::
string
)
>
GetBlobDesc
4BnInOp
,
ParallelPolicy
policy
,
int64_t
parallel_id
,
int64_t
parallel_num
)
const
{
*
GetShapePtr4BnInOp
(
"loss"
)
=
Shape
({
1
});
*
GetShapePtr4BnInOp
(
"loss_buffer"
)
=
Shape
({
1
});
GetBlobDesc4BnInOp
(
"loss"
)
->
mut_shape
(
)
=
Shape
({
1
});
GetBlobDesc4BnInOp
(
"loss_buffer"
)
->
mut_shape
(
)
=
Shape
({
1
});
}
REGISTER_OP
(
OperatorConf
::
kMultinomialLogisticLossConf
,
...
...
oneflow/core/operator/multinomial_logistic_loss_op.h
浏览文件 @
32022d4d
...
...
@@ -17,8 +17,8 @@ class MultinomialLogisticLossOp final : public UserOperator {
const
PbMessage
&
GetSpecialConf
()
const
override
;
bool
IsLossOp
()
const
override
{
return
true
;
}
void
Infer
Shape
4FwBlobs
(
std
::
function
<
Shape
*
(
const
std
::
string
&
)
>
GetShapePtr
4BnInOp
,
void
Infer
BlobDesc
4FwBlobs
(
std
::
function
<
BlobDesc
*
(
const
std
::
string
)
>
GetBlobDesc
4BnInOp
,
ParallelPolicy
policy
,
int64_t
parallel_id
,
int64_t
parallel_num
)
const
override
;
...
...
oneflow/core/operator/multinomial_logistic_loss_op_test.cpp
浏览文件 @
32022d4d
...
...
@@ -21,7 +21,7 @@ TEST(MultinomialLogisticLossOp, test_loss_op) {
return
bn2shape_ptr
.
at
(
bn
);
};
loss_op
->
Infer
Shape
4FwBlobs
(
fp
,
kDataParallel
,
2
,
10
);
loss_op
->
Infer
BlobDesc
4FwBlobs
(
fp
,
kDataParallel
,
2
,
10
);
Shape
*
loss_shape_ptr
=
bn2shape_ptr
.
at
(
loss_op
->
SoleObn
());
Shape
*
loss_buffer_shape_ptr
=
bn2shape_ptr
.
at
(
loss_op
->
SoleDtbn
());
...
...
oneflow/core/operator/operator.h
浏览文件 @
32022d4d
...
...
@@ -2,13 +2,13 @@
#define ONEFLOW_CORE_OPERATOR_OPERATOR_H_
#include "oneflow/core/common/protobuf.h"
#include "oneflow/core/common/shape.h"
#include "oneflow/core/common/util.h"
#include "oneflow/core/job/keyword.h"
#include "oneflow/core/job/parallel_desc.h"
#include "oneflow/core/job/placement.pb.h"
#include "oneflow/core/operator/op_conf.pb.h"
#include "oneflow/core/operator/operator.pb.h"
#include "oneflow/core/register/blob_desc.h"
namespace
oneflow
{
...
...
@@ -79,8 +79,8 @@ class Operator {
// Read: shape of input_blobs
// Write: shape of output_blobs, model_blobs, data_tmp_blobs, model_tmp_blobs
virtual
void
Infer
Shape
4FwBlobs
(
std
::
function
<
Shape
*
(
const
std
::
string
&
)
>
GetShapePtr
4BnInOp
,
virtual
void
Infer
BlobDesc
4FwBlobs
(
std
::
function
<
BlobDesc
*
(
const
std
::
string
)
>
GetBlobDesc
4BnInOp
,
ParallelPolicy
policy
,
int64_t
parallel_id
,
int64_t
parallel_num
)
const
=
0
;
...
...
@@ -144,8 +144,8 @@ class SysOperator : public Operator {
SysOperator
()
=
default
;
virtual
~
SysOperator
()
=
default
;
virtual
void
Infer
Shape
4FwBlobs
(
std
::
function
<
Shape
*
(
const
std
::
string
&
)
>
GetShapePtr
4BnInOp
,
virtual
void
Infer
BlobDesc
4FwBlobs
(
std
::
function
<
BlobDesc
*
(
const
std
::
string
)
>
GetBlobDesc
4BnInOp
,
ParallelPolicy
policy
,
int64_t
parallel_id
,
int64_t
parallel_num
)
const
override
{
UNEXPECTED_RUN
();
...
...
oneflow/core/operator/pooling_op.cpp
浏览文件 @
32022d4d
...
...
@@ -15,32 +15,30 @@ const PbMessage& PoolingOp::GetSpecialConf() const {
return
op_conf
().
pooling_conf
();
}
void
PoolingOp
::
Infer
Shape
4FwBlobs
(
std
::
function
<
Shape
*
(
const
std
::
string
&
)
>
GetShapePtr
4BnInOp
,
void
PoolingOp
::
Infer
BlobDesc
4FwBlobs
(
std
::
function
<
BlobDesc
*
(
const
std
::
string
)
>
GetBlobDesc
4BnInOp
,
ParallelPolicy
policy
,
int64_t
parallel_id
,
int64_t
parallel_num
)
const
{
Shape
*
input_shape_ptr
=
GetShapePtr4BnInOp
(
SoleIbn
()
);
CHECK_EQ
(
input_shape
_ptr
->
NumAxes
(),
4
);
Shape
*
output_shape_ptr
=
GetShapePtr
4BnInOp
(
SoleObn
());
const
Shape
&
input_shape
=
GetBlobDesc4BnInOp
(
SoleIbn
())
->
shape
(
);
CHECK_EQ
(
input_shape
.
NumAxes
(),
4
);
BlobDesc
*
output_blob_desc
=
GetBlobDesc
4BnInOp
(
SoleObn
());
const
PoolingOpConf
&
pooling_conf
=
op_conf
().
pooling_conf
();
std
::
vector
<
int64_t
>
output_shape_dim_vec
=
{
input_shape
_ptr
->
At
(
0
),
input_shape
_ptr
->
At
(
1
)};
std
::
vector
<
int64_t
>
output_shape_dim_vec
=
{
input_shape
.
At
(
0
),
input_shape
.
At
(
1
)};
output_shape_dim_vec
.
push_back
((
input_shape_ptr
->
At
(
2
)
+
2
*
pooling_conf
.
pad_h
()
output_shape_dim_vec
.
push_back
((
input_shape
.
At
(
2
)
+
2
*
pooling_conf
.
pad_h
()
-
pooling_conf
.
kernel_size_h
())
/
pooling_conf
.
stride_h
()
+
1
);
output_shape_dim_vec
.
push_back
((
input_shape_ptr
->
At
(
3
)
+
2
*
pooling_conf
.
pad_w
()
output_shape_dim_vec
.
push_back
((
input_shape
.
At
(
3
)
+
2
*
pooling_conf
.
pad_w
()
-
pooling_conf
.
kernel_size_w
())
/
pooling_conf
.
stride_w
()
+
1
);
*
output_shape_ptr
=
Shape
(
output_shape_dim_vec
);
Shape
*
data_tmp_shape_ptr
=
GetShapePtr
4BnInOp
(
SoleDtbn
());
*
data_tmp_shape_ptr
=
Shape
(
output_shape_dim_vec
);
output_blob_desc
->
mut_shape
()
=
Shape
(
output_shape_dim_vec
);
BlobDesc
*
data_tmp_blob_desc
=
GetBlobDesc
4BnInOp
(
SoleDtbn
());
data_tmp_blob_desc
->
mut_shape
()
=
Shape
(
output_shape_dim_vec
);
}
REGISTER_OP
(
OperatorConf
::
kPoolingConf
,
PoolingOp
);
...
...
oneflow/core/operator/pooling_op.h
浏览文件 @
32022d4d
...
...
@@ -16,8 +16,8 @@ class PoolingOp final : public UserOperator {
void
InitFromOpConf
(
const
OperatorConf
&
op_conf
)
override
;
const
PbMessage
&
GetSpecialConf
()
const
override
;
void
Infer
Shape
4FwBlobs
(
std
::
function
<
Shape
*
(
const
std
::
string
&
)
>
GetShapePtr
4BnInOp
,
void
Infer
BlobDesc
4FwBlobs
(
std
::
function
<
BlobDesc
*
(
const
std
::
string
)
>
GetBlobDesc
4BnInOp
,
ParallelPolicy
policy
,
int64_t
parallel_id
,
int64_t
parallel_num
)
const
override
;
...
...
oneflow/core/operator/pooling_op_test.cpp
浏览文件 @
32022d4d
...
...
@@ -30,7 +30,7 @@ TEST(PoolingOp, pool_100x64x11x11) {
return
bn2shape_ptr
.
at
(
bn
);
};
// do infer shape
pooling_op
->
Infer
Shape
4FwBlobs
(
fp
,
kDataParallel
,
0
,
1
);
pooling_op
->
Infer
BlobDesc
4FwBlobs
(
fp
,
kDataParallel
,
0
,
1
);
// test
Shape
*
output_shape_ptr
=
bn2shape_ptr
.
at
(
pooling_op
->
SoleObn
());
Shape
*
data_tmp_shape_ptr
=
bn2shape_ptr
.
at
(
pooling_op
->
SoleDtbn
());
...
...
oneflow/core/operator/record_op.h
浏览文件 @
32022d4d
...
...
@@ -15,8 +15,8 @@ class RecordOp final : public SysOperator {
const
PbMessage
&
GetSpecialConf
()
const
override
;
bool
IsRecordOp
()
const
override
{
return
true
;
}
void
Infer
Shape
4FwBlobs
(
std
::
function
<
Shape
*
(
const
std
::
string
&
)
>
GetShapePtr
4BnInOp
,
void
Infer
BlobDesc
4FwBlobs
(
std
::
function
<
BlobDesc
*
(
const
std
::
string
)
>
GetBlobDesc
4BnInOp
,
ParallelPolicy
policy
,
int64_t
parallel_id
,
int64_t
parallel_num
)
const
override
{}
...
...
oneflow/core/operator/relu_op.cpp
浏览文件 @
32022d4d
...
...
@@ -14,12 +14,10 @@ const PbMessage& ReluOp::GetSpecialConf() const {
return
op_conf
().
relu_conf
();
}
void
ReluOp
::
Infer
Shape
4FwBlobs
(
std
::
function
<
Shape
*
(
const
std
::
string
&
)
>
GetShapePtr
4BnInOp
,
void
ReluOp
::
Infer
BlobDesc
4FwBlobs
(
std
::
function
<
BlobDesc
*
(
const
std
::
string
)
>
GetBlobDesc
4BnInOp
,
ParallelPolicy
policy
,
int64_t
parallel_id
,
int64_t
parallel_num
)
const
{
Shape
*
output_shape_ptr
=
GetShapePtr4BnInOp
(
SoleObn
());
Shape
*
input_shape_ptr
=
GetShapePtr4BnInOp
(
SoleIbn
());
*
output_shape_ptr
=
*
input_shape_ptr
;
*
GetBlobDesc4BnInOp
(
SoleObn
())
=
*
GetBlobDesc4BnInOp
(
SoleIbn
());
}
REGISTER_OP
(
OperatorConf
::
kReluConf
,
ReluOp
);
...
...
oneflow/core/operator/relu_op.h
浏览文件 @
32022d4d
...
...
@@ -15,8 +15,8 @@ class ReluOp final : public UserOperator {
const
PbMessage
&
GetSpecialConf
()
const
override
;
bool
IsElemWise
()
const
override
{
return
true
;
}
void
Infer
Shape
4FwBlobs
(
std
::
function
<
Shape
*
(
const
std
::
string
&
)
>
GetShapePtr
4BnInOp
,
void
Infer
BlobDesc
4FwBlobs
(
std
::
function
<
BlobDesc
*
(
const
std
::
string
)
>
GetBlobDesc
4BnInOp
,
ParallelPolicy
policy
,
int64_t
parallel_id
,
int64_t
parallel_num
)
const
override
;
...
...
oneflow/core/operator/relu_op_test.cpp
浏览文件 @
32022d4d
...
...
@@ -17,7 +17,7 @@ TEST(ReluOp, relu_3x5x4) {
return
bn2shape_ptr
.
at
(
bn
);
};
// do infer shape
relu_op
->
Infer
Shape
4FwBlobs
(
fp
,
kDataParallel
,
0
,
1
);
relu_op
->
Infer
BlobDesc
4FwBlobs
(
fp
,
kDataParallel
,
0
,
1
);
// test
Shape
*
input_shape_ptr
=
bn2shape_ptr
.
at
(
relu_op
->
SoleIbn
());
Shape
*
output_shape_ptr
=
bn2shape_ptr
.
at
(
relu_op
->
SoleObn
());
...
...
oneflow/core/operator/rmsprop_model_update_op.cpp
浏览文件 @
32022d4d
...
...
@@ -15,11 +15,10 @@ const PbMessage& RMSPropModelUpdateOp::GetSpecialConf() const {
return
op_conf
().
rmsprop_mdupdt_conf
();
}
void
RMSPropModelUpdateOp
::
Infer
Shape
4FwBlobs
(
std
::
function
<
Shape
*
(
const
std
::
string
&
)
>
GetShapePtr
4BnInOp
,
void
RMSPropModelUpdateOp
::
Infer
BlobDesc
4FwBlobs
(
std
::
function
<
BlobDesc
*
(
const
std
::
string
)
>
GetBlobDesc
4BnInOp
,
ParallelPolicy
policy
,
int64_t
parallel_id
,
int64_t
parallel_num
)
const
{
Shape
*
input_shape_ptr
=
GetShapePtr4BnInOp
(
"model_diffs"
);
*
GetShapePtr4BnInOp
(
"mean_square"
)
=
*
input_shape_ptr
;
TODO
();
}
REGISTER_OP
(
OperatorConf
::
kRmspropMdupdtConf
,
RMSPropModelUpdateOp
);
...
...
oneflow/core/operator/rmsprop_model_update_op.h
浏览文件 @
32022d4d
...
...
@@ -13,8 +13,8 @@ class RMSPropModelUpdateOp final : public ModelUpdtOp {
void
InitFromOpConf
(
const
OperatorConf
&
op_conf
)
override
;
const
PbMessage
&
GetSpecialConf
()
const
override
;
void
Infer
Shape
4FwBlobs
(
std
::
function
<
Shape
*
(
const
std
::
string
&
)
>
GetShapePtr
4BnInOp
,
void
Infer
BlobDesc
4FwBlobs
(
std
::
function
<
BlobDesc
*
(
const
std
::
string
)
>
GetBlobDesc
4BnInOp
,
ParallelPolicy
policy
,
int64_t
parallel_id
,
int64_t
parallel_num
)
const
override
;
...
...
oneflow/core/operator/softmax_loss_op.cpp
浏览文件 @
32022d4d
...
...
@@ -17,16 +17,16 @@ const PbMessage& SoftmaxLossOp::GetSpecialConf() const {
return
op_conf
().
softmax_loss_conf
();
}
void
SoftmaxLossOp
::
Infer
Shape
4FwBlobs
(
std
::
function
<
Shape
*
(
const
std
::
string
&
)
>
GetShapePtr
4BnInOp
,
void
SoftmaxLossOp
::
Infer
BlobDesc
4FwBlobs
(
std
::
function
<
BlobDesc
*
(
const
std
::
string
)
>
GetBlobDesc
4BnInOp
,
ParallelPolicy
policy
,
int64_t
parallel_id
,
int64_t
parallel_num
)
const
{
const
std
::
vector
<
int64_t
>
in_dim_vec
=
Get
ShapePtr4BnInOp
(
"prediction"
)
->
dim_vec
();
Get
BlobDesc4BnInOp
(
"prediction"
)
->
shape
().
dim_vec
();
CHECK_EQ
(
in_dim_vec
.
size
(),
2
);
CHECK_EQ
(
*
GetShapePtr4BnInOp
(
"label"
),
Shape
({
in_dim_vec
[
0
]}));
*
GetShapePtr4BnInOp
(
SoleObn
()
)
=
Shape
({
1
});
*
GetShapePtr4BnInOp
(
"prob"
)
=
Shape
(
in_dim_vec
);
*
GetShapePtr4BnInOp
(
"tmp_1D"
)
=
Shape
({
in_dim_vec
[
0
]});
CHECK_EQ
(
GetBlobDesc4BnInOp
(
"label"
)
->
shape
(
),
Shape
({
in_dim_vec
[
0
]}));
GetBlobDesc4BnInOp
(
SoleObn
())
->
mut_shape
(
)
=
Shape
({
1
});
GetBlobDesc4BnInOp
(
"prob"
)
->
mut_shape
(
)
=
Shape
(
in_dim_vec
);
GetBlobDesc4BnInOp
(
"tmp_1D"
)
->
mut_shape
(
)
=
Shape
({
in_dim_vec
[
0
]});
}
REGISTER_OP
(
OperatorConf
::
kSoftmaxLossConf
,
SoftmaxLossOp
);
...
...
oneflow/core/operator/softmax_loss_op.h
浏览文件 @
32022d4d
...
...
@@ -15,8 +15,8 @@ class SoftmaxLossOp final : public UserOperator {
const
PbMessage
&
GetSpecialConf
()
const
override
;
bool
IsLossOp
()
const
override
{
return
true
;
}
void
Infer
Shape
4FwBlobs
(
std
::
function
<
Shape
*
(
const
std
::
string
&
)
>
GetShapePtr
4BnInOp
,
void
Infer
BlobDesc
4FwBlobs
(
std
::
function
<
BlobDesc
*
(
const
std
::
string
)
>
GetBlobDesc
4BnInOp
,
ParallelPolicy
policy
,
int64_t
parallel_id
,
int64_t
parallel_num
)
const
override
;
...
...
oneflow/core/operator/softmax_loss_op_test.cpp
浏览文件 @
32022d4d
...
...
@@ -20,7 +20,7 @@ TEST(SoftmaxLossOp, softmax_loss_3x5) {
return
bn2shape_ptr
.
at
(
bn
);
};
// infershape
softmax_loss_op
->
Infer
Shape
4FwBlobs
(
fp
,
kDataParallel
,
0
,
1
);
softmax_loss_op
->
Infer
BlobDesc
4FwBlobs
(
fp
,
kDataParallel
,
0
,
1
);
// test
ASSERT_EQ
(
*
fp
(
"loss"
),
Shape
({
1
}));
ASSERT_EQ
(
*
fp
(
"prob"
),
Shape
({
3
,
5
}));
...
...
oneflow/core/operator/softmax_op.cpp
浏览文件 @
32022d4d
...
...
@@ -15,13 +15,14 @@ const PbMessage& SoftmaxOp::GetSpecialConf() const {
return
op_conf
().
softmax_conf
();
}
void
SoftmaxOp
::
Infer
Shape
4FwBlobs
(
std
::
function
<
Shape
*
(
const
std
::
string
&
)
>
GetShapePtr
4BnInOp
,
void
SoftmaxOp
::
Infer
BlobDesc
4FwBlobs
(
std
::
function
<
BlobDesc
*
(
const
std
::
string
)
>
GetBlobDesc
4BnInOp
,
ParallelPolicy
policy
,
int64_t
parallel_id
,
int64_t
parallel_num
)
const
{
std
::
vector
<
int64_t
>
vec
=
GetShapePtr4BnInOp
(
SoleIbn
())
->
dim_vec
();
const
std
::
vector
<
int64_t
>&
vec
=
GetBlobDesc4BnInOp
(
SoleIbn
())
->
shape
().
dim_vec
();
CHECK_EQ
(
vec
.
size
(),
2
);
*
GetShapePtr4BnInOp
(
SoleObn
()
)
=
Shape
(
vec
);
*
GetShapePtr4BnInOp
(
SoleDtbn
()
)
=
Shape
({
vec
[
0
]});
GetBlobDesc4BnInOp
(
SoleObn
())
->
mut_shape
(
)
=
Shape
(
vec
);
GetBlobDesc4BnInOp
(
SoleDtbn
())
->
mut_shape
(
)
=
Shape
({
vec
[
0
]});
}
REGISTER_OP
(
OperatorConf
::
kSoftmaxConf
,
SoftmaxOp
);
...
...
oneflow/core/operator/softmax_op.h
浏览文件 @
32022d4d
...
...
@@ -14,8 +14,8 @@ class SoftmaxOp final : public UserOperator {
void
InitFromOpConf
(
const
OperatorConf
&
op_conf
)
override
;
const
PbMessage
&
GetSpecialConf
()
const
override
;
void
Infer
Shape
4FwBlobs
(
std
::
function
<
Shape
*
(
const
std
::
string
&
)
>
GetShapePtr
4BnInOp
,
void
Infer
BlobDesc
4FwBlobs
(
std
::
function
<
BlobDesc
*
(
const
std
::
string
)
>
GetBlobDesc
4BnInOp
,
ParallelPolicy
policy
,
int64_t
parallel_id
,
int64_t
parallel_num
)
const
override
;
...
...
oneflow/core/operator/softmax_op_test.cpp
浏览文件 @
32022d4d
...
...
@@ -17,7 +17,7 @@ TEST(SoftmaxOp, softmax_3x5) {
return
bn2shape_ptr
.
at
(
bn
);
};
// infershape
softmax_op
->
Infer
Shape
4FwBlobs
(
fp
,
kDataParallel
,
0
,
1
);
softmax_op
->
Infer
BlobDesc
4FwBlobs
(
fp
,
kDataParallel
,
0
,
1
);
// test
Shape
*
output_shape_ptr
=
fp
(
softmax_op
->
SoleObn
());
Shape
*
tmp_shape_ptr
=
fp
(
softmax_op
->
SoleDtbn
());
...
...
oneflow/core/register/blob.h
浏览文件 @
32022d4d
...
...
@@ -6,7 +6,7 @@
namespace
oneflow
{
class
Blob
{
class
Blob
final
{
public:
OF_DISALLOW_COPY_AND_MOVE
(
Blob
);
Blob
(
void
*
dptr
,
const
Shape
*
shape
)
:
dptr_
(
dptr
),
shape_
(
shape
)
{}
...
...
oneflow/core/register/blob_desc.h
0 → 100644
浏览文件 @
32022d4d
#ifndef ONEFLOW_CORE_REGISTER_BLOB_DESC_H_
#define ONEFLOW_CORE_REGISTER_BLOB_DESC_H_
#include "oneflow/core/common/shape.h"
#include "oneflow/core/register/blob_desc.pb.h"
namespace
oneflow
{
class
BlobDesc
final
{
public:
// OF_DISALLOW_COPY_AND_MOVE(BlobDesc);
BlobDesc
()
=
default
;
~
BlobDesc
()
=
default
;
BlobDesc
(
const
BlobDescProto
&
proto
)
{
shape_
=
Shape
(
proto
.
shape
());
}
const
Shape
&
shape
()
const
{
return
shape_
;
}
Shape
&
mut_shape
()
{
return
shape_
;
}
void
ToProto
(
BlobDescProto
*
proto
)
const
{
shape_
.
ToProto
(
proto
->
mutable_shape
());
}
private:
Shape
shape_
;
};
}
// namespace oneflow
#endif // ONEFLOW_CORE_REGISTER_BLOB_DESC_H_
oneflow/core/register/blob_desc.proto
0 → 100644
浏览文件 @
32022d4d
syntax
=
"proto3"
;
package
oneflow
;
import
"oneflow/core/common/shape.proto"
;
message
BlobDescProto
{
ShapeProto
shape
=
1
;
}
oneflow/core/register/register_desc.cpp
浏览文件 @
32022d4d
...
...
@@ -40,57 +40,57 @@ void RegstDesc::AddConsumer(const TaskNode* new_consumer) {
}
void
RegstDesc
::
CopyLbnFrom
(
const
RegstDesc
*
rhs
)
{
lbn2shape_
.
clear
(
);
for
(
const
auto
&
pair
:
rhs
->
lbn2
shape
_
)
{
CHECK
(
lbn2blob_desc_
.
empty
()
);
for
(
const
auto
&
pair
:
rhs
->
lbn2
blob_desc
_
)
{
const
std
::
string
&
lbn
=
pair
.
first
;
auto
shape
=
of_make_unique
<
Shape
>
();
CHECK
(
lbn2shape_
.
emplace
(
lbn
,
std
::
move
(
shape
)).
second
);
CHECK
(
lbn2blob_desc_
.
emplace
(
lbn
,
of_make_unique
<
BlobDesc
>
()).
second
);
}
}
void
RegstDesc
::
Copy
Shape
From
(
const
RegstDesc
*
rhs
)
{
for
(
const
auto
&
pair
:
lbn2
shape
_
)
{
void
RegstDesc
::
Copy
BlobDesc
From
(
const
RegstDesc
*
rhs
)
{
for
(
const
auto
&
pair
:
lbn2
blob_desc
_
)
{
const
std
::
string
&
lbn
=
pair
.
first
;
*
(
lbn2
shape_
.
at
(
lbn
))
=
rhs
->
GetShape
(
lbn
);
*
(
lbn2
blob_desc_
.
at
(
lbn
))
=
rhs
->
GetBlobDesc
(
lbn
);
}
}
void
RegstDesc
::
EnrollLbn
(
const
std
::
string
&
lbn
)
{
std
::
unique_ptr
<
Shape
>
ptr
(
new
Shape
);
CHECK
(
lbn2shape_
.
emplace
(
lbn
,
std
::
move
(
ptr
)).
second
)
<<
lbn
;
CHECK
(
lbn2blob_desc_
.
emplace
(
lbn
,
of_make_unique
<
BlobDesc
>
()).
second
)
<<
lbn
;
}
const
Shape
&
RegstDesc
::
GetShape
(
const
std
::
string
&
lbn
)
const
{
return
*
(
lbn2
shape
_
.
at
(
lbn
));
const
BlobDesc
&
RegstDesc
::
GetBlobDesc
(
const
std
::
string
&
lbn
)
const
{
return
*
(
lbn2
blob_desc
_
.
at
(
lbn
));
}
Shape
*
RegstDesc
::
GetMutShapePtr
(
const
std
::
string
&
lbn
)
{
return
lbn2
shape
_
.
at
(
lbn
).
get
();
BlobDesc
*
RegstDesc
::
GetMutBlobDesc
(
const
std
::
string
&
lbn
)
{
return
lbn2
blob_desc
_
.
at
(
lbn
).
get
();
}
void
RegstDesc
::
ForEachLbn
(
std
::
function
<
void
(
const
std
::
string
&
)
>
func
)
const
{
for
(
const
auto
&
p
:
lbn2
shape
_
)
{
func
(
p
.
first
);
}
for
(
const
auto
&
p
:
lbn2
blob_desc
_
)
{
func
(
p
.
first
);
}
}
void
RegstDesc
::
EraseZeroSizeBlob
()
{
EraseIf
<
std
::
string
,
std
::
unique_ptr
<
Shape
>>
(
&
lbn2
shape
_
,
[](
HashMap
<
std
::
string
,
std
::
unique_ptr
<
Shape
>>::
iterator
it
)
{
return
it
->
second
->
elem_cnt
()
==
0
;
EraseIf
<
std
::
string
,
std
::
unique_ptr
<
BlobDesc
>>
(
&
lbn2
blob_desc
_
,
[](
HashMap
<
std
::
string
,
std
::
unique_ptr
<
BlobDesc
>>::
iterator
it
)
{
return
it
->
second
->
shape
().
elem_cnt
()
==
0
;
});
}
int64_t
RegstDesc
::
CompElemCntOfAllBlob
()
const
{
int64_t
sum
=
0
;
for
(
const
auto
&
pair
:
lbn2shape_
)
{
sum
+=
pair
.
second
->
elem_cnt
();
}
for
(
const
auto
&
pair
:
lbn2blob_desc_
)
{
sum
+=
pair
.
second
->
shape
().
elem_cnt
();
}
return
sum
;
}
std
::
string
RegstDesc
::
DebugStr
()
const
{
std
::
stringstream
ss
;
ss
<<
"{"
;
for
(
const
auto
&
pair
:
lbn2
shape
_
)
{
ss
<<
"{"
<<
pair
.
first
<<
":"
<<
pair
.
second
->
DebugStr
()
<<
"}"
;
for
(
const
auto
&
pair
:
lbn2
blob_desc
_
)
{
ss
<<
"{"
<<
pair
.
first
<<
":"
<<
pair
.
second
->
shape
().
DebugStr
()
<<
"}"
;
}
ss
<<
"}"
;
return
ss
.
str
();
...
...
@@ -104,10 +104,10 @@ void RegstDesc::ToProto(RegstDescProto* ret) const {
ret
->
add_consumer_task_id
(
consumer
->
task_id
());
}
}
for
(
const
auto
&
pair
:
lbn2
shape
_
)
{
PbMapPair
<
std
::
string
,
Shape
Proto
>
pb_pair
(
pair
.
first
);
for
(
const
auto
&
pair
:
lbn2
blob_desc
_
)
{
PbMapPair
<
std
::
string
,
BlobDesc
Proto
>
pb_pair
(
pair
.
first
);
pair
.
second
->
ToProto
(
&
(
pb_pair
.
second
));
ret
->
mutable_lbn2
shape
()
->
insert
(
pb_pair
);
ret
->
mutable_lbn2
blob_desc
()
->
insert
(
pb_pair
);
}
ret
->
set_register_num
(
register_num_
);
*
(
ret
->
mutable_mem_case
())
=
InferMemCase
();
...
...
@@ -140,4 +140,9 @@ MemoryCase RegstDesc::InferMemCase() const {
return
mem_case
;
}
BlobDesc
RegstDesc
::
CompPackedBlobDesc
()
const
{
BlobDesc
packed_blob_desc
;
packed_blob_desc
.
mut_shape
()
=
Shape
({
CompElemCntOfAllBlob
()});
}
}
// namespace oneflow
oneflow/core/register/register_desc.h
浏览文件 @
32022d4d
#ifndef ONEFLOW_CORE_REGISTER_REGISTER_DESC_H_
#define ONEFLOW_CORE_REGISTER_REGISTER_DESC_H_
#include "oneflow/core/common/shape.h"
#include "oneflow/core/common/util.h"
#include "oneflow/core/register/blob_desc.h"
#include "oneflow/core/register/register_desc.pb.h"
namespace
oneflow
{
// Regst : Register
// Contig : Contiguous
class
TaskNode
;
class
RegstDesc
final
{
...
...
@@ -27,28 +23,29 @@ class RegstDesc final {
const
HashSet
<
const
TaskNode
*>&
consumers
()
const
{
return
consumers_
;
}
void
AddConsumer
(
const
TaskNode
*
);
// Lbn and
Shape
// Lbn and
BlobDesc
void
CopyLbnFrom
(
const
RegstDesc
*
);
void
Copy
Shape
From
(
const
RegstDesc
*
);
void
Copy
BlobDesc
From
(
const
RegstDesc
*
);
void
EnrollLbn
(
const
std
::
string
&
lbn
);
const
Shape
&
GetShape
(
const
std
::
string
&
lbn
)
const
;
Shape
*
GetMutShapePtr
(
const
std
::
string
&
lbn
);
const
BlobDesc
&
GetBlobDesc
(
const
std
::
string
&
lbn
)
const
;
BlobDesc
*
GetMutBlobDesc
(
const
std
::
string
&
lbn
);
void
ForEachLbn
(
std
::
function
<
void
(
const
std
::
string
&
)
>
func
)
const
;
size_t
NumOfLbn
()
const
{
return
lbn2
shape
_
.
size
();
}
size_t
NumOfLbn
()
const
{
return
lbn2
blob_desc
_
.
size
();
}
//
void
EraseZeroSizeBlob
();
int64_t
CompElemCntOfAllBlob
()
const
;
std
::
string
DebugStr
()
const
;
void
ToProto
(
RegstDescProto
*
)
const
;
MemoryCase
InferMemCase
()
const
;
BlobDesc
CompPackedBlobDesc
()
const
;
private:
int64_t
CompElemCntOfAllBlob
()
const
;
int64_t
regst_desc_id_
;
const
TaskNode
*
producer_
;
HashSet
<
const
TaskNode
*>
consumers_
;
HashMap
<
std
::
string
,
std
::
unique_ptr
<
Shape
>>
lbn2shape
_
;
HashMap
<
std
::
string
,
std
::
unique_ptr
<
BlobDesc
>>
lbn2blob_desc
_
;
int64_t
register_num_
;
};
...
...
oneflow/core/register/register_desc.proto
浏览文件 @
32022d4d
syntax
=
"proto3"
;
package
oneflow
;
import
"oneflow/core/
common/shape
.proto"
;
import
"oneflow/core/
register/blob_desc
.proto"
;
import
"oneflow/core/memory/memory_case.proto"
;
message
RegstDescProto
{
int64
regst_desc_id
=
1
;
int64
producer_task_id
=
2
;
repeated
int64
consumer_task_id
=
3
;
map
<
string
,
ShapeProto
>
lbn2shape
=
4
;
map
<
string
,
BlobDescProto
>
lbn2blob_desc
=
4
;
int64
register_num
=
5
;
MemoryCase
mem_case
=
6
;
}
oneflow/core/register/register_manager.cpp
浏览文件 @
32022d4d
...
...
@@ -19,12 +19,12 @@ void RegstMgr::NewRegsts(const RegstDescProto& regst_desc_proto,
}
int64_t
elem_cnt
=
0
;
std
::
vector
<
std
::
string
>
lbns
;
lbns
.
reserve
(
regst_desc_proto
.
lbn2
shape
().
size
());
for
(
const
auto
&
pair
:
regst_desc_proto
.
lbn2
shape
())
{
const
Shape
*
shape_ptr
=
runtime_regst_desc
->
Get
ShapePtrFromLbn
(
pair
.
first
);
lbns
.
reserve
(
regst_desc_proto
.
lbn2
blob_desc
().
size
());
for
(
const
auto
&
pair
:
regst_desc_proto
.
lbn2
blob_desc
())
{
const
Shape
&
shape_ptr
=
runtime_regst_desc
->
Get
BlobDescFromLbn
(
pair
.
first
)
->
shape
(
);
lbns
.
push_back
(
pair
.
first
);
elem_cnt
+=
shape_ptr
->
elem_cnt
();
elem_cnt
+=
shape_ptr
.
elem_cnt
();
}
std
::
sort
(
lbns
.
begin
(),
lbns
.
end
());
std
::
pair
<
char
*
,
std
::
function
<
void
()
>>
allocation
=
...
...
@@ -33,7 +33,8 @@ void RegstMgr::NewRegsts(const RegstDescProto& regst_desc_proto,
int64_t
blob_idx
=
0
;
for
(
const
std
::
string
&
lbn
:
lbns
)
{
const
Shape
*
shape_ptr
=
runtime_regst_desc
->
GetShapePtrFromLbn
(
lbn
);
const
Shape
*
shape_ptr
=
&
(
runtime_regst_desc
->
GetBlobDescFromLbn
(
lbn
)
->
shape
());
auto
blob_ptr
=
of_make_unique
<
Blob
>
(
allocation
.
first
+
blob_idx
,
shape_ptr
);
CHECK
(
regst
->
lbn2blob_
.
emplace
(
lbn
,
std
::
move
(
blob_ptr
)).
second
);
...
...
oneflow/core/register/runtime_register_desc.cpp
浏览文件 @
32022d4d
...
...
@@ -15,8 +15,9 @@ RtRegstDesc::RtRegstDesc(const RegstDescProto& regst_desc_proto) {
consumers_actor_id_
.
push_back
(
IDMgr
::
Singleton
()
->
ActorId4TaskId
(
task_id
));
}
for
(
const
auto
&
pair
:
regst_desc_proto
.
lbn2shape
())
{
CHECK
(
lbn2shape_
.
emplace
(
pair
.
first
,
of_make_unique
<
Shape
>
(
pair
.
second
))
for
(
const
auto
&
pair
:
regst_desc_proto
.
lbn2blob_desc
())
{
CHECK
(
lbn2blob_desc_
.
emplace
(
pair
.
first
,
of_make_unique
<
BlobDesc
>
(
pair
.
second
))
.
second
);
}
mem_case_
=
regst_desc_proto
.
mem_case
();
...
...
oneflow/core/register/runtime_register_desc.h
浏览文件 @
32022d4d
#ifndef ONEFLOW_CORE_REGISTER_RUNTIME_REGISTER_DESC_H_
#define ONEFLOW_CORE_REGISTER_RUNTIME_REGISTER_DESC_H_
#include "oneflow/core/common/shape.h"
#include "oneflow/core/memory/memory_case.pb.h"
#include "oneflow/core/register/blob_desc.h"
#include "oneflow/core/register/register_desc.pb.h"
namespace
oneflow
{
...
...
@@ -23,15 +23,15 @@ class RtRegstDesc {
int64_t
register_num
()
const
{
return
register_num_
;
}
const
MemoryCase
&
mem_case
()
const
{
return
mem_case_
;
}
const
Shape
*
GetShapePtr
FromLbn
(
const
std
::
string
&
lbn
)
const
{
return
lbn2
shape
_
.
at
(
lbn
).
get
();
const
BlobDesc
*
GetBlobDesc
FromLbn
(
const
std
::
string
&
lbn
)
const
{
return
lbn2
blob_desc
_
.
at
(
lbn
).
get
();
}
private:
int64_t
regst_desc_id_
;
int64_t
producer_actor_id_
;
std
::
vector
<
int64_t
>
consumers_actor_id_
;
std
::
unordered_map
<
std
::
string
,
std
::
unique_ptr
<
Shape
>>
lbn2shape
_
;
std
::
unordered_map
<
std
::
string
,
std
::
unique_ptr
<
BlobDesc
>>
lbn2blob_desc
_
;
int64_t
register_num_
;
MemoryCase
mem_case_
;
};
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录