提交 7543cce4 编写于 作者: D DannyIsFunny

Merge remote-tracking branch 'origin' into test_result

......@@ -302,10 +302,10 @@ void elementwise_add_grad_broadcast<float>(const float* dout_grad,
int pre,
int n,
int post) {
if (x_grad) {
if (x_grad != nullptr) {
elementwise_add_grad(dout_grad, x_grad, pre * n * post);
}
if (y_grad) {
if (y_grad != nullptr) {
memset(y_grad, 0, n * sizeof(float));
#pragma omp parallel for
for (int i = 0; i < pre; ++i) {
......@@ -582,10 +582,10 @@ void elementwise_sub_grad<float>(const float* dout_grad,
float* x_grad,
float* y_grad,
int num) {
if (x_grad) {
if (x_grad != nullptr) {
elementwise_add_grad(dout_grad, x_grad, num);
}
if (y_grad) {
if (y_grad != nullptr) {
int cnt = num >> 4;
int remain = num & 0x0f;
float32x4_t minus = vdupq_n_f32(-1);
......@@ -624,10 +624,10 @@ void elementwise_sub_grad_broadcast<float>(const float* dout_grad,
int pre,
int n,
int post) {
if (x_grad) {
if (x_grad != nullptr) {
elementwise_add_grad(dout_grad, x_grad, pre * n * post);
}
if (y_grad) {
if (y_grad != nullptr) {
memset(y_grad, 0, n * sizeof(float));
#pragma omp parallel for
for (int i = 0; i < pre; ++i) {
......
......@@ -30,6 +30,143 @@ __kernel void elementwise_mul(__global image2d_t input,
WRITE_IMG_TYPE(CL_DTYPE_CHAR, outputImage, coords, output);
}
__kernel void channel_mul(__global image2d_t input,
__global image2d_t bias,
__write_only image2d_t outputImage,
int w) {
int x = get_global_id(0);
int y = get_global_id(1);
const sampler_t sampler =
CLK_NORMALIZED_COORDS_TRUE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
int2 coords;
coords.x = x;
coords.y = y;
int2 coords_bias;
coords_bias.x = x / w;
coords_bias.y = 0;
CL_DTYPE4 in = READ_IMG_TYPE(CL_DTYPE_CHAR, input, sampler, coords);
CL_DTYPE4 biase = READ_IMG_TYPE(CL_DTYPE_CHAR, bias, sampler, coords_bias);
CL_DTYPE4 output = in * biase;
WRITE_IMG_TYPE(CL_DTYPE_CHAR, outputImage, coords, output);
}
// etc : 1 1 1 72
// run time Y [value,0,0,0] * 72
__kernel void channel_mul_d2(__global image2d_t input,
__global image2d_t bias,
__write_only image2d_t outputImage,
int w) {
int x = get_global_id(0);
int y = get_global_id(1);
const sampler_t sampler =
CLK_NORMALIZED_COORDS_TRUE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
int2 coords;
coords.x = x;
coords.y = y;
int2 coords_bias0;
int2 coords_bias1;
int2 coords_bias2;
int2 coords_bias3;
/* if (x == 0 && y == 0) {
CL_DTYPE4 b = (CL_DTYPE4){0, 0, 0, 0};
#define PPI(j, k) \
b = READ_IMG_TYPE(CL_DTYPE_CHAR, bias, sampler, (int2){j, k}); \
printf("bias(%d,%d)={ %f , %f , %f , %f }\n ", j, k, convert_float(b.x), \
convert_float(b.y), convert_float(b.z), convert_float(b.w));
for (int i = 0; i < 73; ++i) {
PPI(i, 0);
}
#undef PPI
}*/
coords_bias0.x = x / w * 4;
coords_bias0.y = 0;
coords_bias1.x = x / w * 4 + 1;
coords_bias1.y = 0;
coords_bias2.x = x / w * 4 + 2;
coords_bias2.y = 0;
coords_bias3.x = x / w * 4 + 3;
coords_bias3.y = 0;
CL_DTYPE4 biase0 = READ_IMG_TYPE(CL_DTYPE_CHAR, bias, sampler, coords_bias0);
CL_DTYPE4 biase1 = READ_IMG_TYPE(CL_DTYPE_CHAR, bias, sampler, coords_bias1);
CL_DTYPE4 biase2 = READ_IMG_TYPE(CL_DTYPE_CHAR, bias, sampler, coords_bias2);
CL_DTYPE4 biase3 = READ_IMG_TYPE(CL_DTYPE_CHAR, bias, sampler, coords_bias3);
/* if (x == 0 && y == 0) {
printf("bias0={ %f , %f , %f , %f }\n ",
convert_float(biase0.x), convert_float(biase0.y),
convert_float(biase0.z), convert_float(biase0.w));
printf("bias1={ %f , %f , %f , %f }\n ",
convert_float(biase1.x), convert_float(biase1.y),
convert_float(biase1.z), convert_float(biase1.w));
printf("bias2={ %f , %f , %f , %f }\n ",
convert_float(biase2.x), convert_float(biase2.y),
convert_float(biase2.z), convert_float(biase2.w));
printf("bias3={ %f , %f , %f , %f }\n ",
convert_float(biase3.x), convert_float(biase3.y),
convert_float(biase3.z), convert_float(biase3.w));
}*/
CL_DTYPE4 biase = {biase0.x, biase1.x, biase2.x, biase3.x};
CL_DTYPE4 in = READ_IMG_TYPE(CL_DTYPE_CHAR, input, sampler, coords);
CL_DTYPE4 output = mad(in, biase, 0);
WRITE_IMG_TYPE(CL_DTYPE_CHAR, outputImage, coords, output);
}
// c 1 1
__kernel void channel_mul_d3(__global image2d_t input,
__global image2d_t bias,
__write_only image2d_t outputImage,
int w) {
int x = get_global_id(0);
int y = get_global_id(1);
const sampler_t sampler =
CLK_NORMALIZED_COORDS_TRUE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
int2 coords;
coords.x = x;
coords.y = y;
int2 coords_bias;
coords_bias.x = x / w;
coords_bias.y = 0;
CL_DTYPE4 in = READ_IMG_TYPE(CL_DTYPE_CHAR, input, sampler, coords);
CL_DTYPE4 biase = READ_IMG_TYPE(CL_DTYPE_CHAR, bias, sampler, coords_bias);
CL_DTYPE4 output = in * biase;
WRITE_IMG_TYPE(CL_DTYPE_CHAR, outputImage, coords, output);
}
__kernel void channel_mul_d4(__global image2d_t input,
__global image2d_t bias,
__write_only image2d_t outputImage, int w) {
int x = get_global_id(0);
int y = get_global_id(1);
const sampler_t sampler =
CLK_NORMALIZED_COORDS_TRUE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
int2 coords;
coords.x = x;
coords.y = y;
int2 coords_bias;
coords_bias.x = x / w;
coords_bias.y = 0;
CL_DTYPE4 in = READ_IMG_TYPE(CL_DTYPE_CHAR, input, sampler, coords);
CL_DTYPE4 biase = READ_IMG_TYPE(CL_DTYPE_CHAR, bias, sampler, coords_bias);
CL_DTYPE4 output = in * biase;
WRITE_IMG_TYPE(CL_DTYPE_CHAR, outputImage, coords, output);
}
#if 0 // TODO(ysh329): comment code below
__kernel void elementwise_mul(__global image2d_t input,
__global image2d_t bias,
__write_only image2d_t outputImage) {
int x = get_global_id(0);
int y = get_global_id(1);
const sampler_t sampler =
CLK_NORMALIZED_COORDS_TRUE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
int2 coords;
coords.x = x;
coords.y = y;
CL_DTYPE4 in = READ_IMG_TYPE(CL_DTYPE_CHAR, input, sampler, coords);
CL_DTYPE4 biase = READ_IMG_TYPE(CL_DTYPE_CHAR, bias, sampler, coords);
CL_DTYPE4 output = in * biase;
WRITE_IMG_TYPE(CL_DTYPE_CHAR, outputImage, coords, output);
}
__kernel void channel_mul_d1(__read_only image2d_t input,
__read_only image2d_t bias,
......@@ -184,4 +321,4 @@ __kernel void channel_mul_d4(__read_only image2d_t input,
WRITE_IMG_TYPE(CL_DTYPE_CHAR, outputImage, coords, output);
}
#endif
......@@ -14,14 +14,127 @@ limitations under the License. */
#include <cl_common.h>
// onnx/pytorch instancenorm by lijian
__kernel void instance_norm_onnx(__private const int in_width,
__private const int in_height,
__private const int in_c_group,
__private const int local_work_size_x,
__private const int local_work_size_y,
__private const float epsilon,
__read_only image2d_t input,
__write_only image2d_t output) {
const int out_cn = get_global_id(0);
const int n = out_cn / in_c_group;
const int c = out_cn % in_c_group;
const int w = get_local_id(1);
const int h = get_local_id(2);
const int local_id = w * local_work_size_y + h;
const int local_total_size = local_work_size_x * local_work_size_y;
__kernel void instance_norm(__read_only image2d_t input,
__write_only image2d_t output,
__read_only image2d_t scale,
__read_only image2d_t bias,
const float epsilon,
const int in_h,
const int in_w){
const sampler_t sampler =
CLK_NORMALIZED_COORDS_TRUE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
#ifdef LOCAL_MEM_128
__local float4 shared_mem[128];
#elif defined(LOCAL_MEM_64)
__local float4 shared_mem[64];
#else
__local float4 shared_mem[256];
#endif
int xOffset = c * in_width;
int yOffset = n * in_height;
float4 sum = 0.0f;
for (int xIndex = w; xIndex < in_width; xIndex += local_work_size_x) {
for (int yIndex = h; yIndex < in_height; yIndex += local_work_size_y) {
sum += read_imagef(input, sampler, (int2)(xOffset + xIndex, yOffset + yIndex));
}
}
shared_mem[local_id] = sum;
barrier(CLK_LOCAL_MEM_FENCE);
sum = 0.0f;
if (local_id < 32) {
for (int i = local_id + 32; i < local_total_size; i += 32) {
sum += shared_mem[i];
}
}
shared_mem[local_id] += sum;
barrier(CLK_LOCAL_MEM_FENCE);
sum = 0.0f;
if (local_id == 0) {
int top = min(32, local_total_size);
for (int i = 0; i < top; i += 1) {
sum += shared_mem[i];
}
shared_mem[0] = sum / (in_width * in_height);
}
barrier(CLK_LOCAL_MEM_FENCE);
const float4 mean_val = shared_mem[0];
barrier(CLK_LOCAL_MEM_FENCE);
sum = 0.0f;
for (int xIndex = w; xIndex < in_width; xIndex += local_work_size_x) {
for (int yIndex = h; yIndex < in_height; yIndex += local_work_size_y) {
float4 temp = read_imagef(input, sampler, (int2)(xOffset + xIndex, yOffset + yIndex)) - mean_val;
sum += temp * temp;
}
}
shared_mem[local_id] = sum;
barrier(CLK_LOCAL_MEM_FENCE);
sum = 0.0f;
if (local_id < 32) {
for (int i = local_id + 32; i < local_total_size; i += 32) {
sum += shared_mem[i];
}
}
shared_mem[local_id] += sum;
barrier(CLK_LOCAL_MEM_FENCE);
sum = 0.0f;
if (local_id == 0) {
int top = min(32, local_total_size);
for (int i = 0; i < top; i += 1) {
sum += shared_mem[i];
}
shared_mem[0] = sum / (in_width * in_height);
}
barrier(CLK_LOCAL_MEM_FENCE);
const float4 sigma = sqrt(shared_mem[0] + (float4)(epsilon));
float4 s = 1 / sigma;
for (int xIndex = w; xIndex < in_width; xIndex += local_work_size_x) {
for (int yIndex = h; yIndex < in_height; yIndex += local_work_size_y) {
int2 intout_pos = (int2)(xOffset + xIndex, yOffset + yIndex);
float4 in_val = read_imagef(input, sampler, intout_pos);
half4 out_val = convert_half4((in_val - mean_val) * s);
#ifdef RELU
out_val = activation(out_val);
#endif
write_imageh(output, intout_pos, out_val);
}
}
}
// paddle instancenorm by zhangxi
__kernel void instance_norm_paddle(__read_only image2d_t input,
__write_only image2d_t output,
__read_only image2d_t scale,
__read_only image2d_t bias,
const float epsilon,
const int in_h,
const int in_w){
__local CL_DTYPE4 saved_mean[1024];
__local CL_DTYPE4 saved_variance[1024];
const int lid = get_local_id(0);
......
......@@ -128,6 +128,12 @@ bool CLRuntime::InitializePlatform() {
}
bool CLRuntime::InitializeDevice() {
// ===================== BASIC =====================
// CL_DEVICE_TYPE_GPU
// CL_DEVICE_NAME
// CL_DEVICE_SUPPORT
// CL_DEVICE_MAX_COMPUTE_UNITS
// CL_DEVICE_MAX_CLOCK_FREQUENCY
std::vector<cl::Device> all_devices;
status_ = platform_->getDevices(CL_DEVICE_TYPE_GPU, &all_devices);
CL_CHECK_ERROR(status_);
......@@ -140,27 +146,153 @@ bool CLRuntime::InitializeDevice() {
auto device_name = device_->getInfo<CL_DEVICE_NAME>();
LOG(INFO) << "Using device: " << device_name;
cl_device_type device_type = device_->getInfo<CL_DEVICE_TYPE>();
auto device_type_to_str = [](cl_device_type t) -> std::string {
std::string t_str{""};
switch (t) {
case CL_DEVICE_TYPE_CPU:
t_str = "CPU";
break;
case CL_DEVICE_TYPE_GPU:
t_str = "GPU";
break;
case CL_DEVICE_TYPE_ACCELERATOR:
t_str = "Accelerator";
break;
case CL_DEVICE_TYPE_DEFAULT:
t_str = "Default";
break;
default:
t_str = "Unknown";
}
return t_str;
};
LOG(INFO) << "device_type:" << device_type_to_str(device_type);
device_info_["CL_DEVICE_TYPE"] = device_type;
auto max_units = device_->getInfo<CL_DEVICE_MAX_COMPUTE_UNITS>();
LOG(INFO) << "The chosen device has " << max_units << " compute units.";
device_info_["CL_DEVICE_MAX_COMPUTE_UNITS"] = max_units;
auto max_clock_freq = device_->getInfo<CL_DEVICE_MAX_CLOCK_FREQUENCY>();
LOG(INFO) << "CL_DEVICE_MAX_CLOCK_FREQUENCY:" << max_clock_freq;
device_info_["CL_DEVICE_MAX_CLOCK_FREQUENCY"] = max_clock_freq;
// ===================== MEMORY =====================
// CL_DEVICE_LOCAL_MEM_SIZE
// CL_DEVICE_GLOBAL_MEM_CACHE_SIZE
// CL_DEVICE_GLOBAL_MEM_CACHELINE_SIZE
// CL_DEVICE_GLOBAL_MEM_SIZE
auto local_mem_kb =
static_cast<float>(device_->getInfo<CL_DEVICE_LOCAL_MEM_SIZE>()) / 1024;
LOG(INFO) << "The local memory size of the chosen device is " << local_mem_kb
<< " KB.";
device_info_["CL_DEVICE_LOCAL_MEM_SIZE_KB"] = local_mem_kb;
auto global_mem_cache_size_kb =
static_cast<float>(device_->getInfo<CL_DEVICE_GLOBAL_MEM_CACHE_SIZE>()) /
1024;
LOG(INFO) << "CL_DEVICE_GLOBAL_MEM_CACHE_SIZE(KB):"
<< global_mem_cache_size_kb << " KB.";
device_info_["CL_DEVICE_GLOBAL_MEM_CACHE_SIZE_KB"] = global_mem_cache_size_kb;
auto global_mem_cacheline_size_kb =
static_cast<float>(
device_->getInfo<CL_DEVICE_GLOBAL_MEM_CACHELINE_SIZE>()) /
1024;
LOG(INFO) << "CL_DEVICE_GLOBAL_MEM_CACHELINE_SIZE(KB):"
<< global_mem_cacheline_size_kb << " KB.";
device_info_["CL_DEVICE_GLOBAL_MEM_CACHELINE_SIZE_KB"] =
global_mem_cacheline_size_kb;
auto global_mem_size_kb =
static_cast<float>(device_->getInfo<CL_DEVICE_GLOBAL_MEM_SIZE>()) / 1024;
LOG(INFO) << "CL_DEVICE_GLOBAL_MEM_SIZE(KB):" << global_mem_size_kb << " KB.";
device_info_["CL_DEVICE_GLOBAL_MEM_SIZE_KB"] = global_mem_size_kb;
// ===================== WORK_GROUP =====================
// CL_DEVICE_MAX_WORK_GROUP_SIZE
// CL_DEVICE_MAX_WORK_ITEM_DIMENSIONS
// CL_DEVICE_MAX_WORK_ITEM_SIZES
auto max_work_group_size = device_->getInfo<CL_DEVICE_MAX_WORK_GROUP_SIZE>();
LOG(INFO) << "CL_DEVICE_MAX_WORK_GROUP_SIZE:" << max_work_group_size;
device_info_["CL_DEVICE_MAX_WORK_GROUP_SIZE"] = max_work_group_size;
auto max_dims_num = device_->getInfo<CL_DEVICE_MAX_WORK_ITEM_DIMENSIONS>();
LOG(INFO) << "CL_DEVICE_MAX_WORK_ITEM_DIMENSIONS:" << max_dims_num;
device_info_["CL_DEVICE_MAX_WORK_ITEM_DIMENSIONS"] = max_dims_num;
auto max_work_item_sizes = device_->getInfo<CL_DEVICE_MAX_WORK_ITEM_SIZES>();
for (size_t i = 0; i < max_work_item_sizes.size(); ++i) {
LOG(INFO) << "max_work_item_sizes[" << i << "]:" << max_work_item_sizes[i];
std::string dim_key = "CL_DEVICE_MAX_WORK_ITEM_SIZES_" + std::to_string(i);
device_info_[dim_key] = max_work_item_sizes[i];
}
// ===================== BUFFER =====================
// CL_DEVICE_MAX_CONSTANT_BUFFER_SIZE
auto max_constant_buffer_size_kb =
static_cast<float>(
device_->getInfo<CL_DEVICE_MAX_CONSTANT_BUFFER_SIZE>()) /
1024;
LOG(INFO) << "CL_DEVICE_MAX_CONSTANT_BUFFER_SIZE:"
<< max_constant_buffer_size_kb;
device_info_["CL_DEVICE_MAX_CONSTANT_BUFFER_SIZE"] =
max_constant_buffer_size_kb;
// ===================== IMAGE =====================
// CL_DEVICE_IMAGE_SUPPORT
// CL_DEVICE_IMAGE2D_MAX_HEIGHT
// CL_DEVICE_IMAGE2D_MAX_WIDTH
auto image_support = device_->getInfo<CL_DEVICE_IMAGE_SUPPORT>();
if (image_support) {
LOG(INFO) << "The chosen device supports image processing.";
device_info_["CL_DEVICE_IMAGE_SUPPORT"] = 1;
} else {
LOG(INFO) << "The chosen device doesn't support image processing!";
device_info_["CL_DEVICE_IMAGE_SUPPORT"] = 0;
return false;
}
auto image2d_max_height = device_->getInfo<CL_DEVICE_IMAGE2D_MAX_HEIGHT>();
LOG(INFO) << "CL_DEVICE_IMAGE2D_MAX_HEIGHT:" << image2d_max_height;
device_info_["CL_DEVICE_IMAGE2D_MAX_HEIGHT"] = image2d_max_height;
auto image2d_max_width = device_->getInfo<CL_DEVICE_IMAGE2D_MAX_WIDTH>();
LOG(INFO) << "CL_DEVICE_IMAGE2D_MAX_WIDTH:" << image2d_max_width;
device_info_["CL_DEVICE_IMAGE2D_MAX_WIDTH"] = image2d_max_width;
// ===================== OTHERS / EXTENSION / VERSION =====================
// CL_DEVICE_EXTENSIONS
// CL_DEVICE_ADDRESS_BITS
auto ext_data = device_->getInfo<CL_DEVICE_EXTENSIONS>();
VLOG(4) << "The extensions supported by this device: " << ext_data;
if (ext_data.find("cl_khr_fp16") != std::string::npos) {
LOG(INFO) << "The chosen device supports the half data type.";
device_info_["CL_DEVICE_EXTENSIONS_FP16"] = 1;
} else {
LOG(INFO) << "The chosen device doesn't support the half data type!";
device_info_["CL_DEVICE_EXTENSIONS_FP16"] = 0;
}
auto max_units = device_->getInfo<CL_DEVICE_MAX_COMPUTE_UNITS>();
LOG(INFO) << "The chosen device has " << max_units << " compute units.";
auto local_mem = device_->getInfo<CL_DEVICE_LOCAL_MEM_SIZE>();
LOG(INFO) << "The local memory size of the chosen device is "
<< static_cast<float>(local_mem) / 1024 << " KB.";
auto address_bits = device_->getInfo<CL_DEVICE_ADDRESS_BITS>();
LOG(INFO) << "CL_DEVICE_ADDRESS_BITS:" << address_bits;
device_info_["CL_DEVICE_ADDRESS_BITS"] = address_bits;
auto driver_version = device_->getInfo<CL_DRIVER_VERSION>();
LOG(INFO) << "CL_DRIVER_VERSION:" << driver_version;
return true;
}
std::map<std::string, size_t>& CLRuntime::GetDeviceInfo() {
if (0 != device_info_.size()) {
return device_info_;
}
InitializeDevice();
return device_info_;
}
} // namespace lite
} // namespace paddle
......@@ -55,6 +55,8 @@ class CLRuntime {
void set_cl_path(std::string cl_path) { cl_path_ = cl_path; }
std::map<std::string, size_t>& GetDeviceInfo();
private:
CLRuntime() = default;
......@@ -84,6 +86,8 @@ class CLRuntime {
return queue;
}
std::map<std::string, size_t> device_info_;
std::string cl_path_;
std::shared_ptr<cl::Platform> platform_{nullptr};
......
......@@ -51,7 +51,7 @@ void* TargetMalloc(TargetType target, size_t size) {
return data;
}
void TargetFree(TargetType target, void* data) {
void TargetFree(TargetType target, void* data, std::string free_flag) {
switch (target) {
case TargetType::kHost:
case TargetType::kX86:
......@@ -66,7 +66,11 @@ void TargetFree(TargetType target, void* data) {
#endif // LITE_WITH_CUDA
#ifdef LITE_WITH_OPENCL
case TargetType::kOpenCL:
TargetWrapperCL::Free(data);
if (free_flag == "cl_use_image2d_") {
TargetWrapperCL::FreeImage(data);
} else {
TargetWrapperCL::Free(data);
}
break;
#endif // LITE_WITH_OPENCL
#ifdef LITE_WITH_FPGA
......
......@@ -13,6 +13,7 @@
// limitations under the License.
#pragma once
#include <string>
#include "lite/api/paddle_place.h"
#include "lite/core/target_wrapper.h"
#include "lite/utils/logging.h"
......@@ -39,7 +40,9 @@ LITE_API void* TargetMalloc(TargetType target, size_t size);
// Free memory for a specific Target. All the targets should be an element in
// the `switch` here.
void LITE_API TargetFree(TargetType target, void* data);
void LITE_API TargetFree(TargetType target,
void* data,
std::string free_flag = "");
// Copy a buffer from host to another target.
void TargetCopy(TargetType target, void* dst, const void* src, size_t size);
......@@ -108,6 +111,9 @@ class Buffer {
data_ = TargetMalloc(target, size);
target_ = target;
space_ = size;
#ifdef LITE_WITH_OPENCL
cl_use_image2d_ = false;
#endif
}
}
......@@ -119,15 +125,15 @@ class Buffer {
const size_t img_w,
const size_t img_h,
void* host_ptr = nullptr) {
size_t size = sizeof(T) * img_w * img_h *
4; // 4 for RGBA, un-used for opencl Image2D
if (target != target_ || cl_image2d_width_ < img_w ||
cl_image2d_height_ < img_h) {
CHECK_EQ(own_data_, true) << "Can not reset unowned buffer.";
Free();
data_ = TargetWrapperCL::MallocImage<T>(img_w, img_h, host_ptr);
target_ = target;
space_ = size; // un-used for opencl Image2D
space_ = sizeof(T) * img_w * img_h *
4; // un-used for opencl Image2D, 4 for RGBA,
cl_use_image2d_ = true;
cl_image2d_width_ = img_w;
cl_image2d_height_ = img_h;
}
......@@ -136,7 +142,11 @@ class Buffer {
void Free() {
if (space_ > 0 && own_data_) {
TargetFree(target_, data_);
if (!cl_use_image2d_) {
TargetFree(target_, data_);
} else {
TargetFree(target_, data_, "cl_use_image2d_");
}
}
data_ = nullptr;
target_ = TargetType::kHost;
......@@ -155,6 +165,7 @@ class Buffer {
private:
// memory it actually malloced.
size_t space_{0};
bool cl_use_image2d_{false}; // only used for OpenCL Image2D
size_t cl_image2d_width_{0}; // only used for OpenCL Image2D
size_t cl_image2d_height_{0}; // only used for OpenCL Image2D
void* data_{nullptr};
......
......@@ -76,8 +76,8 @@ void ElementwiseAddGradCompute::Run() {
const float* x_data = param.X->data<float>();
const float* y_data = param.Y->data<float>();
const float* out_grad_data = param.OutGrad->data<float>();
float* x_grad_data;
float* y_grad_data;
float* x_grad_data = nullptr;
float* y_grad_data = nullptr;
if (param.XGrad) {
x_grad_data = param.XGrad->mutable_data<float>();
}
......@@ -122,8 +122,8 @@ void ElementwiseSubGradCompute::Run() {
const float* x_data = param.X->data<float>();
const float* y_data = param.Y->data<float>();
const float* out_data = param.OutGrad->data<float>();
float* x_grad_data;
float* y_grad_data;
float* x_grad_data = nullptr;
float* y_grad_data = nullptr;
if (param.XGrad) {
x_grad_data = param.XGrad->mutable_data<float>();
}
......@@ -137,9 +137,15 @@ void ElementwiseSubGradCompute::Run() {
if (!param.XGrad || !param.YGrad) {
CHECK(param.XGrad || param.YGrad);
lite::arm::math::elementwise_sub_grad(
out_data, x_grad_data, y_grad_data, y_dims.production());
return;
if (param.XGrad) {
lite::arm::math::elementwise_sub_grad(
out_data, x_grad_data, y_grad_data, x_dims.production());
return;
} else {
lite::arm::math::elementwise_sub_grad(
out_data, x_grad_data, y_grad_data, y_dims.production());
return;
}
}
if (x_dims.size() < y_dims.size()) {
......
......@@ -67,8 +67,8 @@ lite_cc_test(test_reshape_image_opencl SRCS reshape_image_compute_test.cc
lite_cc_test(test_concat_image_opencl SRCS concat_image_compute_test.cc
DEPS concat_opencl layout_opencl op_registry program context)
lite_cc_test(test_elementwise_mul_image_opencl SRCS elementwise_mul_image_compute_test.cc
DEPS elementwise_mul_opencl op_registry program context)
#lite_cc_test(test_elementwise_mul_image_opencl SRCS elementwise_mul_image_compute_test.cc
# DEPS elementwise_mul_opencl op_registry program context)
lite_cc_test(test_layout_image_opencl SRCS layout_image_compute_test.cc
DEPS layout_opencl op_registry program context)
......@@ -89,8 +89,8 @@ lite_cc_test(test_bilinear_interp_image_opencl SRCS bilinear_interp_image_comput
lite_cc_test(test_slice_image_opencl SRCS slice_image_compute_test.cc
DEPS slice_opencl op_registry program context)
lite_cc_test(test_instance_norm_image_opencl SRCS instance_norm_image_compute_test.cc
DEPS instance_norm_opencl op_registry program context)
#lite_cc_test(test_instance_norm_image_opencl SRCS instance_norm_image_compute_test.cc
# DEPS instance_norm_opencl op_registry program context)
lite_cc_test(test_dropout_image_opencl SRCS dropout_image_compute_test.cc
DEPS dropout_opencl op_registry program context)
......
......@@ -14,8 +14,8 @@
#include "lite/kernels/opencl/conv_image_compute.h"
#include <iomanip>
#include <sstream>
#include "lite/backends/opencl/cl_image_converter.h"
#include "lite/backends/opencl/cl_include.h"
#include "lite/core/op_registry.h"
......@@ -78,9 +78,27 @@ void ConvImageCompute::PrepareForRun() {
VLOG(3) << "dilation_equal:" << dilation_equal;
VLOG(3) << "padding :" << paddings[0] << " " << paddings[1] << " "
<< paddings[2] << " " << paddings[3];
CHECK(pad_equal && stride_equal && dilation_equal);
// general gws..
auto out_image_shape = InitImageDimInfoWith(output_dims);
const std::vector<size_t>& default_work_size =
DefaultWorkSize(output_dims,
DDim(std::vector<DDim::value_type>{
static_cast<int64_t>(out_image_shape["width"]),
static_cast<int64_t>(out_image_shape["height"])}));
default_c_blk_ = default_work_size[0];
default_w_blk_ = default_work_size[1];
default_nh_blk_ = default_work_size[2];
c_blk_ = default_c_blk_;
w_blk_ = default_w_blk_;
nh_blk_ = default_nh_blk_;
global_work_size_ = cl::NDRange{static_cast<size_t>(c_blk_),
static_cast<size_t>(w_blk_),
static_cast<size_t>(nh_blk_)};
if (kernel_h == 1 && kernel_w == 1) {
// conv2d_1x1
if (param.x->dims()[1] % 4 == 0) {
......@@ -99,6 +117,15 @@ void ConvImageCompute::PrepareForRun() {
filter_image_dims[0], filter_image_dims[1], filter_image_v.data());
impl_ = &ConvImageCompute::Conv2d1x1opt;
{
// calc 1x1 gws
w_blk_ = maptofactor(default_w_blk_, 4);
c_blk_ = default_c_blk_;
nh_blk_ = default_nh_blk_;
global_work_size_ = cl::NDRange{static_cast<size_t>(c_blk_),
static_cast<size_t>(w_blk_),
static_cast<size_t>(nh_blk_)};
}
#define DEPTH_CONV_USE_SPL
#ifdef DEPTH_CONV_USE_SPL
} else if (filter_dims[1] == 1 && x_dims[1] == output_dims[1] &&
......@@ -107,9 +134,38 @@ void ConvImageCompute::PrepareForRun() {
if (stride_h == 1 && dilations[0] == 1) {
kernel_func_names_.push_back("depth_conv2d_3x3s1");
impl_ = &ConvImageCompute::DepthwiseConv2d3x3s1;
{
// depthwise spl gws s1
int c_block = (output_dims[1] + 3) / 4;
int w = output_dims[3];
int nh = output_dims[0] * output_dims[2];
int w_blk_size = 2;
int w_blk = (w + w_blk_size - 1) / w_blk_size;
c_blk_ = c_block;
w_blk_ = w_blk;
nh_blk_ = nh;
global_work_size_ = cl::NDRange{static_cast<size_t>(c_blk_),
static_cast<size_t>(w_blk_),
static_cast<size_t>(nh_blk_)};
}
} else {
kernel_func_names_.push_back("depth_conv2d_3x3");
impl_ = &ConvImageCompute::DepthwiseConv2d3x3;
{
// depthwise spl gws
int c_block = (output_dims[1] + 3) / 4;
int w = output_dims[3];
int nh = output_dims[0] * output_dims[2];
c_blk_ = c_block;
w_blk_ = w;
nh_blk_ = nh;
global_work_size_ = cl::NDRange{static_cast<size_t>(c_blk_),
static_cast<size_t>(w_blk_),
static_cast<size_t>(nh_blk_)};
}
}
kernel_func_paths_.push_back("image/depthwise_conv2d_kernel.cl");
......@@ -157,6 +213,22 @@ void ConvImageCompute::PrepareForRun() {
filter_image_dims[0], filter_image_dims[1], filter_image_v.data());
impl_ = &ConvImageCompute::Conv2d3x3opt;
{
int w_blk_size = 5;
int w_blk = (default_w_blk_ + w_blk_size - 1) / w_blk_size;
int h_blk_size = 1;
int h_blk = (default_nh_blk_ + h_blk_size - 1) / h_blk_size;
c_blk_ = default_c_blk_;
w_blk_ = w_blk;
nh_blk_ = h_blk;
global_work_size_ = cl::NDRange{static_cast<size_t>(c_blk_),
static_cast<size_t>(w_blk_),
static_cast<size_t>(nh_blk_)};
}
} else if (kernel_h == 5 && kernel_w == 5) {
#define CONV_5x5_OPT
#ifndef CONV_5x5_OPT
......@@ -189,6 +261,21 @@ void ConvImageCompute::PrepareForRun() {
filter_image_dims[0], filter_image_dims[1], filter_image_v.data());
impl_ = &ConvImageCompute::Conv2d5x5opt;
{
int w_blk_size = 5;
int w_blk = (default_w_blk_ + w_blk_size - 1) / w_blk_size;
int h_blk_size = 1;
int h_blk = (default_nh_blk_ + h_blk_size - 1) / h_blk_size;
c_blk_ = default_c_blk_;
w_blk_ = w_blk;
nh_blk_ = h_blk;
global_work_size_ = cl::NDRange{static_cast<size_t>(c_blk_),
static_cast<size_t>(w_blk_),
static_cast<size_t>(nh_blk_)};
}
#endif
#undef CONV_5x5_OPT
} else if (kernel_h == 7 && kernel_w == 7) {
......@@ -223,6 +310,21 @@ void ConvImageCompute::PrepareForRun() {
filter_image_dims[0], filter_image_dims[1], filter_image_v.data());
impl_ = &ConvImageCompute::Conv2d7x7opt;
{
int w_blk_size = 5;
int w_blk = (default_w_blk_ + w_blk_size - 1) / w_blk_size;
int h_blk_size = 1;
int h_blk = (default_nh_blk_ + h_blk_size - 1) / h_blk_size;
c_blk_ = default_c_blk_;
w_blk_ = w_blk;
nh_blk_ = h_blk;
global_work_size_ = cl::NDRange{static_cast<size_t>(c_blk_),
static_cast<size_t>(w_blk_),
static_cast<size_t>(nh_blk_)};
}
#endif
#undef CONV_7x7_OPT
......@@ -270,9 +372,36 @@ void ConvImageCompute::PrepareForRun() {
context.cl_context()->AddKernel(
kernel_func_names_[i], kernel_func_paths_[i], build_options_[i]);
}
VLOG(4) << "global_work_size_[3D]: {" << global_work_size_[0] << ","
<< global_work_size_[1] << "," << global_work_size_[2] << "}";
std::stringstream kernel_key;
kernel_key << kernel_func_names_[0] << build_options_[0];
kernel_ = context.cl_context()->GetKernel(kernel_key.str());
VLOG(4) << "kernel_key: " << kernel_key.str();
VLOG(4) << "kernel ready ... " << kernel_key.str();
size_t max_work_group_size = 0;
kernel_.getWorkGroupInfo<size_t>(CLRuntime::Global()->device(),
CL_KERNEL_WORK_GROUP_SIZE,
&max_work_group_size);
VLOG(4) << "max_work_group_size: " << max_work_group_size;
if (max_work_group_size > 0 && use_lws) {
// local_work_size_ = context.cl_context()->LocalWorkSizeConv1x1(
// global_work_size_, max_work_group_size);
local_work_size_ = context.cl_context()->LocalWorkSize(global_work_size_,
max_work_group_size);
VLOG(4) << "local_work_size_[3D]: {" << local_work_size_[0] << ","
<< local_work_size_[1] << "," << local_work_size_[2] << "}";
}
}
void ConvImageCompute::Conv2d1x1opt() {
auto& context = ctx_->As<OpenCLContext>();
CHECK(context.cl_context() != nullptr);
const auto& param = *param_.get_mutable<param_t>();
auto input_dims = param.x->dims();
auto paddings = *param.paddings;
......@@ -302,16 +431,28 @@ void ConvImageCompute::Conv2d1x1opt() {
int input_c = input_dims[1];
auto dilations = *param.dilations;
const std::vector<size_t>& default_work_size =
DefaultWorkSize(output_dims,
DDim(std::vector<DDim::value_type>{
static_cast<int64_t>(out_image_shape["width"]),
static_cast<int64_t>(out_image_shape["height"])}));
// const std::vector<size_t>& default_work_size =
// DefaultWorkSize(output_dims,
// DDim(std::vector<DDim::value_type>{
// static_cast<int64_t>(out_image_shape["width"]),
// static_cast<int64_t>(out_image_shape["height"])}));
// int c_block = default_work_size[0];
// int w = default_work_size[1];
// int nh = default_work_size[2];
// int maped_w = maptofactor(w, 4);
int c_block = default_work_size[0];
int w = default_work_size[1];
int nh = default_work_size[2];
// auto global_work_size_ =
// cl::NDRange{static_cast<size_t>(default_work_size.data()[0]),
// static_cast<size_t>(maped_w),
// static_cast<size_t>(default_work_size.data()[2])};
#ifndef LITE_SHUTDOWN_LOG
// VLOG(4) << "out_image: " << out_image;
VLOG(4) << "global_work_size_[3D]: {" << global_work_size_[0] << ","
<< global_work_size_[1] << "," << global_work_size_[2] << "}";
#endif
#ifndef LITE_SHUTDOWN_LOG
VLOG(4) << "============ conv2d_1x1 params ============";
VLOG(4) << "input_image_shape: " << input_image_shape["width"] << ","
......@@ -331,9 +472,9 @@ void ConvImageCompute::Conv2d1x1opt() {
VLOG(4) << "offset: " << offset;
VLOG(4) << "dilations.size : " << dilations.size();
VLOG(4) << "dilations: " << dilations[0] << ", " << dilations[1];
VLOG(4) << "default work size{c_block, w, nh}: "
<< "{" << c_block << ", " << w << ", " << nh << ""
<< "}";
// VLOG(4) << "default work size{c_block, w, nh}: "
// << "{" << c_block << ", " << w << ", " << nh << ""
// << "}";
#endif
CHECK_GE(dilations.size(), 2);
CHECK(dilations[0] == dilations[1]);
......@@ -350,27 +491,14 @@ void ConvImageCompute::Conv2d1x1opt() {
bias_image = bias_gpu_image_.data<half_t, cl::Image2D>();
}
auto& context = ctx_->As<OpenCLContext>();
CHECK(context.cl_context() != nullptr);
std::stringstream kernel_key;
kernel_key << kernel_func_names_[0] << build_options_[0];
auto kernel = context.cl_context()->GetKernel(kernel_key.str());
int maped_w = maptofactor(w, 4);
#ifndef LITE_SHUTDOWN_LOG
VLOG(4) << "kernel_key: " << kernel_key.str();
VLOG(4) << "kernel ready ... " << kernel_key.str();
VLOG(4) << "maped_w: " << maped_w;
VLOG(4) << "hasbias: " << has_bias;
#endif
auto kernel = kernel_;
cl_int status;
int arg_idx = 0;
status = kernel.setArg(arg_idx, c_block);
status = kernel.setArg(arg_idx, c_blk_);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, maped_w);
status = kernel.setArg(++arg_idx, w_blk_);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, nh);
status = kernel.setArg(++arg_idx, nh_blk_);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *input_image);
CL_CHECK_FATAL(status);
......@@ -401,49 +529,87 @@ void ConvImageCompute::Conv2d1x1opt() {
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, output_height);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, w);
status = kernel.setArg(++arg_idx, default_w_blk_);
CL_CHECK_FATAL(status);
auto global_work_size =
cl::NDRange{static_cast<size_t>(default_work_size.data()[0]),
static_cast<size_t>(maped_w),
static_cast<size_t>(default_work_size.data()[2])};
#ifndef LITE_SHUTDOWN_LOG
// VLOG(4) << "out_image: " << out_image;
VLOG(4) << "global_work_size[3D]: {" << global_work_size[0] << ","
<< global_work_size[1] << "," << global_work_size[2] << "}";
#endif
size_t max_work_group_size = 0;
kernel.getWorkGroupInfo<size_t>(CLRuntime::Global()->device(),
CL_KERNEL_WORK_GROUP_SIZE,
&max_work_group_size);
cl::NDRange local_work_size = cl::NullRange;
#ifndef LITE_SHUTDOWN_LOG
VLOG(4) << "max_work_group_size: " << max_work_group_size;
#endif
if (max_work_group_size > 0 && use_lws) {
local_work_size = context.cl_context()->LocalWorkSize(global_work_size,
max_work_group_size);
#ifndef LITE_SHUTDOWN_LOG
VLOG(4) << "local_work_size[3D]: {" << local_work_size[0] << ","
<< local_work_size[1] << "," << local_work_size[2] << "}";
#endif
}
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel,
cl::NullRange,
global_work_size,
local_work_size,
global_work_size_,
local_work_size_,
nullptr,
event_.get());
CL_CHECK_FATAL(status);
context.cl_wait_list()->emplace(out_image, event_);
#ifdef PROFILE_CONV_KERNEL
bool use_profile = false;
auto GetCurrentUS = []() -> double {
struct timeval time;
gettimeofday(&time, NULL);
return 1e+6 * time.tv_sec + time.tv_usec;
};
double start = GetCurrentUS();
if (use_profile) {
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel,
cl::NullRange,
global_work_size_,
local_work_size_,
nullptr,
event_.get());
CL_CHECK_FATAL(status);
context.cl_wait_list()->emplace(out_image, event_);
} else {
int count = 50;
double sumtime = 0;
if (!use_profile) {
count = 1;
}
for (size_t i = 0; i < count; i++) {
start = GetCurrentUS();
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel,
cl::NullRange,
global_work_size_,
local_work_size_,
nullptr,
event_.get());
CL_CHECK_FATAL(status);
context.cl_wait_list()->emplace(out_image, event_);
if (use_profile) {
event_->wait();
double duration = GetCurrentUS() - start;
sumtime += duration;
}
}
auto dims_string = [](DDimLite dims) -> std::string {
std::ostringstream stream;
stream << "[" << dims[0] << "," << dims[1] << "," << dims[2] << ","
<< dims[3] << "]";
return stream.str();
};
if (use_profile) {
// LOG(INFO) << "input: " << input_dims;
// LOG(INFO) << "filter: " << filter_dims;
// LOG(INFO) << "output: " << output_dims;
std::cout << std::setw(25) << std::left << dims_string(input_dims)
<< std::setw(25) << std::left << dims_string(filter_dims)
<< std::setw(25) << std::left << dims_string(output_dims)
<< std::setw(25) << std::left << sumtime / count << std::endl;
} else {
dims_string(input_dims);
}
}
#endif
}
void ConvImageCompute::Conv2d3x3() {
auto& context = ctx_->As<OpenCLContext>();
CHECK(context.cl_context() != nullptr);
const auto& param = *param_.get_mutable<param_t>();
auto input_dims = param.x->dims();
auto paddings = *param.paddings;
......@@ -486,24 +652,14 @@ void ConvImageCompute::Conv2d3x3() {
} else if (!(filter_dims[0] == input_dims[1] && filter_dims[1] == 1)) {
new_groups = input_channel / filter_channel;
}
/* TODO(ysh329): mobile has no case below
else {
LOG(FATAL) << "Not support conv3x3 case with"
<< " input_dims:" << input_dims << " output_dims:" <<
output_dims
<< " filter_dims:" << filter_dims;
}
*/
const std::vector<size_t>& default_work_size =
DefaultWorkSize(output_dims,
DDim(std::vector<DDim::value_type>{
static_cast<int64_t>(out_image_shape["width"]),
static_cast<int64_t>(out_image_shape["height"])}));
int c_block = default_work_size[0];
int w = default_work_size[1];
int nh = default_work_size[2];
/* TODO(ysh329): mobile has no case below
else {
LOG(FATAL) << "Not support conv3x3 case with"
<< " input_dims:" << input_dims << " output_dims:" <<
output_dims
<< " filter_dims:" << filter_dims;
}
*/
#ifndef LITE_SHUTDOWN_LOG
VLOG(4) << "============ conv2d params ============";
......@@ -527,9 +683,9 @@ void ConvImageCompute::Conv2d3x3() {
VLOG(4) << "dilations: " << dilations[0] << ", " << dilations[1];
VLOG(4) << "param.groups(groups):" << param.groups;
VLOG(4) << "new_groups:" << new_groups;
VLOG(4) << "default work size{c_block, w, nh}: "
<< "{" << c_block << ", " << w << ", " << nh << ""
<< "}";
// VLOG(4) << "default work size{c_block, w, nh}: "
// << "{" << c_block << ", " << w << ", " << nh << ""
// << "}";
#endif
CHECK_GE(dilations.size(), 2);
......@@ -544,26 +700,15 @@ void ConvImageCompute::Conv2d3x3() {
if (has_bias) {
bias_image = bias_gpu_image_.data<half_t, cl::Image2D>();
}
auto& context = ctx_->As<OpenCLContext>();
CHECK(context.cl_context() != nullptr);
STL::stringstream kernel_key;
kernel_key << kernel_func_names_[0] << build_options_[0];
auto kernel = context.cl_context()->GetKernel(kernel_key.str());
#ifndef LITE_SHUTDOWN_LOG
VLOG(4) << "kernel_key: " << kernel_key.str();
VLOG(4) << "kernel ready ... " << kernel_key.str();
VLOG(4) << "w: " << w;
#endif
auto kernel = kernel_;
cl_int status;
int arg_idx = 0;
status = kernel.setArg(arg_idx, c_block);
status = kernel.setArg(arg_idx, c_blk_);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, w);
status = kernel.setArg(++arg_idx, w_blk_);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, nh);
status = kernel.setArg(++arg_idx, nh_blk_);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *input_image);
CL_CHECK_FATAL(status);
......@@ -607,21 +752,16 @@ void ConvImageCompute::Conv2d3x3() {
status = kernel.setArg(++arg_idx, new_groups);
CL_CHECK_FATAL(status);
auto global_work_size =
cl::NDRange{static_cast<size_t>(default_work_size.data()[0]),
static_cast<size_t>(default_work_size.data()[1]),
static_cast<size_t>(default_work_size.data()[2])};
#ifndef LITE_SHUTDOWN_LOG
// VLOG(4) << "out_image: " << out_image;
VLOG(4) << "global_work_size[3D]: {" << global_work_size[0] << ","
<< global_work_size[1] << "," << global_work_size[2] << "}";
VLOG(4) << "global_work_size_[3D]: {" << global_work_size_[0] << ","
<< global_work_size_[1] << "," << global_work_size_[2] << "}";
#endif
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel,
cl::NullRange,
global_work_size,
global_work_size_,
cl::NullRange,
nullptr,
event_.get());
......@@ -630,6 +770,8 @@ void ConvImageCompute::Conv2d3x3() {
}
void ConvImageCompute::Conv2d3x3opt() {
auto& context = ctx_->As<OpenCLContext>();
CHECK(context.cl_context() != nullptr);
const auto& param = *param_.get_mutable<param_t>();
auto input_dims = param.x->dims();
auto paddings = *param.paddings;
......@@ -657,24 +799,6 @@ void ConvImageCompute::Conv2d3x3opt() {
const bool is_element_wise_bias =
has_bias && param.output->dims() == param.bias->dims();
const std::vector<size_t>& default_work_size =
DefaultWorkSize(output_dims,
DDim(std::vector<DDim::value_type>{
static_cast<int64_t>(out_image_shape["width"]),
static_cast<int64_t>(out_image_shape["height"])}));
int c_block = default_work_size[0];
int w = default_work_size[1];
int nh = default_work_size[2];
int w_blk_size = 5;
int w_blk = (w + w_blk_size - 1) / w_blk_size;
// default_work_size[1] = w_blk;
int h_blk_size = 1;
int h_blk = (nh + h_blk_size - 1) / h_blk_size;
// default_work_size[2] = h_blk;
#ifndef LITE_SHUTDOWN_LOG
VLOG(4) << "============ conv2d params ============";
// VLOG(4) << "input_image_shape: " << input_image_shape["width"] << ","
......@@ -692,9 +816,6 @@ void ConvImageCompute::Conv2d3x3opt() {
VLOG(4) << "strides: " << strides[0] << "," << strides[1];
VLOG(4) << "dilations.size : " << dilations.size();
VLOG(4) << "dilations: " << dilations[0] << ", " << dilations[1];
VLOG(4) << "default work size{c_block, w, nh}: "
<< "{" << c_block << ", " << w << ", " << nh << ""
<< "}";
#endif
CHECK_GE(dilations.size(), 2);
......@@ -710,24 +831,15 @@ void ConvImageCompute::Conv2d3x3opt() {
bias_image = bias_gpu_image_.data<half_t, cl::Image2D>();
}
auto& context = ctx_->As<OpenCLContext>();
CHECK(context.cl_context() != nullptr);
STL::stringstream kernel_key;
kernel_key << kernel_func_names_[0] << build_options_[0];
auto kernel = context.cl_context()->GetKernel(kernel_key.str());
#ifndef LITE_SHUTDOWN_LOG
VLOG(4) << "kernel_key: " << kernel_key.str();
VLOG(4) << "kernel ready ... " << kernel_key.str();
#endif
auto kernel = kernel_;
cl_int status;
int arg_idx = 0;
status = kernel.setArg(arg_idx, c_block);
status = kernel.setArg(arg_idx, c_blk_);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, w_blk);
status = kernel.setArg(++arg_idx, w_blk_);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, h_blk);
status = kernel.setArg(++arg_idx, nh_blk_);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *input_image);
CL_CHECK_FATAL(status);
......@@ -763,38 +875,17 @@ void ConvImageCompute::Conv2d3x3opt() {
status = kernel.setArg(++arg_idx, output_height);
CL_CHECK_FATAL(status);
auto global_work_size =
cl::NDRange{static_cast<size_t>(default_work_size.data()[0]),
static_cast<size_t>(w_blk),
static_cast<size_t>(h_blk)};
#ifndef LITE_SHUTDOWN_LOG
// VLOG(4) << "out_image: " << out_image;
VLOG(4) << "global_work_size[3D]: {" << global_work_size[0] << ","
<< global_work_size[1] << "," << global_work_size[2] << "}";
#endif
size_t max_work_group_size = 0;
kernel.getWorkGroupInfo<size_t>(CLRuntime::Global()->device(),
CL_KERNEL_WORK_GROUP_SIZE,
&max_work_group_size);
cl::NDRange local_work_size = cl::NullRange;
#ifndef LITE_SHUTDOWN_LOG
VLOG(4) << "max_work_group_size: " << max_work_group_size;
#endif
if (max_work_group_size > 0 && use_lws) {
local_work_size = context.cl_context()->LocalWorkSize(global_work_size,
max_work_group_size);
#ifndef LITE_SHUTDOWN_LOG
VLOG(4) << "local_work_size[3D]: {" << local_work_size[0] << ","
<< local_work_size[1] << "," << local_work_size[2] << "}";
VLOG(4) << "global_work_size_[3D]: {" << global_work_size_[0] << ","
<< global_work_size_[1] << "," << global_work_size_[2] << "}";
#endif
}
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel,
cl::NullRange,
global_work_size,
local_work_size,
global_work_size_,
local_work_size_,
nullptr,
event_.get());
CL_CHECK_FATAL(status);
......@@ -802,6 +893,8 @@ void ConvImageCompute::Conv2d3x3opt() {
}
void ConvImageCompute::Conv2d5x5() {
auto& context = ctx_->As<OpenCLContext>();
CHECK(context.cl_context() != nullptr);
const auto& param = *param_.get_mutable<param_t>();
auto input_dims = param.x->dims();
auto paddings = *param.paddings;
......@@ -833,16 +926,6 @@ void ConvImageCompute::Conv2d5x5() {
int input_c = input_dims[1];
auto dilations = *param.dilations;
const std::vector<size_t>& default_work_size =
DefaultWorkSize(output_dims,
DDim(std::vector<DDim::value_type>{
static_cast<int64_t>(out_image_shape["width"]),
static_cast<int64_t>(out_image_shape["height"])}));
int c_block = default_work_size[0];
int w = default_work_size[1];
int nh = default_work_size[2];
#ifndef LITE_SHUTDOWN_LOG
VLOG(4) << "============ conv2d params ============";
VLOG(4) << "input_image_shape: " << input_image_shape["width"] << ","
......@@ -863,9 +946,6 @@ void ConvImageCompute::Conv2d5x5() {
VLOG(4) << "offset: " << offset;
VLOG(4) << "dilations.size : " << dilations.size();
VLOG(4) << "dilations: " << dilations[0] << ", " << dilations[1];
VLOG(4) << "default work size{c_block, w, nh}: "
<< "{" << c_block << ", " << w << ", " << nh << ""
<< "}";
#endif
CHECK_GE(dilations.size(), 2);
......@@ -881,25 +961,15 @@ void ConvImageCompute::Conv2d5x5() {
bias_image = bias_gpu_image_.data<half_t, cl::Image2D>();
}
auto& context = ctx_->As<OpenCLContext>();
CHECK(context.cl_context() != nullptr);
STL::stringstream kernel_key;
kernel_key << kernel_func_names_[0] << build_options_[0];
auto kernel = context.cl_context()->GetKernel(kernel_key.str());
#ifndef LITE_SHUTDOWN_LOG
VLOG(4) << "kernel_key: " << kernel_key.str();
VLOG(4) << "kernel ready ... " << kernel_key.str();
VLOG(4) << "w: " << w;
#endif
auto kernel = kernel_;
cl_int status;
int arg_idx = 0;
status = kernel.setArg(arg_idx, c_block);
status = kernel.setArg(arg_idx, c_blk_);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, w);
status = kernel.setArg(++arg_idx, w_blk_);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, nh);
status = kernel.setArg(++arg_idx, nh_blk_);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *input_image);
CL_CHECK_FATAL(status);
......@@ -933,21 +1003,16 @@ void ConvImageCompute::Conv2d5x5() {
status = kernel.setArg(++arg_idx, output_height);
CL_CHECK_FATAL(status);
auto global_work_size =
cl::NDRange{static_cast<size_t>(default_work_size.data()[0]),
static_cast<size_t>(default_work_size.data()[1]),
static_cast<size_t>(default_work_size.data()[2])};
#ifndef LITE_SHUTDOWN_LOG
// VLOG(4) << "out_image: " << out_image;
VLOG(4) << "global_work_size[3D]: {" << global_work_size[0] << ","
<< global_work_size[1] << "," << global_work_size[2] << "}";
VLOG(4) << "global_work_size_[3D]: {" << global_work_size_[0] << ","
<< global_work_size_[1] << "," << global_work_size_[2] << "}";
#endif
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel,
cl::NullRange,
global_work_size,
global_work_size_,
cl::NullRange,
nullptr,
event_.get());
......@@ -956,6 +1021,8 @@ void ConvImageCompute::Conv2d5x5() {
}
void ConvImageCompute::Conv2d5x5opt() {
auto& context = ctx_->As<OpenCLContext>();
CHECK(context.cl_context() != nullptr);
const auto& param = *param_.get_mutable<param_t>();
auto input_dims = param.x->dims();
auto paddings = *param.paddings;
......@@ -984,22 +1051,6 @@ void ConvImageCompute::Conv2d5x5opt() {
const bool is_element_wise_bias =
has_bias && param.output->dims() == param.bias->dims();
const std::vector<size_t>& default_work_size =
DefaultWorkSize(output_dims,
DDim(std::vector<DDim::value_type>{
static_cast<int64_t>(out_image_shape["width"]),
static_cast<int64_t>(out_image_shape["height"])}));
int c_block = default_work_size[0];
int w = default_work_size[1];
int nh = default_work_size[2];
int w_blk_size = 5;
int w_blk = (w + w_blk_size - 1) / w_blk_size;
// default_work_size[1] = w_blk;
int h_blk_size = 1;
int h_blk = (nh + h_blk_size - 1) / h_blk_size;
// default_work_size[2] = h_blk;
#ifndef LITE_SHUTDOWN_LOG
VLOG(4) << "============ conv2d params ============";
......@@ -1018,9 +1069,6 @@ void ConvImageCompute::Conv2d5x5opt() {
VLOG(4) << "strides: " << strides[0] << "," << strides[1];
VLOG(4) << "dilations.size : " << dilations.size();
VLOG(4) << "dilations: " << dilations[0] << ", " << dilations[1];
VLOG(4) << "default work size{c_block, w, nh}: "
<< "{" << c_block << ", " << w << ", " << nh << ""
<< "}";
#endif
CHECK_GE(dilations.size(), 2);
CHECK(dilations[0] == dilations[1]);
......@@ -1035,22 +1083,14 @@ void ConvImageCompute::Conv2d5x5opt() {
bias_image = bias_gpu_image_.data<half_t, cl::Image2D>();
}
auto& context = ctx_->As<OpenCLContext>();
CHECK(context.cl_context() != nullptr);
STL::stringstream kernel_key;
kernel_key << kernel_func_names_[0] << build_options_[0];
auto kernel = context.cl_context()->GetKernel(kernel_key.str());
#ifndef LITE_SHUTDOWN_LOG
VLOG(4) << "kernel_key: " << kernel_key.str();
VLOG(4) << "kernel ready ... " << kernel_key.str();
#endif
auto kernel = kernel_;
cl_int status;
int arg_idx = 0;
status = kernel.setArg(arg_idx, c_block);
status = kernel.setArg(arg_idx, c_blk_);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, w_blk);
status = kernel.setArg(++arg_idx, w_blk_);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, h_blk);
status = kernel.setArg(++arg_idx, nh_blk_);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *input_image);
CL_CHECK_FATAL(status);
......@@ -1083,38 +1123,13 @@ void ConvImageCompute::Conv2d5x5opt() {
status = kernel.setArg(++arg_idx, output_height);
CL_CHECK_FATAL(status);
auto global_work_size =
cl::NDRange{static_cast<size_t>(default_work_size.data()[0]),
static_cast<size_t>(w_blk),
static_cast<size_t>(h_blk)};
// VLOG(4) << "out_image: " << out_image;
#ifndef LITE_SHUTDOWN_LOG
VLOG(4) << "global_work_size[3D]: {" << global_work_size[0] << ","
<< global_work_size[1] << "," << global_work_size[2] << "}";
#endif
size_t max_work_group_size = 0;
kernel.getWorkGroupInfo<size_t>(CLRuntime::Global()->device(),
CL_KERNEL_WORK_GROUP_SIZE,
&max_work_group_size);
cl::NDRange local_work_size = cl::NullRange;
#ifndef LITE_SHUTDOWN_LOG
VLOG(4) << "max_work_group_size: " << max_work_group_size;
#endif
if (max_work_group_size > 0 && use_lws) {
local_work_size = context.cl_context()->LocalWorkSize(global_work_size,
max_work_group_size);
#ifndef LITE_SHUTDOWN_LOG
VLOG(4) << "local_work_size[3D]: {" << local_work_size[0] << ","
<< local_work_size[1] << "," << local_work_size[2] << "}";
#endif
}
// VLOG(4) << "out_image: " << out_image;
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel,
cl::NullRange,
global_work_size,
local_work_size,
global_work_size_,
local_work_size_,
nullptr,
event_.get());
CL_CHECK_FATAL(status);
......@@ -1122,6 +1137,8 @@ void ConvImageCompute::Conv2d5x5opt() {
}
void ConvImageCompute::Conv2d7x7() {
auto& context = ctx_->As<OpenCLContext>();
CHECK(context.cl_context() != nullptr);
const auto& param = *param_.get_mutable<param_t>();
auto input_dims = param.x->dims();
auto paddings = *param.paddings;
......@@ -1153,16 +1170,6 @@ void ConvImageCompute::Conv2d7x7() {
int input_c = input_dims[1];
auto dilations = *param.dilations;
const std::vector<size_t>& default_work_size =
DefaultWorkSize(output_dims,
DDim(std::vector<DDim::value_type>{
static_cast<int64_t>(out_image_shape["width"]),
static_cast<int64_t>(out_image_shape["height"])}));
int c_block = default_work_size[0];
int w = default_work_size[1];
int nh = default_work_size[2];
#ifndef LITE_SHUTDOWN_LOG
VLOG(4) << "============ conv2d params ============";
VLOG(4) << "input_image_shape: " << input_image_shape["width"] << ","
......@@ -1183,9 +1190,6 @@ void ConvImageCompute::Conv2d7x7() {
VLOG(4) << "offset: " << offset;
VLOG(4) << "dilations.size : " << dilations.size();
VLOG(4) << "dilations: " << dilations[0] << ", " << dilations[1];
VLOG(4) << "default work size{c_block, w, nh}: "
<< "{" << c_block << ", " << w << ", " << nh << ""
<< "}";
#endif
CHECK_GE(dilations.size(), 2);
......@@ -1201,25 +1205,15 @@ void ConvImageCompute::Conv2d7x7() {
bias_image = bias_gpu_image_.data<half_t, cl::Image2D>();
}
auto& context = ctx_->As<OpenCLContext>();
CHECK(context.cl_context() != nullptr);
STL::stringstream kernel_key;
kernel_key << kernel_func_names_[0] << build_options_[0];
auto kernel = context.cl_context()->GetKernel(kernel_key.str());
#ifndef LITE_SHUTDOWN_LOG
VLOG(4) << "kernel_key: " << kernel_key.str();
VLOG(4) << "kernel ready ... " << kernel_key.str();
VLOG(4) << "w: " << w;
#endif
auto kernel = kernel_;
cl_int status;
int arg_idx = 0;
status = kernel.setArg(arg_idx, c_block);
status = kernel.setArg(arg_idx, c_blk_);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, w);
status = kernel.setArg(++arg_idx, w_blk_);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, nh);
status = kernel.setArg(++arg_idx, nh_blk_);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *input_image);
CL_CHECK_FATAL(status);
......@@ -1253,21 +1247,16 @@ void ConvImageCompute::Conv2d7x7() {
status = kernel.setArg(++arg_idx, output_height);
CL_CHECK_FATAL(status);
auto global_work_size =
cl::NDRange{static_cast<size_t>(default_work_size.data()[0]),
static_cast<size_t>(default_work_size.data()[1]),
static_cast<size_t>(default_work_size.data()[2])};
#ifndef LITE_SHUTDOWN_LOG
// VLOG(4) << "out_image: " << out_image;
VLOG(4) << "global_work_size[3D]: {" << global_work_size[0] << ","
<< global_work_size[1] << "," << global_work_size[2] << "}";
VLOG(4) << "global_work_size_[3D]: {" << global_work_size_[0] << ","
<< global_work_size_[1] << "," << global_work_size_[2] << "}";
#endif
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel,
cl::NullRange,
global_work_size,
global_work_size_,
cl::NullRange,
nullptr,
event_.get());
......@@ -1275,6 +1264,8 @@ void ConvImageCompute::Conv2d7x7() {
context.cl_wait_list()->emplace(out_image, event_);
}
void ConvImageCompute::Conv2d7x7opt() {
auto& context = ctx_->As<OpenCLContext>();
CHECK(context.cl_context() != nullptr);
const auto& param = *param_.get_mutable<param_t>();
auto input_dims = param.x->dims();
auto paddings = *param.paddings;
......@@ -1302,23 +1293,6 @@ void ConvImageCompute::Conv2d7x7opt() {
const bool is_element_wise_bias =
has_bias && param.output->dims() == param.bias->dims();
const std::vector<size_t>& default_work_size =
DefaultWorkSize(output_dims,
DDim(std::vector<DDim::value_type>{
static_cast<int64_t>(out_image_shape["width"]),
static_cast<int64_t>(out_image_shape["height"])}));
int c_block = default_work_size[0];
int w = default_work_size[1];
int nh = default_work_size[2];
int w_blk_size = 5;
int w_blk = (w + w_blk_size - 1) / w_blk_size;
// default_work_size[1] = w_blk;
int h_blk_size = 1;
int h_blk = (nh + h_blk_size - 1) / h_blk_size;
// default_work_size[2] = h_blk;
#ifndef LITE_SHUTDOWN_LOG
VLOG(4) << "============ conv2d 7x7 params ============";
// VLOG(4) << "input_image_shape: " << input_image_shape["width"] << ","
......@@ -1336,9 +1310,6 @@ void ConvImageCompute::Conv2d7x7opt() {
VLOG(4) << "strides: " << strides[0] << "," << strides[1];
VLOG(4) << "dilations.size : " << dilations.size();
VLOG(4) << "dilations: " << dilations[0] << ", " << dilations[1];
VLOG(4) << "default work size{c_block, w, nh}: "
<< "{" << c_block << ", " << w << ", " << nh << ""
<< "}";
#endif
CHECK_GE(dilations.size(), 2);
CHECK(dilations[0] == dilations[1]);
......@@ -1353,24 +1324,15 @@ void ConvImageCompute::Conv2d7x7opt() {
bias_image = bias_gpu_image_.data<half_t, cl::Image2D>();
}
auto& context = ctx_->As<OpenCLContext>();
CHECK(context.cl_context() != nullptr);
STL::stringstream kernel_key;
kernel_key << kernel_func_names_[0] << build_options_[0];
auto kernel = context.cl_context()->GetKernel(kernel_key.str());
#ifndef LITE_SHUTDOWN_LOG
VLOG(4) << "kernel_key: " << kernel_key.str();
VLOG(4) << "kernel ready ... " << kernel_key.str();
#endif
auto kernel = kernel_;
cl_int status;
int arg_idx = 0;
status = kernel.setArg(arg_idx, c_block);
status = kernel.setArg(arg_idx, c_blk_);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, w_blk);
status = kernel.setArg(++arg_idx, w_blk_);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, h_blk);
status = kernel.setArg(++arg_idx, nh_blk_);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *input_image);
CL_CHECK_FATAL(status);
......@@ -1403,39 +1365,19 @@ void ConvImageCompute::Conv2d7x7opt() {
status = kernel.setArg(++arg_idx, output_height);
CL_CHECK_FATAL(status);
auto global_work_size =
cl::NDRange{static_cast<size_t>(default_work_size.data()[0]),
static_cast<size_t>(w_blk),
static_cast<size_t>(h_blk)};
#ifndef LITE_SHUTDOWN_LOG
VLOG(4) << "global_work_size[3D]: {" << global_work_size[0] << ","
<< global_work_size[1] << "," << global_work_size[2] << "}";
#endif
size_t max_work_group_size = 0;
kernel.getWorkGroupInfo<size_t>(CLRuntime::Global()->device(),
CL_KERNEL_WORK_GROUP_SIZE,
&max_work_group_size);
cl::NDRange local_work_size = cl::NullRange;
if (max_work_group_size > 0 && use_lws) {
local_work_size = context.cl_context()->LocalWorkSize(global_work_size,
max_work_group_size);
#ifndef LITE_SHUTDOWN_LOG
VLOG(4) << "local_work_size[3D]: {" << local_work_size[0] << ","
<< local_work_size[1] << "," << local_work_size[2] << "}";
#endif
}
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel,
cl::NullRange,
global_work_size,
local_work_size,
global_work_size_,
local_work_size_,
nullptr,
event_.get());
CL_CHECK_FATAL(status);
context.cl_wait_list()->emplace(out_image, event_);
}
void ConvImageCompute::DepthwiseConv2d3x3s1() {
auto& context = ctx_->As<OpenCLContext>();
CHECK(context.cl_context() != nullptr);
const auto& param = *param_.get_mutable<param_t>();
auto x_dims = param.x->dims();
auto filter_dims = param.filter->dims();
......@@ -1444,8 +1386,6 @@ void ConvImageCompute::DepthwiseConv2d3x3s1() {
auto strides = param.strides;
auto dilations = *param.dilations;
auto& context = ctx_->As<OpenCLContext>();
CHECK(context.cl_context() != nullptr);
auto* input_img = param.x->data<half_t, cl::Image2D>();
auto* filter_img = filter_gpu_image_.data<half_t, cl::Image2D>();
......@@ -1459,26 +1399,15 @@ void ConvImageCompute::DepthwiseConv2d3x3s1() {
auto* output_img = param.output->mutable_data<half_t, cl::Image2D>(
image_shape["width"], image_shape["height"]);
STL::stringstream kernel_key;
kernel_key << kernel_func_names_[0] << build_options_[0];
auto kernel = context.cl_context()->GetKernel(kernel_key.str());
int c_block = (output_dims[1] + 3) / 4;
int w = output_dims[3];
int nh = output_dims[0] * output_dims[2];
int w_blk_size = 2;
int w_blk = (w + w_blk_size - 1) / w_blk_size;
auto global_work_size = cl::NDRange(c_block, w_blk, nh);
auto kernel = kernel_;
cl_int status;
int arg_idx = 0;
status = kernel.setArg(arg_idx, static_cast<const int>(c_block));
status = kernel.setArg(arg_idx, c_blk_);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(w_blk));
status = kernel.setArg(++arg_idx, w_blk_);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(nh));
status = kernel.setArg(++arg_idx, nh_blk_);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *input_img);
CL_CHECK_FATAL(status);
......@@ -1516,28 +1445,11 @@ void ConvImageCompute::DepthwiseConv2d3x3s1() {
status = kernel.setArg(++arg_idx, static_cast<const int>(output_dims[2]));
CL_CHECK_FATAL(status);
size_t max_work_group_size = 0;
kernel.getWorkGroupInfo<size_t>(CLRuntime::Global()->device(),
CL_KERNEL_WORK_GROUP_SIZE,
&max_work_group_size);
cl::NDRange local_work_size = cl::NullRange;
#ifndef LITE_SHUTDOWN_LOG
VLOG(4) << "max_work_group_size: " << max_work_group_size;
#endif
if (max_work_group_size > 0 && use_lws) {
local_work_size = context.cl_context()->LocalWorkSize(global_work_size,
max_work_group_size);
#ifndef LITE_SHUTDOWN_LOG
VLOG(4) << "local_work_size[3D]: {" << local_work_size[0] << ","
<< local_work_size[1] << "," << local_work_size[2] << "}";
#endif
}
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel,
cl::NullRange,
global_work_size,
local_work_size,
global_work_size_,
local_work_size_,
nullptr,
event_.get());
CL_CHECK_FATAL(status);
......@@ -1545,6 +1457,8 @@ void ConvImageCompute::DepthwiseConv2d3x3s1() {
}
void ConvImageCompute::DepthwiseConv2d3x3() {
auto& context = ctx_->As<OpenCLContext>();
CHECK(context.cl_context() != nullptr);
const auto& param = *param_.get_mutable<param_t>();
auto x_dims = param.x->dims();
auto filter_dims = param.filter->dims();
......@@ -1555,8 +1469,6 @@ void ConvImageCompute::DepthwiseConv2d3x3() {
int offset = filter_dims[2] / 2 - paddings[0];
int input_c_block = (x_dims[1] + 3) / 4;
auto& context = ctx_->As<OpenCLContext>();
CHECK(context.cl_context() != nullptr);
auto* input_img = param.x->data<half_t, cl::Image2D>();
auto* filter_img = filter_gpu_image_.data<half_t, cl::Image2D>();
......@@ -1570,21 +1482,10 @@ void ConvImageCompute::DepthwiseConv2d3x3() {
auto* output_img = param.output->mutable_data<half_t, cl::Image2D>(
image_shape["width"], image_shape["height"]);
STL::stringstream kernel_key;
kernel_key << kernel_func_names_[0] << build_options_[0];
auto kernel = context.cl_context()->GetKernel(kernel_key.str());
int c_block = (output_dims[1] + 3) / 4;
int w = output_dims[3];
int nh = output_dims[0] * output_dims[2];
auto global_work_size = cl::NDRange(c_block, w, nh);
auto kernel = kernel_;
#ifndef LITE_SHUTDOWN_LOG
VLOG(4) << "setArg";
VLOG(4) << "c_block = " << c_block;
VLOG(4) << "w = " << w;
VLOG(4) << "nh = " << nh;
VLOG(4) << "strides = " << strides[0];
VLOG(4) << "offset = " << offset;
VLOG(4) << "dilations = " << dilations[0];
......@@ -1597,11 +1498,11 @@ void ConvImageCompute::DepthwiseConv2d3x3() {
cl_int status;
int arg_idx = 0;
status = kernel.setArg(arg_idx, static_cast<const int>(c_block));
status = kernel.setArg(arg_idx, c_blk_);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(w));
status = kernel.setArg(++arg_idx, w_blk_);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(nh));
status = kernel.setArg(++arg_idx, nh_blk_);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *input_img);
CL_CHECK_FATAL(status);
......@@ -1641,7 +1542,7 @@ void ConvImageCompute::DepthwiseConv2d3x3() {
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel,
cl::NullRange,
global_work_size,
global_work_size_,
cl::NullRange,
nullptr,
event_.get());
......@@ -1650,6 +1551,8 @@ void ConvImageCompute::DepthwiseConv2d3x3() {
}
void ConvImageCompute::DepthwiseConv2d() {
auto& context = ctx_->As<OpenCLContext>();
CHECK(context.cl_context() != nullptr);
const auto& param = *param_.get_mutable<param_t>();
auto input_dims = param.x->dims();
auto paddings = *param.paddings;
......@@ -1681,16 +1584,6 @@ void ConvImageCompute::DepthwiseConv2d() {
int input_c = input_dims[1];
auto dilations = *param.dilations;
const std::vector<size_t>& default_work_size =
DefaultWorkSize(output_dims,
DDim(std::vector<DDim::value_type>{
static_cast<int64_t>(out_image_shape["width"]),
static_cast<int64_t>(out_image_shape["height"])}));
int c_block = default_work_size[0];
int w = default_work_size[1];
int nh = default_work_size[2];
#ifndef LITE_SHUTDOWN_LOG
VLOG(4) << "============ depthwise conv2d params ============";
VLOG(4) << "input_image_shape: " << input_image_shape["width"] << ","
......@@ -1710,9 +1603,6 @@ void ConvImageCompute::DepthwiseConv2d() {
VLOG(4) << "offset: " << offset;
VLOG(4) << "dilations.size : " << dilations.size();
VLOG(4) << "dilations: " << dilations[0] << ", " << dilations[1];
VLOG(4) << "default work size{c_block, w, nh}: "
<< "{" << c_block << ", " << w << ", " << nh << ""
<< "}";
#endif
CHECK_GE(dilations.size(), 2);
......@@ -1730,25 +1620,15 @@ void ConvImageCompute::DepthwiseConv2d() {
bias_image = bias_gpu_image_.data<half_t, cl::Image2D>();
}
auto& context = ctx_->As<OpenCLContext>();
CHECK(context.cl_context() != nullptr);
STL::stringstream kernel_key;
kernel_key << kernel_func_names_[0] << build_options_[0];
auto kernel = context.cl_context()->GetKernel(kernel_key.str());
#ifndef LITE_SHUTDOWN_LOG
VLOG(4) << "kernel_key: " << kernel_key.str();
VLOG(4) << "kernel ready ... " << kernel_key.str();
VLOG(4) << "w: " << w;
#endif
auto kernel = kernel_;
cl_int status;
int arg_idx = 0;
status = kernel.setArg(arg_idx, c_block);
status = kernel.setArg(arg_idx, c_blk_);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, w);
status = kernel.setArg(++arg_idx, w_blk_);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, nh);
status = kernel.setArg(++arg_idx, nh_blk_);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *input_image);
CL_CHECK_FATAL(status);
......@@ -1786,21 +1666,16 @@ void ConvImageCompute::DepthwiseConv2d() {
status = kernel.setArg(++arg_idx, filter_height);
CL_CHECK_FATAL(status);
auto global_work_size =
cl::NDRange{static_cast<size_t>(default_work_size.data()[0]),
static_cast<size_t>(default_work_size.data()[1]),
static_cast<size_t>(default_work_size.data()[2])};
#ifndef LITE_SHUTDOWN_LOG
// VLOG(4) << "out_image: " << out_image;
VLOG(4) << "global_work_size[3D]: {" << global_work_size[0] << ","
<< global_work_size[1] << "," << global_work_size[2] << "}";
VLOG(4) << "global_work_size_[3D]: {" << global_work_size_[0] << ","
<< global_work_size_[1] << "," << global_work_size_[2] << "}";
#endif
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel,
cl::NullRange,
global_work_size,
global_work_size_,
cl::NullRange,
nullptr,
event_.get());
......@@ -1809,7 +1684,7 @@ void ConvImageCompute::DepthwiseConv2d() {
}
void ConvImageCompute::Run() { (this->*impl_)(); }
#undef PROFILE_CONV_KERNEL
} // namespace opencl
} // namespace kernels
} // namespace lite
......
......@@ -59,6 +59,19 @@ class ConvImageCompute : public KernelLite<TARGET(kOpenCL),
std::shared_ptr<cl::Event> event_{new cl::Event};
Tensor filter_gpu_image_;
Tensor bias_gpu_image_;
cl::NDRange global_work_size_ = cl::NDRange{
static_cast<size_t>(1), static_cast<size_t>(1), static_cast<size_t>(1)};
int c_blk_ = 1;
int w_blk_ = 1;
int nh_blk_ = 1;
int default_c_blk_ = 1;
int default_w_blk_ = 1;
int default_nh_blk_ = 1;
cl::Kernel kernel_;
cl::NDRange local_work_size_ = cl::NDRange{
static_cast<size_t>(1), static_cast<size_t>(1), static_cast<size_t>(1)};
bool use_lws{true};
};
......
......@@ -44,28 +44,31 @@ class ElementwiseMulImageCompute
ele_param_ = param_.get_mutable<param_t>();
auto* y = ele_param_->Y;
auto* x = ele_param_->X;
auto y_dims = y->dims();
auto bias_dims = y->dims();
auto x_dims = x->dims();
if (y_dims == x_dims) {
if (bias_dims == x_dims) {
kernel_func_name_ = "elementwise_mul";
} else if (y_dims.size() == 1) {
kernel_func_name_ = "channel_mul_d1";
} else if (y_dims.size() == 2) {
if (x_dims[0] == y_dims[0] && x_dims[1] == y_dims[1]) {
kernel_func_name_ = "channel_mul_d2_nc";
} else {
const int bias_dim_size = bias_dims.size();
if (bias_dim_size == 1) {
kernel_func_name_ = "channel_mul_d1";
} else if (bias_dim_size == 2) {
kernel_func_name_ = "channel_mul_d2";
} else if (bias_dim_size == 3) {
kernel_func_name_ = "channel_mul_d3";
} else if (bias_dim_size == 4) {
kernel_func_name_ = "channel_mul_d4";
} else {
kernel_func_name_ = "channel_mul_d2_hw";
LOG(FATAL) << "Unsupported ElementwiseMul with x_dims:" << x_dims
<< " y_dims:" << bias_dims;
}
} else if (y_dims.size() == 4 || x_dims.size() == 4) {
kernel_func_name_ = "channel_mul_d4";
} else {
LOG(FATAL) << "ElementwiseMul not supported y_dims.size():"
<< y_dims.size()
<< ", x_dims.size():" << ele_param_->X->dims().size();
}
VLOG(1) << "kernel_func_name_:" << kernel_func_name_;
VLOG(4) << "y_dims:" << y_dims;
VLOG(4) << "y_dims.size():" << y_dims.size();
VLOG(4) << "x_dims:" << x_dims;
VLOG(4) << "bias_dims:" << bias_dims;
VLOG(4) << "bias_dims.size():" << bias_dims.size();
auto& context = ctx_->As<OpenCLContext>();
context.cl_context()->AddKernel(
......@@ -114,79 +117,67 @@ class ElementwiseMulImageCompute
kernel_key << kernel_func_name_ << build_options_;
auto kernel = context.cl_context()->GetKernel(kernel_key.str());
int arg_idx = 0;
auto y_dims = y->dims();
auto bias_dims = y->dims();
auto x_dims = x->dims();
if (y_dims == x_dims) {
// kernel: elementwise_mul(channel_mul_d4)
cl_int status = kernel.setArg(arg_idx, *x_img);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *y_img);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *out_img);
CL_CHECK_FATAL(status);
} else if (y_dims.size() == 1 || y_dims.size() == 4) {
auto tensor_w = x_dims[x_dims.size() - 1];
#ifndef LITE_SHUTDOWN_LOG
VLOG(4) << "tensor_w:" << tensor_w;
#endif
// kernel: channel_mul_d1 / channel_mul_d4
cl_int status = kernel.setArg(arg_idx, *x_img);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *y_img);
if (bias_dims == x_dims) {
// kernel_func_name_ = "elementwise_mul";
cl_int status = kernel.setArg(0, *x_img);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *out_img);
status = kernel.setArg(1, *y_img);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(tensor_w));
status = kernel.setArg(2, *out_img);
CL_CHECK_FATAL(status);
} else if (y_dims.size() == 2) {
if (x_dims[0] == y_dims[0] && x_dims[1] == y_dims[1]) {
auto tensor_w = x_dims[x_dims.size() - 1];
#ifndef LITE_SHUTDOWN_LOG
VLOG(4) << "tensor_w:" << tensor_w;
#endif
// kernel: channel_mul_d2_nc
cl_int status = kernel.setArg(arg_idx, *x_img);
} else {
const int bias_dim_size = bias_dims.size();
if (bias_dim_size == 1) {
// kernel_func_name_ = "channel_mul_d1";
const int tensor_w = x_dims[x_dims.size() - 1];
cl_int status = kernel.setArg(0, *x_img);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *y_img);
status = kernel.setArg(1, *y_img);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *out_img);
status = kernel.setArg(2, *out_img);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(tensor_w));
status = kernel.setArg(3, tensor_w);
CL_CHECK_FATAL(status);
} else {
auto y_tensor_h = y->dims()[0];
auto y_tensor_w = y->dims()[1];
#ifndef LITE_SHUTDOWN_LOG
VLOG(4) << "y_tensor_w:" << y_tensor_w << " y_tensor_h:" << y_tensor_h;
#endif
// kernel: channel_mul_d2_hw
cl_int status = kernel.setArg(arg_idx, *x_img);
} else if (bias_dim_size == 2) {
// kernel_func_name_ = "channel_mul_d2";
const int tensor_w = x_dims[x_dims.size() - 1];
cl_int status = kernel.setArg(0, *x_img);
CL_CHECK_FATAL(status);
status = kernel.setArg(1, *y_img);
CL_CHECK_FATAL(status);
status = kernel.setArg(2, *out_img);
CL_CHECK_FATAL(status);
status = kernel.setArg(3, tensor_w);
CL_CHECK_FATAL(status);
} else if (bias_dim_size == 3) {
// kernel_func_name_ = "channel_mul_d3";
const int tensor_w = x_dims[x_dims.size() - 1];
cl_int status = kernel.setArg(0, *x_img);
CL_CHECK_FATAL(status);
status = kernel.setArg(1, *y_img);
CL_CHECK_FATAL(status);
status = kernel.setArg(2, *out_img);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *y_img);
status = kernel.setArg(3, tensor_w);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *out_img);
} else if (bias_dim_size == 4) {
// kernel_func_name_ = "channel_mul_d4";
const int tensor_w = x_dims[x_dims.size() - 1];
cl_int status = kernel.setArg(0, *x_img);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(y_tensor_w));
status = kernel.setArg(1, *y_img);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(y_tensor_h));
status = kernel.setArg(2, *out_img);
CL_CHECK_FATAL(status);
status = kernel.setArg(3, tensor_w);
CL_CHECK_FATAL(status);
} else {
LOG(FATAL) << "Unsupported ElementwiseMul with x_dims:" << x_dims
<< " y_dims:" << bias_dims;
}
} else if (x_dims.size() == 4) {
auto tensor_w = y_dims[y_dims.size() - 1];
VLOG(4) << "tensor_w:" << tensor_w;
// kernel: channel_mul_d4
cl_int status = kernel.setArg(arg_idx, *y_img);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *x_img);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *out_img);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(tensor_w));
CL_CHECK_FATAL(status);
} else {
LOG(FATAL) << "ElementwiseMul not supported y_dims.size():"
<< y_dims.size();
}
auto global_work_size =
......
......@@ -38,6 +38,115 @@ class InstanceNormImageCompute : public KernelLite<TARGET(kOpenCL),
return "InstanceNorm using cl::Image2D(ImageDefault/RGBA), kFP16";
}
#if 1 // onnx/pytorch version
void PrepareForRun() override {
instance_norm_param_ = param_.get_mutable<param_t>();
auto out = instance_norm_param_->out;
auto out_dims = out->dims();
const int out_n = out_dims[0];
const int out_c = out_dims[1];
const int out_h = out_dims[2];
const int out_w = out_dims[3];
const int c_group = (out_dims[1] + 3) / 4;
// TODO(ysh329): add instance_norm + relu pass
// std::string build_options_ += "-DRELU";
if (out_h == 128) {
build_options_ += " -DLOCAL_MEM_128";
} else if (out_h == 64) {
build_options_ += " -DLOCAL_MEM_64";
} else if (out_h > 256) {
LOG(FATAL) << "Unsupported input height:" << out_h << " of instance norm";
}
auto& context = ctx_->As<OpenCLContext>();
context.cl_context()->AddKernel(
kernel_func_name_, "image/instance_norm_kernel.cl", build_options_);
VLOG(1) << "kernel_func_name_:" << kernel_func_name_;
}
void Run() override {
auto& context = ctx_->As<OpenCLContext>();
CHECK(context.cl_context() != nullptr);
auto* x = instance_norm_param_->x;
auto* out = instance_norm_param_->out;
auto x_dims = x->dims();
auto out_dims = out->dims();
const int out_n = out_dims[0];
const int out_c_group = (out_dims[1] + 3) / 4;
const int out_h = out_dims[2];
const int out_w = out_dims[3];
float epsilon = instance_norm_param_->epsilon;
auto device_info = CLRuntime::Global()->GetDeviceInfo();
int max_work_item_size1 = device_info["CL_DEVICE_MAX_WORK_ITEM_SIZES_1"];
int lws0 = 1;
int lws1 =
std::min(static_cast<int>(max_work_item_size1), std::min(256, out_w));
int lws2 = 1;
auto global_work_size =
cl::NDRange{static_cast<cl::size_type>(out_n * out_c_group),
static_cast<cl::size_type>(lws1),
static_cast<cl::size_type>(lws2)};
auto local_work_size = cl::NDRange{static_cast<cl::size_type>(lws0),
static_cast<cl::size_type>(lws1),
static_cast<cl::size_type>(lws2)};
#ifndef LITE_SHUTDOWN_LOG
VLOG(4) << "global_work_size:" << static_cast<int>(global_work_size[0])
<< " " << static_cast<int>(global_work_size[1]) << " "
<< static_cast<int>(global_work_size[2]);
VLOG(4) << "local_work_size:" << static_cast<int>(local_work_size[0]) << " "
<< static_cast<int>(local_work_size[1]) << " "
<< static_cast<int>(local_work_size[2]);
VLOG(4) << "out_w:" << out_w;
VLOG(4) << "out_h:" << out_h;
VLOG(4) << "out_c_group:" << out_c_group;
VLOG(4) << "lws1:" << lws1;
VLOG(4) << "lws2:" << lws2;
VLOG(4) << "epsilon:" << epsilon;
#endif
auto out_image_shape = InitImageDimInfoWith(out_dims);
auto* x_img = x->data<half_t, cl::Image2D>();
auto* out_img = out->mutable_data<half_t, cl::Image2D>(
out_image_shape["width"], out_image_shape["height"]);
STL::stringstream kernel_key;
kernel_key << kernel_func_name_ << build_options_;
auto kernel = context.cl_context()->GetKernel(kernel_key.str());
cl_int status = kernel.setArg(0, out_w);
CL_CHECK_FATAL(status);
status = kernel.setArg(1, out_h);
CL_CHECK_FATAL(status);
status = kernel.setArg(2, out_c_group);
CL_CHECK_FATAL(status);
status = kernel.setArg(3, lws1);
CL_CHECK_FATAL(status);
status = kernel.setArg(4, lws2);
CL_CHECK_FATAL(status);
status = kernel.setArg(5, epsilon);
CL_CHECK_FATAL(status);
status = kernel.setArg(6, *x_img);
CL_CHECK_FATAL(status);
status = kernel.setArg(7, *out_img);
CL_CHECK_FATAL(status);
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel,
cl::NullRange,
global_work_size,
local_work_size,
nullptr,
event_.get());
CL_CHECK_FATAL(status);
context.cl_wait_list()->emplace(out_img, event_);
}
#else // paddle version
void PrepareForRun() override {
instance_norm_param_ = param_.get_mutable<param_t>();
auto channel = instance_norm_param_->scale->dims()[0];
......@@ -79,7 +188,6 @@ class InstanceNormImageCompute : public KernelLite<TARGET(kOpenCL),
void Run() override {
auto& context = ctx_->As<OpenCLContext>();
CHECK(context.cl_context() != nullptr);
auto* x = instance_norm_param_->x;
auto* out = instance_norm_param_->out;
auto in_dims = x->dims();
......@@ -131,7 +239,6 @@ class InstanceNormImageCompute : public KernelLite<TARGET(kOpenCL),
auto* scale_img = scale_image_.data<half_t, cl::Image2D>();
auto* bias_img = bias_image_.data<half_t, cl::Image2D>();
float epsilon = instance_norm_param_->epsilon;
int arg_idx = 0;
cl_int status = kernel.setArg(arg_idx++, *x_img);
CL_CHECK_FATAL(status);
......@@ -158,10 +265,11 @@ class InstanceNormImageCompute : public KernelLite<TARGET(kOpenCL),
CL_CHECK_FATAL(status);
context.cl_wait_list()->emplace(out_img, event_);
}
#endif
protected:
param_t* instance_norm_param_{nullptr};
std::string kernel_func_name_{"instance_norm"};
std::string kernel_func_name_{"instance_norm_onnx"};
std::string build_options_{"-DCL_DTYPE_half"};
std::shared_ptr<cl::Event> event_{new cl::Event};
Tensor scale_image_;
......
......@@ -360,6 +360,7 @@ function make_x86 {
-DWITH_GPU=OFF \
-DLITE_WITH_PYTHON=${BUILD_PYTHON} \
-DLITE_BUILD_EXTRA=ON \
-DCMAKE_BUILD_TYPE=Release \
-DLITE_WITH_XPU=$BUID_XPU \
-DXPU_SDK_ROOT=$XPU_SDK_ROOT
......
......@@ -130,7 +130,7 @@ build_for_arm_linux() {
-B"../build/release/arm-linux" \
-DCMAKE_BUILD_TYPE="${MODE}" \
-DCMAKE_TOOLCHAIN_FILE="./tools/toolchains/arm-linux-gnueabihf.cmake" \
-DCMAKE_CXX_FLAGS="-std=c++14 -mcpu=cortex-a53 -mtune=cortex-a53 -ftree-vectorize -funsafe-math-optimizations -pipe -mlittle-endian " \
-DCMAKE_CXX_FLAGS=" " \
-DNET="${NETS}" \
-D"V7"=true
else
......@@ -138,7 +138,7 @@ build_for_arm_linux() {
-B"../build/release/arm-linux" \
-DCMAKE_BUILD_TYPE="${MODE}" \
-DCMAKE_TOOLCHAIN_FILE="./tools/toolchains/arm-linux-gnueabihf.cmake" \
-DCMAKE_CXX_FLAGS="-std=c++14 -mcpu=cortex-a53 -mtune=cortex-a53 -ftree-vectorize -funsafe-math-optimizations -pipe -mlittle-endian " \
-DCMAKE_CXX_FLAGS=" " \
-DNET="${NETS}" \
-D"V7"=true
fi
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册