提交 9a14ca91 编写于 作者: J jerrywgz

test=develop

上级 4c9884e7
...@@ -114,7 +114,7 @@ paddle.fluid.layers.pad ArgSpec(args=['x', 'paddings', 'pad_value', 'name'], var ...@@ -114,7 +114,7 @@ paddle.fluid.layers.pad ArgSpec(args=['x', 'paddings', 'pad_value', 'name'], var
paddle.fluid.layers.pad_constant_like ArgSpec(args=['x', 'y', 'pad_value', 'name'], varargs=None, keywords=None, defaults=(0.0, None)) paddle.fluid.layers.pad_constant_like ArgSpec(args=['x', 'y', 'pad_value', 'name'], varargs=None, keywords=None, defaults=(0.0, None))
paddle.fluid.layers.label_smooth ArgSpec(args=['label', 'prior_dist', 'epsilon', 'dtype', 'name'], varargs=None, keywords=None, defaults=(None, 0.1, 'float32', None)) paddle.fluid.layers.label_smooth ArgSpec(args=['label', 'prior_dist', 'epsilon', 'dtype', 'name'], varargs=None, keywords=None, defaults=(None, 0.1, 'float32', None))
paddle.fluid.layers.roi_pool ArgSpec(args=['input', 'rois', 'pooled_height', 'pooled_width', 'spatial_scale'], varargs=None, keywords=None, defaults=(1, 1, 1.0)) paddle.fluid.layers.roi_pool ArgSpec(args=['input', 'rois', 'pooled_height', 'pooled_width', 'spatial_scale'], varargs=None, keywords=None, defaults=(1, 1, 1.0))
paddle.fluid.layers.roi_align ArgSpec(args=['input', 'rois', 'pooled_height', 'pooled_width', 'spatial_scale', 'sampling_ratio'], varargs=None, keywords=None, defaults=(1, 1, 1.0, -1)) paddle.fluid.layers.roi_align ArgSpec(args=['input', 'rois', 'pooled_height', 'pooled_width', 'spatial_scale', 'sampling_ratio', 'name'], varargs=None, keywords=None, defaults=(1, 1, 1.0, -1, None))
paddle.fluid.layers.dice_loss ArgSpec(args=['input', 'label', 'epsilon'], varargs=None, keywords=None, defaults=(1e-05,)) paddle.fluid.layers.dice_loss ArgSpec(args=['input', 'label', 'epsilon'], varargs=None, keywords=None, defaults=(1e-05,))
paddle.fluid.layers.image_resize ArgSpec(args=['input', 'out_shape', 'scale', 'name', 'resample'], varargs=None, keywords=None, defaults=(None, None, None, 'BILINEAR')) paddle.fluid.layers.image_resize ArgSpec(args=['input', 'out_shape', 'scale', 'name', 'resample'], varargs=None, keywords=None, defaults=(None, None, None, 'BILINEAR'))
paddle.fluid.layers.image_resize_short ArgSpec(args=['input', 'out_short_len', 'resample'], varargs=None, keywords=None, defaults=('BILINEAR',)) paddle.fluid.layers.image_resize_short ArgSpec(args=['input', 'out_short_len', 'resample'], varargs=None, keywords=None, defaults=('BILINEAR',))
......
...@@ -94,7 +94,7 @@ class ROIAlignOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -94,7 +94,7 @@ class ROIAlignOpMaker : public framework::OpProtoAndCheckerMaker {
void Make() override { void Make() override {
AddInput("X", AddInput("X",
"(Tensor), " "(Tensor), "
"the input of ROIAlignOp. " "The input of ROIAlignOp. "
"The format of input tensor is NCHW. Where N is batch size, " "The format of input tensor is NCHW. Where N is batch size, "
"C is the number of input channels, " "C is the number of input channels, "
"H is the height of the feature, and " "H is the height of the feature, and "
...@@ -104,7 +104,6 @@ class ROIAlignOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -104,7 +104,6 @@ class ROIAlignOpMaker : public framework::OpProtoAndCheckerMaker {
"ROIs (Regions of Interest) to pool over. " "ROIs (Regions of Interest) to pool over. "
"should be a 2-D LoDTensor of shape (num_rois, 4)" "should be a 2-D LoDTensor of shape (num_rois, 4)"
"given as [[x1, y1, x2, y2], …]. " "given as [[x1, y1, x2, y2], …]. "
"Where batch_id is the id of the data, "
"(x1, y1) is the top left coordinates, and " "(x1, y1) is the top left coordinates, and "
"(x2, y2) is the bottom right coordinates."); "(x2, y2) is the bottom right coordinates.");
AddOutput("Out", AddOutput("Out",
......
...@@ -34,17 +34,13 @@ static inline int NumBlocks(const int N) { ...@@ -34,17 +34,13 @@ static inline int NumBlocks(const int N) {
i += blockDim.x * gridDim.x) i += blockDim.x * gridDim.x)
template <class T> template <class T>
__device__ T bilinear_interpolate(const T* input_data, const int height, __device__ T BilinearInterpolate(const T* input_data, const int height,
const int width, T y, T x) { const int width, T y, T x) {
if (y < -1.0 || y > height || x < -1.0 || x > width) { if (y < -1.0 || y > height || x < -1.0 || x > width) {
return 0; return 0;
} }
if (y <= 0) { y = y <= 0 ? 0 : y;
y = 0; x = x <= 0 ? 0 : x;
}
if (x <= 0) {
x = 0;
}
int y_low = static_cast<int>(y); int y_low = static_cast<int>(y);
int x_low = static_cast<int>(x); int x_low = static_cast<int>(x);
int y_high; int y_high;
...@@ -75,20 +71,16 @@ __device__ T bilinear_interpolate(const T* input_data, const int height, ...@@ -75,20 +71,16 @@ __device__ T bilinear_interpolate(const T* input_data, const int height,
} }
template <class T> template <class T>
__device__ void bilinear_interpolate_gradient(const int height, const int width, __device__ void BilinearInterpolateGradient(const int height, const int width,
T y, T x, T* w1, T* w2, T* w3, T y, T x, T* w1, T* w2, T* w3,
T* w4, int* x_low, int* x_high, T* w4, int* x_low, int* x_high,
int* y_low, int* y_high) { int* y_low, int* y_high) {
if (y < -1.0 || y > height || x < -1.0 || x > width) { if (y < -1.0 || y > height || x < -1.0 || x > width) {
return; return;
} }
if (y <= 0) { y = y <= 0 ? 0 : y;
y = 0; x = x <= 0 ? 0 : x;
}
if (x <= 0) {
x = 0;
}
*y_low = static_cast<int>(y); *y_low = static_cast<int>(y);
*x_low = static_cast<int>(x); *x_low = static_cast<int>(x);
if (*y_low >= height - 1) { if (*y_low >= height - 1) {
...@@ -153,7 +145,7 @@ __global__ void GPUROIAlignForward( ...@@ -153,7 +145,7 @@ __global__ void GPUROIAlignForward(
const T x = roi_xmin + pw * bin_size_w + const T x = roi_xmin + pw * bin_size_w +
static_cast<T>(ix + .5f) * bin_size_w / static_cast<T>(ix + .5f) * bin_size_w /
static_cast<T>(roi_bin_grid_w); static_cast<T>(roi_bin_grid_w);
T val = bilinear_interpolate(offset_input_data, height, width, y, x); T val = BilinearInterpolate(offset_input_data, height, width, y, x);
output_val += val; output_val += val;
} }
} }
...@@ -213,8 +205,8 @@ __global__ void GPUROIAlignBackward(const int nthreads, const T* input_rois, ...@@ -213,8 +205,8 @@ __global__ void GPUROIAlignBackward(const int nthreads, const T* input_rois,
static_cast<T>(roi_bin_grid_w); static_cast<T>(roi_bin_grid_w);
T w1 = 0, w2 = 0, w3 = 0, w4 = 0; T w1 = 0, w2 = 0, w3 = 0, w4 = 0;
int x_low = -1, x_high = -1, y_low = -1, y_high = -1; int x_low = -1, x_high = -1, y_low = -1, y_high = -1;
bilinear_interpolate_gradient(height, width, y, x, &w1, &w2, &w3, &w4, BilinearInterpolateGradient(height, width, y, x, &w1, &w2, &w3, &w4,
&x_low, &x_high, &y_low, &y_high); &x_low, &x_high, &y_low, &y_high);
T diff1 = out_grad_this_bin * w1 / count; T diff1 = out_grad_this_bin * w1 / count;
T diff2 = out_grad_this_bin * w2 / count; T diff2 = out_grad_this_bin * w2 / count;
T diff3 = out_grad_this_bin * w3 / count; T diff3 = out_grad_this_bin * w3 / count;
...@@ -279,8 +271,8 @@ class GPUROIAlignOpKernel : public framework::OpKernel<T> { ...@@ -279,8 +271,8 @@ class GPUROIAlignOpKernel : public framework::OpKernel<T> {
} }
} }
Tensor roi_batch_id_list_gpu; Tensor roi_batch_id_list_gpu;
framework::TensorCopy(roi_batch_id_list, ctx.GetPlace(), framework::TensorCopySync(roi_batch_id_list, ctx.GetPlace(),
ctx.device_context(), &roi_batch_id_list_gpu); &roi_batch_id_list_gpu);
GPUROIAlignForward< GPUROIAlignForward<
T><<<blocks, threads, 0, ctx.cuda_device_context().stream()>>>( T><<<blocks, threads, 0, ctx.cuda_device_context().stream()>>>(
output_size, in->data<T>(), rois->data<T>(), spatial_scale, channels, output_size, in->data<T>(), rois->data<T>(), spatial_scale, channels,
...@@ -310,39 +302,40 @@ class GPUROIAlignGradOpKernel : public framework::OpKernel<T> { ...@@ -310,39 +302,40 @@ class GPUROIAlignGradOpKernel : public framework::OpKernel<T> {
int height = in->dims()[2]; int height = in->dims()[2];
int width = in->dims()[3]; int width = in->dims()[3];
if (in_grad) { if (!in_grad) {
Tensor roi_batch_id_list; return;
roi_batch_id_list.Resize({rois_num}); }
int* roi_batch_id_data = Tensor roi_batch_id_list;
roi_batch_id_list.mutable_data<int>(platform::CPUPlace()); roi_batch_id_list.Resize({rois_num});
auto rois_lod = rois->lod().back(); int* roi_batch_id_data =
int rois_batch_size = rois_lod.size() - 1; roi_batch_id_list.mutable_data<int>(platform::CPUPlace());
for (int n = 0; n < rois_batch_size; ++n) { auto rois_lod = rois->lod().back();
for (size_t i = rois_lod[n]; i < rois_lod[n + 1]; ++i) { int rois_batch_size = rois_lod.size() - 1;
roi_batch_id_data[i] = n; for (int n = 0; n < rois_batch_size; ++n) {
} for (size_t i = rois_lod[n]; i < rois_lod[n + 1]; ++i) {
} roi_batch_id_data[i] = n;
Tensor roi_batch_id_list_gpu;
framework::TensorCopy(roi_batch_id_list, ctx.GetPlace(),
ctx.device_context(), &roi_batch_id_list_gpu);
in_grad->mutable_data<T>(ctx.GetPlace());
math::SetConstant<Place, T> set_zero;
set_zero(ctx.cuda_device_context(), in_grad, static_cast<T>(0));
int output_grad_size = out_grad->numel();
int blocks = NumBlocks(output_grad_size);
int threads = kNumCUDAThreads;
if (output_grad_size > 0) {
GPUROIAlignBackward<
T><<<blocks, threads, 0, ctx.cuda_device_context().stream()>>>(
output_grad_size, rois->data<T>(), out_grad->data<T>(), rois_num,
spatial_scale, channels, height, width, pooled_height, pooled_width,
sampling_ratio, roi_batch_id_list_gpu.data<int>(),
in_grad->mutable_data<T>(ctx.GetPlace()));
} }
} }
Tensor roi_batch_id_list_gpu;
framework::TensorCopySync(roi_batch_id_list, ctx.GetPlace(),
&roi_batch_id_list_gpu);
in_grad->mutable_data<T>(ctx.GetPlace());
math::SetConstant<Place, T> set_zero;
set_zero(ctx.cuda_device_context(), in_grad, static_cast<T>(0));
int output_grad_size = out_grad->numel();
int blocks = NumBlocks(output_grad_size);
int threads = kNumCUDAThreads;
if (output_grad_size > 0) {
GPUROIAlignBackward<
T><<<blocks, threads, 0, ctx.cuda_device_context().stream()>>>(
output_grad_size, rois->data<T>(), out_grad->data<T>(), rois_num,
spatial_scale, channels, height, width, pooled_height, pooled_width,
sampling_ratio, roi_batch_id_list_gpu.data<int>(),
in_grad->mutable_data<T>(ctx.GetPlace()));
}
} }
}; };
......
...@@ -24,7 +24,7 @@ using LoDTensor = framework::LoDTensor; ...@@ -24,7 +24,7 @@ using LoDTensor = framework::LoDTensor;
static constexpr int kROISize = 4; static constexpr int kROISize = 4;
template <class T> template <class T>
void pre_calc_for_bilinear_interpolate( void PreCalcForBilinearInterpolate(
const platform::DeviceContext& ctx, const int height, const int width, const platform::DeviceContext& ctx, const int height, const int width,
const int pooled_height, const int pooled_width, const int iy_upper, const int pooled_height, const int pooled_width, const int iy_upper,
const int ix_upper, T roi_ymin, T roi_xmin, T bin_size_h, T bin_size_w, const int ix_upper, T roi_ymin, T roi_xmin, T bin_size_h, T bin_size_w,
...@@ -53,12 +53,8 @@ void pre_calc_for_bilinear_interpolate( ...@@ -53,12 +53,8 @@ void pre_calc_for_bilinear_interpolate(
pre_calc_index += 1; pre_calc_index += 1;
continue; continue;
} }
if (y <= 0) { y = y <= 0 ? 0 : y;
y = 0; x = x <= 0 ? 0 : x;
}
if (x <= 0) {
x = 0;
}
int y_low = static_cast<int>(y); int y_low = static_cast<int>(y);
int x_low = static_cast<int>(x); int x_low = static_cast<int>(x);
...@@ -104,12 +100,8 @@ void bilinear_interpolate_gradient(const int height, const int width, T y, T x, ...@@ -104,12 +100,8 @@ void bilinear_interpolate_gradient(const int height, const int width, T y, T x,
x_low = x_high = y_low = y_high = -1; x_low = x_high = y_low = y_high = -1;
return; return;
} }
if (y <= 0) { y = y <= 0 ? 0 : y;
y = 0; x = x <= 0 ? 0 : x;
}
if (x <= 0) {
x = 0;
}
y_low = static_cast<int>(y); y_low = static_cast<int>(y);
x_low = static_cast<int>(x); x_low = static_cast<int>(x);
if (y_low >= height - 1) { if (y_low >= height - 1) {
...@@ -139,7 +131,6 @@ void bilinear_interpolate_gradient(const int height, const int width, T y, T x, ...@@ -139,7 +131,6 @@ void bilinear_interpolate_gradient(const int height, const int width, T y, T x,
*(batch_grad_data + y_high * width + x_low) += diff3; *(batch_grad_data + y_high * width + x_low) += diff3;
*(batch_grad_data + y_high * width + x_high) += diff4; *(batch_grad_data + y_high * width + x_high) += diff4;
} }
return;
} }
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
...@@ -214,7 +205,7 @@ class CPUROIAlignOpKernel : public framework::OpKernel<T> { ...@@ -214,7 +205,7 @@ class CPUROIAlignOpKernel : public framework::OpKernel<T> {
pre_pos.Resize({pre_size, kROISize}); pre_pos.Resize({pre_size, kROISize});
pre_w.Resize({pre_size, kROISize}); pre_w.Resize({pre_size, kROISize});
pre_calc_for_bilinear_interpolate( PreCalcForBilinearInterpolate(
dev_ctx, height, width, pooled_height, pooled_width, roi_bin_grid_h, dev_ctx, height, width, pooled_height, pooled_width, roi_bin_grid_h,
roi_bin_grid_w, roi_ymin, roi_xmin, bin_size_h, bin_size_w, roi_bin_grid_w, roi_ymin, roi_xmin, bin_size_h, bin_size_w,
roi_bin_grid_h, roi_bin_grid_w, &pre_pos, &pre_w); roi_bin_grid_h, roi_bin_grid_w, &pre_pos, &pre_w);
...@@ -245,7 +236,6 @@ class CPUROIAlignOpKernel : public framework::OpKernel<T> { ...@@ -245,7 +236,6 @@ class CPUROIAlignOpKernel : public framework::OpKernel<T> {
} }
rois_data += roi_stride[0]; rois_data += roi_stride[0];
} }
return;
} }
}; };
...@@ -264,79 +254,78 @@ class CPUROIAlignGradOpKernel : public framework::OpKernel<T> { ...@@ -264,79 +254,78 @@ class CPUROIAlignGradOpKernel : public framework::OpKernel<T> {
auto spatial_scale = ctx.Attr<float>("spatial_scale"); auto spatial_scale = ctx.Attr<float>("spatial_scale");
auto sampling_ratio = ctx.Attr<int>("sampling_ratio"); auto sampling_ratio = ctx.Attr<int>("sampling_ratio");
auto in_dims = in->dims(); auto in_dims = in->dims();
if (in_grad) { if (!in_grad) {
int channels = in_dims[1]; return;
int height = in_dims[2]; }
int width = in_dims[3]; int channels = in_dims[1];
int rois_num = rois->dims()[0]; int height = in_dims[2];
Tensor roi_batch_id_list; int width = in_dims[3];
roi_batch_id_list.Resize({rois_num}); int rois_num = rois->dims()[0];
int* roi_batch_id_data = Tensor roi_batch_id_list;
roi_batch_id_list.mutable_data<int>(ctx.GetPlace()); roi_batch_id_list.Resize({rois_num});
int* roi_batch_id_data =
roi_batch_id_list.mutable_data<int>(ctx.GetPlace());
auto rois_lod = rois->lod().back(); auto rois_lod = rois->lod().back();
int rois_batch_size = rois_lod.size() - 1; int rois_batch_size = rois_lod.size() - 1;
for (int n = 0; n < rois_batch_size; ++n) { for (int n = 0; n < rois_batch_size; ++n) {
for (size_t i = rois_lod[n]; i < rois_lod[n + 1]; ++i) { for (size_t i = rois_lod[n]; i < rois_lod[n + 1]; ++i) {
roi_batch_id_data[i] = n; roi_batch_id_data[i] = n;
}
} }
}
const T* rois_data = rois->data<T>(); const T* rois_data = rois->data<T>();
const T* out_grad_data = out_grad->data<T>(); const T* out_grad_data = out_grad->data<T>();
T* in_grad_data = in_grad->mutable_data<T>(ctx.GetPlace()); T* in_grad_data = in_grad->mutable_data<T>(ctx.GetPlace());
auto in_stride = framework::stride(in->dims()); auto in_stride = framework::stride(in->dims());
auto roi_stride = framework::stride(rois->dims()); auto roi_stride = framework::stride(rois->dims());
auto out_stride = framework::stride(out_grad->dims()); auto out_stride = framework::stride(out_grad->dims());
for (int n = 0; n < rois_num; ++n) { for (int n = 0; n < rois_num; ++n) {
int roi_batch_idx = roi_batch_id_data[n]; int roi_batch_idx = roi_batch_id_data[n];
T roi_xmin = rois_data[0] * spatial_scale; T roi_xmin = rois_data[0] * spatial_scale;
T roi_ymin = rois_data[1] * spatial_scale; T roi_ymin = rois_data[1] * spatial_scale;
T roi_xmax = rois_data[2] * spatial_scale; T roi_xmax = rois_data[2] * spatial_scale;
T roi_ymax = rois_data[3] * spatial_scale; T roi_ymax = rois_data[3] * spatial_scale;
T roi_width = std::max(roi_xmax - roi_xmin, static_cast<T>(1.)); T roi_width = std::max(roi_xmax - roi_xmin, static_cast<T>(1.));
T roi_height = std::max(roi_ymax - roi_ymin, static_cast<T>(1.)); T roi_height = std::max(roi_ymax - roi_ymin, static_cast<T>(1.));
T bin_size_h = T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
static_cast<T>(roi_height) / static_cast<T>(pooled_height); T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width); for (int c = 0; c < channels; ++c) {
for (int c = 0; c < channels; ++c) { T* batch_grad_data =
T* batch_grad_data = in_grad_data + roi_batch_idx * in_stride[0] + c * in_stride[1];
in_grad_data + roi_batch_idx * in_stride[0] + c * in_stride[1]; const T* batch_out_grad_data =
const T* batch_out_grad_data = out_grad_data + n * out_stride[0] + c * out_stride[1];
out_grad_data + n * out_stride[0] + c * out_stride[1]; for (int ph = 0; ph < pooled_height; ++ph) {
for (int ph = 0; ph < pooled_height; ++ph) { for (int pw = 0; pw < pooled_width; ++pw) {
for (int pw = 0; pw < pooled_width; ++pw) { int pool_index = ph * pooled_width + pw;
int pool_index = ph * pooled_width + pw; T out_grad_this_bin = batch_out_grad_data[pool_index];
T out_grad_this_bin = batch_out_grad_data[pool_index]; int roi_bin_grid_h = (sampling_ratio > 0)
int roi_bin_grid_h = (sampling_ratio > 0) ? sampling_ratio
? sampling_ratio : ceil(roi_height / pooled_height);
: ceil(roi_height / pooled_height); int roi_bin_grid_w = (sampling_ratio > 0)
int roi_bin_grid_w = (sampling_ratio > 0) ? sampling_ratio
? sampling_ratio : ceil(roi_width / pooled_width);
: ceil(roi_width / pooled_width); T count = roi_bin_grid_h * roi_bin_grid_w;
T count = roi_bin_grid_h * roi_bin_grid_w; for (int iy = 0; iy < roi_bin_grid_h; iy++) {
for (int iy = 0; iy < roi_bin_grid_h; iy++) { const T y = roi_ymin + ph * bin_size_h +
const T y = roi_ymin + ph * bin_size_h + static_cast<T>(iy + .5f) * bin_size_h /
static_cast<T>(iy + .5f) * bin_size_h / static_cast<T>(roi_bin_grid_h);
static_cast<T>(roi_bin_grid_h); for (int ix = 0; ix < roi_bin_grid_w; ix++) {
for (int ix = 0; ix < roi_bin_grid_w; ix++) { const T x = roi_xmin + pw * bin_size_w +
const T x = roi_xmin + pw * bin_size_w + static_cast<T>(ix + .5f) * bin_size_w /
static_cast<T>(ix + .5f) * bin_size_w / static_cast<T>(roi_bin_grid_w);
static_cast<T>(roi_bin_grid_w); bilinear_interpolate_gradient(height, width, y, x,
bilinear_interpolate_gradient(height, width, y, x, out_grad_this_bin, count,
out_grad_this_bin, count, batch_grad_data);
batch_grad_data);
}
} }
} }
} }
} }
rois_data += roi_stride[0];
} }
rois_data += roi_stride[0];
} }
return;
} }
}; };
} // namespace operators } // namespace operators
......
...@@ -5184,7 +5184,8 @@ def roi_align(input, ...@@ -5184,7 +5184,8 @@ def roi_align(input,
pooled_height=1, pooled_height=1,
pooled_width=1, pooled_width=1,
spatial_scale=1.0, spatial_scale=1.0,
sampling_ratio=-1): sampling_ratio=-1,
name=None):
""" """
${comment} ${comment}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册