提交 a266292a 编写于 作者: Z zlx

modify format

上级 aaf11fa6
...@@ -71,10 +71,9 @@ public: ...@@ -71,10 +71,9 @@ public:
for (size_t i = 0; i < para->getSize(); i++) for (size_t i = 0; i < para->getSize(); i++)
param.push_back(std::make_pair(fabs(paraCpuCopy->getData()[i]), i)); param.push_back(std::make_pair(fabs(paraCpuCopy->getData()[i]), i));
std::partial_sort(param.begin(), param.begin() + nonZeroNum, param.end(), std::partial_sort(
sortPairAscend); param.begin(), param.begin() + nonZeroNum, param.end(), sortPairAscend);
for (size_t i = 0; i < nonZeroNum; i++) for (size_t i = 0; i < nonZeroNum; i++) maskTempData[param[i].second] = 1.0;
maskTempData[param[i].second] = 1.0;
// Currently just use a mask vector for hack. // Currently just use a mask vector for hack.
if (para->useGpu()) { if (para->useGpu()) {
...@@ -127,14 +126,16 @@ private: ...@@ -127,14 +126,16 @@ private:
std::hash<int> intHasher_; std::hash<int> intHasher_;
}; };
static WeakKVCache<std::pair<std::string, int>, IParameterUpdaterHook, static WeakKVCache<std::pair<std::string, int>,
StringIntPairHasher> g_hookCache_; IParameterUpdaterHook,
StringIntPairHasher>
g_hookCache_;
/** /**
* ParameterUpdaterHook actually factory method. * ParameterUpdaterHook actually factory method.
*/ */
static IParameterUpdaterHook * static IParameterUpdaterHook *createImpl(
createImpl(const ParameterUpdaterHookConfig &config) { const ParameterUpdaterHookConfig &config) {
auto &type = config.type(); auto &type = config.type();
if (type == "pruning") { if (type == "pruning") {
return new StaticPruningHook(config); return new StaticPruningHook(config);
...@@ -144,11 +145,11 @@ createImpl(const ParameterUpdaterHookConfig &config) { ...@@ -144,11 +145,11 @@ createImpl(const ParameterUpdaterHookConfig &config) {
return nullptr; return nullptr;
} }
std::shared_ptr<IParameterUpdaterHook> std::shared_ptr<IParameterUpdaterHook> IParameterUpdaterHook::create(
IParameterUpdaterHook::create(const ParameterConfig &paramConfig, int idx) { const ParameterConfig &paramConfig, int idx) {
std::pair<std::string, int> key = {paramConfig.name(), idx}; std::pair<std::string, int> key = {paramConfig.name(), idx};
return g_hookCache_.get( return g_hookCache_.get(
key, [&] { return createImpl(paramConfig.update_hooks(idx)); }); key, [&] { return createImpl(paramConfig.update_hooks(idx)); });
} }
} // namespace paddle } // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册