diff --git a/paddle/parameter/ParameterUpdaterHook.cpp b/paddle/parameter/ParameterUpdaterHook.cpp index 5e8c77ced03f5547f3d1145b1c7c4900a5223087..a581cc047dfe11105774e4b81d781100606d718f 100644 --- a/paddle/parameter/ParameterUpdaterHook.cpp +++ b/paddle/parameter/ParameterUpdaterHook.cpp @@ -20,6 +20,7 @@ limitations under the License. */ #include #include #include +#include #include "paddle/math/Vector.h" #include "paddle/parameter/Parameter.h" @@ -60,6 +61,7 @@ public: maskTemp_ = Vector::create(para->getSize(), false); maskTemp_->zeroMem(); real* dataPtr = maskTemp_->getData(); + size_t sparsityNum = para->getSize() * (1 - sparsityRatio_); VectorPtr vecCpu = Vector::create(para->getSize(), false); vecCpu->copyFrom(*vec); @@ -67,10 +69,20 @@ public: for (size_t i = 0; i < para->getSize(); i++) param.push_back(std::make_pair(fabs(vecCpu->getData()[i]), i)); - std::sort(param.begin(), param.end(), sortPairAscend); - for (size_t i = 0; i < para->getSize() * sparsityRatio_; i++) - dataPtr[param[i].second] = 1.0; + std::partial_sort(param.begin(), + param.begin() + sparsityNum, + param.end(), + sortPairAscend); + for (size_t i = 0; i < sparsityNum; i++) dataPtr[param[i].second] = 1.0; + + // Currently just use a mask vector for hack. + if (para->useGpu()) { + maskVec_ = Vector::create(para->getSize(), para->useGpu()); + maskVec_->copyFrom(*maskTemp_); + } else { + maskVec_ = maskTemp_; + } } void init(Parameter* para) { @@ -81,15 +93,6 @@ public: VLOG(3) << "Initialize Parameter " << para; SetDevice device(para->getDeviceId()); - // Currently just use a mask vector for hack. - // @TODO(yuyang18): Implemented the mask operation in vector. - if (para->useGpu()) { - maskVec_ = Vector::create(para->getSize(), para->useGpu()); - maskVec_->copyFrom(*maskTemp_); - } else { - maskVec_ = maskTemp_; - } - auto& vec = para->getBuf(PARAMETER_VALUE); vec->dotMul(*maskVec_); } @@ -136,11 +139,7 @@ static IParameterUpdaterHook* createImpl( const ParameterUpdaterHookConfig& config) { auto& type = config.type(); if (type == "pruning") { - if (config.has_sparsity_ratio()) - return new StaticPruningHook(config); - else - LOG(FATAL) << "There must be sparsity_ratio parameter for " << type - << " Hook"; + return new StaticPruningHook(config); } LOG(FATAL) << "Unknown Hook type: " << type; diff --git a/python/paddle/trainer/config_parser.py b/python/paddle/trainer/config_parser.py index e0147b1b37c6574c65ce53e58eccaf6cede91a67..3a29c91807f023d9cae509ee729b2eeeb1038805 100644 --- a/python/paddle/trainer/config_parser.py +++ b/python/paddle/trainer/config_parser.py @@ -3175,8 +3175,8 @@ def ParameterHook(type, **kwargs): hook = ParameterUpdaterHookConfig() hook.type = type sparsity_ratio = kwargs.get('sparsity_ratio', None) - assert sparsity_ratio is not None - hook.sparsity_ratio = sparsity_ratio + if sparsity_ratio is not None: + hook.sparsity_ratio = sparsity_ratio return hook else: return None diff --git a/python/paddle/trainer_config_helpers/attrs.py b/python/paddle/trainer_config_helpers/attrs.py index 556701ca7a88c595c4c16fe5b24ccd5b72ae887d..27b54ffdea74e85d803e6cb5d497e5af7ee108a7 100644 --- a/python/paddle/trainer_config_helpers/attrs.py +++ b/python/paddle/trainer_config_helpers/attrs.py @@ -73,9 +73,11 @@ class HookAttribute(object): def __init__(self, type, sparsity_ratio=None): self.type = type self.sparsity_ratio = sparsity_ratio - assert is_compatible_with(self.sparsity_ratio, - float), 'sparisity_ratio must be float type' - assert self.sparsity_ratio <= 1 and self.sparsity_ratio >= 0, 'sparisity must be a flaot between [0, 1] ' + if self.sparsity_ratio is not None: + assert is_compatible_with( + self.sparsity_ratio, + float), 'sparisity_ratio must be float type' + assert self.sparsity_ratio <= 1 and self.sparsity_ratio >= 0, 'sparisity must be a flaot between [0, 1] ' def __call__(self): return ParameterHook(self.type, sparsity_ratio=self.sparsity_ratio)