Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleDetection
提交
8b86624b
P
PaddleDetection
项目概览
PaddlePaddle
/
PaddleDetection
大约 2 年 前同步成功
通知
708
Star
11112
Fork
2696
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
184
列表
看板
标记
里程碑
合并请求
40
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
184
Issue
184
列表
看板
标记
里程碑
合并请求
40
合并请求
40
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
8b86624b
编写于
6月 23, 2017
作者:
Z
Zhaolong Xing
提交者:
GitHub
6月 23, 2017
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #2354 from NHZlX/improve_pruning
Improve pruning module
上级
bf7a278e
1d6b8595
变更
5
显示空白变更内容
内联
并排
Showing
5 changed file
with
107 addition
and
91 deletion
+107
-91
paddle/parameter/ParameterUpdaterHook.cpp
paddle/parameter/ParameterUpdaterHook.cpp
+56
-84
proto/ParameterConfig.proto
proto/ParameterConfig.proto
+3
-1
python/paddle/trainer/config_parser.py
python/paddle/trainer/config_parser.py
+5
-5
python/paddle/trainer_config_helpers/attrs.py
python/paddle/trainer_config_helpers/attrs.py
+41
-1
python/paddle/v2/attr.py
python/paddle/v2/attr.py
+2
-0
未找到文件。
paddle/parameter/ParameterUpdaterHook.cpp
浏览文件 @
8b86624b
...
...
@@ -14,11 +14,13 @@ limitations under the License. */
#include "ParameterUpdaterHook.h"
#include <algorithm>
#include <atomic>
#include <fstream>
#include <mutex>
#include <thread>
#include <unordered_map>
#include <vector>
#include "paddle/math/Vector.h"
#include "paddle/parameter/Parameter.h"
...
...
@@ -29,106 +31,76 @@ namespace paddle {
/**
* The static pruning hook
*
*
Static means user load a mask map before training started. This map will
*
define which link/weight between neural is disabled
.
*
Static means user specify a sparsity_ratio before training started, and the
*
network will prune the parameters based on the sparsity_ratio. More details
*
can be found https://arxiv.org/pdf/1506.02626.pdf
.
*/
class
StaticPruningHook
:
public
IParameterUpdaterHook
{
public:
/**
* The Mask Map Header.
* The map file started with this header.
*
* In Version 0, reset file will be:
* contains header.size bit, each bit means such weight is enabled or not.
* if bit is 1, then such weight is enabled.
* at end, the file will round to byte, and the low bits of end byte will be
* filled by zero.
*
*/
struct
StaticMaskHeader
{
uint32_t
version
;
size_t
size
;
}
__attribute__
((
__packed__
));
explicit
StaticPruningHook
(
const
std
::
string
&
mask_filename
)
:
initCount_
(
0
)
{
bool
ok
=
this
->
loadMaskFile
(
mask_filename
);
if
(
!
ok
)
{
LOG
(
WARNING
)
<<
"Fail to load mask file "
<<
mask_filename
<<
" in current directory, searching in init_model_path"
;
std
::
string
combineMaskFilename
=
path
::
join
(
FLAGS_init_model_path
,
mask_filename
);
CHECK
(
this
->
loadMaskFile
(
combineMaskFilename
))
<<
"Cannot load "
<<
mask_filename
<<
" in ./"
<<
mask_filename
<<
" and "
<<
combineMaskFilename
;
explicit
StaticPruningHook
(
const
ParameterUpdaterHookConfig
&
hookConfig
)
:
initCount_
(
0
)
{
sparsityRatio_
=
hookConfig
.
sparsity_ratio
();
}
VLOG
(
3
)
<<
mask_filename
<<
" mask size = "
<<
this
->
mask_
.
size
();
static
bool
sortPairAscend
(
const
std
::
pair
<
real
,
size_t
>
&
pair1
,
const
std
::
pair
<
real
,
size_t
>
&
pair2
)
{
return
pair1
.
first
>
pair2
.
first
;
}
void
update
(
Parameter
*
para
)
{
void
update
(
Parameter
*
para
)
{
updateThreadChecker_
.
check
();
auto
&
vec
=
para
->
getBuf
(
PARAMETER_GRADIENT
);
auto
&
vec
=
para
->
getBuf
(
PARAMETER_GRADIENT
);
if
(
vec
)
{
vec
->
dotMul
(
*
maskVec_
);
}
}
void
init
(
Parameter
*
para
)
{
size_t
initCount
=
this
->
initCount_
.
fetch_add
(
1
);
CHECK_EQ
(
initCount
,
0UL
)
<<
"Currently the StaticPruningHook must invoke "
"in same ParamterUpdater"
;
VLOG
(
3
)
<<
"Initialize Parameter "
<<
para
;
SetDevice
device
(
para
->
getDeviceId
());
void
generateMask
(
Parameter
*
para
)
{
VectorPtr
maskTemp
=
Vector
::
create
(
para
->
getSize
(),
false
);
maskTemp
->
zeroMem
();
real
*
maskTempData
=
maskTemp
->
getData
();
size_t
nonZeroNum
=
para
->
getSize
()
*
(
1
-
sparsityRatio_
);
auto
maskVec
=
Vector
::
create
(
this
->
mask_
.
size
(),
false
);
{
// Initialize maskVec with float mask vector
real
*
dataPtr
=
maskVec
->
getData
();
size_t
i
=
0
;
for
(
bool
m
:
mask_
)
{
dataPtr
[
i
++
]
=
m
?
1.0
:
0.0
;
}
}
VectorPtr
paraVec
=
para
->
getBuf
(
PARAMETER_VALUE
);
VectorPtr
paraCpuCopy
=
Vector
::
create
(
para
->
getSize
(),
false
);
paraCpuCopy
->
copyFrom
(
*
paraVec
);
std
::
vector
<
std
::
pair
<
real
,
size_t
>>
param
;
for
(
size_t
i
=
0
;
i
<
para
->
getSize
();
i
++
)
param
.
push_back
(
std
::
make_pair
(
fabs
(
paraCpuCopy
->
getData
()[
i
]),
i
));
std
::
partial_sort
(
param
.
begin
(),
param
.
begin
()
+
nonZeroNum
,
param
.
end
(),
sortPairAscend
);
for
(
size_t
i
=
0
;
i
<
nonZeroNum
;
i
++
)
maskTempData
[
param
[
i
].
second
]
=
1.0
;
// Currently just use a mask vector for hack.
// @TODO(yuyang18): Implemented the mask operation in vector.
if
(
para
->
useGpu
())
{
maskVec_
=
Vector
::
create
(
this
->
mask_
.
s
ize
(),
para
->
useGpu
());
maskVec_
->
copyFrom
(
*
mask
Vec
);
maskVec_
=
Vector
::
create
(
para
->
getS
ize
(),
para
->
useGpu
());
maskVec_
->
copyFrom
(
*
mask
Temp
);
}
else
{
maskVec_
=
mask
Vec
;
maskVec_
=
mask
Temp
;
}
auto
&
vec
=
para
->
getBuf
(
PARAMETER_VALUE
);
vec
->
dotMul
(
*
maskVec_
);
}
private:
bool
loadMaskFile
(
const
std
::
string
&
mask_filename
)
{
std
::
ifstream
fin
;
fin
.
open
(
mask_filename
);
if
(
fin
.
is_open
())
{
StaticMaskHeader
header
;
fin
.
read
(
reinterpret_cast
<
char
*>
(
&
header
),
sizeof
(
StaticMaskHeader
));
CHECK_EQ
(
header
.
version
,
0UL
);
mask_
.
resize
(
header
.
size
);
uint8_t
buf
;
for
(
size_t
i
=
0
;
i
<
header
.
size
;
++
i
,
buf
<<=
1
)
{
if
(
i
%
8
==
0
)
{
fin
.
read
(
reinterpret_cast
<
char
*>
(
&
buf
),
sizeof
(
uint8_t
));
}
mask_
[
i
]
=
buf
&
0x80
;
}
fin
.
close
();
return
true
;
}
else
{
return
false
;
}
void
init
(
Parameter
*
para
)
{
generateMask
(
para
);
size_t
initCount
=
this
->
initCount_
.
fetch_add
(
1
);
CHECK_EQ
(
initCount
,
0UL
)
<<
"Currently the StaticPruningHook must invoke "
"in same ParamterUpdater"
;
VLOG
(
3
)
<<
"Initialize Parameter "
<<
para
;
SetDevice
device
(
para
->
getDeviceId
());
auto
&
paraVec
=
para
->
getBuf
(
PARAMETER_VALUE
);
paraVec
->
dotMul
(
*
maskVec_
);
}
private:
SameThreadChecker
updateThreadChecker_
;
std
::
atomic
<
size_t
>
initCount_
;
VectorPtr
maskVec_
;
std
::
vector
<
bool
>
mask
_
;
real
sparsityRatio
_
;
};
IParameterUpdaterHook
::
IParameterUpdaterHook
()
{}
...
...
@@ -145,7 +117,7 @@ IParameterUpdaterHook::~IParameterUpdaterHook() {}
*/
class
StringIntPairHasher
{
public:
size_t
operator
()(
const
std
::
pair
<
std
::
string
,
int
>
&
k
)
const
{
size_t
operator
()(
const
std
::
pair
<
std
::
string
,
int
>
&
k
)
const
{
return
intHasher_
(
strHasher_
(
k
.
first
)
+
k
.
second
);
}
...
...
@@ -162,19 +134,19 @@ static WeakKVCache<std::pair<std::string, int>,
/**
* ParameterUpdaterHook actually factory method.
*/
static
IParameterUpdaterHook
*
createImpl
(
const
ParameterUpdaterHookConfig
&
config
)
{
auto
&
type
=
config
.
type
();
static
IParameterUpdaterHook
*
createImpl
(
const
ParameterUpdaterHookConfig
&
config
)
{
auto
&
type
=
config
.
type
();
if
(
type
==
"pruning"
)
{
if
(
config
.
has_purning_mask_filename
())
{
return
new
StaticPruningHook
(
config
.
purning_mask_filename
());
}
return
new
StaticPruningHook
(
config
);
}
LOG
(
FATAL
)
<<
"Unknown Hook type: "
<<
type
;
return
nullptr
;
}
std
::
shared_ptr
<
IParameterUpdaterHook
>
IParameterUpdaterHook
::
create
(
const
ParameterConfig
&
paramConfig
,
int
idx
)
{
const
ParameterConfig
&
paramConfig
,
int
idx
)
{
std
::
pair
<
std
::
string
,
int
>
key
=
{
paramConfig
.
name
(),
idx
};
return
g_hookCache_
.
get
(
key
,
[
&
]
{
return
createImpl
(
paramConfig
.
update_hooks
(
idx
));
});
...
...
proto/ParameterConfig.proto
浏览文件 @
8b86624b
...
...
@@ -25,8 +25,10 @@ enum ParameterInitStrategy {
}
message
ParameterUpdaterHookConfig
{
// hook type such as 'pruning'
required
string
type
=
1
;
optional
string
purning_mask_filename
=
2
;
// this represents the ratio of zero element to be set by the Parameter
optional
double
sparsity_ratio
=
2
[
default
=
0.6
];
}
message
ParameterConfig
{
...
...
python/paddle/trainer/config_parser.py
浏览文件 @
8b86624b
...
...
@@ -3139,11 +3139,11 @@ def Layer(name, type, **xargs):
@
config_func
def
ParameterHook
(
type
,
**
kwargs
):
if
type
==
'pruning'
:
mask_filename
=
kwargs
.
get
(
'mask_filename'
,
None
)
assert
mask_filename
is
not
None
hook
=
ParameterUpdaterHookConfig
()
hook
.
type
=
type
hook
.
purning_mask_filename
=
mask_filename
sparsity_ratio
=
kwargs
.
get
(
'sparsity_ratio'
,
None
)
if
sparsity_ratio
is
not
None
:
hook
.
sparsity_ratio
=
sparsity_ratio
return
hook
else
:
return
None
...
...
@@ -3251,13 +3251,13 @@ def Parameter(name,
if
update_hooks
is
not
None
:
if
hasattr
(
update_hooks
,
'__call__'
):
update_hooks
=
update_hooks
(
para
.
name
)
update_hooks
=
update_hooks
()
if
isinstance
(
update_hooks
,
list
):
for
hook
in
update_hooks
:
para
.
update_hooks
.
extend
([
hook
])
else
:
para
.
update_hooks
.
extend
(
update_hooks
)
para
.
update_hooks
.
extend
(
[
update_hooks
]
)
g_parameter_map
[
name
]
=
para
if
initializer
is
not
None
:
...
...
python/paddle/trainer_config_helpers/attrs.py
浏览文件 @
8b86624b
...
...
@@ -14,7 +14,8 @@
from
paddle.trainer.config_parser
import
*
__all__
=
[
'ParamAttr'
,
'ExtraAttr'
,
'ParameterAttribute'
,
'ExtraLayerAttribute'
'HookAttr'
,
'ParamAttr'
,
'ExtraAttr'
,
'ParameterAttribute'
,
'ExtraLayerAttribute'
]
...
...
@@ -55,6 +56,40 @@ def is_compatible_with(x, Type):
return
False
class
HookAttribute
(
object
):
"""
Hook Attribute object. As a member of ParameterAttribute class, the hook is an auxiliary operation that occurs
during training process of a layer with parameters, such as img_conv layer, fc layer.
:param type: Hook type, currently supported types:
'pruning' : user specify a sparsity_ratio before training started, and the
network will prune the parameters based on the sparsity_ratio.
eg: The definition of Hook object can be hk = HookAttribute('pruning', 0.6)
The specific usage can be paddle.layer.img_conv(input=img, filter_size=3,
num_channels=3, num_filters=64,
param_attr=ParameterAttribute(update_hooks=hk) )
The pruning details can be found https://arxiv.org/pdf/1506.02626.pdf
:type type: string
:param sparsity_ratio: Must be specified if hook type is 'pruning',
it represents the ratio of the zero elements to be set by the Parameter.
:type sparsity_ratio: float or None
"""
def
__init__
(
self
,
type
,
sparsity_ratio
=
None
):
self
.
type
=
type
self
.
sparsity_ratio
=
sparsity_ratio
if
self
.
sparsity_ratio
is
not
None
:
assert
is_compatible_with
(
self
.
sparsity_ratio
,
float
),
'sparisity_ratio must be float type'
assert
self
.
sparsity_ratio
<=
1
and
self
.
sparsity_ratio
>=
0
,
'sparsity_ratio must be a float between [0, 1] '
def
__call__
(
self
):
return
ParameterHook
(
self
.
type
,
sparsity_ratio
=
self
.
sparsity_ratio
)
class
ParameterAttribute
(
object
):
"""
Parameter Attributes object. To fine-tuning network training process, user
...
...
@@ -114,6 +149,7 @@ class ParameterAttribute(object):
momentum
=
None
,
gradient_clipping_threshold
=
None
,
sparse_update
=
False
,
update_hooks
=
None
,
initializer
=
None
):
self
.
attr
=
{}
...
...
@@ -169,6 +205,9 @@ class ParameterAttribute(object):
if
initializer
is
not
None
:
self
.
attr
[
'initializer'
]
=
initializer
if
update_hooks
:
self
.
attr
[
'update_hooks'
]
=
update_hooks
def
set_default_parameter_name
(
self
,
name
):
"""
Set default parameter name. If parameter not set, then will use default
...
...
@@ -244,5 +283,6 @@ class ExtraLayerAttribute(object):
return
attr
.
attr
HookAttr
=
HookAttribute
ParamAttr
=
ParameterAttribute
ExtraAttr
=
ExtraLayerAttribute
python/paddle/v2/attr.py
浏览文件 @
8b86624b
...
...
@@ -17,10 +17,12 @@ import paddle.trainer_config_helpers.attrs
__all__
=
[
"Param"
,
"Extra"
,
"Hook"
,
]
Param
=
paddle
.
trainer_config_helpers
.
attrs
.
ParameterAttribute
Extra
=
paddle
.
trainer_config_helpers
.
attrs
.
ExtraLayerAttribute
Hook
=
paddle
.
trainer_config_helpers
.
attrs
.
HookAttribute
for
each
in
paddle
.
trainer_config_helpers
.
attrs
.
__all__
:
globals
()[
each
]
=
getattr
(
paddle
.
trainer_config_helpers
.
attrs
,
each
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录