提交 5c892db6 编写于 作者: T tensor-tang

remove unused code

refine comments and bias
fix typo and todo
上级 4f41eaf7
......@@ -210,11 +210,11 @@ void MKLDNNConvLayer::resetFwdBuffers(
resetWithMatrix(wgt, weight_->getW(), pd->weights_primitive_desc());
bias = nullptr;
if (biases_ == nullptr || biases_->getW() == nullptr) {
return;
if (biases_ && biases_->getW()) {
resetWithMatrix(bias, biases_->getW(), pd->bias_primitive_desc());
} else {
bias = nullptr;
}
resetWithMatrix(bias, biases_->getW(), pd->bias_primitive_desc());
}
void MKLDNNConvLayer::resetFwdPipeline(
......
......@@ -134,10 +134,6 @@ void MKLDNNFcLayer::resetFwdBuffers(MKLDNNMatrixPtr& in,
CHECK(in);
in->downSpatial();
// if (extInVal_) {
// extInVal_->downSpatial();
// }
auto outPD =
MKLDNNMatrix::createPrimitiveDesc({bs_, oc_}, format::nc, engine_);
resetOutValue(out, outPD);
......@@ -153,11 +149,12 @@ void MKLDNNFcLayer::resetFwdBuffers(MKLDNNMatrixPtr& in,
resetWithMatrix(wgt, weight_->getW(), wgtPD);
wgt->downSpatial();
if (biases_ == nullptr || biases_->getW() == nullptr) {
return;
if (biases_ && biases_->getW()) {
auto biasPD = MKLDNNMatrix::createPrimitiveDesc({oc_}, format::x, engine_);
resetWithMatrix(bias, biases_->getW(), biasPD);
} else {
bias = nullptr;
}
auto biasPD = MKLDNNMatrix::createPrimitiveDesc({oc_}, format::x, engine_);
resetWithMatrix(bias, biases_->getW(), biasPD);
}
void MKLDNNFcLayer::resetFwdPD(std::shared_ptr<fc_fwd::primitive_desc>& pd,
......@@ -207,11 +204,11 @@ void MKLDNNFcLayer::resetBwdBuffers(MKLDNNMatrixPtr& in,
CHECK(wgtVal_);
resetWithMatrix(wgt, weight_->getWGrad(), wgtVal_->getPrimitiveDesc());
bias = nullptr;
if (biasVal_ == nullptr) {
return;
if (biasVal_) {
resetWithMatrix(bias, biases_->getWGrad(), biasVal_->getPrimitiveDesc());
} else {
bias = nullptr;
}
resetWithMatrix(bias, biases_->getWGrad(), biasVal_->getPrimitiveDesc());
}
void MKLDNNFcLayer::resetBwdWgtPD(
......
......@@ -60,7 +60,7 @@ void MKLDNNLayer::forward(PassType passType) {
resetFwd(pipelineFwd_, inVal_, wgtVal_, biasVal_, outVal_);
// MKLDNNLayer output value should be MKLDNNMatrix
// so external output value is necessary.
// then external input value is not necessary,
// Then external input value is not necessary,
// since input may be mkldnn internal buffer.
CHECK(extOutVal_) << "external output value is necessary";
output_.value = std::dynamic_pointer_cast<Matrix>(extOutVal_);
......@@ -235,8 +235,8 @@ void MKLDNNLayer::resetInGrad(MKLDNNMatrixPtr& in,
in = MKLDNNMatrix::create(intPD, inMat);
Argument& arg = input->getOutput(this->getName());
arg.grad = std::dynamic_pointer_cast<Matrix>(in);
CHECK(inVal_ != nullptr && inVal_->getPrimitiveDesc() == intPD)
<< "should have internal input value and primitive desc must equal";
CHECK(inVal_);
CHECK(inVal_->getPrimitiveDesc() == intPD) << "the primitive desc must equal";
if (inputIsOnlyMKLDNN()) {
return;
}
......@@ -246,6 +246,7 @@ void MKLDNNLayer::resetInGrad(MKLDNNMatrixPtr& in,
return;
}
// need create reorder
// TODO(TJ): add macro definition to simplify it
CHECK(extInVal_ != nullptr && isPaddleFormat(extInVal_->getFormat()))
<< "should have external input value and the format must be nchw(nc)";
extInGrad_ = MKLDNNMatrix::create(extInVal_->getPrimitiveDesc(), inMat);
......
......@@ -58,14 +58,15 @@ protected:
std::vector<mkldnn::primitive> pipelineFwd_;
std::vector<mkldnn::primitive> pipelineBwd_;
/// value and grad are seperated as internal and external buffers.
/// each MKLDNNLayer must init or reset internal buffer at least,
/// and the external buffer format is always nchw of nc(when h==w==1),
/// which is the same format as paddle.
/// The output_.value and output_.grad always save the external data,
/// when mixed with cpu device.
/// When all layers are mkldnn layers, they could save internal data.
/// below MKLDNNMatrix buffers are all internal buffers
/* Value and grad are seperated as internal and external buffers.
* Each MKLDNNLayer must init or reset internal buffer at least,
* and the external buffer format is always nchw of nc(when h==w==1),
* which is the same format as paddle.
* The output_.value and output_.grad always save the external data,
* when mixed with cpu device.
* When all layers are mkldnn layers, they could save internal data.
*/
// below MKLDNNMatrix buffers are all internal buffers
MKLDNNMatrixPtr inVal_;
MKLDNNMatrixPtr inGrad_;
MKLDNNMatrixPtr outVal_;
......@@ -120,8 +121,8 @@ public:
~MKLDNNLayer() {}
virtual bool init(const LayerMap& layerMap, const ParameterMap& parameterMap);
void forward(PassType passType) override;
void backward(const UpdateCallback& callback) override;
virtual void forward(PassType passType);
virtual void backward(const UpdateCallback& callback);
/**
* reshape the input image sizes
......@@ -217,7 +218,7 @@ protected:
* reset output grad from internal primitive desc.
* merge grad if necessary.
* reset both internal and external buffer and create reorder if necessary.
* note: about merge grad, when this layer has serval outputs,
* note: about merge grad, when this layer has several outputs,
* it could not be mixed with cpu device,
* since it can not get memory desc from cpu device.
*/
......@@ -225,7 +226,7 @@ protected:
/**
* reset the merge grad primitive if necessary.
* note: do not support the grads are mixed with cpu device,
* note: do not support the grads mixed with cpu device,
* since it can not get memory desc from cpu device.
*/
void resetMergeGrad(MKLDNNMatrixPtr& out);
......@@ -313,17 +314,17 @@ protected:
* print the mkldnn memory format of grad
*/
virtual void printGradFormat() {
if (extInGrad_) {
VLOG(MKLDNN_FMTS) << extInGrad_->getFormat() << " <<< ";
}
if (inGrad_) {
VLOG(MKLDNN_FMTS) << inGrad_->getFormat() << " <<<";
if (extOutGrad_) {
VLOG(MKLDNN_FMTS) << extOutGrad_->getFormat();
}
if (outGrad_) {
VLOG(MKLDNN_FMTS) << outGrad_->getFormat() << " <<< ";
}
if (extOutGrad_) {
VLOG(MKLDNN_FMTS) << extOutGrad_->getFormat();
if (inGrad_) {
VLOG(MKLDNN_FMTS) << inGrad_->getFormat() << " <<<";
}
if (extInGrad_) {
VLOG(MKLDNN_FMTS) << extInGrad_->getFormat() << " <<< ";
}
if (wgtGrad_) {
VLOG(MKLDNN_FMTS) << "Weight grad format: " << wgtGrad_->getFormat();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册