提交 2acb84fe 编写于 作者: H hedaoyuan

Add ImageExpandGrad Function.

上级 61aa1098
......@@ -44,6 +44,7 @@ enum ColFormat { kCFO = 0, kOCF = 1 };
* input_channels,
* filter_height,
* filter_width]
* TODO(hedaoyuan): Refactor the arguments of the interface with TensorShape.
*/
template <ColFormat Format, DeviceType Device, class T>
class Im2ColFunctor {
......
......@@ -70,16 +70,67 @@ public:
}
};
template <class T>
class Col2ImFunctor<kOCF, DEVICE_TYPE_CPU, T> {
public:
void operator()(const T* colData,
int inputChannels,
int inputHeight,
int inputWidth,
int filterHeight,
int filterWidth,
int strideHeight,
int strideWidth,
int paddingHeight,
int paddingWidth,
int outputHeight,
int outputWidth,
T* imData) {
for (int outputH = 0; outputH < outputHeight; ++outputH) {
for (int outputW = 0; outputW < outputWidth; ++outputW) {
for (int channel = 0; channel < inputChannels; ++channel) {
for (int filterH = 0; filterH < filterHeight; ++filterH) {
for (int filterW = 0; filterW < filterWidth; ++filterW) {
int imRowOffset =
outputH * strideHeight + filterH - paddingHeight;
int imColOffset = outputW * strideWidth + filterW - paddingWidth;
int colDataOffset =
(((outputH * outputWidth + outputW) * inputChannels +
channel) *
filterHeight +
filterH) *
filterWidth +
filterW;
if (imRowOffset >= 0 && imRowOffset < inputHeight &&
imColOffset >= 0 && imColOffset < inputWidth) {
int imDataOffset =
(channel * inputHeight + imRowOffset) * inputWidth +
imColOffset;
imData[imDataOffset] += colData[colDataOffset];
}
}
}
}
}
}
}
};
/*
* \brief Converts the image data of four dimensions(NCHW) into
* a sequence data of three dimensions(NST). Where N is batch size,
* S is the length of the sequence after each image is expanded,
* T is the size of each time step in the sequence.
* a sequence data of three dimensions(NST) in the forward calculation,
* which is reversed in the backward calculation.
* Where N is batch size, S is the length of the sequence after each
* image is expanded, T is the size of each time step in the sequence.
*
* Arguments in forward function:
* \param inputs[0] Image data of NCHW format.
* \param outputs[0] Sequence data of NST format.
*
* Arguments in backward function:
* \param inputs[0] Sequence data of NST format.
* \param outputs[0] Image data of NCHW format.
*/
template <DeviceType Device>
class ImageExpandFunction : public FunctionBase {
public:
void init(const FuncConfig& config) override {
......@@ -93,25 +144,27 @@ public:
numOutputs_ = 1;
}
void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
CHECK_EQ(numInputs_, inputs.size());
CHECK_EQ(numOutputs_, outputs.size());
const TensorShape& input = inputs[0].shape();
const TensorShape& output = outputs[0].shape();
// input argument should be 4-dimensional.
CHECK_EQ(input.ndims(), (size_t)4);
// output argument should be 3-dimensional.
CHECK_EQ(output.ndims(), (size_t)3);
// The batchSize of the input needs to be equal to
// the batchSize of the output.
CHECK_EQ(input[0], output[0]);
size_t batchSize = input[0];
size_t inputChannels = input[1];
size_t inputHeight = input[2];
size_t inputWidth = input[3];
size_t seqLength = output[1];
size_t stepSize = output[2];
virtual void calc(const BufferArgs& inputs, const BufferArgs& outputs) {}
void check(const TensorShape& image, const TensorShape& sequence) {
// image shape should be 4-dimensional.
CHECK_EQ(image.ndims(), (size_t)4);
// sequence shape should be 3-dimensional.
CHECK_EQ(sequence.ndims(), (size_t)3);
// The batchSize of the image needs to be equal to
// the batchSize of the sequence.
CHECK_EQ(image[0], sequence[0]);
}
// Calculate the shape of colData based on the shape of the image
// and the shape of the sequence.
TensorShape getColShape(const TensorShape& image,
const TensorShape& sequence) {
size_t inputChannels = image[1];
size_t inputHeight = image[2];
size_t inputWidth = image[3];
size_t seqLength = sequence[1];
size_t stepSize = sequence[2];
size_t outputHeight =
1 +
(inputHeight + 2 * paddingH() - blockH() + strideH() - 1) / strideH();
......@@ -121,8 +174,59 @@ public:
CHECK_EQ(seqLength, outputHeight * outputWidth);
CHECK_EQ(stepSize, inputChannels * blockH() * blockW());
real* inputData = inputs[0].data<real>();
real* outputData = outputs[0].data<real>();
// [output_height, output_width,
// input_channels, filter_height, filter_width]
return TensorShape({outputHeight,
outputWidth,
inputChannels,
(size_t)blockH(),
(size_t)blockW()});
}
protected:
std::vector<size_t> strides_;
std::vector<size_t> paddings_;
std::vector<size_t> blocks_;
inline int strideH() const { return strides_[0]; }
inline int strideW() const { return strides_[1]; }
inline int paddingH() const { return paddings_[0]; }
inline int paddingW() const { return paddings_[1]; }
inline int blockH() const { return blocks_[0]; }
inline int blockW() const { return blocks_[1]; }
};
template <DeviceType Device>
class ImageExpandForward : public ImageExpandFunction {
public:
void init(const FuncConfig& config) override {
ImageExpandFunction::init(config);
}
void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
CHECK_EQ(numInputs_, inputs.size());
CHECK_EQ(numOutputs_, outputs.size());
const TensorShape& image = inputs[0].shape();
const TensorShape& sequence = outputs[0].shape();
check(image, sequence);
TensorShape colShape = getColShape(image, sequence);
size_t batchSize = image[0];
size_t inputChannels = image[1];
size_t inputHeight = image[2];
size_t inputWidth = image[3];
size_t seqLength = sequence[1];
size_t stepSize = sequence[2];
size_t outputHeight = colShape[0];
size_t outputWidth = colShape[1];
real* imageData = inputs[0].data<real>();
real* seqData = outputs[0].data<real>();
Im2ColFunctor<kOCF, Device, real> im2col;
for (size_t i = 0; i < batchSize; i++) {
// The result of im2col is [output_height, output_width,
......@@ -130,7 +234,7 @@ public:
// reshape into [seqLength, stepSize], where seqLength is equal
// output_height * output_width, stepSize is equal
// input_channels * filter_height * filter_width
im2col(inputData,
im2col(imageData,
inputChannels,
inputHeight,
inputWidth,
......@@ -142,30 +246,64 @@ public:
paddingW(),
outputHeight,
outputWidth,
outputData);
inputData += inputChannels * inputHeight * inputWidth;
outputData += seqLength * stepSize;
seqData);
imageData += inputChannels * inputHeight * inputWidth;
seqData += seqLength * stepSize;
}
}
};
protected:
std::vector<size_t> strides_;
std::vector<size_t> paddings_;
std::vector<size_t> blocks_;
inline int strideH() const { return strides_[0]; }
inline int strideW() const { return strides_[1]; }
inline int paddingH() const { return paddings_[0]; }
template <DeviceType Device>
class ImageExpandBackward : public ImageExpandFunction {
public:
void init(const FuncConfig& config) override {
ImageExpandFunction::init(config);
}
inline int paddingW() const { return paddings_[1]; }
void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
CHECK_EQ(numInputs_, inputs.size());
CHECK_EQ(numOutputs_, outputs.size());
// Since the implementation of Col2ImFunctor is ADD_TO,
// this function only supports ADD_TO mode.
CHECK_EQ(outputs[0].getArgType(), ADD_TO);
const TensorShape& image = outputs[0].shape();
const TensorShape& sequence = inputs[0].shape();
check(image, sequence);
inline int blockH() const { return blocks_[0]; }
TensorShape colShape = getColShape(image, sequence);
size_t batchSize = image[0];
size_t inputChannels = image[1];
size_t inputHeight = image[2];
size_t inputWidth = image[3];
size_t seqLength = sequence[1];
size_t stepSize = sequence[2];
size_t outputHeight = colShape[0];
size_t outputWidth = colShape[1];
inline int blockW() const { return blocks_[1]; }
real* imageData = outputs[0].data<real>();
real* seqData = inputs[0].data<real>();
Col2ImFunctor<kOCF, Device, real> col2im;
for (size_t i = 0; i < batchSize; i++) {
col2im(seqData,
inputChannels,
inputHeight,
inputWidth,
blockH(),
blockW(),
strideH(),
strideW(),
paddingH(),
paddingW(),
outputHeight,
outputWidth,
imageData);
imageData += inputChannels * inputHeight * inputWidth;
seqData += seqLength * stepSize;
}
}
};
REGISTER_TYPED_FUNC(ImageExpand, CPU, ImageExpandFunction);
REGISTER_TYPED_FUNC(ImageExpand, CPU, ImageExpandForward);
REGISTER_TYPED_FUNC(ImageExpandGrad, CPU, ImageExpandBackward);
} // namespace paddle
......@@ -47,6 +47,12 @@ bool BlockExpandLayer::init(const LayerMap& layerMap,
.set("strides", strides)
.set("paddings", paddings)
.set("blocks", blocks));
createFunction(backward_,
"ImageExpandGrad",
FuncConfig()
.set("strides", strides)
.set("paddings", paddings)
.set("blocks", blocks));
}
return true;
......@@ -126,12 +132,12 @@ void BlockExpandLayer::forward(PassType passType) {
}
start[batchSize] = batchSize * blockNum;
if (!useGpu_) {
TensorShape inputShape({batchSize, channels_, imgSizeH_, imgSizeW_});
TensorShape outputShape({batchSize, blockNum, blockSize});
inputShape_ = TensorShape({batchSize, channels_, imgSizeH_, imgSizeW_});
outputShape_ = TensorShape({batchSize, blockNum, blockSize});
BufferArgs inputs;
BufferArgs outputs;
inputs.addArg(*getInputValue(0), inputShape);
outputs.addArg(*getOutputValue(), outputShape, ASSIGN_TO);
inputs.addArg(*getInputValue(0), inputShape_);
outputs.addArg(*getOutputValue(), outputShape_, ASSIGN_TO);
forward_[0]->calc(inputs, outputs);
}
}
......@@ -144,6 +150,8 @@ void BlockExpandLayer::backward(const UpdateCallback& callback) {
if (!preGrad) {
return;
}
if (useGpu_) {
MatrixPtr grad = getOutputGrad();
MatrixPtr gradTrans = Matrix::create(blockSize, blockNum, false, useGpu_);
size_t batchSize = preGrad->getHeight();
......@@ -180,6 +188,13 @@ void BlockExpandLayer::backward(const UpdateCallback& callback) {
1.0,
1.0);
}
} else {
BufferArgs inputs;
BufferArgs outputs;
inputs.addArg(*getOutputGrad(), outputShape_);
outputs.addArg(*getInputGrad(0), inputShape_, ADD_TO);
backward_[0]->calc(inputs, outputs);
}
}
} // namespace paddle
......@@ -53,6 +53,9 @@ protected:
/// auxiliary variable, which saves the transposed output value.
MatrixPtr outVTrans_;
TensorShape inputShape_;
TensorShape outputShape_;
public:
explicit BlockExpandLayer(const LayerConfig& config) : Layer(config) {}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册