未验证 提交 8ae48e71 编写于 作者: H HappyAngel 提交者: GitHub

[OpenCL]add box coder op (#3347)


* add boxcoder opencl kernel, test=develop

* fix format, test=develop

* fix , test=develop
上级 77734ce7
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <cl_common.h>
__kernel void decode_center_size(__read_only image2d_t prior_box_image,
__read_only image2d_t prior_box_var_image,
__read_only image2d_t target_box_image,
__write_only image2d_t output_image,
__private const int out_C,
__private const int out_H){
const int out_c = get_global_id(0);
const int out_nh = get_global_id(1);
const int out_h = out_nh % out_H;
const int out_n = 1;
const int prior_box_n = 1;
const int prior_box_c = 0;
const int prior_box_h = out_h;
const int prior_box_var_n = 1;
const int prior_box_var_c = 0;
const int prior_box_var_h = out_h;
const int target_box_n = 1;
const int target_box_c = out_c;
const int target_box_h = out_h;
const sampler_t sampler = CLK_NORMALIZED_COORDS_TRUE |
CLK_ADDRESS_CLAMP |
CLK_FILTER_NEAREST;
int2 prior_box_pos;
int2 prior_box_var_pos;
int2 target_box_pos;
int2 output_pos;
prior_box_pos.x = prior_box_c * 4;
prior_box_pos.y = prior_box_n * prior_box_h;
prior_box_var_pos.x = prior_box_var_c * 4;
prior_box_var_pos.y = prior_box_var_n * prior_box_var_h;
target_box_pos.x = target_box_c * 4;
target_box_pos.y = target_box_n * target_box_h;
output_pos.x = out_c * 4;
output_pos.y = out_n * out_h;
CL_DTYPE4 prior_box_input[4];
CL_DTYPE4 prior_box_var_input[4];
CL_DTYPE4 target_box_input[4];
prior_box_input[0] = READ_IMG_TYPE(CL_DTYPE_CHAR, prior_box_image, sampler,
(int2)(prior_box_pos.x + 0, prior_box_pos.y));
prior_box_input[1] = READ_IMG_TYPE(CL_DTYPE_CHAR, prior_box_image, sampler,
(int2)(prior_box_pos.x + 1, prior_box_pos.y));
prior_box_input[2] = READ_IMG_TYPE(CL_DTYPE_CHAR, prior_box_image, sampler,
(int2)(prior_box_pos.x + 2, prior_box_pos.y));
prior_box_input[3] = READ_IMG_TYPE(CL_DTYPE_CHAR, prior_box_image, sampler,
(int2)(prior_box_pos.x + 3, prior_box_pos.y));
prior_box_var_input[0] = READ_IMG_TYPE(CL_DTYPE_CHAR, prior_box_var_image, sampler,
(int2)(prior_box_var_pos.x + 0, prior_box_var_pos.y));
prior_box_var_input[1] = READ_IMG_TYPE(CL_DTYPE_CHAR, prior_box_var_image, sampler,
(int2)(prior_box_var_pos.x + 1, prior_box_var_pos.y));
prior_box_var_input[2] = READ_IMG_TYPE(CL_DTYPE_CHAR, prior_box_var_image, sampler,
(int2)(prior_box_var_pos.x + 2, prior_box_var_pos.y));
prior_box_var_input[3] = READ_IMG_TYPE(CL_DTYPE_CHAR, prior_box_var_image, sampler,
(int2)(prior_box_var_pos.x + 3, prior_box_var_pos.y));
target_box_input[0] = READ_IMG_TYPE(CL_DTYPE_CHAR, target_box_image, sampler,
(int2)(target_box_pos.x + 0,target_box_pos.y));
target_box_input[1] = READ_IMG_TYPE(CL_DTYPE_CHAR, target_box_image, sampler,
(int2)(target_box_pos.x + 1, target_box_pos.y));
target_box_input[2] = READ_IMG_TYPE(CL_DTYPE_CHAR, target_box_image, sampler,
(int2)(target_box_pos.x + 2, target_box_pos.y));
target_box_input[3] = READ_IMG_TYPE(CL_DTYPE_CHAR, target_box_image, sampler,
(int2)(target_box_pos.x + 3, target_box_pos.y));
CL_DTYPE prior_box_width = prior_box_input[2].x - prior_box_input[0].x;
CL_DTYPE prior_box_height = prior_box_input[3].x - prior_box_input[1].x;
CL_DTYPE prior_box_center_x = (prior_box_input[2].x + prior_box_input[0].x)/(CL_DTYPE)2;
CL_DTYPE prior_box_center_y = (prior_box_input[3].x + prior_box_input[1].x)/(CL_DTYPE)2;
CL_DTYPE4 target_box_center_x;
CL_DTYPE4 target_box_center_y;
CL_DTYPE4 target_box_width;
CL_DTYPE4 target_box_height;
CL_DTYPE4 output[4];
output[0] = 0.0f;
output[1] = 0.0f;
output[2] = 0.0f;
output[3] = 0.0f;
target_box_center_x.x = prior_box_var_input[0].x * target_box_input[0].x * prior_box_width + prior_box_center_x;
target_box_center_y.x = prior_box_var_input[1].x * target_box_input[1].x * prior_box_height + prior_box_center_y;
target_box_width.x = exp(prior_box_var_input[2].x * target_box_input[2].x) * prior_box_width;
target_box_height.x = exp(prior_box_var_input[3].x * target_box_input[3].x) * prior_box_height;
output[0].x = target_box_center_x.x - target_box_width.x/(half)2;
output[1].x = target_box_center_y.x - target_box_height.x/(half)2;
output[2].x = target_box_center_x.x + target_box_width.x/(half)2;
output[3].x = target_box_center_y.x + target_box_height.x/(half)2;
if(out_C - out_c * 4 >= 2){
target_box_center_x.y = prior_box_var_input[0].x * target_box_input[0].y * prior_box_width + prior_box_center_x;
target_box_center_y.y = prior_box_var_input[1].x * target_box_input[1].y * prior_box_height + prior_box_center_y;
target_box_width.y = exp(prior_box_var_input[2].x * target_box_input[2].y) * prior_box_width;
target_box_height.y = exp(prior_box_var_input[3].x * target_box_input[3].y) * prior_box_height;
output[0].y = target_box_center_x.y - target_box_width.y/(half)2;
output[1].y = target_box_center_y.y - target_box_height.y/(half)2;
output[2].y = target_box_center_x.y + target_box_width.y/(half)2;
output[3].y = target_box_center_y.y + target_box_height.y/(half)2;
}
if(out_C - out_c * 4 >= 3){
target_box_center_x.z = prior_box_var_input[0].x * target_box_input[0].z * prior_box_width + prior_box_center_x;
target_box_center_y.z = prior_box_var_input[1].x * target_box_input[1].z * prior_box_height + prior_box_center_y;
target_box_width.z = exp(prior_box_var_input[2].x * target_box_input[2].z) * prior_box_width;
target_box_height.z = exp(prior_box_var_input[3].x * target_box_input[3].z) * prior_box_height;
output[0].z = target_box_center_x.z - target_box_width.z/(half)2;
output[1].z = target_box_center_y.z - target_box_height.z/(half)2;
output[2].z = target_box_center_x.z + target_box_width.z/(half)2;
output[3].z = target_box_center_y.z + target_box_height.z/(half)2;
}
if(out_C - out_c * 4 >= 4){
target_box_center_x.w = prior_box_var_input[0].x * target_box_input[0].w * prior_box_width + prior_box_center_x;
target_box_center_y.w = prior_box_var_input[1].x * target_box_input[1].w * prior_box_height + prior_box_center_y;
target_box_width.w = exp(prior_box_var_input[2].x * target_box_input[2].w) * prior_box_width;
target_box_height.w = exp(prior_box_var_input[3].x * target_box_input[3].w) * prior_box_height;
output[0].w = target_box_center_x.w - target_box_width.w/(half)2;
output[1].w = target_box_center_y.w - target_box_height.w/(half)2;
output[2].w = target_box_center_x.w + target_box_width.w/(half)2;
output[3].w = target_box_center_y.w + target_box_height.w/(half)2;
}
WRITE_IMG_TYPE(CL_DTYPE_CHAR, output_image, (int2)(output_pos.x + 0, output_pos.y), output[0]);
WRITE_IMG_TYPE(CL_DTYPE_CHAR, output_image, (int2)(output_pos.x + 1, output_pos.y), output[1]);
WRITE_IMG_TYPE(CL_DTYPE_CHAR, output_image, (int2)(output_pos.x + 2, output_pos.y), output[2]);
WRITE_IMG_TYPE(CL_DTYPE_CHAR, output_image, (int2)(output_pos.x + 3, output_pos.y), output[3]);
}
......@@ -33,7 +33,7 @@ add_kernel(slice_opencl OPENCL basic SRCS slice_image_compute.cc DEPS ${cl_kerne
add_kernel(instance_norm_opencl OPENCL basic SRCS instance_norm_image_compute.cc DEPS ${cl_kernel_deps})
add_kernel(dropout_opencl OPENCL basic SRCS dropout_image_compute.cc DEPS ${cl_kernel_deps})
add_kernel(pad2d_opencl OPENCL basic SRCS pad2d_image_compute.cc DEPS ${cl_kernel_deps})
add_kernel(box_coder_opencl OPENCL basic SRCS box_coder_image_compute.cc DEPS ${cl_kernel_deps})
# extra
# wait to add ...
......@@ -97,6 +97,10 @@ lite_cc_test(test_dropout_image_opencl SRCS dropout_image_compute_test.cc
lite_cc_test(test_pad2d_image_opencl SRCS pad2d_image_compute_test.cc
DEPS pad2d_opencl layout_opencl op_registry program context)
lite_cc_test(test_box_coder_image_opencl SRCS box_coder_image_compute_test.cc
DEPS box_coder_opencl op_registry program context)
######################
# buffer kernel #
######################
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <memory>
#include <string>
#include "lite/backends/opencl/cl_half.h"
#include "lite/backends/opencl/cl_image_converter.h"
#include "lite/backends/opencl/cl_include.h"
#include "lite/core/kernel.h"
#include "lite/core/op_registry.h"
#include "lite/kernels/opencl/image_helper.h"
#include "lite/operators/op_params.h"
#include "lite/utils/logging.h"
#include "lite/utils/replace_stl/stream.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace opencl {
class BoxCoderComputeImage : public KernelLite<TARGET(kOpenCL),
PRECISION(kFP16),
DATALAYOUT(kImageDefault)> {
public:
using param_t = operators::BoxCoderParam;
void PrepareForRun() override {
auto& context = ctx_->As<OpenCLContext>();
boxcoder_param_ = param_.get_mutable<param_t>();
if (boxcoder_param_->code_type == "decode_center_size" &&
boxcoder_param_->box_normalized == true) {
kernel_func_name_ = "decode_center_size";
} else {
printf("This code_type %s doesn't support \n",
boxcoder_param_->code_type.c_str());
return;
}
CHECK(context.cl_context() != nullptr);
VLOG(1) << "kernel_func_name_:" << kernel_func_name_;
context.cl_context()->AddKernel(
kernel_func_name_, "image/box_coder_kernel.cl", build_options_);
}
void Run() override {
boxcoder_param_ = param_.get_mutable<param_t>();
const auto& out_dims = boxcoder_param_->proposals->dims();
auto image_shape = InitImageDimInfoWith(out_dims);
auto* out_buf =
boxcoder_param_->proposals->mutable_data<half_t, cl::Image2D>(
image_shape["width"], image_shape["height"]);
#ifndef LITE_SHUTDOWN_LOG
VLOG(4) << "boxcoder input shape: ";
#endif
const auto* input_priorbox = boxcoder_param_->prior_box;
const auto* input_priorboxvar = boxcoder_param_->prior_box_var;
const auto* input_targetbox = boxcoder_param_->target_box;
const auto& code_type = boxcoder_param_->code_type;
if (code_type == "decode_center_size") {
auto* prior_box_image = input_priorbox->data<half_t, cl::Image2D>();
auto* prior_box_var_image =
input_priorboxvar->data<half_t, cl::Image2D>();
auto* target_box_image = input_targetbox->data<half_t, cl::Image2D>();
int new_dims[4] = {1, 1, 1, 1};
for (int i = 0; i < out_dims.size(); i++) {
new_dims[4 - out_dims.size() + i] = out_dims[i];
}
auto& context = ctx_->As<OpenCLContext>();
CHECK(context.cl_context() != nullptr);
STL::stringstream kernel_key;
kernel_key << kernel_func_name_ << build_options_;
auto kernel = context.cl_context()->GetKernel(kernel_key.str());
auto default_work_size =
DefaultWorkSize(out_dims,
DDim(std::vector<DDim::value_type>{
static_cast<int64_t>(image_shape["width"]),
static_cast<int64_t>(image_shape["height"])}));
int out_C = new_dims[1];
int out_H = new_dims[2];
#ifndef LITE_SHUTDOWN_LOG
VLOG(4) << TargetToStr(boxcoder_param_->proposals->target());
VLOG(4) << "output shape: " << out_dims[0] << ", " << out_dims[1] << ", "
<< out_dims[2] << ", " << out_dims[3];
VLOG(4) << "image_shape(w,h):" << image_shape["width"] << " "
<< image_shape["height"];
VLOG(4) << "out_C = " << out_C;
VLOG(4) << "out_H = " << out_H;
VLOG(4) << "default_work_size = " << default_work_size[0] << ", "
<< default_work_size[1] << ", " << default_work_size[2];
#endif
int arg_idx = 0;
cl_int status = kernel.setArg(arg_idx++, *prior_box_image);
CL_CHECK_FATAL(status);
status = kernel.setArg(arg_idx++, *prior_box_var_image);
CL_CHECK_FATAL(status);
status = kernel.setArg(arg_idx++, *target_box_image);
CL_CHECK_FATAL(status);
status = kernel.setArg(arg_idx++, *out_buf);
CL_CHECK_FATAL(status);
status = kernel.setArg(arg_idx++, out_C);
CL_CHECK_FATAL(status);
status = kernel.setArg(arg_idx++, out_H);
CL_CHECK_FATAL(status);
auto global_work_size =
cl::NDRange{static_cast<cl::size_type>(default_work_size[0]),
static_cast<cl::size_type>(default_work_size[2])};
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel,
cl::NullRange,
global_work_size,
cl::NullRange,
nullptr,
event_.get());
CL_CHECK_FATAL(status);
context.cl_wait_list()->emplace(out_buf, event_);
#ifndef LITE_SHUTDOWN_LOG
VLOG(4) << "global_work_size:[2D]:" << global_work_size[0] << " "
<< global_work_size[1];
#endif
}
}
std::string doc() { return "Boxcoder using cl::Image, kFP16"; }
param_t* boxcoder_param_{nullptr};
std::string kernel_func_name_{};
std::string build_options_{" -DCL_DTYPE_half"};
std::shared_ptr<cl::Event> event_{new cl::Event};
};
} // namespace opencl
} // namespace kernels
} // namespace lite
} // namespace paddle
typedef paddle::lite::kernels::opencl::BoxCoderComputeImage BoxCoder_image;
REGISTER_LITE_KERNEL(
box_coder, kOpenCL, kFP16, kImageDefault, BoxCoder_image, ImageDefault)
.BindInput("PriorBox",
{LiteType::GetTensorTy(TARGET(kOpenCL),
PRECISION(kFP16),
DATALAYOUT(kImageDefault))})
.BindInput("PriorBoxVar",
{LiteType::GetTensorTy(TARGET(kOpenCL),
PRECISION(kFP16),
DATALAYOUT(kImageDefault))})
.BindInput("TargetBox",
{LiteType::GetTensorTy(TARGET(kOpenCL),
PRECISION(kFP16),
DATALAYOUT(kImageDefault))})
.BindOutput("OutputBox",
{LiteType::GetTensorTy(TARGET(kOpenCL),
PRECISION(kFP16),
DATALAYOUT(kImageDefault))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include <memory>
#include <random>
#include "lite/backends/opencl/target_wrapper.h"
#include "lite/core/op_registry.h"
#include "lite/core/tensor.h"
#include "lite/kernels/opencl/test_helper.h"
#define FP16_MAX_DIFF (5e-1)
namespace paddle {
namespace lite {
void box_coder_ref(float* proposals_data,
const float* anchors_data,
const float* bbox_deltas_data,
const float* variances_data,
int axis,
bool box_normalized,
std::string code_type,
int row,
int col) {
if (code_type == "decode_center_size") {
int anchor_len = 4;
int out_len = 4;
int var_len = 4;
int delta_len = 4;
float normalized = !box_normalized ? 1.f : 0;
for (int64_t row_id = 0; row_id < row; ++row_id) {
for (int64_t col_id = 0; col_id < col; ++col_id) {
size_t delta_offset = row_id * col * delta_len + col_id * delta_len;
size_t out_offset = row_id * col * out_len + col_id * out_len;
int prior_box_offset =
axis == 0 ? col_id * anchor_len : row_id * anchor_len;
int var_offset = axis == 0 ? col_id * var_len : row_id * var_len;
auto anchor_data_tmp = anchors_data + prior_box_offset;
auto bbox_deltas_data_tmp = bbox_deltas_data + delta_offset;
auto proposals_data_tmp = proposals_data + out_offset;
auto anchor_width =
anchor_data_tmp[2] - anchor_data_tmp[0] + normalized;
auto anchor_height =
anchor_data_tmp[3] - anchor_data_tmp[1] + normalized;
auto anchor_center_x = anchor_data_tmp[0] + 0.5 * anchor_width;
auto anchor_center_y = anchor_data_tmp[1] + 0.5 * anchor_height;
float bbox_center_x = 0, bbox_center_y = 0;
float bbox_width = 0, bbox_height = 0;
auto variances_data_tmp = variances_data + var_offset;
bbox_center_x =
variances_data_tmp[0] * bbox_deltas_data_tmp[0] * anchor_width +
anchor_center_x;
bbox_center_y =
variances_data_tmp[1] * bbox_deltas_data_tmp[1] * anchor_height +
anchor_center_y;
bbox_width = std::exp(variances_data_tmp[2] * bbox_deltas_data_tmp[2]) *
anchor_width;
bbox_height =
std::exp(variances_data_tmp[3] * bbox_deltas_data_tmp[3]) *
anchor_height;
proposals_data_tmp[0] = bbox_center_x - bbox_width / 2;
proposals_data_tmp[1] = bbox_center_y - bbox_height / 2;
proposals_data_tmp[2] = bbox_center_x + bbox_width / 2 - normalized;
proposals_data_tmp[3] = bbox_center_y + bbox_height / 2 - normalized;
}
}
} else if (code_type == "encode_center_size") {
LOG(FATAL) << "not implemented type: " << code_type;
} else {
LOG(FATAL) << "not supported type: " << code_type;
}
}
// #define BOXCODER_FP16_LOOP_TEST
// #define BOXCODER_FP16_PRINT_RESULT
TEST(box_coder_image2d, compute) {
#ifdef BOXCODER_FP16_LOOP_TEST
for (auto n : {1, 2, 3, 4}) {
for (auto m : {1, 3, 4, 8}) {
for (auto norm : {true}) {
for (auto code_type : {"decode_center_size"}) {
for (auto axis : {0}) {
#else
const int n = 1;
const int m = 1;
const bool norm = true;
const std::string code_type = "decode_center_size";
const int axis = 0;
#endif // BOXCODER_FP16_LOOP_TEST
LOG(INFO) << "======== input shape[n,c,h,w]:" << n << " " << m
<< " ========";
LOG(INFO) << "======== parameters: norm = " << norm
<< ", axis = " << axis << "code_type: " << code_type;
auto kernels =
KernelRegistry::Global().Create("box_coder",
TARGET(kOpenCL),
PRECISION(kFP16),
DATALAYOUT(kImageDefault));
ASSERT_FALSE(kernels.empty());
auto kernel = std::move(kernels.front());
LOG(INFO) << "get kernel:" << kernel->doc();
lite::Tensor prior_box, prior_box_var, target_box, output_box;
operators::BoxCoderParam param;
param.prior_box = &prior_box;
param.prior_box_var = &prior_box_var;
param.target_box = &target_box;
param.proposals = &output_box;
param.axis = axis;
param.box_normalized = norm;
param.code_type = code_type;
std::unique_ptr<KernelContext> context(new KernelContext);
context->As<OpenCLContext>().InitOnce();
kernel->SetParam(param);
std::unique_ptr<KernelContext> boxcoder_context(new KernelContext);
context->As<OpenCLContext>().CopySharedTo(
&(boxcoder_context->As<OpenCLContext>()));
kernel->SetContext(std::move(boxcoder_context));
const DDim prior_box_dims =
DDim(std::vector<DDim::value_type>{1, 1, m, 4});
const DDim prior_box_var_dims =
DDim(std::vector<DDim::value_type>{1, 1, m, 4});
const DDim target_box_dims =
DDim(std::vector<DDim::value_type>{1, n, m, 4});
const DDim out_dim =
DDim(std::vector<DDim::value_type>{1, n, m, 4});
prior_box.Resize(prior_box_dims);
prior_box_var.Resize(prior_box_var_dims);
target_box.Resize(target_box_dims);
output_box.Resize(out_dim);
std::vector<float> prior_box_data(prior_box_dims.production());
std::vector<float> prior_box_var_data(
prior_box_var_dims.production());
std::vector<float> target_box_data(target_box_dims.production());
for (int i = 0; i < prior_box_dims.production(); i++) {
prior_box_data[i] = i * 1.1 / prior_box_dims.production();
}
for (int i = 0; i < prior_box_var_dims.production(); i++) {
prior_box_var_data[i] = i * 1.2 / prior_box_var_dims.production();
}
for (int i = 0; i < target_box_dims.production(); i++) {
target_box_data[i] = i * 1.3 / target_box_dims.production();
}
LOG(INFO) << "prepare input";
CLImageConverterDefault* default_converter =
new CLImageConverterDefault();
DDim prior_box_image_shape =
default_converter->InitImageDimInfoWith(prior_box_dims);
LOG(INFO) << "prior_box_image_shape = " << prior_box_image_shape[0]
<< " " << prior_box_image_shape[1];
std::vector<half_t> prior_box_image_data(
prior_box_image_shape.production() * 4); // 4 : RGBA
default_converter->NCHWToImage(prior_box_data.data(),
prior_box_image_data.data(),
prior_box_dims);
auto* prior_box_image = prior_box.mutable_data<half_t, cl::Image2D>(
prior_box_image_shape[0],
prior_box_image_shape[1],
prior_box_image_data.data());
DDim prior_box_var_image_shape =
default_converter->InitImageDimInfoWith(prior_box_var_dims);
LOG(INFO) << "prior_box_var_image_shape = "
<< prior_box_var_image_shape[0] << " "
<< prior_box_var_image_shape[1];
std::vector<half_t> prior_box_var_image_data(
prior_box_var_image_shape.production() * 4); // 4 : RGBA
default_converter->NCHWToImage(prior_box_var_data.data(),
prior_box_var_image_data.data(),
prior_box_var_dims);
auto* prior_box_var_image =
prior_box_var.mutable_data<half_t, cl::Image2D>(
prior_box_var_image_shape[0],
prior_box_var_image_shape[1],
prior_box_var_image_data.data());
DDim target_box_image_shape =
default_converter->InitImageDimInfoWith(target_box_dims);
LOG(INFO) << "target_box_image_shape = "
<< target_box_image_shape[0] << " "
<< target_box_image_shape[1];
std::vector<half_t> target_box_image_data(
target_box_image_shape.production() * 4); // 4 : RGBA
default_converter->NCHWToImage(target_box_data.data(),
target_box_image_data.data(),
target_box_dims);
auto* target_box_image =
target_box.mutable_data<half_t, cl::Image2D>(
target_box_image_shape[0],
target_box_image_shape[1],
target_box_image_data.data());
DDim out_image_shape =
default_converter->InitImageDimInfoWith(out_dim);
LOG(INFO) << "out_image_shape = " << out_image_shape[0] << " "
<< out_image_shape[1];
auto* out_image = output_box.mutable_data<half_t, cl::Image2D>(
out_image_shape[0], out_image_shape[1]);
kernel->Launch();
auto* wait_list = context->As<OpenCLContext>().cl_wait_list();
auto* out_ptr = param.proposals->data<half_t, cl::Image2D>();
auto it = wait_list->find(out_ptr);
if (it != wait_list->end()) {
VLOG(4) << "--- Find the sync event for the target cl "
"tensor. ---";
auto& event = *(it->second);
event.wait();
} else {
LOG(FATAL) << "Could not find the sync event for the "
"target cl tensor.";
}
lite::Tensor out_ref_tensor;
out_ref_tensor.Resize(out_dim);
box_coder_ref(out_ref_tensor.mutable_data<float>(),
prior_box_data.data(),
target_box_data.data(),
prior_box_var_data.data(),
axis,
norm,
code_type,
target_box_dims[0],
target_box_dims[1]);
const size_t cl_image2d_row_pitch{0};
const size_t cl_image2d_slice_pitch{0};
half_t* out_image_data =
new half_t[40000]; // [out_image_shape.production() * 4];
TargetWrapperCL::ImgcpySync(out_image_data,
out_image,
out_image_shape[0],
out_image_shape[1],
cl_image2d_row_pitch,
cl_image2d_slice_pitch,
IoDirection::DtoH);
float* out_data = new float[out_image_shape.production() * 4];
default_converter->ImageToNCHW(
out_image_data, out_data, out_image_shape, out_dim);
// result
#ifdef BOXCODER_FP16_PRINT_RESULT
LOG(INFO) << "---- print kernel result (input -> output) ----";
for (int eidx = 0; eidx < out_dim.production(); ++eidx) {
std::cout << target_box_data[eidx] << " -> " << out_data[eidx]
<< std::endl;
}
#endif // BOXCODER_FP16_PRINT_RESULT
const float* out_ref = out_ref_tensor.data<float>();
for (int i = 0; i < out_dim.production(); i++) {
auto abs_diff = abs(out_data[i] - out_ref[i]);
auto relative_diff =
COMPUTE_RELATIVE_DIFF(out_data[i], out_ref[i]);
EXPECT_EQ((relative_diff <= FP16_MAX_DIFF) ||
(abs_diff <= FP16_MAX_DIFF),
true);
if ((relative_diff > FP16_MAX_DIFF) &&
(abs_diff > FP16_MAX_DIFF)) {
LOG(ERROR) << "error idx:" << i << ", in_data[" << i
<< "]: " << target_box_data[i] << ", out_data[" << i
<< "]: " << out_data[i] << ", out_ref[" << i
<< "]: " << out_ref[i] << ", abs_diff: " << abs_diff
<< ", relative_diff: " << relative_diff
<< ", FP16_MAX_DIFF: " << FP16_MAX_DIFF;
}
}
#ifdef BOXCODER_FP16_LOOP_TEST
} // axis
} // code_type
} // norm
} // m
} // n
#else
// nothing to do.
#endif
}
} // namespace lite
} // namespace paddle
USE_LITE_KERNEL(box_coder, kOpenCL, kFP16, kImageDefault, ImageDefault);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册