Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
78af6e60
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
78af6e60
编写于
8月 09, 2017
作者:
Y
Yu Yang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add OutputVars method to get all outputs or outputs without intermediate
上级
b368c6ca
变更
6
显示空白变更内容
内联
并排
Showing
6 changed file
with
71 addition
and
55 deletion
+71
-55
paddle/framework/op_registry.h
paddle/framework/op_registry.h
+2
-23
paddle/framework/operator.cc
paddle/framework/operator.cc
+10
-2
paddle/framework/operator.h
paddle/framework/operator.h
+31
-0
paddle/operators/net_op.cc
paddle/operators/net_op.cc
+19
-16
paddle/operators/net_op.h
paddle/operators/net_op.h
+4
-0
paddle/operators/net_op_test.cc
paddle/operators/net_op_test.cc
+5
-14
未找到文件。
paddle/framework/op_registry.h
浏览文件 @
78af6e60
...
...
@@ -22,6 +22,7 @@ limitations under the License. */
#include "paddle/framework/attribute.h"
#include "paddle/framework/framework.pb.h"
#include "paddle/framework/grad_op_builder.h"
#include "paddle/framework/operator.h"
#include "paddle/framework/scope.h"
namespace
paddle
{
...
...
@@ -127,7 +128,7 @@ class OpRegistry {
static
void
RegisterOp
(
const
std
::
string
&
op_type
)
{
op_creators
()[
op_type
]
=
[]
{
return
new
OpType
;
};
OpAttrChecker
&
op_checker
=
op_checkers
()[
op_type
];
OpProto
&
op_proto
=
p
rotos
()[
op_type
];
OpProto
&
op_proto
=
OpP
rotos
()[
op_type
];
auto
maker
=
ProtoMakerType
(
&
op_proto
,
&
op_checker
);
maker
.
Validate
();
*
op_proto
.
mutable_type
()
=
op_type
;
...
...
@@ -135,17 +136,6 @@ class OpRegistry {
op_proto
.
IsInitialized
(),
"Fail to initialize %s's OpProto, because %s is not initialized"
,
op_type
,
op_proto
.
InitializationErrorString
());
VarIndexMaps
()[
op_type
].
reset
(
new
VarIndexMap
());
auto
&
varmap
=
*
VarIndexMaps
()[
op_type
];
int
idx
=
0
;
for
(
auto
&
var
:
op_proto
.
inputs
())
{
varmap
[
var
.
name
()]
=
idx
++
;
}
idx
=
0
;
for
(
auto
&
var
:
op_proto
.
outputs
())
{
varmap
[
var
.
name
()]
=
idx
++
;
}
}
template
<
typename
GradOpType
>
...
...
@@ -212,22 +202,11 @@ class OpRegistry {
return
grad_op
;
}
static
std
::
unordered_map
<
std
::
string
,
OpProto
>&
protos
()
{
static
std
::
unordered_map
<
std
::
string
,
OpProto
>
protos_
;
return
protos_
;
}
static
std
::
unordered_map
<
std
::
string
,
std
::
string
>&
grad_ops
()
{
static
std
::
unordered_map
<
std
::
string
,
std
::
string
>
grad_ops_
;
return
grad_ops_
;
}
static
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
VarIndexMap
>>&
VarIndexMaps
()
{
static
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
VarIndexMap
>>
maps_
;
return
maps_
;
}
static
std
::
unordered_map
<
std
::
string
,
OpCreator
>&
op_creators
()
{
static
std
::
unordered_map
<
std
::
string
,
OpCreator
>
op_creators_
;
return
op_creators_
;
...
...
paddle/framework/operator.cc
浏览文件 @
78af6e60
...
...
@@ -12,9 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <algorithm>
#include "paddle/framework/operator.h"
#include <algorithm>
#include "paddle/framework/op_registry.h"
namespace
paddle
{
namespace
framework
{
...
...
@@ -33,6 +33,14 @@ ExecutionContext::GetEigenDevice<platform::GPUPlace, Eigen::GpuDevice>() const {
}
#endif
static
std
::
unordered_map
<
std
::
string
,
OpProto
>*
g_op_protos
=
nullptr
;
std
::
unordered_map
<
std
::
string
,
OpProto
>&
OpProtos
()
{
if
(
g_op_protos
==
nullptr
)
{
g_op_protos
=
new
std
::
unordered_map
<
std
::
string
,
OpProto
>
();
}
return
*
g_op_protos
;
}
const
std
::
string
&
OperatorBase
::
Input
(
const
std
::
string
&
name
)
const
{
auto
it
=
inputs_
.
find
(
name
);
PADDLE_ENFORCE
(
it
!=
inputs_
.
end
(),
"Op %s does not have output %s"
,
type_
,
...
...
paddle/framework/operator.h
浏览文件 @
78af6e60
...
...
@@ -50,6 +50,8 @@ inline std::string GradVarName(const std::string& var_name) {
return
var_name
+
kGradVarSuffix
;
}
extern
std
::
unordered_map
<
std
::
string
,
OpProto
>&
OpProtos
();
class
OperatorBase
;
class
InferShapeContext
;
class
ExecutionContext
;
...
...
@@ -103,6 +105,35 @@ class OperatorBase {
//! TODO add a vector_view to prevent memory copy.
const
std
::
vector
<
std
::
string
>&
Outputs
(
const
std
::
string
&
name
)
const
;
virtual
std
::
vector
<
std
::
string
>
OutputVars
(
bool
has_intermediate
)
const
{
std
::
vector
<
std
::
string
>
ret_val
;
if
(
has_intermediate
)
{
// push all outputs into ret_val
for
(
auto
&
o
:
outputs_
)
{
ret_val
.
reserve
(
ret_val
.
size
()
+
o
.
second
.
size
());
ret_val
.
insert
(
ret_val
.
end
(),
o
.
second
.
begin
(),
o
.
second
.
end
());
}
return
ret_val
;
}
auto
it
=
OpProtos
().
find
(
type_
);
PADDLE_ENFORCE
(
it
!=
OpProtos
().
end
(),
"Operator %s not registered, cannot figure out intermediate outputs"
,
type_
);
// get all OpProto::Var for outputs
for
(
auto
&
o
:
it
->
second
.
outputs
())
{
// ignore all intermediate output
if
(
o
.
intermediate
())
continue
;
auto
out
=
outputs_
.
find
(
o
.
name
());
if
(
out
!=
outputs_
.
end
())
{
ret_val
.
reserve
(
ret_val
.
size
()
+
out
->
second
.
size
());
ret_val
.
insert
(
ret_val
.
end
(),
out
->
second
.
begin
(),
out
->
second
.
end
());
}
}
return
ret_val
;
}
public:
std
::
string
type_
;
// NOTE: in case of OpGrad, inputs_ contains:
...
...
paddle/operators/net_op.cc
浏览文件 @
78af6e60
...
...
@@ -21,19 +21,20 @@
namespace
paddle
{
namespace
operators
{
const
char
NetOp
::
kAll
[]
=
"all"
;
void
NetOp
::
CompleteAddOp
(
bool
calc
)
{
add_op_done_
=
true
;
if
(
!
calc
)
return
;
std
::
set
<
std
::
string
>
input_set
;
std
::
set
<
std
::
string
>
output_set
;
std
::
set
<
std
::
string
>
temp_output
;
for
(
auto
&
op
:
ops_
)
{
for
(
auto
&
ipt
:
op
->
inputs_
)
{
for
(
auto
&
var_name
:
ipt
.
second
)
{
if
(
!
Contains
(
output_set
,
var_name
))
{
// Not other op's output
input_set
.
insert
(
var_name
);
}
else
{
temp_output
.
insert
(
var_name
);
intermediate_outputs_
.
insert
(
var_name
);
}
}
}
...
...
@@ -44,24 +45,12 @@ void NetOp::CompleteAddOp(bool calc) {
}
}
}
auto
&
inputs
=
inputs_
[
"all"
];
auto
&
inputs
=
inputs_
[
kAll
];
inputs
.
reserve
(
input_set
.
size
());
std
::
copy
(
input_set
.
begin
(),
input_set
.
end
(),
std
::
back_inserter
(
inputs
));
auto
&
outputs
=
outputs_
[
"all"
];
auto
&
outputs
=
outputs_
[
kAll
];
outputs
.
reserve
(
output_set
.
size
());
std
::
copy
(
output_set
.
begin
(),
output_set
.
end
(),
std
::
back_inserter
(
outputs
));
//! TODO figure out how to generate temporary_index in Network.
std
::
vector
<
int
>
tmp_index
;
tmp_index
.
reserve
(
temp_output
.
size
());
int
output_len
=
static_cast
<
int
>
(
outputs
.
size
());
for
(
int
i
=
0
;
i
<
output_len
;
++
i
)
{
if
(
Contains
(
temp_output
,
outputs
[
i
]))
{
tmp_index
.
push_back
(
i
);
}
}
attrs_
[
"temporary_index"
]
=
tmp_index
;
}
std
::
string
NetOp
::
DebugString
()
const
{
...
...
@@ -78,5 +67,19 @@ std::string NetOp::DebugString() const {
bool
NetOp
::
IsNetOp
()
const
{
return
true
;
}
std
::
vector
<
std
::
string
>
NetOp
::
OutputVars
(
bool
has_intermediate
)
const
{
if
(
has_intermediate
)
{
return
this
->
outputs_
.
at
(
kAll
);
}
auto
&
all
=
this
->
outputs_
.
at
(
kAll
);
std
::
vector
<
std
::
string
>
ret_val
;
for
(
auto
&
each
:
all
)
{
if
(
!
Contains
(
intermediate_outputs_
,
each
))
{
ret_val
.
push_back
(
each
);
}
}
return
ret_val
;
}
}
// namespace operators
}
// namespace paddle
paddle/operators/net_op.h
浏览文件 @
78af6e60
...
...
@@ -36,6 +36,8 @@ namespace operators {
*/
class
NetOp
:
public
framework
::
OperatorBase
{
public:
static
const
char
kAll
[];
/**
* Infer all the operators' input and output variables' shapes, will be called
* before every mini-batch
...
...
@@ -91,11 +93,13 @@ class NetOp : public framework::OperatorBase {
std
::
string
DebugString
()
const
override
;
bool
IsNetOp
()
const
override
;
std
::
vector
<
std
::
string
>
OutputVars
(
bool
has_intermediate
)
const
override
;
std
::
vector
<
std
::
shared_ptr
<
OperatorBase
>>
ops_
;
private:
bool
add_op_done_
{
false
};
std
::
set
<
std
::
string
>
intermediate_outputs_
;
template
<
typename
T
,
typename
KeyType
>
static
bool
Contains
(
T
container
,
KeyType
key
)
{
...
...
paddle/operators/net_op_test.cc
浏览文件 @
78af6e60
...
...
@@ -54,22 +54,13 @@ TEST(OpKernel, all) {
net
->
CompleteAddOp
();
AssertSameVectorWithoutOrder
({
"x"
,
"w1"
,
"b1"
,
"w2"
,
"b2"
},
net
->
inputs_
.
at
(
"__all__"
));
AssertSameVectorWithoutOrder
({
"y"
,
"z"
},
net
->
outputs_
.
at
(
"__all__"
));
auto
tmp_idx_iter
=
net
->
attrs_
.
find
(
"temporary_index"
);
ASSERT_NE
(
net
->
attrs_
.
end
(),
tmp_idx_iter
);
auto
&
tmp_idx
=
boost
::
get
<
std
::
vector
<
int
>>
(
tmp_idx_iter
->
second
);
ASSERT_EQ
(
1UL
,
tmp_idx
.
size
());
ASSERT_EQ
(
"y"
,
net
->
outputs_
.
at
(
"__all__"
)[
tmp_idx
[
0
]]);
net
->
inputs_
.
at
(
NetOp
::
kAll
));
AssertSameVectorWithoutOrder
({
"y"
,
"z"
},
net
->
outputs_
.
at
(
NetOp
::
kAll
));
Scope
scope
;
platform
::
CPUDeviceContext
dev_ctx
;
auto
final_outs
=
net
->
OutputVars
(
false
);
net
->
InferShape
(
scope
);
net
->
Run
(
scope
,
dev_ctx
);
ASSERT_EQ
(
2
,
infer_shape_cnt
);
ASSERT_EQ
(
2
,
run_cnt
);
ASSERT_THROW
(
net
->
AddOp
(
op2
),
platform
::
EnforceNotMet
);
ASSERT_EQ
(
final_outs
.
size
(),
1UL
);
ASSERT_EQ
(
final_outs
[
0
],
"z"
);
}
TEST
(
NetOp
,
insert_op
)
{
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录