Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
875946ff
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
875946ff
编写于
7月 18, 2017
作者:
Y
Yu Yang
提交者:
GitHub
7月 18, 2017
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #2935 from reyoung/feature/create_op_use_cpp_params
Change `in_out_idxs_` to shared_ptr
上级
2db1b68d
c1219a53
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
34 addition
and
30 deletion
+34
-30
paddle/framework/op_registry.h
paddle/framework/op_registry.h
+25
-8
paddle/framework/operator.cc
paddle/framework/operator.cc
+8
-18
paddle/framework/operator.h
paddle/framework/operator.h
+1
-4
未找到文件。
paddle/framework/op_registry.h
浏览文件 @
875946ff
...
@@ -198,6 +198,7 @@ Add a mark to which output is temporary is helpful for future optimization.
...
@@ -198,6 +198,7 @@ Add a mark to which output is temporary is helpful for future optimization.
class
OpRegistry
{
class
OpRegistry
{
using
OpCreator
=
std
::
function
<
OperatorBase
*
()
>
;
using
OpCreator
=
std
::
function
<
OperatorBase
*
()
>
;
using
VarIndexMap
=
std
::
unordered_map
<
std
::
string
,
int
>
;
public:
public:
template
<
typename
OpType
,
typename
ProtoMakerType
>
template
<
typename
OpType
,
typename
ProtoMakerType
>
...
@@ -212,6 +213,17 @@ class OpRegistry {
...
@@ -212,6 +213,17 @@ class OpRegistry {
op_proto
.
IsInitialized
(),
op_proto
.
IsInitialized
(),
"Fail to initialize %s's OpProto, because %s is not initialized"
,
"Fail to initialize %s's OpProto, because %s is not initialized"
,
op_type
,
op_proto
.
InitializationErrorString
());
op_type
,
op_proto
.
InitializationErrorString
());
VarIndexMaps
()[
op_type
].
reset
(
new
VarIndexMap
());
auto
&
varmap
=
*
VarIndexMaps
()[
op_type
];
int
idx
=
0
;
for
(
auto
&
var
:
op_proto
.
inputs
())
{
varmap
[
var
.
name
()]
=
idx
++
;
}
idx
=
0
;
for
(
auto
&
var
:
op_proto
.
outputs
())
{
varmap
[
var
.
name
()]
=
idx
++
;
}
}
}
static
OperatorPtr
CreateOp
(
const
OpDesc
&
op_desc
)
{
static
OperatorPtr
CreateOp
(
const
OpDesc
&
op_desc
)
{
...
@@ -220,7 +232,6 @@ class OpRegistry {
...
@@ -220,7 +232,6 @@ class OpRegistry {
OperatorPtr
op
(
creators
().
at
(
op_type
)());
OperatorPtr
op
(
creators
().
at
(
op_type
)());
//! Fill op's data member. Not use constructor because it will be noising
//! Fill op's data member. Not use constructor because it will be noising
//! for Op developer.
//! for Op developer.
const
OpProto
&
op_proto
=
protos
().
at
(
op_type
);
op
->
type_
=
op_desc
.
type
();
op
->
type_
=
op_desc
.
type
();
// set op's inputs_ from desc.
// set op's inputs_ from desc.
op
->
inputs_
.
reserve
((
size_t
)
op_desc
.
inputs_size
());
op
->
inputs_
.
reserve
((
size_t
)
op_desc
.
inputs_size
());
...
@@ -240,25 +251,31 @@ class OpRegistry {
...
@@ -240,25 +251,31 @@ class OpRegistry {
//! Convert Temporary variable name to an unique variable name.
//! Convert Temporary variable name to an unique variable name.
GenerateTempVariableName
(
op
.
get
());
GenerateTempVariableName
(
op
.
get
());
// set argument offsets stored in op.
//! set argument offsets stored in op.
CreateInOutOffsetMap
(
op
,
op_proto
);
{
auto
var_index_it
=
VarIndexMaps
().
find
(
op_type
);
if
(
var_index_it
!=
VarIndexMaps
().
end
())
{
op
->
in_out_idxs_
=
var_index_it
->
second
;
}
}
//! Other op's custom Init for a complex Op. For simple Op, the Init
//! Other op's custom Init for a complex Op. For simple Op, the Init
//! method do nothing.
//! method do nothing.
op
->
Init
();
op
->
Init
();
return
op
;
return
op
;
}
}
// init op.in_out_idxs_ to accelerate argument's offset lookup.
static
void
CreateInOutOffsetMap
(
OperatorPtr
op
,
const
OpProto
&
proto
)
{
op
->
CreateInOutOffsetMap
(
proto
);
}
static
std
::
unordered_map
<
std
::
string
,
OpProto
>&
protos
()
{
static
std
::
unordered_map
<
std
::
string
,
OpProto
>&
protos
()
{
static
std
::
unordered_map
<
std
::
string
,
OpProto
>
protos_
;
static
std
::
unordered_map
<
std
::
string
,
OpProto
>
protos_
;
return
protos_
;
return
protos_
;
};
};
private:
private:
static
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
VarIndexMap
>>&
VarIndexMaps
()
{
static
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
VarIndexMap
>>
maps_
;
return
maps_
;
}
static
void
GenerateTempVariableName
(
OperatorBase
*
op
)
{
static
void
GenerateTempVariableName
(
OperatorBase
*
op
)
{
static
std
::
atomic
<
size_t
>
gUniqId
(
0UL
);
static
std
::
atomic
<
size_t
>
gUniqId
(
0UL
);
for
(
auto
&
outname
:
op
->
outputs_
)
{
for
(
auto
&
outname
:
op
->
outputs_
)
{
...
...
paddle/framework/operator.cc
浏览文件 @
875946ff
...
@@ -19,21 +19,10 @@ limitations under the License. */
...
@@ -19,21 +19,10 @@ limitations under the License. */
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
void
OperatorBase
::
CreateInOutOffsetMap
(
const
OpProto
&
proto
)
{
PADDLE_ENFORCE
(
in_out_idxs_
.
empty
(),
"duplicate call CreateInOutOffsetMap"
);
for
(
int
i
=
0
;
i
<
proto
.
inputs_size
();
i
++
)
{
const
auto
&
name
=
proto
.
inputs
()[
i
].
name
();
in_out_idxs_
[
name
]
=
i
;
}
for
(
int
i
=
0
;
i
<
proto
.
outputs_size
();
i
++
)
{
const
auto
&
name
=
proto
.
outputs
()[
i
].
name
();
in_out_idxs_
[
name
]
=
i
;
}
}
const
std
::
string
&
OperatorBase
::
Input
(
const
std
::
string
&
name
)
const
{
const
std
::
string
&
OperatorBase
::
Input
(
const
std
::
string
&
name
)
const
{
auto
it
=
in_out_idxs_
.
find
(
name
);
auto
it
=
in_out_idxs_
->
find
(
name
);
PADDLE_ENFORCE
(
it
!=
in_out_idxs_
.
end
(),
"no key [%s] in in_out_idxs_"
,
name
);
PADDLE_ENFORCE
(
it
!=
in_out_idxs_
->
end
(),
"no key [%s] in in_out_idxs_"
,
name
);
if
(
attrs_
.
count
(
"input_format"
)
==
0
)
{
if
(
attrs_
.
count
(
"input_format"
)
==
0
)
{
return
inputs_
[
it
->
second
];
return
inputs_
[
it
->
second
];
...
@@ -46,7 +35,7 @@ const std::string& OperatorBase::Input(const std::string& name) const {
...
@@ -46,7 +35,7 @@ const std::string& OperatorBase::Input(const std::string& name) const {
std
::
vector
<
std
::
string
>
OperatorBase
::
Inputs
(
const
std
::
string
&
name
)
const
{
std
::
vector
<
std
::
string
>
OperatorBase
::
Inputs
(
const
std
::
string
&
name
)
const
{
auto
input_format
=
GetAttr
<
std
::
vector
<
int
>>
(
"input_format"
);
auto
input_format
=
GetAttr
<
std
::
vector
<
int
>>
(
"input_format"
);
auto
offset
=
in_out_idxs_
.
at
(
name
);
auto
offset
=
in_out_idxs_
->
at
(
name
);
return
std
::
vector
<
std
::
string
>
{
return
std
::
vector
<
std
::
string
>
{
inputs_
.
begin
()
+
input_format
.
at
(
offset
),
inputs_
.
begin
()
+
input_format
.
at
(
offset
),
...
@@ -54,8 +43,9 @@ std::vector<std::string> OperatorBase::Inputs(const std::string& name) const {
...
@@ -54,8 +43,9 @@ std::vector<std::string> OperatorBase::Inputs(const std::string& name) const {
}
}
const
std
::
string
&
OperatorBase
::
Output
(
const
std
::
string
&
name
)
const
{
const
std
::
string
&
OperatorBase
::
Output
(
const
std
::
string
&
name
)
const
{
auto
it
=
in_out_idxs_
.
find
(
name
);
auto
it
=
in_out_idxs_
->
find
(
name
);
PADDLE_ENFORCE
(
it
!=
in_out_idxs_
.
end
(),
"no key [%s] in in_out_idxs_"
,
name
);
PADDLE_ENFORCE
(
it
!=
in_out_idxs_
->
end
(),
"no key [%s] in in_out_idxs_"
,
name
);
if
(
attrs_
.
count
(
"output_format"
)
==
0
)
{
if
(
attrs_
.
count
(
"output_format"
)
==
0
)
{
return
outputs_
[
it
->
second
];
return
outputs_
[
it
->
second
];
...
@@ -68,7 +58,7 @@ const std::string& OperatorBase::Output(const std::string& name) const {
...
@@ -68,7 +58,7 @@ const std::string& OperatorBase::Output(const std::string& name) const {
std
::
vector
<
std
::
string
>
OperatorBase
::
Outputs
(
const
std
::
string
&
name
)
const
{
std
::
vector
<
std
::
string
>
OperatorBase
::
Outputs
(
const
std
::
string
&
name
)
const
{
auto
output_format
=
GetAttr
<
std
::
vector
<
int
>>
(
"output_format"
);
auto
output_format
=
GetAttr
<
std
::
vector
<
int
>>
(
"output_format"
);
auto
offset
=
in_out_idxs_
.
at
(
name
);
auto
offset
=
in_out_idxs_
->
at
(
name
);
return
std
::
vector
<
std
::
string
>
{
return
std
::
vector
<
std
::
string
>
{
outputs_
.
begin
()
+
output_format
.
at
(
offset
),
outputs_
.
begin
()
+
output_format
.
at
(
offset
),
...
...
paddle/framework/operator.h
浏览文件 @
875946ff
...
@@ -82,16 +82,13 @@ class OperatorBase {
...
@@ -82,16 +82,13 @@ class OperatorBase {
// TODO add a vector_view to prevent memory copy.
// TODO add a vector_view to prevent memory copy.
std
::
vector
<
std
::
string
>
Outputs
(
const
std
::
string
&
name
)
const
;
std
::
vector
<
std
::
string
>
Outputs
(
const
std
::
string
&
name
)
const
;
// init in_out_idxs_ to accelerate argument's offset lookup.
void
CreateInOutOffsetMap
(
const
OpProto
&
proto
);
public:
public:
std
::
string
type_
;
std
::
string
type_
;
std
::
vector
<
std
::
string
>
inputs_
;
std
::
vector
<
std
::
string
>
inputs_
;
std
::
vector
<
std
::
string
>
outputs_
;
std
::
vector
<
std
::
string
>
outputs_
;
AttributeMap
attrs_
;
AttributeMap
attrs_
;
// store the arguments' offset described in op_desc.
// store the arguments' offset described in op_desc.
std
::
unordered_map
<
std
::
string
,
int
>
in_out_idxs_
;
std
::
shared_ptr
<
std
::
unordered_map
<
std
::
string
,
int
>
>
in_out_idxs_
;
};
};
class
KernelContext
{
class
KernelContext
{
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录