提交 60743492 编写于 作者: W wandongdong

add image2d for depthwise

上级 53381282
#pragma OPENCL EXTENSION cl_khr_3d_image_writes : enable
#define ACCUM_FLT4 float4
#ifdef ENABLE_FP16
#define FLT half
#define FLT4 half4
#define TO_FLT4 convert_half4
#else
#define FLT float
#define FLT2 float2
#define FLT3 float3
#define FLT4 float4
#define TO_FLT4 convert_float4
#define TO_ACCUM_TYPE convert_float4
#define TO_ACCUM_FLT convert_float
#define READ_IMAGE read_imagef
#define WRITE_IMAGE write_imagef
__constant sampler_t smp_edge = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP_TO_EDGE | CLK_FILTER_NEAREST;
__constant sampler_t smp_none = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_NONE | CLK_FILTER_NEAREST;
__constant sampler_t smp_zero = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
__kernel void DepthwiseConv2d_NC4HW4(
__global float4* src_data,
__global FLT4* filters,
__global FLT4* biases,
#endif
__constant sampler_t sampler_zero = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
__kernel void DepthwiseConv2d_IMG_NC4HW4(
__read_only image2d_t src_data,
__global FLT4* filter,
__global FLT4* bias,
float relu_clip1,
__global float4* dst_data,
int2 kernel_size,
int2 stride,
int2 padding,
int2 dilation,
int4 src_size,
int4 dst_size
) {
__write_only image2d_t dst_data,
int2 kernel_size,
int2 stride,
int2 padding,
int2 dilation,
int4 src_size,
int4 dst_size) {
int X = get_global_id(0);
int Y = get_global_id(1);
int Z = get_global_id(2);
if (X >= dst_size.x || Y >= dst_size.y || Z >= dst_size.z) return;
ACCUM_FLT4 r = (ACCUM_FLT4)(0.0f, 0.0f, 0.0f, 0.0f);
FLT4 r = (FLT4)(0.0f, 0.0f, 0.0f, 0.0f);
int x_offseted = X * stride.x + padding.x;
int y_offseted = Y * stride.y + padding.y;
int fx_c = Z * kernel_size.x * kernel_size.y;
......@@ -40,37 +35,160 @@ __global float4* dst_data,
int x_c = x_offseted + kx * dilation.x;
bool outside_x = x_c < 0 || x_c >= src_size.x;
if (!outside_x && !outside_y) {
FLT4 f = filters[fx_c];
FLT4 src_final =src_data[(((Z) * src_size.y + (y_c)) * src_size.x + (x_c))];
r += TO_ACCUM_TYPE(src_final * f);
FLT4 f = filter[fx_c];
//FLT4 src_final =src_data[(((Z) * src_size.y + (y_c)) * src_size.x + (x_c))];
FLT4 src_final =read_imagef(src_data, sampler_zero, (int2)(x_c, (Z * src_size.y + y_c)));
r += TO_FLT4(src_final * f);
};
fx_c++;
}
}
FLT4 bias_val = biases[Z];
FLT4 bias_val = bias[Z];
FLT4 res0 = TO_FLT4(r) + bias_val;
res0 = clamp(res0, (FLT)(0.0f), (FLT)(relu_clip1));
//dst_data[(((Z) * dst_size.y + (Y)) * dst_size.x + (X))] = res0;
write_imagef(dst_data, (int2)(X, (Z * dst_size.y + Y)), res0);
}
__kernel void DepthwiseConv2d_IMG_NHWC4(
__read_only image2d_t src_data,
__global FLT4* filter,
__global FLT4* bias,
float relu_clip1,
__write_only image2d_t dst_data,
int2 kernel_size,
int2 stride,
int2 padding,
int2 dilation,
int4 src_size,
int4 dst_size) {
int X = get_global_id(0);
int Y = get_global_id(1);
int Z = get_global_id(2);
if (X >= dst_size.x || Y >= dst_size.y || Z >= dst_size.z) return;
FLT4 r = (FLT4)(0.0f, 0.0f, 0.0f, 0.0f);
int x_offseted = X * stride.x + padding.x;
int y_offseted = Y * stride.y + padding.y;
int fx_c = Z * kernel_size.x * kernel_size.y;
for (int ky = 0; ky < kernel_size.y; ++ky) {
int y_c = y_offseted + ky * dilation.y;
bool outside_y = y_c < 0 || y_c >= src_size.y;
for (int kx = 0; kx < kernel_size.x; ++kx) {
int x_c = x_offseted + kx * dilation.x;
bool outside_x = x_c < 0 || x_c >= src_size.x;
if (!outside_x && !outside_y) {
FLT4 f = filter[fx_c];
//FLT4 src_final =src_data[((y_c * src_size.x + x_c) * src_size.z + Z)];
FLT4 src_final =read_imagef(src_data, sampler_zero, (int2)(Z+x_c*src_size.z, y_c));
r += TO_FLT4(src_final * f);
};
fx_c++;
}
}
FLT4 bias_val = bias[Z];
FLT4 res0 = TO_FLT4(r) + bias_val;
res0 = clamp(res0, (FLT)(0.0f), (FLT)(relu_clip1));
//dst_data[((Y * dst_size.x + X) * dst_size.z + Z)] = res0;
write_imagef(dst_data, (int2)(X*dst_size.z+Z, Y), res0);
}
__kernel void DepthwiseConv2d_IMG_NHWC4_1x1(
__read_only image2d_t src_data,
__global FLT4* filter,
__global FLT4* bias,
float relu_clip1,
__write_only image2d_t dst_data,
int2 kernel_size,
int2 stride,
int2 padding,
int2 dilation,
int4 src_size,
int4 dst_size) {
int X = get_global_id(0);
int Y = get_global_id(1);
int Z = get_global_id(2);
if (X >= dst_size.x || Y >= dst_size.y || Z >= dst_size.z) return;
FLT4 r = (FLT4)(0.0f, 0.0f, 0.0f, 0.0f);
int x_offseted = X * stride.x + padding.x;
int y_offseted = Y * stride.y + padding.y;
int fx_c = Z;
{
int y_c = y_offseted;
bool outside_y = y_c < 0 || y_c >= src_size.y;
{
int x_c = x_offseted;
bool outside_x = x_c < 0 || x_c >= src_size.x;
if (!outside_x && !outside_y) {
FLT4 f = filter[fx_c];
//FLT4 src_final =src_data[((y_c * src_size.x + x_c) * src_size.z + Z)];
FLT4 src_final = read_imagef(src_data, sampler_zero, (int2)(Z, (y_c * src_size.x + x_c) * src_size.z));
r += TO_FLT4(src_final * f);
};
}
}
FLT4 bias_val = bias[Z];
FLT4 res0 = TO_FLT4(r) + bias_val;
res0 = clamp(res0, (FLT)(0.0f), (FLT)(relu_clip1));
//dst_data[((Y * dst_size.x + X) * dst_size.z + Z)] = res0;
write_imagef(dst_data, (int2)(Z, (Y * dst_size.x + X) * dst_size.z), res0);
}
__kernel void DepthwiseConv2d_BUF_NC4HW4(
__global FLT4* src_data,
__global FLT4* filter,
__global FLT4* bias,
float relu_clip1,
__global FLT4* dst_data,
int2 kernel_size,
int2 stride,
int2 padding,
int2 dilation,
int4 src_size,
int4 dst_size) {
int X = get_global_id(0);
int Y = get_global_id(1);
int Z = get_global_id(2);
if (X >= dst_size.x || Y >= dst_size.y || Z >= dst_size.z) return;
FLT4 r = (FLT4)(0.0f, 0.0f, 0.0f, 0.0f);
int x_offseted = X * stride.x + padding.x;
int y_offseted = Y * stride.y + padding.y;
int fx_c = Z * kernel_size.x * kernel_size.y;
for (int ky = 0; ky < kernel_size.y; ++ky) {
int y_c = y_offseted + ky * dilation.y;
bool outside_y = y_c < 0 || y_c >= src_size.y;
for (int kx = 0; kx < kernel_size.x; ++kx) {
int x_c = x_offseted + kx * dilation.x;
bool outside_x = x_c < 0 || x_c >= src_size.x;
if (!outside_x && !outside_y) {
FLT4 f = filter[fx_c];
FLT4 src_final =src_data[(((Z) * src_size.y + (y_c)) * src_size.x + (x_c))];
r += TO_FLT4(src_final * f);
};
fx_c++;
}
}
FLT4 bias_val = bias[Z];
FLT4 res0 = TO_FLT4(r) + bias_val;
res0 = clamp(res0, (FLT)(0.0f), (FLT)(relu_clip1));
dst_data[(((Z) * dst_size.y + (Y)) * dst_size.x + (X))] = res0;
}
__kernel void DepthwiseConv2d_NHWC4(
__global float4* src_data,
__global FLT4* filters,
__global FLT4* biases,
__kernel void DepthwiseConv2d_BUF_NHWC4(
__global FLT4* src_data,
__global FLT4* filter,
__global FLT4* bias,
float relu_clip1,
__global float4* dst_data,
int2 kernel_size,
int2 stride,
int2 padding,
int2 dilation,
int4 src_size,
int4 dst_size
) {
__global FLT4* dst_data,
int2 kernel_size,
int2 stride,
int2 padding,
int2 dilation,
int4 src_size,
int4 dst_size) {
int X = get_global_id(0);
int Y = get_global_id(1);
int Z = get_global_id(2);
if (X >= dst_size.x || Y >= dst_size.y || Z >= dst_size.z) return;
ACCUM_FLT4 r = (ACCUM_FLT4)(0.0f, 0.0f, 0.0f, 0.0f);
FLT4 r = (FLT4)(0.0f, 0.0f, 0.0f, 0.0f);
int x_offseted = X * stride.x + padding.x;
int y_offseted = Y * stride.y + padding.y;
int fx_c = Z * kernel_size.x * kernel_size.y;
......@@ -81,14 +199,53 @@ __global float4* dst_data,
int x_c = x_offseted + kx * dilation.x;
bool outside_x = x_c < 0 || x_c >= src_size.x;
if (!outside_x && !outside_y) {
FLT4 f = filters[fx_c];
FLT4 src_final =src_data[((y_c * src_size.x + x_c) * src_size.z + Z)];
r += TO_ACCUM_TYPE(src_final * f);
FLT4 f = filter[fx_c];
FLT4 src_final =src_data[((y_c * src_size.x + x_c) * src_size.z + Z)];
r += TO_FLT4(src_final * f);
};
fx_c++;
}
}
FLT4 bias_val = biases[Z];
FLT4 bias_val = bias[Z];
FLT4 res0 = TO_FLT4(r) + bias_val;
res0 = clamp(res0, (FLT)(0.0f), (FLT)(relu_clip1));
dst_data[((Y * dst_size.x + X) * dst_size.z + Z)] = res0;
}
__kernel void DepthwiseConv2d_BUF_NHWC4_1x1(
__global FLT4* src_data,
__global FLT4* filter,
__global FLT4* bias,
float relu_clip1,
__global FLT4* dst_data,
int2 kernel_size,
int2 stride,
int2 padding,
int2 dilation,
int4 src_size,
int4 dst_size) {
int X = get_global_id(0);
int Y = get_global_id(1);
int Z = get_global_id(2);
if (X >= dst_size.x || Y >= dst_size.y || Z >= dst_size.z) return;
FLT4 r = (FLT4)(0.0f, 0.0f, 0.0f, 0.0f);
int x_offseted = X * stride.x + padding.x;
int y_offseted = Y * stride.y + padding.y;
int fx_c = Z;
{
int y_c = y_offseted;
bool outside_y = y_c < 0 || y_c >= src_size.y;
{
int x_c = x_offseted;
bool outside_x = x_c < 0 || x_c >= src_size.x;
if (!outside_x && !outside_y) {
FLT4 f = filter[fx_c];
FLT4 src_final =src_data[((y_c * src_size.x + x_c) * src_size.z + Z)];
r += TO_FLT4(src_final * f);
};
}
}
FLT4 bias_val = bias[Z];
FLT4 res0 = TO_FLT4(r) + bias_val;
res0 = clamp(res0, (FLT)(0.0f), (FLT)(relu_clip1));
dst_data[((Y * dst_size.x + X) * dst_size.z + Z)] = res0;
......
......@@ -21,9 +21,12 @@
#include "src/runtime/opencl/opencl_runtime.h"
#include "src/runtime/kernel/arm/fp32/convolution_depthwise.h"
#include "src/runtime/kernel/arm/opclib/pack.h"
#ifndef PROGRAM_WITH_IL
#include "src/runtime/kernel/opencl/cl/fp16/depthwise_conv2d.cl.inc"
#include "src/runtime/kernel/opencl/cl/fp32/depthwise_conv2d.cl.inc"
#endif
using mindspore::kernel::KERNEL_ARCH::kGPU;
......@@ -35,20 +38,31 @@ namespace mindspore::kernel {
int DepthwiseConv2dOpenCLKernel::Init() {
auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance();
std::string kernel_name = "DepthwiseConv2d_NHWC4";
std::string kernel_name = "DepthwiseConv2d";
auto in_format = inputs_[0]->GetFormat();
outputs_[0]->SetFormat(in_format);
if (in_format != schema::Format_NHWC4 && in_format != schema::Format_NC4HW4) {
MS_LOG(ERROR) << "input format(" << in_format << ") " << "format not support!";
}
if (mem_type_ == MEM_TYPE::BUF) {
kernel_name += "_BUF";
} else {
kernel_name += "_IMG";
}
if (in_format == schema::Format_NC4HW4) {
kernel_name = "DepthwiseConv2d_NC4HW4";
kernel_name += "_NC4HW4";
} else if (in_format == schema::Format_NHWC4) {
kernel_name += "_NHWC4";
}
auto parameter = reinterpret_cast<ConvParameter *>(opParameter);
if (parameter->kernel_h_ == 1) {
kernel_name += "_1x1";
}
#ifdef PROGRAM_WITH_IL
ocl_runtime->CreateKernelFromIL(kernel_(), kernel_name);
#else
std::string program_name = "DepthwiseConv2d";
std::set<std::string> build_options;
std::set <std::string> build_options;
#ifdef ENABLE_FP16
std::string source = depthwise_conv2d_source_fp16;
#else
......@@ -61,8 +75,9 @@ int DepthwiseConv2dOpenCLKernel::Init() {
MS_LOG(DEBUG) << kernel_name << " Init Done!";
return 0;
}
int DepthwiseConv2dOpenCLKernel::InitBuffer() {
auto parameter = reinterpret_cast<ConvParameter*>(opParameter);
auto parameter = reinterpret_cast<ConvParameter *>(opParameter);
auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance();
auto allocator = ocl_runtime->GetAllocator();
......@@ -89,54 +104,101 @@ int DepthwiseConv2dOpenCLKernel::InitBuffer() {
size_t up_co_size = C4NUM * CO4 * sizeof(FLOAT_t);
memset_s(bias_data_, up_co_size, 0, up_co_size);
auto ori_bias = reinterpret_cast<FLOAT_t *>(inputs_.at(kBiasIndex)->Data());
memcpy_s(bias_data_, outputs_[0]->Channel() * sizeof(FLOAT_t), ori_bias, outputs_[0]->Channel() * sizeof(FLOAT_t));
memcpy_s(bias_data_, outputs_[0]->Channel() * sizeof(FLOAT_t), ori_bias,
outputs_[0]->Channel() * sizeof(FLOAT_t));
allocator->UnmapBuffer(bias_data_);
} else {
MS_ASSERT(inputs_.size() == kInputSize1);
}
return 0;
}
int DepthwiseConv2dOpenCLKernel::ReSize() {
return 0;
}
int DepthwiseConv2dOpenCLKernel::Run() {
MS_LOG(DEBUG) << this->Name() << " Running!";
auto parameter = reinterpret_cast<ConvParameter*>(opParameter);
auto parameter = reinterpret_cast<ConvParameter *>(opParameter);
auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance();
size_t CO4 = UP_DIV(outputs_[0]->Channel(), C4NUM);
size_t CI4 = UP_DIV(inputs_[0]->Channel(), C4NUM);
std::vector<size_t> global = {(size_t)outputs_[0]->Width(), (size_t)outputs_[0]->Height(), CO4};
std::vector<size_t> local = {1, 1, 1};
std::vector <size_t> global = {(size_t) outputs_[0]->Width(), (size_t) outputs_[0]->Height(), CO4};
std::vector <size_t> local = {1, 1, CO4};
float relu_clip1 = 6.0;
cl_int2 kernel_size = {parameter->kernel_h_, parameter->kernel_w_};
cl_int2 stride = {parameter->stride_h_, parameter->stride_w_};
cl_int2 padding = {-parameter->pad_h_, -parameter->pad_w_};
cl_int2 dilation = {parameter->dilation_h_, parameter->dilation_w_};
cl_int4 src_size = {inputs_[0]->Width(), inputs_[0]->Height(), (cl_int)CI4, inputs_[0]->Batch()};
cl_int4 dst_size = {(cl_int)outputs_[0]->Width(), (cl_int)outputs_[0]->Height(), (cl_int)CO4,
(cl_int)outputs_[0]->Batch()};
ocl_runtime->SetKernelArg(kernel_, 0, inputs_[0]->Data());
cl_int4 src_size = {inputs_[0]->Width(), inputs_[0]->Height(), (cl_int) CI4, inputs_[0]->Batch()};
cl_int4 dst_size = {(cl_int) outputs_[0]->Width(), (cl_int) outputs_[0]->Height(), (cl_int) CO4,
(cl_int) outputs_[0]->Batch()};
ocl_runtime->SetKernelArg(kernel_, 1, packed_weight_);
ocl_runtime->SetKernelArg(kernel_, 2, bias_data_);
ocl_runtime->SetKernelArg(kernel_, 3, relu_clip1);
ocl_runtime->SetKernelArg(kernel_, 4, outputs_[0]->Data());
ocl_runtime->SetKernelArg(kernel_, 5, kernel_size);
ocl_runtime->SetKernelArg(kernel_, 6, stride);
ocl_runtime->SetKernelArg(kernel_, 7, padding);
ocl_runtime->SetKernelArg(kernel_, 8, dilation);
ocl_runtime->SetKernelArg(kernel_, 9, src_size);
ocl_runtime->SetKernelArg(kernel_, 10, dst_size);
if (mem_type_ == MEM_TYPE::BUF) {
ocl_runtime->SetKernelArg(kernel_, 0, inputs_[0]->Data());
ocl_runtime->SetKernelArg(kernel_, 4, outputs_[0]->Data());
ocl_runtime->RunKernel(kernel_, global, local, nullptr);
} else {
cl::ImageFormat image_format;
{
image_format.image_channel_order = CL_RGBA;
image_format.image_channel_data_type = CL_FLOAT;
}
cl_int in_error_code;
size_t im_src_x, im_src_y;
size_t im_dst_x, im_dst_y;
if (inputs_[0]->GetFormat() == schema::Format_NHWC4) {
im_src_x = inputs_[0]->Width() * CI4;
im_src_y = inputs_[0]->Height();
im_dst_x = outputs_[0]->Width() * CO4;
im_dst_y = outputs_[0]->Height();
} else {
im_src_y = inputs_[0]->Height() * CI4;
im_src_x = inputs_[0]->Width();
im_dst_y = outputs_[0]->Height() * CO4;
im_dst_x = outputs_[0]->Width();
}
cl::Image2D in_mem(*ocl_runtime->Context(), CL_MEM_READ_ONLY | CL_MEM_COPY_HOST_PTR, image_format,
im_src_x, im_src_y, 0, inputs_[0]->Data(), &in_error_code);
cl_int out_error_code;
cl::Image2D out_mem(*ocl_runtime->Context(), CL_MEM_WRITE_ONLY, image_format,
im_dst_x, im_dst_y, 0, nullptr, &out_error_code);
if (in_error_code != CL_SUCCESS) {
MS_LOG(DEBUG) << "in Image2D Failed, error=" << in_error_code;
return 1;
}
if (out_error_code != CL_SUCCESS) {
MS_LOG(DEBUG) << "out Image2D Failed, error= " << out_error_code;
return 1;
}
auto origin = cl::array < cl::size_type,
3U > {0, 0, 0};
auto region = cl::array < cl::size_type,
3U > {im_dst_x, im_dst_y, 1};
ocl_runtime->SetKernelArg(kernel_, 0, in_mem);
ocl_runtime->SetKernelArg(kernel_, 4, out_mem);
ocl_runtime->RunKernel(kernel_, global, local, nullptr);
ocl_runtime->RunKernel(kernel_, global, local, nullptr);
ocl_runtime->GetDefaultCommandQueue()->enqueueReadImage(out_mem, CL_TRUE, origin, region, 0, 0,
outputs_[0]->Data());
}
return 0;
}
kernel::LiteKernel *OpenCLDepthwiseConv2dKernelCreator(const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs,
OpParameter *opParameter, const lite::Context *ctx,
const kernel::KernelKey &desc) {
const std::vector<lite::tensor::Tensor *> &outputs,
OpParameter *opParameter, const lite::Context *ctx,
const kernel::KernelKey &desc) {
auto *kernel = new DepthwiseConv2dOpenCLKernel(reinterpret_cast<OpParameter *>(opParameter), inputs, outputs);
auto ret = kernel->Init();
if (0 != ret) {
......@@ -147,6 +209,7 @@ kernel::LiteKernel *OpenCLDepthwiseConv2dKernelCreator(const std::vector<lite::t
return kernel;
}
REG_KERNEL(kGPU, kNumberTypeFloat32, PrimitiveType_DepthwiseConv2D, OpenCLDepthwiseConv2dKernelCreator)
REG_KERNEL(kGPU, kNumberTypeFloat32, PrimitiveType_DepthwiseConv2D, OpenCLDepthwiseConv2dKernelCreator
)
} // namespace mindspore::kernel
......@@ -18,12 +18,10 @@
#define MINDSPORE_LITE_SRC_BACKEND_OPENCL_DEPTHWISE_H_
#include <vector>
#include "src/lite_kernel.h"
#include "src/runtime/kernel/arm/opclib/conv_parameter.h"
#include "src/runtime/opencl/opencl_runtime.h"
namespace mindspore::kernel {
class DepthwiseConv2dOpenCLKernel : public LiteKernel {
......@@ -31,18 +29,25 @@ class DepthwiseConv2dOpenCLKernel : public LiteKernel {
explicit DepthwiseConv2dOpenCLKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs)
: LiteKernel(parameter, inputs, outputs),
packed_weight_(nullptr), bias_data_(nullptr), kernel_(nullptr) {}
packed_weight_(nullptr), bias_data_(nullptr), kernel_(nullptr) {}
~DepthwiseConv2dOpenCLKernel() override {};
int Init() override;
int ReSize() override;
int Run() override;
int InitBuffer();
private:
FLOAT_t *packed_weight_;
FLOAT_t *bias_data_;
cl::Kernel kernel_;
enum class MEM_TYPE {
BUF, IMG
} mem_type_{MEM_TYPE::BUF};
};
} // namespace mindspore::kernel
......
......@@ -264,6 +264,8 @@ if (SUPPORT_GPU)
set(TEST_SRC
${TEST_SRC}
${TEST_DIR}/ut/stc/runtime/kernel/opencl/matmul_tests.cc
${TEST_DIR}/ut/stc/runtime/kernel/opencl/depthwise_conv2d_tests.cc
${TEST_DIR}/ut/stc/runtime/kernel/opencl/concat_tests.cc
${TEST_DIR}/ut/stc/runtime/kernel/opencl/softmax_cl_tests.cc
)
endif()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册