提交 489a18b1 编写于 作者: N NazgulLee 提交者: xiebaiyuan

add grid_sampler op, test=develop (#2632)

上级 0aa8881a
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef GRID_SAMPLER_OP
#include "operators/grid_sampler_op.h"
namespace paddle_mobile {
namespace operators {
template <typename Dtype, typename T>
void GridSamplerOp<Dtype, T>::InferShape() const {
auto x_dim = this->param_.InputX()->dims();
this->param_.Output()->Resize(x_dim);
}
} // namespace operators
} // namespace paddle_mobile
namespace ops = paddle_mobile::operators;
#ifdef PADDLE_MOBILE_CL
REGISTER_OPERATOR_CL(grid_sampler, ops::GridSamplerOp);
#endif
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef GRID_SAMPLER_OP
#pragma once
#include <string>
#include "framework/operator.h"
#include "operators/kernel/grid_sampler_kernel.h"
#include "operators/op_param.h"
namespace paddle_mobile {
namespace operators {
#ifdef GRID_SAMPLER_OP
DECLARE_OPERATOR(GridSampler, GridSamplerParam, GridSamplerKernel);
#endif
} // namespace operators
} // namespace paddle_mobile
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "cl_common.h"
__kernel void grid_sampler(__private const int out_height,
__private const int out_width,
__read_only image2d_t input,
__read_only image2d_t grid,
__write_only image2d_t output) {
const int out_c = get_global_id(0);
const int out_w = get_global_id(1);
const int out_nh = get_global_id(2) * 4;
const int out_n = out_nh / out_height;
const int out_h = out_nh % out_height;
const sampler_t sampler =
CLK_NORMALIZED_COORDS_TRUE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
int x_grid = out_h / 4 * 2;
int y_grid = out_n * out_width + out_w;
float4 g1 = read_imagef(grid, sampler, int2(x_grid, y_grid));
float4 g2 = read_imagef(grid, sampler, int2(x_grid + 1, y_grid));
float x = (g1.x + 1) * (out_width - 1) / 2;
float y = (g2.x + 1) * (out_height - 1) / 2;
float x0 = floor(x);
float y0 = floor(y);
int x_p = out_c * out_width + x0;
int y_p = out_n * out_height + y0;
int x_out = out_c * out_width + out_w;
int y_out = out_n * out_height + out_h;
float4 input0 = read_imagef(input, sampler, int2(x_p, y_p));
float4 input1 = read_imagef(input, sampler, int2(x_p + 1, y_p));
float4 input2 = read_imagef(input, sampler, int2(x_p, y_p + 1));
float4 input3 = read_imagef(input, sampler, int2(x_p + 1, y_p + 1));
float4 out_val = input0 * (x0 + 1 - x) * (y0 + 1 - y) +
input1 * (x - x0) * (y0 + 1 - y) +
input2 * (x0 + 1 - x) * (y - y0) +
input3 * (x - x0) * (y - y0);
write_imageh(output, int2(x_out, y_out), convert_half4(out_val));
x = (g1.y + 1) * (out_width - 1) / 2;
y = (g2.y + 1) * (out_height - 1) / 2;
x0 = floor(x);
y0 = floor(y);
x_p = out_c * out_width + x0;
y_p = out_n * out_height + y0;
input0 = read_imagef(input, sampler, int2(x_p, y_p));
input1 = read_imagef(input, sampler, int2(x_p + 1, y_p));
input2 = read_imagef(input, sampler, int2(x_p, y_p + 1));
input3 = read_imagef(input, sampler, int2(x_p + 1, y_p + 1));
out_val = input0 * (x0 + 1 - x) * (y0 + 1 - y) +
input1 * (x - x0) * (y0 + 1 - y) +
input2 * (x0 + 1 - x) * (y - y0) +
input3 * (x - x0) * (y - y0);
write_imageh(output, int2(x_out, y_out + 1), convert_half4(out_val));
x = (g1.z + 1) * (out_width - 1) / 2;
y = (g2.z + 1) * (out_height - 1) / 2;
x0 = floor(x);
y0 = floor(y);
x_p = out_c * out_width + x0;
y_p = out_n * out_height + y0;
input0 = read_imagef(input, sampler, int2(x_p, y_p));
input1 = read_imagef(input, sampler, int2(x_p + 1, y_p));
input2 = read_imagef(input, sampler, int2(x_p, y_p + 1));
input3 = read_imagef(input, sampler, int2(x_p + 1, y_p + 1));
out_val = input0 * (x0 + 1 - x) * (y0 + 1 - y) +
input1 * (x - x0) * (y0 + 1 - y) +
input2 * (x0 + 1 - x) * (y - y0) +
input3 * (x - x0) * (y - y0);
write_imageh(output, int2(x_out, y_out + 2), convert_half4(out_val));
x = (g1.w + 1) * (out_width - 1) / 2;
y = (g2.w + 1) * (out_height - 1) / 2;
x0 = floor(x);
y0 = floor(y);
x_p = out_c * out_width + x0;
y_p = out_n * out_height + y0;
input0 = read_imagef(input, sampler, int2(x_p, y_p));
input1 = read_imagef(input, sampler, int2(x_p + 1, y_p));
input2 = read_imagef(input, sampler, int2(x_p, y_p + 1));
input3 = read_imagef(input, sampler, int2(x_p + 1, y_p + 1));
out_val = input0 * (x0 + 1 - x) * (y0 + 1 - y) +
input1 * (x - x0) * (y0 + 1 - y) +
input2 * (x0 + 1 - x) * (y - y0) +
input3 * (x - x0) * (y - y0);
write_imageh(output, int2(x_out, y_out + 3), convert_half4(out_val));
}
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef GRID_SAMPLER_OP
#include "operators/kernel/grid_sampler_kernel.h"
namespace paddle_mobile {
namespace operators {
template <>
bool GridSamplerKernel<GPU_CL, float>::Init(GridSamplerParam<GPU_CL>* param) {
this->cl_helper_.AddKernel("grid_sampler", "grid_sampler_kernel.cl");
return true;
}
template <>
void GridSamplerKernel<GPU_CL, float>::Compute(
const GridSamplerParam<GPU_CL>& param) {
auto kernel = this->cl_helper_.KernelAt(0);
auto default_work_size = this->cl_helper_.DefaultWorkSize(*(param.Output()));
cl_int status;
auto output = param.Output();
auto input = param.InputX();
auto grid = param.Grid();
auto output_image = output->GetCLImage();
auto input_image = input->GetCLImage();
auto grid_image = grid->GetCLImage();
const int out_H = output->dims()[2];
const int out_W = output->dims()[3];
status = clSetKernelArg(kernel, 0, sizeof(cl_int), &out_H);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 1, sizeof(cl_int), &out_W);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 2, sizeof(cl_mem), &input_image);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 3, sizeof(cl_mem), &grid_image);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 4, sizeof(cl_mem), &output_image);
CL_CHECK_ERRORS(status);
const size_t work_size[3] = {default_work_size[0], default_work_size[1],
default_work_size[2] / 4};
status = clEnqueueNDRangeKernel(this->cl_helper_.CLCommandQueue(), kernel, 3,
NULL, work_size, NULL, 0, NULL, NULL);
CL_CHECK_ERRORS(status);
}
template class GridSamplerKernel<GPU_CL, float>;
} // namespace operators
} // namespace paddle_mobile
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "framework/operator.h"
#include "operators/op_param.h"
namespace paddle_mobile {
namespace operators {
#ifdef GRID_SAMPLER_OP
DECLARE_KERNEL(GridSampler, GridSamplerParam);
#endif // GRID_SAMPLER_OP
} // namespace operators
} // namespace paddle_mobile
...@@ -337,6 +337,11 @@ class OpParam { ...@@ -337,6 +337,11 @@ class OpParam {
return GetVarValue<T>("Filter", inputs, scope); return GetVarValue<T>("Filter", inputs, scope);
} }
template <typename T>
static T *GridFrom(const VariableNameMap &inputs, const Scope &scope) {
return GetVarValue<T>("Grid", inputs, scope);
}
template <typename T> template <typename T>
static const T GetAttr(const string &key, const AttributeMap &map) { static const T GetAttr(const string &key, const AttributeMap &map) {
return ((Attribute)map.at(key)).Get<T>(); return ((Attribute)map.at(key)).Get<T>();
...@@ -3687,7 +3692,6 @@ class PixelShuffleParam : public OpParam { ...@@ -3687,7 +3692,6 @@ class PixelShuffleParam : public OpParam {
}; };
#endif #endif
#ifdef GRID_SAMPLER_OP #ifdef GRID_SAMPLER_OP
template <typename Dtype> template <typename Dtype>
class GridSamplerParam : public OpParam { class GridSamplerParam : public OpParam {
...@@ -3700,15 +3704,18 @@ class GridSamplerParam : public OpParam { ...@@ -3700,15 +3704,18 @@ class GridSamplerParam : public OpParam {
Scope *scope) Scope *scope)
: OpParam(inputs, outputs, attrs, scope) { : OpParam(inputs, outputs, attrs, scope) {
input_x_ = InputXFrom<GType>(inputs, *scope); input_x_ = InputXFrom<GType>(inputs, *scope);
grid_ = GridFrom<GType>(inputs, *scope);
output_ = OutputFrom<GType>(outputs, *scope); output_ = OutputFrom<GType>(outputs, *scope);
} }
const GType *InputX() const { return input_x_; } const GType *InputX() const { return input_x_; }
const GType *Grid() const { return grid_; }
GType *Output() const { return output_; } GType *Output() const { return output_; }
private: private:
GType *input_x_; GType *input_x_;
GType *grid_;
GType *output_; GType *output_;
}; };
#endif #endif
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册