diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index 8ee8310cb53ddeb3371fdacc592f6ec5571a241f..256693528960a9747f89df2c553102322ff1cbeb 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -73,7 +73,6 @@ paddle.fluid.io.load_params ArgSpec(args=['executor', 'dirname', 'main_program', paddle.fluid.io.load_persistables ArgSpec(args=['executor', 'dirname', 'main_program', 'filename'], varargs=None, keywords=None, defaults=(None, None)) paddle.fluid.io.save_inference_model ArgSpec(args=['dirname', 'feeded_var_names', 'target_vars', 'executor', 'main_program', 'model_filename', 'params_filename', 'export_for_deployment'], varargs=None, keywords=None, defaults=(None, None, None, True)) paddle.fluid.io.load_inference_model ArgSpec(args=['dirname', 'executor', 'model_filename', 'params_filename', 'pserver_endpoints'], varargs=None, keywords=None, defaults=(None, None, None)) -paddle.fluid.io.get_inference_program ArgSpec(args=['target_vars', 'main_program'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.initializer.ConstantInitializer.__init__ ArgSpec(args=['self', 'value', 'force_cpu'], varargs=None, keywords=None, defaults=(0.0, False)) paddle.fluid.initializer.UniformInitializer.__init__ ArgSpec(args=['self', 'low', 'high', 'seed'], varargs=None, keywords=None, defaults=(-1.0, 1.0, 0)) paddle.fluid.initializer.NormalInitializer.__init__ ArgSpec(args=['self', 'loc', 'scale', 'seed'], varargs=None, keywords=None, defaults=(0.0, 1.0, 0)) @@ -296,6 +295,7 @@ paddle.fluid.layers.ssd_loss ArgSpec(args=['location', 'confidence', 'gt_box', ' paddle.fluid.layers.detection_map ArgSpec(args=['detect_res', 'label', 'class_num', 'background_label', 'overlap_threshold', 'evaluate_difficult', 'has_state', 'input_states', 'out_states', 'ap_version'], varargs=None, keywords=None, defaults=(0, 0.3, True, None, None, None, 'integral')) paddle.fluid.layers.rpn_target_assign ArgSpec(args=['bbox_pred', 'cls_logits', 'anchor_box', 'anchor_var', 'gt_boxes', 'is_crowd', 'im_info', 'rpn_batch_size_per_im', 'rpn_straddle_thresh', 'rpn_fg_fraction', 'rpn_positive_overlap', 'rpn_negative_overlap', 'use_random'], varargs=None, keywords=None, defaults=(256, 0.0, 0.5, 0.7, 0.3, True)) paddle.fluid.layers.anchor_generator ArgSpec(args=['input', 'anchor_sizes', 'aspect_ratios', 'variance', 'stride', 'offset', 'name'], varargs=None, keywords=None, defaults=(None, None, [0.1, 0.1, 0.2, 0.2], None, 0.5, None)) +paddle.fluid.layers.roi_perspective_transform ArgSpec(args=['input', 'rois', 'transformed_height', 'transformed_width', 'spatial_scale'], varargs=None, keywords=None, defaults=(1.0,)) paddle.fluid.layers.generate_proposal_labels ArgSpec(args=['rpn_rois', 'gt_classes', 'is_crowd', 'gt_boxes', 'im_info', 'batch_size_per_im', 'fg_fraction', 'fg_thresh', 'bg_thresh_hi', 'bg_thresh_lo', 'bbox_reg_weights', 'class_nums', 'use_random'], varargs=None, keywords=None, defaults=(256, 0.25, 0.25, 0.5, 0.0, [0.1, 0.1, 0.2, 0.2], None, True)) paddle.fluid.layers.generate_proposals ArgSpec(args=['scores', 'bbox_deltas', 'im_info', 'anchors', 'variances', 'pre_nms_top_n', 'post_nms_top_n', 'nms_thresh', 'min_size', 'eta', 'name'], varargs=None, keywords=None, defaults=(6000, 1000, 0.5, 0.1, 1.0, None)) paddle.fluid.layers.iou_similarity ArgSpec(args=[], varargs='args', keywords='kwargs', defaults=None) diff --git a/paddle/fluid/operators/detection/CMakeLists.txt b/paddle/fluid/operators/detection/CMakeLists.txt index f4983c65432991a45f226d97f0fb05b08a30ca89..5a058ddbc59c6135bacf7c2dc4b5c8b687f9b2b1 100644 --- a/paddle/fluid/operators/detection/CMakeLists.txt +++ b/paddle/fluid/operators/detection/CMakeLists.txt @@ -31,5 +31,6 @@ polygon_box_transform_op.cu) detection_library(rpn_target_assign_op SRCS rpn_target_assign_op.cc) detection_library(generate_proposal_labels_op SRCS generate_proposal_labels_op.cc) detection_library(generate_proposals_op SRCS generate_proposals_op.cc) +detection_library(roi_perspective_transform_op SRCS roi_perspective_transform_op.cc roi_perspective_transform_op.cu) #Export local libraries to parent set(DETECTION_LIBRARY ${LOCAL_DETECTION_LIBS} PARENT_SCOPE) diff --git a/paddle/fluid/operators/detection/roi_perspective_transform_op.cc b/paddle/fluid/operators/detection/roi_perspective_transform_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..b98190d40a2afa684cfd29cc52fc29fac851cca7 --- /dev/null +++ b/paddle/fluid/operators/detection/roi_perspective_transform_op.cc @@ -0,0 +1,587 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/math/math_function.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +using LoDTensor = framework::LoDTensor; + +static constexpr int kROISize = 4; + +template +bool GT_E(T a, T b) { + return (a > b) || fabs(a - b) < 1e-4; +} + +template +bool LT_E(T a, T b) { + return (a < b) || fabs(a - b) < 1e-4; +} + +template +bool GT(T a, T b) { + return (a - b) > 1e-4; +} + +/* +*check if (x, y) is in the boundary of roi +*/ +template +bool in_quad(T x, T y, T roi_x[], T roi_y[]) { + for (int i = 0; i < 4; i++) { + T xs = roi_x[i]; + T ys = roi_y[i]; + T xe = roi_x[(i + 1) % 4]; + T ye = roi_y[(i + 1) % 4]; + if (fabs(ys - ye) < 1e-4) { + if (fabs(y - ys) < 1e-4 && fabs(y - ye) < 1e-4 && + GT_E(x, std::min(xs, xe)) && LT_E(x, std::max(xs, xe))) { + return true; + } + } else { + T intersec_x = (y - ys) * (xe - xs) / (ye - ys) + xs; + if (fabs(intersec_x - x) < 1e-4 && GT_E(y, std::min(ys, ye)) && + LT_E(y, std::max(ys, ye))) { + return true; + } + } + } + + int n_cross = 0; + for (int i = 0; i < 4; i++) { + T xs = roi_x[i]; + T ys = roi_y[i]; + T xe = roi_x[(i + 1) % 4]; + T ye = roi_y[(i + 1) % 4]; + if (fabs(ys - ye) < 1e-4) { + continue; + } + if (LT_E(y, std::min(ys, ye)) || GT(y, std::max(ys, ye))) { + continue; + } + T intersec_x = (y - ys) * (xe - xs) / (ye - ys) + xs; + if (fabs(intersec_x - x) < 1e-4) { + return true; + } + if (GT(intersec_x, x)) { + n_cross++; + } + } + return (n_cross % 2 == 1); +} + +/** + * Get the matrix of perspective transform. + * + * dx1 = x1 - x2 + * dx2 = x3 - x2 + * dx3 = x0 - x1 + x2 - x3 + * dy1 = y1 - y2 + * dy2 = y3 - y2 + * dy3 = y0 - y1 + y2 - y3 + * + * a11 = (x1 - x0 + a31 * (w - 1) * x1) / (w - 1) + * a12 = (x3 - x0 + a32 * (h - 1) * x3) / (h - 1) + * a13 = x0 + * a21 = (y1 - y0 + a31 * (w - 1) * y1) / (w - 1) + * a22 = (y3 - y0 + a32 * (h - 1) * y3) / (h - 1) + * a23 = y0 + * a31 = (dx3 * dy2 - dx2 * dy3) / (dx1 * dy2 - dx2 * dy1) / (w - 1) + * a32 = (dx1 * dy3 - dx3 * dy1) / (dx1 * dy2 - dx2 * dy1) / (h - 1) + * a33 = 1 + * + */ +template +void get_transform_matrix(const int transformed_width, + const int transformed_height, T roi_x[], T roi_y[], + T matrix[]) { + T x0 = roi_x[0]; + T x1 = roi_x[1]; + T x2 = roi_x[2]; + T x3 = roi_x[3]; + T y0 = roi_y[0]; + T y1 = roi_y[1]; + T y2 = roi_y[2]; + T y3 = roi_y[3]; + + // Estimate the height and width of RoI + T len1 = sqrt((x0 - x1) * (x0 - x1) + (y0 - y1) * (y0 - y1)); + T len2 = sqrt((x1 - x2) * (x1 - x2) + (y1 - y2) * (y1 - y2)); + T len3 = sqrt((x2 - x3) * (x2 - x3) + (y2 - y3) * (y2 - y3)); + T len4 = sqrt((x3 - x0) * (x3 - x0) + (y3 - y0) * (y3 - y0)); + T estimated_height = (len2 + len4) / 2.0; + T estimated_width = (len1 + len3) / 2.0; + + // Get the normalized height and normalized width + int normalized_height = transformed_height; + int normalized_width = + std::round(estimated_width * (normalized_height - 1) / estimated_height) + + 1; + normalized_width = std::min(normalized_width, transformed_width); + + T dx1 = x1 - x2; + T dx2 = x3 - x2; + T dx3 = x0 - x1 + x2 - x3; + T dy1 = y1 - y2; + T dy2 = y3 - y2; + T dy3 = y0 - y1 + y2 - y3; + + matrix[6] = (dx3 * dy2 - dx2 * dy3) / (dx1 * dy2 - dx2 * dy1) / + (normalized_width - 1); + matrix[7] = (dx1 * dy3 - dx3 * dy1) / (dx1 * dy2 - dx2 * dy1) / + (normalized_height - 1); + matrix[8] = 1; + + matrix[3] = (y1 - y0 + matrix[6] * (normalized_width - 1) * y1) / + (normalized_width - 1); + matrix[4] = (y3 - y0 + matrix[7] * (normalized_height - 1) * y3) / + (normalized_height - 1); + matrix[5] = y0; + + matrix[0] = (x1 - x0 + matrix[6] * (normalized_width - 1) * x1) / + (normalized_width - 1); + matrix[1] = (x3 - x0 + matrix[7] * (normalized_height - 1) * x3) / + (normalized_height - 1); + matrix[2] = x0; +} + +/** + * Get the source coordinates in the input feature map. + * + * (u, v, w)^matrix = matrix * (out_w, out_h, 1)^matrix + * + * in_w = u / w + * in_h = v / w + * + */ +template +void get_source_coords(T matrix[], int out_w, int out_h, T* in_w, T* in_h) { + T u = matrix[0] * out_w + matrix[1] * out_h + matrix[2]; + T v = matrix[3] * out_w + matrix[4] * out_h + matrix[5]; + T w = matrix[6] * out_w + matrix[7] * out_h + matrix[8]; + + in_w[0] = u / w; + in_h[0] = v / w; +} + +/** + * Perform bilinear interpolation in the input feature map. + */ +template +void bilinear_interpolate(const T* in_data, const int channels, const int width, + const int height, int in_n, int in_c, T in_w, T in_h, + T* val) { + // Deal with cases that source coords are out of feature map boundary + if (GT(-0.5, in_w) || GT(in_w, width - 0.5) || GT(-0.5, in_h) || + GT(in_h, height - 0.5)) { + // empty + val[0] = 0.0; + return; + } + + if (GT(0, in_w)) { + in_w = 0; + } + if (GT(0, in_h)) { + in_h = 0; + } + + int in_w_floor = floor(in_w); + int in_h_floor = floor(in_h); + int in_w_ceil; + int in_h_ceil; + + if (GT_E(in_w_floor, width - 1)) { + in_w_ceil = in_w_floor = width - 1; + in_w = static_cast(in_w_floor); + } else { + in_w_ceil = in_w_floor + 1; + } + + if (GT_E(in_h_floor, height - 1)) { + in_h_ceil = in_h_floor = height - 1; + in_h = static_cast(in_h_floor); + } else { + in_h_ceil = in_h_floor + 1; + } + T w_floor = in_w - in_w_floor; + T h_floor = in_h - in_h_floor; + T w_ceil = 1 - w_floor; + T h_ceil = 1 - h_floor; + const T* data = in_data + (in_n * channels + in_c) * height * width; + // Do bilinear interpolation + T v1 = data[in_h_floor * width + in_w_floor]; + T v2 = data[in_h_ceil * width + in_w_floor]; + T v3 = data[in_h_ceil * width + in_w_ceil]; + T v4 = data[in_h_floor * width + in_w_ceil]; + T w1 = w_ceil * h_ceil; + T w2 = w_ceil * h_floor; + T w3 = w_floor * h_floor; + T w4 = w_floor * h_ceil; + val[0] = w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4; +} + +template +class CPUROIPerspectiveTransformOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* in = ctx.Input("X"); + auto* rois = ctx.Input("ROIs"); + auto* out = ctx.Output("Out"); + + auto transformed_height = ctx.Attr("transformed_height"); + auto transformed_width = ctx.Attr("transformed_width"); + auto spatial_scale = ctx.Attr("spatial_scale"); + + auto in_dims = in->dims(); + int channels = in_dims[1]; + int in_height = in_dims[2]; + int in_width = in_dims[3]; + int rois_num = rois->dims()[0]; + + const T* input_data = in->data(); + + framework::Tensor roi2image; + roi2image.Resize({rois_num}); + int* roi2image_data = roi2image.mutable_data(ctx.GetPlace()); + auto lod = rois->lod().back(); + for (int i = 0; i < lod.size() - 1; ++i) { + for (int j = lod[i]; j < lod[i + 1]; ++j) { + roi2image_data[j] = i; + } + } + + T* output_data = out->mutable_data(ctx.GetPlace()); + const T* rois_data = rois->data(); + + for (int n = 0; n < rois_num; ++n) { + const T* n_rois = rois_data + n * 8; + T roi_x[4]; + T roi_y[4]; + for (int k = 0; k < 4; ++k) { + roi_x[k] = n_rois[2 * k] * spatial_scale; + roi_y[k] = n_rois[2 * k + 1] * spatial_scale; + } + int image_id = roi2image_data[n]; + // Get transform matrix + T transform_matrix[9]; + get_transform_matrix(transformed_width, transformed_height, roi_x, + roi_y, transform_matrix); + + for (int c = 0; c < channels; ++c) { + for (int out_h = 0; out_h < transformed_height; ++out_h) { + for (int out_w = 0; out_w < transformed_width; ++out_w) { + int out_index = + n * channels * transformed_height * transformed_width + + c * transformed_height * transformed_width + + out_h * transformed_width + out_w; + T in_w, in_h; + get_source_coords(transform_matrix, out_w, out_h, &in_w, &in_h); + if (in_quad(in_w, in_h, roi_x, roi_y)) { + if (GT(-0.5, in_w) || + GT(in_w, static_cast(in_width - 0.5)) || + GT(-0.5, in_h) || + GT(in_h, static_cast(in_height - 0.5))) { + output_data[out_index] = 0.0; + } else { + bilinear_interpolate(input_data, channels, in_width, in_height, + image_id, c, in_w, in_h, + output_data + out_index); + } + } else { + output_data[out_index] = 0.0; + } + } + } + } + } + } +}; + +template +T get_feature_gradient(T xs, T ys, int w, int h, const int width, + const int height) { + if (GT(-0.5, xs) || GT(xs, width - 0.5) || GT(-0.5, ys) || + GT(ys, height - 0.5)) { + return 0; + } + + if (GT(0, xs)) { + xs = 0; + } + if (GT(0, ys)) { + ys = 0; + } + + int xs_floor = floor(xs); + int ys_floor = floor(ys); + int xs_ceil; + int ys_ceil; + + if (GT_E(xs_floor, width - 1)) { + xs_ceil = xs_floor = width - 1; + xs = static_cast(xs_floor); + } else { + xs_ceil = xs_floor + 1; + } + + if (GT_E(ys_floor, height - 1)) { + ys_ceil = ys_floor = height - 1; + ys = static_cast(ys_floor); + } else { + ys_ceil = ys_floor + 1; + } + + T weight = 0; + if (w == xs_floor) { + if (h == ys_floor) { + weight = (w + 1 - xs) * (h + 1 - ys); + } else if (h == ys_ceil) { + weight = (w + 1 - xs) * (ys + 1 - h); + } + } else if (w == xs_ceil) { + if (h == ys_floor) { + weight = (xs + 1 - w) * (h + 1 - ys); + } else if (h == ys_ceil) { + weight = (xs + 1 - w) * (ys + 1 - h); + } + } + return weight; +} + +template +class CPUROIPerspectiveTransformGradOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* in = ctx.Input("X"); + auto* rois = ctx.Input("ROIs"); + auto* out_grad = + ctx.Input(framework::GradVarName("Out")); + auto* in_grad = ctx.Output(framework::GradVarName("X")); + + auto transformed_height = ctx.Attr("transformed_height"); + auto transformed_width = ctx.Attr("transformed_width"); + auto spatial_scale = ctx.Attr("spatial_scale"); + + auto in_dims = in->dims(); + int batch_size = in_dims[0]; + int channels = in_dims[1]; + int in_height = in_dims[2]; + int in_width = in_dims[3]; + int rois_num = rois->dims()[0]; + + T* in_grad_data = in_grad->mutable_data(ctx.GetPlace()); + const T* out_grad_data = out_grad->data(); + const T* rois_data = rois->data(); + + framework::Tensor roi2image; + roi2image.Resize({rois_num}); + int* roi2image_data = roi2image.mutable_data(ctx.GetPlace()); + auto lod = rois->lod().back(); + for (int i = 0; i < lod.size() - 1; ++i) { + for (int j = lod[i]; j < lod[i + 1]; ++j) { + roi2image_data[j] = i; + } + } + + for (int n = 0; n < batch_size; ++n) { + for (int c = 0; c < channels; ++c) { + for (int in_h = 0; in_h < in_height; ++in_h) { + for (int in_w = 0; in_w < in_width; ++in_w) { + T gradient = 0.0; + for (int roi_idx = lod[n]; roi_idx < lod[n + 1]; ++roi_idx) { + const T* rois = rois_data + roi_idx * 8; + T roi_x[4]; + T roi_y[4]; + for (int k = 0; k < 4; ++k) { + roi_x[k] = rois[2 * k] * spatial_scale; + roi_y[k] = rois[2 * k + 1] * spatial_scale; + } + + // Get transform matrix + T matrix[9]; + get_transform_matrix(transformed_width, transformed_height, + roi_x, roi_y, matrix); + const T* out_grad_ptr = out_grad_data + + (roi_idx * channels + c) * + transformed_height * + transformed_width; + for (int out_h = 0; out_h < transformed_height; ++out_h) { + for (int out_w = 0; out_w < transformed_width; ++out_w) { + T src_w; + T src_h; + get_source_coords(matrix, out_w, out_h, &src_w, &src_h); + if (in_quad(src_w, src_h, roi_x, roi_y)) { + if (GT(-0.5, src_w) || + GT(src_w, static_cast(in_width - 0.5)) || + GT(-0.5, src_h) || + GT(src_h, static_cast(in_height - 0.5))) { + continue; + } + T weight = get_feature_gradient(src_w, src_h, in_w, in_h, + in_width, in_height); + gradient += + out_grad_ptr[out_h * transformed_width + out_w] * + weight; + } + } + } + } + int out_idx = (n * channels + c) * in_height * in_width + + in_h * in_width + in_w; + in_grad_data[out_idx] = gradient; + } + } + } + } + } +}; + +class ROIPerspectiveTransformOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of ROIPerspectiveTransformOp should not be null."); + PADDLE_ENFORCE( + ctx->HasInput("ROIs"), + "Input(ROIs) of ROIPerspectiveTransformOp should not be null."); + PADDLE_ENFORCE( + ctx->HasOutput("Out"), + "Output(Out) of ROIPerspectiveTransformOp should not be null."); + auto input_dims = ctx->GetInputDim("X"); + auto rois_dims = ctx->GetInputDim("ROIs"); + + PADDLE_ENFORCE(input_dims.size() == 4, + "The format of input tensor is NCHW."); + PADDLE_ENFORCE(rois_dims.size() == 2, + "ROIs should be a 2-D LoDTensor of shape (num_rois, 8)" + "given as [[x0, y0, x1, y1, x2, y2, x3, y3], ...]"); + PADDLE_ENFORCE(rois_dims[1] == 8, + "ROIs should be a 2-D LoDTensor of shape (num_rois, 8)" + "given as [[x0, y0, x1, y1, x2, y2, x3, y3], ...]."); + + int transformed_height = ctx->Attrs().Get("transformed_height"); + int transformed_width = ctx->Attrs().Get("transformed_width"); + float spatial_scale = ctx->Attrs().Get("spatial_scale"); + + PADDLE_ENFORCE_GT(transformed_height, 0, + "The transformed output height must greater than 0"); + PADDLE_ENFORCE_GT(transformed_width, 0, + "The transformed output width must greater than 0"); + PADDLE_ENFORCE_GT(spatial_scale, 0.0f, + "The spatial scale must greater than 0"); + std::vector out_dims_v({rois_dims[0], // num_rois + input_dims[1], // channels + static_cast(transformed_height), + static_cast(transformed_width)}); + auto out_dims = framework::make_ddim(out_dims_v); + + ctx->SetOutputDim("Out", out_dims); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + framework::ToDataType(ctx.Input("X")->type()), + ctx.device_context()); + } +}; + +class ROIPerspectiveTransformGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), + "The gradient of Out should not be null."); + PADDLE_ENFORCE(ctx->HasOutputs(framework::GradVarName("X")), + "The gradient of X should not be null."); + ctx->SetOutputsDim(framework::GradVarName("X"), ctx->GetInputsDim("X")); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + framework::ToDataType(ctx.Input("X")->type()), + ctx.device_context()); + } +}; + +class ROIPerspectiveTransformOpMaker + : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", + "(Tensor), " + "the input of ROIPerspectiveTransformOp. " + "The format of input tensor is NCHW. Where N is batch size, " + "C is the number of input channels, " + "H is the height of the feature, and " + "W is the width of the feature."); + AddInput("ROIs", + "(LoDTensor), " + "ROIs (Regions of Interest) to be transformed. " + "should be a 2-D LoDTensor of shape (num_rois, 8)" + "given as [[x1, y1, x2, y2, x3, y3, x4, y4], ...]." + "(x1, y1) is the top left coordinates, and " + "(x2, y2) is the top right coordinates, and" + "(x3, y3) is the bottom right coordinates, and" + "(x4, y4) is the bottom left coordinates."); + AddOutput( + "Out", + "(Tensor), " + "The output of ROIPerspectiveTransformOp is a 4-D tensor with shape " + "(num_rois, channels, transformed_h, transformed_w)."); + AddAttr("spatial_scale", + "(float, default 1.0), " + "Spatial scale factor to scale ROI coords.") + .SetDefault(1.0); + AddAttr("transformed_height", + "(int, default 1), " + "The height of transformed output.") + .SetDefault(1); + AddAttr("transformed_width", + "(int, default 1), " + "The width of transformed output.") + .SetDefault(1); + AddComment(R"DOC( +**ROIPerspectiveTransform Operator** + + )DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(roi_perspective_transform, ops::ROIPerspectiveTransformOp, + ops::ROIPerspectiveTransformOpMaker, + paddle::framework::DefaultGradOpDescMaker); +REGISTER_OPERATOR(roi_perspective_transform_grad, + ops::ROIPerspectiveTransformGradOp); +REGISTER_OP_CPU_KERNEL(roi_perspective_transform, + ops::CPUROIPerspectiveTransformOpKernel); +REGISTER_OP_CPU_KERNEL(roi_perspective_transform_grad, + ops::CPUROIPerspectiveTransformGradOpKernel); diff --git a/paddle/fluid/operators/detection/roi_perspective_transform_op.cu b/paddle/fluid/operators/detection/roi_perspective_transform_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..b683b7573db747bde5f57e530ec53760db099843 --- /dev/null +++ b/paddle/fluid/operators/detection/roi_perspective_transform_op.cu @@ -0,0 +1,523 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/platform/cuda_primitives.h" + +namespace paddle { +namespace operators { + +// CUDA: index helpers +#define idx4_4(index, d1, d2, d3, d4) (index % d4) +#define idx4_3(index, d1, d2, d3, d4) ((index / d4) % d3) +#define idx4_2(index, d1, d2, d3, d4) ((index / d4 / d3) % d2) +#define idx4_1(index, d1, d2, d3, d4) ((index / d4 / d3 / d2) % d1) + +#define CUDA_1D_KERNEL_LOOP(i, n) \ + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \ + i += blockDim.x * gridDim.x) + +template +__device__ bool GT_E(T a, T b) { + return (a > b) || fabs(a - b) < 1e-4; +} + +template +__device__ bool LT_E(T a, T b) { + return (a < b) || fabs(a - b) < 1e-4; +} + +template +__device__ bool GT(T a, T b) { + return (a - b) > 1e-4; +} + +template +__device__ T max(T a, T b) { + return a > b ? a : b; +} + +template +__device__ T min(T a, T b) { + return a < b ? a : b; +} + +/* +* check if (x, y) is in the boundary of roi +*/ +template +__device__ bool in_quad(T x, T y, T roi_x[], T roi_y[]) { + for (int i = 0; i < 4; i++) { + T start_w = roi_x[i]; + T start_h = roi_y[i]; + T end_w = roi_x[(i + 1) % 4]; + T end_h = roi_y[(i + 1) % 4]; + if (fabs(start_h - end_h) < 1e-4) { + if (fabs(y - start_h) < 1e-4 && fabs(y - end_h) < 1e-4 && + GT_E(x, min(start_w, end_w)) && + LT_E(x, max(start_w, end_w))) { + return true; + } + } else { + T intersec_x = + (y - start_h) * (end_w - start_w) / (end_h - start_h) + start_w; + if (fabs(intersec_x - x) < 1e-4 && GT_E(y, min(start_h, end_h)) && + LT_E(y, max(start_h, end_h))) { + return true; + } + } + } + + int n_cross = 0; + for (int i = 0; i < 4; i++) { + T start_w = roi_x[i]; + T start_h = roi_y[i]; + T end_w = roi_x[(i + 1) % 4]; + T end_h = roi_y[(i + 1) % 4]; + if (fabs(start_h - end_h) < 1e-4) { + continue; + } + if (LT_E(y, min(start_h, end_h)) || + GT(y, max(start_h, end_h))) { + continue; + } + T intersec_x = + (y - start_h) * (end_w - start_w) / (end_h - start_h) + start_w; + if (fabs(intersec_x - x) < 1e-4) { + return true; + } + if (GT(intersec_x, x)) { + n_cross++; + } + } + return (n_cross % 2 == 1); +} + +/** + * Perform bilinear interpolation in the input feature map. + */ +template +__device__ void bilinear_interpolate(const T* in_data, const int channels, + const int width, const int height, + int in_n, int in_c, T in_w, T in_h, + T* val) { + // Deal with cases that source coords are out of feature map boundary + if (GT(-0.5, in_w) || GT(in_w, width - 0.5) || GT(-0.5, in_h) || + GT(in_h, height - 0.5)) { + val[0] = 0.0; + return; + } + + if (GT(0, in_w)) { + in_w = 0; + } + if (GT(0, in_h)) { + in_h = 0; + } + + int in_w_floor = floor(in_w); + int in_h_floor = floor(in_h); + int in_w_ceil; + int in_h_ceil; + + if (GT_E(in_w_floor, width - 1)) { + in_w_ceil = in_w_floor = width - 1; + in_w = static_cast(in_w_floor); + } else { + in_w_ceil = in_w_floor + 1; + } + + if (GT_E(in_h_floor, height - 1)) { + in_h_ceil = in_h_floor = height - 1; + in_h = static_cast(in_h_floor); + } else { + in_h_ceil = in_h_floor + 1; + } + + T w_floor = in_w - in_w_floor; + T h_floor = in_h - in_h_floor; + T w_ceil = 1 - w_floor; + T h_ceil = 1 - h_floor; + const T* data = in_data + (in_n * channels + in_c) * height * width; + // Do bilinear interpolation + T v1 = data[in_h_floor * width + in_w_floor]; + T v2 = data[in_h_ceil * width + in_w_floor]; + T v3 = data[in_h_ceil * width + in_w_ceil]; + T v4 = data[in_h_floor * width + in_w_ceil]; + T w1 = w_ceil * h_ceil; + T w2 = w_ceil * h_floor; + T w3 = w_floor * h_floor; + T w4 = w_floor * h_ceil; + val[0] = w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4; +} + +/** + * Get the source coordinates in the input feature map. + * + * (u, v, w)^matrix = T * (out_w, out_h, 1)^matrix + * + * in_w = u / w + * in_h = v / w + * + */ +template +__device__ void get_source_coords(T matrix[], int out_w, int out_h, T* in_w, + T* in_h) { + T u = matrix[0] * out_w + matrix[1] * out_h + matrix[2]; + T v = matrix[3] * out_w + matrix[4] * out_h + matrix[5]; + T w = matrix[6] * out_w + matrix[7] * out_h + matrix[8]; + + in_w[0] = u / w; + in_h[0] = v / w; +} + +/** + * Get the matrix of perspective transform. + * + * dx1 = x1 - x2 + * dx2 = x3 - x2 + * dx3 = x0 - x1 + x2 - x3 + * dy1 = y1 - y2 + * dy2 = y3 - y2 + * dy3 = y0 - y1 + y2 - y3 + * + * a11 = (x1 - x0 + a31 * (w - 1) * x1) / (w - 1) + * a12 = (x3 - x0 + a32 * (h - 1) * x3) / (h - 1) + * a13 = x0 + * a21 = (y1 - y0 + a31 * (w - 1) * y1) / (w - 1) + * a22 = (y3 - y0 + a32 * (h - 1) * y3) / (h - 1) + * a23 = y0 + * a31 = (dx3 * dy2 - dx2 * dy3) / (dx1 * dy2 - dx2 * dy1) / (w - 1) + * a32 = (dx1 * dy3 - dx3 * dy1) / (dx1 * dy2 - dx2 * dy1) / (h - 1) + * a33 = 1 + * + */ +template +__device__ void get_transform_matrix(const int transformed_width, + const int transformed_height, T roi_x[], + T roi_y[], T matrix[]) { + T x0 = roi_x[0]; + T x1 = roi_x[1]; + T x2 = roi_x[2]; + T x3 = roi_x[3]; + T y0 = roi_y[0]; + T y1 = roi_y[1]; + T y2 = roi_y[2]; + T y3 = roi_y[3]; + + // Estimate the height and width of RoI + T len1 = sqrt((x0 - x1) * (x0 - x1) + (y0 - y1) * (y0 - y1)); + T len2 = sqrt((x1 - x2) * (x1 - x2) + (y1 - y2) * (y1 - y2)); + T len3 = sqrt((x2 - x3) * (x2 - x3) + (y2 - y3) * (y2 - y3)); + T len4 = sqrt((x3 - x0) * (x3 - x0) + (y3 - y0) * (y3 - y0)); + T estimated_height = (len2 + len4) / 2.0; + T estimated_width = (len1 + len3) / 2.0; + + // Get the normalized height and normalized width + int normalized_height = transformed_height; + int normalized_width = + round(estimated_width * (normalized_height - 1) / estimated_height) + 1; + normalized_width = min(normalized_width, transformed_width); + + T dx1 = x1 - x2; + T dx2 = x3 - x2; + T dx3 = x0 - x1 + x2 - x3; + T dy1 = y1 - y2; + T dy2 = y3 - y2; + T dy3 = y0 - y1 + y2 - y3; + + matrix[6] = (dx3 * dy2 - dx2 * dy3) / (dx1 * dy2 - dx2 * dy1) / + (normalized_width - 1); + matrix[7] = (dx1 * dy3 - dx3 * dy1) / (dx1 * dy2 - dx2 * dy1) / + (normalized_height - 1); + matrix[8] = 1; + + matrix[3] = (y1 - y0 + matrix[6] * (normalized_width - 1) * y1) / + (normalized_width - 1); + matrix[4] = (y3 - y0 + matrix[7] * (normalized_height - 1) * y3) / + (normalized_height - 1); + matrix[5] = y0; + + matrix[0] = (x1 - x0 + matrix[6] * (normalized_width - 1) * x1) / + (normalized_width - 1); + matrix[1] = (x3 - x0 + matrix[7] * (normalized_height - 1) * x3) / + (normalized_height - 1); + matrix[2] = x0; +} + +template +__global__ void RoiTransformKernel(const float* input_data, + const float* rois_data, + const int* roi2image_data, int num_rois, + int in_height, int in_width, int channels, + int transformed_height, + int transformed_width, float spatial_scale, + T* output_data) { + int output_size = + num_rois * transformed_height * transformed_width * channels; + + CUDA_1D_KERNEL_LOOP(index, output_size) { + // (n, c, out_h, out_w) is an element in the transformed output + int out_w = idx4_4(index, num_rois, channels, transformed_height, + transformed_width); + int out_h = idx4_3(index, num_rois, channels, transformed_height, + transformed_width); + int c = idx4_2(index, num_rois, channels, transformed_height, + transformed_width); + int n = idx4_1(index, num_rois, channels, transformed_height, + transformed_width); + + auto bottom_rois = rois_data + n * 8; + int roi_batch_ind = bottom_rois[0]; + T roi_x[4]; + T roi_y[4]; + for (int k = 0; k < 4; ++k) { + roi_x[k] = bottom_rois[2 * k] * spatial_scale; + roi_y[k] = bottom_rois[2 * k + 1] * spatial_scale; + } + + // Get transform matrix + T matrix[9]; + get_transform_matrix(transformed_width, transformed_height, roi_x, roi_y, + matrix); + + // Get source coords + T in_w; + T in_h; + get_source_coords(matrix, out_w, out_h, &in_w, &in_h); + + if (in_quad(in_w, in_h, roi_x, roi_y)) { + if (GT(-0.5, in_w) || GT(in_w, static_cast(in_width - 0.5)) || + GT(-0.5, in_h) || GT(in_h, static_cast(in_height - 0.5))) { + // Skip if source coords is not in input image + output_data[index] = 0.0; + } else { + // Perform bilinear interpolation + int in_n = roi2image_data[n]; + bilinear_interpolate(input_data, channels, in_width, in_height, in_n, + c, in_w, in_h, output_data + index); + } + + } else { + // Skip if source coords is not in quad + output_data[index] = 0.0; + } + } +} + +template +class CUDAROIPerspectiveTransformOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* in = ctx.Input("X"); + auto* rois = ctx.Input("ROIs"); + auto* out = ctx.Output("Out"); + + auto transformed_height = ctx.Attr("transformed_height"); + auto transformed_width = ctx.Attr("transformed_width"); + auto spatial_scale = ctx.Attr("spatial_scale"); + + auto in_dims = in->dims(); + int batch_size = in_dims[0]; + int channels = in_dims[1]; + int in_height = in_dims[2]; + int in_width = in_dims[3]; + int rois_num = rois->dims()[0]; + + const T* input_data = in->data(); + T* output_data = out->mutable_data(ctx.GetPlace()); + const T* rois_data = rois->data(); + + framework::Tensor roi2image; + framework::Tensor roi2image_dev; + roi2image.Resize({rois_num}); + int* roi2image_data = roi2image.mutable_data(platform::CPUPlace()); + auto lod = rois->lod().back(); + for (int i = 0; i < lod.size() - 1; ++i) { + for (int j = lod[i]; j < lod[i + 1]; ++j) { + roi2image_data[j] = i; + } + } + TensorCopySync(roi2image, ctx.GetPlace(), &roi2image_dev); + + int out_size = rois_num * transformed_height * transformed_width * channels; + auto stream = ctx.cuda_device_context().stream(); + int block = 512; + int grid = (out_size + block - 1) / block; + + RoiTransformKernel<<>>( + input_data, rois_data, roi2image_dev.data(), rois_num, in_height, + in_width, channels, transformed_height, transformed_width, + spatial_scale, output_data); + } +}; + +template +__device__ T get_feature_gradient(T xs, T ys, int w, int h, const int width, + const int height) { + if (GT(-0.5, xs) || GT(xs, width - 0.5) || GT(-0.5, ys) || + GT(ys, height - 0.5)) { + return 0; + } + + if (GT(0, xs)) { + xs = 0; + } + if (GT(0, ys)) { + ys = 0; + } + + int xs_floor = floor(xs); + int ys_floor = floor(ys); + int xs_ceil; + int ys_ceil; + + if (GT_E(xs_floor, width - 1)) { + xs_ceil = xs_floor = width - 1; + xs = static_cast(xs_floor); + } else { + xs_ceil = xs_floor + 1; + } + + if (GT_E(ys_floor, height - 1)) { + ys_ceil = ys_floor = height - 1; + ys = static_cast(ys_floor); + } else { + ys_ceil = ys_floor + 1; + } + + T weight = 0; + if (w == xs_floor) { + if (h == ys_floor) { + weight = (w + 1 - xs) * (h + 1 - ys); + } else if (h == ys_ceil) { + weight = (w + 1 - xs) * (ys + 1 - h); + } + } else if (w == xs_ceil) { + if (h == ys_floor) { + weight = (xs + 1 - w) * (h + 1 - ys); + } else if (h == ys_ceil) { + weight = (xs + 1 - w) * (ys + 1 - h); + } + } + return weight; +} + +template +__global__ void RoiTransformGradKernel( + const size_t* lod, const T* rois_data, int batch_size, int num_rois, + int in_height, int in_width, int channels, int transformed_height, + int transformed_width, float spatial_scale, const T* out_grad_data, + T* in_grad_data) { + int input_size = batch_size * in_height * in_width * channels; + + CUDA_1D_KERNEL_LOOP(index, input_size) { + // (n, c, h, w) coords in input + int in_w = idx4_4(index, batch_size, channels, in_height, in_width); + int in_h = idx4_3(index, batch_size, channels, in_height, in_width); + int c = idx4_2(index, batch_size, channels, in_height, in_width); + int n = idx4_1(index, batch_size, channels, in_height, in_width); + + T gradient = 0.0; + // Accumulate gradient over all RoIs that interpolated this element + for (int roi_idx = lod[n]; roi_idx < lod[n + 1]; ++roi_idx) { + const T* rois = rois_data + roi_idx * 8; + T roi_x[4]; + T roi_y[4]; + for (int k = 0; k < 4; ++k) { + roi_x[k] = rois[2 * k] * spatial_scale; + roi_y[k] = rois[2 * k + 1] * spatial_scale; + } + + // Get transform matrix + T matrix[9]; + get_transform_matrix(transformed_width, transformed_height, roi_x, + roi_y, matrix); + + const T* out_grad_ptr = + out_grad_data + + (roi_idx * channels + c) * transformed_height * transformed_width; + for (int out_h = 0; out_h < transformed_height; ++out_h) { + for (int out_w = 0; out_w < transformed_width; ++out_w) { + T src_w; + T src_h; + get_source_coords(matrix, out_w, out_h, &src_w, &src_h); + if (in_quad(src_w, src_h, roi_x, roi_y)) { + if (GT(-0.5, src_w) || + GT(src_w, static_cast(in_width - 0.5)) || + GT(-0.5, src_h) || + GT(src_h, static_cast(in_height - 0.5))) { + continue; + } + T weight = get_feature_gradient(src_w, src_h, in_w, in_h, + in_width, in_height); + gradient += + out_grad_ptr[out_h * transformed_width + out_w] * weight; + } + } + } + } + in_grad_data[index] = gradient; + } +} + +template +class CUDAROIPerspectiveTransformGradOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* in = ctx.Input("X"); + auto* rois = ctx.Input("ROIs"); + auto* out_grad = + ctx.Input(framework::GradVarName("Out")); + auto* in_grad = ctx.Output(framework::GradVarName("X")); + + auto transformed_height = ctx.Attr("transformed_height"); + auto transformed_width = ctx.Attr("transformed_width"); + auto spatial_scale = ctx.Attr("spatial_scale"); + + auto in_dims = in->dims(); + int batch_size = in_dims[0]; + int channels = in_dims[1]; + int in_height = in_dims[2]; + int in_width = in_dims[3]; + int rois_num = rois->dims()[0]; + + T* in_grad_data = in_grad->mutable_data(ctx.GetPlace()); + const T* out_grad_data = out_grad->data(); + const T* rois_data = rois->data(); + + auto lod = rois->lod().back(); + auto lod_data = lod.CUDAData(ctx.GetPlace()); + + int in_size = in->numel(); + auto stream = ctx.cuda_device_context().stream(); + int block = 512; + int grid = (in_size + block - 1) / block; + + RoiTransformGradKernel<<>>( + lod_data, rois_data, batch_size, rois_num, in_height, in_width, + channels, transformed_height, transformed_width, spatial_scale, + out_grad_data, in_grad_data); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL(roi_perspective_transform, + ops::CUDAROIPerspectiveTransformOpKernel); +REGISTER_OP_CUDA_KERNEL(roi_perspective_transform_grad, + ops::CUDAROIPerspectiveTransformGradOpKernel); diff --git a/python/paddle/fluid/__init__.py b/python/paddle/fluid/__init__.py index 1ca2ac2ddc7daef3f4c0ea2004a62258ae4610ac..9e4a5ae8baaf7f2975c8060856f9eecab55f241c 100644 --- a/python/paddle/fluid/__init__.py +++ b/python/paddle/fluid/__init__.py @@ -46,7 +46,7 @@ from . import transpiler from .param_attr import ParamAttr, WeightNormParamAttr from .data_feeder import DataFeeder from .core import LoDTensor, LoDTensorArray, CPUPlace, CUDAPlace, CUDAPinnedPlace, Scope -from .transpiler import DistributeTranspiler, InferenceTranspiler, \ +from .transpiler import DistributeTranspiler, \ memory_optimize, release_memory, DistributeTranspilerConfig from .lod_tensor import create_lod_tensor, create_random_int_lodtensor from . import clip diff --git a/python/paddle/fluid/io.py b/python/paddle/fluid/io.py index 78bb8a1a0a64631cbe2adc11b1494ceed6d14908..e703e5ac7943b006741f12886a14bf344a6b9b28 100644 --- a/python/paddle/fluid/io.py +++ b/python/paddle/fluid/io.py @@ -27,8 +27,7 @@ from . import core __all__ = [ 'save_vars', 'save_params', 'save_persistables', 'load_vars', 'load_params', - 'load_persistables', 'save_inference_model', 'load_inference_model', - 'get_inference_program' + 'load_persistables', 'save_inference_model', 'load_inference_model' ] @@ -504,23 +503,6 @@ def load_persistables(executor, dirname, main_program=None, filename=None): filename=filename) -def get_inference_program(target_vars, main_program=None): - if main_program is None: - main_program = default_main_program() - if not isinstance(target_vars, list): - target_vars = [target_vars] - vars = [] - for var in target_vars: - if isinstance(var, Evaluator): - vars.extend(var.states) - vars.extend(var.metrics) - else: - vars.append(var) - pruned_program = main_program._prune(targets=vars) - inference_program = pruned_program._inference_optimize() - return inference_program - - def prepend_feed_ops(inference_program, feed_target_names, feed_holder_name='feed'): diff --git a/python/paddle/fluid/layers/detection.py b/python/paddle/fluid/layers/detection.py index 8e86bec8609f17c973389047df66b4d725113e6e..574d0d727cba9fa9de0cffbe116f71b9e65a7092 100644 --- a/python/paddle/fluid/layers/detection.py +++ b/python/paddle/fluid/layers/detection.py @@ -39,6 +39,7 @@ __all__ = [ 'detection_map', 'rpn_target_assign', 'anchor_generator', + 'roi_perspective_transform', 'generate_proposal_labels', 'generate_proposals', ] @@ -1262,6 +1263,54 @@ def anchor_generator(input, return anchor, var +def roi_perspective_transform(input, + rois, + transformed_height, + transformed_width, + spatial_scale=1.0): + """ + ROI perspective transform op. + + Args: + input (Variable): The input of ROIPerspectiveTransformOp. The format of + input tensor is NCHW. Where N is batch size, C is the + number of input channels, H is the height of the feature, + and W is the width of the feature. + rois (Variable): ROIs (Regions of Interest) to be transformed. It should be + a 2-D LoDTensor of shape (num_rois, 8). Given as + [[x1, y1, x2, y2, x3, y3, x4, y4], ...], (x1, y1) is the + top left coordinates, and (x2, y2) is the top right + coordinates, and (x3, y3) is the bottom right coordinates, + and (x4, y4) is the bottom left coordinates. + transformed_height (integer): The height of transformed output. + transformed_height (integer): The width of transformed output. + spatial_scale (float): Spatial scale factor to scale ROI coords. Default: 1.0 + + Returns: + Variable: The output of ROIPerspectiveTransformOp which is a 4-D tensor with shape + (num_rois, channels, transformed_h, transformed_w). + + Examples: + .. code-block:: python + + out = fluid.layers.roi_perspective_transform(input, rois, 7, 7, 1.0) + """ + helper = LayerHelper('roi_perspective_transform', **locals()) + dtype = helper.input_dtype() + out = helper.create_tmp_variable(dtype) + helper.append_op( + type="roi_perspective_transform", + inputs={"X": input, + "ROIs": rois}, + outputs={"Out": out}, + attrs={ + "transformed_height": transformed_height, + "transformed_width": transformed_width, + "spatial_scale": spatial_scale + }) + return out + + def generate_proposal_labels(rpn_rois, gt_classes, is_crowd, diff --git a/python/paddle/fluid/tests/unittests/dist_transformer.py b/python/paddle/fluid/tests/unittests/dist_transformer.py index 3ec79f8ef6e6f70f1365eaa32352c284d294a1ea..175bd130e5a8324227953eeeb769474e78f94fd2 100644 --- a/python/paddle/fluid/tests/unittests/dist_transformer.py +++ b/python/paddle/fluid/tests/unittests/dist_transformer.py @@ -437,13 +437,8 @@ def split_data(data, num_part): ] -def test_context(train_progm, avg_cost, train_exe, dev_count, data_input_names, +def test_context(test_program, avg_cost, train_exe, dev_count, data_input_names, sum_cost, token_num): - # Context to do validation. - test_program = train_progm.clone() - with fluid.program_guard(test_program): - test_program = fluid.io.get_inference_program([avg_cost]) - val_data = DataReader( src_vocab_fpath=TrainTaskConfig.src_vocab_fpath, trg_vocab_fpath=TrainTaskConfig.trg_vocab_fpath, @@ -505,7 +500,7 @@ def test_context(train_progm, avg_cost, train_exe, dev_count, data_input_names, def train_loop(exe, train_progm, dev_count, sum_cost, avg_cost, lr_scheduler, - token_num, predict): + token_num, predict, test_program): # Initialize the parameters. if TrainTaskConfig.ckpt_path: lr_scheduler.current_steps = TrainTaskConfig.start_step @@ -554,7 +549,7 @@ def train_loop(exe, train_progm, dev_count, sum_cost, avg_cost, lr_scheduler, -1] + label_data_input_fields if TrainTaskConfig.val_file_pattern is not None: - test = test_context(train_progm, avg_cost, train_exe, dev_count, + test = test_context(test_program, avg_cost, train_exe, dev_count, data_input_names, sum_cost, token_num) # the best cross-entropy value with label smoothing @@ -1647,6 +1642,8 @@ def get_model(is_dist, is_async): local_lr_scheduler = LearningRateScheduler(ModelHyperParams.d_model, TrainTaskConfig.warmup_steps, TrainTaskConfig.learning_rate) + # Context to do validation. + test_program = fluid.default_main_program().clone(for_test=True) if not is_dist: optimizer = fluid.optimizer.Adam( @@ -1671,7 +1668,7 @@ def get_model(is_dist, is_async): epsilon=TrainTaskConfig.eps) optimizer.minimize(sum_cost) - return sum_cost, avg_cost, predict, token_num, local_lr_scheduler + return sum_cost, avg_cost, predict, token_num, local_lr_scheduler, test_program def update_args(): @@ -1705,7 +1702,7 @@ class DistTransformer2x2(TestDistRunnerBase): def run_trainer(self, use_cuda, args): place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() TrainTaskConfig.use_gpu = use_cuda - sum_cost, avg_cost, predict, token_num, local_lr_scheduler = get_model( + sum_cost, avg_cost, predict, token_num, local_lr_scheduler, test_program = get_model( args.is_dist, not args.sync_mode) if args.is_dist: @@ -1726,7 +1723,7 @@ class DistTransformer2x2(TestDistRunnerBase): TrainTaskConfig.local = not args.is_dist train_loop(startup_exe, trainer_prog, 1, sum_cost, avg_cost, - local_lr_scheduler, token_num, predict) + local_lr_scheduler, token_num, predict, test_program) if __name__ == "__main__": diff --git a/python/paddle/fluid/tests/unittests/test_layers.py b/python/paddle/fluid/tests/unittests/test_layers.py index c084ea059e3265e4e5aef261f738885b5aca8d57..1fe7016924696b6e47d9cc35c137004f15a9b507 100644 --- a/python/paddle/fluid/tests/unittests/test_layers.py +++ b/python/paddle/fluid/tests/unittests/test_layers.py @@ -725,6 +725,16 @@ class TestBook(unittest.TestCase): self.assertIsNotNone(out) print(str(program)) + def test_roi_perspective_transform(self): + program = Program() + with program_guard(program): + x = layers.data(name="x", shape=[256, 30, 30], dtype="float32") + rois = layers.data( + name="rois", shape=[8], dtype="float32", lod_level=1) + output = layers.roi_perspective_transform(x, rois, 7, 7, 0.6) + self.assertIsNotNone(output) + print(str(program)) + def test_sequence_enumerate(self): program = Program() with program_guard(program): diff --git a/python/paddle/fluid/tests/unittests/test_roi_perspective_transform_op.py b/python/paddle/fluid/tests/unittests/test_roi_perspective_transform_op.py new file mode 100644 index 0000000000000000000000000000000000000000..de675131564db43926f97ff4e6dedcaa02ff15b0 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_roi_perspective_transform_op.py @@ -0,0 +1,306 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License") +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUWARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +import math +import sys +import paddle.compat as cpt +from op_test import OpTest +from math import sqrt +from math import floor + + +def gt_e(a, b): + return a > b or abs(a - b) < 1e-4 + + +def gt(a, b): + return (a - b) > 1e-4 + + +def lt_e(a, b): + return a < b or abs(a - b) < 1e-4 + + +def in_quad(x, y, roi_x, roi_y): + # check if (x, y) is in the boundary of roi + for i in range(4): + xs = roi_x[i] + ys = roi_y[i] + xe = roi_x[(i + 1) % 4] + ye = roi_y[(i + 1) % 4] + if abs(ys - ye) < 1e-4: + if abs(y - ys) < 1e-4 and abs(y - ye) < 1e-4 and gt_e( + x, min(xs, xe)) and lt_e(x, max(xs, xe)): + return True + else: + intersec_x = (y - ys) * (xe - xs) / (ye - ys) + xs + if abs(intersec_x - x) < 1e-4 and gt_e(y, min(ys, ye)) and lt_e( + y, max(ys, ye)): + return True + n_cross = 0 + for i in range(4): + xs = roi_x[i] + ys = roi_y[i] + xe = roi_x[(i + 1) % 4] + ye = roi_y[(i + 1) % 4] + if abs(ys - ye) < 1e-4: + continue + if lt_e(y, min(ys, ye)) or gt(y, max(ys, ye)): + continue + intersec_x = (y - ys) * (xe - xs) / (ye - ys) + xs + if abs(intersec_x - x) < 1e-4: + return True + if gt(intersec_x, x): + n_cross += 1 + return (n_cross % 2 == 1) + + +def get_transform_matrix(transformed_width, transformed_height, roi_x, roi_y): + x0 = roi_x[0] + x1 = roi_x[1] + x2 = roi_x[2] + x3 = roi_x[3] + y0 = roi_y[0] + y1 = roi_y[1] + y2 = roi_y[2] + y3 = roi_y[3] + + len1 = sqrt((x0 - x1) * (x0 - x1) + (y0 - y1) * (y0 - y1)) + len2 = sqrt((x1 - x2) * (x1 - x2) + (y1 - y2) * (y1 - y2)) + len3 = sqrt((x2 - x3) * (x2 - x3) + (y2 - y3) * (y2 - y3)) + len4 = sqrt((x3 - x0) * (x3 - x0) + (y3 - y0) * (y3 - y0)) + estimated_height = (len2 + len4) / 2.0 + estimated_width = (len1 + len3) / 2.0 + + normalized_height = transformed_height + normalized_width = round(estimated_width * + (normalized_height - 1) / estimated_height) + 1 + normalized_width = min(normalized_width, transformed_width) + + dx1 = x1 - x2 + dx2 = x3 - x2 + dx3 = x0 - x1 + x2 - x3 + dy1 = y1 - y2 + dy2 = y3 - y2 + dy3 = y0 - y1 + y2 - y3 + matrix = np.zeros([9]) + matrix[6] = (dx3 * dy2 - dx2 * dy3) / (dx1 * dy2 - dx2 * dy1) / ( + normalized_width - 1) + matrix[7] = (dx1 * dy3 - dx3 * dy1) / (dx1 * dy2 - dx2 * dy1) / ( + normalized_height - 1) + matrix[8] = 1 + + matrix[3] = (y1 - y0 + matrix[6] * + (normalized_width - 1) * y1) / (normalized_width - 1) + matrix[4] = (y3 - y0 + matrix[7] * + (normalized_height - 1) * y3) / (normalized_height - 1) + matrix[5] = y0 + + matrix[0] = (x1 - x0 + matrix[6] * + (normalized_width - 1) * x1) / (normalized_width - 1) + matrix[1] = (x3 - x0 + matrix[7] * + (normalized_height - 1) * x3) / (normalized_height - 1) + matrix[2] = x0 + return matrix + + +def get_source_coords(matrix, out_w, out_h): + u = matrix[0] * out_w + matrix[1] * out_h + matrix[2] + v = matrix[3] * out_w + matrix[4] * out_h + matrix[5] + w = matrix[6] * out_w + matrix[7] * out_h + matrix[8] + in_w = u / w + in_h = v / w + return in_w, in_h + + +def bilinear_interpolate(in_data, in_n, in_c, in_w, in_h): + + batch_size = in_data.shape[0] + channels = in_data.shape[1] + height = in_data.shape[2] + width = in_data.shape[3] + + if gt(-0.5, in_w) or gt(in_w, width - 0.5) or gt(-0.5, in_h) or gt( + in_h, height - 0.5): + return 0.0 + + if gt(0, in_w): + in_w = 0 + if gt(0, in_h): + in_h = 0 + + in_w_floor = floor(in_w) + in_h_floor = floor(in_h) + + if gt_e(in_w_floor, width - 1): + in_w_ceil = width - 1 + in_w_floor = width - 1 + in_w = in_w_floor + else: + in_w_ceil = in_w_floor + 1 + + if gt_e(in_h_floor, height - 1): + in_h_ceil = height - 1 + in_h_floor = height - 1 + in_h = in_h_floor + else: + in_h_ceil = in_h_floor + 1 + + w_floor = in_w - in_w_floor + h_floor = in_h - in_h_floor + w_ceil = 1 - w_floor + h_ceil = 1 - h_floor + v1 = in_data[in_n][in_c][int(in_h_floor)][int(in_w_floor)] + v2 = in_data[in_n][in_c][int(in_h_ceil)][int(in_w_floor)] + v3 = in_data[in_n][in_c][int(in_h_ceil)][int(in_w_ceil)] + v4 = in_data[in_n][in_c][int(in_h_floor)][int(in_w_ceil)] + w1 = w_ceil * h_ceil + w2 = w_ceil * h_floor + w3 = w_floor * h_floor + w4 = w_floor * h_ceil + val = w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4 + return val + + +def lod_convert(lod): + ret = [0] + for count in lod: + ret.append(ret[-1] + count) + return ret + + +def roi_transform(in_data, rois, rois_lod, transformed_height, + transformed_width, spatial_scale): + channels = in_data.shape[1] + in_height = in_data.shape[2] + in_width = in_data.shape[3] + rois_num = rois.shape[0] + + roi2image = [0] * rois_num + rois_lod = lod_convert(rois_lod[0]) + for i in range(len(rois_lod) - 1): + for j in range(rois_lod[i], rois_lod[i + 1]): + roi2image[j] = i + + out = np.zeros([rois_num, channels, transformed_height, transformed_width]) + + for n in range(rois_num): + roi_x = [] + roi_y = [] + for k in range(4): + roi_x.append(rois[n][2 * k] * spatial_scale) + roi_y.append(rois[n][2 * k + 1] * spatial_scale) + image_id = roi2image[n] + transform_matrix = get_transform_matrix( + transformed_width, transformed_height, roi_x, roi_y) + + for c in range(channels): + for out_h in range(transformed_height): + for out_w in range(transformed_width): + in_w, in_h = get_source_coords(transform_matrix, out_w, + out_h) + if in_quad(in_w, in_h, roi_x, roi_y) and gt_e( + in_w, -0.5) and lt_e(in_w, in_width - 0.5) and gt_e( + in_h, -0.5) and lt_e(in_h, in_height - 0.5): + out[n][c][out_h][out_w] = bilinear_interpolate( + in_data, image_id, c, in_w, in_h) + else: + out[n][c][out_h][out_w] = 0.0 + return out.astype("float32") + + +class TestROIPoolOp(OpTest): + def set_data(self): + self.init_test_case() + self.make_rois() + + self.inputs = {'X': self.x, 'ROIs': (self.rois, self.rois_lod)} + + self.attrs = { + 'spatial_scale': self.spatial_scale, + 'transformed_height': self.transformed_height, + 'transformed_width': self.transformed_width + } + out = roi_transform(self.x, self.rois, self.rois_lod, + self.transformed_height, self.transformed_width, + self.spatial_scale) + self.outputs = {'Out': out} + + def init_test_case(self): + self.batch_size = 2 + self.channels = 2 + self.height = 8 + self.width = 8 + + # n, c, h, w + self.x_dim = (self.batch_size, self.channels, self.height, self.width) + + self.spatial_scale = 1.0 / 2.0 + self.transformed_height = 2 + self.transformed_width = 3 + + self.x = np.random.random(self.x_dim).astype('float32') + + def make_rois(self): + rois = [] + self.rois_lod = [[]] + for bno in range(self.batch_size): + self.rois_lod[0].append(bno + 1) + for i in range(bno + 1): + x1 = np.random.randint( + 0, + self.width // self.spatial_scale - self.transformed_width) + y1 = np.random.randint( + 0, + self.height // self.spatial_scale - self.transformed_height) + + x2 = np.random.randint(x1 + self.transformed_width, + self.width // self.spatial_scale) + y2 = np.random.randint( + 0, + self.height // self.spatial_scale - self.transformed_height) + + x3 = np.random.randint(x1 + self.transformed_width, + self.width // self.spatial_scale) + y3 = np.random.randint(y1 + self.transformed_height, + self.height // self.spatial_scale) + + x4 = np.random.randint( + 0, + self.width // self.spatial_scale - self.transformed_width) + y4 = np.random.randint(y1 + self.transformed_height, + self.height // self.spatial_scale) + + roi = [x1, y1, x2, y2, x3, y3, x4, y4] + rois.append(roi) + self.rois_num = len(rois) + self.rois = np.array(rois).astype("float32") + + def setUp(self): + self.op_type = "roi_perspective_transform" + self.set_data() + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Out') + + +if __name__ == '__main__': + unittest.main()