Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
edb541f2
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
edb541f2
编写于
8月 14, 2017
作者:
F
fengjiayi
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix compile errors
上级
3e6e5c92
变更
3
显示空白变更内容
内联
并排
Showing
3 changed file
with
15 addition
and
12 deletion
+15
-12
paddle/framework/grad_op_builder.cc
paddle/framework/grad_op_builder.cc
+4
-1
paddle/framework/op_registry.h
paddle/framework/op_registry.h
+10
-10
paddle/framework/operator.cc
paddle/framework/operator.cc
+1
-1
未找到文件。
paddle/framework/grad_op_builder.cc
浏览文件 @
edb541f2
...
...
@@ -25,8 +25,9 @@ static void TransOpArg(const OperatorBase* src_op, const OpArgType& src_type,
const
auto
&
src_inout
=
src_type
==
OpArgType
::
IN
?
src_op
->
inputs_
:
src_op
->
outputs_
;
auto
&
dst_inout
=
*
vars
;
const
OpProto
*
proto
=
OpRegistry
::
op_info_map
().
at
(
src_op
->
type_
).
proto_
;
const
auto
&
src_arg_list
=
src_type
==
OpArgType
::
IN
?
proto
.
inputs
()
:
proto
.
outputs
();
src_type
==
OpArgType
::
IN
?
proto
->
inputs
()
:
proto
->
outputs
();
for
(
const
auto
&
arg
:
src_arg_list
)
{
if
(
arg
.
no_gradient
()
&&
!
is_grad
)
continue
;
const
std
::
string
src_name
=
arg
.
name
();
...
...
@@ -43,6 +44,8 @@ OperatorBase* BuildGradOp(const OperatorBase* op) {
auto
it
=
OpRegistry
::
op_info_map
().
find
(
op
->
type_
);
PADDLE_ENFORCE
(
it
!=
OpRegistry
::
op_info_map
().
end
(),
"'%s' has not been registered."
,
op
->
type_
);
PADDLE_ENFORCE
(
it
->
second
.
proto_
!=
nullptr
,
"'%s' has no OpProto."
,
op
->
type_
);
std
::
string
grad_op_type
=
it
->
second
.
grad_op_type_
;
PADDLE_ENFORCE
(
!
grad_op_type
.
empty
(),
"'%s' has no gradient operator."
,
op
->
type_
);
...
...
paddle/framework/op_registry.h
浏览文件 @
edb541f2
...
...
@@ -126,13 +126,6 @@ class NOPMaker : public OpProtoAndCheckerMaker {
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{}
};
struct
OpInfo
{
std
::
function
<
OperatorBase
*
()
>
creator_
;
std
::
string
grad_op_type_
;
OpProto
*
proto_
;
OpAttrChecker
*
checker_
;
};
class
OpRegistry
{
using
VarNameMap
=
OperatorBase
::
VarNameMap
;
using
OpCreator
=
std
::
function
<
OperatorBase
*
(
...
...
@@ -140,6 +133,13 @@ class OpRegistry {
const
VarNameMap
&
/*outputs*/
,
const
AttributeMap
&
/*attrs*/
)
>
;
public:
struct
OpInfo
{
OpCreator
creator_
;
std
::
string
grad_op_type_
;
OpProto
*
proto_
;
OpAttrChecker
*
checker_
;
};
template
<
typename
OpType
,
typename
ProtoMakerType
,
typename
GradOpType
>
static
void
RegisterOp
(
const
std
::
string
&
op_type
,
const
std
::
string
&
grad_op_type
)
{
...
...
@@ -175,9 +175,9 @@ class OpRegistry {
}
static
std
::
shared_ptr
<
OperatorBase
>
CreateOp
(
const
std
::
string
&
type
,
const
VarName
List
&
inputs
,
const
VarName
List
&
outputs
,
const
AttributeMap
&
attrs
)
{
const
VarName
Map
&
inputs
,
const
VarName
Map
&
outputs
,
AttributeMap
attrs
)
{
auto
it
=
op_info_map
().
find
(
type
);
PADDLE_ENFORCE
(
it
!=
op_info_map
().
end
(),
"Operator '%s' has not been registered."
,
type
);
...
...
paddle/framework/operator.cc
浏览文件 @
edb541f2
...
...
@@ -152,7 +152,7 @@ std::vector<std::string> OperatorBase::OutputVars(bool has_intermediate) const {
type_
);
// get all OpProto::Var for outputs
for
(
auto
&
o
:
it
->
second
.
proto_
.
outputs
())
{
for
(
auto
&
o
:
it
->
second
.
proto_
->
outputs
())
{
// ignore all intermediate output
if
(
o
.
intermediate
())
continue
;
auto
out
=
outputs_
.
find
(
o
.
name
());
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录