Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
78af6e60
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
78af6e60
编写于
8月 09, 2017
作者:
Y
Yu Yang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add OutputVars method to get all outputs or outputs without intermediate
上级
b368c6ca
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
71 addition
and
55 deletion
+71
-55
paddle/framework/op_registry.h
paddle/framework/op_registry.h
+2
-23
paddle/framework/operator.cc
paddle/framework/operator.cc
+10
-2
paddle/framework/operator.h
paddle/framework/operator.h
+31
-0
paddle/operators/net_op.cc
paddle/operators/net_op.cc
+19
-16
paddle/operators/net_op.h
paddle/operators/net_op.h
+4
-0
paddle/operators/net_op_test.cc
paddle/operators/net_op_test.cc
+5
-14
未找到文件。
paddle/framework/op_registry.h
浏览文件 @
78af6e60
...
@@ -22,6 +22,7 @@ limitations under the License. */
...
@@ -22,6 +22,7 @@ limitations under the License. */
#include "paddle/framework/attribute.h"
#include "paddle/framework/attribute.h"
#include "paddle/framework/framework.pb.h"
#include "paddle/framework/framework.pb.h"
#include "paddle/framework/grad_op_builder.h"
#include "paddle/framework/grad_op_builder.h"
#include "paddle/framework/operator.h"
#include "paddle/framework/scope.h"
#include "paddle/framework/scope.h"
namespace
paddle
{
namespace
paddle
{
...
@@ -127,7 +128,7 @@ class OpRegistry {
...
@@ -127,7 +128,7 @@ class OpRegistry {
static
void
RegisterOp
(
const
std
::
string
&
op_type
)
{
static
void
RegisterOp
(
const
std
::
string
&
op_type
)
{
op_creators
()[
op_type
]
=
[]
{
return
new
OpType
;
};
op_creators
()[
op_type
]
=
[]
{
return
new
OpType
;
};
OpAttrChecker
&
op_checker
=
op_checkers
()[
op_type
];
OpAttrChecker
&
op_checker
=
op_checkers
()[
op_type
];
OpProto
&
op_proto
=
p
rotos
()[
op_type
];
OpProto
&
op_proto
=
OpP
rotos
()[
op_type
];
auto
maker
=
ProtoMakerType
(
&
op_proto
,
&
op_checker
);
auto
maker
=
ProtoMakerType
(
&
op_proto
,
&
op_checker
);
maker
.
Validate
();
maker
.
Validate
();
*
op_proto
.
mutable_type
()
=
op_type
;
*
op_proto
.
mutable_type
()
=
op_type
;
...
@@ -135,17 +136,6 @@ class OpRegistry {
...
@@ -135,17 +136,6 @@ class OpRegistry {
op_proto
.
IsInitialized
(),
op_proto
.
IsInitialized
(),
"Fail to initialize %s's OpProto, because %s is not initialized"
,
"Fail to initialize %s's OpProto, because %s is not initialized"
,
op_type
,
op_proto
.
InitializationErrorString
());
op_type
,
op_proto
.
InitializationErrorString
());
VarIndexMaps
()[
op_type
].
reset
(
new
VarIndexMap
());
auto
&
varmap
=
*
VarIndexMaps
()[
op_type
];
int
idx
=
0
;
for
(
auto
&
var
:
op_proto
.
inputs
())
{
varmap
[
var
.
name
()]
=
idx
++
;
}
idx
=
0
;
for
(
auto
&
var
:
op_proto
.
outputs
())
{
varmap
[
var
.
name
()]
=
idx
++
;
}
}
}
template
<
typename
GradOpType
>
template
<
typename
GradOpType
>
...
@@ -212,22 +202,11 @@ class OpRegistry {
...
@@ -212,22 +202,11 @@ class OpRegistry {
return
grad_op
;
return
grad_op
;
}
}
static
std
::
unordered_map
<
std
::
string
,
OpProto
>&
protos
()
{
static
std
::
unordered_map
<
std
::
string
,
OpProto
>
protos_
;
return
protos_
;
}
static
std
::
unordered_map
<
std
::
string
,
std
::
string
>&
grad_ops
()
{
static
std
::
unordered_map
<
std
::
string
,
std
::
string
>&
grad_ops
()
{
static
std
::
unordered_map
<
std
::
string
,
std
::
string
>
grad_ops_
;
static
std
::
unordered_map
<
std
::
string
,
std
::
string
>
grad_ops_
;
return
grad_ops_
;
return
grad_ops_
;
}
}
static
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
VarIndexMap
>>&
VarIndexMaps
()
{
static
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
VarIndexMap
>>
maps_
;
return
maps_
;
}
static
std
::
unordered_map
<
std
::
string
,
OpCreator
>&
op_creators
()
{
static
std
::
unordered_map
<
std
::
string
,
OpCreator
>&
op_creators
()
{
static
std
::
unordered_map
<
std
::
string
,
OpCreator
>
op_creators_
;
static
std
::
unordered_map
<
std
::
string
,
OpCreator
>
op_creators_
;
return
op_creators_
;
return
op_creators_
;
...
...
paddle/framework/operator.cc
浏览文件 @
78af6e60
...
@@ -12,9 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...
@@ -12,9 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#include <algorithm>
#include "paddle/framework/operator.h"
#include "paddle/framework/operator.h"
#include <algorithm>
#include "paddle/framework/op_registry.h"
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
...
@@ -33,6 +33,14 @@ ExecutionContext::GetEigenDevice<platform::GPUPlace, Eigen::GpuDevice>() const {
...
@@ -33,6 +33,14 @@ ExecutionContext::GetEigenDevice<platform::GPUPlace, Eigen::GpuDevice>() const {
}
}
#endif
#endif
static
std
::
unordered_map
<
std
::
string
,
OpProto
>*
g_op_protos
=
nullptr
;
std
::
unordered_map
<
std
::
string
,
OpProto
>&
OpProtos
()
{
if
(
g_op_protos
==
nullptr
)
{
g_op_protos
=
new
std
::
unordered_map
<
std
::
string
,
OpProto
>
();
}
return
*
g_op_protos
;
}
const
std
::
string
&
OperatorBase
::
Input
(
const
std
::
string
&
name
)
const
{
const
std
::
string
&
OperatorBase
::
Input
(
const
std
::
string
&
name
)
const
{
auto
it
=
inputs_
.
find
(
name
);
auto
it
=
inputs_
.
find
(
name
);
PADDLE_ENFORCE
(
it
!=
inputs_
.
end
(),
"Op %s does not have output %s"
,
type_
,
PADDLE_ENFORCE
(
it
!=
inputs_
.
end
(),
"Op %s does not have output %s"
,
type_
,
...
...
paddle/framework/operator.h
浏览文件 @
78af6e60
...
@@ -50,6 +50,8 @@ inline std::string GradVarName(const std::string& var_name) {
...
@@ -50,6 +50,8 @@ inline std::string GradVarName(const std::string& var_name) {
return
var_name
+
kGradVarSuffix
;
return
var_name
+
kGradVarSuffix
;
}
}
extern
std
::
unordered_map
<
std
::
string
,
OpProto
>&
OpProtos
();
class
OperatorBase
;
class
OperatorBase
;
class
InferShapeContext
;
class
InferShapeContext
;
class
ExecutionContext
;
class
ExecutionContext
;
...
@@ -103,6 +105,35 @@ class OperatorBase {
...
@@ -103,6 +105,35 @@ class OperatorBase {
//! TODO add a vector_view to prevent memory copy.
//! TODO add a vector_view to prevent memory copy.
const
std
::
vector
<
std
::
string
>&
Outputs
(
const
std
::
string
&
name
)
const
;
const
std
::
vector
<
std
::
string
>&
Outputs
(
const
std
::
string
&
name
)
const
;
virtual
std
::
vector
<
std
::
string
>
OutputVars
(
bool
has_intermediate
)
const
{
std
::
vector
<
std
::
string
>
ret_val
;
if
(
has_intermediate
)
{
// push all outputs into ret_val
for
(
auto
&
o
:
outputs_
)
{
ret_val
.
reserve
(
ret_val
.
size
()
+
o
.
second
.
size
());
ret_val
.
insert
(
ret_val
.
end
(),
o
.
second
.
begin
(),
o
.
second
.
end
());
}
return
ret_val
;
}
auto
it
=
OpProtos
().
find
(
type_
);
PADDLE_ENFORCE
(
it
!=
OpProtos
().
end
(),
"Operator %s not registered, cannot figure out intermediate outputs"
,
type_
);
// get all OpProto::Var for outputs
for
(
auto
&
o
:
it
->
second
.
outputs
())
{
// ignore all intermediate output
if
(
o
.
intermediate
())
continue
;
auto
out
=
outputs_
.
find
(
o
.
name
());
if
(
out
!=
outputs_
.
end
())
{
ret_val
.
reserve
(
ret_val
.
size
()
+
out
->
second
.
size
());
ret_val
.
insert
(
ret_val
.
end
(),
out
->
second
.
begin
(),
out
->
second
.
end
());
}
}
return
ret_val
;
}
public:
public:
std
::
string
type_
;
std
::
string
type_
;
// NOTE: in case of OpGrad, inputs_ contains:
// NOTE: in case of OpGrad, inputs_ contains:
...
...
paddle/operators/net_op.cc
浏览文件 @
78af6e60
...
@@ -21,19 +21,20 @@
...
@@ -21,19 +21,20 @@
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
const
char
NetOp
::
kAll
[]
=
"all"
;
void
NetOp
::
CompleteAddOp
(
bool
calc
)
{
void
NetOp
::
CompleteAddOp
(
bool
calc
)
{
add_op_done_
=
true
;
add_op_done_
=
true
;
if
(
!
calc
)
return
;
if
(
!
calc
)
return
;
std
::
set
<
std
::
string
>
input_set
;
std
::
set
<
std
::
string
>
input_set
;
std
::
set
<
std
::
string
>
output_set
;
std
::
set
<
std
::
string
>
output_set
;
std
::
set
<
std
::
string
>
temp_output
;
for
(
auto
&
op
:
ops_
)
{
for
(
auto
&
op
:
ops_
)
{
for
(
auto
&
ipt
:
op
->
inputs_
)
{
for
(
auto
&
ipt
:
op
->
inputs_
)
{
for
(
auto
&
var_name
:
ipt
.
second
)
{
for
(
auto
&
var_name
:
ipt
.
second
)
{
if
(
!
Contains
(
output_set
,
var_name
))
{
// Not other op's output
if
(
!
Contains
(
output_set
,
var_name
))
{
// Not other op's output
input_set
.
insert
(
var_name
);
input_set
.
insert
(
var_name
);
}
else
{
}
else
{
temp_output
.
insert
(
var_name
);
intermediate_outputs_
.
insert
(
var_name
);
}
}
}
}
}
}
...
@@ -44,24 +45,12 @@ void NetOp::CompleteAddOp(bool calc) {
...
@@ -44,24 +45,12 @@ void NetOp::CompleteAddOp(bool calc) {
}
}
}
}
}
}
auto
&
inputs
=
inputs_
[
"all"
];
auto
&
inputs
=
inputs_
[
kAll
];
inputs
.
reserve
(
input_set
.
size
());
inputs
.
reserve
(
input_set
.
size
());
std
::
copy
(
input_set
.
begin
(),
input_set
.
end
(),
std
::
back_inserter
(
inputs
));
std
::
copy
(
input_set
.
begin
(),
input_set
.
end
(),
std
::
back_inserter
(
inputs
));
auto
&
outputs
=
outputs_
[
"all"
];
auto
&
outputs
=
outputs_
[
kAll
];
outputs
.
reserve
(
output_set
.
size
());
outputs
.
reserve
(
output_set
.
size
());
std
::
copy
(
output_set
.
begin
(),
output_set
.
end
(),
std
::
back_inserter
(
outputs
));
std
::
copy
(
output_set
.
begin
(),
output_set
.
end
(),
std
::
back_inserter
(
outputs
));
//! TODO figure out how to generate temporary_index in Network.
std
::
vector
<
int
>
tmp_index
;
tmp_index
.
reserve
(
temp_output
.
size
());
int
output_len
=
static_cast
<
int
>
(
outputs
.
size
());
for
(
int
i
=
0
;
i
<
output_len
;
++
i
)
{
if
(
Contains
(
temp_output
,
outputs
[
i
]))
{
tmp_index
.
push_back
(
i
);
}
}
attrs_
[
"temporary_index"
]
=
tmp_index
;
}
}
std
::
string
NetOp
::
DebugString
()
const
{
std
::
string
NetOp
::
DebugString
()
const
{
...
@@ -78,5 +67,19 @@ std::string NetOp::DebugString() const {
...
@@ -78,5 +67,19 @@ std::string NetOp::DebugString() const {
bool
NetOp
::
IsNetOp
()
const
{
return
true
;
}
bool
NetOp
::
IsNetOp
()
const
{
return
true
;
}
std
::
vector
<
std
::
string
>
NetOp
::
OutputVars
(
bool
has_intermediate
)
const
{
if
(
has_intermediate
)
{
return
this
->
outputs_
.
at
(
kAll
);
}
auto
&
all
=
this
->
outputs_
.
at
(
kAll
);
std
::
vector
<
std
::
string
>
ret_val
;
for
(
auto
&
each
:
all
)
{
if
(
!
Contains
(
intermediate_outputs_
,
each
))
{
ret_val
.
push_back
(
each
);
}
}
return
ret_val
;
}
}
// namespace operators
}
// namespace operators
}
// namespace paddle
}
// namespace paddle
paddle/operators/net_op.h
浏览文件 @
78af6e60
...
@@ -36,6 +36,8 @@ namespace operators {
...
@@ -36,6 +36,8 @@ namespace operators {
*/
*/
class
NetOp
:
public
framework
::
OperatorBase
{
class
NetOp
:
public
framework
::
OperatorBase
{
public:
public:
static
const
char
kAll
[];
/**
/**
* Infer all the operators' input and output variables' shapes, will be called
* Infer all the operators' input and output variables' shapes, will be called
* before every mini-batch
* before every mini-batch
...
@@ -91,11 +93,13 @@ class NetOp : public framework::OperatorBase {
...
@@ -91,11 +93,13 @@ class NetOp : public framework::OperatorBase {
std
::
string
DebugString
()
const
override
;
std
::
string
DebugString
()
const
override
;
bool
IsNetOp
()
const
override
;
bool
IsNetOp
()
const
override
;
std
::
vector
<
std
::
string
>
OutputVars
(
bool
has_intermediate
)
const
override
;
std
::
vector
<
std
::
shared_ptr
<
OperatorBase
>>
ops_
;
std
::
vector
<
std
::
shared_ptr
<
OperatorBase
>>
ops_
;
private:
private:
bool
add_op_done_
{
false
};
bool
add_op_done_
{
false
};
std
::
set
<
std
::
string
>
intermediate_outputs_
;
template
<
typename
T
,
typename
KeyType
>
template
<
typename
T
,
typename
KeyType
>
static
bool
Contains
(
T
container
,
KeyType
key
)
{
static
bool
Contains
(
T
container
,
KeyType
key
)
{
...
...
paddle/operators/net_op_test.cc
浏览文件 @
78af6e60
...
@@ -54,22 +54,13 @@ TEST(OpKernel, all) {
...
@@ -54,22 +54,13 @@ TEST(OpKernel, all) {
net
->
CompleteAddOp
();
net
->
CompleteAddOp
();
AssertSameVectorWithoutOrder
({
"x"
,
"w1"
,
"b1"
,
"w2"
,
"b2"
},
AssertSameVectorWithoutOrder
({
"x"
,
"w1"
,
"b1"
,
"w2"
,
"b2"
},
net
->
inputs_
.
at
(
"__all__"
));
net
->
inputs_
.
at
(
NetOp
::
kAll
));
AssertSameVectorWithoutOrder
({
"y"
,
"z"
},
net
->
outputs_
.
at
(
"__all__"
));
AssertSameVectorWithoutOrder
({
"y"
,
"z"
},
net
->
outputs_
.
at
(
NetOp
::
kAll
));
auto
tmp_idx_iter
=
net
->
attrs_
.
find
(
"temporary_index"
);
ASSERT_NE
(
net
->
attrs_
.
end
(),
tmp_idx_iter
);
auto
&
tmp_idx
=
boost
::
get
<
std
::
vector
<
int
>>
(
tmp_idx_iter
->
second
);
ASSERT_EQ
(
1UL
,
tmp_idx
.
size
());
ASSERT_EQ
(
"y"
,
net
->
outputs_
.
at
(
"__all__"
)[
tmp_idx
[
0
]]);
Scope
scope
;
auto
final_outs
=
net
->
OutputVars
(
false
);
platform
::
CPUDeviceContext
dev_ctx
;
net
->
InferShape
(
scope
);
ASSERT_EQ
(
final_outs
.
size
(),
1UL
);
net
->
Run
(
scope
,
dev_ctx
);
ASSERT_EQ
(
final_outs
[
0
],
"z"
);
ASSERT_EQ
(
2
,
infer_shape_cnt
);
ASSERT_EQ
(
2
,
run_cnt
);
ASSERT_THROW
(
net
->
AddOp
(
op2
),
platform
::
EnforceNotMet
);
}
}
TEST
(
NetOp
,
insert_op
)
{
TEST
(
NetOp
,
insert_op
)
{
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录