Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
a31ff363
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
a31ff363
编写于
10月 11, 2017
作者:
Y
Yang Yang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
prune pass dummy test
上级
b504a234
变更
5
显示空白变更内容
内联
并排
Showing
5 changed file
with
337 addition
and
0 deletion
+337
-0
paddle/framework/CMakeLists.txt
paddle/framework/CMakeLists.txt
+3
-0
paddle/framework/framework.proto
paddle/framework/framework.proto
+1
-0
paddle/framework/prune.cc
paddle/framework/prune.cc
+107
-0
paddle/framework/prune.h
paddle/framework/prune.h
+26
-0
paddle/framework/prune_test.cc
paddle/framework/prune_test.cc
+200
-0
未找到文件。
paddle/framework/CMakeLists.txt
浏览文件 @
a31ff363
...
...
@@ -49,5 +49,8 @@ cc_library(executor SRCS executor.cc DEPS op_registry device_context scope frame
# cc_test(executor_test SRCS executor_test.cc DEPS executor)
#endif()
cc_library
(
prune SRCS prune.cc
)
cc_test
(
prune_test SRCS prune_test.cc DEPS prune recurrent_op device_context
)
cc_library
(
tensor_array SRCS tensor_array.cc DEPS lod_tensor
)
cc_test
(
tensor_array_test SRCS tensor_array_test.cc DEPS tensor_array place
)
paddle/framework/framework.proto
浏览文件 @
a31ff363
...
...
@@ -55,6 +55,7 @@ message OpDesc {
repeated
Var
inputs
=
1
;
repeated
Var
outputs
=
2
;
repeated
Attr
attrs
=
4
;
required
bool
is_target
=
5
[
default
=
false
];
};
// OpProto describes a C++ framework::OperatorBase derived class.
...
...
paddle/framework/prune.cc
0 → 100644
浏览文件 @
a31ff363
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/framework/prune.h"
#include <algorithm>
#include <set>
#include <string>
#include <vector>
#include <glog/logging.h>
namespace
paddle
{
namespace
framework
{
const
std
::
string
kFeedOpType
=
"feed"
;
const
std
::
string
kFetchOpType
=
"fetch"
;
bool
HasDependentVar
(
const
OpDesc
&
op_desc
,
const
std
::
set
<
std
::
string
>&
dependent_vars
)
{
for
(
auto
&
var
:
op_desc
.
outputs
())
{
for
(
auto
&
argu
:
var
.
arguments
())
{
if
(
dependent_vars
.
count
(
argu
)
!=
0
)
{
return
true
;
}
}
}
return
false
;
}
void
Prune
(
const
ProgramDesc
&
input
,
ProgramDesc
&
output
,
int
id
)
{
// TODO(tonyyang-svail):
// - will change to use multiple blocks for RNN op and Cond Op
auto
&
block
=
input
.
blocks
(
0
);
auto
&
ops
=
block
.
ops
();
bool
expect_feed
=
true
;
for
(
auto
&
op_desc
:
ops
)
{
PADDLE_ENFORCE
(
op_desc
.
type
()
!=
kFeedOpType
||
expect_feed
,
"All FeedOps are at the beginning of the ProgramDesc"
);
expect_feed
=
(
op_desc
.
type
()
==
kFeedOpType
);
}
bool
expect_fetch
=
true
;
for
(
auto
op_iter
=
ops
.
rbegin
();
op_iter
!=
ops
.
rend
();
++
op_iter
)
{
auto
&
op_desc
=
*
op_iter
;
PADDLE_ENFORCE
(
op_desc
.
type
()
!=
kFetchOpType
||
expect_fetch
,
"All FetchOps must at the end of the ProgramDesc"
);
expect_fetch
=
(
op_desc
.
type
()
==
kFetchOpType
);
}
std
::
set
<
std
::
string
>
dependent_vars
;
std
::
vector
<
bool
>
should_run
;
for
(
auto
op_iter
=
ops
.
rbegin
();
op_iter
!=
ops
.
rend
();
++
op_iter
)
{
auto
&
op_desc
=
*
op_iter
;
if
(
op_desc
.
is_target
()
||
HasDependentVar
(
op_desc
,
dependent_vars
))
{
// erase its output to the dependency graph
for
(
auto
&
var
:
op_desc
.
outputs
())
{
for
(
auto
&
argu
:
var
.
arguments
())
{
dependent_vars
.
erase
(
argu
);
}
}
// insert its input to the dependency graph
for
(
auto
&
var
:
op_desc
.
inputs
())
{
for
(
auto
&
argu
:
var
.
arguments
())
{
dependent_vars
.
insert
(
argu
);
}
}
should_run
.
push_back
(
true
);
}
else
{
should_run
.
push_back
(
false
);
}
}
// since we are traversing the ProgramDesc in reverse order
// we reverse the should_run vector
std
::
reverse
(
should_run
.
begin
(),
should_run
.
end
());
output
=
input
;
auto
*
op_field
=
output
.
mutable_blocks
(
id
)
->
mutable_ops
();
op_field
->
Clear
();
for
(
size_t
i
=
0
;
i
<
should_run
.
size
();
++
i
)
{
if
(
should_run
[
i
])
{
*
op_field
->
Add
()
=
input
.
blocks
(
id
).
ops
(
i
);
}
}
// return should_run;
}
}
// namespace framework
}
// namespace paddle
paddle/framework/prune.h
0 → 100644
浏览文件 @
a31ff363
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/framework.pb.h"
#include "paddle/platform/enforce.h"
namespace
paddle
{
namespace
framework
{
void
Prune
(
const
ProgramDesc
&
input
,
ProgramDesc
&
output
,
int
id
);
}
// namespace framework
}
// namespace paddle
paddle/framework/prune_test.cc
0 → 100644
浏览文件 @
a31ff363
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/framework/prune.h"
#include <gtest/gtest.h>
#include "paddle/framework/attribute.h"
#include "paddle/framework/block_desc.h"
#include "paddle/framework/op_desc.h"
#include "paddle/framework/op_registry.h"
#include "paddle/framework/operator.h"
#include "paddle/framework/program_desc.h"
#include "paddle/operators/net_op.h"
namespace
paddle
{
namespace
framework
{
using
DeviceContext
=
platform
::
DeviceContext
;
class
RowWiseAddOpMaker
:
public
OpProtoAndCheckerMaker
{
public:
RowWiseAddOpMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"X"
,
"Input X of Add"
);
AddInput
(
"b"
,
"Bias of Add"
);
AddOutput
(
"Out"
,
"Out of Add"
);
AddComment
(
"Add Op"
);
}
};
class
RowWiseAddGradMaker
:
public
SingleGradOpDescMaker
{
public:
using
SingleGradOpDescMaker
::
SingleGradOpDescMaker
;
protected:
std
::
unique_ptr
<
OpDescBind
>
Apply
()
const
override
{
auto
grad_op
=
new
OpDescBind
();
grad_op
->
SetInput
(
GradVarName
(
"Out"
),
OutputGrad
(
"Out"
));
grad_op
->
SetOutput
(
GradVarName
(
"X"
),
InputGrad
(
"X"
));
grad_op
->
SetOutput
(
GradVarName
(
"b"
),
InputGrad
(
"b"
));
grad_op
->
SetType
(
"rowwise_add_grad"
);
return
std
::
unique_ptr
<
OpDescBind
>
(
grad_op
);
}
};
class
MulOpMaker
:
public
OpProtoAndCheckerMaker
{
public:
MulOpMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"X"
,
"A"
);
AddInput
(
"Y"
,
"B"
);
AddOutput
(
"Out"
,
"Out"
);
AddAttr
<
int
>
(
"x_num_col_dims"
,
""
).
SetDefault
(
1
).
EqualGreaterThan
(
1
);
AddAttr
<
int
>
(
"y_num_col_dims"
,
""
).
SetDefault
(
1
).
EqualGreaterThan
(
1
);
AddComment
(
"Mul"
);
}
};
class
SigmoidOpMaker
:
public
OpProtoAndCheckerMaker
{
public:
SigmoidOpMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"X"
,
"X"
);
AddOutput
(
"Out"
,
"Y"
);
AddComment
(
"Sigmoid"
);
}
};
class
NoGradOpMaker
:
public
OpProtoAndCheckerMaker
{
public:
NoGradOpMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"X"
,
"X input"
);
AddOutput
(
"Out"
,
"Y output"
);
AddComment
(
"NoGradOp, same input output. no Grad"
);
}
};
class
ManyOutputOpMaker
:
public
OpProtoAndCheckerMaker
{
public:
ManyOutputOpMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"x"
,
"x"
);
AddOutput
(
"y"
,
"y"
);
AddOutput
(
"z"
,
"z"
);
AddComment
(
""
);
}
};
class
FillZeroOpMaker
:
public
OpProtoAndCheckerMaker
{
public:
FillZeroOpMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"X"
,
"x"
);
AddOutput
(
"Y"
,
"out"
);
AddComment
(
""
);
}
};
class
SumOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
SumOpMaker
(
framework
::
OpProto
*
proto
,
framework
::
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"X"
,
"the input tensors of sum operator."
).
AsDuplicable
();
AddOutput
(
"Out"
,
"the output tensor of sum operator."
);
AddComment
(
""
);
}
};
class
MultInOutOpMaker
:
public
OpProtoAndCheckerMaker
{
public:
MultInOutOpMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"X"
,
"x"
);
AddInput
(
"H"
,
"h"
);
AddOutput
(
"Y"
,
"y"
);
AddOutput
(
"Z"
,
"z"
);
AddComment
(
""
);
}
};
}
// namespace framework
}
// namespace paddle
namespace
f
=
paddle
::
framework
;
namespace
ops
=
paddle
::
operators
;
using
EnforceNotMet
=
paddle
::
platform
::
EnforceNotMet
;
REGISTER_OPERATOR
(
rowwise_add
,
f
::
NOP
,
f
::
RowWiseAddOpMaker
,
f
::
RowWiseAddGradMaker
);
REGISTER_OPERATOR
(
rowwise_add_grad
,
f
::
NOP
);
REGISTER_OP
(
mul
,
f
::
NOP
,
f
::
MulOpMaker
,
mul_grad
,
f
::
NOP
);
REGISTER_OP
(
sigmoid
,
f
::
NOP
,
f
::
SigmoidOpMaker
,
sigmoid_grad
,
f
::
NOP
);
REGISTER_OP_WITHOUT_GRADIENT
(
nograd
,
f
::
NOP
,
f
::
NoGradOpMaker
);
REGISTER_OP_WITHOUT_GRADIENT
(
fill_zeros_like
,
f
::
NOP
,
f
::
FillZeroOpMaker
);
REGISTER_OP
(
sum
,
f
::
NOP
,
f
::
SumOpMaker
,
sum_grad
,
f
::
NOP
);
REGISTER_OP
(
many_output_op
,
f
::
NOP
,
f
::
ManyOutputOpMaker
,
many_output_op_grad
,
f
::
NOP
);
REGISTER_OP
(
mult_in_out
,
f
::
NOP
,
f
::
MultInOutOpMaker
,
mult_in_out_grad
,
f
::
NOP
);
void
AddOp
(
const
std
::
string
&
type
,
const
f
::
VariableNameMap
&
inputs
,
const
f
::
VariableNameMap
&
outputs
,
f
::
AttributeMap
attrs
,
paddle
::
framework
::
BlockDescBind
*
block
)
{
// insert output
for
(
auto
kv
:
outputs
)
{
for
(
auto
v
:
kv
.
second
)
{
auto
var
=
block
->
NewVar
(
v
);
var
->
SetDataType
(
paddle
::
framework
::
DataType
::
FP32
);
}
}
// insert op
auto
op
=
block
->
AppendOp
();
op
->
SetType
(
type
);
for
(
auto
&
kv
:
inputs
)
{
op
->
SetInput
(
kv
.
first
,
kv
.
second
);
}
for
(
auto
&
kv
:
outputs
)
{
op
->
SetOutput
(
kv
.
first
,
kv
.
second
);
}
op
->
SetAttrMap
(
attrs
);
}
f
::
ProgramDesc
*
GetNewProgramDesc
()
{
auto
*
program_desc
=
new
f
::
ProgramDesc
();
auto
*
root_block
=
program_desc
->
add_blocks
();
root_block
->
set_idx
(
0
);
root_block
->
set_parent_idx
(
-
1
);
return
program_desc
;
}
TEST
(
Prune
,
one_operator
)
{
f
::
ProgramDesc
*
program_desc
=
GetNewProgramDesc
();
f
::
ProgramDescBind
&
program
=
f
::
ProgramDescBind
::
Instance
(
program_desc
);
f
::
BlockDescBind
*
block
=
program
.
Block
(
0
);
AddOp
(
"mul"
,
{{
"X"
,
{
"a"
}},
{
"Y"
,
{
"w1"
}}},
{{
"Out"
,
{
"b"
}}},
{},
block
);
f
::
ProgramDesc
*
pdesc
=
program
.
Proto
();
f
::
ProgramDesc
pruned
;
Prune
(
*
pdesc
,
pruned
,
0
);
PADDLE_ENFORCE_EQ
(
pruned
.
blocks
(
0
).
ops_size
(),
0
);
pdesc
->
mutable_blocks
(
0
)
->
mutable_ops
(
0
)
->
set_is_target
(
true
);
Prune
(
*
pdesc
,
pruned
,
0
);
PADDLE_ENFORCE_EQ
(
pruned
.
blocks
(
0
).
ops_size
(),
1
);
}
TEST
(
Prune
,
simple_optimize
)
{}
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录