提交 07cde439 编写于 作者: H hedaoyuan

Reconstruction of GemmConv Based on new im2col.

上级 eb0c7e5e
...@@ -12,101 +12,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,101 +12,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "GemmConvOp.h" #include "ConvOp.h"
#include "GemmFunctor.h" #include "GemmFunctor.h"
#include "Im2Col.h"
#include "paddle/math/MemoryHandle.h" #include "paddle/math/MemoryHandle.h"
namespace paddle { namespace paddle {
/*
* imData = [input_channels, input_height, input_width]
* colData = [input_channels, filter_height, filter_width,
* output_height, output_width]
*/
template <class T>
class Im2ColFunctor<DEVICE_TYPE_CPU, T> {
public:
void operator()(const T* imData,
int inputChannels,
int inputHeight,
int inputWidth,
int filterHeight,
int filterWidth,
int strideHeight,
int strideWidth,
int paddingHeight,
int paddingWidth,
int outputHeight,
int outputWidth,
T* colData) {
int channelsCol = inputChannels * filterHeight * filterWidth;
for (int c = 0; c < channelsCol; ++c) {
int wOffset = c % filterWidth;
int hOffset = (c / filterWidth) % filterHeight;
int c_im = c / filterWidth / filterHeight;
for (int h = 0; h < outputHeight; ++h) {
for (int w = 0; w < outputWidth; ++w) {
int imRowIdx = h * strideHeight + hOffset;
int imColIdx = w * strideWidth + wOffset;
if ((imRowIdx - paddingHeight) < 0 ||
(imRowIdx - paddingHeight) >= inputHeight ||
(imColIdx - paddingWidth) < 0 ||
(imColIdx - paddingWidth) >= inputWidth) {
colData[(c * outputHeight + h) * outputWidth + w] = T(0);
} else {
imRowIdx += c_im * inputHeight - paddingHeight;
imColIdx -= paddingWidth;
colData[(c * outputHeight + h) * outputWidth + w] =
imData[imRowIdx * inputWidth + imColIdx];
}
}
}
}
}
};
template <class T>
class Col2ImFunctor<DEVICE_TYPE_CPU, T> {
public:
void operator()(const T* colData,
int inputChannels,
int inputHeight,
int inputWidth,
int filterHeight,
int filterWidth,
int strideHeight,
int strideWidth,
int paddingHeight,
int paddingWidth,
int outputHeight,
int outputWidth,
T* imData) {
int channelsCol = inputChannels * filterHeight * filterWidth;
for (int c = 0; c < channelsCol; ++c) {
int wOffset = c % filterWidth;
int hOffset = (c / filterWidth) % filterHeight;
int c_im = c / filterWidth / filterHeight;
for (int h = 0; h < outputHeight; ++h) {
for (int w = 0; w < outputWidth; ++w) {
int imRowIdx = h * strideHeight + hOffset;
int imColIdx = w * strideWidth + wOffset;
if ((imRowIdx - paddingHeight) >= 0 &&
(imRowIdx - paddingHeight) < inputHeight &&
(imColIdx - paddingWidth) >= 0 &&
(imColIdx - paddingWidth) < inputWidth) {
imRowIdx += c_im * inputHeight - paddingHeight;
imColIdx -= paddingWidth;
imData[imRowIdx * inputWidth + imColIdx] +=
colData[(c * outputHeight + h) * outputWidth + w];
}
}
}
}
}
};
/* /*
* \brief Forward calculation of convolution. * \brief Forward calculation of convolution.
*/ */
...@@ -155,15 +67,20 @@ public: ...@@ -155,15 +67,20 @@ public:
real* inputData = inputs[0].data<real>(); real* inputData = inputs[0].data<real>();
real* filterData = inputs[1].data<real>(); real* filterData = inputs[1].data<real>();
real* outputData = outputs[0].data<real>(); real* outputData = outputs[0].data<real>();
TensorShape imShape =
size_t size = inputChannels / groups_ * filterHeight * filterWidth * TensorShape({inputChannels / groups_, inputHeight, inputWidth});
outputHeight * outputWidth; TensorShape colShape = TensorShape({inputChannels / groups_,
resizeBuffer<Device>(size); filterHeight,
filterWidth,
outputHeight,
outputWidth});
resizeBuffer<Device>(colShape.getElements());
real* colData = reinterpret_cast<real*>(memory_->getBuf()); real* colData = reinterpret_cast<real*>(memory_->getBuf());
Im2ColFunctor<Device, real> im2col; Im2ColFunctor<kCFO, Device, real> im2col;
GemmFunctor<Device, real> gemm; GemmFunctor<Device, real> gemm;
size_t inputOffset = (inputChannels / groups_) * inputHeight * inputWidth; size_t inputOffset = imShape.getElements();
size_t outputOffset = size_t outputOffset =
(outputChannels / groups_) * outputHeight * outputWidth; (outputChannels / groups_) * outputHeight * outputWidth;
size_t filterOffset = filter.getElements() / groups_; size_t filterOffset = filter.getElements() / groups_;
...@@ -171,18 +88,13 @@ public: ...@@ -171,18 +88,13 @@ public:
for (size_t i = 0; i < batchSize; i++) { for (size_t i = 0; i < batchSize; i++) {
for (size_t g = 0; g < groups_; g++) { for (size_t g = 0; g < groups_; g++) {
im2col(inputData + g * inputOffset, im2col(inputData + g * inputOffset,
inputChannels / groups_, imShape,
inputHeight, colData,
inputWidth, colShape,
filterHeight,
filterWidth,
strideH(), strideH(),
strideW(), strideW(),
paddingH(), paddingH(),
paddingW(), paddingW());
outputHeight,
outputWidth,
colData);
int M = outputChannels / groups_; int M = outputChannels / groups_;
int N = outputHeight * outputWidth; int N = outputHeight * outputWidth;
...@@ -249,15 +161,20 @@ public: ...@@ -249,15 +161,20 @@ public:
real* outputGrad = inputs[0].data<real>(); real* outputGrad = inputs[0].data<real>();
real* filterData = inputs[1].data<real>(); real* filterData = inputs[1].data<real>();
real* inputGrad = outputs[0].data<real>(); real* inputGrad = outputs[0].data<real>();
TensorShape imShape =
size_t size = inputChannels / groups_ * filterHeight * filterWidth * TensorShape({inputChannels / groups_, inputHeight, inputWidth});
outputHeight * outputWidth; TensorShape colShape = TensorShape({inputChannels / groups_,
resizeBuffer<Device>(size); filterHeight,
filterWidth,
outputHeight,
outputWidth});
resizeBuffer<Device>(colShape.getElements());
real* colData = reinterpret_cast<real*>(memory_->getBuf()); real* colData = reinterpret_cast<real*>(memory_->getBuf());
Col2ImFunctor<Device, real> col2im; Col2ImFunctor<kCFO, Device, real> col2im;
GemmFunctor<Device, real> gemm; GemmFunctor<Device, real> gemm;
size_t inputOffset = (inputChannels / groups_) * inputHeight * inputWidth; size_t inputOffset = imShape.getElements();
size_t outputOffset = size_t outputOffset =
(outputChannels / groups_) * outputHeight * outputWidth; (outputChannels / groups_) * outputHeight * outputWidth;
size_t filterOffset = filter.getElements() / groups_; size_t filterOffset = filter.getElements() / groups_;
...@@ -280,20 +197,14 @@ public: ...@@ -280,20 +197,14 @@ public:
0.0f, 0.0f,
colData, colData,
N); N);
col2im(inputGrad + g * inputOffset,
col2im(colData, imShape,
inputChannels / groups_, colData,
inputHeight, colShape,
inputWidth,
filterHeight,
filterWidth,
strideH(), strideH(),
strideW(), strideW(),
paddingH(), paddingH(),
paddingW(), paddingW());
outputHeight,
outputWidth,
inputGrad + g * inputOffset);
} }
inputGrad += inputChannels * inputHeight * inputWidth; inputGrad += inputChannels * inputHeight * inputWidth;
outputGrad += outputChannels * outputHeight * outputWidth; outputGrad += outputChannels * outputHeight * outputWidth;
...@@ -347,33 +258,33 @@ public: ...@@ -347,33 +258,33 @@ public:
real* outputGrad = inputs[0].data<real>(); real* outputGrad = inputs[0].data<real>();
real* inputData = inputs[1].data<real>(); real* inputData = inputs[1].data<real>();
real* filterGrad = outputs[0].data<real>(); real* filterGrad = outputs[0].data<real>();
TensorShape imShape =
size_t size = inputChannels / groups_ * filterHeight * filterWidth * TensorShape({inputChannels / groups_, inputHeight, inputWidth});
outputHeight * outputWidth; TensorShape colShape = TensorShape({inputChannels / groups_,
resizeBuffer<Device>(size); filterHeight,
filterWidth,
outputHeight,
outputWidth});
resizeBuffer<Device>(colShape.getElements());
real* colData = reinterpret_cast<real*>(memory_->getBuf()); real* colData = reinterpret_cast<real*>(memory_->getBuf());
Im2ColFunctor<Device, real> im2col; Im2ColFunctor<kCFO, Device, real> im2col;
GemmFunctor<Device, real> gemm; GemmFunctor<Device, real> gemm;
size_t inputOffset = (inputChannels / groups_) * inputHeight * inputWidth; size_t inputOffset = imShape.getElements();
size_t outputOffset = size_t outputOffset =
(outputChannels / groups_) * outputHeight * outputWidth; (outputChannels / groups_) * outputHeight * outputWidth;
size_t filterOffset = filter.getElements() / groups_; size_t filterOffset = filter.getElements() / groups_;
for (size_t i = 0; i < batchSize; i++) { for (size_t i = 0; i < batchSize; i++) {
for (size_t g = 0; g < groups_; g++) { for (size_t g = 0; g < groups_; g++) {
im2col(inputData + g * inputOffset, im2col(inputData + g * inputOffset,
inputChannels / groups_, imShape,
inputHeight, colData,
inputWidth, colShape,
filterHeight,
filterWidth,
strideH(), strideH(),
strideW(), strideW(),
paddingH(), paddingH(),
paddingW(), paddingW());
outputHeight,
outputWidth,
colData);
int M = outputChannels / groups_; int M = outputChannels / groups_;
int K = outputHeight * outputWidth; int K = outputHeight * outputWidth;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册