提交 afd17a51 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!4086 Support image2d for opencl runtime

Merge pull request !4086 from wandongdong/master
......@@ -20,6 +20,7 @@
#include <vector>
#include "src/runtime/kernel/arm/fp32/arithmetic.h"
#include "src/runtime/opencl/opencl_runtime.h"
#include "src/runtime/kernel/opencl/opencl_kernel.h"
namespace mindspore::kernel {
......
......@@ -19,7 +19,7 @@
#include <vector>
#include "ir/anf.h"
#include "src/lite_kernel.h"
#include "src/runtime/kernel/opencl/opencl_kernel.h"
#include "src/runtime/opencl/opencl_runtime.h"
#include "src/runtime/kernel/arm/base/concat_base.h"
......
......@@ -19,7 +19,7 @@
#include <vector>
#include "src/lite_kernel.h"
#include "src/runtime/kernel/opencl/opencl_kernel.h"
#include "src/runtime/kernel/arm/opclib/conv_parameter.h"
#include "src/runtime/opencl/opencl_runtime.h"
......
......@@ -19,7 +19,7 @@
#include <vector>
#include "src/ir/tensor.h"
#include "src/lite_kernel.h"
#include "src/runtime/kernel/opencl/opencl_kernel.h"
#include "schema/model_generated.h"
#include "src/runtime/opencl/opencl_runtime.h"
#include "src/runtime/kernel/arm/opclib/conv_parameter.h"
......
......@@ -17,10 +17,12 @@
#include "src/runtime/kernel/opencl/kernel/depthwise_conv2d.h"
#include <string>
#include <set>
#include <utility>
#include "src/kernel_registry.h"
#include "src/runtime/opencl/opencl_runtime.h"
#include "src/runtime/kernel/arm/fp32/convolution_depthwise.h"
#include "src/runtime/kernel/arm/opclib/pack.h"
#include "include/errorcode.h"
#ifndef PROGRAM_WITH_IL
......@@ -29,9 +31,12 @@
#endif
using mindspore::schema::PrimitiveType_DepthwiseConv2D;
using mindspore::kernel::KERNEL_ARCH::kGPU;
using mindspore::lite::KernelRegistrar;
using mindspore::schema::PrimitiveType_DepthwiseConv2D;
using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_OK;
namespace mindspore::kernel {
......@@ -72,8 +77,8 @@ int DepthwiseConv2dOpenCLKernel::Init() {
ocl_runtime->BuildKernel(kernel_, program_name, kernel_name, build_options);
#endif
this->InitBuffer();
MS_LOG(DEBUG) << kernel_name << " Init Done!";
return 0;
MS_LOG(DEBUG) << kernel_name << " Init Done! mem type=" << static_cast<int>(mem_type_);
return RET_OK;
}
int DepthwiseConv2dOpenCLKernel::InitBuffer() {
......@@ -109,10 +114,46 @@ int DepthwiseConv2dOpenCLKernel::InitBuffer() {
} else {
MS_ASSERT(inputs_.size() == kInputSize1);
}
return 0;
return RET_OK;
}
int DepthwiseConv2dOpenCLKernel::ReSize() { return 0; }
int DepthwiseConv2dOpenCLKernel::ReSize() {
return RET_OK;
}
int DepthwiseConv2dOpenCLKernel::GetImageSize(size_t idx, std::vector<size_t>* img_size) {
size_t CO4 = UP_DIV(outputs_[0]->Channel(), C4NUM);
size_t im_dst_x, im_dst_y;
if (inputs_[0]->GetFormat() == schema::Format_NHWC4) {
im_dst_x = outputs_[0]->Width() * CO4;
im_dst_y = outputs_[0]->Height();
} else {
im_dst_y = outputs_[0]->Height() * CO4;
im_dst_x = outputs_[0]->Width();
}
#ifdef ENABLE_FP16
size_t img_dtype = CL_HALF_FLOAT;
#else
size_t img_dtype = CL_FLOAT;
#endif
img_size->clear();
std::vector<size_t> vec{im_dst_x, im_dst_y, img_dtype};
*img_size = vec;
return RET_OK;
}
int DepthwiseConv2dOpenCLKernel::GetGlobalSize(size_t idx, std::vector<size_t>* global_size) {
size_t CO4 = UP_DIV(outputs_[0]->Channel(), C4NUM);
std::vector <size_t> global = {(size_t) outputs_[0]->Width(), (size_t) outputs_[0]->Height(), CO4};
*global_size = std::move(global);
return RET_OK;
}
int DepthwiseConv2dOpenCLKernel::GetLocalSize(size_t idx, const std::vector<size_t>& global_size,
std::vector<size_t>* local_size) {
size_t CO4 = UP_DIV(outputs_[0]->Channel(), C4NUM);
std::vector <size_t> local = {1, 1, CO4};
*local_size = std::move(local);
return RET_OK;
}
int DepthwiseConv2dOpenCLKernel::Run() {
MS_LOG(DEBUG) << this->Name() << " Running!";
......@@ -120,8 +161,9 @@ int DepthwiseConv2dOpenCLKernel::Run() {
auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance();
size_t CO4 = UP_DIV(outputs_[0]->Channel(), C4NUM);
size_t CI4 = UP_DIV(inputs_[0]->Channel(), C4NUM);
std::vector<size_t> global = {(size_t)outputs_[0]->Width(), (size_t)outputs_[0]->Height(), CO4};
std::vector<size_t> local = {1, 1, CO4};
std::vector <size_t> global = {(size_t) outputs_[0]->Width(), (size_t) outputs_[0]->Height(), CO4};
std::vector <size_t> local;
GetLocalSize(0, global, &local);
float relu_clip1 = 6.0;
cl_int2 kernel_size = {parameter->kernel_h_, parameter->kernel_w_};
......@@ -141,53 +183,10 @@ int DepthwiseConv2dOpenCLKernel::Run() {
ocl_runtime->SetKernelArg(kernel_, 8, dilation);
ocl_runtime->SetKernelArg(kernel_, 9, src_size);
ocl_runtime->SetKernelArg(kernel_, 10, dst_size);
if (mem_type_ == MEM_TYPE::BUF) {
ocl_runtime->SetKernelArg(kernel_, 0, inputs_[0]->Data());
ocl_runtime->SetKernelArg(kernel_, 4, outputs_[0]->Data());
ocl_runtime->RunKernel(kernel_, global, local, nullptr);
} else {
cl::ImageFormat image_format;
{
image_format.image_channel_order = CL_RGBA;
image_format.image_channel_data_type = CL_FLOAT;
}
cl_int in_error_code;
size_t im_src_x, im_src_y;
size_t im_dst_x, im_dst_y;
if (inputs_[0]->GetFormat() == schema::Format_NHWC4) {
im_src_x = inputs_[0]->Width() * CI4;
im_src_y = inputs_[0]->Height();
im_dst_x = outputs_[0]->Width() * CO4;
im_dst_y = outputs_[0]->Height();
} else {
im_src_y = inputs_[0]->Height() * CI4;
im_src_x = inputs_[0]->Width();
im_dst_y = outputs_[0]->Height() * CO4;
im_dst_x = outputs_[0]->Width();
}
cl::Image2D in_mem(*ocl_runtime->Context(), CL_MEM_READ_ONLY | CL_MEM_COPY_HOST_PTR, image_format, im_src_x,
im_src_y, 0, inputs_[0]->Data(), &in_error_code);
cl_int out_error_code;
cl::Image2D out_mem(*ocl_runtime->Context(), CL_MEM_WRITE_ONLY, image_format, im_dst_x, im_dst_y, 0, nullptr,
&out_error_code);
if (in_error_code != CL_SUCCESS) {
MS_LOG(DEBUG) << "in Image2D Failed, error=" << in_error_code;
return 1;
}
if (out_error_code != CL_SUCCESS) {
MS_LOG(DEBUG) << "out Image2D Failed, error= " << out_error_code;
return 1;
}
auto origin = cl::array<cl::size_type, 3U>{0, 0, 0};
auto region = cl::array<cl::size_type, 3U>{im_dst_x, im_dst_y, 1};
ocl_runtime->SetKernelArg(kernel_, 0, in_mem);
ocl_runtime->SetKernelArg(kernel_, 4, out_mem);
ocl_runtime->RunKernel(kernel_, global, local, nullptr);
ocl_runtime->GetDefaultCommandQueue()->enqueueReadImage(out_mem, CL_TRUE, origin, region, 0, 0,
outputs_[0]->Data());
}
return 0;
ocl_runtime->SetKernelArg(kernel_, 0, inputs_[0]->Data());
ocl_runtime->SetKernelArg(kernel_, 4, outputs_[0]->Data());
ocl_runtime->RunKernel(kernel_, global, local, nullptr);
return RET_OK;
}
kernel::LiteKernel *OpenCLDepthwiseConv2dKernelCreator(const std::vector<lite::tensor::Tensor *> &inputs,
......
......@@ -18,17 +18,17 @@
#define MINDSPORE_LITE_SRC_BACKEND_OPENCL_DEPTHWISE_H_
#include <vector>
#include "src/lite_kernel.h"
#include "src/runtime/kernel/opencl/opencl_kernel.h"
#include "src/runtime/kernel/arm/opclib/conv_parameter.h"
#include "src/runtime/opencl/opencl_runtime.h"
namespace mindspore::kernel {
class DepthwiseConv2dOpenCLKernel : public LiteKernel {
class DepthwiseConv2dOpenCLKernel : public OpenCLKernel {
public:
explicit DepthwiseConv2dOpenCLKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs)
: LiteKernel(parameter, inputs, outputs),
const std::vector<lite::tensor::Tensor *> &outputs)
: OpenCLKernel(parameter, inputs, outputs),
packed_weight_(nullptr), bias_data_(nullptr), kernel_(nullptr) {}
~DepthwiseConv2dOpenCLKernel() override {};
......@@ -41,13 +41,18 @@ class DepthwiseConv2dOpenCLKernel : public LiteKernel {
int InitBuffer();
int GetImageSize(size_t idx, std::vector<size_t>* img_size) override;
int GetGlobalSize(size_t idx, std::vector<size_t>* global_size) override;
int GetLocalSize(size_t idx, const std::vector<size_t>& global_size,
std::vector<size_t>* local_size) override;
private:
FLOAT_t *packed_weight_;
FLOAT_t *bias_data_;
cl::Kernel kernel_;
enum class MEM_TYPE {
BUF, IMG
} mem_type_{MEM_TYPE::BUF};
} mem_type_{MEM_TYPE::IMG};
};
} // namespace mindspore::kernel
......
......@@ -19,7 +19,7 @@
#include <vector>
#include "src/lite_kernel.h"
#include "src/runtime/kernel/opencl/opencl_kernel.h"
#include "src/runtime/kernel/arm/opclib/conv_parameter.h"
#include "src/runtime/opencl/opencl_runtime.h"
......
......@@ -19,7 +19,7 @@
#include <vector>
#include "src/lite_kernel.h"
#include "src/runtime/kernel/opencl/opencl_kernel.h"
#include "src/runtime/kernel/arm/opclib/fp32/pooling.h"
#include "src/runtime/opencl/opencl_runtime.h"
......
......@@ -19,7 +19,7 @@
#include <vector>
#include "src/lite_kernel.h"
#include "src/runtime/kernel/opencl/opencl_kernel.h"
#include "src/runtime/kernel/arm/opclib/fp32/softmax.h"
#include "src/runtime/opencl/opencl_runtime.h"
......
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_OPENCL_KERNEL_H_
#define MINDSPORE_LITE_SRC_OPENCL_KERNEL_H_
#include <vector>
#include "src/lite_kernel.h"
namespace mindspore::kernel {
class OpenCLKernel : public LiteKernel {
public:
explicit OpenCLKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs)
: LiteKernel(parameter, inputs, outputs) {}
virtual int Init() { return -1; }
virtual int Prepare() { return -1; }
virtual int InferShape() { return -1; }
virtual int ReSize() { return -1; }
virtual int Run() { return -1; }
virtual int GetImageSize(size_t idx, std::vector<size_t>* img_size) { return -1; }
virtual int GetGlobalSize(size_t idx, std::vector<size_t>* global_size) { return -1; }
virtual int GetLocalSize(size_t idx, const std::vector<size_t>& global_size,
std::vector<size_t>* local_size) { return -1; }
};
} // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_OPENCL_KERNEL_H_
......@@ -32,9 +32,10 @@ int SubGraphOpenCLKernel::Init() {
}
// Map buffer for write, it is not necessary for fine-grained
for (auto &tensor : inputs_) {
void *data = allocator_->MapBuffer(tensor->Data(), CL_MAP_WRITE, nullptr, true);
void *data = tensor->Data();
// It is required with coarse-grained SVM
if (data != nullptr) {
data = allocator_->MapBuffer(data, CL_MAP_WRITE, nullptr, true);
tensor->SetData(data);
} else {
MS_LOG(ERROR) << "OpenCL kernel must use GPU buffer pointer, "
......
......@@ -18,7 +18,7 @@
#define MINDSPORE_LITE_SRC_BACKEND_OPENCL_SUBGRAPH_OPENCL_KENEL_H_
#include <vector>
#include "src/lite_kernel.h"
#include "src/runtime/kernel/opencl/opencl_kernel.h"
#include "src/runtime/opencl/opencl_allocator.h"
namespace mindspore::kernel {
......
......@@ -21,6 +21,7 @@
#include <vector>
#include "CL/cl2.hpp"
#include "utils/log_adapter.h"
#include "src/runtime/kernel/arm/opclib/op_base.h"
namespace mindspore::kernel {
......@@ -81,7 +82,6 @@ std::vector<size_t> GetLocalSize(const std::vector<size_t> &global, int max_size
std::string CLErrorCode(cl_int error_code);
} // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_BACKEND_OPENCL_UTILS_H_
......
......@@ -18,6 +18,7 @@
#include <utility>
#include "utils/log_adapter.h"
#include "src/runtime/opencl/opencl_runtime.h"
#include "include/errorcode.h"
namespace mindspore::lite::opencl {
......@@ -61,7 +62,7 @@ void *OpenCLAllocator::Malloc(size_t size) {
auto svm_capabilities = ocl_runtime->GetSVMCapabilities();
void *host_ptr = nullptr;
void *device_ptr = nullptr;
if (svm_capabilities) {
if (svm_capabilities && svm_on_) {
cl_svm_mem_flags flags = (svm_capabilities & CL_DEVICE_SVM_FINE_GRAIN_BUFFER) ? CL_MEM_SVM_FINE_GRAIN_BUFFER : 0;
flags |= (svm_capabilities & CL_DEVICE_SVM_ATOMICS) ? CL_MEM_SVM_ATOMICS : 0;
flags = flags | CL_MEM_READ_WRITE;
......@@ -69,7 +70,7 @@ void *OpenCLAllocator::Malloc(size_t size) {
} else {
cl_int ret = CL_SUCCESS;
cl::Buffer *buffer =
new cl::Buffer(*ocl_runtime->Context(), CL_MEM_READ_WRITE | CL_MEM_ALLOC_HOST_PTR, size, NULL, &ret);
new cl::Buffer(*ocl_runtime->Context(), CL_MEM_READ_WRITE | CL_MEM_ALLOC_HOST_PTR, size, NULL, &ret);
if (ret != CL_SUCCESS) {
MS_LOG(ERROR) << "Create OpenCL buffer failed! (ERROR CODE: " << ret << ")";
UnLock();
......@@ -77,7 +78,13 @@ void *OpenCLAllocator::Malloc(size_t size) {
}
device_ptr = static_cast<void *>(buffer);
host_ptr = ocl_runtime->MapBuffer(*buffer, CL_MAP_READ | CL_MAP_WRITE, size);
ocl_runtime->UnmapBuffer(*buffer, host_ptr);
if (host_ptr == nullptr) {
MS_LOG(ERROR) << "Map buffer failed, can not found buffer :" << device_ptr << ", host_ptr=" << host_ptr;
UnLock();
return nullptr;
}
cl::Memory *mem = buffer;
ocl_runtime->UnmapBuffer(*mem, host_ptr);
}
std::unique_ptr<MemBuf> mem_buf = std::make_unique<MemBuf>();
mem_buf->size_ = size;
......@@ -90,6 +97,113 @@ void *OpenCLAllocator::Malloc(size_t size) {
return host_ptr;
}
void *OpenCLAllocator::Malloc(size_t size, const std::vector<size_t>& img_size) {
if (size > MAX_MALLOC_SIZE) {
MS_LOG(ERROR) << "MallocData out of max_size, size: " << size;
return nullptr;
}
auto ocl_runtime = opencl::OpenCLRuntime::GetInstance();
Lock();
auto iter = free_list_.lower_bound(size);
if (iter != free_list_.end() && (iter->second->size_ >= size) && (iter->second->size_ < (size << shift_factor_))) {
auto mem_buf = iter->second;
bool is_match{mem_buf->img_size.size() == img_size.size()};
for (int i = 0; i < img_size.size() && is_match; ++i) {
is_match = img_size[i] == mem_buf->img_size[i];
}
if (is_match) {
free_list_.erase(iter);
allocated_list_[mem_buf->host_ptr_] = mem_buf;
UnLock();
MS_LOG(DEBUG) << "Malloc Image2D from free list. size: " << mem_buf->size_
<< ", host addr: " << mem_buf->host_ptr_ << ", device addr: " << mem_buf->device_ptr_;
return mem_buf->host_ptr_;
}
}
void *host_ptr = nullptr;
void *device_ptr = nullptr;
cl_int ret = CL_SUCCESS;
// CL_HALF_FLOAT, CL_FLOAT
cl::ImageFormat image_format(CL_RGBA, img_size[2]);
cl::Image2D *buffer = new cl::Image2D(*ocl_runtime->Context(), CL_MEM_READ_WRITE,
image_format, img_size[0], img_size[1], 0, nullptr, &ret);
if (ret != CL_SUCCESS) {
MS_LOG(ERROR) << "Create OpenCL Image2D failed! (ERROR CODE: " << ret << ")";
UnLock();
return nullptr;
}
device_ptr = static_cast<void *>(buffer);
std::vector<size_t> region{img_size[0], img_size[1], 1};
host_ptr = ocl_runtime->MapBuffer(*buffer, 0, CL_MAP_READ | CL_MAP_WRITE, region);
if (host_ptr == nullptr) {
MS_LOG(ERROR) << "Map buffer failed, can not found buffer :" << device_ptr << ", host_ptr=" << host_ptr;
UnLock();
return nullptr;
}
cl::Memory *mem = buffer;
ocl_runtime->UnmapBuffer(*mem, host_ptr);
std::unique_ptr<MemBuf> mem_buf = std::make_unique<MemBuf>();
mem_buf->size_ = size;
mem_buf->device_ptr_ = device_ptr;
mem_buf->host_ptr_ = host_ptr;
mem_buf->img_size = img_size;
MS_LOG(DEBUG) << "Malloc a new Image2D. size: " << mem_buf->size_ << ", host addr: " << mem_buf->host_ptr_
<< ", device addr: " << mem_buf->device_ptr_;
allocated_list_[host_ptr] = mem_buf.release();
UnLock();
return host_ptr;
}
void *OpenCLAllocator::CreateImageFromHost(void *data, size_t size, const std::vector<size_t>& img_size) {
if (size > MAX_MALLOC_SIZE) {
MS_LOG(ERROR) << "MallocData out of max_size, size: " << size;
return nullptr;
}
auto ocl_runtime = opencl::OpenCLRuntime::GetInstance();
Lock();
auto iter = free_list_.lower_bound(size);
if (iter != free_list_.end() && (iter->second->size_ >= size) && (iter->second->size_ < (size << shift_factor_))) {
auto mem_buf = iter->second;
free_list_.erase(iter);
allocated_list_[mem_buf->host_ptr_] = mem_buf;
UnLock();
MS_LOG(DEBUG) << "Malloc Image2D from free list. size: " << mem_buf->size_ << ", host addr: " << mem_buf->host_ptr_
<< ", device addr: " << mem_buf->device_ptr_;
return mem_buf->host_ptr_;
}
void *host_ptr = nullptr;
void *device_ptr = nullptr;
cl_int ret = CL_SUCCESS;
// CL_HALF_FLOAT, CL_FLOAT
cl::ImageFormat image_format(CL_RGBA, img_size[2]);
cl::Image2D *buffer = new cl::Image2D(*ocl_runtime->Context(), CL_MEM_READ_ONLY | CL_MEM_COPY_HOST_PTR, image_format,
img_size[0], img_size[1], 0, data, &ret);
if (ret != CL_SUCCESS) {
MS_LOG(ERROR) << "Create OpenCL Image2D failed! (ERROR CODE: " << ret << ")";
UnLock();
return nullptr;
}
device_ptr = static_cast<void *>(buffer);
std::vector<size_t> region{img_size[0], img_size[1], 1};
host_ptr = ocl_runtime->MapBuffer(*buffer, 0, CL_MAP_READ | CL_MAP_WRITE, region);
if (host_ptr == nullptr) {
MS_LOG(ERROR) << "Map buffer failed, can not found buffer :" << device_ptr << ", host_ptr=" << host_ptr;
UnLock();
return nullptr;
}
cl::Memory *mem = buffer;
ocl_runtime->UnmapBuffer(*mem, host_ptr);
std::unique_ptr<MemBuf> mem_buf = std::make_unique<MemBuf>();
mem_buf->size_ = size;
mem_buf->device_ptr_ = device_ptr;
mem_buf->host_ptr_ = host_ptr;
mem_buf->img_size = img_size;
MS_LOG(DEBUG) << "Malloc a new Image2D. size: " << mem_buf->size_ << ", host addr: " << mem_buf->host_ptr_
<< ", device addr: " << mem_buf->device_ptr_;
allocated_list_[host_ptr] = mem_buf.release();
UnLock();
return host_ptr;
}
void OpenCLAllocator::Free(void *buf) {
if (buf == nullptr) {
return;
......@@ -163,7 +277,7 @@ void OpenCLAllocator::Clear() {
void *OpenCLAllocator::MapBuffer(void *host_ptr, int flags, void *command_queue, bool sync) {
auto ocl_runtime = opencl::OpenCLRuntime::GetInstance();
auto svm_capabilities = ocl_runtime->GetSVMCapabilities();
if (svm_capabilities) {
if (svm_capabilities && svm_on_) {
if (!(svm_capabilities & CL_DEVICE_SVM_FINE_GRAIN_BUFFER)) {
auto it = allocated_list_.find(host_ptr);
if (it == allocated_list_.end()) {
......@@ -178,11 +292,25 @@ void *OpenCLAllocator::MapBuffer(void *host_ptr, int flags, void *command_queue,
auto it = allocated_list_.find(host_ptr);
if (it == allocated_list_.end()) {
MS_LOG(ERROR) << "Map buffer failed, can not found buffer :" << host_ptr;
UnLock();
return nullptr;
}
MemBuf *mem_buf = it->second;
cl::Buffer *buffer = static_cast<cl::Buffer *>(mem_buf->device_ptr_);
void *new_host_ptr = ocl_runtime->MapBuffer(*buffer, flags, mem_buf->size_, nullptr, sync);
void *new_host_ptr{nullptr};
if (mem_buf->img_size.empty()) {
cl::Buffer *buffer = static_cast<cl::Buffer *>(mem_buf->device_ptr_);
new_host_ptr = ocl_runtime->MapBuffer(*buffer, flags, mem_buf->size_, nullptr, sync);
} else {
cl::ImageFormat image_format(CL_RGBA, mem_buf->img_size[2]);
std::vector<size_t> region{mem_buf->img_size[0], mem_buf->img_size[1], 1};
cl::Image2D *buffer = static_cast<cl::Image2D *>(mem_buf->device_ptr_);
new_host_ptr = ocl_runtime->MapBuffer(*buffer, 0, CL_MAP_READ | CL_MAP_WRITE, region);
}
if (new_host_ptr == nullptr) {
MS_LOG(ERROR) << "Map buffer failed, can not found buffer :" << mem_buf->device_ptr_ << ", host_ptr=" << host_ptr;
UnLock();
return nullptr;
}
mem_buf->host_ptr_ = new_host_ptr;
allocated_list_.erase(it);
allocated_list_[new_host_ptr] = mem_buf;
......@@ -208,5 +336,40 @@ int OpenCLAllocator::UnmapBuffer(void *host_ptr, void *command_queue) {
return ocl_runtime->UnmapBuffer(*buffer, it->second->host_ptr_, static_cast<cl::CommandQueue *>(command_queue));
}
MEM_TYPE OpenCLAllocator::GetMemType(void *host_ptr) {
MEM_TYPE mem_type{MEM_TYPE::BUF};
Lock();
auto it = allocated_list_.find(host_ptr);
if (it == allocated_list_.end()) {
MS_LOG(ERROR) << "Can not found buffer :" << host_ptr;
UnLock();
return mem_type;
}
MemBuf *mem_buf = it->second;
if (mem_buf->img_size.empty()) {
mem_type = MEM_TYPE::BUF;
} else {
mem_type = MEM_TYPE::IMG;
}
UnLock();
return mem_type;
}
int OpenCLAllocator::GetImageSize(void *host_ptr, std::vector<size_t>* img_size) {
Lock();
auto it = allocated_list_.find(host_ptr);
if (it == allocated_list_.end()) {
MS_LOG(ERROR) << "Can not found buffer :" << host_ptr;
UnLock();
return RET_OK;
}
MemBuf *mem_buf = it->second;
if (!mem_buf->img_size.empty()) {
*img_size = mem_buf->img_size;
}
UnLock();
return RET_OK;
}
} // namespace mindspore::lite::opencl
......@@ -39,18 +39,27 @@ struct OpenclMemory {
OpenCLMemoryType mem_type{MS_HOST_BUFFER | MS_CL_BUFFER};
};
enum class MEM_TYPE : char {
BUF, IMG
};
class OpenCLAllocator : public Allocator {
public:
OpenCLAllocator();
~OpenCLAllocator() override;
void SetContext(const AllocatorContext &ctx) override;
void *Malloc(size_t size) override;
void *Malloc(size_t size, const std::vector<size_t>& img_size);
void *CreateImageFromHost(void *host_ptr, size_t size, const std::vector<size_t>& img_size);
void Free(void *ptr) override;
size_t GetTotalSize() override;
void Clear() override;
void *GetDeviceBuffer(void *buffer);
void *MapBuffer(void *host_ptr, int flags, void *command_queue = nullptr, bool sync = true);
int UnmapBuffer(void *host_ptr, void *command_queue = nullptr);
MEM_TYPE GetMemType(void *host_ptr);
int GetImageSize(void *host_ptr, std::vector<size_t>* img_size);
private:
void Lock();
......@@ -59,6 +68,7 @@ class OpenCLAllocator : public Allocator {
size_t size_;
void *device_ptr_;
void *host_ptr_;
std::vector<size_t> img_size;
};
std::mutex lock;
......@@ -68,6 +78,7 @@ class OpenCLAllocator : public Allocator {
// 6 is empirical value
int shift_factor_ = 6;
bool lock_flag_ = false;
bool svm_on_{false};
};
} // namespace mindspore::lite::opencl
......
......@@ -15,9 +15,10 @@
*/
#include "src/runtime/opencl/opencl_executor.h"
#include "src/runtime/kernel/opencl/utils.h"
#include "src/runtime/kernel/arm/opclib/pack.h"
#include "include/errorcode.h"
#include "src/common/ms_tensor_utils.h"
#include "include/errorcode.h"
namespace mindspore::lite::opencl {
int OpenCLExecutor::Run(std::vector<tensor::Tensor *> &inputs, std::vector<tensor::Tensor *> &outputs,
......@@ -29,23 +30,32 @@ int OpenCLExecutor::Run(std::vector<tensor::Tensor *> &inputs, std::vector<tenso
MS_LOG(ERROR) << "Graph input tensor is nullptr";
return RET_ERROR;
}
if (inTensor->GetFormat() != schema::Format_NHWC4 && inTensor->GetFormat() != schema::Format_NC4HW4) {
if (inTensor->GetFormat() != schema::Format_NHWC) {
MS_LOG(ERROR) << "Model input should be NHWC, actual is " << schema::EnumNameFormat(inTensor->GetFormat());
return RET_ERROR;
} else {
TransformTensorLayout(inTensor, schema::Format_NHWC4);
// TransformTensorLayout(inTensor, schema::Format_NC4HW4);
}
if (inTensor->GetFormat() != schema::Format_NHWC4 && inTensor->GetFormat() != schema::Format_NC4HW4 &&
inTensor->GetFormat() != schema::Format_NHWC) {
MS_LOG(ERROR) << "input should be NHWC/NHWC4/NC4HW4, actual is " << schema::EnumNameFormat(inTensor->GetFormat());
return RET_ERROR;
} else {
TransformTensorLayout(inTensor, inTensor->GetFormat(), schema::Format_NHWC4, true);
// TransformTensorLayout(inTensor, inTensor->GetFormat(), schema::Format_NC4HW4, true);
}
}
kernel::LiteKernelUtil::InitTensorRefCount(kernels);
OpenCLAllocator* op_allocator = reinterpret_cast<OpenCLAllocator*>(allocator);
for (auto *kernel : kernels) {
MS_ASSERT(nullptr != kernel);
kernel::OpenCLKernel *op_kernel = reinterpret_cast<kernel::OpenCLKernel*>(kernel);
auto &outputs = kernel->GetOutputs();
for (auto *output : outputs) {
for (auto i = 0; i < outputs.size(); ++i) {
auto *output = outputs.at(i);
MS_ASSERT(nullptr != output);
output->MallocData();
if (is_image2d_out_) {
std::vector<size_t> img_size;
op_kernel->GetImageSize(i, &img_size);
auto data_ptr = op_allocator->Malloc(output->Size(), img_size);
output->SetData(data_ptr);
} else {
output->MallocData(allocator);
}
}
session::CallBackParam callbackParam;
callbackParam.name_callback_param = kernel->Name();
......@@ -81,21 +91,22 @@ int OpenCLExecutor::Run(std::vector<tensor::Tensor *> &inputs, std::vector<tenso
return RET_ERROR;
}
if (outTensor->GetFormat() != schema::Format_NHWC) {
MS_LOG(ERROR) << "Model output tensor should be NHWC";
TransformTensorLayout(outTensor, outTensor->GetFormat(), schema::Format_NHWC, false);
}
}
return RET_OK;
}
int OpenCLExecutor::TransformTensorLayout(tensor::Tensor *tensor, schema::Format dst_format) {
int OpenCLExecutor::TransformTensorLayout(tensor::Tensor *tensor, schema::Format src_format,
schema::Format dst_format, bool trans_dir) {
MS_ASSERT(nullptr != tensor);
MS_ASSERT(4 == tensor->shape().size());
auto data_type = tensor->data_type();
switch (data_type) {
case kNumberTypeInt8:
return TransformTensorLayoutUint8(tensor, dst_format);
return TransformTensorLayoutUint8(tensor, src_format, dst_format, trans_dir);
case kNumberTypeFloat32:
return TransformTensorLayoutFp32(tensor, dst_format);
return TransformTensorLayoutFp32(tensor, src_format, dst_format, trans_dir);
default:
MS_LOG(ERROR) << "Unsupport layout transform: " << schema::EnumNameFormat(tensor->GetFormat()) << " to "
<< schema::EnumNameFormat(dst_format);
......@@ -104,21 +115,103 @@ int OpenCLExecutor::TransformTensorLayout(tensor::Tensor *tensor, schema::Format
return RET_OK;
}
int OpenCLExecutor::TransformTensorLayoutFp32(tensor::Tensor *tensor, schema::Format dst_format) {
int OpenCLExecutor::TransformTensorLayoutFp32(tensor::Tensor *tensor, schema::Format src_format,
schema::Format dst_format, bool trans_dir) {
MS_ASSERT(nullptr != tensor);
MS_ASSERT(nullptr != allocator_);
MS_ASSERT(4 == tensor->shape().size());
if (trans_dir) {
if (is_image2d_out_) {
return TransformTensorLayoutToImage(tensor, src_format, dst_format);
} else {
return TransformTensorLayoutToBuffer(tensor, src_format, dst_format);
}
} else {
if (is_image2d_out_) {
return TransformTensorLayoutFromImage(tensor, src_format, dst_format);
} else {
return TransformTensorLayoutToBuffer(tensor, src_format, dst_format);
}
}
}
int OpenCLExecutor::TransformTensorLayoutToBuffer(tensor::Tensor *tensor, schema::Format src_format,
schema::Format dst_format) {
if (dst_format == schema::Format_NHWC4) {
auto *src_data = tensor->Data();
auto *dst_data = allocator_->Malloc(tensor->Size());
if (dst_data == nullptr) {
MS_LOG(ERROR) << "Malloc data failed";
return RET_ERROR;
size_t C4 = UP_DIV(tensor->Channel(), C4NUM);
std::vector <size_t> img_size{tensor->Width() * C4, (size_t) tensor->Height(), CL_FLOAT};
if (src_format == schema::Format_NHWC) {
auto *dst_data = allocator_->Malloc(tensor->Size(), img_size);
if (dst_data == nullptr) {
MS_LOG(ERROR) << "Malloc data failed";
return RET_ERROR;
}
dst_data = reinterpret_cast<FLOAT_t *>(allocator_->MapBuffer(dst_data, CL_MAP_WRITE, nullptr, true));
PackNHWCToNHWC4Fp32(src_data, dst_data, tensor->Batch(), tensor->Height() * tensor->Width(), tensor->Channel());
tensor->SetData(dst_data);
allocator_->Free(src_data);
allocator_->UnmapBuffer(dst_data);
}
dst_data = reinterpret_cast<FLOAT_t *>(allocator_->MapBuffer(dst_data, CL_MAP_WRITE, nullptr, true));
PackNHWCToNHWC4Fp32(src_data, dst_data, tensor->Batch(), tensor->Height() * tensor->Width(), tensor->Channel());
tensor->SetData(dst_data);
tensor->SetFormat(dst_format);
return RET_OK;
} else if (dst_format == schema::Format_NHWC) {
// TODO(wandongdong): add support !!
return RET_OK;
} else {
MS_LOG(ERROR) << "Unsupport layout transform: " << schema::EnumNameFormat(tensor->GetFormat()) << " to "
<< schema::EnumNameFormat(dst_format) << " in float32";
return RET_ERROR;
}
}
int OpenCLExecutor::TransformTensorLayoutToImage(tensor::Tensor *tensor, schema::Format src_format,
schema::Format dst_format) {
if (dst_format == schema::Format_NHWC4) {
// convert to nhwc4
auto *src_data = tensor->Data();
auto *dst_data{src_data};
if (src_format == schema::Format_NHWC) {
dst_data = allocator_->Malloc(tensor->Size());
if (dst_data == nullptr) {
MS_LOG(ERROR) << "Malloc data failed";
return RET_ERROR;
}
dst_data = reinterpret_cast<FLOAT_t *>(allocator_->MapBuffer(dst_data, CL_MAP_WRITE, nullptr, true));
PackNHWCToNHWC4Fp32(src_data, dst_data, tensor->Batch(), tensor->Height() * tensor->Width(), tensor->Channel());
tensor->SetData(dst_data);
allocator_->Free(src_data);
allocator_->UnmapBuffer(dst_data);
}
// copy to image2d
src_data = dst_data;
size_t C4 = UP_DIV(tensor->Channel(), C4NUM);
std::vector<size_t> img_size{tensor->Width() * C4, (size_t)tensor->Height(), CL_FLOAT};
dst_data = allocator_->CreateImageFromHost(src_data, tensor->Size(), img_size);
tensor->SetData(dst_data);
allocator_->Free(src_data);
tensor->SetFormat(schema::Format_NHWC4);
return RET_OK;
} else {
MS_LOG(ERROR) << "Unsupport layout transform: " << schema::EnumNameFormat(tensor->GetFormat()) << " to "
<< schema::EnumNameFormat(dst_format) << " in float32";
return RET_ERROR;
}
}
int OpenCLExecutor::TransformTensorLayoutFromImage(tensor::Tensor *tensor, schema::Format src_format,
schema::Format dst_format) {
if (dst_format == schema::Format_NHWC) {
auto src_data = tensor->Data();
auto dst_data = allocator_->Malloc(tensor->Size());
cl::Image2D *out_mem = reinterpret_cast<cl::Image2D *>(allocator_->GetDeviceBuffer(src_data));
std::vector<size_t> img_size;
allocator_->GetImageSize(src_data, &img_size);
auto origin = cl::array < cl::size_type, 3U > {0, 0, 0};
auto region = cl::array < cl::size_type, 3U > {img_size[0], img_size[1], 1};
auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance();
ocl_runtime->GetDefaultCommandQueue()->enqueueReadImage(*out_mem, CL_TRUE, origin, region, 0, 0, dst_data);
tensor->SetData(dst_data);
allocator_->Free(src_data);
return RET_OK;
} else {
......@@ -128,7 +221,8 @@ int OpenCLExecutor::TransformTensorLayoutFp32(tensor::Tensor *tensor, schema::Fo
}
}
int OpenCLExecutor::TransformTensorLayoutUint8(tensor::Tensor *tensor, schema::Format dst_format) {
int OpenCLExecutor::TransformTensorLayoutUint8(tensor::Tensor *tensor, schema::Format src_format,
schema::Format dst_format, bool is_image) {
MS_ASSERT(nullptr != tensor);
MS_ASSERT(4 == tensor->shape().size());
// auto src_format = tensor->GetFormat();
......
......@@ -20,7 +20,7 @@
#include <vector>
#include "src/runtime/opencl/opencl_runtime.h"
#include "src/runtime/allocator.h"
#include "src/lite_kernel.h"
#include "src/runtime/kernel/opencl/opencl_kernel.h"
#include "src/executor.h"
#include "include/lite_session.h"
......@@ -38,15 +38,25 @@ class OpenCLExecutor : Executor {
const session::KernelCallBack &before = nullptr, const session::KernelCallBack &after = nullptr);
protected:
int TransformTensorLayoutFp32(tensor::Tensor *tensor, schema::Format dst_format);
int TransformTensorLayoutFp32(tensor::Tensor *tensor, schema::Format src_format, schema::Format dst_format,
bool trans_dir = false);
int TransformTensorLayoutUint8(tensor::Tensor *tensor, schema::Format dst_format);
int TransformTensorLayoutUint8(tensor::Tensor *tensor, schema::Format src_format, schema::Format dst_format,
bool trans_dir = false);
int TransformTensorLayout(tensor::Tensor *tensor, schema::Format dst_format);
int TransformTensorLayout(tensor::Tensor *tensor, schema::Format src_format, schema::Format dst_format,
bool trans_dir = false);
int TransformTensorLayoutToBuffer(tensor::Tensor *tensor, schema::Format src_format, schema::Format dst_format);
int TransformTensorLayoutToImage(tensor::Tensor *tensor, schema::Format src_format, schema::Format dst_format);
int TransformTensorLayoutFromImage(tensor::Tensor *tensor, schema::Format src_format, schema::Format dst_format);
protected:
Context *context = nullptr;
OpenCLAllocator *allocator_;
bool is_image2d_out_{true};
};
} // namespace mindspore::lite::opencl
......
......@@ -124,8 +124,13 @@ int OpenCLRuntime::Init() {
const std::string device_name = device_->getInfo<CL_DEVICE_NAME>();
const std::string device_version = device_->getInfo<CL_DEVICE_VERSION>();
const std::string opencl_version = device_->getInfo<CL_DEVICE_OPENCL_C_VERSION>();
cl_uint align;
size_t ret;
clGetDeviceInfo((*device_)(), CL_DEVICE_IMAGE_PITCH_ALIGNMENT, sizeof(cl_uint), &align, &ret);
MS_LOG(INFO) << "Device name:\t" << device_name;
MS_LOG(INFO) << "Opencl version:\t" << device_version;
MS_LOG(INFO) << "Image alignment:\t" << align;
MS_LOG(INFO) << "Image ret:\t" << ret;
MS_LOG(INFO) << "Highest OpenCL c version:\t" << opencl_version;
MS_LOG(INFO) << "Max work item size:\t"
<< max_work_item_sizes_[0] << " : "
......@@ -133,7 +138,6 @@ int OpenCLRuntime::Init() {
<< max_work_item_sizes_[2];
gpu_info_ = ParseGpuInfo(device_name, device_version);
cl_int err;
#if defined(SHARING_MEM_WITH_OPENGL) && (CL_HPP_TARGET_OPENCL_VERSION >= 120)
// create context from glcontext
......@@ -164,6 +168,7 @@ int OpenCLRuntime::Init() {
support_fp16_ = CL_SUCCESS == success && fp_config > 0;
err = device_->getInfo(CL_DEVICE_SVM_CAPABILITIES, &svm_capabilities_);
svm_capabilities_ = 0;
if (err != CL_SUCCESS || svm_capabilities_ == 0) {
svm_capabilities_ = 0;
MS_LOG(INFO) << "SVM capalibilties: "
......@@ -535,7 +540,19 @@ int OpenCLRuntime::MapBuffer(void *host_ptr, int flags, size_t size, cl::Command
return command_queue->enqueueMapSVM(host_ptr, sync, flags, size);
}
int OpenCLRuntime::UnmapBuffer(const cl::Buffer buffer, void *host_ptr, cl::CommandQueue *command_queue) const {
void *OpenCLRuntime::MapBuffer(const cl::Image2D buffer, bool sync, int flags,
const std::vector<size_t>& region, cl::CommandQueue *command_queue) const {
if (command_queue == nullptr) {
command_queue = default_command_queue_.get();
}
cl::size_type row_pitch;
cl::size_type slice_pitch;
cl::array<cl::size_type, 3> origin_{0, 0, 0};
cl::array<cl::size_type, 3> region_{region[0], region[1], region[2]};
return command_queue->enqueueMapImage(buffer, sync, flags, origin_, region_, &row_pitch, &slice_pitch);
}
int OpenCLRuntime::UnmapBuffer(const cl::Memory buffer, void *host_ptr, cl::CommandQueue *command_queue) const {
if (command_queue == nullptr) {
command_queue = default_command_queue_.get();
}
......
......@@ -75,9 +75,16 @@ class OpenCLRuntime {
MS_LOG(DEBUG) << "Set kernel arg[" << index << "] SVM pointer " << value;
return clSetKernelArgSVMPointer(kernel, index, value);
} else {
cl::Buffer *buffer = reinterpret_cast<cl::Buffer *>(allocator_->GetDeviceBuffer(value));
MS_LOG(DEBUG) << "Set kernel arg[" << index << "] OpenCL Buffer " << value;
return clSetKernelArg(kernel, index, sizeof((*buffer)()), &(*buffer)());
MEM_TYPE mem_type = allocator_->GetMemType(value);
if (mem_type == MEM_TYPE::BUF) {
cl::Buffer *buffer = reinterpret_cast<cl::Buffer *>(allocator_->GetDeviceBuffer(value));
MS_LOG(DEBUG) << "Set kernel arg[" << index << "] OpenCL Buffer " << value;
return clSetKernelArg(kernel, index, sizeof((*buffer)()), &(*buffer)());
} else {
cl::Image2D *buffer = reinterpret_cast<cl::Image2D *>(allocator_->GetDeviceBuffer(value));
MS_LOG(DEBUG) << "Set kernel arg[" << index << "] OpenCL Image2D " << value;
return clSetKernelArg(kernel, index, sizeof((*buffer)()), &(*buffer)());
}
}
}
......@@ -107,9 +114,11 @@ class OpenCLRuntime {
bool sync = false) const;
void *MapBuffer(const cl::Buffer buffer, int map_flags, size_t size, cl::CommandQueue *command_queue = nullptr,
bool sync = false) const;
void *MapBuffer(const cl::Image2D buffer, bool sync, int flags,
const std::vector<size_t>& region, cl::CommandQueue *command_queue = nullptr) const;
int MapBuffer(void *host_ptr, int map_flags, size_t size, cl::CommandQueue *command_queue = nullptr,
bool sync = false) const;
int UnmapBuffer(const cl::Buffer buffer, void *host_ptr, cl::CommandQueue *command_queue = nullptr) const;
int UnmapBuffer(const cl::Memory buffer, void *host_ptr, cl::CommandQueue *command_queue = nullptr) const;
int UnmapBuffer(void *host_ptr, cl::CommandQueue *command_queue = nullptr) const;
bool SyncCommandQueue(cl::CommandQueue *command_queue = nullptr);
......
......@@ -35,6 +35,8 @@
a = nullptr; \
}
bool IMAGE2D_OPEN = true;
namespace mindspore {
class TestConvolutionDwOpenCL : public mindspore::Common {
public:
......@@ -95,6 +97,18 @@ void DepthWiseTestMain(ConvParameter *conv_param, float_t *input_data, float_t *
std::vector<kernel::LiteKernel *> kernels{pKernel};
std::vector<lite::tensor::Tensor *> inputs_{tensor_a};
size_t C4 = UP_DIV(inputs[0]->Channel(), C4NUM);
// if (IMAGE2D_OPEN && format == schema::Format_NHWC4) {
// std::vector<size_t> img_size{inputs[0]->Width() * C4, (size_t)inputs[0]->Height(), CL_FLOAT};
// auto in_data = allocator->Malloc(inputs[0]->Size(), img_size);
// inputs[0]->SetData(in_data);
// } else if (IMAGE2D_OPEN && format == schema::Format_NC4HW4) {
// std::vector<size_t> img_size{(size_t)inputs[0]->Width(), inputs[0]->Height() * C4, CL_FLOAT};
// auto in_data = allocator->Malloc(inputs[0]->Size(), img_size);
// inputs[0]->SetData(in_data);
// } else {
inputs[0]->MallocData(allocator);
// }
auto *pGraph = new kernel::SubGraphOpenCLKernel(inputs_, outputs, kernels, kernels, kernels);
pGraph->Init();
......@@ -103,9 +117,9 @@ void DepthWiseTestMain(ConvParameter *conv_param, float_t *input_data, float_t *
pGraph->Run();
if (is_compare) {
float* packed_output = reinterpret_cast<float *>(outputs[0]->Data());
float *packed_correct_data = new float[packed_output_size];
memset(packed_correct_data, 0, packed_output_size * sizeof(float));
float_t* packed_output = reinterpret_cast<float *>(outputs[0]->Data());
float_t *packed_correct_data = new float_t[packed_output_size];
memset(packed_correct_data, 0, packed_output_size * sizeof(float_t));
if (format == schema::Format_NC4HW4) {
PackNHWCToNC4HW4Fp32(gnd_data, packed_correct_data, conv_param->output_batch_,
conv_param->output_h_ * conv_param->output_w_, conv_param->output_channel_);
......@@ -128,7 +142,7 @@ void DepthWiseTestMain(ConvParameter *conv_param, float_t *input_data, float_t *
std::cout << std::endl;
printf("==================output data=================\n");
std::cout << std::endl;
for (int i = 0; i < packed_output_size; i++) {
for (int i = 0; i < 80/*packed_output_size*/; i++) {
std::cout << packed_output[i] << ", ";
}
std::cout << std::endl;
......@@ -142,13 +156,13 @@ void DepthWiseTestMain(ConvParameter *conv_param, float_t *input_data, float_t *
SAFE_DELETE_ARRAY(packed_correct_data)
}
inputs[1]->SetData(nullptr);
inputs[2]->SetData(nullptr);
SAFE_DELETE_ARRAY(packed_input);
for (auto tensor : inputs) {
tensor->SetData(nullptr);
SAFE_DELETE_PTR(tensor)
}
for (auto tensor : outputs) {
tensor->SetData(nullptr);
SAFE_DELETE_PTR(tensor)
}
SAFE_DELETE_PTR(pKernel)
......@@ -477,6 +491,7 @@ TEST_F(TestConvolutionDwOpenCL, ConvDwNoPadFp32) {
std::vector<kernel::LiteKernel *> kernels{pKernel};
std::vector<lite::tensor::Tensor *> inputs_{tensor_a};
inputs[0]->MallocData();
auto *pGraph = new kernel::SubGraphOpenCLKernel(inputs_, outputs, kernels, kernels, kernels);
pGraph->Init();
......@@ -516,12 +531,12 @@ TEST_F(TestConvolutionDwOpenCL, ConvDwNoPadFp32) {
// compare
Common::CompareOutputData(packed_output, packed_correct_data, packed_output_size, 0.00001);
inputs[1]->SetData(nullptr);
inputs[2]->SetData(nullptr);
for (auto tensor : inputs) {
tensor->SetData(nullptr);
SAFE_DELETE_PTR(tensor)
}
for (auto tensor : outputs) {
tensor->SetData(nullptr);
SAFE_DELETE_PTR(tensor)
}
SAFE_DELETE_PTR(pKernel)
......@@ -640,6 +655,7 @@ TEST_F(TestConvolutionDwOpenCL, ConvDwPadFp32) {
std::vector<kernel::LiteKernel *> kernels{pKernel};
std::vector<lite::tensor::Tensor *> inputs_{tensor_a};
inputs[0]->MallocData();
auto *pGraph = new kernel::SubGraphOpenCLKernel(inputs_, outputs, kernels, kernels, kernels);
pGraph->Init();
......@@ -687,14 +703,14 @@ TEST_F(TestConvolutionDwOpenCL, ConvDwPadFp32) {
// compare
Common::CompareOutputData(packed_output, packed_correct_data, packed_output_size, 0.00001);
inputs[1]->SetData(nullptr);
inputs[2]->SetData(nullptr);
SAFE_DELETE_ARRAY(packed_input);
SAFE_DELETE_ARRAY(packed_correct_data)
for (auto tensor : inputs) {
tensor->SetData(nullptr);
SAFE_DELETE_PTR(tensor)
}
for (auto tensor : outputs) {
tensor->SetData(nullptr);
SAFE_DELETE_PTR(tensor)
}
SAFE_DELETE_PTR(pKernel)
......@@ -742,35 +758,27 @@ TEST_F(TestConvolutionDwOpenCL, ProfilingMobilenetv2) {
};
// nhwc
float_t *input_data = new float_t[96*112*112]{
0.5488135 , 0.3834415 , 0.77815676, 0.9446689 , 0.6120957 ,
0.71518934, 0.79172504, 0.87001216, 0.5218483 , 0.616934 ,
0.60276335, 0.5288949 , 0.9786183 , 0.41466194, 0.94374806,
0.5448832 , 0.56804454, 0.7991586 , 0.2645556 , 0.6818203 ,
0.4236548 , 0.92559665, 0.46147937, 0.7742337 , 0.3595079 ,
0.6458941 , 0.07103606, 0.7805292 , 0.45615032, 0.43703195,
0.4375872 , 0.0871293 , 0.11827443, 0.56843394, 0.6976312 ,
0.891773 , 0.0202184 , 0.639921 , 0.0187898 , 0.06022547,
0.96366274, 0.83261985, 0.14335328, 0.6176355 , 0.6667667 };
size_t in_size = 96*112*112;
float_t *input_data = new float_t[in_size];
memset(input_data, 0, in_size);
for (auto i = 0; i < in_size; ++i) {
input_data[i] = 1;
}
// co h w ci
float_t *weight_data = new float_t[576*3*3]{
0.67063785, 0.21038257, 0.12892629,
0.31542835, 0.36371076, 0.57019675,
0.43860152, 0.9883738 , 0.10204481,
0.20887676, 0.16130951, 0.6531083 ,
0.2532916 , 0.46631077, 0.2444256 ,
0.15896958, 0.11037514, 0.6563296 ,
0.13818295, 0.19658236, 0.36872518,
0.82099324, 0.09710128, 0.8379449 ,
0.09609841, 0.97645944, 0.4686512 ,
0.9767611 , 0.6048455 , 0.7392636 ,
0.03918779, 0.28280696, 0.12019656,
0.2961402 , 0.11872772, 0.31798318,
0.41426298, 0.06414749, 0.6924721 ,
0.56660146, 0.2653895 , 0.5232481 ,
0.09394051, 0.5759465 , 0.9292962 };
size_t wt_size = 576*3*3;
float_t *weight_data = new float_t[wt_size];
memset(weight_data, 0, wt_size);
for (auto i = 0; i < wt_size; ++i) {
weight_data[i] = 1;
}
size_t out_size = 96*112*112;
float_t *gnd_data = new float_t[out_size];
memset(gnd_data, 0, out_size);
// for (auto i = 0; i < in_size; ++i) {
// gnd_data[i] = 1;
// }
for (size_t i = 0; i < src_shape.size(); ++i) {
const int MAX_RUN_TIMES = 10;
const int MAX_RUN_TIMES = 1;
for (int j = 0; j < MAX_RUN_TIMES; ++j) {
printf("========profiling depthwise, in shape(%d,%d,%d,%d), out shape(%d,%d,%d,%d), iter%d========\n",
src_shape[i][0], src_shape[i][1], src_shape[i][2], src_shape[i][3],
......@@ -794,8 +802,8 @@ TEST_F(TestConvolutionDwOpenCL, ProfilingMobilenetv2) {
conv_param->dilation_h_ = 1;
conv_param->dilation_w_ = 1;
}
DepthWiseTestMain(conv_param, input_data, weight_data, nullptr, schema::Format_NC4HW4, false);
// DepthWiseTestMain(conv_param, input_data, weight_data, nullptr, schema::Format_NHWC4, false);
// DepthWiseTestMain(conv_param, input_data, weight_data, gnd_data, schema::Format_NC4HW4, false);
DepthWiseTestMain(conv_param, input_data, weight_data, nullptr, schema::Format_NHWC4, false);
}
}
SAFE_DELETE_ARRAY(input_data);
......@@ -803,4 +811,54 @@ TEST_F(TestConvolutionDwOpenCL, ProfilingMobilenetv2) {
lite::opencl::OpenCLRuntime::DeleteInstance();
}
TEST_F(TestConvolutionDwOpenCL, Buffer2Image) {
std::vector<int> src_shape{1, 96, 64, 64};
std::vector<int> dst_shape{1, 96, 32, 32};
std::vector<int> filter_shape{96, 3, 3, 1};
// nhwc
size_t in_size = 96*112*112;
float_t *input_data = new float_t[in_size];
memset(input_data, 0, in_size);
for (auto i = 0; i < in_size; ++i) {
input_data[i] = 1;
}
// co h w ci
size_t wt_size = 576*3*3;
float_t *weight_data = new float_t[wt_size];
memset(weight_data, 0, wt_size);
for (auto i = 0; i < wt_size; ++i) {
weight_data[i] = 1;
}
size_t out_size = 96*112*112;
float_t *gnd_data = new float_t[out_size];
memset(gnd_data, 0, out_size);
// for (auto i = 0; i < in_size; ++i) {
// gnd_data[i] = 1;
// }
ConvParameter *conv_param = new ConvParameter();
{
conv_param->input_batch_ = 1;
conv_param->input_h_ = src_shape[2];
conv_param->input_w_ = src_shape[3];
conv_param->input_channel_ = src_shape[1];
conv_param->output_batch_ = 1;
conv_param->output_h_ = dst_shape[2];
conv_param->output_w_ = dst_shape[3];
conv_param->output_channel_ = dst_shape[1];
conv_param->kernel_h_ = filter_shape[1];
conv_param->kernel_w_ = filter_shape[2];
conv_param->stride_h_ = conv_param->output_h_/conv_param->input_h_;
conv_param->stride_w_ = conv_param->output_w_/conv_param->input_w_;
conv_param->pad_h_ = (conv_param->kernel_h_-1)/2;
conv_param->pad_w_ = (conv_param->kernel_w_-1)/2;
conv_param->dilation_h_ = 1;
conv_param->dilation_w_ = 1;
}
// DepthWiseTestMain(conv_param, input_data, weight_data, gnd_data, schema::Format_NC4HW4, true);
DepthWiseTestMain(conv_param, input_data, weight_data, gnd_data, schema::Format_NHWC4, true);
SAFE_DELETE_ARRAY(input_data);
SAFE_DELETE_ARRAY(weight_data);
lite::opencl::OpenCLRuntime::DeleteInstance();
}
} // namespace mindspore
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册