提交 722bcc23 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!5529 add biasadd support in arithmetic for opencl

Merge pull request !5529 from liuchao/master
......@@ -30,6 +30,14 @@ using mindspore::lite::KernelRegistrar;
namespace mindspore::kernel {
ArithmeticOpenCLKernel::~ArithmeticOpenCLKernel() {
if (weight_ptr_ != nullptr) {
auto allocator = runtime_->GetAllocator();
allocator->Free(weight_ptr_);
weight_ptr_ = nullptr;
}
}
std::vector<size_t> ArithmeticOpenCLKernel::InitGlobalSize() const {
const size_t global_x = out_tensors_[0]->Width();
const size_t global_y = out_tensors_[0]->Height();
......@@ -39,10 +47,18 @@ std::vector<size_t> ArithmeticOpenCLKernel::InitGlobalSize() const {
}
void ArithmeticOpenCLKernel::Image2dGetWorkGroupSize() {
size_t H = out_tensors_[0]->Batch() * out_tensors_[0]->Height();
size_t W = out_tensors_[0]->Width() * UP_DIV(out_tensors_[0]->Channel(), C4NUM);
local_size_ = {16, 16};
global_size_ = {W, H};
if (out_tensors_[0]->GetFormat() == schema::Format_NHWC4) {
size_t H = out_tensors_[0]->Batch() * out_tensors_[0]->Height();
size_t W = out_tensors_[0]->Width() * UP_DIV(out_tensors_[0]->Channel(), C4NUM);
global_size_ = {W, H};
} else if (out_tensors_[0]->GetFormat() == schema::Format_NC4) {
size_t H = out_tensors_[0]->Batch();
size_t W = UP_DIV(out_tensors_[0]->Channel(), C4NUM);
global_size_ = {W, H};
} else {
MS_LOG(ERROR) << "Unspport data format " << out_tensors_[0]->GetFormat();
}
}
void ArithmeticOpenCLKernel::BufferGetWorkGroupSize() {
......@@ -51,16 +67,16 @@ void ArithmeticOpenCLKernel::BufferGetWorkGroupSize() {
}
int ArithmeticOpenCLKernel::GetImageSize(size_t idx, std::vector<size_t> *img_size) {
size_t CO4 = UP_DIV(out_tensors_[0]->Channel(), C4NUM);
int H = out_tensors_[0]->Batch() * out_tensors_[0]->Height();
int W = out_tensors_[0]->Width() * CO4;
size_t im_dst_x, im_dst_y;
if (in_tensors_[0]->GetFormat() == schema::Format_NHWC4) {
im_dst_x = W;
im_dst_y = H;
if (out_tensors_[0]->GetFormat() == schema::Format_NHWC4) {
im_dst_x = out_tensors_[0]->Width() * UP_DIV(out_tensors_[0]->Channel(), C4NUM);
im_dst_y = out_tensors_[0]->Batch() * out_tensors_[0]->Height();
} else if (out_tensors_[0]->GetFormat() == schema::Format_NC4) {
im_dst_y = out_tensors_[0]->Batch();
im_dst_x = UP_DIV(out_tensors_[0]->Channel(), C4NUM);
} else {
im_dst_y = out_tensors_[0]->Batch() * out_tensors_[0]->Height() * CO4;
im_dst_x = out_tensors_[0]->Width();
MS_LOG(ERROR) << "Unspport data format " << out_tensors_[0]->GetFormat();
return RET_ERROR;
}
#ifdef ENABLE_FP16
size_t img_dtype = CL_HALF_FLOAT;
......@@ -73,11 +89,26 @@ int ArithmeticOpenCLKernel::GetImageSize(size_t idx, std::vector<size_t> *img_si
return RET_OK;
}
int ArithmeticOpenCLKernel::InitBuffer() {
const ArithmeticParameter *arithmetic_parameter = reinterpret_cast<const ArithmeticParameter *>(op_parameter_);
if (!arithmetic_parameter->broadcasting_) {
if (in_tensors_[1]->TensorType() == schema::NodeType_ValueNode && in_tensors_[1]->Data() != nullptr) {
auto allocatdor = runtime_->GetAllocator();
std::vector<size_t> img_size;
GetImageSize(0, &img_size);
weight_ptr_ = allocatdor->CreateImageFromHost(in_tensors_[1]->Data(), in_tensors_[1]->ElementsNum(), img_size);
return RET_OK;
}
}
return RET_OK;
}
int ArithmeticOpenCLKernel::Init() {
runtime_ = lite::opencl::OpenCLRuntime::GetInstance();
std::string kernel_name;
if (in_tensors_[1]->TensorType() == schema::NodeType_ValueNode && in_tensors_[1]->Data() != nullptr) {
const ArithmeticParameter *arithmetic_parameter = reinterpret_cast<const ArithmeticParameter *>(op_parameter_);
if (arithmetic_parameter->broadcasting_ && in_tensors_[1]->TensorType() == schema::NodeType_ValueNode &&
in_tensors_[1]->Data() != nullptr) {
element_flag_ = false;
kernel_name = "BoardcastArith";
} else {
......@@ -103,7 +134,7 @@ int ArithmeticOpenCLKernel::Init() {
lite::STATUS error_code = RET_OK;
#ifdef PROGRAM_WITH_IL
kernel_ = ocl_runtime->GetKernelFromBinary(kernel_name);
kernel_ = runtime_->GetKernelFromBinary(kernel_name);
#else
if (out_mem_type_ == OpenCLMemType::IMG) {
kernel_name += "_IMG";
......@@ -119,23 +150,31 @@ int ArithmeticOpenCLKernel::Init() {
if (error_code != RET_OK) {
return error_code;
}
auto format = schema::Format_NHWC4;
if (arithmetic_parameter->ndim_ == 2) {
format = schema::Format_NC4;
}
in_ori_format_ = in_tensors_[0]->GetFormat();
in_tensors_[0]->SetFormat(schema::Format_NHWC4);
out_ori_format_ = out_tensors_[0]->GetFormat();
out_tensors_[0]->SetFormat(schema::Format_NHWC4);
in_tensors_[0]->SetFormat(format);
if (element_flag_ && in_tensors_[1]->TensorType() != schema::NodeType_ValueNode) {
in_tensors_[1]->SetFormat(format);
}
out_tensors_[0]->SetFormat(format);
Image2dGetWorkGroupSize();
InitBuffer();
return RET_OK;
}
int ArithmeticOpenCLKernel::Run() {
MS_LOG(DEBUG) << this->name() << " Running!";
auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance();
int arg_idx = 0;
ocl_runtime->SetKernelArg(kernel_, arg_idx++, in_tensors_[0]->Data());
runtime_->SetKernelArg(kernel_, arg_idx++, in_tensors_[0]->Data());
if (element_flag_) {
runtime_->SetKernelArg(kernel_, arg_idx++, in_tensors_[1]->Data());
void *weight = weight_ptr_ == nullptr ? in_tensors_[1]->Data() : weight_ptr_;
runtime_->SetKernelArg(kernel_, arg_idx++, weight);
} else {
float value = static_cast<float *>(in_tensors_[1]->Data())[0];
switch (op_parameter_->type_) {
......@@ -155,23 +194,47 @@ int ArithmeticOpenCLKernel::Run() {
MS_LOG(ERROR) << "Error Operator type " << op_parameter_->type_;
break;
}
ocl_runtime->SetKernelArg(kernel_, arg_idx++, weight_);
ocl_runtime->SetKernelArg(kernel_, arg_idx++, bias_);
runtime_->SetKernelArg(kernel_, arg_idx++, weight_);
runtime_->SetKernelArg(kernel_, arg_idx++, bias_);
}
runtime_->SetKernelArg(kernel_, arg_idx++, out_tensors_[0]->Data());
int H = 0;
int W = 0;
if (out_tensors_[0]->GetFormat() == schema::Format_NHWC4) {
H = out_tensors_[0]->Batch() * out_tensors_[0]->Height();
W = out_tensors_[0]->Width() * UP_DIV(out_tensors_[0]->Channel(), C4NUM);
} else if (out_tensors_[0]->GetFormat() == schema::Format_NC4) {
H = out_tensors_[0]->Batch();
W = UP_DIV(out_tensors_[0]->Channel(), C4NUM);
} else {
MS_LOG(ERROR) << "Error output type " << out_tensors_[0]->GetFormat();
return RET_ERROR;
}
ocl_runtime->SetKernelArg(kernel_, arg_idx++, out_tensors_[0]->Data());
int H = out_tensors_[0]->Batch() * out_tensors_[0]->Height();
int W = out_tensors_[0]->Width() * UP_DIV(out_tensors_[0]->Channel(), C4NUM);
cl_int2 output_shape{W, H};
ocl_runtime->SetKernelArg(kernel_, arg_idx++, output_shape);
ocl_runtime->RunKernel(kernel_, global_size_, local_size_, nullptr);
runtime_->SetKernelArg(kernel_, arg_idx++, output_shape);
runtime_->RunKernel(kernel_, global_size_, local_size_, nullptr);
return RET_OK;
}
kernel::LiteKernel *OpenCLBiasAddKernelCreator(const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs,
OpParameter *opParameter, const lite::Context *ctx,
const kernel::KernelKey &desc, const lite::PrimitiveC *primitive);
kernel::LiteKernel *OpenCLArithmeticKernelCreator(const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs,
OpParameter *opParameter, const lite::Context *ctx,
const kernel::KernelKey &desc,
const mindspore::lite::PrimitiveC *primitive) {
const ArithmeticParameter *arithmetic_parameter = reinterpret_cast<const ArithmeticParameter *>(opParameter);
if (arithmetic_parameter->broadcasting_) {
for (size_t i = 0; i < arithmetic_parameter->ndim_; i++) {
if (arithmetic_parameter->in_shape1_[i] != 0 && arithmetic_parameter->in_shape1_[i] != 1) {
return OpenCLBiasAddKernelCreator(inputs, outputs, opParameter, ctx, desc, primitive);
}
}
}
auto *kernel =
new (std::nothrow) ArithmeticOpenCLKernel(reinterpret_cast<OpParameter *>(opParameter), inputs, outputs, ctx);
if (kernel == nullptr) {
......
......@@ -29,7 +29,7 @@ class ArithmeticOpenCLKernel : public OpenCLKernel {
explicit ArithmeticOpenCLKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs, const lite::Context *ctx)
: OpenCLKernel(parameter, inputs, outputs) {}
~ArithmeticOpenCLKernel() override{};
~ArithmeticOpenCLKernel() override;
int Init() override;
int Run() override;
......@@ -39,12 +39,14 @@ class ArithmeticOpenCLKernel : public OpenCLKernel {
std::vector<size_t> InitGlobalSize() const;
void Image2dGetWorkGroupSize();
void BufferGetWorkGroupSize();
int InitBuffer();
cl::Kernel kernel_;
lite::opencl::OpenCLRuntime *runtime_;
bool element_flag_{true};
float weight_{1.f};
float bias_{.0f};
void *weight_ptr_{nullptr};
std::vector<size_t> local_size_;
std::vector<size_t> global_size_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册