From 9acfba82a37d06aeafaaacccc30b6e2df56354ed Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Thu, 16 Nov 2017 11:46:31 +0800 Subject: [PATCH] add input index choice for mkldnn_concat --- paddle/gserver/layers/MKLDNNLayer.cpp | 7 +++++-- paddle/gserver/layers/MKLDNNLayer.h | 5 ++++- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/paddle/gserver/layers/MKLDNNLayer.cpp b/paddle/gserver/layers/MKLDNNLayer.cpp index e75ac5ba464..0d063a89cc5 100644 --- a/paddle/gserver/layers/MKLDNNLayer.cpp +++ b/paddle/gserver/layers/MKLDNNLayer.cpp @@ -138,8 +138,11 @@ void MKLDNNLayer::backward(const UpdateCallback& callback) { } } -void MKLDNNLayer::reshapeInput(int& batchsize, int& height, int& width) { - const Argument& input = inputLayers_[0]->getOutput(); +void MKLDNNLayer::reshapeInput(int& batchsize, + int& height, + int& width, + size_t inputIdx) { + const Argument& input = inputLayers_[inputIdx]->getOutput(); batchsize = input.getBatchSize(); int h = input.getFrameHeight(); int w = input.getFrameWidth(); diff --git a/paddle/gserver/layers/MKLDNNLayer.h b/paddle/gserver/layers/MKLDNNLayer.h index 7479c34c92b..4c42df1bee7 100644 --- a/paddle/gserver/layers/MKLDNNLayer.h +++ b/paddle/gserver/layers/MKLDNNLayer.h @@ -178,7 +178,10 @@ protected: /** * reshape the input image sizes and input batchsize */ - void reshapeInput(int& batchsize, int& height, int& width); + void reshapeInput(int& batchsize, + int& height, + int& width, + size_t inputIdx = 0); /** * reshape output image sizes -- GitLab