提交 7a8d4c47 编写于 作者: B Bluebird 提交者: GitHub

[LITE][OPENCL] add op bilinear_interp; fix output size error of op...

[LITE][OPENCL] add op bilinear_interp; fix output size error of op nearest_interp when scale == 0 (#3034)

* [LITE][OPENCL] add op bilinear_interp; fix output size error of op nearest_interp when scale == 0

* [LITE][OPENCL]fix codestyle
上级 66943610
......@@ -15,6 +15,7 @@ limitations under the License. */
#ifdef BILINEAR_INTERP_OP
#include "operators/bilinear_interp_op.h"
#include <vector>
namespace paddle_mobile {
namespace operators {
......@@ -49,6 +50,10 @@ namespace ops = paddle_mobile::operators;
REGISTER_OPERATOR_CPU(bilinear_interp, ops::BilinearOp);
#endif
#if PADDLE_MOBILE_CL
REGISTER_OPERATOR_CL(bilinear_interp, ops::BilinearOp)
#endif
#ifdef PADDLE_MOBILE_FPGA
#endif
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef BILINEAR_INTERP_OP
#include <operators/kernel/bilinear_interp_kernel.h>
namespace paddle_mobile {
namespace operators {
template <>
bool BilinearInterpKernel<GPU_CL, float>::Init(
paddle_mobile::operators::BilinearInterpParam<paddle_mobile::GPU_CL>
*param) {
this->cl_helper_.AddKernel("bilinear_interp", "bilinear_interp_kernel.cl");
return true;
}
template <>
void BilinearInterpKernel<GPU_CL, float>::Compute(
const paddle_mobile::operators::BilinearInterpParam<paddle_mobile::GPU_CL>
&param) {
auto kernel = this->cl_helper_.KernelAt(0);
auto default_work_size = this->cl_helper_.DefaultWorkSize(*(param.Out()));
auto input = param.InputX();
cl_mem input_image = input->GetCLImage();
auto output = param.Out();
cl_mem output_image = output->GetCLImage();
float scale_h, scale_w;
if (param.AlignCorners()) {
scale_h = (input->dims()[2] - 1.0f) / (output->dims()[2] - 1.0f);
scale_w = (input->dims()[3] - 1.0f) / (output->dims()[3] - 1.0f);
} else {
scale_h = input->dims()[2] / static_cast<float> output->dims()[2];
scale_w = input->dims()[3] / static_cast<float> output->dims()[3];
}
float align_delta = 0.0f;
if (!param.AlignCorners() && param.AlignMode() == 0) {
align_delta = 0.5f;
}
int in_dims_h = input->dims()[2];
int out_dims_h = output->dims()[2];
int in_dims_w = input->dims()[3];
int out_dims_w = output->dims()[3];
cl_int status;
status = clSetKernelArg(kernel, 0, sizeof(cl_mem), &input_image);
CL_CHECK_ERRORS(status)
status = clSetKernelArg(kernel, 1, sizeof(cl_mem), &output_image);
CL_CHECK_ERRORS(status)
status = clSetKernelArg(kernel, 2, sizeof(float), &scale_h);
CL_CHECK_ERRORS(status)
status = clSetKernelArg(kernel, 3, sizeof(float), &scale_w);
CL_CHECK_ERRORS(status)
status = clSetKernelArg(kernel, 4, sizeof(int), &in_dims_h);
CL_CHECK_ERRORS(status)
status = clSetKernelArg(kernel, 5, sizeof(int), &out_dims_h);
CL_CHECK_ERRORS(status)
status = clSetKernelArg(kernel, 6, sizeof(int), &in_dims_w);
CL_CHECK_ERRORS(status)
status = clSetKernelArg(kernel, 7, sizeof(int), &out_dims_w);
CL_CHECK_ERRORS(status)
status = clSetKernelArg(kernel, 8, sizeof(float), &align_delta);
CL_CHECK_ERRORS(status)
status = clEnqueueNDRangeKernel(
this->cl_helper_.CLCommandQueue(), kernel, default_work_size.size(), NULL,
default_work_size.data(), NULL, 0, NULL, NULL);
CL_CHECK_ERRORS(status)
}
template class BilinearInterpKernel<GPU_CL, float>;
} // namespace operators
} // namespace paddle_mobile
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
__kernel void bilinear_interp(__read_only image2d_t input, __write_only image2d_t output,
__private const float scale_h, __private const float scale_w,
__private const int in_dims_h, __private const int out_dims_h,
__private const int in_dims_w, __private const int out_dims_w,
__private const float align_delta) {
const int c = get_global_id(0);
const int w = get_global_id(1);
const int nh = get_global_id(2);
int2 output_pos;
output_pos.x = c * out_dims_w + w;
output_pos.y = nh;
// calculate center pixel's pos
int out_n = nh / out_dims_h;
int out_h = nh % out_dims_h;
float center_w = (w + align_delta) * scale_w - align_delta;
float center_h = (out_h + align_delta) * scale_h - align_delta;
int floor_w = (int)center_w;
int floor_h = (int)center_h;
int ceil_w = floor_w + 1;
int ceil_h = floor_h + 1;
if (ceil_w > in_dims_w) {
ceil_w = floor_w;
}
if (ceil_h > in_dims_h) {
ceil_h = floor_h;
}
float wight0_w = center_w - floor_w;
float wight0_h = center_h - floor_h;
float wight1_w = 1.0 - wight0_w;
float wight1_h = 1.0 - wight0_h;
const sampler_t sampler = CLK_NORMALIZED_COORDS_TRUE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
// get left up pixel data
int2 left_up;
left_up.x = c * in_dims_w + floor_w;
left_up.y = out_n * in_dims_h + ceil_h;
half4 left_up_data = read_imageh(input, sampler, left_up);
// get left down pixel data
int2 left_down;
left_down.x = c * in_dims_w + floor_w;
left_down.y = out_n * in_dims_h + floor_h;
half4 left_down_data = read_imageh(input, sampler, left_down);
// get right up pixel data
int2 right_up;
right_up.x = c * in_dims_w + ceil_w;
right_up.y = out_n * in_dims_h + ceil_h;
half4 right_up_data = read_imageh(input, sampler, right_up);
// get right down pixel's data
int2 right_down;
right_down.x = c * in_dims_w + ceil_w;
right_down.y = out_n * in_dims_h + floor_h;
half4 right_down_data = read_imageh(input, sampler, right_down);
// calculate output data
half4 data = (left_down_data * wight1_w + right_down_data * wight0_w) * wight1_h
+ (left_up_data * wight1_w + right_up_data * wight0_w) * wight0_h;
write_imageh(output, output_pos, data);
}
\ No newline at end of file
......@@ -27,8 +27,12 @@ void NearestInterpolationOp<DeviceType, T>::InferShape() const {
auto dim_x = this->param_.InputX()->dims(); // NCHW format
DLOG << "dim_x :" << dim_x;
bool ignore_scale = false;
int out_h = this->param_.OutH();
int out_w = this->param_.OutW();
if (out_h > 0 && out_w > 0) {
ignore_scale = true;
}
PADDLE_MOBILE_ENFORCE(dim_x.size() == 4, "X's dimension must be 4");
if (this->param_.InputOutPutSize() != nullptr) {
......@@ -40,7 +44,7 @@ void NearestInterpolationOp<DeviceType, T>::InferShape() const {
}
DLOG << "this->param_.HasScale(): " << this->param_.HasScale();
if (this->param_.HasScale()) {
if (this->param_.HasScale() && !ignore_scale) {
const float scale = this->param_.Scale();
DLOG << "scale_: " << scale;
std::vector<int64_t> dim_out({dim_x[0], dim_x[1],
......
......@@ -3081,12 +3081,16 @@ class BilinearInterpParam : public OpParam {
out_ = OutFrom<GType>(outputs, *scope);
out_h_ = GetAttr<int>("out_h", attrs);
out_w_ = GetAttr<int>("out_w", attrs);
align_corners = GetAttr<bool>("align_corners", attrs);
align_mode = GetAttr<int>("align_mode", attrs);
}
const GType *InputX() const { return input_x_; }
const GType *InputOutPutSize() const { return input_outsize_; }
GType *Out() const { return out_; }
int OutH() const { return out_h_; }
int OutW() const { return out_w_; }
bool AlignCorners() const { return align_corners; }
int AlignMode() const { return align_mode; }
private:
GType *input_x_;
......@@ -3094,6 +3098,8 @@ class BilinearInterpParam : public OpParam {
GType *out_;
int out_h_;
int out_w_;
bool align_corners;
int align_mode;
};
#endif
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册