Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
6768b310
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
6768b310
编写于
8月 11, 2017
作者:
F
fengjiayi
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Fix compile error
上级
3e11e4c6
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
24 addition
and
20 deletion
+24
-20
paddle/framework/grad_op_builder.cc
paddle/framework/grad_op_builder.cc
+5
-5
paddle/framework/op_registry.h
paddle/framework/op_registry.h
+16
-13
paddle/framework/operator_test.cc
paddle/framework/operator_test.cc
+3
-2
未找到文件。
paddle/framework/grad_op_builder.cc
浏览文件 @
6768b310
...
...
@@ -50,7 +50,7 @@ static void TransOpArg(const OperatorBase* src_op, OperatorBase* dst_op,
std
::
vector
<
std
::
string
>&
dst_inout
=
dst_type
==
OpArgType
::
IN
?
dst_op
->
inputs_
:
dst_op
->
outputs_
;
std
::
vector
<
int
>*
dst_format
=
GetOpFormat
(
dst_op
,
dst_type
);
const
OpProto
&
proto
=
OpRegistry
::
protos
().
at
(
src_op
->
type
_
);
const
OpProto
&
proto
=
*
(
OpRegistry
::
op_info_map
().
at
(
src_op
->
type_
).
proto
_
);
const
auto
&
src_arg_list
=
src_type
==
OpArgType
::
IN
?
proto
.
inputs
()
:
proto
.
outputs
();
...
...
@@ -76,13 +76,13 @@ static void TransOpArg(const OperatorBase* src_op, OperatorBase* dst_op,
}
OperatorBase
*
BuildGradOp
(
const
OperatorBase
*
op
)
{
auto
it
=
op_info_map
().
find
(
op
->
type_
);
auto
it
=
OpRegistry
::
op_info_map
().
find
(
op
->
type_
);
PADDLE_ENFORCE
(
it
!=
OpRegistry
::
op_info_map
().
end
(),
"'%s' has not been registered."
,
op
->
type
);
"'%s' has not been registered."
,
op
->
type
_
);
std
::
string
grad_op_type
=
it
->
second
.
grad_op_type_
;
PADDLE_ENFORCE
(
!
grad_op_type
.
empty
(),
"'%s' has no gradient operator."
,
op
->
type
);
it
=
op_info_map
().
find
(
grad_op_type
);
op
->
type
_
);
it
=
OpRegistry
::
op_info_map
().
find
(
grad_op_type
);
PADDLE_ENFORCE
(
it
!=
OpRegistry
::
op_info_map
().
end
(),
"'%s' has not been registered."
,
grad_op_type
);
OperatorBase
*
grad_op
=
it
->
second
.
creator_
();
...
...
paddle/framework/op_registry.h
浏览文件 @
6768b310
...
...
@@ -175,17 +175,20 @@ Add a mark to which output is temporary is helpful for future optimization.
bool
has_temporary_output_
{
false
};
};
class
NOPMaker
:
public
OpProtoAndCheckerMaker
{};
class
NOPMaker
:
public
OpProtoAndCheckerMaker
{
public:
NOPMaker
(
framework
::
OpProto
*
proto
,
framework
::
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{}
};
struct
OpInfo
{
std
::
function
creator_
;
std
::
function
<
OperatorBase
*
()
>
creator_
;
std
::
string
grad_op_type_
;
OpProto
*
proto_
;
OpAttrChecker
*
checker_
;
};
class
OpRegistry
{
using
OpCreator
=
std
::
function
<
OperatorBase
*
()
>
;
using
VarIndexMap
=
std
::
unordered_map
<
std
::
string
,
int
>
;
using
VarNameList
=
std
::
vector
<
std
::
string
>
;
...
...
@@ -201,28 +204,28 @@ class OpRegistry {
if
(
std
::
type_index
(
typeid
(
ProtoMakerType
))
!=
std
::
type_index
(
typeid
(
NOPMaker
)))
{
op_info
.
proto_
=
new
OpProto
;
op_info
.
op_
checker_
=
new
OpAttrChecker
;
auto
maker
=
ProtoMakerType
(
op_info
.
proto_
,
op_info
.
op_
checker_
);
op_info
.
checker_
=
new
OpAttrChecker
;
auto
maker
=
ProtoMakerType
(
op_info
.
proto_
,
op_info
.
checker_
);
maker
.
Validate
();
*
op_info
.
proto_
->
mutable_type
()
=
op_type
;
PADDLE_ENFORCE
(
op_info
.
proto_
->
IsInitialized
(),
"Fail to initialize %s's OpProto, because %s is not initialized"
,
op_type
,
op_info
.
proto_
->
InitializationErrorString
());
//
======will be refactored in following PRs============
//
//
======will be refactored in following PRs============
//
VarIndexMaps
()[
op_type
].
reset
(
new
VarIndexMap
());
auto
&
varmap
=
*
VarIndexMaps
()[
op_type
];
int
idx
=
0
;
for
(
auto
&
var
:
op_
proto
.
inputs
())
{
for
(
auto
&
var
:
op_
info
.
proto_
->
inputs
())
{
varmap
[
var
.
name
()]
=
idx
++
;
}
idx
=
0
;
for
(
auto
&
var
:
op_
proto
.
outputs
())
{
for
(
auto
&
var
:
op_
info
.
proto_
->
outputs
())
{
varmap
[
var
.
name
()]
=
idx
++
;
}
//
================================================
//
//
================================================
//
}
op_info_map
.
insert
(
std
::
make_pair
(
op_type
,
op_info
));
op_info_map
()
.
insert
(
std
::
make_pair
(
op_type
,
op_info
));
}
static
std
::
shared_ptr
<
OperatorBase
>
CreateOp
(
const
std
::
string
&
type
,
...
...
@@ -281,8 +284,8 @@ class OpRegistry {
return
grad_op
;
}
static
std
::
unordered_map
<
const
std
::
string
,
const
OpInfo
>&
op_info_map
()
{
static
std
::
unordered_map
<
const
std
::
string
,
const
OpInfo
>
op_info_map_
;
static
std
::
unordered_map
<
std
::
string
,
const
OpInfo
>&
op_info_map
()
{
static
std
::
unordered_map
<
std
::
string
,
const
OpInfo
>
op_info_map_
;
return
op_info_map_
;
}
...
...
@@ -321,7 +324,7 @@ class Registrar {
template
<
typename
OpType
,
typename
ProtoMakerType
>
class
OpRegistrar
:
public
Registrar
{
public:
OpRegistrar
(
const
char
*
op_type
)
{
OpRegistrar
(
op_type
,
""
);
}
explicit
OpRegistrar
(
const
char
*
op_type
)
{
OpRegistrar
(
op_type
,
""
);
}
OpRegistrar
(
const
char
*
op_type
,
const
char
*
grad_op_type
)
{
OpRegistry
::
RegisterOp
<
OpType
,
ProtoMakerType
>
(
op_type
,
grad_op_type
);
}
...
...
paddle/framework/operator_test.cc
浏览文件 @
6768b310
...
...
@@ -188,8 +188,9 @@ class CPUKernalMultiInputsTest : public OpKernel {
}
// namespace framework
}
// namespace paddle
REGISTER_OP
(
op_with_kernel
,
paddle
::
framework
::
OpWithKernelTest
,
paddle
::
framework
::
OpKernelTestProtoAndCheckerMaker
);
REGISTER_OP_WITHOUT_GRADIENT
(
op_with_kernel
,
paddle
::
framework
::
OpWithKernelTest
,
paddle
::
framework
::
OpKernelTestProtoAndCheckerMaker
);
REGISTER_OP_CPU_KERNEL
(
op_with_kernel
,
paddle
::
framework
::
CPUKernelTest
<
float
,
float
>
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录