diff --git a/paddle/parameter/ParameterUpdaterHook.cpp b/paddle/parameter/ParameterUpdaterHook.cpp index 66e554a70d9efaab43ba46fa30b9c128514b75d8..ba2cb37fa2cecf9f04a1e52819d4e09ab6aacb19 100644 --- a/paddle/parameter/ParameterUpdaterHook.cpp +++ b/paddle/parameter/ParameterUpdaterHook.cpp @@ -38,29 +38,28 @@ namespace paddle { class StaticPruningHook : public IParameterUpdaterHook { public: - explicit StaticPruningHook(const ParameterUpdaterHookConfig& hookConfig) + explicit StaticPruningHook(const ParameterUpdaterHookConfig &hookConfig) : initCount_(0) { sparsityRatio_ = hookConfig.sparsity_ratio(); } - static bool sortPairAscend(const std::pair& pair1, - const std::pair& pair2) { + static bool sortPairAscend(const std::pair &pair1, + const std::pair &pair2) { return pair1.first > pair2.first; } - void update(Parameter* para) { + void update(Parameter *para) { updateThreadChecker_.check(); - auto& vec = para->getBuf(PARAMETER_GRADIENT); + auto &vec = para->getBuf(PARAMETER_GRADIENT); if (vec) { vec->dotMul(*maskVec_); } } - void generateMask(Parameter* para) { - + void generateMask(Parameter *para) { VectorPtr maskTemp = Vector::create(para->getSize(), false); maskTemp->zeroMem(); - real* maskTempData = maskTemp->getData(); + real *maskTempData = maskTemp->getData(); size_t nonZeroNum = para->getSize() * (1 - sparsityRatio_); VectorPtr paraVec = para->getBuf(PARAMETER_VALUE); @@ -72,9 +71,10 @@ public: for (size_t i = 0; i < para->getSize(); i++) param.push_back(std::make_pair(fabs(paraCpuCopy->getData()[i]), i)); - std::partial_sort( - param.begin(), param.begin() + nonZeroNum, param.end(), sortPairAscend); - for (size_t i = 0; i < nonZeroNum; i++) maskTempData[param[i].second] = 1.0; + std::partial_sort(param.begin(), param.begin() + nonZeroNum, param.end(), + sortPairAscend); + for (size_t i = 0; i < nonZeroNum; i++) + maskTempData[param[i].second] = 1.0; // Currently just use a mask vector for hack. if (para->useGpu()) { @@ -85,7 +85,7 @@ public: } } - void init(Parameter* para) { + void init(Parameter *para) { generateMask(para); size_t initCount = this->initCount_.fetch_add(1); CHECK_EQ(initCount, 0UL) << "Currently the StaticPruningHook must invoke " @@ -93,7 +93,7 @@ public: VLOG(3) << "Initialize Parameter " << para; SetDevice device(para->getDeviceId()); - auto& paraVec = para->getBuf(PARAMETER_VALUE); + auto ¶Vec = para->getBuf(PARAMETER_VALUE); paraVec->dotMul(*maskVec_); } @@ -118,7 +118,7 @@ IParameterUpdaterHook::~IParameterUpdaterHook() {} */ class StringIntPairHasher { public: - size_t operator()(const std::pair& k) const { + size_t operator()(const std::pair &k) const { return intHasher_(strHasher_(k.first) + k.second); } @@ -127,17 +127,15 @@ private: std::hash intHasher_; }; -static WeakKVCache, - IParameterUpdaterHook, - StringIntPairHasher> - g_hookCache_; +static WeakKVCache, IParameterUpdaterHook, + StringIntPairHasher> g_hookCache_; /** * ParameterUpdaterHook actually factory method. */ -static IParameterUpdaterHook* createImpl( - const ParameterUpdaterHookConfig& config) { - auto& type = config.type(); +static IParameterUpdaterHook * +createImpl(const ParameterUpdaterHookConfig &config) { + auto &type = config.type(); if (type == "pruning") { return new StaticPruningHook(config); } @@ -146,11 +144,11 @@ static IParameterUpdaterHook* createImpl( return nullptr; } -std::shared_ptr IParameterUpdaterHook::create( - const ParameterConfig& paramConfig, int idx) { +std::shared_ptr +IParameterUpdaterHook::create(const ParameterConfig ¶mConfig, int idx) { std::pair key = {paramConfig.name(), idx}; return g_hookCache_.get( key, [&] { return createImpl(paramConfig.update_hooks(idx)); }); } -} // namespace paddle +} // namespace paddle