Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
5413af8d
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
5413af8d
编写于
6月 02, 2017
作者:
X
xzl
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
imporve pruning module
上级
da83d286
变更
5
显示空白变更内容
内联
并排
Showing
5 changed file
with
144 addition
and
11 deletion
+144
-11
paddle/parameter/ParameterUpdaterHook.cpp
paddle/parameter/ParameterUpdaterHook.cpp
+85
-5
proto/ParameterConfig.proto
proto/ParameterConfig.proto
+2
-0
python/paddle/trainer/config_parser.py
python/paddle/trainer/config_parser.py
+11
-4
python/paddle/trainer_config_helpers/attrs.py
python/paddle/trainer_config_helpers/attrs.py
+44
-2
python/paddle/v2/attr.py
python/paddle/v2/attr.py
+2
-0
未找到文件。
paddle/parameter/ParameterUpdaterHook.cpp
浏览文件 @
5413af8d
...
@@ -25,6 +25,9 @@ limitations under the License. */
...
@@ -25,6 +25,9 @@ limitations under the License. */
#include "paddle/utils/Flags.h"
#include "paddle/utils/Flags.h"
#include "paddle/utils/Util.h"
#include "paddle/utils/Util.h"
using
std
::
vector
;
using
std
::
pair
;
namespace
paddle
{
namespace
paddle
{
/**
/**
...
@@ -131,6 +134,73 @@ private:
...
@@ -131,6 +134,73 @@ private:
std
::
vector
<
bool
>
mask_
;
std
::
vector
<
bool
>
mask_
;
};
};
class
DynamicPruningHook
:
public
IParameterUpdaterHook
{
public:
explicit
DynamicPruningHook
(
const
ParameterUpdaterHookConfig
&
hookConfig
)
:
initCount_
(
0
)
{
sparsityRatio_
=
hookConfig
.
sparsity_ratio
();
}
static
bool
sortPairAscend
(
const
pair
<
real
,
size_t
>&
pair1
,
const
pair
<
real
,
size_t
>&
pair2
)
{
return
pair1
.
first
>
pair2
.
first
;
}
void
update
(
Parameter
*
para
)
{
updateThreadChecker_
.
check
();
auto
&
vec
=
para
->
getBuf
(
PARAMETER_GRADIENT
);
if
(
vec
)
{
vec
->
dotMul
(
*
maskVec_
);
}
}
void
generateMask
(
Parameter
*
para
)
{
VectorPtr
vec
=
para
->
getBuf
(
PARAMETER_VALUE
);
maskTemp_
=
Vector
::
create
(
para
->
getSize
(),
false
);
maskTemp_
->
zeroMem
();
real
*
dataPtr
=
maskTemp_
->
getData
();
VectorPtr
vecCpu
=
Vector
::
create
(
para
->
getSize
(),
false
);
vecCpu
->
copyFrom
(
*
vec
);
vector
<
pair
<
real
,
size_t
>>
param
;
for
(
size_t
i
=
0
;
i
<
para
->
getSize
();
i
++
)
param
.
push_back
(
std
::
make_pair
(
fabs
(
vecCpu
->
getData
()[
i
]),
i
));
std
::
sort
(
param
.
begin
(),
param
.
end
(),
sortPairAscend
);
for
(
size_t
i
=
0
;
i
<
para
->
getSize
()
*
sparsityRatio_
;
i
++
)
dataPtr
[
param
[
i
].
second
]
=
1.0
;
}
void
init
(
Parameter
*
para
)
{
generateMask
(
para
);
size_t
initCount
=
this
->
initCount_
.
fetch_add
(
1
);
CHECK_EQ
(
initCount
,
0UL
)
<<
"Currently the DynamicPruningHook must invoke "
"in same ParamterUpdater"
;
VLOG
(
3
)
<<
"Initialize Parameter "
<<
para
;
SetDevice
device
(
para
->
getDeviceId
());
// Currently just use a mask vector for hack.
// @TODO(yuyang18): Implemented the mask operation in vector.
if
(
para
->
useGpu
())
{
maskVec_
=
Vector
::
create
(
para
->
getSize
(),
para
->
useGpu
());
maskVec_
->
copyFrom
(
*
maskTemp_
);
}
else
{
maskVec_
=
maskTemp_
;
}
auto
&
vec
=
para
->
getBuf
(
PARAMETER_VALUE
);
vec
->
dotMul
(
*
maskVec_
);
}
private:
SameThreadChecker
updateThreadChecker_
;
std
::
atomic
<
size_t
>
initCount_
;
VectorPtr
maskVec_
;
VectorPtr
maskTemp_
;
real
sparsityRatio_
;
};
IParameterUpdaterHook
::
IParameterUpdaterHook
()
{}
IParameterUpdaterHook
::
IParameterUpdaterHook
()
{}
IParameterUpdaterHook
::~
IParameterUpdaterHook
()
{}
IParameterUpdaterHook
::~
IParameterUpdaterHook
()
{}
...
@@ -156,8 +226,7 @@ private:
...
@@ -156,8 +226,7 @@ private:
static
WeakKVCache
<
std
::
pair
<
std
::
string
,
int
>
,
static
WeakKVCache
<
std
::
pair
<
std
::
string
,
int
>
,
IParameterUpdaterHook
,
IParameterUpdaterHook
,
StringIntPairHasher
>
StringIntPairHasher
>
g_hookCache_
;
g_hookCache_
;
/**
/**
* ParameterUpdaterHook actually factory method.
* ParameterUpdaterHook actually factory method.
...
@@ -165,11 +234,22 @@ static WeakKVCache<std::pair<std::string, int>,
...
@@ -165,11 +234,22 @@ static WeakKVCache<std::pair<std::string, int>,
static
IParameterUpdaterHook
*
createImpl
(
static
IParameterUpdaterHook
*
createImpl
(
const
ParameterUpdaterHookConfig
&
config
)
{
const
ParameterUpdaterHookConfig
&
config
)
{
auto
&
type
=
config
.
type
();
auto
&
type
=
config
.
type
();
if
(
type
==
"pruning"
)
{
if
(
type
==
"pruning
_static
"
)
{
if
(
config
.
has_purning_mask_filename
())
{
if
(
config
.
has_purning_mask_filename
())
return
new
StaticPruningHook
(
config
.
purning_mask_filename
());
return
new
StaticPruningHook
(
config
.
purning_mask_filename
());
else
LOG
(
FATAL
)
<<
"There must be mask_filename parameter for "
<<
type
<<
" Hook"
;
}
else
if
(
type
==
"pruning"
)
{
if
(
config
.
has_sparsity_ratio
())
return
new
DynamicPruningHook
(
config
);
else
LOG
(
FATAL
)
<<
"There must be sparsity_ratio parameter for "
<<
type
<<
" Hook"
;
}
}
}
LOG
(
FATAL
)
<<
"Unknown Hook type: "
<<
type
;
return
nullptr
;
return
nullptr
;
}
}
...
...
proto/ParameterConfig.proto
浏览文件 @
5413af8d
...
@@ -26,7 +26,9 @@ enum ParameterInitStrategy {
...
@@ -26,7 +26,9 @@ enum ParameterInitStrategy {
message
ParameterUpdaterHookConfig
{
message
ParameterUpdaterHookConfig
{
required
string
type
=
1
;
required
string
type
=
1
;
//hook type such as 'pruning', 'pruning_static'
optional
string
purning_mask_filename
=
2
;
optional
string
purning_mask_filename
=
2
;
optional
double
sparsity_ratio
=
3
;
}
}
message
ParameterConfig
{
message
ParameterConfig
{
...
...
python/paddle/trainer/config_parser.py
浏览文件 @
5413af8d
...
@@ -3171,12 +3171,19 @@ def Layer(name, type, **xargs):
...
@@ -3171,12 +3171,19 @@ def Layer(name, type, **xargs):
@
config_func
@
config_func
def
ParameterHook
(
type
,
**
kwargs
):
def
ParameterHook
(
type
,
**
kwargs
):
if
type
==
'pruning'
:
if
type
==
'pruning_static'
:
hook
=
ParameterUpdaterHookConfig
()
hook
.
type
=
type
mask_filename
=
kwargs
.
get
(
'mask_filename'
,
None
)
mask_filename
=
kwargs
.
get
(
'mask_filename'
,
None
)
assert
mask_filename
is
not
None
assert
mask_filename
is
not
None
hook
.
pruning_mask_filename
=
mask_filename
return
hook
elif
type
==
'pruning'
:
hook
=
ParameterUpdaterHookConfig
()
hook
=
ParameterUpdaterHookConfig
()
hook
.
type
=
type
hook
.
type
=
type
hook
.
purning_mask_filename
=
mask_filename
sparsity_ratio
=
kwargs
.
get
(
'sparsity_ratio'
,
None
)
assert
sparsity_ratio
is
not
None
hook
.
sparsity_ratio
=
sparsity_ratio
return
hook
return
hook
else
:
else
:
return
None
return
None
...
@@ -3283,13 +3290,13 @@ def Parameter(name,
...
@@ -3283,13 +3290,13 @@ def Parameter(name,
if
update_hooks
is
not
None
:
if
update_hooks
is
not
None
:
if
hasattr
(
update_hooks
,
'__call__'
):
if
hasattr
(
update_hooks
,
'__call__'
):
update_hooks
=
update_hooks
(
para
.
name
)
update_hooks
=
update_hooks
()
if
isinstance
(
update_hooks
,
list
):
if
isinstance
(
update_hooks
,
list
):
for
hook
in
update_hooks
:
for
hook
in
update_hooks
:
para
.
update_hooks
.
extend
([
hook
])
para
.
update_hooks
.
extend
([
hook
])
else
:
else
:
para
.
update_hooks
.
extend
(
update_hooks
)
para
.
update_hooks
.
extend
(
[
update_hooks
]
)
g_parameter_map
[
name
]
=
para
g_parameter_map
[
name
]
=
para
...
...
python/paddle/trainer_config_helpers/attrs.py
浏览文件 @
5413af8d
...
@@ -14,7 +14,8 @@
...
@@ -14,7 +14,8 @@
from
paddle.trainer.config_parser
import
*
from
paddle.trainer.config_parser
import
*
__all__
=
[
__all__
=
[
'ParamAttr'
,
'ExtraAttr'
,
'ParameterAttribute'
,
'ExtraLayerAttribute'
'HookAttr'
,
'ParamAttr'
,
'ExtraAttr'
,
'ParameterAttribute'
,
'ExtraLayerAttribute'
]
]
...
@@ -55,6 +56,42 @@ def is_compatible_with(x, Type):
...
@@ -55,6 +56,42 @@ def is_compatible_with(x, Type):
return
False
return
False
class
HookAttribute
(
object
):
"""
Hook Attribute object. The hook is an auxiliary operation that occurs
during network propagation. Such as pruning operation, It will cut off
redundant parameters in the network before training. More detail can see
here paddle/parameter/ParameterUpdaterHook.cpp
NOTE: IT IS A HIGH LEVEL USER INTERFACE.
:param type: Hook type, eg: 'pruning', 'pruning_static'
:type type: string
:param mask_file: Must be specified if hook type is 'pruning_static',
the network reads the mask from the file to determine which parameters should be cut off
:type mask_file: string
:param sparsity_ratio: Must be specified if hook type is 'pruning',
the network will hold the sparsity_ratio maximum parameters, and cut off the rest.
:type sparsity_ratio: float number between 0 and 1
"""
def
__init__
(
self
,
type
,
mask_filename
=
None
,
sparsity_ratio
=
None
):
self
.
type
=
type
self
.
mask_filename
=
mask_filename
self
.
sparsity_ratio
=
sparsity_ratio
assert
is_compatible_with
(
self
.
sparsity_ratio
,
float
),
'sparisity_ratio must be float type'
assert
self
.
sparsity_ratio
<=
1
and
self
.
sparsity_ratio
>=
0
,
'sparisity must be a flaot between [0, 1] '
def
__call__
(
self
):
return
ParameterHook
(
self
.
type
,
mask_filename
=
self
.
mask_filename
,
sparsity_ratio
=
self
.
sparsity_ratio
)
class
ParameterAttribute
(
object
):
class
ParameterAttribute
(
object
):
"""
"""
Parameter Attributes object. To fine-tuning network training process, user
Parameter Attributes object. To fine-tuning network training process, user
...
@@ -109,7 +146,8 @@ class ParameterAttribute(object):
...
@@ -109,7 +146,8 @@ class ParameterAttribute(object):
learning_rate
=
None
,
learning_rate
=
None
,
momentum
=
None
,
momentum
=
None
,
gradient_clipping_threshold
=
None
,
gradient_clipping_threshold
=
None
,
sparse_update
=
False
):
sparse_update
=
False
,
update_hooks
=
None
):
self
.
attr
=
{}
self
.
attr
=
{}
if
is_static
:
if
is_static
:
...
@@ -162,6 +200,9 @@ class ParameterAttribute(object):
...
@@ -162,6 +200,9 @@ class ParameterAttribute(object):
self
.
attr
[
'gradient_clipping_threshold'
]
=
\
self
.
attr
[
'gradient_clipping_threshold'
]
=
\
gradient_clipping_threshold
gradient_clipping_threshold
if
update_hooks
:
self
.
attr
[
'update_hooks'
]
=
update_hooks
def
set_default_parameter_name
(
self
,
name
):
def
set_default_parameter_name
(
self
,
name
):
"""
"""
Set default parameter name. If parameter not set, then will use default
Set default parameter name. If parameter not set, then will use default
...
@@ -237,5 +278,6 @@ class ExtraLayerAttribute(object):
...
@@ -237,5 +278,6 @@ class ExtraLayerAttribute(object):
return
attr
.
attr
return
attr
.
attr
HookAttr
=
HookAttribute
ParamAttr
=
ParameterAttribute
ParamAttr
=
ParameterAttribute
ExtraAttr
=
ExtraLayerAttribute
ExtraAttr
=
ExtraLayerAttribute
python/paddle/v2/attr.py
浏览文件 @
5413af8d
...
@@ -17,10 +17,12 @@ import paddle.trainer_config_helpers.attrs
...
@@ -17,10 +17,12 @@ import paddle.trainer_config_helpers.attrs
__all__
=
[
__all__
=
[
"Param"
,
"Param"
,
"Extra"
,
"Extra"
,
"Hook"
,
]
]
Param
=
paddle
.
trainer_config_helpers
.
attrs
.
ParameterAttribute
Param
=
paddle
.
trainer_config_helpers
.
attrs
.
ParameterAttribute
Extra
=
paddle
.
trainer_config_helpers
.
attrs
.
ExtraLayerAttribute
Extra
=
paddle
.
trainer_config_helpers
.
attrs
.
ExtraLayerAttribute
Hook
=
paddle
.
trainer_config_helpers
.
attrs
.
HookAttribute
for
each
in
paddle
.
trainer_config_helpers
.
attrs
.
__all__
:
for
each
in
paddle
.
trainer_config_helpers
.
attrs
.
__all__
:
globals
()[
each
]
=
getattr
(
paddle
.
trainer_config_helpers
.
attrs
,
each
)
globals
()[
each
]
=
getattr
(
paddle
.
trainer_config_helpers
.
attrs
,
each
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录