提交 f8ef8c17 编写于 作者: H hedaoyuan

Add the GPU version implementation of ImageExpandGrad function.

上级 152bd2f9
......@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "Im2Col.h"
#include "hl_device_functions.cuh"
namespace paddle {
......@@ -25,30 +26,29 @@ void im2colOCF(const T* imData, T* colData,
int strideHeight, int strideWidth,
int paddingHeight, int paddingWidth,
int outputHeight, int outputWidth) {
int idx = threadIdx.x;
int idy = threadIdx.y;
int swId = blockIdx.x;
int shId = blockIdx.y;
for (int channelId = threadIdx.z;
channelId < inputChannels;
channelId += blockDim.z) {
int widthOffset = idx + swId * strideWidth - paddingWidth;
int heightOffset = idy + shId * strideHeight - paddingHeight;
int imOffset = widthOffset + heightOffset * inputWidth
+ channelId * inputHeight * inputWidth;
int colOffset = idx + idy * filterWidth
+ channelId * filterHeight * filterWidth
+ (shId * outputWidth + swId)
* (inputChannels * filterHeight * filterWidth);
if (idx < filterWidth && idy < filterHeight) {
if (heightOffset >= inputHeight || heightOffset < 0 ||
widthOffset >= inputWidth || widthOffset < 0) {
colData[colOffset] = T(0);
} else {
colData[colOffset] = imData[imOffset];
for (int idy = threadIdx.y; idy < filterHeight; idy += blockDim.y) {
for (int idx = threadIdx.x; idx < filterWidth; idx += blockDim.x) {
int widthOffset = idx + swId * strideWidth - paddingWidth;
int heightOffset = idy + shId * strideHeight - paddingHeight;
int imOffset = widthOffset + heightOffset * inputWidth
+ channelId * inputHeight * inputWidth;
int colOffset = idx + idy * filterWidth
+ channelId * filterHeight * filterWidth
+ (shId * outputWidth + swId)
* (inputChannels * filterHeight * filterWidth);
if (heightOffset >= inputHeight || heightOffset < 0 ||
widthOffset >= inputWidth || widthOffset < 0) {
colData[colOffset] = T(0);
} else {
colData[colOffset] = imData[imOffset];
}
}
}
}
......@@ -105,6 +105,41 @@ public:
}
};
template<class T>
__global__
void col2imOCF(T* imData, const T* colData,
int inputChannels,
int inputHeight, int inputWidth,
int filterHeight, int filterWidth,
int strideHeight, int strideWidth,
int paddingHeight, int paddingWidth,
int outputHeight, int outputWidth) {
int swId = blockIdx.x;
int shId = blockIdx.y;
for (int channelId = threadIdx.z;
channelId < inputChannels;
channelId += blockDim.z) {
for (int idy = threadIdx.y; idy < filterHeight; idy += blockDim.y) {
for (int idx = threadIdx.x; idx < filterWidth; idx += blockDim.x) {
int widthOffset = idx + swId * strideWidth - paddingWidth;
int heightOffset = idy + shId * strideHeight - paddingHeight;
int imOffset = widthOffset + heightOffset * inputWidth
+ channelId * inputHeight * inputWidth;
int colOffset = idx + idy * filterWidth
+ channelId * filterHeight * filterWidth
+ (shId * outputWidth + swId)
* (inputChannels * filterHeight * filterWidth);
if (heightOffset >= 0 && heightOffset < inputHeight &&
widthOffset >= 0 && widthOffset < inputWidth) {
paddle::paddleAtomicAdd(imData + imOffset, colData[colOffset]);
}
}
}
}
}
/*
* imShape = [inputChannels, inputHeight, inputWidth]
* colShape =
......@@ -121,10 +156,44 @@ public:
int strideWidth,
int paddingHeight,
int paddingWidth) {
int inputChannels = imShape[0];
int inputHeight = imShape[1];
int inputWidth = imShape[2];
int filterHeight = colShape[3];
int filterWidth = colShape[4];
int outputHeight = colShape[0];
int outputWidth = colShape[1];
int blockDimX = 0;
int blockDimY = 0;
if (filterHeight <= 4 && filterWidth <= 4) {
blockDimX = 4;
blockDimY = 4;
} else if (filterHeight <= 8 && filterWidth <= 8) {
blockDimX = 8;
blockDimY = 8;
} else if (filterHeight <= 16 && filterWidth <= 16) {
blockDimX = 16;
blockDimY = 16;
} else {
blockDimX = 32;
blockDimY = 32;
}
int blockDimZ = 1024 / blockDimX / blockDimY;
dim3 threads(blockDimX, blockDimY, std::min(blockDimZ, inputChannels));
dim3 grid(outputWidth, outputHeight);
col2imOCF<T><<< grid, threads, 0, STREAM_DEFAULT >>>
(imData, colData, inputChannels, inputHeight, inputWidth,
filterHeight, filterWidth, strideHeight, strideWidth,
paddingHeight, paddingWidth, outputHeight, outputWidth);
CHECK_SYNC("Col2ImFunctor GPU failed");
}
};
template class Im2ColFunctor<kOCF, DEVICE_TYPE_GPU, float>;
template class Im2ColFunctor<kOCF, DEVICE_TYPE_GPU, double>;
template class Col2ImFunctor<kOCF, DEVICE_TYPE_GPU, float>;
template class Col2ImFunctor<kOCF, DEVICE_TYPE_GPU, double>;
} // namespace paddle
......@@ -293,6 +293,7 @@ REGISTER_TYPED_FUNC(ImageExpand, CPU, ImageExpandForward);
REGISTER_TYPED_FUNC(ImageExpandGrad, CPU, ImageExpandBackward);
#ifndef PADDLE_ONLY_CPU
REGISTER_TYPED_FUNC(ImageExpand, GPU, ImageExpandForward);
REGISTER_TYPED_FUNC(ImageExpandGrad, GPU, ImageExpandBackward);
#endif
} // namespace paddle
......@@ -46,14 +46,12 @@ bool BlockExpandLayer::init(const LayerMap& layerMap,
.set("strides", strides)
.set("paddings", paddings)
.set("blocks", blocks));
if (!useGpu_) {
createFunction(backward_,
"ImageExpandGrad",
FuncConfig()
.set("strides", strides)
.set("paddings", paddings)
.set("blocks", blocks));
}
createFunction(backward_,
"ImageExpandGrad",
FuncConfig()
.set("strides", strides)
.set("paddings", paddings)
.set("blocks", blocks));
return true;
}
......@@ -110,14 +108,16 @@ void BlockExpandLayer::forward(PassType passType) {
}
void BlockExpandLayer::backward(const UpdateCallback& callback) {
size_t blockNum = outputH_ * outputW_;
size_t blockSize = blockH_ * blockW_ * channels_;
/* Calculate the input layers error */
MatrixPtr preGrad = inputLayers_[0]->getOutputGrad();
if (!preGrad) {
return;
if (getInputGrad(0)) {
BufferArgs inputs;
BufferArgs outputs;
inputs.addArg(*getOutputGrad(), outputShape_);
outputs.addArg(*getInputGrad(0), inputShape_, ADD_TO);
backward_[0]->calc(inputs, outputs);
}
#if 0
if (useGpu_) {
MatrixPtr grad = getOutputGrad();
MatrixPtr gradTrans = Matrix::create(blockSize, blockNum, false, useGpu_);
......@@ -155,13 +155,8 @@ void BlockExpandLayer::backward(const UpdateCallback& callback) {
1.0,
1.0);
}
} else {
BufferArgs inputs;
BufferArgs outputs;
inputs.addArg(*getOutputGrad(), outputShape_);
outputs.addArg(*getInputGrad(0), inputShape_, ADD_TO);
backward_[0]->calc(inputs, outputs);
}
#endif
}
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册