Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
麻辣皮子熊
Paddle
提交
5d33ef61
P
Paddle
项目概览
麻辣皮子熊
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
5
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
5d33ef61
编写于
8月 14, 2017
作者:
Q
qiaolongfei
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
change op_register and grad_op_builder
上级
b2e3824e
变更
2
显示空白变更内容
内联
并排
Showing
2 changed file
with
32 addition
and
18 deletion
+32
-18
paddle/framework/grad_op_builder.cc
paddle/framework/grad_op_builder.cc
+22
-16
paddle/framework/op_registry.h
paddle/framework/op_registry.h
+10
-2
未找到文件。
paddle/framework/grad_op_builder.cc
浏览文件 @
5d33ef61
...
@@ -13,22 +13,22 @@ express or implied. See the License for the specific language governing
...
@@ -13,22 +13,22 @@ express or implied. See the License for the specific language governing
permissions and limitations under the License. */
permissions and limitations under the License. */
#include "paddle/framework/grad_op_builder.h"
#include "paddle/framework/grad_op_builder.h"
#include "paddle/framework/framework.pb.h"
#include "paddle/framework/op_registry.h"
#include "paddle/framework/op_registry.h"
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
enum
class
OpArgType
{
IN
,
OUT
};
enum
class
OpArgType
{
IN
,
OUT
};
static
void
TransOpArg
(
const
OperatorBase
*
src_op
,
OperatorBase
*
dst_op
,
using
VarNameMap
=
OperatorBase
::
VarNameMap
;
const
OpArgType
&
src_type
,
const
OpArgType
&
dst_type
,
bool
is_grad
)
{
static
VarNameMap
TransOpArg
(
const
OperatorBase
*
src_op
,
const
OpArgType
&
src_type
,
const
OpArgType
&
dst_type
,
bool
is_grad
)
{
const
auto
&
src_inout
=
const
auto
&
src_inout
=
src_type
==
OpArgType
::
IN
?
src_op
->
inputs_
:
src_op
->
outputs_
;
src_type
==
OpArgType
::
IN
?
src_op
->
Inputs
()
:
src_op
->
Outputs
();
auto
&
dst_inout
=
VarNameMap
dst_inout
;
dst_type
==
OpArgType
::
IN
?
dst_op
->
inputs_
:
dst_op
->
outputs_
;
const
OpProto
&
proto
=
OpProtos
().
at
(
src_op
->
type_
);
const
OpProto
&
proto
=
OpProtos
().
at
(
src_op
->
Type
()
);
const
auto
&
src_arg_list
=
const
auto
&
src_arg_list
=
src_type
==
OpArgType
::
IN
?
proto
.
inputs
()
:
proto
.
outputs
();
src_type
==
OpArgType
::
IN
?
proto
.
inputs
()
:
proto
.
outputs
();
for
(
const
auto
&
arg
:
src_arg_list
)
{
for
(
const
auto
&
arg
:
src_arg_list
)
{
...
@@ -41,17 +41,23 @@ static void TransOpArg(const OperatorBase* src_op, OperatorBase* dst_op,
...
@@ -41,17 +41,23 @@ static void TransOpArg(const OperatorBase* src_op, OperatorBase* dst_op,
dst_inout
[
dst_name
].
emplace_back
(
s
);
dst_inout
[
dst_name
].
emplace_back
(
s
);
}
}
}
}
return
dst_inout
;
}
}
OperatorBase
*
BuildGradOp
(
const
OperatorBase
*
op
)
{
OperatorBase
*
BuildGradOp
(
const
OperatorBase
*
op
)
{
std
::
string
grad_op_type
=
OpRegistry
::
grad_ops
().
at
(
op
->
type_
);
std
::
string
grad_op_type
=
OpRegistry
::
grad_ops
().
at
(
op
->
Type
());
OperatorBase
*
grad_op
=
OpRegistry
::
op_creators
().
at
(
grad_op_type
)();
auto
I
=
TransOpArg
(
op
,
OpArgType
::
IN
,
OpArgType
::
IN
,
false
);
// I
grad_op
->
type_
=
grad_op_type
;
auto
O
=
TransOpArg
(
op
,
OpArgType
::
OUT
,
OpArgType
::
IN
,
false
);
// O
grad_op
->
attrs_
=
op
->
attrs_
;
auto
OG
=
TransOpArg
(
op
,
OpArgType
::
OUT
,
OpArgType
::
IN
,
true
);
// OG
TransOpArg
(
op
,
grad_op
,
OpArgType
::
IN
,
OpArgType
::
IN
,
false
);
// I
auto
IG
=
TransOpArg
(
op
,
OpArgType
::
IN
,
OpArgType
::
OUT
,
true
);
// IG
TransOpArg
(
op
,
grad_op
,
OpArgType
::
OUT
,
OpArgType
::
IN
,
false
);
// O
// TODO(merge I/O/OG)
TransOpArg
(
op
,
grad_op
,
OpArgType
::
OUT
,
OpArgType
::
IN
,
true
);
// OG
VarNameMap
GradIn
;
TransOpArg
(
op
,
grad_op
,
OpArgType
::
IN
,
OpArgType
::
OUT
,
true
);
// IG
GradIn
.
insert
(
I
.
begin
(),
I
.
end
());
GradIn
.
insert
(
O
.
begin
(),
O
.
end
());
GradIn
.
insert
(
OG
.
begin
(),
OG
.
end
());
OperatorBase
*
grad_op
=
OpRegistry
::
op_creators
().
at
(
grad_op_type
)(
grad_op_type
,
GradIn
,
IG
,
op
->
Attrs
());
return
grad_op
;
return
grad_op
;
}
}
...
...
paddle/framework/op_registry.h
浏览文件 @
5d33ef61
...
@@ -128,7 +128,11 @@ class OpRegistry {
...
@@ -128,7 +128,11 @@ class OpRegistry {
public:
public:
template
<
typename
OpType
,
typename
ProtoMakerType
>
template
<
typename
OpType
,
typename
ProtoMakerType
>
static
void
RegisterOp
(
const
std
::
string
&
op_type
)
{
static
void
RegisterOp
(
const
std
::
string
&
op_type
)
{
op_creators
()[
op_type
]
=
[]
{
return
new
OpType
;
};
op_creators
()[
op_type
]
=
[](
const
std
::
string
&
type
,
const
VarNameMap
&
inputs
,
const
VarNameMap
&
outputs
,
const
AttributeMap
&
attrs
)
{
return
new
OpType
(
type
,
inputs
,
outputs
,
attrs
);
};
OpAttrChecker
&
op_checker
=
op_checkers
()[
op_type
];
OpAttrChecker
&
op_checker
=
op_checkers
()[
op_type
];
OpProto
&
op_proto
=
OpProtos
()[
op_type
];
OpProto
&
op_proto
=
OpProtos
()[
op_type
];
auto
maker
=
ProtoMakerType
(
&
op_proto
,
&
op_checker
);
auto
maker
=
ProtoMakerType
(
&
op_proto
,
&
op_checker
);
...
@@ -143,7 +147,11 @@ class OpRegistry {
...
@@ -143,7 +147,11 @@ class OpRegistry {
template
<
typename
GradOpType
>
template
<
typename
GradOpType
>
static
void
RegisterGradOp
(
const
std
::
string
&
op_type
,
static
void
RegisterGradOp
(
const
std
::
string
&
op_type
,
const
std
::
string
&
grad_op_type
)
{
const
std
::
string
&
grad_op_type
)
{
op_creators
()[
grad_op_type
]
=
[]
{
return
new
GradOpType
;
};
op_creators
()[
grad_op_type
]
=
[](
const
std
::
string
&
type
,
const
VarNameMap
&
inputs
,
const
VarNameMap
&
outputs
,
const
AttributeMap
&
attrs
)
{
return
new
GradOpType
(
type
,
inputs
,
outputs
,
attrs
);
};
grad_ops
()[
op_type
]
=
grad_op_type
;
grad_ops
()[
op_type
]
=
grad_op_type
;
}
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录