From 68183f92afc2dca63523dd6e9ebfcf7b66689c0f Mon Sep 17 00:00:00 2001 From: xiebaiyuan Date: Thu, 19 Dec 2019 23:07:38 +0800 Subject: [PATCH] suite scale in nearest_interp_op ,test=mobile (#2633) * suite scale in nearest_interp_op ,test=mobile * suite scale in nearest_interp_op && fix grid_sampler_kernel.cl ,test=mobile * suite scale in nearest_interp_op && fix grid_sampler_kernel.cl && fix code style ,test=mobile --- .../cl/cl_kernel/grid_sampler_kernel.cl | 44 +++++++++---------- mobile/src/operators/nearest_interp_op.cpp | 21 +++++++-- mobile/src/operators/op_param.h | 12 ++++- 3 files changed, 51 insertions(+), 26 deletions(-) diff --git a/mobile/src/operators/kernel/cl/cl_kernel/grid_sampler_kernel.cl b/mobile/src/operators/kernel/cl/cl_kernel/grid_sampler_kernel.cl index e366316e43..0512ce9bea 100644 --- a/mobile/src/operators/kernel/cl/cl_kernel/grid_sampler_kernel.cl +++ b/mobile/src/operators/kernel/cl/cl_kernel/grid_sampler_kernel.cl @@ -28,8 +28,8 @@ __kernel void grid_sampler(__private const int out_height, CLK_NORMALIZED_COORDS_TRUE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; int x_grid = out_h / 4 * 2; int y_grid = out_n * out_width + out_w; - float4 g1 = read_imagef(grid, sampler, int2(x_grid, y_grid)); - float4 g2 = read_imagef(grid, sampler, int2(x_grid + 1, y_grid)); + float4 g1 = read_imagef(grid, sampler, (int2)(x_grid, y_grid)); + float4 g2 = read_imagef(grid, sampler, (int2)(x_grid + 1, y_grid)); float x = (g1.x + 1) * (out_width - 1) / 2; float y = (g2.x + 1) * (out_height - 1) / 2; @@ -39,15 +39,15 @@ __kernel void grid_sampler(__private const int out_height, int y_p = out_n * out_height + y0; int x_out = out_c * out_width + out_w; int y_out = out_n * out_height + out_h; - float4 input0 = read_imagef(input, sampler, int2(x_p, y_p)); - float4 input1 = read_imagef(input, sampler, int2(x_p + 1, y_p)); - float4 input2 = read_imagef(input, sampler, int2(x_p, y_p + 1)); - float4 input3 = read_imagef(input, sampler, int2(x_p + 1, y_p + 1)); + float4 input0 = read_imagef(input, sampler, (int2)(x_p, y_p)); + float4 input1 = read_imagef(input, sampler, (int2)(x_p + 1, y_p)); + float4 input2 = read_imagef(input, sampler, (int2)(x_p, y_p + 1)); + float4 input3 = read_imagef(input, sampler, (int2)(x_p + 1, y_p + 1)); float4 out_val = input0 * (x0 + 1 - x) * (y0 + 1 - y) + input1 * (x - x0) * (y0 + 1 - y) + input2 * (x0 + 1 - x) * (y - y0) + input3 * (x - x0) * (y - y0); - write_imageh(output, int2(x_out, y_out), convert_half4(out_val)); + write_imageh(output, (int2)(x_out, y_out), convert_half4(out_val)); x = (g1.y + 1) * (out_width - 1) / 2; y = (g2.y + 1) * (out_height - 1) / 2; @@ -55,15 +55,15 @@ __kernel void grid_sampler(__private const int out_height, y0 = floor(y); x_p = out_c * out_width + x0; y_p = out_n * out_height + y0; - input0 = read_imagef(input, sampler, int2(x_p, y_p)); - input1 = read_imagef(input, sampler, int2(x_p + 1, y_p)); - input2 = read_imagef(input, sampler, int2(x_p, y_p + 1)); - input3 = read_imagef(input, sampler, int2(x_p + 1, y_p + 1)); + input0 = read_imagef(input, sampler, (int2)(x_p, y_p)); + input1 = read_imagef(input, sampler, (int2)(x_p + 1, y_p)); + input2 = read_imagef(input, sampler, (int2)(x_p, y_p + 1)); + input3 = read_imagef(input, sampler, (int2)(x_p + 1, y_p + 1)); out_val = input0 * (x0 + 1 - x) * (y0 + 1 - y) + input1 * (x - x0) * (y0 + 1 - y) + input2 * (x0 + 1 - x) * (y - y0) + input3 * (x - x0) * (y - y0); - write_imageh(output, int2(x_out, y_out + 1), convert_half4(out_val)); + write_imageh(output, (int2)(x_out, y_out + 1), convert_half4(out_val)); x = (g1.z + 1) * (out_width - 1) / 2; y = (g2.z + 1) * (out_height - 1) / 2; @@ -71,15 +71,15 @@ __kernel void grid_sampler(__private const int out_height, y0 = floor(y); x_p = out_c * out_width + x0; y_p = out_n * out_height + y0; - input0 = read_imagef(input, sampler, int2(x_p, y_p)); - input1 = read_imagef(input, sampler, int2(x_p + 1, y_p)); - input2 = read_imagef(input, sampler, int2(x_p, y_p + 1)); - input3 = read_imagef(input, sampler, int2(x_p + 1, y_p + 1)); + input0 = read_imagef(input, sampler, (int2)(x_p, y_p)); + input1 = read_imagef(input, sampler, (int2)(x_p + 1, y_p)); + input2 = read_imagef(input, sampler, (int2)(x_p, y_p + 1)); + input3 = read_imagef(input, sampler, (int2)(x_p + 1, y_p + 1)); out_val = input0 * (x0 + 1 - x) * (y0 + 1 - y) + input1 * (x - x0) * (y0 + 1 - y) + input2 * (x0 + 1 - x) * (y - y0) + input3 * (x - x0) * (y - y0); - write_imageh(output, int2(x_out, y_out + 2), convert_half4(out_val)); + write_imageh(output, (int2)(x_out, y_out + 2), convert_half4(out_val)); x = (g1.w + 1) * (out_width - 1) / 2; y = (g2.w + 1) * (out_height - 1) / 2; @@ -87,13 +87,13 @@ __kernel void grid_sampler(__private const int out_height, y0 = floor(y); x_p = out_c * out_width + x0; y_p = out_n * out_height + y0; - input0 = read_imagef(input, sampler, int2(x_p, y_p)); - input1 = read_imagef(input, sampler, int2(x_p + 1, y_p)); - input2 = read_imagef(input, sampler, int2(x_p, y_p + 1)); - input3 = read_imagef(input, sampler, int2(x_p + 1, y_p + 1)); + input0 = read_imagef(input, sampler, (int2)(x_p, y_p)); + input1 = read_imagef(input, sampler, (int2)(x_p + 1, y_p)); + input2 = read_imagef(input, sampler, (int2)(x_p, y_p + 1)); + input3 = read_imagef(input, sampler, (int2)(x_p + 1, y_p + 1)); out_val = input0 * (x0 + 1 - x) * (y0 + 1 - y) + input1 * (x - x0) * (y0 + 1 - y) + input2 * (x0 + 1 - x) * (y - y0) + input3 * (x - x0) * (y - y0); - write_imageh(output, int2(x_out, y_out + 3), convert_half4(out_val)); + write_imageh(output, (int2)(x_out, y_out + 3), convert_half4(out_val)); } diff --git a/mobile/src/operators/nearest_interp_op.cpp b/mobile/src/operators/nearest_interp_op.cpp index 14e71b78f1..e885ea26ad 100644 --- a/mobile/src/operators/nearest_interp_op.cpp +++ b/mobile/src/operators/nearest_interp_op.cpp @@ -24,8 +24,9 @@ void NearestInterpolationOp::InferShape() const { "Input(X) of BilinearInterOp should not be null."); PADDLE_MOBILE_ENFORCE(this->param_.Out() != nullptr, "Output(Out) of BilinearInterOp should not be null."); - auto dim_x = this->param_.InputX()->dims(); // NCHW format + DLOG << "dim_x :" << dim_x; + int out_h = this->param_.OutH(); int out_w = this->param_.OutW(); PADDLE_MOBILE_ENFORCE(dim_x.size() == 4, "X's dimension must be 4"); @@ -37,8 +38,22 @@ void NearestInterpolationOp::InferShape() const { "OutSize's dimension size must be 1"); PADDLE_MOBILE_ENFORCE(out_size_dim[0] == 2, "OutSize's dim[0] must be 2"); } - std::vector dim_out({dim_x[0], dim_x[1], out_h, out_w}); - this->param_.Out()->Resize(framework::make_ddim(dim_out)); + + DLOG << "this->param_.HasScale(): " << this->param_.HasScale(); + if (this->param_.HasScale()) { + const float scale = this->param_.Scale(); + DLOG << "scale_: " << scale; + std::vector dim_out({dim_x[0], dim_x[1], + static_cast(dim_x[2] * scale), + static_cast(dim_x[3] * scale)}); + this->param_.Out()->Resize(framework::make_ddim(dim_out)); + DLOG << "interp -- dim_out: " << dim_out; + + } else { + std::vector dim_out({dim_x[0], dim_x[1], out_h, out_w}); + this->param_.Out()->Resize(framework::make_ddim(dim_out)); + DLOG << "interp -- dim_out: " << dim_out; + } } } // namespace operators diff --git a/mobile/src/operators/op_param.h b/mobile/src/operators/op_param.h index 1224ef0693..0415291a73 100644 --- a/mobile/src/operators/op_param.h +++ b/mobile/src/operators/op_param.h @@ -3042,7 +3042,7 @@ class SplitParam : public OpParam { int axis; int num; std::vector sections; - // std::vector out_ts_; +// std::vector out_ts_; #ifdef PADDLE_MOBILE_FPGA private: @@ -3103,12 +3103,20 @@ class NearestInterpolationParam : public OpParam { out_ = OutFrom(outputs, *scope); out_h_ = GetAttr("out_h", attrs); out_w_ = GetAttr("out_w", attrs); + if (HasAttr("scale", attrs)) { + has_scale_ = true; + scale_ = GetAttr("scale", attrs); + } + DLOG << "has_scale_: " << has_scale_; + DLOG << "scale_: " << scale_; } const GType *InputX() const { return input_x_; } const GType *InputOutPutSize() const { return input_outsize_; } GType *Out() const { return out_; } int OutH() const { return out_h_; } int OutW() const { return out_w_; } + float Scale() const { return scale_; } + bool HasScale() const { return has_scale_; } private: GType *input_x_; @@ -3116,6 +3124,8 @@ class NearestInterpolationParam : public OpParam { GType *out_; int out_h_; int out_w_; + float scale_; + bool has_scale_; }; #endif -- GitLab