diff --git a/mindspore/lite/CMakeLists.txt b/mindspore/lite/CMakeLists.txt index 25962c3e2c78d6d197fbb8f607ff128a15af5b6e..76808d7506fe83b97892bab9eff3f6fdad6ae812 100644 --- a/mindspore/lite/CMakeLists.txt +++ b/mindspore/lite/CMakeLists.txt @@ -54,6 +54,7 @@ endif () if (SUPPORT_GPU) add_definitions(-DUSE_OPENCL_WRAPPER) add_definitions(-DMS_OPENCL_PROFILE=false) + add_definitions(-DCL_HPP_TARGET_OPENCL_VERSION=200) add_compile_definitions(SUPPORT_GPU) if(OFFLINE_COMPILE) add_compile_definitions(PROGRAM_WITH_IL) diff --git a/mindspore/lite/src/runtime/kernel/opencl/cl/fp16/conv2d_transpose2x2.cl b/mindspore/lite/src/runtime/kernel/opencl/cl/fp16/conv2d_transpose2x2.cl index bb4ffce68bb3c95ec036216c83f311cc63bfd10f..e166e699c5ca2e4d224fe3afa3aa14cf64953c82 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/cl/fp16/conv2d_transpose2x2.cl +++ b/mindspore/lite/src/runtime/kernel/opencl/cl/fp16/conv2d_transpose2x2.cl @@ -1,6 +1,8 @@ #define FLT half #define FLT4 half4 #define FLT16 half16 +#define READ_IMAGE read_imageh +#define WRITE_IMAGE write_imageh __constant sampler_t smp_zero = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; __kernel void conv2d_transpose2x2(__read_only image2d_t src_data, __global FLT16 *weight, __read_only image2d_t biases, __write_only image2d_t dst_data, int2 kernel_size, int2 stride, int2 padding, @@ -14,17 +16,17 @@ __kernel void conv2d_transpose2x2(__read_only image2d_t src_data, __global FLT16 int src_w = w / 2; src_w = src_w * 2; int co = get_global_id(2); - if (h * 2 >= dst_size.x || w * 2 >= dst_size.y || co >= dst_size.z) return; + if (src_h * 2 >= dst_size.x || src_w * 2 >= dst_size.y || co >= dst_size.z) return; FLT4 r0 = (FLT4)(0.f); FLT4 r1 = (FLT4)(0.f); FLT4 r2 = (FLT4)(0.f); FLT4 r3 = (FLT4)(0.f); int base_w = (co * 4 + kh + kw * 2) * src_size.z; for (int ci = 0; ci < src_size.z; ++ci) { - FLT4 x0 = read_imagef(src_data, smp_zero, (int2)(src_w * src_size.z + ci, src_h)); - FLT4 x1 = read_imagef(src_data, smp_zero, (int2)(src_w * src_size.z + ci, src_h + 1)); - FLT4 x2 = read_imagef(src_data, smp_zero, (int2)((src_w + 1) * src_size.z + ci, src_h)); - FLT4 x3 = read_imagef(src_data, smp_zero, (int2)((src_w + 1) * src_size.z + ci, src_h + 1)); + FLT4 x0 = READ_IMAGE(src_data, smp_zero, (int2)(src_w * src_size.z + ci, src_h)); + FLT4 x1 = READ_IMAGE(src_data, smp_zero, (int2)(src_w * src_size.z + ci, src_h + 1)); + FLT4 x2 = READ_IMAGE(src_data, smp_zero, (int2)((src_w + 1) * src_size.z + ci, src_h)); + FLT4 x3 = READ_IMAGE(src_data, smp_zero, (int2)((src_w + 1) * src_size.z + ci, src_h + 1)); FLT16 weight_cache = weight[base_w++]; r0 += x0.x * weight_cache.s0123; r0 += x0.y * weight_cache.s4567; @@ -46,14 +48,14 @@ __kernel void conv2d_transpose2x2(__read_only image2d_t src_data, __global FLT16 r3 += x3.z * weight_cache.s89ab; r3 += x3.w * weight_cache.scdef; } - FLT4 bias_val = read_imagef(biases, smp_zero, (int2)(co, 0)); + FLT4 bias_val = READ_IMAGE(biases, smp_zero, (int2)(co, 0)); r0 += bias_val; r1 += bias_val; r2 += bias_val; r3 += bias_val; - write_imagef(dst_data, (int2)((2 * src_w + kw) * dst_size.z + co, 2 * src_h + kh), r0); - write_imagef(dst_data, (int2)((2 * src_w + kw) * dst_size.z + co, 2 * src_h + kh + 2), r1); - write_imagef(dst_data, (int2)((2 * src_w + kw + 2) * dst_size.z + co, 2 * src_h + kh), r2); - write_imagef(dst_data, (int2)((2 * src_w + kw + 2) * dst_size.z + co, 2 * src_h + kh + 2), r3); + WRITE_IMAGE(dst_data, (int2)((2 * src_w + kw) * dst_size.z + co, 2 * src_h + kh), r0); + WRITE_IMAGE(dst_data, (int2)((2 * src_w + kw) * dst_size.z + co, 2 * src_h + kh + 2), r1); + WRITE_IMAGE(dst_data, (int2)((2 * src_w + kw + 2) * dst_size.z + co, 2 * src_h + kh), r2); + WRITE_IMAGE(dst_data, (int2)((2 * src_w + kw + 2) * dst_size.z + co, 2 * src_h + kh + 2), r3); } diff --git a/mindspore/lite/src/runtime/kernel/opencl/cl/fp16/matmul.cl b/mindspore/lite/src/runtime/kernel/opencl/cl/fp16/matmul.cl index 6b89bb3b6bf45316297382ce1a5e67c583f10439..c121f824bdd984bf52f8abffa14f84c12de53f05 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/cl/fp16/matmul.cl +++ b/mindspore/lite/src/runtime/kernel/opencl/cl/fp16/matmul.cl @@ -1,31 +1,32 @@ -#pragma OPENCL EXTENSION cl_khr_fp16 : enable #define FLT4 half4 #define FLT16 half16 -__kernel void MatMul(__global FLT4 *x, __global FLT16 *weight, __global FLT4 *buffer, __global FLT4 *bias, - int2 offset_ci, int2 offset_co, int has_bias) { +#define READ_IMAGE read_imageh +#define WRITE_IMAGE write_imageh +__constant sampler_t smp_zero = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; +__kernel void MatMul(__read_only image2d_t input, __global FLT16 *weight, __read_only image2d_t bias, + __write_only image2d_t output, int2 offset_ci, int2 offset_co, int has_bias) { int2 gid = (int2)(get_global_id(0), get_global_id(1)); int2 lid = (int2)(get_local_id(0), get_local_id(1)); - FLT4 s = (FLT4)(0.0f); + FLT4 result = (FLT4)(0.0f); bool inside = gid.x < offset_co.y; for (uint i = lid.y; i < offset_ci.y && inside; i += 4) { - FLT4 v = x[i]; + FLT4 v = READ_IMAGE(input, smp_zero, (int2)(i, 0)); FLT16 w = weight[gid.x + i * offset_co.y]; - s.x += dot(v, w.s0123); - s.y += dot(v, w.s4567); - s.z += dot(v, w.s89ab); - s.w += dot(v, w.scdef); + result.x += dot(v, w.s0123); + result.y += dot(v, w.s4567); + result.z += dot(v, w.s89ab); + result.w += dot(v, w.scdef); } __local FLT4 temp[64][4]; - temp[lid.x][lid.y] = s; + temp[lid.x][lid.y] = result; barrier(CLK_LOCAL_MEM_FENCE); if (lid.y == 0 && inside) { - s += temp[lid.x][1]; - s += temp[lid.x][2]; - s += temp[lid.x][3]; + result += temp[lid.x][1]; + result += temp[lid.x][2]; + result += temp[lid.x][3]; if (has_bias != 0) { - s += bias[gid.x]; + result += READ_IMAGE(bias, smp_zero, (int2)(gid.x, 0)); } - buffer[gid.x] = s; - // memory pollution? or protected by opencl + WRITE_IMAGE(output, (int2)(gid.x, 0), result); } } diff --git a/mindspore/lite/src/runtime/kernel/opencl/cl/fp32/conv2d_transpose2x2.cl b/mindspore/lite/src/runtime/kernel/opencl/cl/fp32/conv2d_transpose2x2.cl index e77281772d54967f64d4c646db6dee93f8fecb81..c4b90579804bed66c143b2438af431a30a7ac85c 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/cl/fp32/conv2d_transpose2x2.cl +++ b/mindspore/lite/src/runtime/kernel/opencl/cl/fp32/conv2d_transpose2x2.cl @@ -1,6 +1,8 @@ #define FLT float #define FLT4 float4 #define FLT16 float16 +#define READ_IMAGE read_imagef +#define WRITE_IMAGE write_imagef __constant sampler_t smp_zero = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; __kernel void conv2d_transpose2x2(__read_only image2d_t src_data, __global FLT16 *weight, __read_only image2d_t biases, __write_only image2d_t dst_data, int2 kernel_size, int2 stride, int2 padding, @@ -14,17 +16,17 @@ __kernel void conv2d_transpose2x2(__read_only image2d_t src_data, __global FLT16 int src_w = w / 2; src_w = src_w * 2; int co = get_global_id(2); - if (h * 2 >= dst_size.x || w * 2 >= dst_size.y || co >= dst_size.z) return; + if (src_h * 2 >= dst_size.x || src_w * 2 >= dst_size.y || co >= dst_size.z) return; FLT4 r0 = (FLT4)(0.f); FLT4 r1 = (FLT4)(0.f); FLT4 r2 = (FLT4)(0.f); FLT4 r3 = (FLT4)(0.f); int base_w = (co * 4 + kh + kw * 2) * src_size.z; for (int ci = 0; ci < src_size.z; ++ci) { - FLT4 x0 = read_imagef(src_data, smp_zero, (int2)(src_w * src_size.z + ci, src_h)); - FLT4 x1 = read_imagef(src_data, smp_zero, (int2)(src_w * src_size.z + ci, src_h + 1)); - FLT4 x2 = read_imagef(src_data, smp_zero, (int2)((src_w + 1) * src_size.z + ci, src_h)); - FLT4 x3 = read_imagef(src_data, smp_zero, (int2)((src_w + 1) * src_size.z + ci, src_h + 1)); + FLT4 x0 = READ_IMAGE(src_data, smp_zero, (int2)(src_w * src_size.z + ci, src_h)); + FLT4 x1 = READ_IMAGE(src_data, smp_zero, (int2)(src_w * src_size.z + ci, src_h + 1)); + FLT4 x2 = READ_IMAGE(src_data, smp_zero, (int2)((src_w + 1) * src_size.z + ci, src_h)); + FLT4 x3 = READ_IMAGE(src_data, smp_zero, (int2)((src_w + 1) * src_size.z + ci, src_h + 1)); FLT16 weight_cache = weight[base_w++]; r0 += x0.x * weight_cache.s0123; r0 += x0.y * weight_cache.s4567; @@ -46,14 +48,14 @@ __kernel void conv2d_transpose2x2(__read_only image2d_t src_data, __global FLT16 r3 += x3.z * weight_cache.s89ab; r3 += x3.w * weight_cache.scdef; } - FLT4 bias_val = read_imagef(biases, smp_zero, (int2)(co, 0)); + FLT4 bias_val = READ_IMAGE(biases, smp_zero, (int2)(co, 0)); r0 += bias_val; r1 += bias_val; r2 += bias_val; r3 += bias_val; - write_imagef(dst_data, (int2)((2 * src_w + kw) * dst_size.z + co, 2 * src_h + kh), r0); - write_imagef(dst_data, (int2)((2 * src_w + kw) * dst_size.z + co, 2 * src_h + kh + 2), r1); - write_imagef(dst_data, (int2)((2 * src_w + kw + 2) * dst_size.z + co, 2 * src_h + kh), r2); - write_imagef(dst_data, (int2)((2 * src_w + kw + 2) * dst_size.z + co, 2 * src_h + kh + 2), r3); + WRITE_IMAGE(dst_data, (int2)((2 * src_w + kw) * dst_size.z + co, 2 * src_h + kh), r0); + WRITE_IMAGE(dst_data, (int2)((2 * src_w + kw) * dst_size.z + co, 2 * src_h + kh + 2), r1); + WRITE_IMAGE(dst_data, (int2)((2 * src_w + kw + 2) * dst_size.z + co, 2 * src_h + kh), r2); + WRITE_IMAGE(dst_data, (int2)((2 * src_w + kw + 2) * dst_size.z + co, 2 * src_h + kh + 2), r3); } diff --git a/mindspore/lite/src/runtime/kernel/opencl/cl/fp32/matmul.cl b/mindspore/lite/src/runtime/kernel/opencl/cl/fp32/matmul.cl index e851282c4bad7f0394e64047ce5cf25346ea9b49..1dcc884e0e70f226164fcefad869a9eb7ce8e85d 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/cl/fp32/matmul.cl +++ b/mindspore/lite/src/runtime/kernel/opencl/cl/fp32/matmul.cl @@ -1,30 +1,32 @@ #define FLT4 float4 #define FLT16 float16 -__kernel void MatMul(__global FLT4 *x, __global FLT16 *weight, __global FLT4 *buffer, __global FLT4 *bias, - int2 offset_ci, int2 offset_co, int has_bias) { +#define READ_IMAGE read_imagef +#define WRITE_IMAGE write_imagef +__constant sampler_t smp_zero = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; +__kernel void MatMul(__read_only image2d_t input, __global FLT16 *weight, __read_only image2d_t bias, + __write_only image2d_t output, int2 offset_ci, int2 offset_co, int has_bias) { int2 gid = (int2)(get_global_id(0), get_global_id(1)); int2 lid = (int2)(get_local_id(0), get_local_id(1)); - FLT4 s = (FLT4)(0.0f); + FLT4 result = (FLT4)(0.0f); bool inside = gid.x < offset_co.y; for (uint i = lid.y; i < offset_ci.y && inside; i += 4) { - FLT4 v = x[i]; + FLT4 v = READ_IMAGE(input, smp_zero, (int2)(i, 0)); FLT16 w = weight[gid.x + i * offset_co.y]; - s.x += dot(v, w.s0123); - s.y += dot(v, w.s4567); - s.z += dot(v, w.s89ab); - s.w += dot(v, w.scdef); + result.x += dot(v, w.s0123); + result.y += dot(v, w.s4567); + result.z += dot(v, w.s89ab); + result.w += dot(v, w.scdef); } __local FLT4 temp[64][4]; - temp[lid.x][lid.y] = s; + temp[lid.x][lid.y] = result; barrier(CLK_LOCAL_MEM_FENCE); if (lid.y == 0 && inside) { - s += temp[lid.x][1]; - s += temp[lid.x][2]; - s += temp[lid.x][3]; + result += temp[lid.x][1]; + result += temp[lid.x][2]; + result += temp[lid.x][3]; if (has_bias != 0) { - s += bias[gid.x]; + result += READ_IMAGE(bias, smp_zero, (int2)(gid.x, 0)); } - buffer[gid.x] = s; - // memory pollution? or protected by opencl + WRITE_IMAGE(output, (int2)(gid.x, 0), result); } } diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/conv2d_transpose.cc b/mindspore/lite/src/runtime/kernel/opencl/kernel/conv2d_transpose.cc index 59d617d65a3e0d52b104c291948c9220da7e4300..6aefa36a69bf0608591f55081bc7e341dee5abe3 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/conv2d_transpose.cc +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/conv2d_transpose.cc @@ -144,8 +144,8 @@ int Conv2dTransposeOpenCLKernel::Run() { &out_error_code); // local size should less than MAX_GROUP_SIZE std::vector local = {16, 1, 16}; - std::vector global = {UP_ROUND((size_t)oh / 2, local[0]), UP_ROUND((size_t)ow / 2, local[1]), - UP_ROUND((size_t)co / 4, local[2])}; + std::vector global = {UP_ROUND((size_t)UP_ROUND(oh / 2, 2), local[0]), + UP_ROUND((size_t)UP_ROUND(ow / 2, 2), local[1]), UP_ROUND((size_t)co / 4, local[2])}; cl_int2 kernel_size = {kh, kw}; cl_int2 stride = {2, 2}; diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/matmul.cc b/mindspore/lite/src/runtime/kernel/opencl/kernel/matmul.cc index 199d833c44c44eb0ac08647b052e573c8ccdaf17..9a73ef48d7c3d635fa73bf242d9993c8b3a76feb 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/matmul.cc +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/matmul.cc @@ -50,22 +50,24 @@ int MatMulOpenCLKernel::Init() { ocl_runtime->LoadSource(program_name, source); ocl_runtime->BuildKernel(kernel_, program_name, kernel_name, build_options); #endif - int ci = inputs_[1]->shape()[1]; + auto weight_format = inputs_[1]->GetFormat(); + if (weight_format != schema::Format_NHWC) { + MS_LOG(ERROR) << "weight format(" << weight_format << ") " + << "format not support!"; + return 1; + } + int ci = inputs_[1]->shape()[3]; int co = inputs_[1]->shape()[0]; sizeCI = {ci, UP_DIV(ci, 4)}; sizeCO = {co, UP_DIV(co, 4)}; auto allocator = ocl_runtime->GetAllocator(); padWeight_ = reinterpret_cast(allocator->Malloc(sizeCI.s[1] * sizeCO.s[1] * 16 * sizeof(FLOAT_T))); padWeight_ = reinterpret_cast(allocator->MapBuffer(padWeight_, CL_MAP_WRITE, nullptr, true)); - if (hasBias_) { - bias_ = reinterpret_cast(allocator->Malloc(sizeCO.s[1] * 4 * sizeof(FLOAT_T))); - bias_ = reinterpret_cast(allocator->MapBuffer(bias_, CL_MAP_WRITE, nullptr, true)); - } + bias_ = reinterpret_cast(allocator->Malloc(sizeCO.s[1] * 4 * sizeof(FLOAT_T))); + bias_ = reinterpret_cast(allocator->MapBuffer(bias_, CL_MAP_WRITE, nullptr, true)); PadWeight(); allocator->UnmapBuffer(padWeight_); - if (hasBias_) { - allocator->UnmapBuffer(bias_); - } + allocator->UnmapBuffer(bias_); outputs_[0]->SetFormat(schema::Format_NHWC4); MS_LOG(DEBUG) << kernel_name << " Init Done!"; return 0; @@ -98,6 +100,10 @@ void MatMulOpenCLKernel::PadWeight() { for (int i = sizeCO.s[0]; i < sizeCO.s[1] * 4; i++) { bias_[i] = 0; } + } else { + for (int i = 0; i < sizeCO.s[1] * 4; i++) { + bias_[i] = 0; + } } } @@ -114,18 +120,34 @@ int MatMulOpenCLKernel::Run() { std::vector local = {64, 4}; std::vector global = {UP_ROUND(sizeCO.s[1], local[0]), 4}; - ocl_runtime->SetKernelArg(kernel_, 0, inputs_[0]->Data()); - ocl_runtime->SetKernelArg(kernel_, 1, padWeight_); - ocl_runtime->SetKernelArg(kernel_, 2, outputs_[0]->Data()); - if (hasBias_) { - ocl_runtime->SetKernelArg(kernel_, 3, bias_); - } else { - ocl_runtime->SetKernelArg(kernel_, 3, nullptr); + cl::ImageFormat image_format; + { + image_format.image_channel_order = CL_RGBA; +#ifdef ENABLE_FP16 + image_format.image_channel_data_type = CL_HALF_FLOAT; +#else + image_format.image_channel_data_type = CL_FLOAT; +#endif } + cl_int in_error_code, in_error_code_weight, in_error_code_bias, out_error_code; + cl::Image2D img_input(*ocl_runtime->Context(), CL_MEM_READ_ONLY | CL_MEM_COPY_HOST_PTR, image_format, sizeCI.s[1], 1, + 0, inputs_[0]->Data(), &in_error_code); + cl::Image2D img_bias(*ocl_runtime->Context(), CL_MEM_READ_ONLY | CL_MEM_COPY_HOST_PTR, image_format, sizeCO.s[1], 1, + 0, bias_, &in_error_code_bias); + cl::Image2D img_out(*ocl_runtime->Context(), CL_MEM_WRITE_ONLY, image_format, sizeCO.s[1], 1, 0, nullptr, + &out_error_code); + + ocl_runtime->SetKernelArg(kernel_, 0, img_input); + ocl_runtime->SetKernelArg(kernel_, 1, padWeight_); + ocl_runtime->SetKernelArg(kernel_, 2, img_bias); + ocl_runtime->SetKernelArg(kernel_, 3, img_out); ocl_runtime->SetKernelArg(kernel_, 4, sizeCI); ocl_runtime->SetKernelArg(kernel_, 5, sizeCO); ocl_runtime->SetKernelArg(kernel_, 6, hasBias_ ? 1 : 0); ocl_runtime->RunKernel(kernel_, global, local, nullptr); + auto origin = cl::array{0, 0, 0}; + auto region = cl::array{(size_t)(sizeCO.s[1]), 1, 1}; + ocl_runtime->GetDefaultCommandQueue()->enqueueReadImage(img_out, CL_TRUE, origin, region, 0, 0, outputs_[0]->Data()); return 0; } @@ -151,4 +173,3 @@ kernel::LiteKernel *OpenCLMatMulKernelCreator(const std::vector +#include + +#include "common/common_test.h" +#include "mindspore/lite/src/common/file_utils.h" +#include "mindspore/lite/src/runtime/opencl/opencl_runtime.h" +#include "mindspore/lite/src/runtime/kernel/opencl/subgraph_opencl_kernel.h" +#include "mindspore/lite/src/runtime/kernel/opencl/kernel/conv2d_transpose.h" +#include "mindspore/core/utils/log_adapter.h" + +namespace mindspore { +class TestConv2dTransposeOpenCL : public mindspore::Common { + public: + TestConv2dTransposeOpenCL() {} +}; + +TEST_F(TestConv2dTransposeOpenCL, Conv2dTransposeFp32) { + // setbuf(stdout, NULL); + auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); + ocl_runtime->Init(); + int pad = 0; + int n = 1; + int h = 240; + int w = 240; + int kh = 2; + int kw = 2; + int ci = 128; + int co = 128; + int oh = 2 * h - 1 + 2 * (kh - 1 - pad) - kh + 1; + int ow = 2 * w - 1 + 2 * (kw - 1 - pad) - kw + 1; + + size_t input_size; + std::string input_path = "./test_data/conv2d_transpose/conv2d_transpose_fp32_input.bin"; + auto input_data = reinterpret_cast(mindspore::lite::ReadFile(input_path.c_str(), &input_size)); + + size_t weight_size; + std::string weight_path = "./test_data/conv2d_transpose/conv2d_transpose_fp32_weight.bin"; + auto weight_data = reinterpret_cast(mindspore::lite::ReadFile(weight_path.c_str(), &weight_size)); + + size_t bias_size; + std::string bias_path = "./test_data/conv2d_transpose/conv2d_transpose_fp32_bias.bin"; + auto bias_data = reinterpret_cast(mindspore::lite::ReadFile(bias_path.c_str(), &bias_size)); + + lite::tensor::Tensor *tensor_x = new lite::tensor::Tensor(TypeId(kNumberTypeFloat32), {1, h, w, ci}); + tensor_x->SetData(input_data); + + lite::tensor::Tensor *tensor_w = new lite::tensor::Tensor(TypeId(kNumberTypeFloat32), {co, kh, kw, ci}); + tensor_w->SetData(weight_data); + + lite::tensor::Tensor *tensor_bias = new lite::tensor::Tensor(TypeId(kNumberTypeFloat32), {co}); + tensor_bias->SetData(bias_data); + + lite::tensor::Tensor *tensor_out = new lite::tensor::Tensor(TypeId(kNumberTypeFloat32), {1, oh, ow, co}); + std::vector inputs{tensor_x, tensor_w, tensor_bias}; + std::vector outputs{tensor_out}; + ConvParameter *opParameter = new ConvParameter(); + opParameter->kernel_h_ = kh; + opParameter->kernel_w_ = kw; + opParameter->stride_h_ = 2; + opParameter->stride_w_ = 2; + opParameter->pad_h_ = pad; + opParameter->pad_w_ = pad; + opParameter->input_channel_ = ci; + opParameter->output_channel_ = co; + auto *arith_kernel = + new kernel::Conv2dTransposeOpenCLKernel(reinterpret_cast(opParameter), inputs, outputs); + arith_kernel->Init(); + + std::vector kernels{arith_kernel}; + auto *pGraph = new kernel::SubGraphOpenCLKernel({tensor_x}, outputs, kernels, kernels, kernels); + pGraph->Init(); + pGraph->Run(); + + printf("==================output data=================\n"); + float *output_data = reinterpret_cast(tensor_out->Data()); + std::cout << std::endl; + size_t output_size; + std::string output_path = "./test_data/conv2d_transpose/conv2d_transpose_fp32_output.bin"; + auto correct_data = reinterpret_cast(mindspore::lite::ReadFile(output_path.c_str(), &output_size)); + int size_n = oh * ow * co; + size_n = size_n > 100 ? 100 : size_n; + for (int i = 0; i < size_n; i++) { + std::cout << output_data[i] << ", "; + if ((i + 1) % co == 0) { + std::cout << std::endl; + } + } + std::cout << std::endl; + + // compare + CompareOutputData(output_data, correct_data, oh * ow * co, 0.00001); + + MS_LOG(INFO) << "Test Conv2dTransposeFp32 passed"; +} +} // namespace mindspore diff --git a/mindspore/lite/test/ut/src/runtime/kernel/opencl/matmul_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/opencl/matmul_tests.cc index 62463c6d825a304ddbf5dd54fe22e297ca48eefb..50d46b2a681fc6cbf722f7af3764af1e6a5d1129 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/opencl/matmul_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/opencl/matmul_tests.cc @@ -41,38 +41,39 @@ TEST_F(TestMatMulOpenCL, MatMulFp32) { std::string weight_path = "./test_data/matmul/matmul_fp32_weight.bin"; auto weight_data = reinterpret_cast(mindspore::lite::ReadFile(weight_path.c_str(), &weight_size)); - lite::tensor::Tensor *tensor_x = new lite::tensor::Tensor(TypeId(kNumberTypeFloat32), {1, ci}); + lite::tensor::Tensor *tensor_x = new lite::tensor::Tensor(TypeId(kNumberTypeFloat32), {1, 1, 1, ci}); + tensor_x->SetData(input_data); - lite::tensor::Tensor *tensor_w = new lite::tensor::Tensor(TypeId(kNumberTypeFloat32), {co, ci}); + lite::tensor::Tensor *tensor_w = new lite::tensor::Tensor(TypeId(kNumberTypeFloat32), {co, 1, 1, ci}); tensor_w->SetData(weight_data); - lite::tensor::Tensor *tensor_out = new lite::tensor::Tensor(TypeId(kNumberTypeFloat32), {1, co}); + lite::tensor::Tensor *tensor_out = new lite::tensor::Tensor(TypeId(kNumberTypeFloat32), {1, 1, 1, co}); std::vector inputs{tensor_x, tensor_w}; std::vector outputs{tensor_out}; auto *arith_kernel = new kernel::MatMulOpenCLKernel(nullptr, inputs, outputs, false); arith_kernel->Init(); std::vector kernels{arith_kernel}; - auto *pGraph = new kernel::SubGraphOpenCLKernel(inputs, outputs, kernels, kernels, kernels); + auto *pGraph = new kernel::SubGraphOpenCLKernel({tensor_x}, outputs, kernels, kernels, kernels); pGraph->Init(); - - memcpy(inputs[0]->Data(), input_data, sizeof(float) * ci); pGraph->Run(); + size_t output_size; + std::string output_path = "./test_data/matmul/matmul_fp32_output.bin"; + auto correct_data = reinterpret_cast(mindspore::lite::ReadFile(output_path.c_str(), &output_size)); printf("==================output data=================\n"); float *output_data = reinterpret_cast(tensor_out->Data()); std::cout << std::endl; - for (int i = 0; i < co; i++) { - std::cout << output_data[i] << ", "; + int size_n = co; + size_n = size_n > 100 ? 100 : size_n; + for (int i = 0; i < size_n; i++) { + std::cout << output_data[i] << " "; } std::cout << std::endl; - size_t output_size; - std::string output_path = "./test_data/matmul/matmul_fp32_output.bin"; - auto correct_data = reinterpret_cast(mindspore::lite::ReadFile(output_path.c_str(), &output_size)); // compare - CompareOutputData(output_data, correct_data, co * sizeof(float), 0.00001); + CompareOutputData(output_data, correct_data, co, 0.00001); delete input_data; delete weight_data;