Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Oneflow-Inc
oneflow
提交
13048516
O
oneflow
项目概览
Oneflow-Inc
/
oneflow
上一次同步 2 年多
通知
13
Star
2733
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
O
oneflow
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
13048516
编写于
4月 14, 2017
作者:
W
willzhang4a58
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
special ibn2lbn
上级
49523253
变更
15
隐藏空白更改
内联
并排
Showing
15 changed file
with
48 addition
and
68 deletion
+48
-68
oneflow/graph/comp_task_node.cc
oneflow/graph/comp_task_node.cc
+16
-48
oneflow/graph/comp_task_node.h
oneflow/graph/comp_task_node.h
+0
-2
oneflow/graph/task_node.cpp
oneflow/graph/task_node.cpp
+1
-2
oneflow/operator/clone_op.h
oneflow/operator/clone_op.h
+1
-1
oneflow/operator/concat_op.h
oneflow/operator/concat_op.h
+1
-1
oneflow/operator/convolution_op.h
oneflow/operator/convolution_op.h
+1
-1
oneflow/operator/copy_op.h
oneflow/operator/copy_op.h
+1
-1
oneflow/operator/innerproduct_op.h
oneflow/operator/innerproduct_op.h
+1
-1
oneflow/operator/multinomial_logistic_loss_op.h
oneflow/operator/multinomial_logistic_loss_op.h
+1
-1
oneflow/operator/operator.cpp
oneflow/operator/operator.cpp
+9
-1
oneflow/operator/operator.h
oneflow/operator/operator.h
+12
-5
oneflow/operator/pooling_op.h
oneflow/operator/pooling_op.h
+1
-1
oneflow/operator/relu_op.h
oneflow/operator/relu_op.h
+1
-1
oneflow/operator/softmax_op.h
oneflow/operator/softmax_op.h
+1
-1
oneflow/operator/split_op.h
oneflow/operator/split_op.h
+1
-1
未找到文件。
oneflow/graph/comp_task_node.cc
浏览文件 @
13048516
...
...
@@ -24,15 +24,16 @@ void CompTaskNode::DataFwBuildExecAndProducedRegsts(Path* path) {
if
(
GetBpNode
()
!=
nullptr
)
{
FwAddCopyInOp
(
&
extern_in_lbn2consumers
);
}
FwAddCloneOp
();
mut_exec_gph
().
UpdateSourceAndSink
();
// data regst
std
::
unique_ptr
<
RegstDesc
>
data_regst
(
new
DisContigRegstDesc
);
BindProducedRegstAndOutEdge
(
data_regst
.
get
(),
SoleOutEdge
());
EnrollProducedRegstDesc
(
"data"
,
std
::
move
(
data_regst
));
FwSetDataRegstDesc
(
lbn2producer
,
extern_in_lbn2consumers
);
// model_tmp regst
std
::
unique_ptr
<
RegstDesc
>
model_tmp_regst
(
new
DisContigRegstDesc
);
EnrollProducedRegstDesc
(
"model_tmp"
,
std
::
move
(
model_tmp_regst
));
FwSetModelTmpRegstDesc
();
FwSetDataRegstDesc
(
lbn2producer
,
extern_in_lbn2consumers
);
}
void
CompTaskNode
::
ModelUpdateFwBuildExecAndProducedRegsts
(
Path
*
path
)
{
...
...
@@ -125,52 +126,6 @@ void CompTaskNode::FwAddCopyInOp(Lbn2NodeIbnVecMap* extern_in_lbn2consumers) {
}
}
void
CompTaskNode
::
FwAddCloneOp
()
{
std
::
vector
<
CloneInfo
>
clone_info_vec
;
CollectCloneInfoVec
(
&
clone_info_vec
);
for
(
const
CloneInfo
&
clone_info
:
clone_info_vec
)
{
AddOneCloneNode
(
clone_info
);
}
}
void
CompTaskNode
::
FwCollectCloneInfoVec
(
std
::
vector
<
CloneInfo
>*
clone_info_vec
)
{
}
void
CompTaskNode
::
FwAddOneCloneNode
(
const
CloneInfo
&
clone_info
)
{
ExecNode
*
clone_node
=
mut_exec_gph
().
NewFinalNode
();
clone_node
->
mut_op
()
=
clone_info
.
clone_op
;
// InEdge
ExecEdge
*
in_edge
=
mut_exec_gph
().
NewExecEdge
();
in_edge
->
set_lbn
(
clone_info
.
lbn
);
in_edge
->
mut_dst_bn
()
=
clone_node
->
op
()
->
SoleIbn
();
in_edge
->
mut_src_bn
()
=
clone_info
.
edges
().
front
()
->
obn
();
Connect
(
clone_info
.
pred_node
,
in_edge
,
clone_node
);
// OutEdge
CHECK_EQ
(
clone_node
->
op
()
->
output_bns
().
size
(),
clone_info
.
edges
.
size
());
for
(
size_t
i
=
0
;
i
<
clone_info
.
edges
.
size
();
++
i
)
{
const
std
::
string
&
obn
=
clone_node
->
op
()
->
output_bns
().
at
(
i
);
ExecEdge
*
out_edge
=
clone_info
.
edges
.
at
(
i
);
ExecNode
*
dst_node
=
out_edge
->
dst_node
();
DisConnect
(
out_edge
);
out_edge
->
mut_src_bn
()
=
obn
;
Connect
(
clone_node
,
out_edge
,
dst_node
);
}
}
void
CompTaskNode
::
FwSetModelTmpRegstDesc
()
{
RegstDesc
*
model_tmp_regst
=
GetProducedRegstDesc
(
"model_tmp"
);
for
(
const
std
::
unique_ptr
<
ExecNode
>&
node
:
exec_gph
().
nodes
())
{
for
(
const
std
::
string
&
mtbn
:
node
->
op
()
->
model_tmp_bns
())
{
std
::
string
lbn
=
node
->
op
()
->
mtbn2lbn
(
mtbn
);
Shape
*
ptr
=
model_tmp_regst
->
EnrollWithLbn
(
lbn
);
node
->
op
()
->
SetShapePtr
(
mtbn
,
ptr
);
node
->
BindBnInOpAndRegst
(
mtbn
,
model_tmp_regst
);
}
node
->
op
()
->
InferShape4Mtb
();
}
}
void
CompTaskNode
::
FwSetDataRegstDesc
(
const
Lbn2NodeObnMap
&
lbn2producer
,
const
Lbn2NodeIbnVecMap
&
extern_in_lbn2consumers
)
{
...
...
@@ -222,6 +177,19 @@ void CompTaskNode::FwSetDataRegstDesc(
}
}
void
CompTaskNode
::
FwSetModelTmpRegstDesc
()
{
RegstDesc
*
model_tmp_regst
=
GetProducedRegstDesc
(
"model_tmp"
);
for
(
const
std
::
unique_ptr
<
ExecNode
>&
node
:
exec_gph
().
nodes
())
{
for
(
const
std
::
string
&
mtbn
:
node
->
op
()
->
model_tmp_bns
())
{
std
::
string
lbn
=
node
->
op
()
->
mtbn2lbn
(
mtbn
);
Shape
*
ptr
=
model_tmp_regst
->
EnrollWithLbn
(
lbn
);
node
->
op
()
->
SetShapePtr
(
mtbn
,
ptr
);
node
->
BindBnInOpAndRegst
(
mtbn
,
model_tmp_regst
);
}
node
->
op
()
->
InferShape4Mtb
();
}
}
void
CompTaskNode
::
BpBuildExecAndProducedRegsts
(
Path
*
path
)
{
const
ExecGraph
&
fw_gph
=
GetFwNode
()
->
exec_gph
();
const
ExecNode
*
cp_in_node
=
fw_gph
.
source_node
().
SoleOutEdge
()
->
dst_node
();
...
...
oneflow/graph/comp_task_node.h
浏览文件 @
13048516
...
...
@@ -43,8 +43,6 @@ class CompTaskNode : public TaskNode {
Lbn2NodeObnMap
*
lbn2producer
,
Lbn2NodeIbnVecMap
*
extern_in_lbn2consumers
);
void
FwAddCopyInOp
(
Lbn2NodeIbnVecMap
*
extern_in_lbn2consumers
);
void
FwAddCloneOp
();
void
FwBindOutEdgeAndRegst
();
void
FwSetProducedRegstDescs
(
const
Lbn2NodeObnMap
&
lbn2producer
,
const
Lbn2NodeIbnVecMap
&
extern_in_lbn2consumers
);
...
...
oneflow/graph/task_node.cpp
浏览文件 @
13048516
...
...
@@ -88,8 +88,7 @@ void TaskNode::EnrollProducedRegstDesc(
void
TaskNode
::
SubscribeRegstDescInnerPath
()
{
for
(
const
TaskEdge
*
edge
:
in_edges
())
{
RegstDesc
*
regst
=
GetRelatedRegst
(
edge
);
Subscribe
(
regst
);
Subscribe
(
GetRelatedRegst
(
edge
));
}
}
...
...
oneflow/operator/clone_op.h
浏览文件 @
13048516
...
...
@@ -14,7 +14,7 @@ class CloneOp final : public SysOperator {
void
Init
(
const
OperatorConf
&
op_conf
)
override
;
void
InferShape4ObAndDtbFromIb
()
const
override
{
TODO
();
}
std
::
string
ibn2lbn
(
const
std
::
string
&
input_bn
)
const
override
{
std
::
string
normal_
ibn2lbn
(
const
std
::
string
&
input_bn
)
const
override
{
return
GetValueFromPbOpConf
(
"lbn"
);
}
std
::
string
obn2lbn
(
const
std
::
string
&
output_bn
)
const
override
{
...
...
oneflow/operator/concat_op.h
浏览文件 @
13048516
...
...
@@ -14,7 +14,7 @@ class ConcatOp final : public SysOperator {
void
Init
(
const
OperatorConf
&
op_conf
)
override
;
void
InferShape4ObAndDtbFromIb
()
const
override
{
TODO
();
}
std
::
string
ibn2lbn
(
const
std
::
string
&
input_bn
)
const
override
{
std
::
string
normal_
ibn2lbn
(
const
std
::
string
&
input_bn
)
const
override
{
return
GetValueFromPbOpConf
(
"lbn"
);
}
std
::
string
obn2lbn
(
const
std
::
string
&
output_bn
)
const
override
{
...
...
oneflow/operator/convolution_op.h
浏览文件 @
13048516
...
...
@@ -13,7 +13,7 @@ class ConvolutionOp final : public UserOperator {
void
Init
(
const
OperatorConf
&
op_conf
)
override
;
void
InferShape4ObAndDtbFromIb
()
const
override
{
TODO
();
}
void
InferShape4M
bAndM
tb
()
const
override
{
TODO
();
}
void
InferShape4Mtb
()
const
override
{
TODO
();
}
private:
...
...
oneflow/operator/copy_op.h
浏览文件 @
13048516
...
...
@@ -14,7 +14,7 @@ class CopyOp final : public SysOperator {
void
Init
(
const
OperatorConf
&
op_conf
)
override
;
void
InferShape4ObAndDtbFromIb
()
const
override
{
TODO
();
}
std
::
string
ibn2lbn
(
const
std
::
string
&
input_bn
)
const
override
{
std
::
string
normal_
ibn2lbn
(
const
std
::
string
&
input_bn
)
const
override
{
return
ibn2lbn_
.
at
(
input_bn
);
}
std
::
string
obn2lbn
(
const
std
::
string
&
output_bn
)
const
override
{
...
...
oneflow/operator/innerproduct_op.h
浏览文件 @
13048516
...
...
@@ -14,7 +14,7 @@ class InnerProductOp final : public UserOperator {
void
Init
(
const
OperatorConf
&
op_conf
)
override
;
void
InferShape4ObAndDtbFromIb
()
const
override
{
TODO
();
}
void
InferShape4M
bAndM
tb
()
const
override
{
TODO
();
}
void
InferShape4Mtb
()
const
override
{
TODO
();
}
private:
...
...
oneflow/operator/multinomial_logistic_loss_op.h
浏览文件 @
13048516
...
...
@@ -17,7 +17,7 @@ class MultinomialLogisticLossOp : public UserOperator {
bool
IsLossOp
()
const
override
{
return
true
;
}
void
InferShape4ObAndDtbFromIb
()
const
override
{
TODO
();
}
void
InferShape4M
bAndM
tb
()
const
override
{
TODO
();
}
void
InferShape4Mtb
()
const
override
{
TODO
();
}
private:
...
...
oneflow/operator/operator.cpp
浏览文件 @
13048516
...
...
@@ -27,6 +27,14 @@ std::string Operator::odbn2lbn(const std::string& output_diff_bn) const {
std
::
string
Operator
::
mdbn2lbn
(
const
std
::
string
&
model_diff_bn
)
const
{
return
mbn2lbn
(
GenUnDiffBn
(
model_diff_bn
));
}
std
::
string
Operator
::
ibn2lbn
(
const
std
::
string
&
input_bn
)
const
{
auto
it
=
special_ibn2lbn_
.
find
(
input_bn
);
if
(
it
==
special_ibn2lbn_
.
end
())
{
return
normal_ibn2lbn
(
input_bn
);
}
else
{
return
it
->
second
;
}
}
std
::
string
Operator
::
GetValueFromPbOpConf
(
const
std
::
string
&
k
)
const
{
return
GetValueFromPbMessage
(
*
pb_op_conf_
,
k
);
...
...
@@ -79,7 +87,7 @@ void Operator::EnrollBn(std::vector<std::string>* bn_vec,
CHECK
(
bn_in_op2shape_ptr_
.
emplace
(
bn
,
nullptr
).
second
);
}
std
::
string
UserOperator
::
ibn2lbn
(
const
std
::
string
&
input_bn
)
const
{
std
::
string
UserOperator
::
normal_
ibn2lbn
(
const
std
::
string
&
input_bn
)
const
{
return
GetValueFromPbOpConf
(
input_bn
);
}
std
::
string
UserOperator
::
obn2lbn
(
const
std
::
string
&
output_bn
)
const
{
...
...
oneflow/operator/operator.h
浏览文件 @
13048516
...
...
@@ -28,11 +28,15 @@ class Operator {
std
::
string
idbn2lbn
(
const
std
::
string
&
input_diff_bn
)
const
;
std
::
string
odbn2lbn
(
const
std
::
string
&
output_diff_bn
)
const
;
std
::
string
mdbn2lbn
(
const
std
::
string
&
model_diff_bn
)
const
;
std
::
string
ibn2lbn
(
const
std
::
string
&
input_bn
)
const
;
virtual
std
::
string
ibn2lbn
(
const
std
::
string
&
input_bn
)
const
=
0
;
virtual
std
::
string
obn2lbn
(
const
std
::
string
&
output_bn
)
const
=
0
;
virtual
std
::
string
mtbn2lbn
(
const
std
::
string
&
model_tmp_bn
)
const
=
0
;
virtual
std
::
string
mbn2lbn
(
const
std
::
string
&
model_bn
)
const
=
0
;
void
AddSpecialIbn2Lbn
(
const
std
::
string
&
ibn
,
const
std
::
string
&
lbn
)
{
CHECK
(
special_ibn2lbn_
.
emplace
(
ibn
,
lbn
).
second
);
}
// Getters
const
std
::
string
&
op_name
()
const
{
return
op_name_
;
}
...
...
@@ -58,11 +62,12 @@ class Operator {
void
SetShapePtr
(
const
std
::
string
&
bn_in_op
,
Shape
*
ptr
)
const
;
void
SetNull4AllShapePtr
()
const
;
virtual
void
InferShape4ObAndDtbFromIb
()
const
=
0
;
virtual
void
InferShape4M
bAndM
tb
()
const
=
0
;
virtual
void
InferShape4Mtb
()
const
=
0
;
protected:
std
::
string
&
mut_op_name
()
{
return
op_name_
;
}
std
::
unique_ptr
<
PbMessage
>&
mut_pb_op_conf
()
{
return
pb_op_conf_
;
}
virtual
std
::
string
normal_ibn2lbn
(
const
std
::
string
&
input_bn
)
const
=
0
;
// enroll data blobs
void
EnrollDataTmpBn
(
const
std
::
string
&
dtbn
);
...
...
@@ -82,6 +87,8 @@ class Operator {
std
::
string
op_name_
;
std
::
unique_ptr
<
PbMessage
>
pb_op_conf_
;
std
::
unordered_map
<
std
::
string
,
std
::
string
>
special_ibn2lbn_
;
// blob name in op
std
::
vector
<
std
::
string
>
data_tmp_bns_
;
std
::
vector
<
std
::
string
>
input_bns_
;
...
...
@@ -102,7 +109,7 @@ class UserOperator : public Operator {
UserOperator
()
=
default
;
virtual
~
UserOperator
()
=
default
;
std
::
string
ibn2lbn
(
const
std
::
string
&
input_bn
)
const
override
;
std
::
string
normal_
ibn2lbn
(
const
std
::
string
&
input_bn
)
const
override
;
std
::
string
obn2lbn
(
const
std
::
string
&
output_bn
)
const
override
;
std
::
string
mtbn2lbn
(
const
std
::
string
&
model_tmp_bn
)
const
override
;
std
::
string
mbn2lbn
(
const
std
::
string
&
model_bn
)
const
override
;
...
...
@@ -120,14 +127,14 @@ class SysOperator : public Operator {
UNEXPECTED_RUN(); \
}
SET_UNEXPECTED
(
ibn2lbn
);
SET_UNEXPECTED
(
normal_
ibn2lbn
);
SET_UNEXPECTED
(
obn2lbn
);
SET_UNEXPECTED
(
mtbn2lbn
);
SET_UNEXPECTED
(
mbn2lbn
);
#undef SET_UNEXPECTED
void
InferShape4M
bAndM
tb
()
const
override
{
UNEXPECTED_RUN
();
}
void
InferShape4Mtb
()
const
override
{
UNEXPECTED_RUN
();
}
private:
};
...
...
oneflow/operator/pooling_op.h
浏览文件 @
13048516
...
...
@@ -14,7 +14,7 @@ class PoolingOp final : public UserOperator {
void
Init
(
const
OperatorConf
&
op_conf
)
override
;
void
InferShape4ObAndDtbFromIb
()
const
override
{
TODO
();
}
void
InferShape4M
bAndM
tb
()
const
override
{
TODO
();
}
void
InferShape4Mtb
()
const
override
{
TODO
();
}
private:
...
...
oneflow/operator/relu_op.h
浏览文件 @
13048516
...
...
@@ -15,7 +15,7 @@ class ReluOp final : public UserOperator {
bool
IsElemWise
()
const
override
{
return
true
;
}
void
InferShape4ObAndDtbFromIb
()
const
override
{
TODO
();
}
void
InferShape4M
bAndM
tb
()
const
override
{
TODO
();
}
void
InferShape4Mtb
()
const
override
{
TODO
();
}
private:
...
...
oneflow/operator/softmax_op.h
浏览文件 @
13048516
...
...
@@ -14,7 +14,7 @@ class SoftmaxOp : public UserOperator {
void
Init
(
const
OperatorConf
&
op_conf
)
override
;
void
InferShape4ObAndDtbFromIb
()
const
override
{
TODO
();
}
void
InferShape4M
bAndM
tb
()
const
override
{
TODO
();
}
void
InferShape4Mtb
()
const
override
{
TODO
();
}
private:
...
...
oneflow/operator/split_op.h
浏览文件 @
13048516
...
...
@@ -14,7 +14,7 @@ class SplitOp final : public SysOperator {
void
Init
(
const
OperatorConf
&
op_conf
)
override
;
void
InferShape4ObAndDtbFromIb
()
const
override
{
TODO
();
}
std
::
string
ibn2lbn
(
const
std
::
string
&
input_bn
)
const
override
{
std
::
string
normal_
ibn2lbn
(
const
std
::
string
&
input_bn
)
const
override
{
return
GetValueFromPbOpConf
(
"lbn"
);
}
std
::
string
obn2lbn
(
const
std
::
string
&
output_bn
)
const
override
{
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录