未验证 提交 765085d2 编写于 作者: J jerrywgz 提交者: GitHub

Merge pull request #13904 from jerrywgz/roialign

Add RoI align operator.
......@@ -116,6 +116,7 @@ paddle.fluid.layers.pad ArgSpec(args=['x', 'paddings', 'pad_value', 'name'], var
paddle.fluid.layers.pad_constant_like ArgSpec(args=['x', 'y', 'pad_value', 'name'], varargs=None, keywords=None, defaults=(0.0, None))
paddle.fluid.layers.label_smooth ArgSpec(args=['label', 'prior_dist', 'epsilon', 'dtype', 'name'], varargs=None, keywords=None, defaults=(None, 0.1, 'float32', None))
paddle.fluid.layers.roi_pool ArgSpec(args=['input', 'rois', 'pooled_height', 'pooled_width', 'spatial_scale'], varargs=None, keywords=None, defaults=(1, 1, 1.0))
paddle.fluid.layers.roi_align ArgSpec(args=['input', 'rois', 'pooled_height', 'pooled_width', 'spatial_scale', 'sampling_ratio', 'name'], varargs=None, keywords=None, defaults=(1, 1, 1.0, -1, None))
paddle.fluid.layers.dice_loss ArgSpec(args=['input', 'label', 'epsilon'], varargs=None, keywords=None, defaults=(1e-05,))
paddle.fluid.layers.image_resize ArgSpec(args=['input', 'out_shape', 'scale', 'name', 'resample'], varargs=None, keywords=None, defaults=(None, None, None, 'BILINEAR'))
paddle.fluid.layers.image_resize_short ArgSpec(args=['input', 'out_short_len', 'resample'], varargs=None, keywords=None, defaults=('BILINEAR',))
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/roi_align_op.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
class ROIAlignOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of ROIAlignOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("ROIs"),
"Input(ROIs) of ROIAlignOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of ROIAlignOp should not be null.");
auto input_dims = ctx->GetInputDim("X");
auto rois_dims = ctx->GetInputDim("ROIs");
PADDLE_ENFORCE(input_dims.size() == 4,
"The format of input tensor is NCHW.");
PADDLE_ENFORCE(rois_dims.size() == 2,
"ROIs should be a 2-D LoDTensor of shape (num_rois, 4)"
"given as [[x1, y1, x2, y2], …].");
PADDLE_ENFORCE(rois_dims[1] == 4,
"ROIs should be a 2-D LoDTensor of shape (num_rois, 4)"
"given as [[x1, y1, x2, y2], …].");
int pooled_height = ctx->Attrs().Get<int>("pooled_height");
int pooled_width = ctx->Attrs().Get<int>("pooled_width");
float spatial_scale = ctx->Attrs().Get<float>("spatial_scale");
PADDLE_ENFORCE_GT(pooled_height, 0,
"The pooled output height must greater than 0");
PADDLE_ENFORCE_GT(pooled_width, 0,
"The pooled output width must greater than 0");
PADDLE_ENFORCE_GT(spatial_scale, 0.0f,
"The spatial scale must greater than 0");
auto out_dims = input_dims;
out_dims[0] = rois_dims[0];
out_dims[1] = input_dims[1];
out_dims[2] = pooled_height;
out_dims[3] = pooled_width;
ctx->SetOutputDim("Out", out_dims);
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::Tensor>("X")->type()),
ctx.device_context());
}
};
class ROIAlignGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"The GRAD@Out of ROIAlignGradOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutputs(framework::GradVarName("X")),
"The GRAD@X of ROIAlignGradOp should not be null.");
ctx->SetOutputsDim(framework::GradVarName("X"), ctx->GetInputsDim("X"));
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::Tensor>("X")->type()),
ctx.device_context());
}
};
class ROIAlignOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X",
"(Tensor), "
"The input of ROIAlignOp. "
"The format of input tensor is NCHW. Where N is batch size, "
"C is the number of input channels, "
"H is the height of the feature, and "
"W is the width of the feature.");
AddInput("ROIs",
"(LoDTensor), "
"ROIs (Regions of Interest) to pool over. "
"should be a 2-D LoDTensor of shape (num_rois, 4)"
"given as [[x1, y1, x2, y2], …]. "
"(x1, y1) is the top left coordinates, and "
"(x2, y2) is the bottom right coordinates.");
AddOutput("Out",
"(Tensor), "
"The output of ROIAlignOp is a 4-D tensor with shape "
"(num_rois, channels, pooled_h, pooled_w).");
AddAttr<float>("spatial_scale",
"(float, default 1.0), "
"Multiplicative spatial scale factor "
"to translate ROI coords from their input scale "
"to the scale used when pooling.")
.SetDefault(1.0);
AddAttr<int>("pooled_height",
"(int, default 1), "
"The pooled output height.")
.SetDefault(1);
AddAttr<int>("pooled_width",
"(int, default 1), "
"The pooled output width.")
.SetDefault(1);
AddAttr<int>("sampling_ratio",
"(int,default -1),"
"number of sampling points in the interpolation grid"
"If <=0, then grid points are adaptive to roi_width "
"and pooled_w, likewise for height")
.SetDefault(-1);
AddComment(R"DOC(
**RoIAlign Operator**
Region of interest align (also known as RoI align) is to perform
bilinear interpolation on inputs of nonuniform sizes to obtain
fixed-size feature maps (e.g. 7*7)
Dividing each region proposal into equal-sized sections with
the pooled_width and pooled_height. Location remains the origin
result.
In each ROI bin, the value of the four regularly sampled locations
are computed directly through bilinear interpolation. The output is
the mean of four locations.
Thus avoid the misaligned problem.
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(roi_align, ops::ROIAlignOp, ops::ROIAlignOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(roi_align_grad, ops::ROIAlignGradOp);
REGISTER_OP_CPU_KERNEL(
roi_align,
ops::CPUROIAlignOpKernel<paddle::platform::CPUDeviceContext, float>,
ops::CPUROIAlignOpKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(
roi_align_grad,
ops::CPUROIAlignGradOpKernel<paddle::platform::CPUDeviceContext, float>,
ops::CPUROIAlignGradOpKernel<paddle::platform::CPUDeviceContext, double>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/roi_align_op.h"
#include "paddle/fluid/platform/cuda_primitives.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
static constexpr int kNumCUDAThreads = 512;
static constexpr int kNumMaxinumNumBlocks = 4096;
static inline int NumBlocks(const int N) {
return std::min((N + kNumCUDAThreads - 1) / kNumCUDAThreads,
kNumMaxinumNumBlocks);
}
#define CUDA_1D_KERNEL_LOOP(i, n) \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
i += blockDim.x * gridDim.x)
template <class T>
__device__ T BilinearInterpolate(const T* input_data, const int height,
const int width, T y, T x) {
if (y < -1.0 || y > height || x < -1.0 || x > width) {
return 0;
}
y = y <= 0 ? 0 : y;
x = x <= 0 ? 0 : x;
int y_low = static_cast<int>(y);
int x_low = static_cast<int>(x);
int y_high;
int x_high;
if (y_low >= height - 1) {
y_high = y_low = height - 1;
y = static_cast<T>(y_low);
} else {
y_high = y_low + 1;
}
if (x_low >= width - 1) {
x_high = x_low = width - 1;
x = static_cast<T>(x_low);
} else {
x_high = x_low + 1;
}
T ly = y - y_low, lx = x - x_low;
T hy = 1. - ly, hx = 1. - lx;
T v1 = input_data[y_low * width + x_low];
T v2 = input_data[y_low * width + x_high];
T v3 = input_data[y_high * width + x_low];
T v4 = input_data[y_high * width + x_high];
T w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx;
T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
return val;
}
template <class T>
__device__ void BilinearInterpolateGradient(const int height, const int width,
T y, T x, T* w1, T* w2, T* w3,
T* w4, int* x_low, int* x_high,
int* y_low, int* y_high) {
if (y < -1.0 || y > height || x < -1.0 || x > width) {
return;
}
y = y <= 0 ? 0 : y;
x = x <= 0 ? 0 : x;
*y_low = static_cast<int>(y);
*x_low = static_cast<int>(x);
if (*y_low >= height - 1) {
*y_high = *y_low = height - 1;
y = static_cast<T>(*y_low);
} else {
*y_high = *y_low + 1;
}
if (*x_low >= width - 1) {
*x_high = *x_low = width - 1;
x = static_cast<T>(*x_low);
} else {
*x_high = *x_low + 1;
}
T ly = y - *y_low, lx = x - *x_low;
T hy = 1. - ly, hx = 1. - lx;
*w1 = hy * hx, *w2 = hy * lx, *w3 = ly * hx, *w4 = ly * lx;
return;
}
template <class T>
__global__ void GPUROIAlignForward(
const int nthreads, const T* input_data, const T* input_rois,
const float spatial_scale, const int channels, const int height,
const int width, const int pooled_height, const int pooled_width,
const int sampling_ratio, int* roi_batch_id_data, T* output_data) {
CUDA_1D_KERNEL_LOOP(i, nthreads) {
int pw = i % pooled_width;
int ph = (i / pooled_width) % pooled_height;
int c = (i / pooled_width / pooled_height) % channels;
int n = i / pooled_width / pooled_height / channels;
const T* offset_input_rois = input_rois + n * kROISize;
int roi_batch_ind = roi_batch_id_data[n];
T roi_xmin = offset_input_rois[0] * spatial_scale;
T roi_ymin = offset_input_rois[1] * spatial_scale;
T roi_xmax = offset_input_rois[2] * spatial_scale;
T roi_ymax = offset_input_rois[3] * spatial_scale;
T roi_width = max(roi_xmax - roi_xmin, static_cast<T>(1.));
T roi_height = max(roi_ymax - roi_ymin, static_cast<T>(1.));
T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
const T* offset_input_data =
input_data + (roi_batch_ind * channels + c) * height * width;
int roi_bin_grid_h = (sampling_ratio > 0)
? sampling_ratio
: ceil(roi_height / pooled_height);
int roi_bin_grid_w =
(sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width);
const T count = roi_bin_grid_h * roi_bin_grid_w;
T output_val = 0;
for (int iy = 0; iy < roi_bin_grid_h; iy++) {
const T y = roi_ymin + ph * bin_size_h +
static_cast<T>(iy + .5f) * bin_size_h /
static_cast<T>(roi_bin_grid_h);
for (int ix = 0; ix < roi_bin_grid_w; ix++) {
const T x = roi_xmin + pw * bin_size_w +
static_cast<T>(ix + .5f) * bin_size_w /
static_cast<T>(roi_bin_grid_w);
T val = BilinearInterpolate(offset_input_data, height, width, y, x);
output_val += val;
}
}
output_val /= count;
output_data[i] = output_val;
}
}
template <typename T>
__global__ void GPUROIAlignBackward(const int nthreads, const T* input_rois,
const T* out_grad, const int num_rois,
const float spatial_scale,
const int channels, const int height,
const int width, const int pooled_height,
const int pooled_width,
const int sampling_ratio,
int* roi_batch_id_data, T* input_grad) {
CUDA_1D_KERNEL_LOOP(i, nthreads) {
int pw = i % pooled_width;
int ph = (i / pooled_width) % pooled_height;
int c = (i / pooled_width / pooled_height) % channels;
int n = i / pooled_width / pooled_height / channels;
const T* offset_input_rois = input_rois + n * kROISize;
int roi_batch_ind = roi_batch_id_data[n];
T roi_xmin = offset_input_rois[0] * spatial_scale;
T roi_ymin = offset_input_rois[1] * spatial_scale;
T roi_xmax = offset_input_rois[2] * spatial_scale;
T roi_ymax = offset_input_rois[3] * spatial_scale;
T roi_width = max(roi_xmax - roi_xmin, static_cast<T>(1.));
T roi_height = max(roi_ymax - roi_ymin, static_cast<T>(1.));
T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
T* offset_input_grad =
input_grad + (roi_batch_ind * channels + c) * height * width;
const T* offset_out_grad =
out_grad + (n * channels + c) * pooled_height * pooled_width;
const T out_grad_this_bin = offset_out_grad[ph * pooled_width + pw];
int roi_bin_grid_h = (sampling_ratio > 0)
? sampling_ratio
: ceil(roi_height / pooled_height);
int roi_bin_grid_w =
(sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width);
const T count = roi_bin_grid_h * roi_bin_grid_w;
for (int iy = 0; iy < roi_bin_grid_h; iy++) {
const T y = roi_ymin + ph * bin_size_h +
static_cast<T>(iy + .5f) * bin_size_h /
static_cast<T>(roi_bin_grid_h);
for (int ix = 0; ix < roi_bin_grid_w; ix++) {
const T x = roi_xmin + pw * bin_size_w +
static_cast<T>(ix + .5f) * bin_size_w /
static_cast<T>(roi_bin_grid_w);
T w1 = 0, w2 = 0, w3 = 0, w4 = 0;
int x_low = -1, x_high = -1, y_low = -1, y_high = -1;
BilinearInterpolateGradient(height, width, y, x, &w1, &w2, &w3, &w4,
&x_low, &x_high, &y_low, &y_high);
T diff1 = out_grad_this_bin * w1 / count;
T diff2 = out_grad_this_bin * w2 / count;
T diff3 = out_grad_this_bin * w3 / count;
T diff4 = out_grad_this_bin * w4 / count;
if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) {
platform::CudaAtomicAdd(offset_input_grad + y_low * width + x_low,
diff1);
platform::CudaAtomicAdd(offset_input_grad + y_low * width + x_high,
diff2);
platform::CudaAtomicAdd(offset_input_grad + y_high * width + x_low,
diff3);
platform::CudaAtomicAdd(offset_input_grad + y_high * width + x_high,
diff4);
}
}
}
}
}
template <typename Place, typename T>
class GPUROIAlignOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* in = ctx.Input<Tensor>("X");
auto* rois = ctx.Input<LoDTensor>("ROIs");
auto* out = ctx.Output<Tensor>("Out");
auto pooled_height = ctx.Attr<int>("pooled_height");
auto pooled_width = ctx.Attr<int>("pooled_width");
auto spatial_scale = ctx.Attr<float>("spatial_scale");
auto sampling_ratio = ctx.Attr<int>("sampling_ratio");
auto in_dims = in->dims();
int batch_size = in_dims[0];
int channels = in_dims[1];
int height = in_dims[2];
int width = in_dims[3];
int rois_num = rois->dims()[0];
if (rois_num == 0) return;
int output_size = out->numel();
int blocks = NumBlocks(output_size);
int threads = kNumCUDAThreads;
Tensor roi_batch_id_list;
roi_batch_id_list.Resize({rois_num});
int* roi_batch_id_data =
roi_batch_id_list.mutable_data<int>(platform::CPUPlace());
auto rois_lod = rois->lod().back();
int rois_batch_size = rois_lod.size() - 1;
PADDLE_ENFORCE_EQ(
rois_batch_size, batch_size,
"The rois_batch_size and imgs batch_size must be the same.");
int rois_num_with_lod = rois_lod[rois_batch_size];
PADDLE_ENFORCE_EQ(rois_num, rois_num_with_lod,
"The rois_num from input and lod must be the same.");
for (int n = 0; n < rois_batch_size; ++n) {
for (size_t i = rois_lod[n]; i < rois_lod[n + 1]; ++i) {
roi_batch_id_data[i] = n;
}
}
Tensor roi_batch_id_list_gpu;
framework::TensorCopySync(roi_batch_id_list, ctx.GetPlace(),
&roi_batch_id_list_gpu);
GPUROIAlignForward<
T><<<blocks, threads, 0, ctx.cuda_device_context().stream()>>>(
output_size, in->data<T>(), rois->data<T>(), spatial_scale, channels,
height, width, pooled_height, pooled_width, sampling_ratio,
roi_batch_id_list_gpu.data<int>(),
out->mutable_data<T>(ctx.GetPlace()));
}
};
template <typename Place, typename T>
class GPUROIAlignGradOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* in = ctx.Input<Tensor>("X");
auto* rois = ctx.Input<LoDTensor>("ROIs");
auto* out_grad = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* in_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
auto pooled_height = ctx.Attr<int>("pooled_height");
auto pooled_width = ctx.Attr<int>("pooled_width");
auto spatial_scale = ctx.Attr<float>("spatial_scale");
auto sampling_ratio = ctx.Attr<int>("sampling_ratio");
int rois_num = rois->dims()[0];
int channels = in->dims()[1];
int height = in->dims()[2];
int width = in->dims()[3];
if (!in_grad) {
return;
}
Tensor roi_batch_id_list;
roi_batch_id_list.Resize({rois_num});
int* roi_batch_id_data =
roi_batch_id_list.mutable_data<int>(platform::CPUPlace());
auto rois_lod = rois->lod().back();
int rois_batch_size = rois_lod.size() - 1;
for (int n = 0; n < rois_batch_size; ++n) {
for (size_t i = rois_lod[n]; i < rois_lod[n + 1]; ++i) {
roi_batch_id_data[i] = n;
}
}
Tensor roi_batch_id_list_gpu;
framework::TensorCopySync(roi_batch_id_list, ctx.GetPlace(),
&roi_batch_id_list_gpu);
in_grad->mutable_data<T>(ctx.GetPlace());
math::SetConstant<Place, T> set_zero;
set_zero(ctx.cuda_device_context(), in_grad, static_cast<T>(0));
int output_grad_size = out_grad->numel();
int blocks = NumBlocks(output_grad_size);
int threads = kNumCUDAThreads;
if (output_grad_size > 0) {
GPUROIAlignBackward<
T><<<blocks, threads, 0, ctx.cuda_device_context().stream()>>>(
output_grad_size, rois->data<T>(), out_grad->data<T>(), rois_num,
spatial_scale, channels, height, width, pooled_height, pooled_width,
sampling_ratio, roi_batch_id_list_gpu.data<int>(),
in_grad->mutable_data<T>(ctx.GetPlace()));
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
roi_align,
ops::GPUROIAlignOpKernel<paddle::platform::CUDADeviceContext, float>,
ops::GPUROIAlignOpKernel<paddle::platform::CUDADeviceContext, double>);
REGISTER_OP_CUDA_KERNEL(
roi_align_grad,
ops::GPUROIAlignGradOpKernel<paddle::platform::CUDADeviceContext, float>,
ops::GPUROIAlignGradOpKernel<paddle::platform::CUDADeviceContext, double>);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <algorithm>
#include <limits>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/math_function.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
static constexpr int kROISize = 4;
template <class T>
void PreCalcForBilinearInterpolate(
const platform::DeviceContext& ctx, const int height, const int width,
const int pooled_height, const int pooled_width, const int iy_upper,
const int ix_upper, T roi_ymin, T roi_xmin, T bin_size_h, T bin_size_w,
int roi_bin_grid_h, int roi_bin_grid_w, Tensor* pre_pos, Tensor* pre_w) {
int pre_calc_index = 0;
int* pre_pos_data = pre_pos->mutable_data<int>(ctx.GetPlace());
T* pre_w_data = pre_w->mutable_data<T>(ctx.GetPlace());
for (int ph = 0; ph < pooled_height; ph++) {
for (int pw = 0; pw < pooled_width; pw++) {
for (int iy = 0; iy < iy_upper; iy++) {
// calculate y of sample points
T y = roi_ymin + ph * bin_size_h +
static_cast<T>(iy + .5f) * bin_size_h /
static_cast<T>(roi_bin_grid_h);
// calculate x of samle points
for (int ix = 0; ix < ix_upper; ix++) {
T x = roi_xmin + pw * bin_size_w +
static_cast<T>(ix + .5f) * bin_size_w /
static_cast<T>(roi_bin_grid_w);
// deal with elements out of map
if (y < -1.0 || y > height || x < -1.0 || x > width) {
for (int i = 0; i < kROISize; ++i) {
pre_pos_data[i + pre_calc_index * kROISize] = 0;
pre_w_data[i + pre_calc_index * kROISize] = 0;
}
pre_calc_index += 1;
continue;
}
y = y <= 0 ? 0 : y;
x = x <= 0 ? 0 : x;
int y_low = static_cast<int>(y);
int x_low = static_cast<int>(x);
int y_high;
int x_high;
if (y_low >= height - 1) {
y_high = y_low = height - 1;
y = static_cast<T>(y_low);
} else {
y_high = y_low + 1;
}
if (x_low >= width - 1) {
x_high = x_low = width - 1;
x = static_cast<T>(x_low);
} else {
x_high = x_low + 1;
}
T ly = y - y_low, lx = x - x_low;
T hy = 1. - ly, hx = 1. - lx;
pre_pos_data[pre_calc_index * kROISize] = y_low * width + x_low;
pre_pos_data[pre_calc_index * kROISize + 1] = y_low * width + x_high;
pre_pos_data[pre_calc_index * kROISize + 2] = y_high * width + x_low;
pre_pos_data[pre_calc_index * kROISize + 3] = y_high * width + x_high;
pre_w_data[pre_calc_index * kROISize] = hy * hx;
pre_w_data[pre_calc_index * kROISize + 1] = hy * lx;
pre_w_data[pre_calc_index * kROISize + 2] = ly * hx;
pre_w_data[pre_calc_index * kROISize + 3] = ly * lx;
pre_calc_index += 1;
}
}
}
}
}
template <class T>
void bilinear_interpolate_gradient(const int height, const int width, T y, T x,
const T out_grad_this_bin, const T count,
T* batch_grad_data) {
int x_low, y_low, x_high, y_high;
T w1, w2, w3, w4;
if (y < -1.0 || y > height || x < -1.0 || x > width) {
w1 = w2 = w3 = w4 = 0;
x_low = x_high = y_low = y_high = -1;
return;
}
y = y <= 0 ? 0 : y;
x = x <= 0 ? 0 : x;
y_low = static_cast<int>(y);
x_low = static_cast<int>(x);
if (y_low >= height - 1) {
y_high = y_low = height - 1;
y = static_cast<T>(y_low);
} else {
y_high = y_low + 1;
}
if (x_low >= width - 1) {
x_high = x_low = width - 1;
x = static_cast<T>(x_low);
} else {
x_high = x_low + 1;
}
T ly = y - y_low, lx = x - x_low;
T hy = 1. - ly, hx = 1. - lx;
w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx;
T diff1 = out_grad_this_bin * w1 / count;
T diff2 = out_grad_this_bin * w2 / count;
T diff3 = out_grad_this_bin * w3 / count;
T diff4 = out_grad_this_bin * w4 / count;
if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) {
*(batch_grad_data + y_low * width + x_low) += diff1;
*(batch_grad_data + y_low * width + x_high) += diff2;
*(batch_grad_data + y_high * width + x_low) += diff3;
*(batch_grad_data + y_high * width + x_high) += diff4;
}
}
template <typename DeviceContext, typename T>
class CPUROIAlignOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* in = ctx.Input<framework::Tensor>("X");
auto* rois = ctx.Input<framework::LoDTensor>("ROIs");
auto* out = ctx.Output<framework::Tensor>("Out");
auto pooled_height = ctx.Attr<int>("pooled_height");
auto pooled_width = ctx.Attr<int>("pooled_width");
auto spatial_scale = ctx.Attr<float>("spatial_scale");
auto sampling_ratio = ctx.Attr<int>("sampling_ratio");
auto& dev_ctx = ctx.template device_context<DeviceContext>();
auto in_dims = in->dims();
int batch_size = in_dims[0];
int channels = in_dims[1];
int height = in_dims[2];
int width = in_dims[3];
int rois_num = rois->dims()[0];
auto in_stride = framework::stride(in_dims);
auto roi_stride = framework::stride(rois->dims());
auto out_stride = framework::stride(out->dims());
const T* input_data = in->data<T>();
framework::Tensor roi_batch_id_list;
roi_batch_id_list.Resize({rois_num});
int* roi_batch_id_data =
roi_batch_id_list.mutable_data<int>(ctx.GetPlace());
auto rois_lod = rois->lod().back();
int rois_batch_size = rois_lod.size() - 1;
PADDLE_ENFORCE_EQ(
rois_batch_size, batch_size,
"The rois_batch_size and imgs batch_size must be the same.");
int rois_num_with_lod = rois_lod[rois_batch_size];
PADDLE_ENFORCE_EQ(rois_num, rois_num_with_lod,
"The rois_num from input and lod must be the same.");
for (int n = 0; n < rois_batch_size; ++n) {
for (size_t i = rois_lod[n]; i < rois_lod[n + 1]; ++i) {
roi_batch_id_data[i] = n;
}
}
T* output_data = out->mutable_data<T>(ctx.GetPlace());
const T* rois_data = rois->data<T>();
for (int n = 0; n < rois_num; ++n) {
int roi_batch_id = roi_batch_id_data[n];
T roi_xmin = rois_data[0] * spatial_scale;
T roi_ymin = rois_data[1] * spatial_scale;
T roi_xmax = rois_data[2] * spatial_scale;
T roi_ymax = rois_data[3] * spatial_scale;
T roi_width = std::max(roi_xmax - roi_xmin, static_cast<T>(1.));
T roi_height = std::max(roi_ymax - roi_ymin, static_cast<T>(1.));
T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
const T* batch_data = input_data + roi_batch_id * in_stride[0];
int roi_bin_grid_h = (sampling_ratio > 0)
? sampling_ratio
: ceil(roi_height / pooled_height);
int roi_bin_grid_w = (sampling_ratio > 0)
? sampling_ratio
: ceil(roi_width / pooled_width);
const T count = roi_bin_grid_h * roi_bin_grid_w;
Tensor pre_pos;
Tensor pre_w;
int pre_size = count * out_stride[1];
pre_pos.Resize({pre_size, kROISize});
pre_w.Resize({pre_size, kROISize});
PreCalcForBilinearInterpolate(
dev_ctx, height, width, pooled_height, pooled_width, roi_bin_grid_h,
roi_bin_grid_w, roi_ymin, roi_xmin, bin_size_h, bin_size_w,
roi_bin_grid_h, roi_bin_grid_w, &pre_pos, &pre_w);
const int* pre_pos_data = pre_pos.data<int>();
const T* pre_w_data = pre_w.data<T>();
for (int c = 0; c < channels; c++) {
int pre_calc_index = 0;
for (int ph = 0; ph < pooled_height; ph++) {
for (int pw = 0; pw < pooled_width; pw++) {
const int pool_index = ph * pooled_width + pw;
T output_val = 0;
for (int iy = 0; iy < roi_bin_grid_h; iy++) {
for (int ix = 0; ix < roi_bin_grid_w; ix++) {
for (int i = 0; i < kROISize; i++) {
int pos = pre_pos_data[pre_calc_index * kROISize + i];
T w = pre_w_data[pre_calc_index * kROISize + i];
output_val += w * batch_data[pos];
}
pre_calc_index += 1;
}
}
output_val /= count;
output_data[pool_index] = output_val;
}
}
batch_data += in_stride[1];
output_data += out_stride[1];
}
rois_data += roi_stride[0];
}
}
};
template <typename DeviceContext, typename T>
class CPUROIAlignGradOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* in = ctx.Input<framework::Tensor>("X");
auto* rois = ctx.Input<framework::LoDTensor>("ROIs");
auto* out_grad =
ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
auto* in_grad = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
auto pooled_height = ctx.Attr<int>("pooled_height");
auto pooled_width = ctx.Attr<int>("pooled_width");
auto spatial_scale = ctx.Attr<float>("spatial_scale");
auto sampling_ratio = ctx.Attr<int>("sampling_ratio");
auto in_dims = in->dims();
if (!in_grad) {
return;
}
int channels = in_dims[1];
int height = in_dims[2];
int width = in_dims[3];
int rois_num = rois->dims()[0];
Tensor roi_batch_id_list;
roi_batch_id_list.Resize({rois_num});
int* roi_batch_id_data =
roi_batch_id_list.mutable_data<int>(ctx.GetPlace());
auto rois_lod = rois->lod().back();
int rois_batch_size = rois_lod.size() - 1;
for (int n = 0; n < rois_batch_size; ++n) {
for (size_t i = rois_lod[n]; i < rois_lod[n + 1]; ++i) {
roi_batch_id_data[i] = n;
}
}
const T* rois_data = rois->data<T>();
const T* out_grad_data = out_grad->data<T>();
T* in_grad_data = in_grad->mutable_data<T>(ctx.GetPlace());
auto in_stride = framework::stride(in->dims());
auto roi_stride = framework::stride(rois->dims());
auto out_stride = framework::stride(out_grad->dims());
for (int n = 0; n < rois_num; ++n) {
int roi_batch_idx = roi_batch_id_data[n];
T roi_xmin = rois_data[0] * spatial_scale;
T roi_ymin = rois_data[1] * spatial_scale;
T roi_xmax = rois_data[2] * spatial_scale;
T roi_ymax = rois_data[3] * spatial_scale;
T roi_width = std::max(roi_xmax - roi_xmin, static_cast<T>(1.));
T roi_height = std::max(roi_ymax - roi_ymin, static_cast<T>(1.));
T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
for (int c = 0; c < channels; ++c) {
T* batch_grad_data =
in_grad_data + roi_batch_idx * in_stride[0] + c * in_stride[1];
const T* batch_out_grad_data =
out_grad_data + n * out_stride[0] + c * out_stride[1];
for (int ph = 0; ph < pooled_height; ++ph) {
for (int pw = 0; pw < pooled_width; ++pw) {
int pool_index = ph * pooled_width + pw;
T out_grad_this_bin = batch_out_grad_data[pool_index];
int roi_bin_grid_h = (sampling_ratio > 0)
? sampling_ratio
: ceil(roi_height / pooled_height);
int roi_bin_grid_w = (sampling_ratio > 0)
? sampling_ratio
: ceil(roi_width / pooled_width);
T count = roi_bin_grid_h * roi_bin_grid_w;
for (int iy = 0; iy < roi_bin_grid_h; iy++) {
const T y = roi_ymin + ph * bin_size_h +
static_cast<T>(iy + .5f) * bin_size_h /
static_cast<T>(roi_bin_grid_h);
for (int ix = 0; ix < roi_bin_grid_w; ix++) {
const T x = roi_xmin + pw * bin_size_w +
static_cast<T>(ix + .5f) * bin_size_w /
static_cast<T>(roi_bin_grid_w);
bilinear_interpolate_gradient(height, width, y, x,
out_grad_this_bin, count,
batch_grad_data);
}
}
}
}
}
rois_data += roi_stride[0];
}
}
};
} // namespace operators
} // namespace paddle
......@@ -96,6 +96,7 @@ __all__ = [
'pad_constant_like',
'label_smooth',
'roi_pool',
'roi_align',
'dice_loss',
'image_resize',
'image_resize_short',
......@@ -5430,6 +5431,54 @@ def roi_pool(input, rois, pooled_height=1, pooled_width=1, spatial_scale=1.0):
return pool_out
@templatedoc()
def roi_align(input,
rois,
pooled_height=1,
pooled_width=1,
spatial_scale=1.0,
sampling_ratio=-1,
name=None):
"""
${comment}
Args:
input (Variable): ${x_comment}
rois (Variable): ROIs (Regions of Interest) to pool over.
pooled_height (integer): ${pooled_height_comment} Default: 1
pooled_width (integer): ${pooled_width_comment} Default: 1
spatial_scale (float): ${spatial_scale_comment} Default: 1.0
sampling_ratio(intger): ${sampling_ratio_comment} Default: -1
Returns:
Variable: ${out_comment}.
Examples:
.. code-block:: python
align_out = fluid.layers.roi_align(input=x,
rois=rois,
pooled_height=7,
pooled_width=7,
spatial_scale=0.5,
sampling_ratio=-1)
"""
helper = LayerHelper('roi_align', **locals())
dtype = helper.input_dtype()
align_out = helper.create_tmp_variable(dtype)
helper.append_op(
type="roi_align",
inputs={"X": input,
"ROIs": rois},
outputs={"Out": align_out},
attrs={
"pooled_height": pooled_height,
"pooled_width": pooled_width,
"spatial_scale": spatial_scale,
"sampling_ratio": sampling_ratio
})
return align_out
def dice_loss(input, label, epsilon=0.00001):
"""
Dice loss for comparing the similarity of two batch of data,
......
......@@ -465,6 +465,16 @@ class TestBook(unittest.TestCase):
self.assertIsNotNone(output)
print(str(program))
def test_roi_align(self):
program = Program()
with program_guard(program):
x = layers.data(name="x", shape=[256, 30, 30], dtype="float32")
rois = layers.data(
name="rois", shape=[4], dtype="float32", lod_level=1)
output = layers.roi_align(x, rois, 14, 14, 0.5, 2)
self.assertIsNotNone(output)
print(str(program))
def test_resize_bilinear(self):
program = Program()
with program_guard(program):
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import math
import sys
from op_test import OpTest
class TestROIAlignOp(OpTest):
def set_data(self):
self.init_test_case()
self.make_rois()
self.calc_roi_align()
self.inputs = {'X': self.x, 'ROIs': (self.rois[:, 1:5], self.rois_lod)}
self.attrs = {
'spatial_scale': self.spatial_scale,
'pooled_height': self.pooled_height,
'pooled_width': self.pooled_width,
'sampling_ratio': self.sampling_ratio
}
self.outputs = {'Out': self.out_data}
def init_test_case(self):
self.batch_size = 3
self.channels = 3
self.height = 8
self.width = 6
# n, c, h, w
self.x_dim = (self.batch_size, self.channels, self.height, self.width)
self.spatial_scale = 1.0 / 2.0
self.pooled_height = 2
self.pooled_width = 2
self.sampling_ratio = -1
self.x = np.random.random(self.x_dim).astype('float32')
def pre_calc(self, x_i, roi_xmin, roi_ymin, roi_bin_grid_h, roi_bin_grid_w,
bin_size_h, bin_size_w):
count = roi_bin_grid_h * roi_bin_grid_w
bilinear_pos = np.zeros(
[self.channels, self.pooled_height, self.pooled_width, count, 4],
np.float32)
bilinear_w = np.zeros(
[self.pooled_height, self.pooled_width, count, 4], np.float32)
for ph in range(self.pooled_width):
for pw in range(self.pooled_height):
c = 0
for iy in range(roi_bin_grid_h):
y = roi_ymin + ph * bin_size_h + (iy + 0.5) * \
bin_size_h / roi_bin_grid_h
for ix in range(roi_bin_grid_w):
x = roi_xmin + pw * bin_size_w + (ix + 0.5) * \
bin_size_w / roi_bin_grid_w
if y < -1.0 or y > self.height or \
x < -1.0 or x > self.width:
continue
if y <= 0:
y = 0
if x <= 0:
x = 0
y_low = int(y)
x_low = int(x)
if y_low >= self.height - 1:
y = y_high = y_low = self.height - 1
else:
y_high = y_low + 1
if x_low >= self.width - 1:
x = x_high = x_low = self.width - 1
else:
x_high = x_low + 1
ly = y - y_low
lx = x - x_low
hy = 1 - ly
hx = 1 - lx
for ch in range(self.channels):
bilinear_pos[ch, ph, pw, c, 0] = x_i[ch, y_low,
x_low]
bilinear_pos[ch, ph, pw, c, 1] = x_i[ch, y_low,
x_high]
bilinear_pos[ch, ph, pw, c, 2] = x_i[ch, y_high,
x_low]
bilinear_pos[ch, ph, pw, c, 3] = x_i[ch, y_high,
x_high]
bilinear_w[ph, pw, c, 0] = hy * hx
bilinear_w[ph, pw, c, 1] = hy * lx
bilinear_w[ph, pw, c, 2] = ly * hx
bilinear_w[ph, pw, c, 3] = ly * lx
c = c + 1
return bilinear_pos, bilinear_w
def calc_roi_align(self):
self.out_data = np.zeros(
(self.rois_num, self.channels, self.pooled_height,
self.pooled_width)).astype('float32')
for i in range(self.rois_num):
roi = self.rois[i]
roi_batch_id = int(roi[0])
x_i = self.x[roi_batch_id]
roi_xmin = roi[1] * self.spatial_scale
roi_ymin = roi[2] * self.spatial_scale
roi_xmax = roi[3] * self.spatial_scale
roi_ymax = roi[4] * self.spatial_scale
roi_width = max(roi_xmax - roi_xmin, 1)
roi_height = max(roi_ymax - roi_ymin, 1)
bin_size_h = float(roi_height) / float(self.pooled_height)
bin_size_w = float(roi_width) / float(self.pooled_width)
roi_bin_grid_h = self.sampling_ratio if self.sampling_ratio > 0 else \
math.ceil(roi_height / self.pooled_height)
roi_bin_grid_w = self.sampling_ratio if self.sampling_ratio > 0 else \
math.ceil(roi_width / self.pooled_width)
count = int(roi_bin_grid_h * roi_bin_grid_w)
pre_size = count * self.pooled_width * self.pooled_height
bilinear_pos, bilinear_w = self.pre_calc(x_i, roi_xmin, roi_ymin,
int(roi_bin_grid_h),
int(roi_bin_grid_w),
bin_size_h, bin_size_w)
for ch in range(self.channels):
align_per_bin = (bilinear_pos[ch] * bilinear_w).sum(axis=-1)
output_val = align_per_bin.mean(axis=-1)
self.out_data[i, ch, :, :] = output_val
def make_rois(self):
rois = []
self.rois_lod = [[]]
for bno in range(self.batch_size):
self.rois_lod[0].append(bno + 1)
for i in range(bno + 1):
x1 = np.random.random_integers(
0, self.width // self.spatial_scale - self.pooled_width)
y1 = np.random.random_integers(
0, self.height // self.spatial_scale - self.pooled_height)
x2 = np.random.random_integers(x1 + self.pooled_width,
self.width // self.spatial_scale)
y2 = np.random.random_integers(
y1 + self.pooled_height, self.height // self.spatial_scale)
roi = [bno, x1, y1, x2, y2]
rois.append(roi)
self.rois_num = len(rois)
self.rois = np.array(rois).astype("float32")
def setUp(self):
self.op_type = "roi_align"
self.set_data()
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Out')
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册