提交 9c9b721b 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!5109 conv2d transpose support fp16

Merge pull request !5109 from chenzupeng/master-lite
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
__constant sampler_t smp_zero = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
__kernel void conv2d_transpose2x2(__read_only image2d_t src_data, __global FLT16 *weight, __read_only image2d_t biases,
__write_only image2d_t dst_data, int2 kernel_size, int2 stride, int2 padding,
......
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
__constant sampler_t smp_zero = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
__kernel void to_format_NCHW_to_NHWC4_IMG(__global FLT4 *src_data, __write_only image2d_t dst_data, int4 size,
int4 shape) {
......
......@@ -16,6 +16,7 @@
#include <string>
#include <set>
#include "src/common/utils.h"
#include "src/kernel_registry.h"
#include "src/runtime/opencl/opencl_runtime.h"
#include "src/runtime/kernel/opencl/kernel/conv2d_transpose.h"
......@@ -41,6 +42,7 @@ int Conv2dTransposeOpenCLKernel::Init() {
}
std::string kernel_name = "conv2d_transpose2x2";
auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance();
enable_fp16_ = ocl_runtime->GetFp16Enable();
#ifdef PROGRAM_WITH_IL
kernel_ = ocl_runtime->GetKernelFromBinary(kernel_name);
#else
......@@ -70,13 +72,18 @@ void Conv2dTransposeOpenCLKernel::PadWeight() {
int div_ci = UP_DIV(ci, C4NUM);
int div_co = UP_DIV(co, C4NUM);
auto allocator = lite::opencl::OpenCLRuntime::GetInstance()->GetAllocator();
auto data_size = enable_fp16_ ? sizeof(float16_t) : sizeof(float);
using FLT = float;
if (enable_fp16_) {
using FLT = float16_t;
}
// IHWO to OHWI4(I)4(O)(converter format is IHWO)
// init padWeight_(buffer mem)
padWeight_ =
reinterpret_cast<FLOAT_t *>(allocator->Malloc(div_ci * div_co * C4NUM * C4NUM * kh * kw * sizeof(FLOAT_t)));
padWeight_ = reinterpret_cast<FLOAT_t *>(allocator->MapBuffer(padWeight_, CL_MAP_WRITE, nullptr, true));
auto origin_weight = reinterpret_cast<FLOAT_t *>(in_tensors_.at(kWeightIndex)->Data());
padWeight_ = allocator->Malloc(div_ci * div_co * C4NUM * C4NUM * kh * kw * data_size);
padWeight_ = allocator->MapBuffer(padWeight_, CL_MAP_WRITE, nullptr, true);
auto origin_weight = in_tensors_.at(kWeightIndex)->Data();
auto weight_dtype = in_tensors_.at(kWeightIndex)->data_type();
int index = 0;
for (int co_i = 0; co_i < div_co; co_i++) {
for (int kh_i = 0; kh_i < kh; kh_i++) {
......@@ -88,9 +95,19 @@ void Conv2dTransposeOpenCLKernel::PadWeight() {
int ci_offset = ci_i * C4NUM + ci4_i;
if (co_offset < co && ci_offset < ci) {
int ori_index = ((ci_offset * kh + kh_i) * kw + kw_i) * ci + co_offset;
padWeight_[index++] = origin_weight[ori_index];
if (enable_fp16_) {
if (weight_dtype == kNumberTypeFloat32) {
reinterpret_cast<float16_t *>(padWeight_)[index++] =
lite::Float32ToShort(reinterpret_cast<float *>(origin_weight)[ori_index]);
} else {
reinterpret_cast<float16_t *>(padWeight_)[index++] =
reinterpret_cast<float16_t *>(origin_weight)[ori_index];
}
} else {
reinterpret_cast<float *>(padWeight_)[index++] = reinterpret_cast<float *>(origin_weight)[ori_index];
}
} else {
padWeight_[index++] = 0.;
reinterpret_cast<FLT *>(padWeight_)[index++] = 0.;
}
}
}
......@@ -104,17 +121,24 @@ void Conv2dTransposeOpenCLKernel::PadWeight() {
size_t im_dst_x, im_dst_y;
im_dst_x = div_co;
im_dst_y = 1;
#ifdef ENABLE_FP16
size_t img_dtype = CL_HALF_FLOAT;
#else
size_t img_dtype = CL_FLOAT;
#endif
if (enable_fp16_) {
img_dtype = CL_HALF_FLOAT;
}
std::vector<size_t> img_size{im_dst_x, im_dst_y, img_dtype};
bias_ = reinterpret_cast<FLOAT_t *>(allocator->Malloc(im_dst_x * im_dst_y * C4NUM * sizeof(FLOAT_t), img_size));
bias_ = reinterpret_cast<FLOAT_t *>(allocator->MapBuffer(bias_, CL_MAP_WRITE, nullptr, true));
memset(bias_, 0x00, div_co * C4NUM * sizeof(FLOAT_t));
bias_ = allocator->Malloc(im_dst_x * im_dst_y * C4NUM * data_size, img_size);
bias_ = allocator->MapBuffer(bias_, CL_MAP_WRITE, nullptr, true);
memset(bias_, 0x00, div_co * C4NUM * sizeof(data_size));
auto bias_dtype = in_tensors_[2]->data_type();
if (in_tensors_.size() >= 3) {
memcpy(bias_, in_tensors_[2]->Data(), co * sizeof(FLOAT_t));
if (bias_dtype == kNumberTypeFloat32 && enable_fp16_) {
auto fdata = reinterpret_cast<float *>(in_tensors_[2]->Data());
for (int i = 0; i < co; i++) {
reinterpret_cast<float16_t *>(bias_)[i] = lite::Float32ToShort(fdata[i]);
}
} else {
memcpy(bias_, in_tensors_[2]->Data(), co * data_size);
}
}
allocator->UnmapBuffer(bias_);
}
......@@ -123,11 +147,10 @@ int Conv2dTransposeOpenCLKernel::GetImageSize(size_t idx, std::vector<size_t> *i
size_t im_dst_x, im_dst_y;
im_dst_x = UP_DIV(out_tensors_[0]->Channel() * out_tensors_[0]->Width(), C4NUM);
im_dst_y = out_tensors_[0]->Height();
#ifdef ENABLE_FP16
size_t img_dtype = CL_HALF_FLOAT;
#else
size_t img_dtype = CL_FLOAT;
#endif
if (enable_fp16_) {
img_dtype = CL_HALF_FLOAT;
}
img_size->clear();
std::vector<size_t> vec{im_dst_x, im_dst_y, img_dtype};
*img_size = vec;
......@@ -197,4 +220,5 @@ kernel::LiteKernel *OpenCLConv2dTransposeKernelCreator(const std::vector<lite::t
}
REG_KERNEL(kGPU, kNumberTypeFloat32, PrimitiveType_DeConv2D, OpenCLConv2dTransposeKernelCreator)
REG_KERNEL(kGPU, kNumberTypeFloat16, PrimitiveType_DeConv2D, OpenCLConv2dTransposeKernelCreator)
} // namespace mindspore::kernel
......@@ -42,8 +42,9 @@ class Conv2dTransposeOpenCLKernel : public OpenCLKernel {
private:
ConvParameter *parameter_;
cl::Kernel kernel_;
FLOAT_t *padWeight_;
FLOAT_t *bias_;
void *padWeight_;
void *bias_;
bool enable_fp16_{false};
};
} // namespace mindspore::kernel
......
......@@ -128,11 +128,11 @@ int ToFormatOpenCLKernel::GetImageSize(size_t idx, std::vector<size_t> *img_size
MS_LOG(ERROR) << "Unsupported format. " << out_tensors_[0]->GetFormat();
}
img_size->clear();
#ifdef ENABLE_FP16
size_t img_dtype = CL_HALF_FLOAT;
#else
auto enable_fp16_ = lite::opencl::OpenCLRuntime::GetInstance()->GetFp16Enable();
size_t img_dtype = CL_FLOAT;
#endif
if (enable_fp16_) {
img_dtype = CL_HALF_FLOAT;
}
std::vector<size_t> vec{im_dst_x, im_dst_y, img_dtype};
*img_size = vec;
return RET_OK;
......
......@@ -50,10 +50,12 @@ void *OpenCLAllocator::Malloc(size_t size, const std::vector<size_t> &img_size)
auto svm_capabilities = ocl_runtime->GetSVMCapabilities();
size_t img_pitch = 0;
size_t dtype_size = 1;
if (!img_size.empty()) {
dtype_size = img_size[2] == CL_FLOAT ? sizeof(cl_float4) : sizeof(cl_half4);
uint32_t image_alignment = ocl_runtime->GetImagePitchAlignment();
img_pitch = (img_size[0] + image_alignment - 1) / image_alignment * image_alignment;
size = img_pitch * img_size[1] * sizeof(cl_float4);
size = img_pitch * img_size[1] * dtype_size;
}
if (size > MAX_MALLOC_SIZE) {
MS_LOG(ERROR) << "MallocData out of max_size, size: " << size;
......@@ -107,7 +109,7 @@ void *OpenCLAllocator::Malloc(size_t size, const std::vector<size_t> &img_size)
if (!img_size.empty()) {
cl::ImageFormat image_format(CL_RGBA, img_size[2]);
cl::Image2D *image = new (std::nothrow) cl::Image2D(*ocl_runtime->Context(), image_format, *buffer, img_size[0],
img_size[1], img_pitch * sizeof(cl_float4), &ret);
img_size[1], img_pitch * dtype_size, &ret);
if (image == nullptr || ret != CL_SUCCESS) {
delete buffer;
UnLock();
......
......@@ -29,23 +29,26 @@ class TestConv2dTransposeOpenCL : public mindspore::CommonTest {
TestConv2dTransposeOpenCL() {}
};
TEST_F(TestConv2dTransposeOpenCL, Conv2dTransposeFp32) {
void RunTestCase(const std::vector<int> shape, const std::vector<std::string> file_path, bool fp16) {
auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance();
if (fp16) {
ocl_runtime->SetFp16Enable(true);
}
ocl_runtime->Init();
auto allocator = ocl_runtime->GetAllocator();
int pad = 0;
int n = 1;
int h = 240;
int w = 240;
int kh = 2;
int kw = 2;
int ci = 128;
int co = 128;
int pad = shape[0];
int n = shape[1];
int h = shape[2];
int w = shape[3];
int kh = shape[4];
int kw = shape[5];
int ci = shape[6];
int co = shape[7];
int oh = 2 * h - 1 + 2 * (kh - 1 - pad) - kh + 1;
int ow = 2 * w - 1 + 2 * (kw - 1 - pad) - kw + 1;
size_t input_size;
std::string input_path = "./test_data/conv2d_transpose/conv2d_transpose_fp32_input.bin";
std::string input_path = file_path[0];
auto input_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(input_path.c_str(), &input_size));
if (input_data == nullptr) {
MS_LOG(ERROR) << "input_data load error.";
......@@ -53,7 +56,7 @@ TEST_F(TestConv2dTransposeOpenCL, Conv2dTransposeFp32) {
}
size_t weight_size;
std::string weight_path = "./test_data/conv2d_transpose/conv2d_transpose_fp32_weight.bin";
std::string weight_path = file_path[1];
auto weight_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(weight_path.c_str(), &weight_size));
if (weight_data == nullptr) {
MS_LOG(ERROR) << "weight_data load error.";
......@@ -61,14 +64,15 @@ TEST_F(TestConv2dTransposeOpenCL, Conv2dTransposeFp32) {
}
size_t bias_size;
std::string bias_path = "./test_data/conv2d_transpose/conv2d_transpose_fp32_bias.bin";
std::string bias_path = file_path[2];
auto bias_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(bias_path.c_str(), &bias_size));
if (bias_data == nullptr) {
MS_LOG(ERROR) << "bias_data load error.";
return;
}
std::vector<int> input_shape = {n, h, w, ci};
auto tensor_x_ptr = std::make_unique<lite::tensor::Tensor>(TypeId(kNumberTypeFloat32), input_shape);
auto tensor_x_ptr =
std::make_unique<lite::tensor::Tensor>(TypeId(fp16 ? kNumberTypeFloat16 : kNumberTypeFloat32), input_shape);
auto tensor_x = tensor_x_ptr.get();
if (tensor_x == nullptr) {
MS_LOG(ERROR) << "tensor_x create error.";
......@@ -76,7 +80,8 @@ TEST_F(TestConv2dTransposeOpenCL, Conv2dTransposeFp32) {
}
std::vector<int> weight_shape = {co, kh, kw, ci};
auto tensor_w_ptr = std::make_unique<lite::tensor::Tensor>(TypeId(kNumberTypeFloat32), weight_shape);
auto tensor_w_ptr =
std::make_unique<lite::tensor::Tensor>(TypeId(fp16 ? kNumberTypeFloat16 : kNumberTypeFloat32), weight_shape);
auto tensor_w = tensor_w_ptr.get();
if (tensor_w == nullptr) {
MS_LOG(ERROR) << "tensor_w create error.";
......@@ -85,7 +90,8 @@ TEST_F(TestConv2dTransposeOpenCL, Conv2dTransposeFp32) {
tensor_w->SetData(weight_data);
std::vector<int> bias_shape = {co};
auto tensor_bias_ptr = std::make_unique<lite::tensor::Tensor>(TypeId(kNumberTypeFloat32), bias_shape);
auto tensor_bias_ptr =
std::make_unique<lite::tensor::Tensor>(TypeId(fp16 ? kNumberTypeFloat16 : kNumberTypeFloat32), bias_shape);
auto tensor_bias = tensor_bias_ptr.get();
if (tensor_bias == nullptr) {
MS_LOG(ERROR) << "tensor_bias create error.";
......@@ -94,7 +100,8 @@ TEST_F(TestConv2dTransposeOpenCL, Conv2dTransposeFp32) {
tensor_bias->SetData(bias_data);
std::vector<int> out_shape = {1, oh, ow, co};
auto tensor_out_ptr = std::make_unique<lite::tensor::Tensor>(TypeId(kNumberTypeFloat32), out_shape);
auto tensor_out_ptr =
std::make_unique<lite::tensor::Tensor>(TypeId(fp16 ? kNumberTypeFloat16 : kNumberTypeFloat32), out_shape);
auto tensor_out = tensor_out_ptr.get();
if (tensor_out == nullptr) {
MS_LOG(ERROR) << "tensor_out create error.";
......@@ -116,17 +123,18 @@ TEST_F(TestConv2dTransposeOpenCL, Conv2dTransposeFp32) {
opParameter->pad_w_ = pad;
opParameter->input_channel_ = ci;
opParameter->output_channel_ = co;
auto arith_kernel_ptr = std::make_unique<kernel::Conv2dTransposeOpenCLKernel>(
auto op_kernel_ptr = std::make_unique<kernel::Conv2dTransposeOpenCLKernel>(
reinterpret_cast<OpParameter *>(opParameter), inputs, outputs);
auto arith_kernel = arith_kernel_ptr.get();
if (arith_kernel == nullptr) {
MS_LOG(ERROR) << "arith_kernel create error.";
auto op_kernel = op_kernel_ptr.get();
if (op_kernel == nullptr) {
MS_LOG(ERROR) << "op_kernel create error.";
return;
}
arith_kernel->Init();
op_kernel->set_name("DeConv");
op_kernel->Init();
inputs[0]->MallocData(allocator);
std::vector<kernel::LiteKernel *> kernels{arith_kernel};
std::vector<kernel::LiteKernel *> kernels{op_kernel};
std::vector<lite::tensor::Tensor *> inputs_g{tensor_x};
auto pGraph_ptr = std::make_unique<kernel::SubGraphOpenCLKernel>(inputs_g, outputs, kernels, kernels, kernels);
auto pGraph = pGraph_ptr.get();
......@@ -138,13 +146,16 @@ TEST_F(TestConv2dTransposeOpenCL, Conv2dTransposeFp32) {
pGraph->Init();
memcpy(inputs[0]->Data(), input_data, input_size);
pGraph->Run();
using FLT = float;
if (fp16) {
using FLT = float16_t;
}
std::cout << "==================output data=================" << std::endl;
float *output_data = reinterpret_cast<float *>(tensor_out->Data());
FLT *output_data = reinterpret_cast<FLT *>(tensor_out->Data());
std::cout << std::endl;
size_t output_size;
std::string output_path = "./test_data/conv2d_transpose/conv2d_transpose_fp32_output.bin";
auto correct_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(output_path.c_str(), &output_size));
std::string output_path = file_path[3];
auto correct_data = reinterpret_cast<FLT *>(mindspore::lite::ReadFile(output_path.c_str(), &output_size));
if (correct_data == nullptr) {
MS_LOG(ERROR) << "correct_data create error.";
return;
......@@ -152,7 +163,7 @@ TEST_F(TestConv2dTransposeOpenCL, Conv2dTransposeFp32) {
int size_n = oh * ow * co;
size_n = size_n > 100 ? 100 : size_n;
for (int i = 0; i < size_n; i++) {
std::cout << output_data[i] << ", ";
std::cout << output_data[i] << ", " << correct_data[i] << " ";
if ((i + 1) % co == 0) {
std::cout << std::endl;
}
......@@ -160,10 +171,43 @@ TEST_F(TestConv2dTransposeOpenCL, Conv2dTransposeFp32) {
std::cout << std::endl;
// compare
CompareOutputData(output_data, correct_data, oh * ow * co, 0.00001);
CommonTest::CompareOutputData(output_data, correct_data, oh * ow * co, 0.00001);
inputs[0]->SetData(nullptr);
outputs[0]->SetData(nullptr);
MS_LOG(INFO) << "Test Conv2dTransposeFp32 passed";
lite::opencl::OpenCLRuntime::DeleteInstance();
}
TEST_F(TestConv2dTransposeOpenCL, Conv2dTransposeFp32) {
int pad = 0;
int n = 1;
int h = 240;
int w = 240;
int kh = 2;
int kw = 2;
int ci = 128;
int co = 128;
std::vector<int> shape = {pad, n, h, w, kh, kw, ci, co};
std::vector<std::string> file_path = {"./test_data/conv2d_transpose/conv2d_transpose_fp32_input.bin",
"./test_data/conv2d_transpose/conv2d_transpose_fp32_weight.bin",
"./test_data/conv2d_transpose/conv2d_transpose_fp32_bias.bin",
"./test_data/conv2d_transpose/conv2d_transpose_fp32_output.bin"};
RunTestCase(shape, file_path, false);
}
TEST_F(TestConv2dTransposeOpenCL, Conv2dTransposeFp16) {
int pad = 0;
int n = 1;
int h = 240;
int w = 240;
int kh = 2;
int kw = 2;
int ci = 128;
int co = 128;
std::vector<int> shape = {pad, n, h, w, kh, kw, ci, co};
std::vector<std::string> file_path = {"./test_data/conv2d_transpose/conv2d_transpose_fp16_input.bin",
"./test_data/conv2d_transpose/conv2d_transpose_fp16_weight.bin",
"./test_data/conv2d_transpose/conv2d_transpose_fp16_bias.bin",
"./test_data/conv2d_transpose/conv2d_transpose_fp16_output.bin"};
RunTestCase(shape, file_path, true);
}
} // namespace mindspore
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册