Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleDetection
提交
59b3df31
P
PaddleDetection
项目概览
PaddlePaddle
/
PaddleDetection
大约 2 年 前同步成功
通知
708
Star
11112
Fork
2696
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
184
列表
看板
标记
里程碑
合并请求
40
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
184
Issue
184
列表
看板
标记
里程碑
合并请求
40
合并请求
40
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
59b3df31
编写于
8月 20, 2017
作者:
Y
Yu Yang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Extract OpInfo into a library
Fix cycle dependencies, Fix #3583.
上级
0d9846f3
变更
15
隐藏空白更改
内联
并排
Showing
15 changed file
with
152 addition
and
95 deletion
+152
-95
paddle/framework/CMakeLists.txt
paddle/framework/CMakeLists.txt
+2
-2
paddle/framework/backward_test.cc
paddle/framework/backward_test.cc
+2
-2
paddle/framework/grad_op_builder.cc
paddle/framework/grad_op_builder.cc
+10
-10
paddle/framework/op_info.cc
paddle/framework/op_info.cc
+30
-0
paddle/framework/op_info.h
paddle/framework/op_info.h
+42
-0
paddle/framework/op_registry.cc
paddle/framework/op_registry.cc
+18
-19
paddle/framework/op_registry.h
paddle/framework/op_registry.h
+8
-27
paddle/framework/operator.cc
paddle/framework/operator.cc
+4
-4
paddle/framework/operator.h
paddle/framework/operator.h
+14
-13
paddle/framework/operator_test.cc
paddle/framework/operator_test.cc
+5
-4
paddle/framework/pybind.cc
paddle/framework/pybind.cc
+1
-1
paddle/operators/net_op.cc
paddle/operators/net_op.cc
+2
-3
paddle/operators/net_op.h
paddle/operators/net_op.h
+4
-2
paddle/operators/recurrent_op.cc
paddle/operators/recurrent_op.cc
+4
-4
paddle/operators/recurrent_op.h
paddle/operators/recurrent_op.h
+6
-4
未找到文件。
paddle/framework/CMakeLists.txt
浏览文件 @
59b3df31
...
@@ -18,8 +18,8 @@ cc_test(scope_test SRCS scope_test.cc DEPS scope)
...
@@ -18,8 +18,8 @@ cc_test(scope_test SRCS scope_test.cc DEPS scope)
proto_library
(
framework_proto SRCS framework.proto
)
proto_library
(
framework_proto SRCS framework.proto
)
cc_library
(
attribute SRCS attribute.cc DEPS framework_proto
)
cc_library
(
attribute SRCS attribute.cc DEPS framework_proto
)
cc_library
(
op_info SRCS op_info.cc DEPS attribute framework_proto
)
cc_library
(
operator SRCS operator.cc DEPS
framework_proto device_context tensor scope attribut
e
)
cc_library
(
operator SRCS operator.cc DEPS
op_info device_context tensor scop
e
)
cc_test
(
operator_test SRCS operator_test.cc DEPS operator op_registry
)
cc_test
(
operator_test SRCS operator_test.cc DEPS operator op_registry
)
cc_library
(
grad_op_builder SRCS grad_op_builder.cc DEPS operator
)
cc_library
(
grad_op_builder SRCS grad_op_builder.cc DEPS operator
)
...
...
paddle/framework/backward_test.cc
浏览文件 @
59b3df31
...
@@ -72,8 +72,8 @@ class NoGradOpMaker : public OpProtoAndCheckerMaker {
...
@@ -72,8 +72,8 @@ class NoGradOpMaker : public OpProtoAndCheckerMaker {
class
FcOp
:
public
operators
::
NetOp
{
class
FcOp
:
public
operators
::
NetOp
{
public:
public:
FcOp
(
const
std
::
string
&
type
,
const
VarNameMap
&
inputs
,
FcOp
(
const
std
::
string
&
type
,
const
Var
iable
NameMap
&
inputs
,
const
VarNameMap
&
outputs
,
const
AttributeMap
&
attrs
)
const
Var
iable
NameMap
&
outputs
,
const
AttributeMap
&
attrs
)
:
NetOp
(
type
,
inputs
,
outputs
,
attrs
)
{
:
NetOp
(
type
,
inputs
,
outputs
,
attrs
)
{
AddOp
(
OpRegistry
::
CreateOp
(
"mul"
,
AddOp
(
OpRegistry
::
CreateOp
(
"mul"
,
{{
"X"
,
{
Input
(
"X"
)}},
{
"Y"
,
{
Input
(
"W"
)}}},
{{
"X"
,
{
Input
(
"X"
)}},
{
"Y"
,
{
Input
(
"W"
)}}},
...
...
paddle/framework/grad_op_builder.cc
浏览文件 @
59b3df31
...
@@ -20,11 +20,11 @@ namespace framework {
...
@@ -20,11 +20,11 @@ namespace framework {
enum
class
OpArgType
{
IN
,
OUT
};
enum
class
OpArgType
{
IN
,
OUT
};
static
void
TransOpArg
(
const
OperatorBase
*
src_op
,
const
OpArgType
&
src_type
,
static
void
TransOpArg
(
const
OperatorBase
*
src_op
,
const
OpArgType
&
src_type
,
bool
is_grad
,
OperatorBase
::
Var
NameMap
*
vars
)
{
bool
is_grad
,
Variable
NameMap
*
vars
)
{
const
auto
&
src_inout
=
const
auto
&
src_inout
=
src_type
==
OpArgType
::
IN
?
src_op
->
Inputs
()
:
src_op
->
Outputs
();
src_type
==
OpArgType
::
IN
?
src_op
->
Inputs
()
:
src_op
->
Outputs
();
auto
&
dst_inout
=
*
vars
;
auto
&
dst_inout
=
*
vars
;
const
OpProto
*
proto
=
Op
Registry
::
op_info_m
ap
().
at
(
src_op
->
Type
()).
proto_
;
const
OpProto
*
proto
=
Op
InfoM
ap
().
at
(
src_op
->
Type
()).
proto_
;
const
auto
&
src_arg_list
=
const
auto
&
src_arg_list
=
src_type
==
OpArgType
::
IN
?
proto
->
inputs
()
:
proto
->
outputs
();
src_type
==
OpArgType
::
IN
?
proto
->
inputs
()
:
proto
->
outputs
();
for
(
const
auto
&
arg
:
src_arg_list
)
{
for
(
const
auto
&
arg
:
src_arg_list
)
{
...
@@ -40,25 +40,25 @@ static void TransOpArg(const OperatorBase* src_op, const OpArgType& src_type,
...
@@ -40,25 +40,25 @@ static void TransOpArg(const OperatorBase* src_op, const OpArgType& src_type,
}
}
OperatorBase
*
BuildGradOp
(
const
OperatorBase
*
op
)
{
OperatorBase
*
BuildGradOp
(
const
OperatorBase
*
op
)
{
auto
it
=
Op
Registry
::
op_info_m
ap
().
find
(
op
->
Type
());
auto
it
=
Op
InfoM
ap
().
find
(
op
->
Type
());
PADDLE_ENFORCE
(
it
!=
Op
Registry
::
op_info_map
().
end
()
,
PADDLE_ENFORCE
(
it
!=
Op
InfoMap
().
end
(),
"'%s' has not been registered."
,
"'%s' has not been registered."
,
op
->
Type
());
op
->
Type
());
PADDLE_ENFORCE
(
it
->
second
.
proto_
!=
nullptr
,
"'%s' has no OpProto."
,
PADDLE_ENFORCE
(
it
->
second
.
proto_
!=
nullptr
,
"'%s' has no OpProto."
,
op
->
Type
());
op
->
Type
());
std
::
string
grad_op_type
=
it
->
second
.
grad_op_type_
;
std
::
string
grad_op_type
=
it
->
second
.
grad_op_type_
;
PADDLE_ENFORCE
(
!
grad_op_type
.
empty
(),
"'%s' has no gradient operator."
,
PADDLE_ENFORCE
(
!
grad_op_type
.
empty
(),
"'%s' has no gradient operator."
,
op
->
Type
());
op
->
Type
());
OperatorBase
::
Var
NameMap
inputs
;
Variable
NameMap
inputs
;
OperatorBase
::
Var
NameMap
outputs
;
Variable
NameMap
outputs
;
TransOpArg
(
op
,
OpArgType
::
IN
,
false
,
&
inputs
);
// I
TransOpArg
(
op
,
OpArgType
::
IN
,
false
,
&
inputs
);
// I
TransOpArg
(
op
,
OpArgType
::
OUT
,
false
,
&
inputs
);
// O
TransOpArg
(
op
,
OpArgType
::
OUT
,
false
,
&
inputs
);
// O
TransOpArg
(
op
,
OpArgType
::
OUT
,
true
,
&
inputs
);
// OG
TransOpArg
(
op
,
OpArgType
::
OUT
,
true
,
&
inputs
);
// OG
TransOpArg
(
op
,
OpArgType
::
IN
,
true
,
&
outputs
);
// IG
TransOpArg
(
op
,
OpArgType
::
IN
,
true
,
&
outputs
);
// IG
it
=
Op
Registry
::
op_info_m
ap
().
find
(
grad_op_type
);
it
=
Op
InfoM
ap
().
find
(
grad_op_type
);
PADDLE_ENFORCE
(
it
!=
Op
Registry
::
op_info_map
().
end
()
,
PADDLE_ENFORCE
(
it
!=
Op
InfoMap
().
end
(),
"'%s' has not been registered."
,
"'%s' has not been registered."
,
grad_op_type
);
grad_op_type
);
return
it
->
second
.
creator_
(
grad_op_type
,
inputs
,
outputs
,
op
->
Attrs
());
return
it
->
second
.
creator_
(
grad_op_type
,
inputs
,
outputs
,
op
->
Attrs
());
}
}
...
...
paddle/framework/op_info.cc
0 → 100644
浏览文件 @
59b3df31
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/framework/op_info.h"
namespace
paddle
{
namespace
framework
{
static
std
::
unordered_map
<
std
::
string
,
const
paddle
::
framework
::
OpInfo
>*
g_op_info_map
=
nullptr
;
std
::
unordered_map
<
std
::
string
,
const
paddle
::
framework
::
OpInfo
>&
OpInfoMap
()
{
if
(
g_op_info_map
==
nullptr
)
{
g_op_info_map
=
new
std
::
unordered_map
<
std
::
string
,
const
paddle
::
framework
::
OpInfo
>
();
}
return
*
g_op_info_map
;
}
}
// namespace framework
}
// namespace paddle
paddle/framework/op_info.h
0 → 100644
浏览文件 @
59b3df31
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <functional>
#include <map>
#include <string>
#include <unordered_map>
#include "paddle/framework/attribute.h"
namespace
paddle
{
namespace
framework
{
class
OperatorBase
;
using
VariableNameMap
=
std
::
map
<
std
::
string
,
std
::
vector
<
std
::
string
>>
;
using
OpCreator
=
std
::
function
<
OperatorBase
*
(
const
std
::
string
&
/*type*/
,
const
VariableNameMap
&
/*inputs*/
,
const
VariableNameMap
&
/*outputs*/
,
const
AttributeMap
&
/*attrs*/
)
>
;
struct
OpInfo
{
OpCreator
creator_
;
std
::
string
grad_op_type_
;
OpProto
*
proto_
;
OpAttrChecker
*
checker_
;
};
extern
std
::
unordered_map
<
std
::
string
,
const
OpInfo
>&
OpInfoMap
();
}
// namespace framework
}
// namespace paddle
paddle/framework/op_registry.cc
浏览文件 @
59b3df31
...
@@ -19,32 +19,20 @@ limitations under the License. */
...
@@ -19,32 +19,20 @@ limitations under the License. */
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
std
::
unique_ptr
<
OperatorBase
>
OpRegistry
::
CreateOp
(
const
std
::
string
&
type
,
std
::
unique_ptr
<
OperatorBase
>
OpRegistry
::
CreateOp
(
const
VarNameMap
&
inputs
,
const
std
::
string
&
type
,
const
VariableNameMap
&
inputs
,
const
VarNameMap
&
outputs
,
const
VariableNameMap
&
outputs
,
AttributeMap
attrs
)
{
AttributeMap
attrs
)
{
auto
it
=
OpInfoMap
().
find
(
type
);
auto
it
=
op_info_map
().
find
(
type
);
PADDLE_ENFORCE
(
it
!=
OpInfoMap
().
end
(),
PADDLE_ENFORCE
(
it
!=
op_info_map
().
end
(),
"Operator '%s' has not been registered."
,
type
);
"Operator '%s' has not been registered."
,
type
);
it
->
second
.
checker_
->
Check
(
attrs
);
it
->
second
.
checker_
->
Check
(
attrs
);
auto
op
=
it
->
second
.
creator_
(
type
,
inputs
,
outputs
,
attrs
);
auto
op
=
it
->
second
.
creator_
(
type
,
inputs
,
outputs
,
attrs
);
return
std
::
unique_ptr
<
OperatorBase
>
(
op
);
return
std
::
unique_ptr
<
OperatorBase
>
(
op
);
}
}
std
::
unique_ptr
<
OperatorBase
>
OpRegistry
::
CreateOp
(
const
OpDesc
&
op_desc
)
{
static
VariableNameMap
ConvertOpDescVarsToVarNameMap
(
VarNameMap
inputs
=
ConvertOpDescVarsToVarNameMap
(
op_desc
.
inputs
());
VarNameMap
outputs
=
ConvertOpDescVarsToVarNameMap
(
op_desc
.
outputs
());
AttributeMap
attrs
;
for
(
auto
&
attr
:
op_desc
.
attrs
())
{
attrs
[
attr
.
name
()]
=
GetAttrValue
(
attr
);
}
return
CreateOp
(
op_desc
.
type
(),
inputs
,
outputs
,
attrs
);
}
OperatorBase
::
VarNameMap
OpRegistry
::
ConvertOpDescVarsToVarNameMap
(
const
google
::
protobuf
::
RepeatedPtrField
<
OpDesc
::
Var
>&
op_desc_vars
)
{
const
google
::
protobuf
::
RepeatedPtrField
<
OpDesc
::
Var
>&
op_desc_vars
)
{
VarNameMap
ret_val
;
Var
iable
NameMap
ret_val
;
for
(
auto
&
var
:
op_desc_vars
)
{
for
(
auto
&
var
:
op_desc_vars
)
{
auto
&
var_names
=
ret_val
[
var
.
parameter
()];
auto
&
var_names
=
ret_val
[
var
.
parameter
()];
auto
&
var_names_in_proto
=
var
.
arguments
();
auto
&
var_names_in_proto
=
var
.
arguments
();
...
@@ -55,6 +43,17 @@ OperatorBase::VarNameMap OpRegistry::ConvertOpDescVarsToVarNameMap(
...
@@ -55,6 +43,17 @@ OperatorBase::VarNameMap OpRegistry::ConvertOpDescVarsToVarNameMap(
return
ret_val
;
return
ret_val
;
}
}
std
::
unique_ptr
<
OperatorBase
>
OpRegistry
::
CreateOp
(
const
OpDesc
&
op_desc
)
{
VariableNameMap
inputs
=
ConvertOpDescVarsToVarNameMap
(
op_desc
.
inputs
());
VariableNameMap
outputs
=
ConvertOpDescVarsToVarNameMap
(
op_desc
.
outputs
());
AttributeMap
attrs
;
for
(
auto
&
attr
:
op_desc
.
attrs
())
{
attrs
[
attr
.
name
()]
=
GetAttrValue
(
attr
);
}
return
CreateOp
(
op_desc
.
type
(),
inputs
,
outputs
,
attrs
);
}
std
::
unique_ptr
<
OperatorBase
>
OpRegistry
::
CreateGradOp
(
const
OperatorBase
&
op
)
{
std
::
unique_ptr
<
OperatorBase
>
OpRegistry
::
CreateGradOp
(
const
OperatorBase
&
op
)
{
PADDLE_ENFORCE
(
!
op
.
IsNetOp
(),
"Use framework::Backward to get backward ops"
);
PADDLE_ENFORCE
(
!
op
.
IsNetOp
(),
"Use framework::Backward to get backward ops"
);
return
std
::
unique_ptr
<
OperatorBase
>
(
BuildGradOp
(
&
op
));
return
std
::
unique_ptr
<
OperatorBase
>
(
BuildGradOp
(
&
op
));
...
...
paddle/framework/op_registry.h
浏览文件 @
59b3df31
...
@@ -23,6 +23,7 @@ limitations under the License. */
...
@@ -23,6 +23,7 @@ limitations under the License. */
#include "paddle/framework/attribute.h"
#include "paddle/framework/attribute.h"
#include "paddle/framework/framework.pb.h"
#include "paddle/framework/framework.pb.h"
#include "paddle/framework/grad_op_builder.h"
#include "paddle/framework/grad_op_builder.h"
#include "paddle/framework/op_info.h"
#include "paddle/framework/operator.h"
#include "paddle/framework/operator.h"
#include "paddle/framework/scope.h"
#include "paddle/framework/scope.h"
...
@@ -30,28 +31,16 @@ namespace paddle {
...
@@ -30,28 +31,16 @@ namespace paddle {
namespace
framework
{
namespace
framework
{
class
OpRegistry
{
class
OpRegistry
{
using
VarNameMap
=
OperatorBase
::
VarNameMap
;
using
OpCreator
=
std
::
function
<
OperatorBase
*
(
const
std
::
string
&
/*type*/
,
const
VarNameMap
&
/*inputs*/
,
const
VarNameMap
&
/*outputs*/
,
const
AttributeMap
&
/*attrs*/
)
>
;
public:
public:
struct
OpInfo
{
OpCreator
creator_
;
std
::
string
grad_op_type_
;
OpProto
*
proto_
;
OpAttrChecker
*
checker_
;
};
template
<
typename
OpType
,
typename
ProtoMakerType
,
typename
GradOpType
>
template
<
typename
OpType
,
typename
ProtoMakerType
,
typename
GradOpType
>
static
void
RegisterOp
(
const
std
::
string
&
op_type
,
static
void
RegisterOp
(
const
std
::
string
&
op_type
,
const
std
::
string
&
grad_op_type
)
{
const
std
::
string
&
grad_op_type
)
{
PADDLE_ENFORCE
(
op_info_m
ap
().
count
(
op_type
)
==
0
,
PADDLE_ENFORCE
(
OpInfoM
ap
().
count
(
op_type
)
==
0
,
"'%s' is registered more than once."
,
op_type
);
"'%s' is registered more than once."
,
op_type
);
OpInfo
op_info
;
OpInfo
op_info
;
op_info
.
creator_
=
[](
const
std
::
string
&
type
,
const
VarNameMap
&
inputs
,
op_info
.
creator_
=
[](
const
VarNameMap
&
out
puts
,
const
std
::
string
&
type
,
const
VariableNameMap
&
in
puts
,
const
AttributeMap
&
attrs
)
{
const
VariableNameMap
&
outputs
,
const
AttributeMap
&
attrs
)
{
return
new
OpType
(
type
,
inputs
,
outputs
,
attrs
);
return
new
OpType
(
type
,
inputs
,
outputs
,
attrs
);
};
};
op_info
.
grad_op_type_
=
grad_op_type
;
op_info
.
grad_op_type_
=
grad_op_type
;
...
@@ -70,7 +59,7 @@ class OpRegistry {
...
@@ -70,7 +59,7 @@ class OpRegistry {
op_info
.
proto_
=
nullptr
;
op_info
.
proto_
=
nullptr
;
op_info
.
checker_
=
nullptr
;
op_info
.
checker_
=
nullptr
;
}
}
op_info_m
ap
().
insert
(
std
::
make_pair
(
op_type
,
op_info
));
OpInfoM
ap
().
insert
(
std
::
make_pair
(
op_type
,
op_info
));
// register gradient op
// register gradient op
if
(
!
grad_op_type
.
empty
())
{
if
(
!
grad_op_type
.
empty
())
{
RegisterOp
<
GradOpType
,
NOPMaker
,
NOP
>
(
grad_op_type
,
""
);
RegisterOp
<
GradOpType
,
NOPMaker
,
NOP
>
(
grad_op_type
,
""
);
...
@@ -78,21 +67,13 @@ class OpRegistry {
...
@@ -78,21 +67,13 @@ class OpRegistry {
}
}
static
std
::
unique_ptr
<
OperatorBase
>
CreateOp
(
const
std
::
string
&
type
,
static
std
::
unique_ptr
<
OperatorBase
>
CreateOp
(
const
std
::
string
&
type
,
const
VarNameMap
&
inputs
,
const
Var
iable
NameMap
&
inputs
,
const
VarNameMap
&
outputs
,
const
Var
iable
NameMap
&
outputs
,
AttributeMap
attrs
);
AttributeMap
attrs
);
static
std
::
unique_ptr
<
OperatorBase
>
CreateOp
(
const
OpDesc
&
op_desc
);
static
std
::
unique_ptr
<
OperatorBase
>
CreateOp
(
const
OpDesc
&
op_desc
);
static
VarNameMap
ConvertOpDescVarsToVarNameMap
(
const
google
::
protobuf
::
RepeatedPtrField
<
OpDesc
::
Var
>&
op_desc_vars
);
static
std
::
unique_ptr
<
OperatorBase
>
CreateGradOp
(
const
OperatorBase
&
op
);
static
std
::
unique_ptr
<
OperatorBase
>
CreateGradOp
(
const
OperatorBase
&
op
);
static
std
::
unordered_map
<
std
::
string
,
const
OpInfo
>&
op_info_map
()
{
static
std
::
unordered_map
<
std
::
string
,
const
OpInfo
>
op_info_map_
;
return
op_info_map_
;
}
};
};
class
Registrar
{
class
Registrar
{
...
...
paddle/framework/operator.cc
浏览文件 @
59b3df31
...
@@ -115,8 +115,8 @@ void OperatorBase::Rename(const std::string& old_name,
...
@@ -115,8 +115,8 @@ void OperatorBase::Rename(const std::string& old_name,
}
}
OperatorBase
::
OperatorBase
(
const
std
::
string
&
type
,
OperatorBase
::
OperatorBase
(
const
std
::
string
&
type
,
const
OperatorBase
::
Var
NameMap
&
inputs
,
const
Variable
NameMap
&
inputs
,
const
OperatorBase
::
Var
NameMap
&
outputs
,
const
Variable
NameMap
&
outputs
,
const
AttributeMap
&
attrs
)
const
AttributeMap
&
attrs
)
:
type_
(
type
),
inputs_
(
inputs
),
outputs_
(
outputs
),
attrs_
(
attrs
)
{
:
type_
(
type
),
inputs_
(
inputs
),
outputs_
(
outputs
),
attrs_
(
attrs
)
{
static
std
::
atomic
<
size_t
>
gUniqId
(
0UL
);
static
std
::
atomic
<
size_t
>
gUniqId
(
0UL
);
...
@@ -141,9 +141,9 @@ std::vector<std::string> OperatorBase::OutputVars(bool has_intermediate) const {
...
@@ -141,9 +141,9 @@ std::vector<std::string> OperatorBase::OutputVars(bool has_intermediate) const {
}
}
return
ret_val
;
return
ret_val
;
}
}
auto
it
=
Op
Registry
::
op_info_m
ap
().
find
(
type_
);
auto
it
=
Op
InfoM
ap
().
find
(
type_
);
PADDLE_ENFORCE
(
PADDLE_ENFORCE
(
it
!=
Op
Registry
::
op_info_m
ap
().
end
(),
it
!=
Op
InfoM
ap
().
end
(),
"Operator %s not registered, cannot figure out intermediate outputs"
,
"Operator %s not registered, cannot figure out intermediate outputs"
,
type_
);
type_
);
PADDLE_ENFORCE
(
PADDLE_ENFORCE
(
...
...
paddle/framework/operator.h
浏览文件 @
59b3df31
...
@@ -19,6 +19,7 @@ limitations under the License. */
...
@@ -19,6 +19,7 @@ limitations under the License. */
#include <unordered_map>
#include <unordered_map>
#include <vector>
#include <vector>
#include "op_info.h"
#include "paddle/framework/attribute.h"
#include "paddle/framework/attribute.h"
#include "paddle/framework/framework.pb.h"
#include "paddle/framework/framework.pb.h"
#include "paddle/framework/scope.h"
#include "paddle/framework/scope.h"
...
@@ -62,10 +63,8 @@ class ExecutionContext;
...
@@ -62,10 +63,8 @@ class ExecutionContext;
*/
*/
class
OperatorBase
{
class
OperatorBase
{
public:
public:
using
VarNameMap
=
std
::
map
<
std
::
string
,
std
::
vector
<
std
::
string
>>
;
OperatorBase
(
const
std
::
string
&
type
,
const
VariableNameMap
&
inputs
,
const
VariableNameMap
&
outputs
,
const
AttributeMap
&
attrs
);
OperatorBase
(
const
std
::
string
&
type
,
const
VarNameMap
&
inputs
,
const
VarNameMap
&
outputs
,
const
AttributeMap
&
attrs
);
virtual
~
OperatorBase
()
{}
virtual
~
OperatorBase
()
{}
...
@@ -93,8 +92,8 @@ class OperatorBase {
...
@@ -93,8 +92,8 @@ class OperatorBase {
/// rename inputs outputs name
/// rename inputs outputs name
void
Rename
(
const
std
::
string
&
old_name
,
const
std
::
string
&
new_name
);
void
Rename
(
const
std
::
string
&
old_name
,
const
std
::
string
&
new_name
);
const
VarNameMap
&
Inputs
()
const
{
return
inputs_
;
}
const
Var
iable
NameMap
&
Inputs
()
const
{
return
inputs_
;
}
const
VarNameMap
&
Outputs
()
const
{
return
outputs_
;
}
const
Var
iable
NameMap
&
Outputs
()
const
{
return
outputs_
;
}
//! Get a input with argument's name described in `op_proto`
//! Get a input with argument's name described in `op_proto`
const
std
::
string
&
Input
(
const
std
::
string
&
name
)
const
;
const
std
::
string
&
Input
(
const
std
::
string
&
name
)
const
;
//! Get a input which has multiple variables.
//! Get a input which has multiple variables.
...
@@ -122,11 +121,11 @@ class OperatorBase {
...
@@ -122,11 +121,11 @@ class OperatorBase {
// I (Inputs)opear
// I (Inputs)opear
// O (Outputs)
// O (Outputs)
// OG (Output Gradients)
// OG (Output Gradients)
VarNameMap
inputs_
;
Var
iable
NameMap
inputs_
;
// NOTE: in case of OpGrad, outputs_ contains
// NOTE: in case of OpGrad, outputs_ contains
// IG (Inputs Gradients)
// IG (Inputs Gradients)
VarNameMap
outputs_
;
Var
iable
NameMap
outputs_
;
AttributeMap
attrs_
;
AttributeMap
attrs_
;
};
};
...
@@ -142,9 +141,11 @@ class OperatorBase {
...
@@ -142,9 +141,11 @@ class OperatorBase {
// You can also use
// You can also use
// using PARENT_CLASS::PARENT_CLASS;
// using PARENT_CLASS::PARENT_CLASS;
// to use parent's constructor.
// to use parent's constructor.
#define DEFINE_OP_CONSTRUCTOR(CLS, PARENT_CLS) \
#define DEFINE_OP_CONSTRUCTOR(CLS, PARENT_CLS) \
CLS(const std::string& type, const VarNameMap& inputs, \
CLS(const std::string& type, \
const VarNameMap& outputs, const paddle::framework::AttributeMap& attrs) \
const ::paddle::framework::VariableNameMap& inputs, \
const ::paddle::framework::VariableNameMap& outputs, \
const paddle::framework::AttributeMap& attrs) \
: PARENT_CLS(type, inputs, outputs, attrs) {}
: PARENT_CLS(type, inputs, outputs, attrs) {}
class
NOP
:
public
OperatorBase
{
class
NOP
:
public
OperatorBase
{
...
@@ -389,8 +390,8 @@ class OperatorWithKernel : public OperatorBase {
...
@@ -389,8 +390,8 @@ class OperatorWithKernel : public OperatorBase {
using
OpKernelMap
=
using
OpKernelMap
=
std
::
unordered_map
<
OpKernelKey
,
std
::
unique_ptr
<
OpKernel
>
,
OpKernelHash
>
;
std
::
unordered_map
<
OpKernelKey
,
std
::
unique_ptr
<
OpKernel
>
,
OpKernelHash
>
;
OperatorWithKernel
(
const
std
::
string
&
type
,
const
VarNameMap
&
inputs
,
OperatorWithKernel
(
const
std
::
string
&
type
,
const
Var
iable
NameMap
&
inputs
,
const
VarNameMap
&
outputs
,
const
AttributeMap
&
attrs
)
const
Var
iable
NameMap
&
outputs
,
const
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
InferShape
(
const
Scope
&
scope
)
const
override
{
void
InferShape
(
const
Scope
&
scope
)
const
override
{
...
...
paddle/framework/operator_test.cc
浏览文件 @
59b3df31
...
@@ -23,8 +23,8 @@ static int op_run_num = 0;
...
@@ -23,8 +23,8 @@ static int op_run_num = 0;
class
OpWithoutKernelTest
:
public
OperatorBase
{
class
OpWithoutKernelTest
:
public
OperatorBase
{
public:
public:
OpWithoutKernelTest
(
const
std
::
string
&
type
,
const
VarNameMap
&
inputs
,
OpWithoutKernelTest
(
const
std
::
string
&
type
,
const
Var
iable
NameMap
&
inputs
,
const
VarNameMap
&
outputs
,
const
AttributeMap
&
attrs
)
const
Var
iable
NameMap
&
outputs
,
const
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
),
x
(
1
)
{}
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
),
x
(
1
)
{}
void
InferShape
(
const
Scope
&
scope
)
const
override
{}
void
InferShape
(
const
Scope
&
scope
)
const
override
{}
void
Run
(
const
Scope
&
scope
,
void
Run
(
const
Scope
&
scope
,
...
@@ -249,8 +249,9 @@ TEST(OpKernel, multi_inputs) {
...
@@ -249,8 +249,9 @@ TEST(OpKernel, multi_inputs) {
class
OperatorClone
:
public
paddle
::
framework
::
OperatorBase
{
class
OperatorClone
:
public
paddle
::
framework
::
OperatorBase
{
public:
public:
DEFINE_OP_CLONE_METHOD
(
OperatorClone
);
DEFINE_OP_CLONE_METHOD
(
OperatorClone
);
OperatorClone
(
const
std
::
string
&
type
,
const
VarNameMap
&
inputs
,
OperatorClone
(
const
std
::
string
&
type
,
const
VarNameMap
&
outputs
,
const
paddle
::
framework
::
VariableNameMap
&
inputs
,
const
paddle
::
framework
::
VariableNameMap
&
outputs
,
const
paddle
::
framework
::
AttributeMap
&
attrs
)
const
paddle
::
framework
::
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
InferShape
(
const
paddle
::
framework
::
Scope
&
scope
)
const
override
{}
void
InferShape
(
const
paddle
::
framework
::
Scope
&
scope
)
const
override
{}
...
...
paddle/framework/pybind.cc
浏览文件 @
59b3df31
...
@@ -138,7 +138,7 @@ All parameter, weight, gradient are variables in Paddle.
...
@@ -138,7 +138,7 @@ All parameter, weight, gradient are variables in Paddle.
//! @note: Be careful! PyBind will return std::string as an unicode, not
//! @note: Be careful! PyBind will return std::string as an unicode, not
//! Python str. If you want a str object, you should cast them in Python.
//! Python str. If you want a str object, you should cast them in Python.
m
.
def
(
"get_all_op_protos"
,
[]()
->
std
::
vector
<
py
::
bytes
>
{
m
.
def
(
"get_all_op_protos"
,
[]()
->
std
::
vector
<
py
::
bytes
>
{
auto
&
op_info_map
=
Op
Registry
::
op_info_m
ap
();
auto
&
op_info_map
=
Op
InfoM
ap
();
std
::
vector
<
py
::
bytes
>
ret_values
;
std
::
vector
<
py
::
bytes
>
ret_values
;
for
(
auto
it
=
op_info_map
.
begin
();
it
!=
op_info_map
.
end
();
++
it
)
{
for
(
auto
it
=
op_info_map
.
begin
();
it
!=
op_info_map
.
end
();
++
it
)
{
const
OpProto
*
proto
=
it
->
second
.
proto_
;
const
OpProto
*
proto
=
it
->
second
.
proto_
;
...
...
paddle/operators/net_op.cc
浏览文件 @
59b3df31
...
@@ -81,9 +81,8 @@ std::vector<std::string> NetOp::OutputVars(bool has_intermediate) const {
...
@@ -81,9 +81,8 @@ std::vector<std::string> NetOp::OutputVars(bool has_intermediate) const {
return
ret_val
;
return
ret_val
;
}
}
NetOp
::
NetOp
(
const
std
::
string
&
type
,
NetOp
::
NetOp
(
const
std
::
string
&
type
,
const
framework
::
VariableNameMap
&
inputs
,
const
framework
::
OperatorBase
::
VarNameMap
&
inputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
OperatorBase
::
VarNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
const
framework
::
AttributeMap
&
attrs
)
:
framework
::
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
:
framework
::
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
...
...
paddle/operators/net_op.h
浏览文件 @
59b3df31
...
@@ -38,8 +38,10 @@ class NetOp : public framework::OperatorBase {
...
@@ -38,8 +38,10 @@ class NetOp : public framework::OperatorBase {
public:
public:
static
const
char
kAll
[];
static
const
char
kAll
[];
NetOp
()
:
framework
::
OperatorBase
(
"plain_net"
,
{},
{},
{})
{}
NetOp
()
:
framework
::
OperatorBase
(
"plain_net"
,
{},
{},
{})
{}
NetOp
(
const
std
::
string
&
type
,
const
VarNameMap
&
inputs
,
const
VarNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
);
NetOp
(
const
std
::
string
&
type
,
const
framework
::
VariableNameMap
&
inputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
);
NetOp
(
const
NetOp
&
o
)
:
framework
::
OperatorBase
(
o
.
type_
,
{},
{},
o
.
attrs_
)
{
NetOp
(
const
NetOp
&
o
)
:
framework
::
OperatorBase
(
o
.
type_
,
{},
{},
o
.
attrs_
)
{
this
->
ops_
.
reserve
(
o
.
ops_
.
size
());
this
->
ops_
.
reserve
(
o
.
ops_
.
size
());
...
...
paddle/operators/recurrent_op.cc
浏览文件 @
59b3df31
...
@@ -131,8 +131,8 @@ const rnn::ArgumentName RecurrentGradientOp::kArgName{
...
@@ -131,8 +131,8 @@ const rnn::ArgumentName RecurrentGradientOp::kArgName{
"memories"
,
"pre_memories"
,
"boot_memories@grad"
};
"memories"
,
"pre_memories"
,
"boot_memories@grad"
};
RecurrentOp
::
RecurrentOp
(
const
std
::
string
&
type
,
RecurrentOp
::
RecurrentOp
(
const
std
::
string
&
type
,
const
framework
::
OperatorBase
::
Var
NameMap
&
inputs
,
const
framework
::
Variable
NameMap
&
inputs
,
const
framework
::
OperatorBase
::
Var
NameMap
&
outputs
,
const
framework
::
Variable
NameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
const
framework
::
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{
rnn
::
InitArgument
(
kArgName
,
&
arg_
,
*
this
);
rnn
::
InitArgument
(
kArgName
,
&
arg_
,
*
this
);
...
@@ -223,8 +223,8 @@ void RecurrentGradientAlgorithm::InferShape(const Scope& scope) const {
...
@@ -223,8 +223,8 @@ void RecurrentGradientAlgorithm::InferShape(const Scope& scope) const {
}
}
RecurrentGradientOp
::
RecurrentGradientOp
(
RecurrentGradientOp
::
RecurrentGradientOp
(
const
std
::
string
&
type
,
const
framework
::
OperatorBase
::
Var
NameMap
&
inputs
,
const
std
::
string
&
type
,
const
framework
::
Variable
NameMap
&
inputs
,
const
framework
::
OperatorBase
::
Var
NameMap
&
outputs
,
const
framework
::
Variable
NameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
const
framework
::
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{
rnn
::
InitArgument
(
kArgName
,
&
arg_
,
*
this
);
rnn
::
InitArgument
(
kArgName
,
&
arg_
,
*
this
);
...
...
paddle/operators/recurrent_op.h
浏览文件 @
59b3df31
...
@@ -114,8 +114,9 @@ class RecurrentGradientAlgorithm {
...
@@ -114,8 +114,9 @@ class RecurrentGradientAlgorithm {
class
RecurrentOp
:
public
framework
::
OperatorBase
{
class
RecurrentOp
:
public
framework
::
OperatorBase
{
public:
public:
RecurrentOp
(
const
std
::
string
&
type
,
const
VarNameMap
&
inputs
,
RecurrentOp
(
const
std
::
string
&
type
,
const
framework
::
VariableNameMap
&
inputs
,
const
VarNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
);
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
);
RecurrentOp
(
const
RecurrentOp
&
o
)
RecurrentOp
(
const
RecurrentOp
&
o
)
:
framework
::
OperatorBase
(
:
framework
::
OperatorBase
(
...
@@ -150,8 +151,9 @@ class RecurrentOp : public framework::OperatorBase {
...
@@ -150,8 +151,9 @@ class RecurrentOp : public framework::OperatorBase {
class
RecurrentGradientOp
:
public
framework
::
OperatorBase
{
class
RecurrentGradientOp
:
public
framework
::
OperatorBase
{
public:
public:
RecurrentGradientOp
(
const
std
::
string
&
type
,
const
VarNameMap
&
inputs
,
RecurrentGradientOp
(
const
std
::
string
&
type
,
const
VarNameMap
&
outputs
,
const
framework
::
VariableNameMap
&
inputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
);
const
framework
::
AttributeMap
&
attrs
);
RecurrentGradientOp
(
const
RecurrentGradientOp
&
o
)
RecurrentGradientOp
(
const
RecurrentGradientOp
&
o
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录