提交 997cef2e 编写于 作者: X xzl

tiny modify

上级 97a2fde9
......@@ -20,6 +20,7 @@ limitations under the License. */
#include <thread>
#include <unordered_map>
#include <vector>
#include <algorithm>
#include "paddle/math/Vector.h"
#include "paddle/parameter/Parameter.h"
......@@ -60,6 +61,7 @@ public:
maskTemp_ = Vector::create(para->getSize(), false);
maskTemp_->zeroMem();
real* dataPtr = maskTemp_->getData();
size_t sparsityNum = para->getSize() * (1 - sparsityRatio_);
VectorPtr vecCpu = Vector::create(para->getSize(), false);
vecCpu->copyFrom(*vec);
......@@ -67,10 +69,20 @@ public:
for (size_t i = 0; i < para->getSize(); i++)
param.push_back(std::make_pair(fabs(vecCpu->getData()[i]), i));
std::sort(param.begin(), param.end(), sortPairAscend);
for (size_t i = 0; i < para->getSize() * sparsityRatio_; i++)
dataPtr[param[i].second] = 1.0;
std::partial_sort(param.begin(),
param.begin() + sparsityNum,
param.end(),
sortPairAscend);
for (size_t i = 0; i < sparsityNum; i++) dataPtr[param[i].second] = 1.0;
// Currently just use a mask vector for hack.
if (para->useGpu()) {
maskVec_ = Vector::create(para->getSize(), para->useGpu());
maskVec_->copyFrom(*maskTemp_);
} else {
maskVec_ = maskTemp_;
}
}
void init(Parameter* para) {
......@@ -81,15 +93,6 @@ public:
VLOG(3) << "Initialize Parameter " << para;
SetDevice device(para->getDeviceId());
// Currently just use a mask vector for hack.
// @TODO(yuyang18): Implemented the mask operation in vector.
if (para->useGpu()) {
maskVec_ = Vector::create(para->getSize(), para->useGpu());
maskVec_->copyFrom(*maskTemp_);
} else {
maskVec_ = maskTemp_;
}
auto& vec = para->getBuf(PARAMETER_VALUE);
vec->dotMul(*maskVec_);
}
......@@ -136,11 +139,7 @@ static IParameterUpdaterHook* createImpl(
const ParameterUpdaterHookConfig& config) {
auto& type = config.type();
if (type == "pruning") {
if (config.has_sparsity_ratio())
return new StaticPruningHook(config);
else
LOG(FATAL) << "There must be sparsity_ratio parameter for " << type
<< " Hook";
}
LOG(FATAL) << "Unknown Hook type: " << type;
......
......@@ -3175,7 +3175,7 @@ def ParameterHook(type, **kwargs):
hook = ParameterUpdaterHookConfig()
hook.type = type
sparsity_ratio = kwargs.get('sparsity_ratio', None)
assert sparsity_ratio is not None
if sparsity_ratio is not None:
hook.sparsity_ratio = sparsity_ratio
return hook
else:
......
......@@ -73,7 +73,9 @@ class HookAttribute(object):
def __init__(self, type, sparsity_ratio=None):
self.type = type
self.sparsity_ratio = sparsity_ratio
assert is_compatible_with(self.sparsity_ratio,
if self.sparsity_ratio is not None:
assert is_compatible_with(
self.sparsity_ratio,
float), 'sparisity_ratio must be float type'
assert self.sparsity_ratio <= 1 and self.sparsity_ratio >= 0, 'sparisity must be a flaot between [0, 1] '
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册