From 997cef2e63ef4d7c99c58710289f7581d2af08c6 Mon Sep 17 00:00:00 2001 From: xzl Date: Wed, 14 Jun 2017 17:26:08 +0800 Subject: [PATCH] tiny modify --- paddle/parameter/ParameterUpdaterHook.cpp | 33 +++++++++---------- python/paddle/trainer/config_parser.py | 4 +-- python/paddle/trainer_config_helpers/attrs.py | 8 +++-- 3 files changed, 23 insertions(+), 22 deletions(-) diff --git a/paddle/parameter/ParameterUpdaterHook.cpp b/paddle/parameter/ParameterUpdaterHook.cpp index 5e8c77ced03..a581cc047df 100644 --- a/paddle/parameter/ParameterUpdaterHook.cpp +++ b/paddle/parameter/ParameterUpdaterHook.cpp @@ -20,6 +20,7 @@ limitations under the License. */ #include #include #include +#include #include "paddle/math/Vector.h" #include "paddle/parameter/Parameter.h" @@ -60,6 +61,7 @@ public: maskTemp_ = Vector::create(para->getSize(), false); maskTemp_->zeroMem(); real* dataPtr = maskTemp_->getData(); + size_t sparsityNum = para->getSize() * (1 - sparsityRatio_); VectorPtr vecCpu = Vector::create(para->getSize(), false); vecCpu->copyFrom(*vec); @@ -67,10 +69,20 @@ public: for (size_t i = 0; i < para->getSize(); i++) param.push_back(std::make_pair(fabs(vecCpu->getData()[i]), i)); - std::sort(param.begin(), param.end(), sortPairAscend); - for (size_t i = 0; i < para->getSize() * sparsityRatio_; i++) - dataPtr[param[i].second] = 1.0; + std::partial_sort(param.begin(), + param.begin() + sparsityNum, + param.end(), + sortPairAscend); + for (size_t i = 0; i < sparsityNum; i++) dataPtr[param[i].second] = 1.0; + + // Currently just use a mask vector for hack. + if (para->useGpu()) { + maskVec_ = Vector::create(para->getSize(), para->useGpu()); + maskVec_->copyFrom(*maskTemp_); + } else { + maskVec_ = maskTemp_; + } } void init(Parameter* para) { @@ -81,15 +93,6 @@ public: VLOG(3) << "Initialize Parameter " << para; SetDevice device(para->getDeviceId()); - // Currently just use a mask vector for hack. - // @TODO(yuyang18): Implemented the mask operation in vector. - if (para->useGpu()) { - maskVec_ = Vector::create(para->getSize(), para->useGpu()); - maskVec_->copyFrom(*maskTemp_); - } else { - maskVec_ = maskTemp_; - } - auto& vec = para->getBuf(PARAMETER_VALUE); vec->dotMul(*maskVec_); } @@ -136,11 +139,7 @@ static IParameterUpdaterHook* createImpl( const ParameterUpdaterHookConfig& config) { auto& type = config.type(); if (type == "pruning") { - if (config.has_sparsity_ratio()) - return new StaticPruningHook(config); - else - LOG(FATAL) << "There must be sparsity_ratio parameter for " << type - << " Hook"; + return new StaticPruningHook(config); } LOG(FATAL) << "Unknown Hook type: " << type; diff --git a/python/paddle/trainer/config_parser.py b/python/paddle/trainer/config_parser.py index e0147b1b37c..3a29c91807f 100644 --- a/python/paddle/trainer/config_parser.py +++ b/python/paddle/trainer/config_parser.py @@ -3175,8 +3175,8 @@ def ParameterHook(type, **kwargs): hook = ParameterUpdaterHookConfig() hook.type = type sparsity_ratio = kwargs.get('sparsity_ratio', None) - assert sparsity_ratio is not None - hook.sparsity_ratio = sparsity_ratio + if sparsity_ratio is not None: + hook.sparsity_ratio = sparsity_ratio return hook else: return None diff --git a/python/paddle/trainer_config_helpers/attrs.py b/python/paddle/trainer_config_helpers/attrs.py index 556701ca7a8..27b54ffdea7 100644 --- a/python/paddle/trainer_config_helpers/attrs.py +++ b/python/paddle/trainer_config_helpers/attrs.py @@ -73,9 +73,11 @@ class HookAttribute(object): def __init__(self, type, sparsity_ratio=None): self.type = type self.sparsity_ratio = sparsity_ratio - assert is_compatible_with(self.sparsity_ratio, - float), 'sparisity_ratio must be float type' - assert self.sparsity_ratio <= 1 and self.sparsity_ratio >= 0, 'sparisity must be a flaot between [0, 1] ' + if self.sparsity_ratio is not None: + assert is_compatible_with( + self.sparsity_ratio, + float), 'sparisity_ratio must be float type' + assert self.sparsity_ratio <= 1 and self.sparsity_ratio >= 0, 'sparisity must be a flaot between [0, 1] ' def __call__(self): return ParameterHook(self.type, sparsity_ratio=self.sparsity_ratio) -- GitLab