diff --git a/mindspore/lite/src/runtime/kernel/opencl/cl/conv2d_transpose2x2.cl b/mindspore/lite/src/runtime/kernel/opencl/cl/conv2d_transpose2x2.cl index 2014bb589d0d7b9420ba4219ece5054f85fd70ec..eea2139ee2f332408e798cad26a38ce69b01c805 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/cl/conv2d_transpose2x2.cl +++ b/mindspore/lite/src/runtime/kernel/opencl/cl/conv2d_transpose2x2.cl @@ -1,3 +1,4 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable __constant sampler_t smp_zero = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; __kernel void conv2d_transpose2x2(__read_only image2d_t src_data, __global FLT16 *weight, __read_only image2d_t biases, __write_only image2d_t dst_data, int2 kernel_size, int2 stride, int2 padding, diff --git a/mindspore/lite/src/runtime/kernel/opencl/cl/to_format.cl b/mindspore/lite/src/runtime/kernel/opencl/cl/to_format.cl index 31b6e02b559a2649e5b94adaef5220a353907f3d..fecdcb488743b369167d3b4c4364c91719c656ed 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/cl/to_format.cl +++ b/mindspore/lite/src/runtime/kernel/opencl/cl/to_format.cl @@ -1,3 +1,4 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable __constant sampler_t smp_zero = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; __kernel void to_format_NCHW_to_NHWC4_IMG(__global FLT4 *src_data, __write_only image2d_t dst_data, int4 size, int4 shape) { diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/conv2d_transpose.cc b/mindspore/lite/src/runtime/kernel/opencl/kernel/conv2d_transpose.cc index dcf776efe654c0a6bc6fb056430949e06820dfec..ae751adc316f5453a1f710d24a52950c17037e10 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/conv2d_transpose.cc +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/conv2d_transpose.cc @@ -16,6 +16,7 @@ #include #include +#include "src/common/utils.h" #include "src/kernel_registry.h" #include "src/runtime/opencl/opencl_runtime.h" #include "src/runtime/kernel/opencl/kernel/conv2d_transpose.h" @@ -41,6 +42,7 @@ int Conv2dTransposeOpenCLKernel::Init() { } std::string kernel_name = "conv2d_transpose2x2"; auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); + enable_fp16_ = ocl_runtime->GetFp16Enable(); #ifdef PROGRAM_WITH_IL kernel_ = ocl_runtime->GetKernelFromBinary(kernel_name); #else @@ -70,13 +72,18 @@ void Conv2dTransposeOpenCLKernel::PadWeight() { int div_ci = UP_DIV(ci, C4NUM); int div_co = UP_DIV(co, C4NUM); auto allocator = lite::opencl::OpenCLRuntime::GetInstance()->GetAllocator(); + auto data_size = enable_fp16_ ? sizeof(float16_t) : sizeof(float); + using FLT = float; + if (enable_fp16_) { + using FLT = float16_t; + } // IHWO to OHWI4(I)4(O)(converter format is IHWO) // init padWeight_(buffer mem) - padWeight_ = - reinterpret_cast(allocator->Malloc(div_ci * div_co * C4NUM * C4NUM * kh * kw * sizeof(FLOAT_t))); - padWeight_ = reinterpret_cast(allocator->MapBuffer(padWeight_, CL_MAP_WRITE, nullptr, true)); - auto origin_weight = reinterpret_cast(in_tensors_.at(kWeightIndex)->Data()); + padWeight_ = allocator->Malloc(div_ci * div_co * C4NUM * C4NUM * kh * kw * data_size); + padWeight_ = allocator->MapBuffer(padWeight_, CL_MAP_WRITE, nullptr, true); + auto origin_weight = in_tensors_.at(kWeightIndex)->Data(); + auto weight_dtype = in_tensors_.at(kWeightIndex)->data_type(); int index = 0; for (int co_i = 0; co_i < div_co; co_i++) { for (int kh_i = 0; kh_i < kh; kh_i++) { @@ -88,9 +95,19 @@ void Conv2dTransposeOpenCLKernel::PadWeight() { int ci_offset = ci_i * C4NUM + ci4_i; if (co_offset < co && ci_offset < ci) { int ori_index = ((ci_offset * kh + kh_i) * kw + kw_i) * ci + co_offset; - padWeight_[index++] = origin_weight[ori_index]; + if (enable_fp16_) { + if (weight_dtype == kNumberTypeFloat32) { + reinterpret_cast(padWeight_)[index++] = + lite::Float32ToShort(reinterpret_cast(origin_weight)[ori_index]); + } else { + reinterpret_cast(padWeight_)[index++] = + reinterpret_cast(origin_weight)[ori_index]; + } + } else { + reinterpret_cast(padWeight_)[index++] = reinterpret_cast(origin_weight)[ori_index]; + } } else { - padWeight_[index++] = 0.; + reinterpret_cast(padWeight_)[index++] = 0.; } } } @@ -104,17 +121,24 @@ void Conv2dTransposeOpenCLKernel::PadWeight() { size_t im_dst_x, im_dst_y; im_dst_x = div_co; im_dst_y = 1; -#ifdef ENABLE_FP16 - size_t img_dtype = CL_HALF_FLOAT; -#else size_t img_dtype = CL_FLOAT; -#endif + if (enable_fp16_) { + img_dtype = CL_HALF_FLOAT; + } std::vector img_size{im_dst_x, im_dst_y, img_dtype}; - bias_ = reinterpret_cast(allocator->Malloc(im_dst_x * im_dst_y * C4NUM * sizeof(FLOAT_t), img_size)); - bias_ = reinterpret_cast(allocator->MapBuffer(bias_, CL_MAP_WRITE, nullptr, true)); - memset(bias_, 0x00, div_co * C4NUM * sizeof(FLOAT_t)); + bias_ = allocator->Malloc(im_dst_x * im_dst_y * C4NUM * data_size, img_size); + bias_ = allocator->MapBuffer(bias_, CL_MAP_WRITE, nullptr, true); + memset(bias_, 0x00, div_co * C4NUM * sizeof(data_size)); + auto bias_dtype = in_tensors_[2]->data_type(); if (in_tensors_.size() >= 3) { - memcpy(bias_, in_tensors_[2]->Data(), co * sizeof(FLOAT_t)); + if (bias_dtype == kNumberTypeFloat32 && enable_fp16_) { + auto fdata = reinterpret_cast(in_tensors_[2]->Data()); + for (int i = 0; i < co; i++) { + reinterpret_cast(bias_)[i] = lite::Float32ToShort(fdata[i]); + } + } else { + memcpy(bias_, in_tensors_[2]->Data(), co * data_size); + } } allocator->UnmapBuffer(bias_); } @@ -123,11 +147,10 @@ int Conv2dTransposeOpenCLKernel::GetImageSize(size_t idx, std::vector *i size_t im_dst_x, im_dst_y; im_dst_x = UP_DIV(out_tensors_[0]->Channel() * out_tensors_[0]->Width(), C4NUM); im_dst_y = out_tensors_[0]->Height(); -#ifdef ENABLE_FP16 - size_t img_dtype = CL_HALF_FLOAT; -#else size_t img_dtype = CL_FLOAT; -#endif + if (enable_fp16_) { + img_dtype = CL_HALF_FLOAT; + } img_size->clear(); std::vector vec{im_dst_x, im_dst_y, img_dtype}; *img_size = vec; @@ -197,4 +220,5 @@ kernel::LiteKernel *OpenCLConv2dTransposeKernelCreator(const std::vector *img_size MS_LOG(ERROR) << "Unsupported format. " << out_tensors_[0]->GetFormat(); } img_size->clear(); -#ifdef ENABLE_FP16 - size_t img_dtype = CL_HALF_FLOAT; -#else + auto enable_fp16_ = lite::opencl::OpenCLRuntime::GetInstance()->GetFp16Enable(); size_t img_dtype = CL_FLOAT; -#endif + if (enable_fp16_) { + img_dtype = CL_HALF_FLOAT; + } std::vector vec{im_dst_x, im_dst_y, img_dtype}; *img_size = vec; return RET_OK; diff --git a/mindspore/lite/src/runtime/opencl/opencl_allocator.cc b/mindspore/lite/src/runtime/opencl/opencl_allocator.cc index fc9e1caf1518e04a55c0f2941eb2ff1a21b37062..404e9a372ef8b3b4ea3cbb4be887c82ea2d754dc 100644 --- a/mindspore/lite/src/runtime/opencl/opencl_allocator.cc +++ b/mindspore/lite/src/runtime/opencl/opencl_allocator.cc @@ -50,10 +50,12 @@ void *OpenCLAllocator::Malloc(size_t size, const std::vector &img_size) auto svm_capabilities = ocl_runtime->GetSVMCapabilities(); size_t img_pitch = 0; + size_t dtype_size = 1; if (!img_size.empty()) { + dtype_size = img_size[2] == CL_FLOAT ? sizeof(cl_float4) : sizeof(cl_half4); uint32_t image_alignment = ocl_runtime->GetImagePitchAlignment(); img_pitch = (img_size[0] + image_alignment - 1) / image_alignment * image_alignment; - size = img_pitch * img_size[1] * sizeof(cl_float4); + size = img_pitch * img_size[1] * dtype_size; } if (size > MAX_MALLOC_SIZE) { MS_LOG(ERROR) << "MallocData out of max_size, size: " << size; @@ -107,7 +109,7 @@ void *OpenCLAllocator::Malloc(size_t size, const std::vector &img_size) if (!img_size.empty()) { cl::ImageFormat image_format(CL_RGBA, img_size[2]); cl::Image2D *image = new (std::nothrow) cl::Image2D(*ocl_runtime->Context(), image_format, *buffer, img_size[0], - img_size[1], img_pitch * sizeof(cl_float4), &ret); + img_size[1], img_pitch * dtype_size, &ret); if (image == nullptr || ret != CL_SUCCESS) { delete buffer; UnLock(); diff --git a/mindspore/lite/test/ut/src/runtime/kernel/opencl/conv2d_transpose_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/opencl/conv2d_transpose_tests.cc index 4ec2d16697a30add053ba832bc7bc71a2bb2e5a6..e814347820b46249843752a5e4f2ea85e7ebfb5e 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/opencl/conv2d_transpose_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/opencl/conv2d_transpose_tests.cc @@ -29,23 +29,26 @@ class TestConv2dTransposeOpenCL : public mindspore::CommonTest { TestConv2dTransposeOpenCL() {} }; -TEST_F(TestConv2dTransposeOpenCL, Conv2dTransposeFp32) { +void RunTestCase(const std::vector shape, const std::vector file_path, bool fp16) { auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); + if (fp16) { + ocl_runtime->SetFp16Enable(true); + } ocl_runtime->Init(); auto allocator = ocl_runtime->GetAllocator(); - int pad = 0; - int n = 1; - int h = 240; - int w = 240; - int kh = 2; - int kw = 2; - int ci = 128; - int co = 128; + int pad = shape[0]; + int n = shape[1]; + int h = shape[2]; + int w = shape[3]; + int kh = shape[4]; + int kw = shape[5]; + int ci = shape[6]; + int co = shape[7]; int oh = 2 * h - 1 + 2 * (kh - 1 - pad) - kh + 1; int ow = 2 * w - 1 + 2 * (kw - 1 - pad) - kw + 1; size_t input_size; - std::string input_path = "./test_data/conv2d_transpose/conv2d_transpose_fp32_input.bin"; + std::string input_path = file_path[0]; auto input_data = reinterpret_cast(mindspore::lite::ReadFile(input_path.c_str(), &input_size)); if (input_data == nullptr) { MS_LOG(ERROR) << "input_data load error."; @@ -53,7 +56,7 @@ TEST_F(TestConv2dTransposeOpenCL, Conv2dTransposeFp32) { } size_t weight_size; - std::string weight_path = "./test_data/conv2d_transpose/conv2d_transpose_fp32_weight.bin"; + std::string weight_path = file_path[1]; auto weight_data = reinterpret_cast(mindspore::lite::ReadFile(weight_path.c_str(), &weight_size)); if (weight_data == nullptr) { MS_LOG(ERROR) << "weight_data load error."; @@ -61,14 +64,15 @@ TEST_F(TestConv2dTransposeOpenCL, Conv2dTransposeFp32) { } size_t bias_size; - std::string bias_path = "./test_data/conv2d_transpose/conv2d_transpose_fp32_bias.bin"; + std::string bias_path = file_path[2]; auto bias_data = reinterpret_cast(mindspore::lite::ReadFile(bias_path.c_str(), &bias_size)); if (bias_data == nullptr) { MS_LOG(ERROR) << "bias_data load error."; return; } std::vector input_shape = {n, h, w, ci}; - auto tensor_x_ptr = std::make_unique(TypeId(kNumberTypeFloat32), input_shape); + auto tensor_x_ptr = + std::make_unique(TypeId(fp16 ? kNumberTypeFloat16 : kNumberTypeFloat32), input_shape); auto tensor_x = tensor_x_ptr.get(); if (tensor_x == nullptr) { MS_LOG(ERROR) << "tensor_x create error."; @@ -76,7 +80,8 @@ TEST_F(TestConv2dTransposeOpenCL, Conv2dTransposeFp32) { } std::vector weight_shape = {co, kh, kw, ci}; - auto tensor_w_ptr = std::make_unique(TypeId(kNumberTypeFloat32), weight_shape); + auto tensor_w_ptr = + std::make_unique(TypeId(fp16 ? kNumberTypeFloat16 : kNumberTypeFloat32), weight_shape); auto tensor_w = tensor_w_ptr.get(); if (tensor_w == nullptr) { MS_LOG(ERROR) << "tensor_w create error."; @@ -85,7 +90,8 @@ TEST_F(TestConv2dTransposeOpenCL, Conv2dTransposeFp32) { tensor_w->SetData(weight_data); std::vector bias_shape = {co}; - auto tensor_bias_ptr = std::make_unique(TypeId(kNumberTypeFloat32), bias_shape); + auto tensor_bias_ptr = + std::make_unique(TypeId(fp16 ? kNumberTypeFloat16 : kNumberTypeFloat32), bias_shape); auto tensor_bias = tensor_bias_ptr.get(); if (tensor_bias == nullptr) { MS_LOG(ERROR) << "tensor_bias create error."; @@ -94,7 +100,8 @@ TEST_F(TestConv2dTransposeOpenCL, Conv2dTransposeFp32) { tensor_bias->SetData(bias_data); std::vector out_shape = {1, oh, ow, co}; - auto tensor_out_ptr = std::make_unique(TypeId(kNumberTypeFloat32), out_shape); + auto tensor_out_ptr = + std::make_unique(TypeId(fp16 ? kNumberTypeFloat16 : kNumberTypeFloat32), out_shape); auto tensor_out = tensor_out_ptr.get(); if (tensor_out == nullptr) { MS_LOG(ERROR) << "tensor_out create error."; @@ -116,17 +123,18 @@ TEST_F(TestConv2dTransposeOpenCL, Conv2dTransposeFp32) { opParameter->pad_w_ = pad; opParameter->input_channel_ = ci; opParameter->output_channel_ = co; - auto arith_kernel_ptr = std::make_unique( + auto op_kernel_ptr = std::make_unique( reinterpret_cast(opParameter), inputs, outputs); - auto arith_kernel = arith_kernel_ptr.get(); - if (arith_kernel == nullptr) { - MS_LOG(ERROR) << "arith_kernel create error."; + auto op_kernel = op_kernel_ptr.get(); + if (op_kernel == nullptr) { + MS_LOG(ERROR) << "op_kernel create error."; return; } - arith_kernel->Init(); + op_kernel->set_name("DeConv"); + op_kernel->Init(); inputs[0]->MallocData(allocator); - std::vector kernels{arith_kernel}; + std::vector kernels{op_kernel}; std::vector inputs_g{tensor_x}; auto pGraph_ptr = std::make_unique(inputs_g, outputs, kernels, kernels, kernels); auto pGraph = pGraph_ptr.get(); @@ -138,13 +146,16 @@ TEST_F(TestConv2dTransposeOpenCL, Conv2dTransposeFp32) { pGraph->Init(); memcpy(inputs[0]->Data(), input_data, input_size); pGraph->Run(); - + using FLT = float; + if (fp16) { + using FLT = float16_t; + } std::cout << "==================output data=================" << std::endl; - float *output_data = reinterpret_cast(tensor_out->Data()); + FLT *output_data = reinterpret_cast(tensor_out->Data()); std::cout << std::endl; size_t output_size; - std::string output_path = "./test_data/conv2d_transpose/conv2d_transpose_fp32_output.bin"; - auto correct_data = reinterpret_cast(mindspore::lite::ReadFile(output_path.c_str(), &output_size)); + std::string output_path = file_path[3]; + auto correct_data = reinterpret_cast(mindspore::lite::ReadFile(output_path.c_str(), &output_size)); if (correct_data == nullptr) { MS_LOG(ERROR) << "correct_data create error."; return; @@ -152,7 +163,7 @@ TEST_F(TestConv2dTransposeOpenCL, Conv2dTransposeFp32) { int size_n = oh * ow * co; size_n = size_n > 100 ? 100 : size_n; for (int i = 0; i < size_n; i++) { - std::cout << output_data[i] << ", "; + std::cout << output_data[i] << ", " << correct_data[i] << " "; if ((i + 1) % co == 0) { std::cout << std::endl; } @@ -160,10 +171,43 @@ TEST_F(TestConv2dTransposeOpenCL, Conv2dTransposeFp32) { std::cout << std::endl; // compare - CompareOutputData(output_data, correct_data, oh * ow * co, 0.00001); + CommonTest::CompareOutputData(output_data, correct_data, oh * ow * co, 0.00001); inputs[0]->SetData(nullptr); outputs[0]->SetData(nullptr); MS_LOG(INFO) << "Test Conv2dTransposeFp32 passed"; lite::opencl::OpenCLRuntime::DeleteInstance(); } +TEST_F(TestConv2dTransposeOpenCL, Conv2dTransposeFp32) { + int pad = 0; + int n = 1; + int h = 240; + int w = 240; + int kh = 2; + int kw = 2; + int ci = 128; + int co = 128; + std::vector shape = {pad, n, h, w, kh, kw, ci, co}; + std::vector file_path = {"./test_data/conv2d_transpose/conv2d_transpose_fp32_input.bin", + "./test_data/conv2d_transpose/conv2d_transpose_fp32_weight.bin", + "./test_data/conv2d_transpose/conv2d_transpose_fp32_bias.bin", + "./test_data/conv2d_transpose/conv2d_transpose_fp32_output.bin"}; + RunTestCase(shape, file_path, false); +} + +TEST_F(TestConv2dTransposeOpenCL, Conv2dTransposeFp16) { + int pad = 0; + int n = 1; + int h = 240; + int w = 240; + int kh = 2; + int kw = 2; + int ci = 128; + int co = 128; + std::vector shape = {pad, n, h, w, kh, kw, ci, co}; + std::vector file_path = {"./test_data/conv2d_transpose/conv2d_transpose_fp16_input.bin", + "./test_data/conv2d_transpose/conv2d_transpose_fp16_weight.bin", + "./test_data/conv2d_transpose/conv2d_transpose_fp16_bias.bin", + "./test_data/conv2d_transpose/conv2d_transpose_fp16_output.bin"}; + RunTestCase(shape, file_path, true); +} } // namespace mindspore