未验证 提交 fe127825 编写于 作者: H HappyAngel 提交者: GitHub

[LITE][OPENCL] Add concat kernel (#2841)

add concat ut, test=develop

* fix axis compute, test=develop

* add other axis, test=develop

* fix ut. test=develop
上级 3d0b463a
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <cl_common.h>
__kernel void concat2(__global const CL_DTYPE* x_data0, __global const CL_DTYPE* x_data1, __global CL_DTYPE* out_data,
int size, int axis_size, int pre_size, int post_size, int total, int total0, int total1) {
const int index = get_global_id(0);
if (index < size){
for (int i = 0; i < pre_size; i++){
int offset_out = index * post_size + i * total;
int offset_in = index * post_size + i * total0;
// memcpy(out_data + offset_out, x_data0 + offset_in, post_size);
CL_DTYPE* dst = out_data + offset_out;
CL_DTYPE* src = x_data0 + offset_in;
for (int k = 0; k < post_size; k++){
*dst++ = *src++;
}
}
}else if (index < axis_size){
for (int i = 0; i < pre_size; i++){
int offset_out = index * post_size + i * total;
int offset_in = index * post_size + i * total1;
// memcpy(out_data + offset_out, x_data1 + offset_in, post_size);
CL_DTYPE* dst = out_data + offset_out;
CL_DTYPE* src = x_data1 + offset_in;
for (int k = 0; k < post_size; k++){
*dst++ = *src++;
}
}
}
}
__kernel void concat_mul(__global const CL_DTYPE* x_data, __global CL_DTYPE* out_data,
int axis_size, int pre_size, int post_size, int start, int total, int total0) {
const int index = get_global_id(0);
if (index < axis_size){
for (int i = 0; i < pre_size; i++){
int offset_out = (start + index) * post_size + i * total;
int offset_in = index * post_size + i * total0;
// memcpy(out_data + offset_out, x_data + offset_in, post_size);
CL_DTYPE* dst = out_data + offset_out;
CL_DTYPE* src = x_data + offset_in;
for (int k = 0; k < post_size; k++){
*dst++ = *src++;
}
}
}
}
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <cl_common.h>
__kernel void concat2(__read_only image2d_t input0,
__read_only image2d_t input1,
__write_only image2d_t output,
int axis_size, int flag, int width) {
const int x = get_global_id(0); // image_width cxw/4
const int y = get_global_id(1); // image_height nxh
const sampler_t sampler = CLK_NORMALIZED_COORDS_TRUE |
CLK_ADDRESS_CLAMP |
CLK_FILTER_NEAREST;
int xx = x / width;
if (flag == 0){
xx = y / width;
}
if (xx < axis_size){
CL_DTYPE4 in = READ_IMG_TYPE(CL_DTYPE_CHAR, input0, sampler, (int2)(x, y));
WRITE_IMG_TYPE(CL_DTYPE_CHAR, output, (int2)(x, y), in);
}else{
int new_val = xx - axis_size;
new_val *= width;
CL_DTYPE4 in = READ_IMG_TYPE(CL_DTYPE_CHAR, input0, sampler, (int2)(new_val, y));
WRITE_IMG_TYPE(CL_DTYPE_CHAR, output, (int2)(x, y), in);
}
// WRITE_IMG_TYPE(CL_DTYPE_CHAR, output, (int2)(x, y), in);
}
__kernel void concat_mul(__read_only image2d_t input0,
__write_only image2d_t output,
int axis_size, int flag, int width, int start) {
const int x = get_global_id(0); // image_width cxw/4
const int y = get_global_id(1); // image_height nxh
const sampler_t sampler = CLK_NORMALIZED_COORDS_TRUE |
CLK_ADDRESS_CLAMP |
CLK_FILTER_NEAREST;
int xx = x / width;
if (flag == 0){
xx = y / width;
}
if (xx < axis_size && xx >= start){
xx -= start;
xx *= width;
CL_DTYPE4 in = READ_IMG_TYPE(CL_DTYPE_CHAR, input0, sampler, (int2)(xx, y));
WRITE_IMG_TYPE(CL_DTYPE_CHAR, output, (int2)(x, y), in);
}
}
......@@ -20,6 +20,7 @@ add_kernel(depthwise_conv2d_opencl OPENCL basic SRCS depthwise_conv2d_compute.cc
add_kernel(reshape_opencl OPENCL basic SRCS reshape_compute.cc DEPS ${cl_kernel_deps})
add_kernel(conv_opencl OPENCL basic SRCS conv_compute.cc DEPS ${cl_kernel_deps} cl_image_converter)
add_kernel(layout_opencl OPENCL basic SRCS layout_compute.cc DEPS ${cl_kernel_deps})
add_kernel(concat_opencl OPENCL basic SRCS concat_compute.cc DEPS ${cl_kernel_deps})
add_kernel(nearest_interp_opencl OPENCL basic SRCS nearest_interp_compute.cc DEPS ${cl_kernel_deps})
lite_cc_test(test_elementwise_add_opencl SRCS elementwise_add_compute_test.cc
......@@ -83,6 +84,11 @@ lite_cc_test(test_conv_image2d_opencl SRCS conv_image2d_compute_test.cc
lite_cc_test(test_layout_opencl SRCS layout_compute_test.cc
DEPS layout_opencl op_registry program context cl_image_converter
ARGS --cl_path=${CMAKE_SOURCE_DIR}/lite/backends/opencl)
lite_cc_test(test_concat_opencl SRCS concat_compute_test.cc
DEPS concat_opencl layout_opencl op_registry program context
ARGS --cl_path=${CMAKE_SOURCE_DIR}/lite/backends/opencl)
lite_cc_test(test_nearest_interp_opencl SRCS nearest_interp_compute_test.cc
DEPS nearest_interp_opencl layout_opencl op_registry program context cl_image_converter
ARGS --cl_path=${CMAKE_SOURCE_DIR}/lite/backends/opencl)
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/opencl/concat_compute.h"
#include "lite/backends/opencl/cl_include.h"
#include "lite/core/kernel.h"
#include "lite/core/op_registry.h"
#include "lite/kernels/opencl/image_helper.h"
#include "lite/operators/op_params.h"
#include "lite/utils/replace_stl/stream.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace opencl {
template <>
void ConcatCompute<PRECISION(kFloat),
DATALAYOUT(kImageDefault)>::PrepareForRun() {
auto& context = ctx_->As<OpenCLContext>();
concat_param_ = param_.get_mutable<param_t>();
if (concat_param_->x.size() == 2) {
kernel_func_name_ = "concat2";
} else {
kernel_func_name_ = "concat_mul";
}
context.cl_context()->AddKernel(
kernel_func_name_, "image/concat_kernel.cl", build_options_);
// UpdateParams<kFloat, kImageDefault>();
auto axis = concat_param_->axis;
auto inputs = concat_param_->x;
auto out_dims = concat_param_->output->dims();
auto* axis_tensor = concat_param_->axis_tensor;
if (axis_tensor != nullptr) {
// auto* axis_tensor_data = axis_tensor->data<int>(TARGET(kARM));
// axis = axis_tensor_data[0];
}
auto in_dims = inputs[0]->dims();
axis_size_ = out_dims[axis];
axis_ = axis;
for (int i = 0; i < axis; i++) {
pre_size_ *= in_dims[i];
}
for (int i = axis + 1; i < in_dims.size(); i++) {
post_size_ *= in_dims[i];
}
for (int i = 1; i < inputs.size(); i++) {
auto dims = inputs[i]->dims();
// auto flag = CHECK_EQ_OR_FALSE(in_dims.size(), dims.size());
if (in_dims.size() != dims.size()) {
printf("input shape must be same \n");
return;
}
for (int i = 0; i < dims.size(); i++) {
if (i != axis) {
if (in_dims[i] != dims[i]) {
printf("input shape must be same \n");
return;
}
}
}
}
}
template <>
void ConcatCompute<PRECISION(kFloat), DATALAYOUT(kImageDefault)>::Run() {
auto& param = *param_.get_mutable<param_t>();
const auto& x_dims = param.output->dims();
auto image_shape = InitImageDimInfoWith(x_dims);
auto* out_buf = param.output->mutable_data<float, cl::Image2D>(
image_shape["width"], image_shape["height"]);
const auto& y_dims = param.output->dims(); // useless: check dim only
auto& context = ctx_->As<OpenCLContext>();
CHECK(context.cl_context() != nullptr);
STL::stringstream kernel_key;
kernel_key << kernel_func_name_ << build_options_;
auto inputs = param.x;
int arg_idx = 0;
int width = inputs[0]->dims()[-1];
auto global_work_size =
cl::NDRange{static_cast<cl::size_type>(image_shape["width"]),
static_cast<cl::size_type>(image_shape["height"])};
VLOG(4) << TargetToStr(param.output->target());
VLOG(4) << "image_shape(w,h):" << image_shape["width"] << " "
<< image_shape["height"];
VLOG(4) << "x_dims[" << x_dims.size() << "D]:" << x_dims[0] << " "
<< x_dims[1] << " " << x_dims[2] << " " << x_dims[3];
VLOG(4) << "y_dims[" << y_dims.size() << "D]:" << y_dims[0] << " "
<< y_dims[1] << " " << y_dims[2] << " " << y_dims[3];
auto kernel = context.cl_context()->GetKernel(kernel_key.str());
int flag = 1; // cxw
switch (axis_) {
case 0:
width = x_dims[2]; // n
flag = 0;
break;
case 1:
width = x_dims[3]; // c
break;
case 2:
width = x_dims[0]; // h
flag = 0;
break;
case 3:
case -1:
width = x_dims[1]; // w
break;
default:
printf("this axis: %d does not support \n", axis_);
}
if (inputs.size() == 2) {
auto* x_buf0 = inputs[0]->data<float, cl::Image2D>();
auto* x_buf1 = inputs[1]->data<float, cl::Image2D>();
cl_int status = kernel.setArg(arg_idx, *x_buf0);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *x_buf1);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *out_buf);
CL_CHECK_FATAL(status);
status =
kernel.setArg(++arg_idx, static_cast<int>(inputs[0]->dims()[axis_]));
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, flag);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, width);
CL_CHECK_FATAL(status);
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel,
cl::NullRange,
global_work_size,
cl::NullRange,
nullptr,
event_.get());
CL_CHECK_FATAL(status);
context.cl_context()->GetCommandQueue().finish();
} else {
auto start = 0;
for (int i = 0; i < inputs.size(); i++) {
arg_idx = 0;
auto* x_buf = inputs[i]->data<float, cl::Image2D>();
cl_int status = kernel.setArg(arg_idx, *x_buf);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *out_buf);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, axis_size_);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, start);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, flag);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, width);
CL_CHECK_FATAL(status);
CL_CHECK_FATAL(status);
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel,
cl::NullRange,
global_work_size,
cl::NullRange,
nullptr,
event_.get());
CL_CHECK_FATAL(status);
context.cl_context()->GetCommandQueue().finish();
start += inputs[i]->dims()[axis_];
}
}
}
template <>
std::string ConcatCompute<PRECISION(kFloat), DATALAYOUT(kImageDefault)>::doc() {
return "Concat using cl::Image, kFloat";
}
template <>
void ConcatCompute<PRECISION(kFloat), DATALAYOUT(kNCHW)>::PrepareForRun() {
auto& context = ctx_->As<OpenCLContext>();
concat_param_ = param_.get_mutable<param_t>();
if (concat_param_->x.size() == 2) {
kernel_func_name_ = "concat2";
} else {
kernel_func_name_ = "concat_mul";
}
context.cl_context()->AddKernel(
kernel_func_name_, "buffer/concat_kernel.cl", build_options_);
// UpdateParams<kFloat, kImageDefault>();
auto axis = concat_param_->axis;
auto inputs = concat_param_->x;
auto out_dims = concat_param_->output->dims();
auto* axis_tensor = concat_param_->axis_tensor;
if (axis_tensor != nullptr) {
// auto* axis_tensor_data = axis_tensor->data<int>(TARGET(kARM));
// axis = axis_tensor_data[0];
}
auto in_dims = inputs[0]->dims();
axis_size_ = out_dims[axis];
axis_ = axis;
for (int i = 0; i < axis; i++) {
pre_size_ *= in_dims[i];
}
for (int i = axis + 1; i < in_dims.size(); i++) {
post_size_ *= in_dims[i];
}
for (int i = 1; i < inputs.size(); i++) {
auto dims = inputs[i]->dims();
if (in_dims.size() != dims.size()) {
printf("input shape must be same \n");
return;
}
for (int i = 0; i < dims.size(); i++) {
if (i != axis) {
if (in_dims[i] != dims[i]) {
printf("input shape must be same \n");
return;
}
}
}
}
}
template <>
void ConcatCompute<PRECISION(kFloat), DATALAYOUT(kNCHW)>::Run() {
auto& param = *param_.get_mutable<param_t>();
const auto& x_dims = param.output->dims();
auto image_shape = InitImageDimInfoWith(x_dims);
auto* out_buf =
param.output->mutable_data<float, cl::Buffer>(TARGET(kOpenCL));
const auto& y_dims = param.output->dims(); // useless: check dim only
auto& context = ctx_->As<OpenCLContext>();
CHECK(context.cl_context() != nullptr);
STL::stringstream kernel_key;
kernel_key << kernel_func_name_ << build_options_;
auto inputs = param.x;
int arg_idx = 0;
auto global_work_size = cl::NDRange{axis_size_};
int total = axis_size_ * post_size_;
auto kernel = context.cl_context()->GetKernel(kernel_key.str());
if (inputs.size() == 2) {
auto* x_buf0 = inputs[0]->data<float, cl::Buffer>();
auto* x_buf1 = inputs[1]->data<float, cl::Buffer>();
auto axis0 = inputs[0]->dims()[axis_];
int total0 = axis0 * post_size_;
int total1 = (axis_size_ - axis0) * post_size_;
cl_int status = kernel.setArg(arg_idx, *x_buf0);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *x_buf1);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *out_buf);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<int>(axis0));
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, axis_size_);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, pre_size_);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, post_size_);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, total);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, total0);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, total1);
CL_CHECK_FATAL(status);
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel,
cl::NullRange,
global_work_size,
cl::NullRange,
nullptr,
event_.get());
CL_CHECK_FATAL(status);
context.cl_wait_list()->emplace(out_buf, event_);
} else {
auto start = 0;
for (int i = 0; i < inputs.size(); i++) {
arg_idx = 0;
int size = inputs[i]->dims()[axis_];
auto* x_buf = inputs[i]->data<float, cl::Buffer>();
global_work_size = cl::NDRange{static_cast<size_t>(size)};
int total0 = size * post_size_;
cl_int status = kernel.setArg(arg_idx, *x_buf);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *out_buf);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<int>(size));
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, pre_size_);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, post_size_);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, start);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, total);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, total0);
CL_CHECK_FATAL(status);
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel,
cl::NullRange,
global_work_size,
cl::NullRange,
nullptr,
event_.get());
CL_CHECK_FATAL(status);
context.cl_wait_list()->emplace(out_buf, event_);
start += size;
}
}
}
template <>
std::string ConcatCompute<PRECISION(kFloat), DATALAYOUT(kNCHW)>::doc() {
return "Concat using cl::Buffer, kFloat";
}
} // namespace opencl
} // namespace kernels
} // namespace lite
} // namespace paddle
typedef paddle::lite::kernels::opencl::ConcatCompute<PRECISION(kFloat),
DATALAYOUT(kNCHW)>
Concat_buffer;
typedef paddle::lite::kernels::opencl::ConcatCompute<PRECISION(kFloat),
DATALAYOUT(kImageDefault)>
Concat_image;
REGISTER_LITE_KERNEL(
concat, kOpenCL, kFloat, kImageDefault, Concat_image, ImageDefault)
.BindInput("X",
{LiteType::GetTensorTy(TARGET(kOpenCL),
PRECISION(kFloat),
DATALAYOUT(kImageDefault))})
.BindInput("AxisTensor",
{LiteType::GetTensorTy(TARGET(kOpenCL),
PRECISION(kInt32),
DATALAYOUT(kImageDefault))})
.BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kOpenCL),
PRECISION(kFloat),
DATALAYOUT(kImageDefault))})
.Finalize();
REGISTER_LITE_KERNEL(concat, kOpenCL, kFloat, kNCHW, Concat_buffer, def)
.BindInput("X",
{LiteType::GetTensorTy(TARGET(kOpenCL),
PRECISION(kFloat),
DATALAYOUT(kNCHW))})
.BindInput("AxisTensor",
{LiteType::GetTensorTy(TARGET(kOpenCL),
PRECISION(kInt32),
DATALAYOUT(kNCHW))})
.BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kOpenCL),
PRECISION(kFloat),
DATALAYOUT(kNCHW))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include "lite/core/kernel.h"
#include "lite/operators/op_params.h"
#include "lite/utils/cp_logging.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace opencl {
template <PrecisionType Ptype, DataLayoutType layout>
class ConcatCompute : public KernelLite<TARGET(kOpenCL), Ptype, layout> {
public:
using param_t = operators::ConcatParam;
void PrepareForRun() override;
void Run() override;
std::string doc(); // override;
// protected:
// void UpdateParams();
int axis_size_ = 1;
int post_size_ = 1;
int pre_size_ = 1;
int axis_ = 1;
param_t* concat_param_{nullptr};
std::string kernel_func_name_{};
std::string build_options_{"-DCL_DTYPE_float"};
std::shared_ptr<cl::Event> event_{new cl::Event};
};
} // namespace opencl
} // namespace kernels
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
// //
// // Licensed under the Apache License, Version 2.0 (the "License");
// // you may not use this file except in compliance with the License.
// // You may obtain a copy of the License at
// //
// // http://www.apache.org/licenses/LICENSE-2.0
// //
// // Unless required by applicable law or agreed to in writing, software
// // distributed under the License is distributed on an "AS IS" BASIS,
// // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// // See the License for the specific language governing permissions and
// // limitations under the License.
#include <gtest/gtest.h>
#include <random>
#include "lite/backends/opencl/target_wrapper.h"
#include "lite/core/op_registry.h"
#include "lite/core/tensor.h"
#include "lite/kernels/opencl/image_helper.h"
namespace paddle {
namespace lite {
template <typename dtype>
void concat2_compute_ref(const dtype *in0,
const dtype *in1,
const int axis,
const DDim in0_dim,
const DDim in1_dim,
const DDim out_dim,
dtype *out_data) {
int pre_size = 1;
int post_size = 1;
for (int i = 0; i < axis; i++) {
pre_size *= in0_dim[i];
}
for (int i = axis + 1; i < in0_dim.size(); i++) {
post_size *= in0_dim[i];
}
int axis_size = out_dim[axis];
for (int i = 0; i < pre_size; i++) {
for (int j = 0; j < axis_size; j++) {
if (j < in0_dim[axis]) {
memcpy(out_data, in0, sizeof(dtype) * post_size);
in0 += post_size;
out_data += post_size;
}
}
}
}
template <typename dtype>
void concat_mul_compute_ref(std::vector<const dtype *> ins_data,
std::vector<const DDim> ins_dim,
int axis,
const DDim out_dim,
dtype *out_data) {
int pre_size = 1;
int post_size = 1;
for (int i = 0; i < axis; i++) {
pre_size *= ins_dim[0][i];
}
for (int i = axis + 1; i < ins_dim[0].size(); i++) {
post_size *= ins_dim[0][i];
}
int axis_size = out_dim[axis];
for (int i = 0; i < pre_size; i++) {
for (int j = 0; j < ins_data.size(); j++) {
int size = post_size * ins_dim[j][axis];
memcpy(out_data, ins_data[j], sizeof(dtype) * size);
out_data += size;
}
}
}
#if 1 // concat_buffer
TEST(opencl_concat_buffer, compute) {
// prepare data
const DDim x0_dim = DDim(std::vector<DDim::value_type>{1, 2, 3, 4});
const DDim x1_dim = DDim(std::vector<DDim::value_type>{1, 2, 3, 4});
const DDim x2_dim = DDim(std::vector<DDim::value_type>{1, 2, 3, 4});
const DDim out_dim = DDim(std::vector<DDim::value_type>{1, 6, 3, 4});
lite::Tensor x0, x1, x2, out, out_ref;
x0.Resize(x0_dim);
x1.Resize(x1_dim);
x2.Resize(x2_dim);
out.Resize(out_dim);
out_ref.Resize(out_dim);
auto *x0_data = x0.mutable_data<float, cl::Buffer>(TARGET(kOpenCL));
auto *x1_data = x1.mutable_data<float, cl::Buffer>(TARGET(kOpenCL));
auto *x2_data = x2.mutable_data<float, cl::Buffer>(TARGET(kOpenCL));
std::default_random_engine engine;
std::uniform_real_distribution<float> dist(-10, 10);
auto *mapped_x0 = static_cast<float *>(
TargetWrapperCL::Map(x0_data, 0, sizeof(float) * x0_dim.production()));
auto *mapped_x1 = static_cast<float *>(
TargetWrapperCL::Map(x1_data, 0, sizeof(float) * x1_dim.production()));
auto *mapped_x2 = static_cast<float *>(
TargetWrapperCL::Map(x2_data, 0, sizeof(float) * x2_dim.production()));
for (int i = 0; i < x0_dim.production(); i++) {
mapped_x0[i] = dist(engine);
}
for (int i = 0; i < x1_dim.production(); i++) {
mapped_x1[i] = dist(engine);
}
for (int i = 0; i < x2_dim.production(); i++) {
mapped_x2[i] = dist(engine);
}
// set param and kernel, then run
operators::ConcatParam param;
std::vector<lite::Tensor *> ins;
ins.push_back(&x0);
ins.push_back(&x1);
ins.push_back(&x2);
auto axis = 1;
param.x = ins;
param.output = &out;
param.axis = axis;
std::vector<const float *> ins_data;
std::vector<const DDim> ins_dim;
ins_data.push_back(mapped_x0);
ins_data.push_back(mapped_x1);
ins_data.push_back(mapped_x2);
ins_dim.push_back(x0_dim);
ins_dim.push_back(x1_dim);
ins_dim.push_back(x2_dim);
std::unique_ptr<KernelContext> context(new KernelContext);
context->As<OpenCLContext>().InitOnce();
auto kernels = KernelRegistry::Global().Create(
"concat", TARGET(kOpenCL), PRECISION(kFloat), DATALAYOUT(kNCHW));
ASSERT_FALSE(kernels.empty());
auto kernel = std::move(kernels.front());
kernel->SetParam(param);
std::unique_ptr<KernelContext> concat_context(new KernelContext);
context->As<OpenCLContext>().CopySharedTo(
&(concat_context->As<OpenCLContext>()));
kernel->SetContext(std::move(concat_context));
kernel->Launch();
auto *wait_list = context->As<OpenCLContext>().cl_wait_list();
auto *out_ptr = param.output->data<float, cl::Buffer>();
auto it = wait_list->find(out_ptr);
if (it != wait_list->end()) {
VLOG(4) << "--- Find the sync event for the target cl tensor. ---";
auto &event = *(it->second);
event.wait();
} else {
LOG(FATAL) << "Could not find the sync event for the target cl tensor.";
}
// run compute ref and check
auto *out_ref_data = out_ref.mutable_data<float>(TARGET(kARM));
concat_mul_compute_ref<float>(ins_data, ins_dim, axis, out_dim, out_ref_data);
auto *out_data = out.mutable_data<float, cl::Buffer>();
auto *mapped_out = static_cast<float *>(
TargetWrapperCL::Map(out_data, 0, sizeof(float) * out_dim.production()));
for (int i = 0; i < out_dim.production(); i++) {
EXPECT_NEAR(mapped_out[i], out_ref_data[i], 1e-6);
}
TargetWrapperCL::Unmap(out_data, mapped_out);
TargetWrapperCL::Unmap(x0_data, mapped_x0);
TargetWrapperCL::Unmap(x1_data, mapped_x1);
TargetWrapperCL::Unmap(x2_data, mapped_x2);
}
#endif // concat_buffer
// #define LOOP_TEST
// #define PRINT_RESULT
TEST(concat_image2d_fp32, compute) {
LOG(INFO) << "main steps of test: host -> layout(buf2img) -> concat(img) -> "
"layout(img2buf) "
"-> host";
#ifdef LOOP_TEST
for (int n = 1; n <= 100; n += 33) {
for (auto c : {1, 3}) {
for (int h = 12; h <= 100; h += 13) {
for (int w = 12; w <= 100; w += 25) {
for (atuo &axis : {0, 1, 2, 3}) {
#else
const int n = 1;
const int c = 2;
const int h = 3;
const int w = 4;
const int axis = 1;
#endif // LOOP_TEST
LOG(INFO) << "======== input shape[n,c,h,w]:" << n << " " << c
<< " " << h << " " << w << " ========";
LOG(INFO) << "======== axis: " << axis;
// set layout kernels
auto buf_to_img_kernels =
KernelRegistry::Global().Create("layout",
TARGET(kOpenCL),
PRECISION(kAny),
DATALAYOUT(kImageDefault));
auto buf_to_img_kernels1 =
KernelRegistry::Global().Create("layout",
TARGET(kOpenCL),
PRECISION(kAny),
DATALAYOUT(kImageDefault));
auto img_to_buf_kernels = KernelRegistry::Global().Create(
"layout", TARGET(kOpenCL), PRECISION(kAny), DATALAYOUT(kNCHW));
auto concat_img_kernels =
KernelRegistry::Global().Create("concat",
TARGET(kOpenCL),
PRECISION(kFloat),
DATALAYOUT(kImageDefault));
ASSERT_FALSE(buf_to_img_kernels.empty());
ASSERT_FALSE(buf_to_img_kernels1.empty());
ASSERT_FALSE(img_to_buf_kernels.empty());
ASSERT_FALSE(concat_img_kernels.empty());
auto buf_to_img_kernel = std::move(buf_to_img_kernels.front());
auto buf_to_img_kernel1 = std::move(buf_to_img_kernels1.front());
auto img_to_buf_kernel = std::move(img_to_buf_kernels.front());
auto concat_img_kernel = std::move(concat_img_kernels.front());
LOG(INFO) << "get 1st kernel: " << buf_to_img_kernel->doc();
LOG(INFO) << "get 1st-1 kernel: " << buf_to_img_kernel1->doc();
LOG(INFO) << "get 2nd kernel: " << img_to_buf_kernel->doc();
LOG(INFO) << "get 3rd kernel: " << concat_img_kernel->doc();
// set tensors about op param
LOG(INFO) << "set tensors about op param";
lite::Tensor x0, x1, y, concat_in0, concat_in1, concat_out, y_ref;
operators::LayoutParam BufferToImageParam0, BufferToImageParam1;
operators::LayoutParam ImageToBufferParam;
BufferToImageParam0.x = &x0;
BufferToImageParam0.y = &concat_in0;
BufferToImageParam1.x = &x1;
BufferToImageParam1.y = &concat_in1;
ImageToBufferParam.x = &concat_out;
ImageToBufferParam.y = &y;
std::vector<lite::Tensor *> ins;
operators::ConcatParam concatParam;
ins.push_back(&concat_in0);
ins.push_back(&concat_in1);
concatParam.x = ins;
concatParam.axis = axis;
concatParam.output = &concat_out;
const DDim x0_dim = DDim(std::vector<DDim::value_type>{n, c, h, w});
DDim x1_dim = DDim(std::vector<DDim::value_type>{n, c, h, w});
DDim out_dim = DDim(std::vector<DDim::value_type>{n, c, h, w});
x1_dim[axis] += 2;
out_dim[axis] = x0_dim[axis] + x1_dim[axis];
x0.Resize(x0_dim);
x1.Resize(x1_dim);
y.Resize(out_dim);
concat_in0.Resize(x0_dim);
concat_in1.Resize(x1_dim);
concat_out.Resize(out_dim);
y_ref.Resize(out_dim);
auto concat_image2d_shape =
paddle::lite::kernels::opencl::InitImageDimInfoWith(out_dim);
auto concat_image2d_shape_in0 =
paddle::lite::kernels::opencl::InitImageDimInfoWith(x0_dim);
auto concat_image2d_shape_in1 =
paddle::lite::kernels::opencl::InitImageDimInfoWith(x1_dim);
// initialize tensors
LOG(INFO) << "initialize tensors";
auto *x_data0 = x0.mutable_data<float, cl::Buffer>(TARGET(kOpenCL));
auto *x_data1 = x1.mutable_data<float, cl::Buffer>(TARGET(kOpenCL));
auto *y_data = y.mutable_data<float, cl::Buffer>(TARGET(kOpenCL));
auto *y_data_ref = y_ref.mutable_data<float>(TARGET(kARM));
auto *mapped_x0 = static_cast<float *>(TargetWrapperCL::Map(
x_data0, 0, sizeof(float) * x0_dim.production()));
auto *mapped_x1 = static_cast<float *>(TargetWrapperCL::Map(
x_data1, 0, sizeof(float) * x1_dim.production()));
auto *mapped_y = static_cast<float *>(TargetWrapperCL::Map(
y_data, 0, sizeof(float) * out_dim.production()));
for (int i = 0; i < x0_dim.production(); ++i) {
mapped_x0[i] = static_cast<int>(i) - x0_dim.production() / 2;
}
for (int i = 0; i < x1_dim.production(); ++i) {
mapped_x1[i] = static_cast<int>(i) - x1_dim.production() / 2;
}
for (int i = 0; i < out_dim.production(); ++i) {
mapped_y[i] = static_cast<int>(0);
}
auto *concat_in_data0 = concat_in0.mutable_data<float, cl::Image2D>(
concat_image2d_shape_in0["width"],
concat_image2d_shape_in0["height"]);
auto *concat_in_data1 = concat_in1.mutable_data<float, cl::Image2D>(
concat_image2d_shape_in1["width"],
concat_image2d_shape_in1["height"]);
auto *concat_out_data = concat_out.mutable_data<float, cl::Image2D>(
concat_image2d_shape["width"], concat_image2d_shape["height"]);
// set context and kernel args
LOG(INFO) << "set context and kernel args";
std::unique_ptr<KernelContext> context(new KernelContext);
context->As<OpenCLContext>().InitOnce();
buf_to_img_kernel->SetParam(BufferToImageParam0);
std::unique_ptr<KernelContext> buf_to_img_context(
new KernelContext);
context->As<OpenCLContext>().CopySharedTo(
&(buf_to_img_context->As<OpenCLContext>()));
buf_to_img_kernel->SetContext(std::move(buf_to_img_context));
buf_to_img_kernel1->SetParam(BufferToImageParam1);
std::unique_ptr<KernelContext> buf_to_img_context1(
new KernelContext);
context->As<OpenCLContext>().CopySharedTo(
&(buf_to_img_context1->As<OpenCLContext>()));
buf_to_img_kernel1->SetContext(std::move(buf_to_img_context1));
img_to_buf_kernel->SetParam(ImageToBufferParam);
std::unique_ptr<KernelContext> img_to_buf_context(
new KernelContext);
context->As<OpenCLContext>().CopySharedTo(
&(img_to_buf_context->As<OpenCLContext>()));
img_to_buf_kernel->SetContext(std::move(img_to_buf_context));
concat_img_kernel->SetParam(concatParam);
std::unique_ptr<KernelContext> concat_img_context(
new KernelContext);
context->As<OpenCLContext>().CopySharedTo(
&(concat_img_context->As<OpenCLContext>()));
concat_img_kernel->SetContext(std::move(concat_img_context));
// run kernels
LOG(INFO) << "run kernel: buf_to_img_kernel";
buf_to_img_kernel->Launch();
buf_to_img_kernel1->Launch();
LOG(INFO) << "run kernel: concat_img_kernel";
concat_img_kernel->Launch();
LOG(INFO) << "run kernel: img_to_buf_kernel";
img_to_buf_kernel->Launch();
// compute ref cp_u
std::vector<const float *> ins_ptr;
std::vector<const DDim> in_dim;
ins_ptr.push_back(mapped_x0);
ins_ptr.push_back(mapped_x1);
in_dim.push_back(x0_dim);
in_dim.push_back(x1_dim);
concat_mul_compute_ref<float>(
ins_ptr, in_dim, axis, out_dim, y_data_ref);
// result
#ifdef PRINT_RESULT
LOG(INFO) << "---- print kernel result (input -> output) ----";
for (int eidx = 0; eidx < out_dim.production(); ++eidx) {
std::cout << mapped_x0[eidx] << ", " << mapped_x1[eidx] << " -> "
<< mapped_y[eidx] << std::endl;
}
#endif // PRINT_RESULT
// check result: compare kernel output and cpu output(y_data_ref)
for (int eidx = 0; eidx < out_dim.production(); eidx++) {
EXPECT_NEAR(y_data_ref[eidx], mapped_y[eidx], 1e-6);
if (abs(y_data_ref[eidx] - mapped_y[eidx]) > 1e-6) {
LOG(INFO) << "1st diff in this case at eidx[from 0]:" << eidx
<< " / " << x0_dim.production() << ", y_data_ref["
<< eidx << "]:" << y_data_ref[eidx] << ", mapped_y["
<< eidx << "]:" << mapped_y[eidx];
break;
}
}
// free
LOG(INFO) << "free: unmap x, y";
TargetWrapperCL::Unmap(x_data0, mapped_x0);
TargetWrapperCL::Unmap(x_data1, mapped_x1);
TargetWrapperCL::Unmap(y_data, mapped_y);
#ifdef LOOP_TEST
} // axis
} // w
} // h
} // c
} // n
#else
// nothing to do.
#endif
}
} // namespace lite
} // namespace paddle
// concat buffer
USE_LITE_KERNEL(concat, kOpenCL, kFloat, kNCHW, def);
// concat image2d fp32
USE_LITE_KERNEL(layout, kOpenCL, kAny, kImageDefault, NCHW_to_ImageDefault);
USE_LITE_KERNEL(layout, kOpenCL, kAny, kNCHW, ImageDefault_to_NCHW);
USE_LITE_KERNEL(concat, kOpenCL, kFloat, kImageDefault, ImageDefault);
......@@ -307,7 +307,7 @@ void test_conv_fp32(const std::vector<DDim>& input_dims,
#endif // LITE_WITH_ARM
// TODO(chenjiaoAngel): fix multi-threds, diff: 3x3 depthwise conv
#if 1 /// 3x3dw
#if 1 // 3x3dw
TEST(TestConv3x3DW, test_conv3x3_depthwise) {
if (FLAGS_basic_test) {
for (auto& stride : {1, 2}) {
......@@ -449,7 +449,7 @@ TEST(TestConv3x3s1, test_conv_3x3s1) {
dims.push_back(DDim({batch, cin, h, h}));
}
}
if (cin == 1 && cout ==1) {
if (cin == 1 && cout == 1) {
continue;
}
const float leakey_relu_scale = 8.88;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册