Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
8b86624b
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
8b86624b
编写于
6月 23, 2017
作者:
Z
Zhaolong Xing
提交者:
GitHub
6月 23, 2017
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #2354 from NHZlX/improve_pruning
Improve pruning module
上级
bf7a278e
1d6b8595
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
107 addition
and
91 deletion
+107
-91
paddle/parameter/ParameterUpdaterHook.cpp
paddle/parameter/ParameterUpdaterHook.cpp
+56
-84
proto/ParameterConfig.proto
proto/ParameterConfig.proto
+3
-1
python/paddle/trainer/config_parser.py
python/paddle/trainer/config_parser.py
+5
-5
python/paddle/trainer_config_helpers/attrs.py
python/paddle/trainer_config_helpers/attrs.py
+41
-1
python/paddle/v2/attr.py
python/paddle/v2/attr.py
+2
-0
未找到文件。
paddle/parameter/ParameterUpdaterHook.cpp
浏览文件 @
8b86624b
...
...
@@ -14,11 +14,13 @@ limitations under the License. */
#include "ParameterUpdaterHook.h"
#include <algorithm>
#include <atomic>
#include <fstream>
#include <mutex>
#include <thread>
#include <unordered_map>
#include <vector>
#include "paddle/math/Vector.h"
#include "paddle/parameter/Parameter.h"
...
...
@@ -29,106 +31,76 @@ namespace paddle {
/**
* The static pruning hook
*
*
Static means user load a mask map before training started. This map will
*
define which link/weight between neural is disabled
.
*
Static means user specify a sparsity_ratio before training started, and the
*
network will prune the parameters based on the sparsity_ratio. More details
*
can be found https://arxiv.org/pdf/1506.02626.pdf
.
*/
class
StaticPruningHook
:
public
IParameterUpdaterHook
{
public:
/**
* The Mask Map Header.
* The map file started with this header.
*
* In Version 0, reset file will be:
* contains header.size bit, each bit means such weight is enabled or not.
* if bit is 1, then such weight is enabled.
* at end, the file will round to byte, and the low bits of end byte will be
* filled by zero.
*
*/
struct
StaticMaskHeader
{
uint32_t
version
;
size_t
size
;
}
__attribute__
((
__packed__
));
explicit
StaticPruningHook
(
const
std
::
string
&
mask_filename
)
:
initCount_
(
0
)
{
bool
ok
=
this
->
loadMaskFile
(
mask_filename
);
if
(
!
ok
)
{
LOG
(
WARNING
)
<<
"Fail to load mask file "
<<
mask_filename
<<
" in current directory, searching in init_model_path"
;
std
::
string
combineMaskFilename
=
path
::
join
(
FLAGS_init_model_path
,
mask_filename
);
CHECK
(
this
->
loadMaskFile
(
combineMaskFilename
))
<<
"Cannot load "
<<
mask_filename
<<
" in ./"
<<
mask_filename
<<
" and "
<<
combineMaskFilename
;
}
VLOG
(
3
)
<<
mask_filename
<<
" mask size = "
<<
this
->
mask_
.
size
();
explicit
StaticPruningHook
(
const
ParameterUpdaterHookConfig
&
hookConfig
)
:
initCount_
(
0
)
{
sparsityRatio_
=
hookConfig
.
sparsity_ratio
();
}
void
update
(
Parameter
*
para
)
{
static
bool
sortPairAscend
(
const
std
::
pair
<
real
,
size_t
>
&
pair1
,
const
std
::
pair
<
real
,
size_t
>
&
pair2
)
{
return
pair1
.
first
>
pair2
.
first
;
}
void
update
(
Parameter
*
para
)
{
updateThreadChecker_
.
check
();
auto
&
vec
=
para
->
getBuf
(
PARAMETER_GRADIENT
);
auto
&
vec
=
para
->
getBuf
(
PARAMETER_GRADIENT
);
if
(
vec
)
{
vec
->
dotMul
(
*
maskVec_
);
}
}
void
init
(
Parameter
*
para
)
{
size_t
initCount
=
this
->
initCount_
.
fetch_add
(
1
);
CHECK_EQ
(
initCount
,
0UL
)
<<
"Currently the StaticPruningHook must invoke "
"in same ParamterUpdater"
;
VLOG
(
3
)
<<
"Initialize Parameter "
<<
para
;
SetDevice
device
(
para
->
getDeviceId
());
void
generateMask
(
Parameter
*
para
)
{
VectorPtr
maskTemp
=
Vector
::
create
(
para
->
getSize
(),
false
);
maskTemp
->
zeroMem
();
real
*
maskTempData
=
maskTemp
->
getData
();
size_t
nonZeroNum
=
para
->
getSize
()
*
(
1
-
sparsityRatio_
);
auto
maskVec
=
Vector
::
create
(
this
->
mask_
.
size
(),
false
);
{
// Initialize maskVec with float mask vector
real
*
dataPtr
=
maskVec
->
getData
();
size_t
i
=
0
;
for
(
bool
m
:
mask_
)
{
dataPtr
[
i
++
]
=
m
?
1.0
:
0.0
;
}
}
VectorPtr
paraVec
=
para
->
getBuf
(
PARAMETER_VALUE
);
VectorPtr
paraCpuCopy
=
Vector
::
create
(
para
->
getSize
(),
false
);
paraCpuCopy
->
copyFrom
(
*
paraVec
);
std
::
vector
<
std
::
pair
<
real
,
size_t
>>
param
;
for
(
size_t
i
=
0
;
i
<
para
->
getSize
();
i
++
)
param
.
push_back
(
std
::
make_pair
(
fabs
(
paraCpuCopy
->
getData
()[
i
]),
i
));
std
::
partial_sort
(
param
.
begin
(),
param
.
begin
()
+
nonZeroNum
,
param
.
end
(),
sortPairAscend
);
for
(
size_t
i
=
0
;
i
<
nonZeroNum
;
i
++
)
maskTempData
[
param
[
i
].
second
]
=
1.0
;
// Currently just use a mask vector for hack.
// @TODO(yuyang18): Implemented the mask operation in vector.
if
(
para
->
useGpu
())
{
maskVec_
=
Vector
::
create
(
this
->
mask_
.
s
ize
(),
para
->
useGpu
());
maskVec_
->
copyFrom
(
*
mask
Vec
);
maskVec_
=
Vector
::
create
(
para
->
getS
ize
(),
para
->
useGpu
());
maskVec_
->
copyFrom
(
*
mask
Temp
);
}
else
{
maskVec_
=
mask
Vec
;
maskVec_
=
mask
Temp
;
}
auto
&
vec
=
para
->
getBuf
(
PARAMETER_VALUE
);
vec
->
dotMul
(
*
maskVec_
);
}
private:
bool
loadMaskFile
(
const
std
::
string
&
mask_filename
)
{
std
::
ifstream
fin
;
fin
.
open
(
mask_filename
);
if
(
fin
.
is_open
())
{
StaticMaskHeader
header
;
fin
.
read
(
reinterpret_cast
<
char
*>
(
&
header
),
sizeof
(
StaticMaskHeader
));
CHECK_EQ
(
header
.
version
,
0UL
);
mask_
.
resize
(
header
.
size
);
uint8_t
buf
;
for
(
size_t
i
=
0
;
i
<
header
.
size
;
++
i
,
buf
<<=
1
)
{
if
(
i
%
8
==
0
)
{
fin
.
read
(
reinterpret_cast
<
char
*>
(
&
buf
),
sizeof
(
uint8_t
));
}
mask_
[
i
]
=
buf
&
0x80
;
}
fin
.
close
();
return
true
;
}
else
{
return
false
;
}
void
init
(
Parameter
*
para
)
{
generateMask
(
para
);
size_t
initCount
=
this
->
initCount_
.
fetch_add
(
1
);
CHECK_EQ
(
initCount
,
0UL
)
<<
"Currently the StaticPruningHook must invoke "
"in same ParamterUpdater"
;
VLOG
(
3
)
<<
"Initialize Parameter "
<<
para
;
SetDevice
device
(
para
->
getDeviceId
());
auto
&
paraVec
=
para
->
getBuf
(
PARAMETER_VALUE
);
paraVec
->
dotMul
(
*
maskVec_
);
}
private:
SameThreadChecker
updateThreadChecker_
;
std
::
atomic
<
size_t
>
initCount_
;
VectorPtr
maskVec_
;
std
::
vector
<
bool
>
mask
_
;
real
sparsityRatio
_
;
};
IParameterUpdaterHook
::
IParameterUpdaterHook
()
{}
...
...
@@ -145,7 +117,7 @@ IParameterUpdaterHook::~IParameterUpdaterHook() {}
*/
class
StringIntPairHasher
{
public:
size_t
operator
()(
const
std
::
pair
<
std
::
string
,
int
>
&
k
)
const
{
size_t
operator
()(
const
std
::
pair
<
std
::
string
,
int
>
&
k
)
const
{
return
intHasher_
(
strHasher_
(
k
.
first
)
+
k
.
second
);
}
...
...
@@ -162,19 +134,19 @@ static WeakKVCache<std::pair<std::string, int>,
/**
* ParameterUpdaterHook actually factory method.
*/
static
IParameterUpdaterHook
*
createImpl
(
const
ParameterUpdaterHookConfig
&
config
)
{
auto
&
type
=
config
.
type
();
static
IParameterUpdaterHook
*
createImpl
(
const
ParameterUpdaterHookConfig
&
config
)
{
auto
&
type
=
config
.
type
();
if
(
type
==
"pruning"
)
{
if
(
config
.
has_purning_mask_filename
())
{
return
new
StaticPruningHook
(
config
.
purning_mask_filename
());
}
return
new
StaticPruningHook
(
config
);
}
LOG
(
FATAL
)
<<
"Unknown Hook type: "
<<
type
;
return
nullptr
;
}
std
::
shared_ptr
<
IParameterUpdaterHook
>
IParameterUpdaterHook
::
create
(
const
ParameterConfig
&
paramConfig
,
int
idx
)
{
const
ParameterConfig
&
paramConfig
,
int
idx
)
{
std
::
pair
<
std
::
string
,
int
>
key
=
{
paramConfig
.
name
(),
idx
};
return
g_hookCache_
.
get
(
key
,
[
&
]
{
return
createImpl
(
paramConfig
.
update_hooks
(
idx
));
});
...
...
proto/ParameterConfig.proto
浏览文件 @
8b86624b
...
...
@@ -25,8 +25,10 @@ enum ParameterInitStrategy {
}
message
ParameterUpdaterHookConfig
{
// hook type such as 'pruning'
required
string
type
=
1
;
optional
string
purning_mask_filename
=
2
;
// this represents the ratio of zero element to be set by the Parameter
optional
double
sparsity_ratio
=
2
[
default
=
0.6
];
}
message
ParameterConfig
{
...
...
python/paddle/trainer/config_parser.py
浏览文件 @
8b86624b
...
...
@@ -3139,11 +3139,11 @@ def Layer(name, type, **xargs):
@
config_func
def
ParameterHook
(
type
,
**
kwargs
):
if
type
==
'pruning'
:
mask_filename
=
kwargs
.
get
(
'mask_filename'
,
None
)
assert
mask_filename
is
not
None
hook
=
ParameterUpdaterHookConfig
()
hook
.
type
=
type
hook
.
purning_mask_filename
=
mask_filename
sparsity_ratio
=
kwargs
.
get
(
'sparsity_ratio'
,
None
)
if
sparsity_ratio
is
not
None
:
hook
.
sparsity_ratio
=
sparsity_ratio
return
hook
else
:
return
None
...
...
@@ -3251,13 +3251,13 @@ def Parameter(name,
if
update_hooks
is
not
None
:
if
hasattr
(
update_hooks
,
'__call__'
):
update_hooks
=
update_hooks
(
para
.
name
)
update_hooks
=
update_hooks
()
if
isinstance
(
update_hooks
,
list
):
for
hook
in
update_hooks
:
para
.
update_hooks
.
extend
([
hook
])
else
:
para
.
update_hooks
.
extend
(
update_hooks
)
para
.
update_hooks
.
extend
(
[
update_hooks
]
)
g_parameter_map
[
name
]
=
para
if
initializer
is
not
None
:
...
...
python/paddle/trainer_config_helpers/attrs.py
浏览文件 @
8b86624b
...
...
@@ -14,7 +14,8 @@
from
paddle.trainer.config_parser
import
*
__all__
=
[
'ParamAttr'
,
'ExtraAttr'
,
'ParameterAttribute'
,
'ExtraLayerAttribute'
'HookAttr'
,
'ParamAttr'
,
'ExtraAttr'
,
'ParameterAttribute'
,
'ExtraLayerAttribute'
]
...
...
@@ -55,6 +56,40 @@ def is_compatible_with(x, Type):
return
False
class
HookAttribute
(
object
):
"""
Hook Attribute object. As a member of ParameterAttribute class, the hook is an auxiliary operation that occurs
during training process of a layer with parameters, such as img_conv layer, fc layer.
:param type: Hook type, currently supported types:
'pruning' : user specify a sparsity_ratio before training started, and the
network will prune the parameters based on the sparsity_ratio.
eg: The definition of Hook object can be hk = HookAttribute('pruning', 0.6)
The specific usage can be paddle.layer.img_conv(input=img, filter_size=3,
num_channels=3, num_filters=64,
param_attr=ParameterAttribute(update_hooks=hk) )
The pruning details can be found https://arxiv.org/pdf/1506.02626.pdf
:type type: string
:param sparsity_ratio: Must be specified if hook type is 'pruning',
it represents the ratio of the zero elements to be set by the Parameter.
:type sparsity_ratio: float or None
"""
def
__init__
(
self
,
type
,
sparsity_ratio
=
None
):
self
.
type
=
type
self
.
sparsity_ratio
=
sparsity_ratio
if
self
.
sparsity_ratio
is
not
None
:
assert
is_compatible_with
(
self
.
sparsity_ratio
,
float
),
'sparisity_ratio must be float type'
assert
self
.
sparsity_ratio
<=
1
and
self
.
sparsity_ratio
>=
0
,
'sparsity_ratio must be a float between [0, 1] '
def
__call__
(
self
):
return
ParameterHook
(
self
.
type
,
sparsity_ratio
=
self
.
sparsity_ratio
)
class
ParameterAttribute
(
object
):
"""
Parameter Attributes object. To fine-tuning network training process, user
...
...
@@ -114,6 +149,7 @@ class ParameterAttribute(object):
momentum
=
None
,
gradient_clipping_threshold
=
None
,
sparse_update
=
False
,
update_hooks
=
None
,
initializer
=
None
):
self
.
attr
=
{}
...
...
@@ -169,6 +205,9 @@ class ParameterAttribute(object):
if
initializer
is
not
None
:
self
.
attr
[
'initializer'
]
=
initializer
if
update_hooks
:
self
.
attr
[
'update_hooks'
]
=
update_hooks
def
set_default_parameter_name
(
self
,
name
):
"""
Set default parameter name. If parameter not set, then will use default
...
...
@@ -244,5 +283,6 @@ class ExtraLayerAttribute(object):
return
attr
.
attr
HookAttr
=
HookAttribute
ParamAttr
=
ParameterAttribute
ExtraAttr
=
ExtraLayerAttribute
python/paddle/v2/attr.py
浏览文件 @
8b86624b
...
...
@@ -17,10 +17,12 @@ import paddle.trainer_config_helpers.attrs
__all__
=
[
"Param"
,
"Extra"
,
"Hook"
,
]
Param
=
paddle
.
trainer_config_helpers
.
attrs
.
ParameterAttribute
Extra
=
paddle
.
trainer_config_helpers
.
attrs
.
ExtraLayerAttribute
Hook
=
paddle
.
trainer_config_helpers
.
attrs
.
HookAttribute
for
each
in
paddle
.
trainer_config_helpers
.
attrs
.
__all__
:
globals
()[
each
]
=
getattr
(
paddle
.
trainer_config_helpers
.
attrs
,
each
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录