提交 997cef2e 编写于 作者: X xzl

tiny modify

上级 97a2fde9
...@@ -20,6 +20,7 @@ limitations under the License. */ ...@@ -20,6 +20,7 @@ limitations under the License. */
#include <thread> #include <thread>
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
#include <algorithm>
#include "paddle/math/Vector.h" #include "paddle/math/Vector.h"
#include "paddle/parameter/Parameter.h" #include "paddle/parameter/Parameter.h"
...@@ -60,6 +61,7 @@ public: ...@@ -60,6 +61,7 @@ public:
maskTemp_ = Vector::create(para->getSize(), false); maskTemp_ = Vector::create(para->getSize(), false);
maskTemp_->zeroMem(); maskTemp_->zeroMem();
real* dataPtr = maskTemp_->getData(); real* dataPtr = maskTemp_->getData();
size_t sparsityNum = para->getSize() * (1 - sparsityRatio_);
VectorPtr vecCpu = Vector::create(para->getSize(), false); VectorPtr vecCpu = Vector::create(para->getSize(), false);
vecCpu->copyFrom(*vec); vecCpu->copyFrom(*vec);
...@@ -67,10 +69,20 @@ public: ...@@ -67,10 +69,20 @@ public:
for (size_t i = 0; i < para->getSize(); i++) for (size_t i = 0; i < para->getSize(); i++)
param.push_back(std::make_pair(fabs(vecCpu->getData()[i]), i)); param.push_back(std::make_pair(fabs(vecCpu->getData()[i]), i));
std::sort(param.begin(), param.end(), sortPairAscend);
for (size_t i = 0; i < para->getSize() * sparsityRatio_; i++) std::partial_sort(param.begin(),
dataPtr[param[i].second] = 1.0; param.begin() + sparsityNum,
param.end(),
sortPairAscend);
for (size_t i = 0; i < sparsityNum; i++) dataPtr[param[i].second] = 1.0;
// Currently just use a mask vector for hack.
if (para->useGpu()) {
maskVec_ = Vector::create(para->getSize(), para->useGpu());
maskVec_->copyFrom(*maskTemp_);
} else {
maskVec_ = maskTemp_;
}
} }
void init(Parameter* para) { void init(Parameter* para) {
...@@ -81,15 +93,6 @@ public: ...@@ -81,15 +93,6 @@ public:
VLOG(3) << "Initialize Parameter " << para; VLOG(3) << "Initialize Parameter " << para;
SetDevice device(para->getDeviceId()); SetDevice device(para->getDeviceId());
// Currently just use a mask vector for hack.
// @TODO(yuyang18): Implemented the mask operation in vector.
if (para->useGpu()) {
maskVec_ = Vector::create(para->getSize(), para->useGpu());
maskVec_->copyFrom(*maskTemp_);
} else {
maskVec_ = maskTemp_;
}
auto& vec = para->getBuf(PARAMETER_VALUE); auto& vec = para->getBuf(PARAMETER_VALUE);
vec->dotMul(*maskVec_); vec->dotMul(*maskVec_);
} }
...@@ -136,11 +139,7 @@ static IParameterUpdaterHook* createImpl( ...@@ -136,11 +139,7 @@ static IParameterUpdaterHook* createImpl(
const ParameterUpdaterHookConfig& config) { const ParameterUpdaterHookConfig& config) {
auto& type = config.type(); auto& type = config.type();
if (type == "pruning") { if (type == "pruning") {
if (config.has_sparsity_ratio()) return new StaticPruningHook(config);
return new StaticPruningHook(config);
else
LOG(FATAL) << "There must be sparsity_ratio parameter for " << type
<< " Hook";
} }
LOG(FATAL) << "Unknown Hook type: " << type; LOG(FATAL) << "Unknown Hook type: " << type;
......
...@@ -3175,8 +3175,8 @@ def ParameterHook(type, **kwargs): ...@@ -3175,8 +3175,8 @@ def ParameterHook(type, **kwargs):
hook = ParameterUpdaterHookConfig() hook = ParameterUpdaterHookConfig()
hook.type = type hook.type = type
sparsity_ratio = kwargs.get('sparsity_ratio', None) sparsity_ratio = kwargs.get('sparsity_ratio', None)
assert sparsity_ratio is not None if sparsity_ratio is not None:
hook.sparsity_ratio = sparsity_ratio hook.sparsity_ratio = sparsity_ratio
return hook return hook
else: else:
return None return None
......
...@@ -73,9 +73,11 @@ class HookAttribute(object): ...@@ -73,9 +73,11 @@ class HookAttribute(object):
def __init__(self, type, sparsity_ratio=None): def __init__(self, type, sparsity_ratio=None):
self.type = type self.type = type
self.sparsity_ratio = sparsity_ratio self.sparsity_ratio = sparsity_ratio
assert is_compatible_with(self.sparsity_ratio, if self.sparsity_ratio is not None:
float), 'sparisity_ratio must be float type' assert is_compatible_with(
assert self.sparsity_ratio <= 1 and self.sparsity_ratio >= 0, 'sparisity must be a flaot between [0, 1] ' self.sparsity_ratio,
float), 'sparisity_ratio must be float type'
assert self.sparsity_ratio <= 1 and self.sparsity_ratio >= 0, 'sparisity must be a flaot between [0, 1] '
def __call__(self): def __call__(self):
return ParameterHook(self.type, sparsity_ratio=self.sparsity_ratio) return ParameterHook(self.type, sparsity_ratio=self.sparsity_ratio)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册