提交 46a2d163 编写于 作者: L liuzhongkai

add biasadd in opencl

上级 6eb98f28
#pragma OPENCL EXTENSION cl_arm_printf : enable
#define SLICES 4
#define UP_DIV(x, y) (((x) + (y) - (1)) / (y))
#define FLT4 float4
#define READ_FLT4 read_imagef
#define WRITE_FLT4 write_imagef
__constant sampler_t smp_zero = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
__kernel void BiasAdd(__read_only image2d_t input, __write_only image2d_t output, const int4 input_shape,
__global float *alpha, const int dim) {
int C = input_shape.w; // channel size
int Y = get_global_id(0); // height id
int X = get_global_id(1); // weight id
for (int num = 0; num < UP_DIV(C, SLICES); ++num) {
FLT4 in_c4 = READ_FLT4(input, smp_zero, (int2)(X * UP_DIV(C, SLICES) + num, Y)); // NHWC4: H WC
FLT4 tmp;
int index = 0;
if (dim == 2) {
index = X * 4;
} else {
index = num * 4;
}
tmp.x = in_c4.x + alpha[index];
tmp.y = in_c4.y + alpha[index + 1];
tmp.z = in_c4.z + alpha[index + 2];
tmp.w = in_c4.w + alpha[index + 3];
WRITE_FLT4(output, (int2)(X * UP_DIV(C, SLICES) + num, Y), tmp); // NHWC4: H WC
}
}
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <string>
#include <map>
#include <set>
#include <vector>
#include "src/kernel_registry.h"
#include "include/errorcode.h"
#include "src/runtime/kernel/opencl/kernel/biasadd.h"
#include "src/runtime/opencl/opencl_runtime.h"
#include "src/runtime/kernel/opencl/cl/biasadd.cl.inc"
using mindspore::kernel::KERNEL_ARCH::kGPU;
using mindspore::lite::KernelRegistrar;
using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_BiasAdd;
namespace mindspore::kernel {
void BiasAddOpenCLKernel::InitBuffer() {
int C = in_tensors_[1]->shape()[0];
int div_ci = UP_DIV(C, C4NUM);
auto allocator = lite::opencl::OpenCLRuntime::GetInstance()->GetAllocator();
BiasAdd_ = reinterpret_cast<FLOAT_t *>(allocator->Malloc(div_ci * C4NUM * sizeof(FLOAT_t)));
BiasAdd_ = reinterpret_cast<FLOAT_t *>(allocator->MapBuffer(BiasAdd_, CL_MAP_WRITE, nullptr, true));
memset(BiasAdd_, 0x00, div_ci * C4NUM * sizeof(FLOAT_t));
auto origin_weight = reinterpret_cast<FLOAT_t *>(in_tensors_[1]->Data());
for (int i = 0; i < in_tensors_[1]->ElementsNum(); ++i) {
BiasAdd_[i] = origin_weight[i];
}
allocator->UnmapBuffer(BiasAdd_);
}
int BiasAddOpenCLKernel::Init() {
in_size_ = in_tensors_[0]->shape().size();
out_size_ = out_tensors_[0]->shape().size();
if (in_size_ != 4 && in_size_ != 2) {
MS_LOG(ERROR) << "BiasAdd only support dim=4 or 2, but your dim=" << in_size_;
return RET_ERROR;
}
int C = in_tensors_[0]->shape()[3];
int Bias_Size = in_tensors_[1]->shape()[0];
if (UP_DIV(Bias_Size, C4NUM) != UP_DIV(C, C4NUM)) {
MS_LOG(ERROR) << "BiasAdd weight channel size:" << Bias_Size << " must be equal with in_teneors channel size:" << C;
return RET_ERROR;
}
InitBuffer();
std::set<std::string> build_options;
std::string source = biasadd_source;
std::string program_name = "BiasAdd";
std::string kernel_name = "BiasAdd";
auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance();
ocl_runtime->LoadSource(program_name, source);
ocl_runtime->BuildKernel(kernel_, program_name, kernel_name, build_options);
in_ori_format_ = in_tensors_[0]->GetFormat();
out_ori_format_ = out_tensors_[0]->GetFormat();
std::map<int, schema::Format> format{{4, schema::Format_NHWC4}, {2, schema::Format_NC4}};
if (format.count(out_size_) == 0) {
MS_LOG(ERROR) << "Not found output tensor format";
return RET_ERROR;
}
in_tensors_[0]->SetFormat(format[in_size_]);
out_tensors_[0]->SetFormat(format[out_size_]);
if (in_size_ == 2) {
in_ori_format_ = format[in_size_];
out_ori_format_ = format[out_size_];
}
MS_LOG(DEBUG) << program_name << " Init Done!";
return RET_OK;
}
int BiasAddOpenCLKernel::Run() {
cl_int4 input_shape = GetImg2dShape();
MS_LOG(DEBUG) << op_parameter_->name_ << " Running!";
auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance();
int arg_idx = 0;
ocl_runtime->SetKernelArg(kernel_, arg_idx++, in_tensors_[0]->Data());
ocl_runtime->SetKernelArg(kernel_, arg_idx++, out_tensors_[0]->Data());
ocl_runtime->SetKernelArg(kernel_, arg_idx++, input_shape);
ocl_runtime->SetKernelArg(kernel_, arg_idx++, BiasAdd_);
ocl_runtime->SetKernelArg(kernel_, arg_idx++, in_size_);
std::vector<size_t> local = {1, 1};
std::vector<size_t> global = {static_cast<size_t>(input_shape.s[1]), static_cast<size_t>(input_shape.s[2])};
auto ret = ocl_runtime->RunKernel(kernel_, global, local, nullptr);
if (ret != RET_OK) {
MS_LOG(ERROR) << "Run kernel " << op_parameter_->name_ << " error.";
return RET_ERROR;
}
return RET_OK;
}
cl_int4 BiasAddOpenCLKernel::GetImg2dShape() {
cl_int4 img2d_shape = {0, 0, 0, 0};
for (int i = 0; i < in_size_; ++i) {
img2d_shape.s[i + 4 - in_size_] = in_tensors_[0]->shape()[i];
}
if (in_size_ == 2) {
img2d_shape.s[1] = img2d_shape.s[2];
img2d_shape.s[2] = UP_DIV(img2d_shape.s[3], C4NUM);
img2d_shape.s[3] = C4NUM;
}
return img2d_shape;
}
int BiasAddOpenCLKernel::GetImageSize(size_t idx, std::vector<size_t> *img_size) {
cl_int4 img_shape = GetImg2dShape();
#ifdef ENABLE_FP16
size_t img_dtype = CL_HALF_FLOAT;
#else
size_t img_dtype = CL_FLOAT;
#endif
img_size->clear();
img_size->push_back(img_shape.s[2] * UP_DIV(img_shape.s[3], C4NUM));
img_size->push_back(img_shape.s[1]);
img_size->push_back(img_dtype);
return RET_OK;
}
kernel::LiteKernel *OpenCLBiasAddKernelCreator(const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs,
OpParameter *opParameter, const lite::Context *ctx,
const kernel::KernelKey &desc, const lite::PrimitiveC *primitive) {
if (inputs.size() == 0) {
MS_LOG(ERROR) << "Input data size must be greater than 0, but your size is " << inputs.size();
return nullptr;
}
if (inputs[0]->shape()[0] > 1) {
MS_LOG(ERROR) << "Input data size unsupported multi-batch.";
return nullptr;
}
auto *kernel = new (std::nothrow) BiasAddOpenCLKernel(reinterpret_cast<OpParameter *>(opParameter), inputs, outputs);
if (kernel == nullptr) {
MS_LOG(ERROR) << "Kernel " << opParameter->name_ << "is nullptr.";
return nullptr;
}
auto ret = kernel->Init();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Init BiasAdd kernel failed!";
delete kernel;
return nullptr;
}
return kernel;
}
REG_KERNEL(kGPU, kNumberTypeFloat32, PrimitiveType_BiasAdd, OpenCLBiasAddKernelCreator)
} // namespace mindspore::kernel
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_OPENCL_KERNEL_BIASADD_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_OPENCL_KERNEL_BIASADD_H_
#include <vector>
#include <string>
#include "src/ir/tensor.h"
#include "src/runtime/kernel/opencl/opencl_kernel.h"
#include "schema/model_generated.h"
#include "src/runtime/opencl/opencl_runtime.h"
namespace mindspore::kernel {
class BiasAddOpenCLKernel : public OpenCLKernel {
public:
explicit BiasAddOpenCLKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs)
: OpenCLKernel(parameter, inputs, outputs) {}
~BiasAddOpenCLKernel() override{};
int Init() override;
int Run() override;
int GetImageSize(size_t idx, std::vector<size_t> *img_size) override;
void InitBuffer();
cl_int4 GetImg2dShape();
private:
cl::Kernel kernel_;
FLOAT_t *BiasAdd_;
int in_size_;
int out_size_;
};
} // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_OPENCL_KERNEL_BIASADD_H_
......@@ -158,6 +158,7 @@ if (SUPPORT_GPU)
${LITE_DIR}/src/runtime/kernel/opencl/kernel/caffe_prelu.cc
${LITE_DIR}/src/runtime/kernel/opencl/kernel/prelu.cc
${LITE_DIR}/src/runtime/kernel/opencl/kernel/to_format.cc
${LITE_DIR}/src/runtime/kernel/opencl/kernel/biasadd.cc
)
endif()
### minddata lite
......@@ -338,6 +339,7 @@ if (SUPPORT_GPU)
${TEST_DIR}/ut/src/runtime/kernel/opencl/caffe_prelu_tests.cc
${TEST_DIR}/ut/src/runtime/kernel/opencl/prelu_tests.cc
${TEST_DIR}/ut/src/runtime/kernel/opencl/reshape_tests.cc
${TEST_DIR}/ut/src/runtime/kernel/opencl/biasadd_tests.cc
)
endif()
......
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <iostream>
#include "utils/log_adapter.h"
#include "common/common_test.h"
#include "mindspore/lite/src/common/file_utils.h"
#include "mindspore/lite/src/runtime/opencl/opencl_runtime.h"
#include "mindspore/lite/src/runtime/kernel/opencl/subgraph_opencl_kernel.h"
#include "mindspore/lite/src/runtime/kernel/opencl/kernel/biasadd.h"
using mindspore::kernel::BiasAddOpenCLKernel;
using mindspore::kernel::LiteKernel;
using mindspore::kernel::SubGraphOpenCLKernel;
using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_OK;
namespace mindspore {
class TestBiasAddOpenCL : public mindspore::CommonTest {};
void LoadDataBiasAdd(void *dst, size_t dst_size, const std::string &file_path) {
if (file_path.empty()) {
memset(dst, 0x00, dst_size);
} else {
auto src_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(file_path.c_str(), &dst_size));
memcpy(dst, src_data, dst_size);
}
}
void CompareOutBiasAdd(lite::tensor::Tensor *output_tensor, const std::string &standard_answer_file) {
auto *output_data = reinterpret_cast<float *>(output_tensor->Data());
size_t output_size = output_tensor->ElementsNum();
auto expect_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(standard_answer_file.c_str(), &output_size));
constexpr float atol = 0.0002;
for (int i = 0; i < output_tensor->ElementsNum(); ++i) {
if (std::fabs(output_data[i] - expect_data[i]) > atol) {
printf("error at idx[%d] expect=%.3f output=%.3f\n", i, expect_data[i], output_data[i]);
printf("error at idx[%d] expect=%.3f output=%.3f\n", i, expect_data[i], output_data[i]);
printf("error at idx[%d] expect=%.3f output=%.3f\n\n\n", i, expect_data[i], output_data[i]);
return;
}
}
printf("compare success!\n");
printf("compare success!\n");
printf("compare success!\n\n\n");
}
void printf_tensor_BiasAdd(mindspore::lite::tensor::Tensor *in_data, int size) {
auto input_data = reinterpret_cast<float *>(in_data->Data());
for (int i = 0; i < size; ++i) {
printf("%f ", input_data[i]);
}
printf("\n");
MS_LOG(INFO) << "Print tensor done";
}
void printf_float_BiasAdd(float *data, int num = 0) {
float *temp = data;
for (int i = 0; i < num; ++i) {
std::cout << *temp << " ";
temp++;
}
std::cout << std::endl;
}
TEST_F(TestBiasAddOpenCL, BiasAddFp32_dim4) {
std::string in_file = "/data/local/tmp/in_data.bin";
std::string weight_file = "/data/local/tmp/weight_data.bin";
std::string standard_answer_file = "/data/local/tmp/biasadd.bin";
MS_LOG(INFO) << "BiasAdd Begin test:";
auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance();
ocl_runtime->Init();
auto allocator = ocl_runtime->GetAllocator();
MS_LOG(INFO) << "BiasAdd init tensors.";
std::vector<int> input_shape = {1, 9};
std::vector<int> output_shape = {1, 9};
auto data_type = kNumberTypeFloat32;
auto tensor_type = schema::NodeType_ValueNode;
auto *input_tensor =
new (std::nothrow) lite::tensor::Tensor(data_type, input_shape, schema::Format_NC, tensor_type);
if (input_tensor == nullptr) {
MS_LOG(ERROR) << "new input tensor error!";
return;
}
auto *output_tensor =
new (std::nothrow) lite::tensor::Tensor(data_type, output_shape, schema::Format_NC, tensor_type);
if (output_tensor == nullptr) {
MS_LOG(ERROR) << "new output tensor error!";
delete input_tensor;
return;
}
auto *weight_tensor = new (std::nothrow)
lite::tensor::Tensor(data_type, std::vector<int>{input_shape[1]}, schema::Format_NHWC, tensor_type);
if (weight_tensor == nullptr) {
MS_LOG(ERROR) << "new weight tensor error!";
delete output_tensor;
delete input_tensor;
return;
}
std::vector<lite::tensor::Tensor *> inputs{input_tensor, weight_tensor};
std::vector<lite::tensor::Tensor *> outputs{output_tensor};
inputs[0]->MallocData(allocator);
inputs[1]->MallocData(allocator);
LoadDataBiasAdd(input_tensor->Data(), input_tensor->Size(), in_file);
MS_LOG(INFO) << "BiasAdd==================input data================";
printf_tensor_BiasAdd(inputs[0], input_tensor->ElementsNum());
LoadDataBiasAdd(weight_tensor->Data(), weight_tensor->Size(), weight_file);
MS_LOG(INFO) << "BiasAdd==================weight data================";
printf_tensor_BiasAdd(inputs[1], weight_tensor->ElementsNum());
auto *param = new (std::nothrow) OpParameter();
if (param == nullptr) {
delete input_tensor;
delete output_tensor;
delete weight_tensor;
MS_LOG(ERROR) << "new OpParameter error!";
return;
}
auto *biasadd_kernel =
new (std::nothrow) kernel::BiasAddOpenCLKernel(reinterpret_cast<OpParameter *>(param), inputs, outputs);
if (biasadd_kernel == nullptr) {
MS_LOG(ERROR) << "Create biasadd kernel error.";
delete input_tensor;
delete output_tensor;
delete weight_tensor;
delete param;
return;
}
auto ret = biasadd_kernel->Init();
if (ret != RET_OK) {
MS_LOG(ERROR) << "biasadd kernel init error.";
delete input_tensor;
delete output_tensor;
delete weight_tensor;
delete param;
delete biasadd_kernel;
return;
}
MS_LOG(INFO) << "initialize sub_graph";
std::vector<kernel::LiteKernel *> kernels{biasadd_kernel};
auto *sub_graph = new (std::nothrow) kernel::SubGraphOpenCLKernel({input_tensor}, outputs, kernels, kernels, kernels);
if (sub_graph == nullptr) {
MS_LOG(ERROR) << "Create sub_graph kernel error.";
delete input_tensor;
delete output_tensor;
delete weight_tensor;
delete param;
delete biasadd_kernel;
return;
}
ret = sub_graph->Init();
if (ret != RET_OK) {
MS_LOG(ERROR) << "sub_graph init error.";
delete input_tensor;
delete output_tensor;
delete weight_tensor;
delete sub_graph;
delete param;
delete biasadd_kernel;
return;
}
MS_LOG(INFO) << "Sub graph begin running!";
ret = sub_graph->Run();
if (ret != RET_OK) {
MS_LOG(ERROR) << "sub_graph run error.";
delete input_tensor;
delete output_tensor;
delete weight_tensor;
delete sub_graph;
delete param;
delete biasadd_kernel;
return;
}
MS_LOG(INFO) << "BiasAdd==================output data================";
printf_tensor_BiasAdd(outputs[0], output_tensor->ElementsNum());
CompareOutBiasAdd(output_tensor, standard_answer_file);
delete input_tensor;
delete weight_tensor;
delete output_tensor;
delete sub_graph;
delete param;
delete biasadd_kernel;
}
} // namespace mindspore
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册