Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
c53abcdf
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
c53abcdf
编写于
1月 14, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
chore(mge): minor improvements related to grad
GitOrigin-RevId: 102467d79d148b52f4dfefadeb3e6a7d7a0d2ad6
上级
0a3ca253
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
86 addition
and
11 deletion
+86
-11
imperative/python/src/grad.cpp
imperative/python/src/grad.cpp
+1
-11
imperative/python/src/tensor.h
imperative/python/src/tensor.h
+12
-0
imperative/src/impl/backward_graph_opt.cpp
imperative/src/impl/backward_graph_opt.cpp
+4
-0
imperative/src/impl/ops/backward_graph.cpp
imperative/src/impl/ops/backward_graph.cpp
+65
-0
imperative/src/include/megbrain/imperative/ops/backward_graph.h
...tive/src/include/megbrain/imperative/ops/backward_graph.h
+4
-0
未找到文件。
imperative/python/src/grad.cpp
浏览文件 @
c53abcdf
...
@@ -155,17 +155,7 @@ struct BackwardGraphWithClosure {
...
@@ -155,17 +155,7 @@ struct BackwardGraphWithClosure {
}
}
if
(
null_grad
)
return
;
if
(
null_grad
)
return
;
ApplyContext
ctx
;
auto
igrads
=
apply
(
backward_graph
->
backward
,
args
,
nargs
);
ctx
.
op
=
backward_graph
->
backward
;
ctx
.
flags
=
is_tracing
?
Flags
::
TRACE
:
0
;
ctx
.
nargs
=
nargs
;
ctx
.
args
=
args
;
for
(
size_t
i
=
0
;
i
<
nargs
;
++
i
)
{
ctx
.
flags
|=
args
[
i
]
->
m_flags
;
mgb_assert
(
args
[
i
]);
}
auto
igrads
=
apply
(
ctx
);
auto
&&
it
=
igrads
.
begin
();
auto
&&
it
=
igrads
.
begin
();
for
(
auto
[
i
,
p
]
:
views
::
enumerate
(
backward_graph
->
input_has_grad
))
{
for
(
auto
[
i
,
p
]
:
views
::
enumerate
(
backward_graph
->
input_has_grad
))
{
if
(
p
)
{
if
(
p
)
{
...
...
imperative/python/src/tensor.h
浏览文件 @
c53abcdf
...
@@ -252,6 +252,18 @@ auto apply(std::shared_ptr<OpDef> op, T&& tensors)
...
@@ -252,6 +252,18 @@ auto apply(std::shared_ptr<OpDef> op, T&& tensors)
return
apply
(
ctx
);
return
apply
(
ctx
);
}
}
inline
auto
apply
(
std
::
shared_ptr
<
OpDef
>
op
,
Tensor
*
const
*
args
,
size_t
nargs
)
{
ApplyContext
ctx
;
ctx
.
op
=
std
::
move
(
op
);
ctx
.
flags
=
is_tracing
?
Tensor
::
Flags
::
TRACE
:
0
;
ctx
.
nargs
=
nargs
;
ctx
.
args
=
args
;
for
(
size_t
i
=
0
;
i
<
nargs
;
++
i
)
{
ctx
.
flags
|=
args
[
i
]
->
m_flags
;
}
return
apply
(
ctx
);
}
void
init_tensor
(
pybind11
::
module
);
void
init_tensor
(
pybind11
::
module
);
extern
PyObject
*
cpp_apply_with_tracing
,
*
cpp_apply_compiled_mode
;
extern
PyObject
*
cpp_apply_with_tracing
,
*
cpp_apply_compiled_mode
;
...
...
imperative/src/impl/backward_graph_opt.cpp
浏览文件 @
c53abcdf
...
@@ -111,4 +111,8 @@ OptimizedBackwardGraphResult::OptimizedBackwardGraphResult(const BackwardGraphRe
...
@@ -111,4 +111,8 @@ OptimizedBackwardGraphResult::OptimizedBackwardGraphResult(const BackwardGraphRe
}
}
}
}
}
}
if
(
!
fgraph
.
outputs
.
size
())
{
precomp
.
reset
();
}
}
}
imperative/src/impl/ops/backward_graph.cpp
浏览文件 @
c53abcdf
...
@@ -9,7 +9,11 @@
...
@@ -9,7 +9,11 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
*/
#include <sstream>
#include <range/v3/all.hpp>
#include "megbrain/imperative/ops/backward_graph.h"
#include "megbrain/imperative/ops/backward_graph.h"
#include "megbrain/imperative/ops/opr_attr.h"
#include "../op_trait.h"
#include "../op_trait.h"
namespace
mgb
{
namespace
mgb
{
...
@@ -66,6 +70,67 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> BackwardGraph::InternalGraph::i
...
@@ -66,6 +70,67 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> BackwardGraph::InternalGraph::i
return
{
ret
,
validated
};
return
{
ret
,
validated
};
}
}
std
::
string
BackwardGraph
::
InternalGraph
::
repr
()
{
std
::
ostringstream
buf
;
buf
<<
"("
;
for
(
size_t
i
=
0
;
i
<
inputs
.
size
();
++
i
)
{
if
(
i
>
0
)
buf
<<
", "
;
buf
<<
"%"
<<
inputs
[
i
];
}
buf
<<
") => {
\n
"
;
auto
fmt_const
=
[](
size_t
i
,
TensorPtr
&
t
)
{
if
(
t
->
shape
().
ndim
==
1
&&
t
->
shape
()[
0
]
==
1
)
{
auto
&&
v
=
t
->
get_value
();
if
(
v
.
dtype
()
==
dtype
::
Float32
{})
{
return
std
::
to_string
(
*
v
.
ptr
<
dt_float32
>
());
}
else
if
(
v
.
dtype
()
==
dtype
::
Int32
{})
{
return
std
::
to_string
(
*
v
.
ptr
<
int32_t
>
());
}
}
return
std
::
string
(
"%c"
)
+
std
::
to_string
(
i
);
};
std
::
unordered_map
<
size_t
,
std
::
string
>
const_reps
;
for
(
auto
&&
[
i
,
t
]
:
constants
)
{
const_reps
.
emplace
(
i
,
fmt_const
(
i
,
t
));
}
for
(
auto
&
[
op
,
ins
,
outs
]
:
exprs
)
{
buf
<<
" "
;
if
(
outs
.
size
())
{
for
(
size_t
i
=
0
;
i
<
outs
.
size
();
++
i
)
{
if
(
i
>
0
)
buf
<<
", "
;
buf
<<
"%"
<<
outs
[
i
];
}
buf
<<
" = "
;
}
if
(
auto
*
p
=
op
->
try_cast_final
<
OprAttr
>
())
{
buf
<<
p
->
type
;
}
else
{
buf
<<
op
->
dyn_typeinfo
()
->
name
;
}
for
(
size_t
i
:
ins
)
{
buf
<<
" "
;
auto
&&
it
=
const_reps
.
find
(
i
);
if
(
it
!=
const_reps
.
end
())
{
buf
<<
it
->
second
;
}
else
{
buf
<<
"%"
<<
i
;
}
}
buf
<<
"
\n
"
;
}
buf
<<
" "
;
if
(
outputs
.
size
())
{
for
(
size_t
i
=
0
;
i
<
outputs
.
size
();
++
i
)
{
if
(
i
>
0
)
buf
<<
", "
;
buf
<<
"%"
<<
outputs
[
i
];
}
}
else
{
buf
<<
"()"
;
}
buf
<<
"
\n
}
\n
"
;
return
buf
.
str
();
}
MGB_DYN_TYPE_OBJ_FINAL_IMPL
(
BackwardGraph
);
MGB_DYN_TYPE_OBJ_FINAL_IMPL
(
BackwardGraph
);
namespace
{
namespace
{
...
...
imperative/src/include/megbrain/imperative/ops/backward_graph.h
浏览文件 @
c53abcdf
...
@@ -71,6 +71,8 @@ public:
...
@@ -71,6 +71,8 @@ public:
}
}
return
ret
;
return
ret
;
}
}
std
::
string
repr
();
};
};
const
InternalGraph
&
graph
()
const
{
const
InternalGraph
&
graph
()
const
{
...
@@ -93,6 +95,8 @@ public:
...
@@ -93,6 +95,8 @@ public:
return
false
;
return
false
;
}
}
std
::
string
repr
()
{
return
m_graph
.
repr
();}
private:
private:
InternalGraph
m_graph
;
InternalGraph
m_graph
;
};
};
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录