From 5413af8d11eb8662598e5f89d5f4fa284e1c6fdf Mon Sep 17 00:00:00 2001 From: xzl Date: Fri, 2 Jun 2017 14:28:20 +0800 Subject: [PATCH] imporve pruning module --- paddle/parameter/ParameterUpdaterHook.cpp | 90 +++++++++++++++++-- proto/ParameterConfig.proto | 2 + python/paddle/trainer/config_parser.py | 15 +++- python/paddle/trainer_config_helpers/attrs.py | 46 +++++++++- python/paddle/v2/attr.py | 2 + 5 files changed, 144 insertions(+), 11 deletions(-) diff --git a/paddle/parameter/ParameterUpdaterHook.cpp b/paddle/parameter/ParameterUpdaterHook.cpp index f826e8448c6..76cc3ecad14 100644 --- a/paddle/parameter/ParameterUpdaterHook.cpp +++ b/paddle/parameter/ParameterUpdaterHook.cpp @@ -25,6 +25,9 @@ limitations under the License. */ #include "paddle/utils/Flags.h" #include "paddle/utils/Util.h" +using std::vector; +using std::pair; + namespace paddle { /** @@ -131,6 +134,73 @@ private: std::vector mask_; }; +class DynamicPruningHook : public IParameterUpdaterHook { +public: + explicit DynamicPruningHook(const ParameterUpdaterHookConfig& hookConfig) + : initCount_(0) { + sparsityRatio_ = hookConfig.sparsity_ratio(); + } + + static bool sortPairAscend(const pair& pair1, + const pair& pair2) { + return pair1.first > pair2.first; + } + + void update(Parameter* para) { + updateThreadChecker_.check(); + auto& vec = para->getBuf(PARAMETER_GRADIENT); + if (vec) { + vec->dotMul(*maskVec_); + } + } + + void generateMask(Parameter* para) { + VectorPtr vec = para->getBuf(PARAMETER_VALUE); + maskTemp_ = Vector::create(para->getSize(), false); + maskTemp_->zeroMem(); + real* dataPtr = maskTemp_->getData(); + + VectorPtr vecCpu = Vector::create(para->getSize(), false); + vecCpu->copyFrom(*vec); + vector> param; + + for (size_t i = 0; i < para->getSize(); i++) + param.push_back(std::make_pair(fabs(vecCpu->getData()[i]), i)); + std::sort(param.begin(), param.end(), sortPairAscend); + + for (size_t i = 0; i < para->getSize() * sparsityRatio_; i++) + dataPtr[param[i].second] = 1.0; + } + + void init(Parameter* para) { + generateMask(para); + size_t initCount = this->initCount_.fetch_add(1); + CHECK_EQ(initCount, 0UL) << "Currently the DynamicPruningHook must invoke " + "in same ParamterUpdater"; + VLOG(3) << "Initialize Parameter " << para; + SetDevice device(para->getDeviceId()); + + // Currently just use a mask vector for hack. + // @TODO(yuyang18): Implemented the mask operation in vector. + if (para->useGpu()) { + maskVec_ = Vector::create(para->getSize(), para->useGpu()); + maskVec_->copyFrom(*maskTemp_); + } else { + maskVec_ = maskTemp_; + } + + auto& vec = para->getBuf(PARAMETER_VALUE); + vec->dotMul(*maskVec_); + } + +private: + SameThreadChecker updateThreadChecker_; + std::atomic initCount_; + VectorPtr maskVec_; + VectorPtr maskTemp_; + real sparsityRatio_; +}; + IParameterUpdaterHook::IParameterUpdaterHook() {} IParameterUpdaterHook::~IParameterUpdaterHook() {} @@ -156,8 +226,7 @@ private: static WeakKVCache, IParameterUpdaterHook, - StringIntPairHasher> - g_hookCache_; + StringIntPairHasher> g_hookCache_; /** * ParameterUpdaterHook actually factory method. @@ -165,11 +234,22 @@ static WeakKVCache, static IParameterUpdaterHook* createImpl( const ParameterUpdaterHookConfig& config) { auto& type = config.type(); - if (type == "pruning") { - if (config.has_purning_mask_filename()) { + if (type == "pruning_static") { + if (config.has_purning_mask_filename()) return new StaticPruningHook(config.purning_mask_filename()); - } + else + LOG(FATAL) << "There must be mask_filename parameter for " << type + << " Hook"; + + } else if (type == "pruning") { + if (config.has_sparsity_ratio()) + return new DynamicPruningHook(config); + else + LOG(FATAL) << "There must be sparsity_ratio parameter for " << type + << " Hook"; } + + LOG(FATAL) << "Unknown Hook type: " << type; return nullptr; } diff --git a/proto/ParameterConfig.proto b/proto/ParameterConfig.proto index cbcd0af598d..61f4b037cf0 100644 --- a/proto/ParameterConfig.proto +++ b/proto/ParameterConfig.proto @@ -26,7 +26,9 @@ enum ParameterInitStrategy { message ParameterUpdaterHookConfig { required string type = 1; + //hook type such as 'pruning', 'pruning_static' optional string purning_mask_filename = 2; + optional double sparsity_ratio = 3; } message ParameterConfig { diff --git a/python/paddle/trainer/config_parser.py b/python/paddle/trainer/config_parser.py index 9fe8794691e..d80590210f2 100644 --- a/python/paddle/trainer/config_parser.py +++ b/python/paddle/trainer/config_parser.py @@ -3171,12 +3171,19 @@ def Layer(name, type, **xargs): @config_func def ParameterHook(type, **kwargs): - if type == 'pruning': + if type == 'pruning_static': + hook = ParameterUpdaterHookConfig() + hook.type = type mask_filename = kwargs.get('mask_filename', None) assert mask_filename is not None + hook.pruning_mask_filename = mask_filename + return hook + elif type == 'pruning': hook = ParameterUpdaterHookConfig() hook.type = type - hook.purning_mask_filename = mask_filename + sparsity_ratio = kwargs.get('sparsity_ratio', None) + assert sparsity_ratio is not None + hook.sparsity_ratio = sparsity_ratio return hook else: return None @@ -3283,13 +3290,13 @@ def Parameter(name, if update_hooks is not None: if hasattr(update_hooks, '__call__'): - update_hooks = update_hooks(para.name) + update_hooks = update_hooks() if isinstance(update_hooks, list): for hook in update_hooks: para.update_hooks.extend([hook]) else: - para.update_hooks.extend(update_hooks) + para.update_hooks.extend([update_hooks]) g_parameter_map[name] = para diff --git a/python/paddle/trainer_config_helpers/attrs.py b/python/paddle/trainer_config_helpers/attrs.py index d1167a234ca..011147a3685 100644 --- a/python/paddle/trainer_config_helpers/attrs.py +++ b/python/paddle/trainer_config_helpers/attrs.py @@ -14,7 +14,8 @@ from paddle.trainer.config_parser import * __all__ = [ - 'ParamAttr', 'ExtraAttr', 'ParameterAttribute', 'ExtraLayerAttribute' + 'HookAttr', 'ParamAttr', 'ExtraAttr', 'ParameterAttribute', + 'ExtraLayerAttribute' ] @@ -55,6 +56,42 @@ def is_compatible_with(x, Type): return False +class HookAttribute(object): + """ + Hook Attribute object. The hook is an auxiliary operation that occurs + during network propagation. Such as pruning operation, It will cut off + redundant parameters in the network before training. More detail can see + here paddle/parameter/ParameterUpdaterHook.cpp + NOTE: IT IS A HIGH LEVEL USER INTERFACE. + + :param type: Hook type, eg: 'pruning', 'pruning_static' + :type type: string + + :param mask_file: Must be specified if hook type is 'pruning_static', + the network reads the mask from the file to determine which parameters should be cut off + :type mask_file: string + + :param sparsity_ratio: Must be specified if hook type is 'pruning', + the network will hold the sparsity_ratio maximum parameters, and cut off the rest. + :type sparsity_ratio: float number between 0 and 1 + + """ + + def __init__(self, type, mask_filename=None, sparsity_ratio=None): + self.type = type + self.mask_filename = mask_filename + self.sparsity_ratio = sparsity_ratio + assert is_compatible_with(self.sparsity_ratio, + float), 'sparisity_ratio must be float type' + assert self.sparsity_ratio <= 1 and self.sparsity_ratio >= 0, 'sparisity must be a flaot between [0, 1] ' + + def __call__(self): + return ParameterHook( + self.type, + mask_filename=self.mask_filename, + sparsity_ratio=self.sparsity_ratio) + + class ParameterAttribute(object): """ Parameter Attributes object. To fine-tuning network training process, user @@ -109,7 +146,8 @@ class ParameterAttribute(object): learning_rate=None, momentum=None, gradient_clipping_threshold=None, - sparse_update=False): + sparse_update=False, + update_hooks=None): self.attr = {} if is_static: @@ -162,6 +200,9 @@ class ParameterAttribute(object): self.attr['gradient_clipping_threshold'] = \ gradient_clipping_threshold + if update_hooks: + self.attr['update_hooks'] = update_hooks + def set_default_parameter_name(self, name): """ Set default parameter name. If parameter not set, then will use default @@ -237,5 +278,6 @@ class ExtraLayerAttribute(object): return attr.attr +HookAttr = HookAttribute ParamAttr = ParameterAttribute ExtraAttr = ExtraLayerAttribute diff --git a/python/paddle/v2/attr.py b/python/paddle/v2/attr.py index 32f78614e7f..5d23894d735 100644 --- a/python/paddle/v2/attr.py +++ b/python/paddle/v2/attr.py @@ -17,10 +17,12 @@ import paddle.trainer_config_helpers.attrs __all__ = [ "Param", "Extra", + "Hook", ] Param = paddle.trainer_config_helpers.attrs.ParameterAttribute Extra = paddle.trainer_config_helpers.attrs.ExtraLayerAttribute +Hook = paddle.trainer_config_helpers.attrs.HookAttribute for each in paddle.trainer_config_helpers.attrs.__all__: globals()[each] = getattr(paddle.trainer_config_helpers.attrs, each) -- GitLab