Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
b2806135
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
b2806135
编写于
10月 03, 2017
作者:
Y
Yu Yang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Change Interface to unique_ptr
上级
495a80a7
变更
4
显示空白变更内容
内联
并排
Showing
4 changed file
with
24 addition
and
16 deletion
+24
-16
doc/design/register_grad_op.md
doc/design/register_grad_op.md
+3
-3
paddle/framework/grad_op_desc_maker.h
paddle/framework/grad_op_desc_maker.h
+16
-12
paddle/framework/op_info.h
paddle/framework/op_info.h
+1
-1
paddle/framework/type_defs.h
paddle/framework/type_defs.h
+4
-0
未找到文件。
doc/design/register_grad_op.md
浏览文件 @
b2806135
...
...
@@ -42,7 +42,7 @@ The `GradOpDescMaker` will be registered in `OpInfo`, to replace `grad_op_type_`
```
cpp
struct
OpInfo
{
std
::
function
<
std
::
vector
<
OpDescBind
>
(
const
OpDescBind
&
)
>
grad_op_maker_
;
std
::
function
<
std
::
vector
<
std
::
unique_ptr
<
OpDescBind
>
>
(
const
OpDescBind
&
)
>
grad_op_maker_
;
...
};
```
...
...
@@ -55,11 +55,11 @@ We propose a base class called `GradOpDescMakerBase` to let operator developers
class
GradOpDescMakerBase
{
public:
GradOpDescMakerBase
(
const
OpDescBind
&
);
virtual
std
::
vector
<
OpDescBind
>
operator
()()
const
=
0
;
virtual
std
::
vector
<
std
::
unique_ptr
<
OpDescBind
>
>
operator
()()
const
=
0
;
};
```
We can convert
`GradOpDescMakerBase`
to
`std::function<std::vector<
OpDescBind
>(const OpDescBind&)>`
by
We can convert
`GradOpDescMakerBase`
to
`std::function<std::vector<
std::unique_ptr<OpDescBind>
>(const OpDescBind&)>`
by
```
cpp
using
GradOpMaker
=
...;
...
...
paddle/framework/grad_op_desc_maker.h
浏览文件 @
b2806135
...
...
@@ -24,7 +24,7 @@ class GradOpDescMakerBase {
explicit
GradOpDescMakerBase
(
const
OpDescBind
&
fwd_op
)
:
fwd_op_
(
fwd_op
)
{}
virtual
~
GradOpDescMakerBase
()
=
default
;
virtual
std
::
vector
<
OpDescBind
>
operator
()()
const
=
0
;
virtual
std
::
vector
<
std
::
unique_ptr
<
OpDescBind
>
>
operator
()()
const
=
0
;
protected:
static
std
::
vector
<
std
::
string
>
ToGradNames
(
...
...
@@ -81,10 +81,14 @@ class SingleGradOpDescMaker : public GradOpDescMakerBase {
public:
using
GradOpDescMakerBase
::
GradOpDescMakerBase
;
std
::
vector
<
OpDescBind
>
operator
()()
const
{
return
{
this
->
Apply
()};
}
std
::
vector
<
std
::
unique_ptr
<
OpDescBind
>>
operator
()()
const
{
std
::
vector
<
std
::
unique_ptr
<
OpDescBind
>>
retv
;
retv
.
emplace_back
(
this
->
Apply
());
return
retv
;
}
protected:
virtual
OpDescBind
Apply
()
const
=
0
;
virtual
std
::
unique_ptr
<
OpDescBind
>
Apply
()
const
=
0
;
};
class
DefaultGradOpDescMaker
:
public
SingleGradOpDescMaker
{
...
...
@@ -92,23 +96,23 @@ class DefaultGradOpDescMaker : public SingleGradOpDescMaker {
using
SingleGradOpDescMaker
::
SingleGradOpDescMaker
;
protected:
virtual
OpDescBind
Apply
()
const
{
OpDescBind
grad
;
grad
.
SetType
(
this
->
GradOpType
());
virtual
std
::
unique_ptr
<
OpDescBind
>
Apply
()
const
{
auto
*
grad
=
new
OpDescBind
()
;
grad
->
SetType
(
this
->
GradOpType
());
for
(
auto
&
input_param
:
this
->
InputNames
())
{
grad
.
SetInput
(
input_param
,
this
->
Input
(
input_param
));
grad
.
SetOutput
(
GradVarName
(
input_param
),
this
->
InputGrad
(
input_param
));
grad
->
SetInput
(
input_param
,
this
->
Input
(
input_param
));
grad
->
SetOutput
(
GradVarName
(
input_param
),
this
->
InputGrad
(
input_param
));
}
for
(
auto
&
output_param
:
this
->
OutputNames
())
{
grad
.
SetInput
(
output_param
,
this
->
Output
(
output_param
));
grad
.
SetInput
(
GradVarName
(
output_param
),
this
->
OutputGrad
(
output_param
));
grad
->
SetInput
(
output_param
,
this
->
Output
(
output_param
));
grad
->
SetInput
(
GradVarName
(
output_param
),
this
->
OutputGrad
(
output_param
));
}
grad
.
SetAttrMap
(
this
->
Attrs
());
grad
->
SetAttrMap
(
this
->
Attrs
());
return
grad
;
return
std
::
unique_ptr
<
OpDescBind
>
(
grad
)
;
}
virtual
std
::
string
GradOpType
()
const
{
...
...
paddle/framework/op_info.h
浏览文件 @
b2806135
...
...
@@ -28,7 +28,7 @@ namespace framework {
struct
OpInfo
{
OpCreator
creator_
;
std
::
string
grad_op_type_
;
std
::
function
<
std
::
vector
<
OpDescBind
>
(
const
OpDescBind
&
)
>
grad_op_maker_
;
GradOpMakerFN
grad_op_maker_
;
OpProto
*
proto_
{
nullptr
};
OpAttrChecker
*
checker_
{
nullptr
};
...
...
paddle/framework/type_defs.h
浏览文件 @
b2806135
...
...
@@ -20,6 +20,7 @@
namespace
paddle
{
namespace
framework
{
class
OperatorBase
;
class
OpDescBind
;
using
VariableNameMap
=
std
::
map
<
std
::
string
,
std
::
vector
<
std
::
string
>>
;
// The order should be as same as framework.proto
...
...
@@ -34,5 +35,8 @@ using OpCreator = std::function<OperatorBase*(
const
std
::
string
&
/*type*/
,
const
VariableNameMap
&
/*inputs*/
,
const
VariableNameMap
&
/*outputs*/
,
const
AttributeMap
&
/*attrs*/
)
>
;
using
GradOpMakerFN
=
std
::
function
<
std
::
vector
<
std
::
unique_ptr
<
OpDescBind
>>
(
const
OpDescBind
&
)
>
;
}
// namespace framework
}
// namespace paddle
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录