提交 1eab8cce 编写于 作者: Z zlx

modify the annotations of HookAttribute, Variable declaration

上级 15bf6e05
...@@ -31,9 +31,9 @@ namespace paddle { ...@@ -31,9 +31,9 @@ namespace paddle {
/** /**
* The static pruning hook * The static pruning hook
* Static means user specific a sparsity_ratio before training start, and the * Static means user specify a sparsity_ratio before training started, and the
* network will prune the parameters based on the sparsity_ratio. More deatils * network will prune the parameters based on the sparsity_ratio. More deatils
* can see https://arxiv.org/pdf/1506.02626.pdf. * can be found https://arxiv.org/pdf/1506.02626.pdf.
*/ */
class StaticPruningHook : public IParameterUpdaterHook { class StaticPruningHook : public IParameterUpdaterHook {
...@@ -57,29 +57,31 @@ public: ...@@ -57,29 +57,31 @@ public:
} }
void generateMask(Parameter* para) { void generateMask(Parameter* para) {
VectorPtr vec = para->getBuf(PARAMETER_VALUE);
maskTemp_ = Vector::create(para->getSize(), false); VectorPtr maskTemp = Vector::create(para->getSize(), false);
maskTemp_->zeroMem(); maskTemp->zeroMem();
real* dataPtr = maskTemp_->getData(); real* maskTempData = maskTemp->getData();
size_t nonZeroNum = para->getSize() * (1 - sparsityRatio_); size_t nonZeroNum = para->getSize() * (1 - sparsityRatio_);
VectorPtr vecCpu = Vector::create(para->getSize(), false); VectorPtr paraVec = para->getBuf(PARAMETER_VALUE);
vecCpu->copyFrom(*vec); VectorPtr paraCpuCopy = Vector::create(para->getSize(), false);
paraCpuCopy->copyFrom(*paraVec);
std::vector<std::pair<real, size_t>> param; std::vector<std::pair<real, size_t>> param;
for (size_t i = 0; i < para->getSize(); i++) for (size_t i = 0; i < para->getSize(); i++)
param.push_back(std::make_pair(fabs(vecCpu->getData()[i]), i)); param.push_back(std::make_pair(fabs(paraCpuCopy->getData()[i]), i));
std::partial_sort( std::partial_sort(
param.begin(), param.begin() + nonZeroNum, param.end(), sortPairAscend); param.begin(), param.begin() + nonZeroNum, param.end(), sortPairAscend);
for (size_t i = 0; i < nonZeroNum; i++) dataPtr[param[i].second] = 1.0; for (size_t i = 0; i < nonZeroNum; i++) maskTempData[param[i].second] = 1.0;
// Currently just use a mask vector for hack. // Currently just use a mask vector for hack.
if (para->useGpu()) { if (para->useGpu()) {
maskVec_ = Vector::create(para->getSize(), para->useGpu()); maskVec_ = Vector::create(para->getSize(), para->useGpu());
maskVec_->copyFrom(*maskTemp_); maskVec_->copyFrom(*maskTemp);
} else { } else {
maskVec_ = maskTemp_; maskVec_ = maskTemp;
} }
} }
...@@ -91,15 +93,14 @@ public: ...@@ -91,15 +93,14 @@ public:
VLOG(3) << "Initialize Parameter " << para; VLOG(3) << "Initialize Parameter " << para;
SetDevice device(para->getDeviceId()); SetDevice device(para->getDeviceId());
auto& vec = para->getBuf(PARAMETER_VALUE); auto& paraVec = para->getBuf(PARAMETER_VALUE);
vec->dotMul(*maskVec_); paraVec->dotMul(*maskVec_);
} }
private: private:
SameThreadChecker updateThreadChecker_; SameThreadChecker updateThreadChecker_;
std::atomic<size_t> initCount_; std::atomic<size_t> initCount_;
VectorPtr maskVec_; VectorPtr maskVec_;
VectorPtr maskTemp_;
real sparsityRatio_; real sparsityRatio_;
}; };
......
...@@ -58,11 +58,17 @@ def is_compatible_with(x, Type): ...@@ -58,11 +58,17 @@ def is_compatible_with(x, Type):
class HookAttribute(object): class HookAttribute(object):
""" """
Hook Attribute object. The hook is an auxiliary operation that occurs Hook Attribute object. As a member of ParameterAttribute class, the hook is an auxiliary operation that occurs
during network propagation. during training process of a layer with parameters, such as img_conv layer, fc layer.
NOTE: IT IS A HIGH LEVEL USER INTERFACE.
:param type: Hook type, currently supported types:
:param type: Hook type, eg: 'pruning' 'pruning' : user specify a sparsity_ratio before training started, and the
network will prune the parameters based on the sparsity_ratio.
eg: The definition of Hook object can be hk = HookAttribute('pruning', 0.6)
The specific usage can be paddle.layer.img_conv(input=img, filter_size=3,
num_channels=3, num_filters=64,
param_attr=ParameterAttribute(update_hooks=hk) )
The pruning deatils can be found https://arxiv.org/pdf/1506.02626.pdf
:type type: string :type type: string
:param sparsity_ratio: Must be specified if hook type is 'pruning', :param sparsity_ratio: Must be specified if hook type is 'pruning',
...@@ -78,7 +84,7 @@ class HookAttribute(object): ...@@ -78,7 +84,7 @@ class HookAttribute(object):
assert is_compatible_with( assert is_compatible_with(
self.sparsity_ratio, self.sparsity_ratio,
float), 'sparisity_ratio must be float type' float), 'sparisity_ratio must be float type'
assert self.sparsity_ratio <= 1 and self.sparsity_ratio >= 0, 'sparisity must be a flaot between [0, 1] ' assert self.sparsity_ratio <= 1 and self.sparsity_ratio >= 0, 'sparisity_ratio must be a float between [0, 1] '
def __call__(self): def __call__(self):
return ParameterHook(self.type, sparsity_ratio=self.sparsity_ratio) return ParameterHook(self.type, sparsity_ratio=self.sparsity_ratio)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册