提交 f8ef8c17 编写于 作者: H hedaoyuan

Add the GPU version implementation of ImageExpandGrad function.

上级 152bd2f9
...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "Im2Col.h" #include "Im2Col.h"
#include "hl_device_functions.cuh"
namespace paddle { namespace paddle {
...@@ -25,30 +26,29 @@ void im2colOCF(const T* imData, T* colData, ...@@ -25,30 +26,29 @@ void im2colOCF(const T* imData, T* colData,
int strideHeight, int strideWidth, int strideHeight, int strideWidth,
int paddingHeight, int paddingWidth, int paddingHeight, int paddingWidth,
int outputHeight, int outputWidth) { int outputHeight, int outputWidth) {
int idx = threadIdx.x;
int idy = threadIdx.y;
int swId = blockIdx.x; int swId = blockIdx.x;
int shId = blockIdx.y; int shId = blockIdx.y;
for (int channelId = threadIdx.z; for (int channelId = threadIdx.z;
channelId < inputChannels; channelId < inputChannels;
channelId += blockDim.z) { channelId += blockDim.z) {
int widthOffset = idx + swId * strideWidth - paddingWidth; for (int idy = threadIdx.y; idy < filterHeight; idy += blockDim.y) {
int heightOffset = idy + shId * strideHeight - paddingHeight; for (int idx = threadIdx.x; idx < filterWidth; idx += blockDim.x) {
int imOffset = widthOffset + heightOffset * inputWidth int widthOffset = idx + swId * strideWidth - paddingWidth;
+ channelId * inputHeight * inputWidth; int heightOffset = idy + shId * strideHeight - paddingHeight;
int imOffset = widthOffset + heightOffset * inputWidth
int colOffset = idx + idy * filterWidth + channelId * inputHeight * inputWidth;
+ channelId * filterHeight * filterWidth
+ (shId * outputWidth + swId) int colOffset = idx + idy * filterWidth
* (inputChannels * filterHeight * filterWidth); + channelId * filterHeight * filterWidth
+ (shId * outputWidth + swId)
if (idx < filterWidth && idy < filterHeight) { * (inputChannels * filterHeight * filterWidth);
if (heightOffset >= inputHeight || heightOffset < 0 ||
widthOffset >= inputWidth || widthOffset < 0) { if (heightOffset >= inputHeight || heightOffset < 0 ||
colData[colOffset] = T(0); widthOffset >= inputWidth || widthOffset < 0) {
} else { colData[colOffset] = T(0);
colData[colOffset] = imData[imOffset]; } else {
colData[colOffset] = imData[imOffset];
}
} }
} }
} }
...@@ -105,6 +105,41 @@ public: ...@@ -105,6 +105,41 @@ public:
} }
}; };
template<class T>
__global__
void col2imOCF(T* imData, const T* colData,
int inputChannels,
int inputHeight, int inputWidth,
int filterHeight, int filterWidth,
int strideHeight, int strideWidth,
int paddingHeight, int paddingWidth,
int outputHeight, int outputWidth) {
int swId = blockIdx.x;
int shId = blockIdx.y;
for (int channelId = threadIdx.z;
channelId < inputChannels;
channelId += blockDim.z) {
for (int idy = threadIdx.y; idy < filterHeight; idy += blockDim.y) {
for (int idx = threadIdx.x; idx < filterWidth; idx += blockDim.x) {
int widthOffset = idx + swId * strideWidth - paddingWidth;
int heightOffset = idy + shId * strideHeight - paddingHeight;
int imOffset = widthOffset + heightOffset * inputWidth
+ channelId * inputHeight * inputWidth;
int colOffset = idx + idy * filterWidth
+ channelId * filterHeight * filterWidth
+ (shId * outputWidth + swId)
* (inputChannels * filterHeight * filterWidth);
if (heightOffset >= 0 && heightOffset < inputHeight &&
widthOffset >= 0 && widthOffset < inputWidth) {
paddle::paddleAtomicAdd(imData + imOffset, colData[colOffset]);
}
}
}
}
}
/* /*
* imShape = [inputChannels, inputHeight, inputWidth] * imShape = [inputChannels, inputHeight, inputWidth]
* colShape = * colShape =
...@@ -121,10 +156,44 @@ public: ...@@ -121,10 +156,44 @@ public:
int strideWidth, int strideWidth,
int paddingHeight, int paddingHeight,
int paddingWidth) { int paddingWidth) {
int inputChannels = imShape[0];
int inputHeight = imShape[1];
int inputWidth = imShape[2];
int filterHeight = colShape[3];
int filterWidth = colShape[4];
int outputHeight = colShape[0];
int outputWidth = colShape[1];
int blockDimX = 0;
int blockDimY = 0;
if (filterHeight <= 4 && filterWidth <= 4) {
blockDimX = 4;
blockDimY = 4;
} else if (filterHeight <= 8 && filterWidth <= 8) {
blockDimX = 8;
blockDimY = 8;
} else if (filterHeight <= 16 && filterWidth <= 16) {
blockDimX = 16;
blockDimY = 16;
} else {
blockDimX = 32;
blockDimY = 32;
}
int blockDimZ = 1024 / blockDimX / blockDimY;
dim3 threads(blockDimX, blockDimY, std::min(blockDimZ, inputChannels));
dim3 grid(outputWidth, outputHeight);
col2imOCF<T><<< grid, threads, 0, STREAM_DEFAULT >>>
(imData, colData, inputChannels, inputHeight, inputWidth,
filterHeight, filterWidth, strideHeight, strideWidth,
paddingHeight, paddingWidth, outputHeight, outputWidth);
CHECK_SYNC("Col2ImFunctor GPU failed");
} }
}; };
template class Im2ColFunctor<kOCF, DEVICE_TYPE_GPU, float>; template class Im2ColFunctor<kOCF, DEVICE_TYPE_GPU, float>;
template class Im2ColFunctor<kOCF, DEVICE_TYPE_GPU, double>; template class Im2ColFunctor<kOCF, DEVICE_TYPE_GPU, double>;
template class Col2ImFunctor<kOCF, DEVICE_TYPE_GPU, float>;
template class Col2ImFunctor<kOCF, DEVICE_TYPE_GPU, double>;
} // namespace paddle } // namespace paddle
...@@ -293,6 +293,7 @@ REGISTER_TYPED_FUNC(ImageExpand, CPU, ImageExpandForward); ...@@ -293,6 +293,7 @@ REGISTER_TYPED_FUNC(ImageExpand, CPU, ImageExpandForward);
REGISTER_TYPED_FUNC(ImageExpandGrad, CPU, ImageExpandBackward); REGISTER_TYPED_FUNC(ImageExpandGrad, CPU, ImageExpandBackward);
#ifndef PADDLE_ONLY_CPU #ifndef PADDLE_ONLY_CPU
REGISTER_TYPED_FUNC(ImageExpand, GPU, ImageExpandForward); REGISTER_TYPED_FUNC(ImageExpand, GPU, ImageExpandForward);
REGISTER_TYPED_FUNC(ImageExpandGrad, GPU, ImageExpandBackward);
#endif #endif
} // namespace paddle } // namespace paddle
...@@ -46,14 +46,12 @@ bool BlockExpandLayer::init(const LayerMap& layerMap, ...@@ -46,14 +46,12 @@ bool BlockExpandLayer::init(const LayerMap& layerMap,
.set("strides", strides) .set("strides", strides)
.set("paddings", paddings) .set("paddings", paddings)
.set("blocks", blocks)); .set("blocks", blocks));
if (!useGpu_) { createFunction(backward_,
createFunction(backward_, "ImageExpandGrad",
"ImageExpandGrad", FuncConfig()
FuncConfig() .set("strides", strides)
.set("strides", strides) .set("paddings", paddings)
.set("paddings", paddings) .set("blocks", blocks));
.set("blocks", blocks));
}
return true; return true;
} }
...@@ -110,14 +108,16 @@ void BlockExpandLayer::forward(PassType passType) { ...@@ -110,14 +108,16 @@ void BlockExpandLayer::forward(PassType passType) {
} }
void BlockExpandLayer::backward(const UpdateCallback& callback) { void BlockExpandLayer::backward(const UpdateCallback& callback) {
size_t blockNum = outputH_ * outputW_;
size_t blockSize = blockH_ * blockW_ * channels_;
/* Calculate the input layers error */ /* Calculate the input layers error */
MatrixPtr preGrad = inputLayers_[0]->getOutputGrad(); if (getInputGrad(0)) {
if (!preGrad) { BufferArgs inputs;
return; BufferArgs outputs;
inputs.addArg(*getOutputGrad(), outputShape_);
outputs.addArg(*getInputGrad(0), inputShape_, ADD_TO);
backward_[0]->calc(inputs, outputs);
} }
#if 0
if (useGpu_) { if (useGpu_) {
MatrixPtr grad = getOutputGrad(); MatrixPtr grad = getOutputGrad();
MatrixPtr gradTrans = Matrix::create(blockSize, blockNum, false, useGpu_); MatrixPtr gradTrans = Matrix::create(blockSize, blockNum, false, useGpu_);
...@@ -155,13 +155,8 @@ void BlockExpandLayer::backward(const UpdateCallback& callback) { ...@@ -155,13 +155,8 @@ void BlockExpandLayer::backward(const UpdateCallback& callback) {
1.0, 1.0,
1.0); 1.0);
} }
} else {
BufferArgs inputs;
BufferArgs outputs;
inputs.addArg(*getOutputGrad(), outputShape_);
outputs.addArg(*getInputGrad(0), inputShape_, ADD_TO);
backward_[0]->calc(inputs, outputs);
} }
#endif
} }
} // namespace paddle } // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册