提交 aaf11fa6 编写于 作者: Z zlx

modify the format

上级 1eab8cce
...@@ -38,29 +38,28 @@ namespace paddle { ...@@ -38,29 +38,28 @@ namespace paddle {
class StaticPruningHook : public IParameterUpdaterHook { class StaticPruningHook : public IParameterUpdaterHook {
public: public:
explicit StaticPruningHook(const ParameterUpdaterHookConfig& hookConfig) explicit StaticPruningHook(const ParameterUpdaterHookConfig &hookConfig)
: initCount_(0) { : initCount_(0) {
sparsityRatio_ = hookConfig.sparsity_ratio(); sparsityRatio_ = hookConfig.sparsity_ratio();
} }
static bool sortPairAscend(const std::pair<real, size_t>& pair1, static bool sortPairAscend(const std::pair<real, size_t> &pair1,
const std::pair<real, size_t>& pair2) { const std::pair<real, size_t> &pair2) {
return pair1.first > pair2.first; return pair1.first > pair2.first;
} }
void update(Parameter* para) { void update(Parameter *para) {
updateThreadChecker_.check(); updateThreadChecker_.check();
auto& vec = para->getBuf(PARAMETER_GRADIENT); auto &vec = para->getBuf(PARAMETER_GRADIENT);
if (vec) { if (vec) {
vec->dotMul(*maskVec_); vec->dotMul(*maskVec_);
} }
} }
void generateMask(Parameter* para) { void generateMask(Parameter *para) {
VectorPtr maskTemp = Vector::create(para->getSize(), false); VectorPtr maskTemp = Vector::create(para->getSize(), false);
maskTemp->zeroMem(); maskTemp->zeroMem();
real* maskTempData = maskTemp->getData(); real *maskTempData = maskTemp->getData();
size_t nonZeroNum = para->getSize() * (1 - sparsityRatio_); size_t nonZeroNum = para->getSize() * (1 - sparsityRatio_);
VectorPtr paraVec = para->getBuf(PARAMETER_VALUE); VectorPtr paraVec = para->getBuf(PARAMETER_VALUE);
...@@ -72,9 +71,10 @@ public: ...@@ -72,9 +71,10 @@ public:
for (size_t i = 0; i < para->getSize(); i++) for (size_t i = 0; i < para->getSize(); i++)
param.push_back(std::make_pair(fabs(paraCpuCopy->getData()[i]), i)); param.push_back(std::make_pair(fabs(paraCpuCopy->getData()[i]), i));
std::partial_sort( std::partial_sort(param.begin(), param.begin() + nonZeroNum, param.end(),
param.begin(), param.begin() + nonZeroNum, param.end(), sortPairAscend); sortPairAscend);
for (size_t i = 0; i < nonZeroNum; i++) maskTempData[param[i].second] = 1.0; for (size_t i = 0; i < nonZeroNum; i++)
maskTempData[param[i].second] = 1.0;
// Currently just use a mask vector for hack. // Currently just use a mask vector for hack.
if (para->useGpu()) { if (para->useGpu()) {
...@@ -85,7 +85,7 @@ public: ...@@ -85,7 +85,7 @@ public:
} }
} }
void init(Parameter* para) { void init(Parameter *para) {
generateMask(para); generateMask(para);
size_t initCount = this->initCount_.fetch_add(1); size_t initCount = this->initCount_.fetch_add(1);
CHECK_EQ(initCount, 0UL) << "Currently the StaticPruningHook must invoke " CHECK_EQ(initCount, 0UL) << "Currently the StaticPruningHook must invoke "
...@@ -93,7 +93,7 @@ public: ...@@ -93,7 +93,7 @@ public:
VLOG(3) << "Initialize Parameter " << para; VLOG(3) << "Initialize Parameter " << para;
SetDevice device(para->getDeviceId()); SetDevice device(para->getDeviceId());
auto& paraVec = para->getBuf(PARAMETER_VALUE); auto &paraVec = para->getBuf(PARAMETER_VALUE);
paraVec->dotMul(*maskVec_); paraVec->dotMul(*maskVec_);
} }
...@@ -118,7 +118,7 @@ IParameterUpdaterHook::~IParameterUpdaterHook() {} ...@@ -118,7 +118,7 @@ IParameterUpdaterHook::~IParameterUpdaterHook() {}
*/ */
class StringIntPairHasher { class StringIntPairHasher {
public: public:
size_t operator()(const std::pair<std::string, int>& k) const { size_t operator()(const std::pair<std::string, int> &k) const {
return intHasher_(strHasher_(k.first) + k.second); return intHasher_(strHasher_(k.first) + k.second);
} }
...@@ -127,17 +127,15 @@ private: ...@@ -127,17 +127,15 @@ private:
std::hash<int> intHasher_; std::hash<int> intHasher_;
}; };
static WeakKVCache<std::pair<std::string, int>, static WeakKVCache<std::pair<std::string, int>, IParameterUpdaterHook,
IParameterUpdaterHook, StringIntPairHasher> g_hookCache_;
StringIntPairHasher>
g_hookCache_;
/** /**
* ParameterUpdaterHook actually factory method. * ParameterUpdaterHook actually factory method.
*/ */
static IParameterUpdaterHook* createImpl( static IParameterUpdaterHook *
const ParameterUpdaterHookConfig& config) { createImpl(const ParameterUpdaterHookConfig &config) {
auto& type = config.type(); auto &type = config.type();
if (type == "pruning") { if (type == "pruning") {
return new StaticPruningHook(config); return new StaticPruningHook(config);
} }
...@@ -146,11 +144,11 @@ static IParameterUpdaterHook* createImpl( ...@@ -146,11 +144,11 @@ static IParameterUpdaterHook* createImpl(
return nullptr; return nullptr;
} }
std::shared_ptr<IParameterUpdaterHook> IParameterUpdaterHook::create( std::shared_ptr<IParameterUpdaterHook>
const ParameterConfig& paramConfig, int idx) { IParameterUpdaterHook::create(const ParameterConfig &paramConfig, int idx) {
std::pair<std::string, int> key = {paramConfig.name(), idx}; std::pair<std::string, int> key = {paramConfig.name(), idx};
return g_hookCache_.get( return g_hookCache_.get(
key, [&] { return createImpl(paramConfig.update_hooks(idx)); }); key, [&] { return createImpl(paramConfig.update_hooks(idx)); });
} }
} // namespace paddle } // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册