提交 013d0a26 编写于 作者: W wanghaoshuang

add crop layer

上级 fa216165
......@@ -37,6 +37,7 @@ if(WITH_GPU)
add_simple_unittest(MulOpTest)
add_simple_unittest(CosSimOpTest)
add_simple_unittest(RowConvOpTest)
add_simple_unittest(CropOpTest)
endif()
add_simple_unittest(ConvOpTest)
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "CropOp.h"
#include "paddle/math/Vector.h"
#include "paddle/function/TensorShape.h"
namespace paddle {
static inline CropConf castToCropConf(const FuncConfig& conf) {
return {conf.get<std::vector<uint32_t>>("crop_corner"),
conf.get<std::vector<uint32_t>>("crop_shape")};
}
template <>
void Crop<DEVICE_TYPE_CPU>(real* outputs,
const real* inputs,
const TensorShape inShape,
const CropConf& crop) {
int cCrop = crop.corner[0];
int hCrop = crop.corner[1];
int wCrop = crop.corner[2];
int num = inShape[0];
int inC = inShape[1];
int inH = inShape[2];
int inW = inShape[3];
int outC = crop.shape[0];
int outH = crop.shape[1];
int outW = crop.shape[2];
for (int n = 0; n < num; n++) {
for (int c = 0; c < outC; c++) {
for (int h = 0; h < outH; h++) {
int outoff = ((n * outC + c) * outH + h) * outW;
int inoff = ((n * inC + c + cCrop) * inH + h + hCrop) * inW + wCrop;
memcpy(outputs + outoff, inputs + inoff, outW * sizeof(real));
}
}
}
}
template <>
void CropGrad<DEVICE_TYPE_CPU>(const real* inGrad,
real* outGrad,
const TensorShape outShape,
const CropConf& crop) {
int cCrop = crop.corner[0];
int hCrop = crop.corner[1];
int wCrop = crop.corner[2];
int num = outShape[0];
int outC = outShape[1];
int outH = outShape[2];
int outW = outShape[3];
int inC = crop.shape[0];
int inH = crop.shape[1];
int inW = crop.shape[2];
for (int n = 0; n < num; n++) {
for (int c = 0; c < inC; c++) {
for (int h = 0; h < inH; h++) {
int outoff = ((n * outC + c + cCrop) * outH + h + hCrop) * outW + wCrop;
int inoff = ((n * inC + c) * inH + h) * inW;
CpuVector inG = CpuVector(inW, const_cast<real*>(inGrad + inoff));
CpuVector outG = CpuVector(inW, outGrad + outoff);
outG += inG;
}
}
}
}
/**
* \brief Crop input according to the specify corner and shape.
* The input and output is a 4D tensor. In CropFunc, we only
* crop the 2nd to 4th dimension.
*
* Argument in this Function:
* \param pad_ A struct object contains the cropping corner and shape.
* \param inputs A 4D tensor, only one input.
* \param outputs A 4D tensor, the output value after cropping.
*
* For example,
* Input(2,2,2,3) = [
* [ [[1,2,3], [3,4,5]],
* [[2,3,5], [1,6,7]] ],
* [ [[4,3,1], [1,8,7]],
* [[3,8,9], [2,3,5]] ]
* ] # the input shape is (2,2,2,3)
*
* pad_: if corner = (0,1,1) and crop_shape = (2,1,2)
* Output(2,2,1,2) = [
* [ [[4,5]],
* [[6,7]] ],
* [ [[8,7]],
* [[3,5]] ]
* ] # the input shape is (2,2,2,3)
*/
template <DeviceType Device>
class CropFunc : public FunctionBase {
public:
void init(const FuncConfig& config) override {
crop_ = castToCropConf(config);
}
void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
CHECK_EQ(1UL, inputs.size());
CHECK_EQ(1UL, outputs.size());
CHECK_EQ(outputs[0].shape()[1], crop_.shape[0]);
CHECK_EQ(outputs[0].shape()[2], crop_.shape[1]);
CHECK_EQ(outputs[0].shape()[3], crop_.shape[2]);
CHECK_EQ(outputs[0].getArgType(), ASSIGN_TO);
TensorShape inShape = inputs[0].shape();
Crop<Device>(
outputs[0].data<real>(), inputs[0].data<real>(), inShape, crop_);
}
private:
CropConf crop_;
};
/**
* \brief The backward propagation of cropping Function.
*
* Argument in this Function:
* \param crop_ The same meaning as it in CropFunc.
* \param inputs The gradient with respect to the output value of CropFunc.
* \param outputs The gradient with respect to the input value of CropFunc.
*/
template <DeviceType Device>
class CropGradFunc : public FunctionBase {
public:
void init(const FuncConfig& config) override {
crop_ = castToCropConf(config);
}
void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
CHECK_EQ(1UL, inputs.size());
CHECK_EQ(1UL, outputs.size());
CHECK_EQ(inputs[0].shape()[1], crop_.shape[0]);
CHECK_EQ(inputs[0].shape()[2], crop_.shape[1]);
CHECK_EQ(inputs[0].shape()[3], crop_.shape[2]);
CHECK_EQ(outputs[0].getArgType(), ASSIGN_TO);
TensorShape outShape = outputs[0].shape();
CropGrad<Device>(
inputs[0].data<real>(), outputs[0].data<real>(), outShape, crop_);
}
private:
CropConf crop_;
};
REGISTER_TYPED_FUNC(Crop, CPU, CropFunc);
REGISTER_TYPED_FUNC(CropGrad, CPU, CropGradFunc);
#ifndef PADDLE_ONLY_CPU
REGISTER_TYPED_FUNC(Crop, GPU, CropFunc);
REGISTER_TYPED_FUNC(CropGrad, GPU, CropGradFunc);
#endif
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "Function.h"
namespace paddle {
struct CropConf {
/// The upper left corner of croped result
std::vector<uint32_t> corner;
/// The shape of croped result
std::vector<uint32_t> shape;
};
/**
* \brief This funtion crops inputs according to the specify start point and
*shape.
*
* \param[out] outputs save results.
* \param[in] inputs input data.
* \param[in] inShape the shape of input tensor.
* \param[in] crop the cropping config
*/
template <DeviceType Device>
void Crop(real* outputs,
const real* inputs,
const TensorShape inShape,
const CropConf& crop);
/**
* \brief Cropping operation backward.
*
* \param[out] inGrad gradients of previous layer
* \param[in] outGrad output gradient
* \param[in] inShape the shape of input tensor.
* \param[in] crop the cropping config
*/
template <DeviceType Device>
void CropGrad(const real* inGrad,
real* outGrad,
const TensorShape inShape,
const CropConf& crop);
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "hl_base.h"
#include "CropOp.h"
namespace paddle {
__global__ void KeCrop(real* outputs, const real* inputs,
int inC, int inH, int inW,
int cropC, int cropH, int cropW,
int outC, int outH, int outW, int nthreads) {
const int idx = threadIdx.x + blockIdx.x * blockDim.x;
if (idx < nthreads) {
const int w = idx % outW;
const int h = (idx / outW) % outH;
const int c = (idx / outW / outH) % outC;
const int n = idx / outW / outH / outC;
const int off = ((n * inC + c + cropC) * inH + h + cropH) * inW + cropW + w;
outputs[idx] = inputs[off];
}
}
template <>
void Crop<DEVICE_TYPE_GPU>(real* outputs,
const real* inputs,
const TensorShape inShape,
const CropConf& crop) {
int cropC = crop.corner[0];
int cropH = crop.corner[1];
int cropW = crop.corner[2];
int num = inShape[0];
int inC = inShape[1];
int inH = inShape[2];
int inW = inShape[3];
int outC = crop.shape[0];
int outH = crop.shape[1];
int outW = crop.shape[2];
size_t nth = num * outC * outH * outW;
int blockSize = 1024;
int gridSize = (nth + blockSize - 1) / blockSize;
KeCrop<<<gridSize, blockSize, 0, STREAM_DEFAULT>>>
(outputs, inputs, inC, inH, inW, cropC, cropH, cropW,
outC, outH, outW, nth);
CHECK_SYNC("Crop");
}
__global__ void KeCropDiff(const real* inGrad, real* outGrad,
int inC, int inH, int inW,
int cropC, int cropH, int cropW,
int outC, int outH, int outW, int nthreads) {
const int idx = threadIdx.x + blockIdx.x * blockDim.x;
if (idx < nthreads) {
const int w = idx % inW;
const int h = (idx / inW) % inH;
const int c = (idx / inW / inH) % inC;
const int n = idx / inW / inH / inC;
const int off = ((n * outC + c + cropC) * outH + h + cropH) * outW + cropW + w;
outGrad[off] += inGrad[idx];
}
}
template <>
void CropGrad<DEVICE_TYPE_GPU>(const real* inGrad,
real* outGrad,
const TensorShape outShape,
const CropConf& crop) {
int cropC = crop.corner[0];
int cropH = crop.corner[1];
int cropW = crop.corner[2];
int num = outShape[0];
int outC = outShape[1];
int outH = outShape[2];
int outW = outShape[3];
int inC = crop.shape[0];
int inH = crop.shape[1];
int inW = crop.shape[2];
size_t nth = num * inC * inH * inW;
int blockSize = 1024;
int gridSize = (nth + blockSize - 1) / blockSize;
KeCropDiff <<<gridSize, blockSize, 0, STREAM_DEFAULT>>>
(inGrad, outGrad, inC, inH, inW, cropC, cropH, cropW,
outC, outH, outW, nth);
CHECK_SYNC("CropGrad");
}
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <gtest/gtest.h>
#include "FunctionTest.h"
namespace paddle {
TEST(Crop, real) {
for (size_t numSamples : {5, 32}) {
for (size_t channels : {5, 5, 32}) {
for (size_t imgSizeH : {5, 33, 100}) {
for (size_t imgSizeW : {5, 32, 96}) {
VLOG(3) << " numSamples=" << numSamples << " channels=" << channels
<< " imgSizeH=" << imgSizeH << " imgSizeW=" << imgSizeW;
for (bool test_grad : {false, true}) {
FunctionCompare compare(
test_grad ? "CropGrad" : "Crop",
FuncConfig()
.set<std::vector<uint32_t>>("crop_corner", {1, 1, 1})
.set<std::vector<uint32_t>>("crop_shape", {2, 3, 3}));
TensorShape inDims{numSamples, channels, imgSizeH, imgSizeW};
TensorShape outDims{numSamples, 2, 3, 3};
compare.addInputs(
BufferArg(VALUE_TYPE_FLOAT, test_grad ? outDims : inDims));
compare.addOutputs(BufferArg(
VALUE_TYPE_FLOAT, test_grad ? inDims : outDims, ASSIGN_TO));
compare.run();
}
}
}
}
}
}
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "CropLayer.h"
#include "paddle/utils/Stat.h"
namespace paddle {
REGISTER_LAYER(crop, CropLayer);
bool CropLayer::init(const LayerMap& layerMap,
const ParameterMap& parameterMap) {
/* Initialize the basic parent class */
Layer::init(layerMap, parameterMap);
auto& crop_conf = config_.inputs(0).crop_conf();
auto& img_conf = crop_conf.image_conf();
CHECK_EQ(config_.inputs_size(), 1);
inDims_ = TensorShape(
{0,
img_conf.channels(),
img_conf.has_img_size_y() ? img_conf.img_size_y() : img_conf.img_size(),
img_conf.img_size()});
crop_corner_ = {crop_conf.crop_corner(0),
crop_conf.crop_corner(1),
crop_conf.crop_corner(2)};
crop_shape_ = {crop_conf.crop_shape(0),
crop_conf.crop_shape(1),
crop_conf.crop_shape(2)};
outDims_ = TensorShape(4);
setOutDims(0);
createFunction(forward_,
"Crop",
FuncConfig()
.set("crop_corner", crop_corner_)
.set("crop_shape", crop_shape_));
createFunction(backward_,
"CropGrad",
FuncConfig()
.set("crop_corner", crop_corner_)
.set("crop_shape", crop_shape_));
return true;
}
void CropLayer::setOutDims(const size_t batchSize) {
outDims_.reshape({batchSize, crop_shape_[0], crop_shape_[1], crop_shape_[2]});
}
void CropLayer::setTensorDim(const size_t batchSize) {
CHECK_EQ(static_cast<int>(inputLayers_.size()), 1);
inDims_.setDim(0, batchSize);
int h = inputLayers_[0]->getOutput().getFrameHeight();
if (h != 0) inDims_.setDim(2, h);
int w = inputLayers_[0]->getOutput().getFrameWidth();
if (w != 0) inDims_.setDim(3, w);
setOutDims(batchSize);
}
void CropLayer::forward(PassType passType) {
Layer::forward(passType);
MatrixPtr input = inputLayers_[0]->getOutputValue();
size_t batchSize = input->getHeight();
setTensorDim(batchSize);
int size = outDims_[1] * outDims_[2] * outDims_[3];
resetOutput(batchSize, size);
MatrixPtr outV = getOutputValue();
REGISTER_TIMER_INFO("CropForward", getName().c_str());
BufferArgs inputs;
BufferArgs outputs;
inputs.addArg(*getInputValue(0), inDims_);
outputs.addArg(*getOutputValue(), outDims_, ASSIGN_TO);
forward_[0]->calc(inputs, outputs);
}
void CropLayer::backward(const UpdateCallback& callback) {
(void)callback;
REGISTER_TIMER_INFO("CropBackward", getName().c_str());
BufferArgs inputs;
BufferArgs outputs;
inputs.addArg(*getOutputGrad(), outDims_);
outputs.addArg(*getInputGrad(0), inDims_, ADD_TO);
backward_[0]->calc(inputs, outputs);
}
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "Layer.h"
namespace paddle {
/**
* \brief This layer crop inputs according to the specify corner and shape.
* The input and output is a 4D tensor. Cropping from the 2nd to
* the 4th dimenstion.
*/
class CropLayer : public Layer {
public:
explicit CropLayer(const LayerConfig& config) : Layer(config) {}
~CropLayer() {}
bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
void forward(PassType passType) override;
void backward(const UpdateCallback& callback = nullptr) override;
protected:
void setOutDims(const size_t batchSize);
void setTensorDim(const size_t batchSize);
std::vector<uint32_t> crop_corner_;
std::vector<uint32_t> crop_shape_;
TensorShape inDims_;
TensorShape outDims_;
};
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册