提交 5413af8d 编写于 作者: X xzl

imporve pruning module

上级 da83d286
...@@ -25,6 +25,9 @@ limitations under the License. */ ...@@ -25,6 +25,9 @@ limitations under the License. */
#include "paddle/utils/Flags.h" #include "paddle/utils/Flags.h"
#include "paddle/utils/Util.h" #include "paddle/utils/Util.h"
using std::vector;
using std::pair;
namespace paddle { namespace paddle {
/** /**
...@@ -131,6 +134,73 @@ private: ...@@ -131,6 +134,73 @@ private:
std::vector<bool> mask_; std::vector<bool> mask_;
}; };
class DynamicPruningHook : public IParameterUpdaterHook {
public:
explicit DynamicPruningHook(const ParameterUpdaterHookConfig& hookConfig)
: initCount_(0) {
sparsityRatio_ = hookConfig.sparsity_ratio();
}
static bool sortPairAscend(const pair<real, size_t>& pair1,
const pair<real, size_t>& pair2) {
return pair1.first > pair2.first;
}
void update(Parameter* para) {
updateThreadChecker_.check();
auto& vec = para->getBuf(PARAMETER_GRADIENT);
if (vec) {
vec->dotMul(*maskVec_);
}
}
void generateMask(Parameter* para) {
VectorPtr vec = para->getBuf(PARAMETER_VALUE);
maskTemp_ = Vector::create(para->getSize(), false);
maskTemp_->zeroMem();
real* dataPtr = maskTemp_->getData();
VectorPtr vecCpu = Vector::create(para->getSize(), false);
vecCpu->copyFrom(*vec);
vector<pair<real, size_t>> param;
for (size_t i = 0; i < para->getSize(); i++)
param.push_back(std::make_pair(fabs(vecCpu->getData()[i]), i));
std::sort(param.begin(), param.end(), sortPairAscend);
for (size_t i = 0; i < para->getSize() * sparsityRatio_; i++)
dataPtr[param[i].second] = 1.0;
}
void init(Parameter* para) {
generateMask(para);
size_t initCount = this->initCount_.fetch_add(1);
CHECK_EQ(initCount, 0UL) << "Currently the DynamicPruningHook must invoke "
"in same ParamterUpdater";
VLOG(3) << "Initialize Parameter " << para;
SetDevice device(para->getDeviceId());
// Currently just use a mask vector for hack.
// @TODO(yuyang18): Implemented the mask operation in vector.
if (para->useGpu()) {
maskVec_ = Vector::create(para->getSize(), para->useGpu());
maskVec_->copyFrom(*maskTemp_);
} else {
maskVec_ = maskTemp_;
}
auto& vec = para->getBuf(PARAMETER_VALUE);
vec->dotMul(*maskVec_);
}
private:
SameThreadChecker updateThreadChecker_;
std::atomic<size_t> initCount_;
VectorPtr maskVec_;
VectorPtr maskTemp_;
real sparsityRatio_;
};
IParameterUpdaterHook::IParameterUpdaterHook() {} IParameterUpdaterHook::IParameterUpdaterHook() {}
IParameterUpdaterHook::~IParameterUpdaterHook() {} IParameterUpdaterHook::~IParameterUpdaterHook() {}
...@@ -156,8 +226,7 @@ private: ...@@ -156,8 +226,7 @@ private:
static WeakKVCache<std::pair<std::string, int>, static WeakKVCache<std::pair<std::string, int>,
IParameterUpdaterHook, IParameterUpdaterHook,
StringIntPairHasher> StringIntPairHasher> g_hookCache_;
g_hookCache_;
/** /**
* ParameterUpdaterHook actually factory method. * ParameterUpdaterHook actually factory method.
...@@ -165,11 +234,22 @@ static WeakKVCache<std::pair<std::string, int>, ...@@ -165,11 +234,22 @@ static WeakKVCache<std::pair<std::string, int>,
static IParameterUpdaterHook* createImpl( static IParameterUpdaterHook* createImpl(
const ParameterUpdaterHookConfig& config) { const ParameterUpdaterHookConfig& config) {
auto& type = config.type(); auto& type = config.type();
if (type == "pruning") { if (type == "pruning_static") {
if (config.has_purning_mask_filename()) { if (config.has_purning_mask_filename())
return new StaticPruningHook(config.purning_mask_filename()); return new StaticPruningHook(config.purning_mask_filename());
} else
LOG(FATAL) << "There must be mask_filename parameter for " << type
<< " Hook";
} else if (type == "pruning") {
if (config.has_sparsity_ratio())
return new DynamicPruningHook(config);
else
LOG(FATAL) << "There must be sparsity_ratio parameter for " << type
<< " Hook";
} }
LOG(FATAL) << "Unknown Hook type: " << type;
return nullptr; return nullptr;
} }
......
...@@ -26,7 +26,9 @@ enum ParameterInitStrategy { ...@@ -26,7 +26,9 @@ enum ParameterInitStrategy {
message ParameterUpdaterHookConfig { message ParameterUpdaterHookConfig {
required string type = 1; required string type = 1;
//hook type such as 'pruning', 'pruning_static'
optional string purning_mask_filename = 2; optional string purning_mask_filename = 2;
optional double sparsity_ratio = 3;
} }
message ParameterConfig { message ParameterConfig {
......
...@@ -3171,12 +3171,19 @@ def Layer(name, type, **xargs): ...@@ -3171,12 +3171,19 @@ def Layer(name, type, **xargs):
@config_func @config_func
def ParameterHook(type, **kwargs): def ParameterHook(type, **kwargs):
if type == 'pruning': if type == 'pruning_static':
hook = ParameterUpdaterHookConfig()
hook.type = type
mask_filename = kwargs.get('mask_filename', None) mask_filename = kwargs.get('mask_filename', None)
assert mask_filename is not None assert mask_filename is not None
hook.pruning_mask_filename = mask_filename
return hook
elif type == 'pruning':
hook = ParameterUpdaterHookConfig() hook = ParameterUpdaterHookConfig()
hook.type = type hook.type = type
hook.purning_mask_filename = mask_filename sparsity_ratio = kwargs.get('sparsity_ratio', None)
assert sparsity_ratio is not None
hook.sparsity_ratio = sparsity_ratio
return hook return hook
else: else:
return None return None
...@@ -3283,13 +3290,13 @@ def Parameter(name, ...@@ -3283,13 +3290,13 @@ def Parameter(name,
if update_hooks is not None: if update_hooks is not None:
if hasattr(update_hooks, '__call__'): if hasattr(update_hooks, '__call__'):
update_hooks = update_hooks(para.name) update_hooks = update_hooks()
if isinstance(update_hooks, list): if isinstance(update_hooks, list):
for hook in update_hooks: for hook in update_hooks:
para.update_hooks.extend([hook]) para.update_hooks.extend([hook])
else: else:
para.update_hooks.extend(update_hooks) para.update_hooks.extend([update_hooks])
g_parameter_map[name] = para g_parameter_map[name] = para
......
...@@ -14,7 +14,8 @@ ...@@ -14,7 +14,8 @@
from paddle.trainer.config_parser import * from paddle.trainer.config_parser import *
__all__ = [ __all__ = [
'ParamAttr', 'ExtraAttr', 'ParameterAttribute', 'ExtraLayerAttribute' 'HookAttr', 'ParamAttr', 'ExtraAttr', 'ParameterAttribute',
'ExtraLayerAttribute'
] ]
...@@ -55,6 +56,42 @@ def is_compatible_with(x, Type): ...@@ -55,6 +56,42 @@ def is_compatible_with(x, Type):
return False return False
class HookAttribute(object):
"""
Hook Attribute object. The hook is an auxiliary operation that occurs
during network propagation. Such as pruning operation, It will cut off
redundant parameters in the network before training. More detail can see
here paddle/parameter/ParameterUpdaterHook.cpp
NOTE: IT IS A HIGH LEVEL USER INTERFACE.
:param type: Hook type, eg: 'pruning', 'pruning_static'
:type type: string
:param mask_file: Must be specified if hook type is 'pruning_static',
the network reads the mask from the file to determine which parameters should be cut off
:type mask_file: string
:param sparsity_ratio: Must be specified if hook type is 'pruning',
the network will hold the sparsity_ratio maximum parameters, and cut off the rest.
:type sparsity_ratio: float number between 0 and 1
"""
def __init__(self, type, mask_filename=None, sparsity_ratio=None):
self.type = type
self.mask_filename = mask_filename
self.sparsity_ratio = sparsity_ratio
assert is_compatible_with(self.sparsity_ratio,
float), 'sparisity_ratio must be float type'
assert self.sparsity_ratio <= 1 and self.sparsity_ratio >= 0, 'sparisity must be a flaot between [0, 1] '
def __call__(self):
return ParameterHook(
self.type,
mask_filename=self.mask_filename,
sparsity_ratio=self.sparsity_ratio)
class ParameterAttribute(object): class ParameterAttribute(object):
""" """
Parameter Attributes object. To fine-tuning network training process, user Parameter Attributes object. To fine-tuning network training process, user
...@@ -109,7 +146,8 @@ class ParameterAttribute(object): ...@@ -109,7 +146,8 @@ class ParameterAttribute(object):
learning_rate=None, learning_rate=None,
momentum=None, momentum=None,
gradient_clipping_threshold=None, gradient_clipping_threshold=None,
sparse_update=False): sparse_update=False,
update_hooks=None):
self.attr = {} self.attr = {}
if is_static: if is_static:
...@@ -162,6 +200,9 @@ class ParameterAttribute(object): ...@@ -162,6 +200,9 @@ class ParameterAttribute(object):
self.attr['gradient_clipping_threshold'] = \ self.attr['gradient_clipping_threshold'] = \
gradient_clipping_threshold gradient_clipping_threshold
if update_hooks:
self.attr['update_hooks'] = update_hooks
def set_default_parameter_name(self, name): def set_default_parameter_name(self, name):
""" """
Set default parameter name. If parameter not set, then will use default Set default parameter name. If parameter not set, then will use default
...@@ -237,5 +278,6 @@ class ExtraLayerAttribute(object): ...@@ -237,5 +278,6 @@ class ExtraLayerAttribute(object):
return attr.attr return attr.attr
HookAttr = HookAttribute
ParamAttr = ParameterAttribute ParamAttr = ParameterAttribute
ExtraAttr = ExtraLayerAttribute ExtraAttr = ExtraLayerAttribute
...@@ -17,10 +17,12 @@ import paddle.trainer_config_helpers.attrs ...@@ -17,10 +17,12 @@ import paddle.trainer_config_helpers.attrs
__all__ = [ __all__ = [
"Param", "Param",
"Extra", "Extra",
"Hook",
] ]
Param = paddle.trainer_config_helpers.attrs.ParameterAttribute Param = paddle.trainer_config_helpers.attrs.ParameterAttribute
Extra = paddle.trainer_config_helpers.attrs.ExtraLayerAttribute Extra = paddle.trainer_config_helpers.attrs.ExtraLayerAttribute
Hook = paddle.trainer_config_helpers.attrs.HookAttribute
for each in paddle.trainer_config_helpers.attrs.__all__: for each in paddle.trainer_config_helpers.attrs.__all__:
globals()[each] = getattr(paddle.trainer_config_helpers.attrs, each) globals()[each] = getattr(paddle.trainer_config_helpers.attrs, each)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册