未验证 提交 3044a62f 编写于 作者: W wopeizl 提交者: GitHub

fix the precise roi poop op test=develop (#20126)

* fix the precise roi poop op test=develop
add roi backward implementation, fix the output-channel
上级 cff9970b
......@@ -290,7 +290,7 @@ paddle.fluid.layers.shuffle_channel (ArgSpec(args=['x', 'group', 'name'], vararg
paddle.fluid.layers.temporal_shift (ArgSpec(args=['x', 'seg_num', 'shift_ratio', 'name'], varargs=None, keywords=None, defaults=(0.25, None)), ('document', 'd5945431cdcae3cda21914db5bbf383e'))
paddle.fluid.layers.py_func (ArgSpec(args=['func', 'x', 'out', 'backward_func', 'skip_vars_in_backward_input'], varargs=None, keywords=None, defaults=(None, None)), ('document', '231f91231430f5dae2b757df22317c67'))
paddle.fluid.layers.psroi_pool (ArgSpec(args=['input', 'rois', 'output_channels', 'spatial_scale', 'pooled_height', 'pooled_width', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '9bf0cc6b0717010b8ceec5dc2541d566'))
paddle.fluid.layers.prroi_pool (ArgSpec(args=['input', 'rois', 'output_channels', 'spatial_scale', 'pooled_height', 'pooled_width', 'name'], varargs=None, keywords=None, defaults=(1.0, 1, 1, None)), ('document', '454c7ea8c73313dd41513929d7526303'))
paddle.fluid.layers.prroi_pool (ArgSpec(args=['input', 'rois', 'spatial_scale', 'pooled_height', 'pooled_width', 'name'], varargs=None, keywords=None, defaults=(1.0, 1, 1, None)), ('document', '466be691ac4c1cd7b88fccb40846afce'))
paddle.fluid.layers.teacher_student_sigmoid_loss (ArgSpec(args=['input', 'label', 'soft_max_up_bound', 'soft_max_lower_bound'], varargs=None, keywords=None, defaults=(15.0, -15.0)), ('document', 'b0e07aa41caae04b07a8e8217cc96020'))
paddle.fluid.layers.huber_loss (ArgSpec(args=['input', 'label', 'delta'], varargs=None, keywords=None, defaults=None), ('document', '9d93ee81f7a3e526d68bb280bc695d6c'))
paddle.fluid.layers.kldiv_loss (ArgSpec(args=['x', 'target', 'reduction', 'name'], varargs=None, keywords=None, defaults=('mean', None)), ('document', '45f3ebbcb766fca84cb2fe6307086573'))
......
......@@ -43,12 +43,6 @@ class PRROIPoolOpMaker : public framework::OpProtoAndCheckerMaker {
"(Tensor), "
"the output of PRROIPoolOp is a 4-D Tensor with shape "
"(num_rois, output_channels, pooled_h, pooled_w).");
AddAttr<int>(
"output_channels",
"(int), "
"the number of channels of the output feature map. "
"For a task of C classes of objects, output_channels should be "
"(C + 1) for classification only.");
AddAttr<float>("spatial_scale",
"(float, default 1.0), "
"Multiplicative spatial scale factor "
......@@ -100,28 +94,18 @@ class PRROIPoolOp : public framework::OperatorWithKernel {
int pooled_height = ctx->Attrs().Get<int>("pooled_height");
int pooled_width = ctx->Attrs().Get<int>("pooled_width");
int output_channels = ctx->Attrs().Get<int>("output_channels");
float spatial_scale = ctx->Attrs().Get<float>("spatial_scale");
PADDLE_ENFORCE_EQ(
input_dims[1], output_channels * pooled_height * pooled_width,
"the channel of X(%d) should be equal to the product of "
"output_channels(%d), pooled_height(%d) and pooled_width(%d)",
input_dims[1], output_channels, pooled_height, pooled_width);
PADDLE_ENFORCE_GT(pooled_height, 0,
"The pooled output height must be greater than 0");
PADDLE_ENFORCE_GT(pooled_width, 0,
"The pooled output width must be greater than 0");
PADDLE_ENFORCE_GT(output_channels, 1,
"The pooled output channels must greater than 1");
PADDLE_ENFORCE_GT(spatial_scale, 0.0f,
"The spatial scale must greater than 0.");
auto out_dims = input_dims;
out_dims[0] = rois_dims[0];
out_dims[1] =
output_channels; // input_dims[1] / (pooled_height * pooled_width);
out_dims[1] = input_dims[1];
out_dims[2] = pooled_height;
out_dims[3] = pooled_width;
ctx->SetOutputDim("Out", out_dims);
......@@ -145,6 +129,7 @@ class PRROIPoolGradOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_EQ(ctx->HasOutput(framework::GradVarName("X")), true,
"The gradient of X should not be null.");
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
ctx->SetOutputDim(framework::GradVarName("ROIs"), ctx->GetInputDim("ROIs"));
}
protected:
......@@ -164,9 +149,11 @@ class PRROIPoolGradDescMaker : public framework::SingleGradOpDescMaker {
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
op->SetType("prroi_pool_grad");
op->SetInput("X", Input("X"));
op->SetInput("Out", Output("Out"));
op->SetInput("ROIs", Input("ROIs"));
op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
op->SetOutput(framework::GradVarName("ROIs"), InputGrad("ROIs"));
op->SetAttrMap(Attrs());
return op;
}
......
......@@ -40,6 +40,11 @@ DEVICE void PrRoIPoolingDistributeDiffCUDA(T* diff, const T top_diff,
}
}
template <typename T>
DEVICE void GPUAccumulateRois(T* offset, T data) {
paddle::platform::CudaAtomicAdd(offset, data);
}
template <typename T>
__global__ void GPUPRROIPoolForward(
const int nthreads, const T* input_data, const T* input_rois,
......@@ -78,7 +83,7 @@ __global__ void GPUPRROIPoolForward(
T win_end_h = win_start_h + bin_size_h;
T win_size = max(static_cast<T>(0.0), bin_size_w * bin_size_h);
int input_channel = (c * pooled_height + ph) * pooled_width + pw;
int input_channel = c;
const T* offset_input_data =
input_data +
(roi_batch_id * input_channels + input_channel) * height * width;
......@@ -110,10 +115,12 @@ __global__ void GPUPRROIPoolForward(
template <typename T>
__global__ void GPUPRROIPoolBackward(
const int nthreads, const T* input_rois, const T* output_grad_data,
const float spatial_scale, const int input_channels, const int height,
const int width, const int output_channels, const int pooled_height,
const int pooled_width, const int* rois_batch_id_data, T* input_grad_data) {
const int nthreads, const T* in_data, const T* input_rois,
const T* output_grad_data, const float spatial_scale,
const int input_channels, const int height, const int width,
const int output_channels, const int pooled_height, const int pooled_width,
const int* rois_batch_id_data, T* input_grad_data, const T* out_data,
T* input_roi_grad_data) {
int index = blockIdx.x * blockDim.x + threadIdx.x;
int offset = blockDim.x * gridDim.x;
for (int i = index; i < nthreads; i += offset) {
......@@ -125,7 +132,7 @@ __global__ void GPUPRROIPoolBackward(
// set roi_batch_id
int roi_batch_id = rois_batch_id_data[n];
int input_channel = (c * pooled_height + ph) * pooled_width + pw;
int input_channel = c;
int input_offset =
(roi_batch_id * input_channels + input_channel) * height * width;
T* offset_input_grad_data = input_grad_data + input_offset;
......@@ -137,6 +144,7 @@ __global__ void GPUPRROIPoolBackward(
T roi_start_h = static_cast<T>(offset_input_rois[1]) * spatial_scale;
T roi_end_w = static_cast<T>(offset_input_rois[2]) * spatial_scale;
T roi_end_h = static_cast<T>(offset_input_rois[3]) * spatial_scale;
T* offset_input_roi_grad_data = input_roi_grad_data + n * 4;
T roi_width = max(roi_end_w - roi_start_w, static_cast<T>(0.0));
T roi_height = max(roi_end_h - roi_start_h, static_cast<T>(0.0));
......@@ -171,6 +179,16 @@ __global__ void GPUPRROIPoolBackward(
height, width, PrRoIPoolingDistributeDiffCUDA<T>);
}
}
const T* offset_out_data = out_data + i;
const T* offset_in_data = in_data + input_offset;
PrRoIPoolingCoorBackward(
s_w, e_w, s_h, e_h, width, height, win_start_w, win_start_h, win_end_w,
win_end_h, pw, ph, pooled_width, pooled_height, win_size, spatial_scale,
offset_in_data, offset_out_data, offset_input_grad_data,
offset_input_roi_grad_data, GPUAccumulateRois<T>,
[](const T x, const T y) { return max(x, y); },
[](const T x, const T y) { return min(x, y); });
}
}
......@@ -184,20 +202,15 @@ class GPUPRROIPoolOpKernel : public framework::OpKernel<T> {
auto pooled_height = ctx.Attr<int>("pooled_height");
auto pooled_width = ctx.Attr<int>("pooled_width");
auto output_channels = ctx.Attr<int>("output_channels");
auto spatial_scale = ctx.Attr<float>("spatial_scale");
auto in_dims = in->dims();
int batch_size = in_dims[0];
int input_channels = in_dims[1];
auto output_channels = input_channels;
int height = in_dims[2];
int width = in_dims[3];
PADDLE_ENFORCE_EQ(input_channels,
output_channels * pooled_height * pooled_width,
"the channels of input X should equal the product of "
"output_channels x pooled_height x pooled_width");
int rois_num = rois->dims()[0];
if (rois_num == 0) return;
......@@ -245,17 +258,20 @@ class GPUPRROIPoolGradOpKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& ctx) const override {
auto* in = ctx.Input<Tensor>("X");
auto* rois = ctx.Input<LoDTensor>("ROIs");
auto* out = ctx.Input<framework::Tensor>("Out");
auto* output_grad = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* input_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* input_roi_grad =
ctx.Output<LoDTensor>(framework::GradVarName("ROIs"));
auto pooled_height = ctx.Attr<int>("pooled_height");
auto pooled_width = ctx.Attr<int>("pooled_width");
auto output_channels = ctx.Attr<int>("output_channels");
auto spatial_scale = ctx.Attr<float>("spatial_scale");
int rois_num = rois->dims()[0];
int input_channels = in->dims()[1];
auto output_channels = input_channels;
int height = in->dims()[2];
int width = in->dims()[3];
......@@ -280,6 +296,8 @@ class GPUPRROIPoolGradOpKernel : public framework::OpKernel<T> {
input_grad->mutable_data<T>(ctx.GetPlace());
math::SetConstant<DeviceContext, T> set_zero;
set_zero(ctx.cuda_device_context(), input_grad, static_cast<T>(0));
input_roi_grad->mutable_data<T>(ctx.GetPlace());
set_zero(ctx.cuda_device_context(), input_roi_grad, static_cast<T>(0));
int output_grad_size = output_grad->numel();
int blocks = NumBlocks(output_grad_size);
......@@ -288,10 +306,12 @@ class GPUPRROIPoolGradOpKernel : public framework::OpKernel<T> {
if (output_grad_size > 0) {
GPUPRROIPoolBackward<
T><<<blocks, threads, 0, ctx.cuda_device_context().stream()>>>(
output_grad_size, rois->data<T>(), output_grad->data<T>(),
spatial_scale, input_channels, height, width, output_channels,
pooled_height, pooled_width, rois_batch_id_list_gpu.data<int>(),
input_grad->mutable_data<T>(ctx.GetPlace()));
output_grad_size, in->data<T>(), rois->data<T>(),
output_grad->data<T>(), spatial_scale, input_channels, height,
width, output_channels, pooled_height, pooled_width,
rois_batch_id_list_gpu.data<int>(),
input_grad->mutable_data<T>(ctx.GetPlace()), out->data<T>(),
input_roi_grad->mutable_data<T>(ctx.GetPlace()));
}
}
}
......
......@@ -21,19 +21,20 @@ namespace paddle {
namespace operators {
template <typename T>
HOSTDEVICE T PrRoIPoolingGetData(const T* data, const int h, const int w,
const int height, const int width) {
inline HOSTDEVICE T PrRoIPoolingGetData(const T* data, const int h, const int w,
const int height, const int width) {
bool overflow = (h < 0) || (w < 0) || (h >= height) || (w >= width);
T retVal = overflow ? 0.0f : data[h * width + w];
return retVal;
}
template <typename T>
HOSTDEVICE T PrRoIPoolingMatCalculation(const T* this_data, const int s_h,
const int s_w, const int e_h,
const int e_w, const T y0, const T x0,
const T y1, const T x1, const int h0,
const int w0) {
inline HOSTDEVICE T PrRoIPoolingMatCalculation(const T* this_data,
const int s_h, const int s_w,
const int e_h, const int e_w,
const T y0, const T x0,
const T y1, const T x1,
const int h0, const int w0) {
T alpha, beta, lim_alpha, lim_beta, tmp;
T sum_out = 0;
......@@ -73,10 +74,11 @@ HOSTDEVICE T PrRoIPoolingMatCalculation(const T* this_data, const int s_h,
}
template <typename T>
HOSTDEVICE void PrRoIPoolingDistributeDiff(T* diff, const T top_diff,
const int h, const int w,
const int height, const int width,
const T coeff) {
inline HOSTDEVICE void PrRoIPoolingDistributeDiff(T* diff, const T top_diff,
const int h, const int w,
const int height,
const int width,
const T coeff) {
bool overflow = (h < 0) || (w < 0) || (h >= height) || (w >= width);
if (!overflow) {
*(diff + h * width + w) = top_diff * coeff;
......@@ -123,6 +125,132 @@ HOSTDEVICE void PrRoIPoolingMatDistributeDiff(
functor(diff, top_diff, e_h, e_w, h0, w0, tmp);
}
template <typename T>
inline HOSTDEVICE void CPUAccumulateRois(T* offset, T data) {
*offset += data;
}
template <typename T>
inline HOSTDEVICE static T PrRoIPoolingGetCoeff(T dh, T dw) {
dw = dw > 0 ? dw : -dw;
dh = dh > 0 ? dh : -dh;
return (1.0f - dh) * (1.0f - dw);
}
template <typename T, typename H, typename W>
inline HOSTDEVICE static T PrRoIPoolingInterpolation(const T* data, const H h,
const W w,
const int height,
const int width) {
T retVal = 0.0f;
int h1 = floorf(h);
int w1 = floorf(w);
retVal +=
PrRoIPoolingGetData(data, h1, w1, height, width) *
PrRoIPoolingGetCoeff(h - static_cast<T>(h1), w - static_cast<T>(w1));
h1 = floorf(h) + 1;
w1 = floorf(w);
retVal +=
PrRoIPoolingGetData(data, h1, w1, height, width) *
PrRoIPoolingGetCoeff(h - static_cast<T>(h1), w - static_cast<T>(w1));
h1 = floorf(h);
w1 = floorf(w) + 1;
retVal +=
PrRoIPoolingGetData(data, h1, w1, height, width) *
PrRoIPoolingGetCoeff(h - static_cast<T>(h1), w - static_cast<T>(w1));
h1 = floorf(h) + 1;
w1 = floorf(w) + 1;
retVal +=
PrRoIPoolingGetData(data, h1, w1, height, width) *
PrRoIPoolingGetCoeff(h - static_cast<T>(h1), w - static_cast<T>(w1));
return retVal;
}
template <typename T>
inline HOSTDEVICE T PrRoIPoolingSingleCoorIntegral(T s, T t, T c1, T c2) {
return 0.5f * (t * t - s * s) * c2 +
(t - 0.5f * t * t - s + 0.5f * s * s) * c1;
}
template <typename T, typename Functor, typename MaxFunctor,
typename MinFunctor>
inline HOSTDEVICE void PrRoIPoolingCoorBackward(
int s_w, int e_w, int s_h, int e_h, int width, int height, T win_start_w,
T win_start_h, T win_end_w, T win_end_h, int pw, int ph,
const int pooled_width, const int pooled_height, T win_size,
const float spatial_scale, const T* this_bottom_data,
const T* this_top_data, T* this_data_grad, T* this_out_grad,
Functor functor, MaxFunctor maxFunctor, MinFunctor minFunctor) {
T g_x1_y = 0.f;
T g_x2_y = 0.f;
T g_x_y1 = 0.f;
T g_x_y2 = 0.f;
for (int h_iter = s_h; h_iter < e_h; ++h_iter) {
g_x1_y += PrRoIPoolingSingleCoorIntegral(
maxFunctor(win_start_h, static_cast<T>(h_iter)) - h_iter,
minFunctor(win_end_h, static_cast<T>(h_iter + 1)) - h_iter,
PrRoIPoolingInterpolation(this_bottom_data, h_iter, win_start_w, height,
width),
PrRoIPoolingInterpolation(this_bottom_data, h_iter + 1, win_start_w,
height, width));
g_x2_y += PrRoIPoolingSingleCoorIntegral(
maxFunctor(win_start_h, static_cast<T>(h_iter)) - h_iter,
minFunctor(win_end_h, static_cast<T>(h_iter + 1)) - h_iter,
PrRoIPoolingInterpolation(this_bottom_data, h_iter, win_end_w, height,
width),
PrRoIPoolingInterpolation(this_bottom_data, h_iter + 1, win_end_w,
height, width));
}
for (int w_iter = s_w; w_iter < e_w; ++w_iter) {
g_x_y1 += PrRoIPoolingSingleCoorIntegral(
maxFunctor(win_start_w, static_cast<T>(w_iter)) - w_iter,
minFunctor(win_end_w, static_cast<T>(w_iter + 1)) - w_iter,
PrRoIPoolingInterpolation(this_bottom_data, win_start_h, w_iter, height,
width),
PrRoIPoolingInterpolation(this_bottom_data, win_start_h, w_iter + 1,
height, width));
g_x_y2 += PrRoIPoolingSingleCoorIntegral(
maxFunctor(win_start_w, static_cast<T>(w_iter)) - w_iter,
minFunctor(win_end_w, static_cast<T>(w_iter + 1)) - w_iter,
PrRoIPoolingInterpolation(this_bottom_data, win_end_h, w_iter, height,
width),
PrRoIPoolingInterpolation(this_bottom_data, win_end_h, w_iter + 1,
height, width));
}
float partial_x1 = -g_x1_y + (win_end_h - win_start_h) * (*this_top_data);
float partial_y1 = -g_x_y1 + (win_end_w - win_start_w) * (*this_top_data);
float partial_x2 = g_x2_y - (win_end_h - win_start_h) * (*this_top_data);
float partial_y2 = g_x_y2 - (win_end_w - win_start_w) * (*this_top_data);
partial_x1 = partial_x1 / win_size * spatial_scale;
partial_x2 = partial_x2 / win_size * spatial_scale;
partial_y1 = partial_y1 / win_size * spatial_scale;
partial_y2 = partial_y2 / win_size * spatial_scale;
this_data_grad[0] = 0;
functor(this_data_grad + 1,
(partial_x1 * (1.0 - static_cast<T>(pw) / pooled_width) +
partial_x2 * (1.0 - static_cast<T>(pw + 1) / pooled_width)) *
(*this_out_grad));
functor(this_data_grad + 2,
(partial_y1 * (1.0 - static_cast<T>(ph) / pooled_height) +
partial_y2 * (1.0 - static_cast<T>(ph + 1) / pooled_height)) *
(*this_out_grad));
functor(this_data_grad + 3,
(partial_x2 * static_cast<T>(pw + 1) / pooled_width +
partial_x1 * static_cast<T>(pw) / pooled_width) *
(*this_out_grad));
functor(this_data_grad + 4,
(partial_y2 * static_cast<T>(ph + 1) / pooled_height +
partial_y1 * static_cast<T>(ph) / pooled_height) *
(*this_out_grad));
}
template <typename DeviceContext, typename T>
class CPUPRROIPoolOpKernel : public framework::OpKernel<T> {
public:
......@@ -134,11 +262,11 @@ class CPUPRROIPoolOpKernel : public framework::OpKernel<T> {
auto pooled_height = ctx.Attr<int>("pooled_height");
auto pooled_width = ctx.Attr<int>("pooled_width");
auto spatial_scale = ctx.Attr<float>("spatial_scale");
auto output_channels = ctx.Attr<int>("output_channels");
auto in_dims = in->dims();
int batch_size = in_dims[0];
int input_channels = in_dims[1];
auto output_channels = input_channels;
int height = in_dims[2];
int width = in_dims[3];
int rois_num = rois->dims()[0];
......@@ -162,11 +290,6 @@ class CPUPRROIPoolOpKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE_EQ(rois_num_with_lod, rois_num,
"the rois_num from input and lod must be the same");
PADDLE_ENFORCE_EQ(input_channels,
output_channels * pooled_height * pooled_width,
"the channels of input X should equal the product of "
"output_channels x pooled_height x pooled_width");
// calculate batch id index for each roi according to LoD
for (int n = 0; n < rois_batch_size; ++n) {
for (size_t i = rois_lod[n]; i < rois_lod[n + 1]; ++i) {
......@@ -217,7 +340,7 @@ class CPUPRROIPoolOpKernel : public framework::OpKernel<T> {
int e_h = std::ceil(win_end_h);
int output_index = out_row_offset + pw;
int input_channel = (c * pooled_height + ph) * pooled_width + pw;
int input_channel = c;
int input_plane_offset =
roi_batch_id * in_stride[0] + input_channel * in_stride[1];
const T* offset_input_data = input_data + input_plane_offset;
......@@ -254,20 +377,26 @@ class CPUPRROIPoolGradOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* in = ctx.Input<framework::Tensor>("X");
auto* out = ctx.Input<framework::Tensor>("Out");
auto* rois = ctx.Input<framework::LoDTensor>("ROIs");
auto* output_grad =
ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
auto* input_grad =
ctx.Output<framework::Tensor>(framework::GradVarName("X"));
auto* input_roi_grad =
ctx.Output<framework::Tensor>(framework::GradVarName("ROIs"));
auto pooled_height = ctx.Attr<int>("pooled_height");
auto pooled_width = ctx.Attr<int>("pooled_width");
auto output_channels = ctx.Attr<int>("output_channels");
auto spatial_scale = ctx.Attr<float>("spatial_scale");
if (input_grad) {
if (input_grad && input_roi_grad) {
auto in_dims = in->dims();
auto* in_data = in->data<T>();
auto* out_data = out->data<T>();
int input_channels = in_dims[1];
auto output_channels = input_channels;
int height = in_dims[2];
int width = in_dims[3];
int rois_num = rois->dims()[0];
......@@ -289,6 +418,7 @@ class CPUPRROIPoolGradOpKernel : public framework::OpKernel<T> {
const T* input_rois = rois->data<T>();
const T* output_grad_data = output_grad->data<T>();
T* input_grad_data = input_grad->mutable_data<T>(ctx.GetPlace());
T* input_roi_grad_data = input_roi_grad->mutable_data<T>(ctx.GetPlace());
// set gradient of X to be 0. before backpropagate.
math::SetConstant<DeviceContext, T> set_zero;
......@@ -306,11 +436,12 @@ class CPUPRROIPoolGradOpKernel : public framework::OpKernel<T> {
// set roi_batch_id
int roi_batch_id = rois_batch_id_data[n];
int input_channel = (c * pooled_height + ph) * pooled_width + pw;
int input_channel = c;
int input_offset =
(roi_batch_id * input_channels + input_channel) * height * width;
T* offset_input_grad_data = input_grad_data + input_offset;
const T* offset_output_grad_data = output_grad_data + i;
const T* offset_out_data = out_data + i;
// [start, end) interval for spatial sampling
const T* offset_input_rois = input_rois + n * 4;
......@@ -318,6 +449,7 @@ class CPUPRROIPoolGradOpKernel : public framework::OpKernel<T> {
T roi_start_h = static_cast<T>(offset_input_rois[1]) * spatial_scale;
T roi_end_w = static_cast<T>(offset_input_rois[2]) * spatial_scale;
T roi_end_h = static_cast<T>(offset_input_rois[3]) * spatial_scale;
T* offset_input_roi_grad_data = input_roi_grad_data + n * 4;
T roi_width = std::max(roi_end_w - roi_start_w, static_cast<T>(0.0));
T roi_height = std::max(roi_end_h - roi_start_h, static_cast<T>(0.0));
......@@ -355,6 +487,16 @@ class CPUPRROIPoolGradOpKernel : public framework::OpKernel<T> {
height, width, PrRoIPoolingDistributeDiff<T>);
}
}
const T* offset_in_data = in_data + input_offset;
PrRoIPoolingCoorBackward(
s_w, e_w, s_h, e_h, width, height, win_start_w, win_start_h,
win_end_w, win_end_h, pw, ph, pooled_width, pooled_height, win_size,
spatial_scale, offset_in_data, offset_out_data,
offset_input_grad_data, offset_input_roi_grad_data,
CPUAccumulateRois<T>,
[](const T x, const T y) { return std::max(x, y); },
[](const T x, const T y) { return std::min(x, y); });
}
}
}
......
......@@ -15442,7 +15442,6 @@ def psroi_pool(input,
@templatedoc()
def prroi_pool(input,
rois,
output_channels,
spatial_scale=1.0,
pooled_height=1,
pooled_width=1,
......@@ -15459,7 +15458,6 @@ def prroi_pool(input,
is 1. Given as [[x1, y1, x2, y2], ...], (x1, y1) is
the top left coordinates, and (x2, y2) is the bottom
right coordinates.
output_channels (integer): The output's channel.
spatial_scale (float): Ratio of input feature map height (or width) to raw image height (or width).
Equals the reciprocal of total stride in convolutional layers, Default: 1.0.
pooled_height (integer): The pooled output height. Default: 1.
......@@ -15475,12 +15473,10 @@ def prroi_pool(input,
import paddle.fluid as fluid
x = fluid.layers.data(name='x', shape=[490, 28, 28], dtype='float32')
rois = fluid.layers.data(name='rois', shape=[4], lod_level=1, dtype='float32')
pool_out = fluid.layers.prroi_pool(x, rois, 10, 1.0, 7, 7)
pool_out = fluid.layers.prroi_pool(x, rois, 1.0, 7, 7)
"""
helper = LayerHelper('prroi_pool', **locals())
# check attrs
if not isinstance(output_channels, int):
raise TypeError("output_channels must be int type")
if not isinstance(spatial_scale, float):
raise TypeError("spatial_scale must be float type")
if not isinstance(pooled_height, int):
......@@ -15495,7 +15491,6 @@ def prroi_pool(input,
'ROIs': rois},
outputs={'Out': out},
attrs={
'output_channels': output_channels,
'spatial_scale': spatial_scale,
'pooled_height': pooled_height,
'pooled_width': pooled_width
......
......@@ -133,8 +133,7 @@ class PyPrRoIPool(object):
s_h = math.floor(win_start_h)
e_h = math.ceil(win_end_h)
c_in = (c * pooled_height + ph) * pooled_width + pw
c_in = c
for w_iter in range(int(s_w), int(e_w)):
for h_iter in range(int(s_h), int(e_h)):
sum_out += self._PrRoIPoolingMatCalculation(
......
......@@ -48,7 +48,7 @@ class TestPRROIPoolOp(OpTest):
self.x_dim = [self.batch_size, self.channels, self.height, self.width]
self.spatial_scale = 1.0 / 4.0
self.output_channels = 3
self.output_channels = self.channels
self.pooled_height = 2
self.pooled_width = 2
......@@ -60,15 +60,15 @@ class TestPRROIPoolOp(OpTest):
for bno in range(self.batch_size):
self.rois_lod[0].append(bno + 1)
for i in range(bno + 1):
x1 = np.random.random_integers(
x1 = np.random.uniform(
0, self.width // self.spatial_scale - self.pooled_width)
y1 = np.random.random_integers(
y1 = np.random.uniform(
0, self.height // self.spatial_scale - self.pooled_height)
x2 = np.random.random_integers(x1 + self.pooled_width,
self.width // self.spatial_scale)
y2 = np.random.random_integers(
y1 + self.pooled_height, self.height // self.spatial_scale)
x2 = np.random.uniform(x1 + self.pooled_width,
self.width // self.spatial_scale)
y2 = np.random.uniform(y1 + self.pooled_height,
self.height // self.spatial_scale)
roi = [bno, x1, y1, x2, y2]
rois.append(roi)
self.rois_num = len(rois)
......@@ -93,8 +93,7 @@ class TestPRROIPoolOp(OpTest):
dtype="float32")
rois = fluid.layers.data(
name="ROIs", shape=[4], dtype="float32", lod_level=1)
output = fluid.layers.prroi_pool(x, rois, self.output_channels,
0.25, 2, 2)
output = fluid.layers.prroi_pool(x, rois, 0.25, 2, 2)
loss = fluid.layers.mean(output)
optimizer = fluid.optimizer.SGD(learning_rate=1e-3)
optimizer.minimize(loss)
......@@ -120,18 +119,15 @@ class TestPRROIPoolOp(OpTest):
name="x", shape=[245, 30, 30], dtype="float32")
rois = fluid.layers.data(
name="rois", shape=[4], dtype="float32", lod_level=1)
# channel must be int type
self.assertRaises(TypeError, fluid.layers.prroi_pool, x, rois, 0.5,
0.25, 7, 7)
# spatial_scale must be float type
self.assertRaises(TypeError, fluid.layers.prroi_pool, x, rois, 5, 2,
7, 7)
self.assertRaises(TypeError, fluid.layers.prroi_pool, x, rois, 2, 7,
7)
# pooled_height must be int type
self.assertRaises(TypeError, fluid.layers.prroi_pool, x, rois, 5,
0.25, 0.7, 7)
self.assertRaises(TypeError, fluid.layers.prroi_pool, x, rois, 0.25,
0.7, 7)
# pooled_width must be int type
self.assertRaises(TypeError, fluid.layers.prroi_pool, x, rois, 5,
0.25, 7, 0.7)
self.assertRaises(TypeError, fluid.layers.prroi_pool, x, rois, 0.25,
7, 0.7)
if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册