diff --git a/paddle/gserver/layers/MKLDNNLayer.cpp b/paddle/gserver/layers/MKLDNNLayer.cpp index e75ac5ba4647a8267b7bc189893bd7adb5c3053f..0d063a89cc52a69604512850c9138904a49f7896 100644 --- a/paddle/gserver/layers/MKLDNNLayer.cpp +++ b/paddle/gserver/layers/MKLDNNLayer.cpp @@ -138,8 +138,11 @@ void MKLDNNLayer::backward(const UpdateCallback& callback) { } } -void MKLDNNLayer::reshapeInput(int& batchsize, int& height, int& width) { - const Argument& input = inputLayers_[0]->getOutput(); +void MKLDNNLayer::reshapeInput(int& batchsize, + int& height, + int& width, + size_t inputIdx) { + const Argument& input = inputLayers_[inputIdx]->getOutput(); batchsize = input.getBatchSize(); int h = input.getFrameHeight(); int w = input.getFrameWidth(); diff --git a/paddle/gserver/layers/MKLDNNLayer.h b/paddle/gserver/layers/MKLDNNLayer.h index 7479c34c92b5231b2521493bc631474d4efd4224..4c42df1bee75fa7b28c2001c30797cc0df7c5554 100644 --- a/paddle/gserver/layers/MKLDNNLayer.h +++ b/paddle/gserver/layers/MKLDNNLayer.h @@ -178,7 +178,10 @@ protected: /** * reshape the input image sizes and input batchsize */ - void reshapeInput(int& batchsize, int& height, int& width); + void reshapeInput(int& batchsize, + int& height, + int& width, + size_t inputIdx = 0); /** * reshape output image sizes