Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Oneflow-Inc
oneflow
提交
4ecca347
O
oneflow
项目概览
Oneflow-Inc
/
oneflow
上一次同步 2 年多
通知
13
Star
2733
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
O
oneflow
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
4ecca347
编写于
4月 18, 2017
作者:
W
willzhang4a58
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
implement register_desc
上级
42402a1b
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
37 addition
and
16 deletion
+37
-16
oneflow/graph/model_update_task_graph.cpp
oneflow/graph/model_update_task_graph.cpp
+1
-1
oneflow/graph/register_desc.cpp
oneflow/graph/register_desc.cpp
+25
-1
oneflow/graph/register_desc.h
oneflow/graph/register_desc.h
+11
-14
未找到文件。
oneflow/graph/model_update_task_graph.cpp
浏览文件 @
4ecca347
...
@@ -72,7 +72,7 @@ void MdUpdtTaskGraph::CompleteUpdateTaskAndFwTask(
...
@@ -72,7 +72,7 @@ void MdUpdtTaskGraph::CompleteUpdateTaskAndFwTask(
RegstDesc
*
model_diff_regst
=
bp_task
->
GetProducedRegstDesc
(
"model_diff"
);
RegstDesc
*
model_diff_regst
=
bp_task
->
GetProducedRegstDesc
(
"model_diff"
);
RegstDesc
*
model_regst
=
update_task
->
GetProducedRegstDesc
(
"model"
);
RegstDesc
*
model_regst
=
update_task
->
GetProducedRegstDesc
(
"model"
);
// complete update task
// complete update task
model_regst
->
CopyLbn
AndShape
(
model_diff_regst
);
model_regst
->
CopyLbn
2ShapeMap
(
model_diff_regst
);
ExecNode
*
update_exec
=
update_task
->
exec_gph
().
SoleNode
();
ExecNode
*
update_exec
=
update_task
->
exec_gph
().
SoleNode
();
const
std
::
string
&
ibn
=
update_exec
->
op
()
->
SoleIbn
();
const
std
::
string
&
ibn
=
update_exec
->
op
()
->
SoleIbn
();
if
(
update_task
->
in_edges
().
empty
())
{
if
(
update_task
->
in_edges
().
empty
())
{
...
...
oneflow/graph/register_desc.cpp
浏览文件 @
4ecca347
...
@@ -8,6 +8,30 @@ RegstDesc::RegstDesc() {
...
@@ -8,6 +8,30 @@ RegstDesc::RegstDesc() {
producer_
=
nullptr
;
producer_
=
nullptr
;
}
}
const
char
*
ContigRegstDesc
::
kAllLbn
=
"OfReservedAllLbn"
;
void
RegstDesc
::
CopyLbn2ShapeMap
(
const
RegstDesc
*
rhs
)
{
for
(
const
auto
&
pair
:
rhs
->
lbn2shape_
)
{
const
std
::
string
&
lbn
=
pair
.
first
;
std
::
unique_ptr
<
Shape
>
shape
(
new
Shape
);
*
shape
=
*
(
pair
.
second
);
CHECK
(
lbn2shape_
.
insert
(
std
::
make_pair
(
lbn
,
std
::
move
(
shape
))).
second
);
}
}
Shape
*
RegstDesc
::
EnrollLbn
(
const
std
::
string
&
lbn
)
{
Shape
*
raw_ptr
=
new
Shape
;
std
::
unique_ptr
<
Shape
>
uptr
(
raw_ptr
);
CHECK
(
lbn2shape_
.
insert
(
std
::
make_pair
(
lbn
,
std
::
move
(
uptr
))).
second
);
return
raw_ptr
;
}
const
Shape
&
RegstDesc
::
GetShape
(
const
std
::
string
&
lbn
)
{
return
*
(
lbn2shape_
.
at
(
lbn
));
}
Shape
*
RegstDesc
::
GetMutShapePtr
(
const
std
::
string
&
lbn
)
{
return
lbn2shape_
.
at
(
lbn
).
get
();
}
const
char
*
RegstDesc
::
kAllLbn
=
"OfReservedAllLbn"
;
}
// namespace oneflow
}
// namespace oneflow
oneflow/graph/register_desc.h
浏览文件 @
4ecca347
...
@@ -6,9 +6,10 @@
...
@@ -6,9 +6,10 @@
namespace
oneflow
{
namespace
oneflow
{
class
TaskNode
;
// Regst : Register
// Contig : Contiguous
// Regst : Register
class
TaskNode
;
class
RegstDesc
{
class
RegstDesc
{
public:
public:
...
@@ -16,32 +17,28 @@ class RegstDesc {
...
@@ -16,32 +17,28 @@ class RegstDesc {
RegstDesc
();
RegstDesc
();
virtual
~
RegstDesc
()
=
default
;
virtual
~
RegstDesc
()
=
default
;
//
//
Producer
const
TaskNode
*
GetProducer
()
const
{
return
producer_
;
}
const
TaskNode
*
GetProducer
()
const
{
return
producer_
;
}
void
SetProducer
(
const
TaskNode
*
task_node
)
{
producer_
=
task_node
;
}
void
SetProducer
(
const
TaskNode
*
task_node
)
{
producer_
=
task_node
;
}
void
AddSubscriber
(
const
TaskNode
*
task_node
)
{
CHECK
(
subscribers_
.
insert
(
task_node
).
second
);
}
void
CopyLbnAndShape
(
const
RegstDesc
*
)
{
TODO
();
}
Shape
*
EnrollLbn
(
const
std
::
string
&
lbn
)
{
TODO
();
}
// Lbn and Shape
const
Shape
&
GetShape
(
const
std
::
string
&
lbn
)
{
TODO
();
}
void
CopyLbn2ShapeMap
(
const
RegstDesc
*
);
Shape
*
GetMutShapePtr
(
const
std
::
string
&
lbn
)
{
TODO
();
}
Shape
*
EnrollLbn
(
const
std
::
string
&
lbn
);
const
Shape
&
GetShape
(
const
std
::
string
&
lbn
);
Shape
*
GetMutShapePtr
(
const
std
::
string
&
lbn
);
static
const
char
*
kAllLbn
;
private:
private:
int32_t
regst_desc_id_
;
int32_t
regst_desc_id_
;
const
TaskNode
*
producer_
;
const
TaskNode
*
producer_
;
std
::
unordered_set
<
const
TaskNode
*>
subscribers_
;
HashMap
<
std
::
string
,
std
::
unique_ptr
<
Shape
>>
lbn2shape_
;
HashMap
<
std
::
string
,
std
::
unique_ptr
<
Shape
>>
lbn2shape_
;
};
};
// Contiguous
class
ContigRegstDesc
final
:
public
RegstDesc
{
class
ContigRegstDesc
final
:
public
RegstDesc
{
public:
public:
static
const
char
*
kAllLbn
;
OF_DISALLOW_COPY_AND_MOVE
(
ContigRegstDesc
);
OF_DISALLOW_COPY_AND_MOVE
(
ContigRegstDesc
);
ContigRegstDesc
()
=
default
;
ContigRegstDesc
()
=
default
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录