diff --git a/paddle/gserver/layers/MKLDNNConcatLayer.cpp b/paddle/gserver/layers/MKLDNNConcatLayer.cpp index 7906e1808551a38555ea49da2be24ea30e26e84b..a3106b0c06cca4db3e46899fb347c7b8cb9639ae 100644 --- a/paddle/gserver/layers/MKLDNNConcatLayer.cpp +++ b/paddle/gserver/layers/MKLDNNConcatLayer.cpp @@ -84,10 +84,7 @@ void MKLDNNConcatLayer::resetFwdBuffers(std::vector& inputs, inputs.resize(inputLayers_.size()); bool has8c = false, has16c = false, hasnc = false; for (size_t i = 0; i < inputs.size(); i++) { - // resetInValue will use ic_ so temporary change as current input's channel - // TODO(TJ): change ic_ as vector then can remove channels_ - ic_ = channels_[i]; - resetInValue(inputs[i], nullptr, i); + resetInValue(inputs[i], nullptr, i, channels_[i]); CHECK(inputs[i]); auto dm = inputs[i]->getDims(); // inputs format can be different, but ndims must equal @@ -108,8 +105,6 @@ void MKLDNNConcatLayer::resetFwdBuffers(std::vector& inputs, has16c = true; } } - // change back, ic_ always save the input 0 size - ic_ = channels_[0]; format outFmt; if (has16c && oc_ % 16 == 0) { diff --git a/paddle/gserver/layers/MKLDNNLayer.cpp b/paddle/gserver/layers/MKLDNNLayer.cpp index e22345354997032a6e06634fda41e898f41e6a01..02170ea8160d7f7ed6208d3a5144399791878ffb 100644 --- a/paddle/gserver/layers/MKLDNNLayer.cpp +++ b/paddle/gserver/layers/MKLDNNLayer.cpp @@ -176,13 +176,15 @@ void MKLDNNLayer::resetWithMatrix(MKLDNNMatrixPtr& dnn, void MKLDNNLayer::resetInValue( MKLDNNMatrixPtr& in, const std::shared_ptr& intPD, - size_t inputIdx) { + size_t inputIdx, + int inputChannel) { cvtInVal_ = nullptr; extInVal_ = nullptr; in = nullptr; - CHECK_GT(bs_ * ic_ * ih_ * iw_, 0); + inputChannel = inputChannel == 0 ? ic_ : inputChannel; + CHECK_GT(bs_ * inputChannel * ih_ * iw_, 0); auto extPD = MKLDNNMatrix::createPrimitiveDesc( - {bs_, ic_, ih_, iw_}, format::nchw, engine_); + {bs_, inputChannel, ih_, iw_}, format::nchw, engine_); const MatrixPtr& inMat = inputLayers_[inputIdx]->getOutputValue(); extInVal_ = std::dynamic_pointer_cast(inMat); CHECK_EQ(inputIsOnlyMKLDNN(), extInVal_ != nullptr); diff --git a/paddle/gserver/layers/MKLDNNLayer.h b/paddle/gserver/layers/MKLDNNLayer.h index d9542bfca2c7aac918624dd3df8c27331d2080eb..0e271908099b0d0e513233a7130f1b199281dfde 100644 --- a/paddle/gserver/layers/MKLDNNLayer.h +++ b/paddle/gserver/layers/MKLDNNLayer.h @@ -38,6 +38,7 @@ protected: size_t inputElemenCnt_; // batch size int bs_; + // they sizes are always from the first input layer // input image channel, height and width int ic_, ih_, iw_; // output image channel, height and width @@ -196,11 +197,13 @@ protected: /** * reset input value from input MKLDNNMatrix and internal primitive desc. * reset both internal and external buffer and create reorder if necessary. + * input channel may be different in concat. */ void resetInValue( MKLDNNMatrixPtr& in, const std::shared_ptr& intPD = nullptr, - size_t inputIdx = 0); + size_t inputIdx = 0, + int inputChannel = 0); /** * reset output value from internal primitive desc.