提交 d99faf31 编写于 作者: H hedaoyuan

Add the calculation implementation of GemmConvGradInputFunction.

上级 90326198
...@@ -78,12 +78,10 @@ public: ...@@ -78,12 +78,10 @@ public:
test.addOutputs(BufferArg(VALUE_TYPE_FLOAT, output)); test.addOutputs(BufferArg(VALUE_TYPE_FLOAT, output));
test.run(); test.run();
} else if (type == BACKWARD_INPUT_TEST) { } else if (type == BACKWARD_INPUT_TEST) {
#if 0
test.addInputs(BufferArg(VALUE_TYPE_FLOAT, output)); test.addInputs(BufferArg(VALUE_TYPE_FLOAT, output));
test.addInputs(BufferArg(VALUE_TYPE_FLOAT, filter)); test.addInputs(BufferArg(VALUE_TYPE_FLOAT, filter));
test.addOutputs(BufferArg(VALUE_TYPE_FLOAT, input)); test.addOutputs(BufferArg(VALUE_TYPE_FLOAT, input));
test.run(); test.run();
#endif
} else if (type == BACKWARD_FILTER_TEST) { } else if (type == BACKWARD_FILTER_TEST) {
test.addInputs(BufferArg(VALUE_TYPE_FLOAT, output)); test.addInputs(BufferArg(VALUE_TYPE_FLOAT, output));
test.addInputs(BufferArg(VALUE_TYPE_FLOAT, input)); test.addInputs(BufferArg(VALUE_TYPE_FLOAT, input));
...@@ -111,6 +109,11 @@ TEST(Forward, GEMM2) { ...@@ -111,6 +109,11 @@ TEST(Forward, GEMM2) {
"GemmConv-CPU", "GemmConv-GPU", FORWARD_TEST); "GemmConv-CPU", "GemmConv-GPU", FORWARD_TEST);
} }
TEST(BackwardInput, GEMM) {
ConvolutionTest<DEVICE_TYPE_CPU, DEVICE_TYPE_GPU> test(
"GemmConvGradInput-CPU", "GemmConvGradInput-GPU", BACKWARD_INPUT_TEST);
}
TEST(BackwardFilter, GEMM) { TEST(BackwardFilter, GEMM) {
ConvolutionTest<DEVICE_TYPE_CPU, DEVICE_TYPE_GPU> test( ConvolutionTest<DEVICE_TYPE_CPU, DEVICE_TYPE_GPU> test(
"GemmConvGradFilter-CPU", "GemmConvGradFilter-GPU", BACKWARD_FILTER_TEST); "GemmConvGradFilter-CPU", "GemmConvGradFilter-GPU", BACKWARD_FILTER_TEST);
......
...@@ -44,22 +44,62 @@ public: ...@@ -44,22 +44,62 @@ public:
for (int c = 0; c < channelsCol; ++c) { for (int c = 0; c < channelsCol; ++c) {
int wOffset = c % filterWidth; int wOffset = c % filterWidth;
int hOffset = (c / filterWidth) % filterHeight; int hOffset = (c / filterWidth) % filterHeight;
int c_im = c / filterHeight / filterWidth; int c_im = c / filterWidth / filterHeight;
for (int h = 0; h < outputHeight; ++h) { for (int h = 0; h < outputHeight; ++h) {
for (int w = 0; w < outputWidth; ++w) { for (int w = 0; w < outputWidth; ++w) {
// no c_im*height to Exclude the channel number int imRowIdx = h * strideHeight + hOffset;
int imgRowIdx = h * strideHeight + hOffset; int imColIdx = w * strideWidth + wOffset;
int imgColIdx = w * strideWidth + wOffset; if ((imRowIdx - paddingHeight) < 0 ||
if ((imgRowIdx - paddingHeight) < 0 || (imRowIdx - paddingHeight) >= inputHeight ||
(imgRowIdx - paddingHeight) >= inputHeight || (imColIdx - paddingWidth) < 0 ||
(imgColIdx - paddingWidth) < 0 || (imColIdx - paddingWidth) >= inputWidth) {
(imgColIdx - paddingWidth) >= inputWidth) {
colData[(c * outputHeight + h) * outputWidth + w] = T(0); colData[(c * outputHeight + h) * outputWidth + w] = T(0);
} else { } else {
imgRowIdx += c_im * inputHeight - paddingHeight; imRowIdx += c_im * inputHeight - paddingHeight;
imgColIdx -= paddingWidth; imColIdx -= paddingWidth;
colData[(c * outputHeight + h) * outputWidth + w] = colData[(c * outputHeight + h) * outputWidth + w] =
imData[imgRowIdx * inputWidth + imgColIdx]; imData[imRowIdx * inputWidth + imColIdx];
}
}
}
}
}
};
template <class T>
class Col2ImFunctor<DEVICE_TYPE_CPU, T> {
public:
void operator()(const T* colData,
int inputChannels,
int inputHeight,
int inputWidth,
int filterHeight,
int filterWidth,
int strideHeight,
int strideWidth,
int paddingHeight,
int paddingWidth,
int outputHeight,
int outputWidth,
T* imData) {
int channelsCol = inputChannels * filterHeight * filterWidth;
for (int c = 0; c < channelsCol; ++c) {
int wOffset = c % filterWidth;
int hOffset = (c / filterWidth) % filterHeight;
int c_im = c / filterWidth / filterHeight;
for (int h = 0; h < outputHeight; ++h) {
for (int w = 0; w < outputWidth; ++w) {
int imRowIdx = h * strideHeight + hOffset;
int imColIdx = w * strideWidth + wOffset;
if ((imRowIdx - paddingHeight) >= 0 &&
(imRowIdx - paddingHeight) < inputHeight &&
(imColIdx - paddingWidth) >= 0 &&
(imColIdx - paddingWidth) < inputWidth) {
imRowIdx += c_im * inputHeight - paddingHeight;
imColIdx -= paddingWidth;
imData[imRowIdx * inputWidth + imColIdx] +=
colData[(c * outputHeight + h) * outputWidth + w];
} }
} }
} }
...@@ -171,10 +211,74 @@ public: ...@@ -171,10 +211,74 @@ public:
void calc(const BufferArgs& inputs, const BufferArgs& outputs) override { void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
CHECK_EQ(numInputs_, inputs.size()); CHECK_EQ(numInputs_, inputs.size());
CHECK_EQ(numOutputs_, outputs.size()); CHECK_EQ(numOutputs_, outputs.size());
const TensorShape& outputGrad = inputs[0].shape(); // CHECK_EQ(outputs[0].getArgType(), ADD_TO);
const TensorShape& output = inputs[0].shape();
const TensorShape& filter = inputs[1].shape(); const TensorShape& filter = inputs[1].shape();
const TensorShape& inputGrad = outputs[0].shape(); const TensorShape& input = outputs[0].shape();
check(inputGrad, filter, outputGrad); check(input, filter, output);
size_t batchSize = input[0];
size_t inputChannels = input[1];
size_t inputHeight = input[2];
size_t inputWidth = input[3];
size_t filterHeight = filter[2];
size_t filterWidth = filter[3];
size_t outputChannels = output[1];
size_t outputHeight = output[2];
size_t outputWidth = output[3];
real* outputGrad = inputs[0].data<real>();
real* filterData = inputs[1].data<real>();
real* inputGrad = outputs[0].data<real>();
size_t size = inputChannels / groups_ * filterHeight * filterWidth *
outputHeight * outputWidth;
resizeBuffer<Device>(size);
real* colData = reinterpret_cast<real*>(memory_->getBuf());
Col2ImFunctor<Device, real> col2im;
GemmFunctor<Device, real> gemm;
size_t inputOffset = (inputChannels / groups_) * inputHeight * inputWidth;
size_t outputOffset =
(outputChannels / groups_) * outputHeight * outputWidth;
size_t filterOffset = filter.getElements() / groups_;
for (size_t i = 0; i < batchSize; i++) {
for (size_t g = 0; g < groups_; g++) {
int K = outputChannels / groups_;
int N = outputHeight * outputWidth;
int M = inputChannels / groups_ * filterHeight * filterWidth;
gemm(CblasTrans,
CblasNoTrans,
M,
N,
K,
1.0f,
filterData + g * filterOffset,
M,
outputGrad + g * outputOffset,
N,
0.0f,
colData,
N);
col2im(colData,
inputChannels / groups_,
inputHeight,
inputWidth,
filterHeight,
filterWidth,
strideH(),
strideW(),
paddingH(),
paddingW(),
outputHeight,
outputWidth,
inputGrad + g * inputOffset);
}
inputGrad += inputChannels * inputHeight * inputWidth;
outputGrad += outputChannels * outputHeight * outputWidth;
}
} }
}; };
...@@ -191,12 +295,18 @@ public: ...@@ -191,12 +295,18 @@ public:
void calc(const BufferArgs& inputs, const BufferArgs& outputs) override { void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
CHECK_EQ(numInputs_, inputs.size()); CHECK_EQ(numInputs_, inputs.size());
CHECK_EQ(numOutputs_, outputs.size()); CHECK_EQ(numOutputs_, outputs.size());
CHECK_EQ(outputs[0].getArgType(), ASSIGN_TO);
const TensorShape& output = inputs[0].shape(); const TensorShape& output = inputs[0].shape();
const TensorShape& input = inputs[1].shape(); const TensorShape& input = inputs[1].shape();
const TensorShape& filter = outputs[0].shape(); const TensorShape& filter = outputs[0].shape();
check(input, filter, output); check(input, filter, output);
real beta;
if (outputs[0].getArgType() == ADD_TO) {
beta = 1.0;
} else {
beta = 0.0;
}
size_t batchSize = input[0]; size_t batchSize = input[0];
size_t inputChannels = input[1]; size_t inputChannels = input[1];
size_t inputHeight = input[2]; size_t inputHeight = input[2];
...@@ -251,7 +361,7 @@ public: ...@@ -251,7 +361,7 @@ public:
K, K,
colData, colData,
K, K,
1.0f, i == 0 ? beta : 1.0f,
filterGrad + g * filterOffset, filterGrad + g * filterOffset,
N); N);
} }
......
...@@ -41,4 +41,22 @@ public: ...@@ -41,4 +41,22 @@ public:
T* colData); T* colData);
}; };
template <DeviceType Device, class T>
class Col2ImFunctor {
public:
void operator()(const T* colData,
int inputChannels,
int inputHeight,
int inputWidth,
int filterHeight,
int filterWidth,
int strideHeight,
int strideWidth,
int paddingHeight,
int paddingWidth,
int outputHeight,
int outputWidth,
T* imData);
};
} // namespace paddle } // namespace paddle
...@@ -87,7 +87,100 @@ public: ...@@ -87,7 +87,100 @@ public:
} }
}; };
template<class T>
__global__
void col2im(size_t n, const T* data_col, size_t height,
size_t width, size_t channels,
size_t blockH, size_t blockW,
size_t strideH, size_t strideW,
size_t paddingH, size_t paddingW,
size_t height_col, size_t width_col,
T* data_im) {
size_t index =
(blockIdx.x * gridDim.y + blockIdx.y) * blockDim.x + threadIdx.x;
if (index < n) {
T val = 0;
int w = int(index % width);
int h = int((index / width) % height);
int c = int(index / (width * height));
if ((w - (int)paddingW) >= 0 &&
(w - (int)paddingW) < (width-2 * paddingW) &&
(h - (int)paddingH) >= 0 &&
(h - paddingH) < (height - 2 * paddingH)) {
// compute the start and end of the output
int w_col_start =
(w < (int)blockW) ? 0 : (w - int(blockW)) / (int)strideW + 1;
int w_col_end =
min((int)(w / (int)strideW + 1), (int)(width_col));
int h_col_start =
(h < (int)blockH) ? 0 : (h - (int)blockH) / (int)strideH + 1;
int h_col_end = min(int(h / strideH + 1), int(height_col));
for (int h_col = h_col_start; h_col < h_col_end; ++h_col) {
for (int w_col = w_col_start; w_col < w_col_end; ++w_col) {
// the col location: [c * width * height + h_out, w_out]
int c_col = int(c * blockH* blockW) + \
(h - h_col * (int)strideH) * (int)blockW +
(w - w_col * (int)strideW);
val += data_col[(c_col * height_col + h_col) * width_col + w_col];
}
}
h -= paddingH;
w -= paddingW;
data_im[c*((width-2*paddingW) * (height-2*paddingH)) +
h*(width-2*paddingW) + w] += val;
}
}
}
template <class T>
class Col2ImFunctor<DEVICE_TYPE_GPU, T> {
public:
void operator()(const T* colData,
int inputChannels,
int inputHeight,
int inputWidth,
int filterHeight,
int filterWidth,
int strideHeight,
int strideWidth,
int paddingHeight,
int paddingWidth,
int outputHeight,
int outputWidth,
T* imData) {
size_t numKernels = inputChannels * (inputHeight + 2*paddingHeight)
* (inputWidth + 2*paddingWidth);
size_t blocks = (numKernels + 1024 -1) / 1024;
size_t blockX = 512;
size_t blockY = (blocks+512-1)/512;
dim3 threads(1024, 1);
dim3 grid(blockX, blockY);
// To avoid involving atomic operations, we will launch one kernel per
// bottom dimension, and then in the kernel add up the top dimensions.
col2im<T><<< grid, threads, 0, STREAM_DEFAULT >>>
(numKernels,
colData,
inputHeight + 2*paddingHeight,
inputWidth + 2*paddingWidth,
inputChannels,
filterHeight,
filterWidth,
strideHeight,
strideWidth,
paddingHeight,
paddingWidth,
outputHeight,
outputWidth,
imData);
CHECK_SYNC("Col2ImFunctor GPU failed");
}
};
template class Im2ColFunctor<DEVICE_TYPE_GPU, float>; template class Im2ColFunctor<DEVICE_TYPE_GPU, float>;
template class Im2ColFunctor<DEVICE_TYPE_GPU, double>; template class Im2ColFunctor<DEVICE_TYPE_GPU, double>;
template class Col2ImFunctor<DEVICE_TYPE_GPU, float>;
template class Col2ImFunctor<DEVICE_TYPE_GPU, double>;
} // namespace paddle } // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册