diff --git a/paddle/gserver/layers/MKLDNNBatchNormLayer.cpp b/paddle/gserver/layers/MKLDNNBatchNormLayer.cpp index f577616230be65e9581cf8f3ed5f63a77c7c3e21..9b0ae20f089e34a719883bc65e88e33ab9334e39 100644 --- a/paddle/gserver/layers/MKLDNNBatchNormLayer.cpp +++ b/paddle/gserver/layers/MKLDNNBatchNormLayer.cpp @@ -216,17 +216,13 @@ void MKLDNNBatchNormLayer::resetFwdPD( } auto fwdDesc = bn_fwd::desc(pk, in->getMemoryDesc(), EPS, flags_); pd.reset(new bn_fwd::primitive_desc(fwdDesc, engine_)); - // TODO(TJ): use check macro - CHECK(out); - CHECK(out->getPrimitiveDesc() == pd->dst_primitive_desc()); + CHECK_PRIMITIVE_DESC_EQ(out, pd->dst_primitive_desc()); if (wgt) { - CHECK(wgt->getPrimitiveDesc() == pd->weights_primitive_desc()); + CHECK_PRIMITIVE_DESC_EQ(wgt, pd->weights_primitive_desc()); } if (passType_ != PASS_TEST || useGlobalStats_) { - CHECK(mean_); - CHECK(mean_->getPrimitiveDesc() == pd->mean_primitive_desc()); - CHECK(var_); - CHECK(var_->getPrimitiveDesc() == pd->variance_primitive_desc()); + CHECK_PRIMITIVE_DESC_EQ(mean_, pd->mean_primitive_desc()); + CHECK_PRIMITIVE_DESC_EQ(var_, pd->variance_primitive_desc()); } } @@ -283,19 +279,14 @@ void MKLDNNBatchNormLayer::resetBwdPD( if (in == nullptr) { return; } - CHECK(out); - CHECK(out->getPrimitiveDesc() == in->getPrimitiveDesc()); + CHECK_PRIMITIVE_DESC_EQ(out, in->getPrimitiveDesc()); auto md = in->getMemoryDesc(); auto bwdDesc = bn_bwd::desc(prop_kind::backward, md, md, EPS, flags_); pd.reset(new bn_bwd::primitive_desc(bwdDesc, engine_, *fwdPD_)); - // TODO(TJ): use check macro - CHECK(wgt); - CHECK(wgt->getPrimitiveDesc() == pd->diff_weights_primitive_desc()); CHECK(pd->weights_primitive_desc() == fwdPD_->weights_primitive_desc()); - CHECK(mean_); - CHECK(mean_->getPrimitiveDesc() == pd->mean_primitive_desc()); - CHECK(var_); - CHECK(var_->getPrimitiveDesc() == pd->variance_primitive_desc()); + CHECK_PRIMITIVE_DESC_EQ(wgt, pd->diff_weights_primitive_desc()); + CHECK_PRIMITIVE_DESC_EQ(mean_, pd->mean_primitive_desc()); + CHECK_PRIMITIVE_DESC_EQ(var_, pd->variance_primitive_desc()); } void MKLDNNBatchNormLayer::resetBwdPipeline( diff --git a/paddle/gserver/layers/MKLDNNConvLayer.cpp b/paddle/gserver/layers/MKLDNNConvLayer.cpp index 83f4e4e6151d727b3e6cf367bb7ecae55dd7df73..b8120eda1e2dadab943869a05546351a369af6fd 100644 --- a/paddle/gserver/layers/MKLDNNConvLayer.cpp +++ b/paddle/gserver/layers/MKLDNNConvLayer.cpp @@ -262,12 +262,15 @@ void MKLDNNConvLayer::resetBwdWgtPD( padR, padKind); pd.reset(new conv_bwdWgt::primitive_desc(bwdWgtDesc, engine_, *fwdPD_)); - CHECK(pd->src_primitive_desc() == inVal_->getPrimitiveDesc()) - << "primitive desc of in value should equal"; - CHECK(pd->diff_dst_primitive_desc() == outVal_->getPrimitiveDesc()) - << "primitive desc of out grad should equal the out value"; - CHECK(pd->diff_weights_primitive_desc() == wgtVal_->getPrimitiveDesc()) - << "primitive desc of weight grad should equal the weight value"; + CHECK_PRIMITIVE_DESC_EQ(inVal_, pd->src_primitive_desc()); + CHECK_PRIMITIVE_DESC_EQ( + outVal_, + pd->diff_dst_primitive_desc(), + "primitive desc of out value and grad should be equal"); + CHECK_PRIMITIVE_DESC_EQ( + wgtVal_, + pd->diff_weights_primitive_desc(), + "primitive desc of weight value and grad should be equal"); } void MKLDNNConvLayer::resetBwdDataPD( @@ -292,10 +295,14 @@ void MKLDNNConvLayer::resetBwdDataPD( padR, padding_kind::zero); pd.reset(new conv_bwdData::primitive_desc(bwdDataDesc, engine_, *fwdPD_)); - CHECK(pd->diff_src_primitive_desc() == inVal_->getPrimitiveDesc()) - << "primitive desc of in grad should equal the in value"; - CHECK(pd->diff_dst_primitive_desc() == outVal_->getPrimitiveDesc()) - << "primitive desc of out grad should equal"; + CHECK_PRIMITIVE_DESC_EQ( + inVal_, + pd->diff_src_primitive_desc(), + "primitive desc of in value and grad should be equal"); + CHECK_PRIMITIVE_DESC_EQ( + outVal_, + pd->diff_dst_primitive_desc(), + "primitive desc of out value and grad should be equal"); } void MKLDNNConvLayer::resetBwdBuffers( @@ -310,17 +317,20 @@ void MKLDNNConvLayer::resetBwdBuffers( resetWithMatrix( wgt, weight_->getWGrad(), wgtPD->diff_weights_primitive_desc()); - CHECK(wgtVal_ != nullptr && - wgt->getPrimitiveDesc() == wgtVal_->getPrimitiveDesc()) - << "primitive desc of weight grad and value should be equal"; + CHECK_PRIMITIVE_DESC_EQ( + wgtVal_, + wgt->getPrimitiveDesc(), + "primitive desc of weight grad and value should be equal"); bias = nullptr; if (biases_ && biases_->getWGrad()) { resetWithMatrix( bias, biases_->getWGrad(), wgtPD->diff_bias_primitive_desc()); - CHECK(bias && biasVal_ && - bias->getPrimitiveDesc() == biasVal_->getPrimitiveDesc()) - << "primitive desc of bias grad should equal the bias value"; + CHECK(bias); + CHECK_PRIMITIVE_DESC_EQ( + biasVal_, + bias->getPrimitiveDesc(), + "primitive desc of bias grad and value should be equal"); } if (dataPD == nullptr) { diff --git a/paddle/gserver/layers/MKLDNNLayer.cpp b/paddle/gserver/layers/MKLDNNLayer.cpp index 6bb19976b5552fcd2e420f03de45c77a90ffb9d2..663a10509857ec9fb487c1cda1621bdfac1250ac 100644 --- a/paddle/gserver/layers/MKLDNNLayer.cpp +++ b/paddle/gserver/layers/MKLDNNLayer.cpp @@ -235,8 +235,7 @@ void MKLDNNLayer::resetInGrad(MKLDNNMatrixPtr& in, in = MKLDNNMatrix::create(intPD, inMat); Argument& arg = input->getOutput(this->getName()); arg.grad = std::dynamic_pointer_cast(in); - CHECK(inVal_); - CHECK(inVal_->getPrimitiveDesc() == intPD) << "the primitive desc must equal"; + CHECK_PRIMITIVE_DESC_EQ(inVal_, intPD); if (inputIsOnlyMKLDNN()) { return; } @@ -250,8 +249,7 @@ void MKLDNNLayer::resetInGrad(MKLDNNMatrixPtr& in, CHECK(extInVal_ != nullptr && isPaddleFormat(extInVal_->getFormat())) << "should have external input value and the format must be nchw(nc)"; extInGrad_ = MKLDNNMatrix::create(extInVal_->getPrimitiveDesc(), inMat); - CHECK(inVal_ != nullptr && inVal_->getPrimitiveDesc() == intPD) - << "should have internal input value and primitive desc must equal"; + CHECK_PRIMITIVE_DESC_EQ(inVal_, intPD); in = MKLDNNMatrix::create(intPD); cvtInGrad_ = MKLDNNMatrix::createReorder(in, extInGrad_); CHECK(cvtInGrad_); @@ -277,8 +275,7 @@ void MKLDNNLayer::resetOutGrad(MKLDNNMatrixPtr& out, CHECK(extOutVal_ != nullptr && isPaddleFormat(extOutVal_->getFormat())) << "should have external output value and the format must be nchw(nc)"; extOutGrad_ = MKLDNNMatrix::create(extOutVal_->getPrimitiveDesc(), outMat); - CHECK(outVal_ != nullptr && outVal_->getPrimitiveDesc() == intPD) - << "should have internal output value and primitive desc must equal"; + CHECK_PRIMITIVE_DESC_EQ(outVal_, intPD); out = MKLDNNMatrix::create(intPD); cvtOutGrad_ = MKLDNNMatrix::createReorder(extOutGrad_, out); CHECK(cvtOutGrad_); diff --git a/paddle/math/MKLDNNMatrix.h b/paddle/math/MKLDNNMatrix.h index 2b62d4e11ac7276924947ab47360ffca84240aea..5f5b819017b83579ce58522198b3f13311297d42 100644 --- a/paddle/math/MKLDNNMatrix.h +++ b/paddle/math/MKLDNNMatrix.h @@ -24,6 +24,12 @@ namespace paddle { class MKLDNNMatrix; typedef std::shared_ptr MKLDNNMatrixPtr; +#define CHECK_PRIMITIVE_DESC_EQ(MAT, PD, ...) \ + CHECK(MAT) << " can not be empty."; \ + CHECK(MAT->getPrimitiveDesc() == PD) \ + << #MAT "->getPrimitiveDesc() and " #PD " should be equal.\n " \ + << "" __VA_ARGS__; + /** * @brief MKLDNN Matrix. *