提交 cf4533da 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!3810 update the performance of concat ops

Merge pull request !3810 from pengyongrong/master
//#pragma OPENCL EXTENSION cl_khr_fp16 : enable
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
__kernel void Concat(__global float *input0, __global float *input1, __global float *output, const int4 input_shape0,
const int4 input_shape1, const int4 output_shape, const int axis) {
int postion = 0, index_input_shape0 = 0, index_input_shape1 = 0;
switch (axis) {
case 1:
for (int i = 0; i < output_shape.x; i++) {
for (int j = 0; j < output_shape.y; j++) {
for (int k = 0; k < output_shape.z; k++) {
for (int w = 0; w < output_shape.w; w++) {
postion = i * output_shape.y * output_shape.z * output_shape.w + j * output_shape.z * output_shape.w +
k * output_shape.w + w;
if (j < input_shape0.y) {
output[postion] = input0[index_input_shape0++];
} else {
output[postion] = input1[index_input_shape1++];
}
}
}
}
}
break;
case 2:
for (int i = 0; i < output_shape.x; i++) {
for (int j = 0; j < output_shape.y; j++) {
for (int k = 0; k < output_shape.z; k++) {
for (int w = 0; w < output_shape.w; w++) {
postion = i * output_shape.y * output_shape.z * output_shape.w + j * output_shape.z * output_shape.w +
k * output_shape.w + w;
if (k < input_shape0.z) {
output[postion] = input0[index_input_shape0++];
} else {
output[postion] = input1[index_input_shape1++];
}
}
}
}
}
break;
case 3:
for (int i = 0; i < output_shape.x; i++) {
for (int j = 0; j < output_shape.y; j++) {
for (int k = 0; k < output_shape.z; k++) {
for (int w = 0; w < output_shape.w; w++) {
postion = i * output_shape.y * output_shape.z * output_shape.w + j * output_shape.z * output_shape.w +
k * output_shape.w + w;
if (w < input_shape0.w) {
output[postion] = input0[index_input_shape0++];
} else {
output[postion] = input1[index_input_shape1++];
}
}
}
}
}
break;
default:
break;
uint oh = get_global_id(0);
uint ow = get_global_id(1);
uint oc = get_global_id(2);
uint index_output;
uint input_idx;
if ((oh >= output_shape.y || oh < 0) || (ow >= output_shape.z || ow < 0) || (oc >= output_shape.w || oc < 0)) {
return;
}
if (axis == 3) {
index_output = oh * output_shape.z * output_shape.w + ow * output_shape.w + oc;
if (oc < input_shape0.w) {
input_idx = (input_shape0.z * oh + ow) * input_shape0.w + oc;
output[index_output] = input0[input_idx];
} else if ((input_shape0.w <= oc) && oc < (input_shape0.w + input_shape1.w)) {
input_idx = (input_shape1.z * oh + ow) * input_shape1.w + (oc - input_shape0.w);
output[index_output] = input1[input_idx];
} else {
output[index_output] = 0;
}
}
}
__kernel void Concat3input(__global float *input0, __global float *input1, __global float *input2,
__global float *output, const int4 input_shape0, const int4 input_shape1,
const int4 input_shape2, const int4 output_shape, const int axis) {
uint oh = get_global_id(0);
uint ow = get_global_id(1);
uint oc = get_global_id(2);
uint index_output;
uint input_idx;
if ((oh >= output_shape.y || oh < 0) || (ow >= output_shape.z || ow < 0) || (oc >= output_shape.w || oc < 0)) {
return;
}
index_output = oh * output_shape.z * output_shape.w + ow * output_shape.w + oc;
if (oc < (input_shape0.w + input_shape1.w)) {
if (oc < input_shape0.w) {
input_idx = (input_shape0.z * oh + ow) * input_shape0.w + oc;
output[index_output] = input0[input_idx];
} else {
input_idx = (input_shape1.z * oh + ow) * input_shape1.w + (oc - input_shape0.w);
output[index_output] = input1[input_idx];
}
} else {
if ((input_shape0.w + input_shape1.w + input_shape2.w) <= oc) {
output[index_output] = 0;
} else {
input_idx = (input_shape2.z * oh + ow) * input_shape2.w + (oc - input_shape0.w - input_shape1.w);
output[index_output] = input2[input_idx];
}
}
}
\ No newline at end of file
......@@ -13,15 +13,13 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/runtime/kernel/opencl/kernel/concat.h"
#include <string>
#include <algorithm>
#include <set>
#include "src/kernel_registry.h"
#include "src/runtime/opencl/opencl_runtime.h"
#include "src/runtime/kernel/arm/opclib/concat_parameter.h"
#ifndef PROGRAM_WITH_IL
#include "src/runtime/kernel/opencl/cl/fp32/concat.cl.inc"
#endif
#include "src/runtime/kernel/opencl/kernel/concat.h"
#include "src/backend/opencl/cl/fp32/concat.cl.inc"
using mindspore::kernel::KERNEL_ARCH::kGPU;
using mindspore::lite::KernelRegistrar;
......@@ -35,7 +33,7 @@ int ConcatOpenCLKernel::Init() {
}
auto param = reinterpret_cast<ConcatParameter *>(this->opParameter);
MS_LOG(DEBUG) << "concat at axis=: " << param->axis_;
MS_LOG(INFO) << "concat at axis=: " << param->axis_;
if (param->axis_ != 0 && param->axis_ != 3) {
MS_LOG(ERROR) << "only support axis=0 or axis=3";
}
......@@ -43,20 +41,26 @@ int ConcatOpenCLKernel::Init() {
if (param->axis_ == 0) {
return 0;
}
if (inputs_.size() == 2) {
std::set<std::string> build_options;
std::string source = concat_source_fp32;
std::string program_name = "Concat";
std::string kernel_name = "Concat";
auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance();
ocl_runtime->LoadSource(program_name, source);
ocl_runtime->BuildKernel(kernel_, program_name, kernel_name, build_options);
}
if (inputs_.size() == 3) {
std::set<std::string> build_options;
std::string source = concat_source_fp32;
std::string program_name = "Concat3input";
std::string kernel_name = "Concat3input";
auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance();
ocl_runtime->LoadSource(program_name, source);
ocl_runtime->BuildKernel(kernel_, program_name, kernel_name, build_options);
}
std::string kernel_name = "Concat";
auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance();
#ifdef PROGRAM_WITH_IL
ocl_runtime->CreateKernelFromIL(kernel_(), kernel_name);
#else
std::set<std::string> build_options;
std::string source = concat_source_fp32;
std::string program_name = "Concat";
ocl_runtime->LoadSource(program_name, source);
ocl_runtime->BuildKernel(kernel_, program_name, kernel_name, build_options);
#endif
outputs_[0]->SetFormat(schema::Format_NHWC4);
MS_LOG(DEBUG) << kernel_name << " Init Done!";
return 0;
}
......@@ -88,42 +92,113 @@ int ConcatOpenCLKernel::Run_axis0() {
}
return 0;
}
int DivideRoundUp(int n, int div) {
int q = n / div;
return n % div == 0 ? q : q + 1;
}
int GetBiggestDividerWithPriority(int number, int max_divider) {
if (number % 8 == 0 && 8 <= max_divider) {
return number / 8;
}
if (number % 4 == 0 && 4 <= max_divider) {
return number / 4;
}
if (number % 2 == 0 && 2 <= max_divider) {
return number / 2;
}
for (int i = max_divider; i != 0; i--) {
if (number % i == 0) {
return i;
}
}
return 1;
}
void ConcatGetWorkGroup(const std::vector<size_t> &global, const std::vector<size_t> &local, int max_size) {
int x = std::min(GetBiggestDividerWithPriority(global[0], 8), 4);
int yz = max_size / x;
int y = std::min(std::min(GetBiggestDividerWithPriority(global[1], 8), yz), 8);
int z = std::min(yz / y, DivideRoundUp(global[2], 2));
local = {static_cast<unsigned int>(x), static_cast<unsigned int>(y), static_cast<unsigned int>(z)};
}
int ConcatOpenCLKernel::Run() {
MS_LOG(DEBUG) << this->Name() << " Running!";
auto param = reinterpret_cast<ConcatParameter *>(this->opParameter);
if (param->axis_ == 0) {
return Run_axis0();
}
auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance();
std::vector<size_t> local = {1, 1, 1};
std::vector<size_t> global = {1, 1, 1};
auto input0_shape = inputs_[0]->shape();
auto input1_shape = inputs_[1]->shape();
auto output_shape = outputs_[0]->shape();
cl_int4 input0_shape_ = {input0_shape[0], input0_shape[1], input0_shape[2], input0_shape[3]};
cl_int4 input1_shape_ = {input1_shape[0], input1_shape[1], input1_shape[2], input1_shape[3]};
cl_int4 output_shape_ = {output_shape[0], output_shape[1], output_shape[2], output_shape[3]};
int arg_cn = 0;
ocl_runtime->SetKernelArg(kernel_, arg_cn++, inputs_[0]->Data());
ocl_runtime->SetKernelArg(kernel_, arg_cn++, inputs_[1]->Data());
ocl_runtime->SetKernelArg(kernel_, arg_cn++, outputs_[0]->Data());
ocl_runtime->SetKernelArg(kernel_, arg_cn++, input0_shape_);
ocl_runtime->SetKernelArg(kernel_, arg_cn++, input1_shape_);
ocl_runtime->SetKernelArg(kernel_, arg_cn++, output_shape_);
ocl_runtime->SetKernelArg(kernel_, arg_cn++, param->axis_);
std::vector<size_t> local;
std::vector<size_t> global;
if (inputs_.size() == 2) {
auto input0_shape = inputs_[0]->shape();
auto input1_shape = inputs_[1]->shape();
auto output_shape = outputs_[0]->shape();
cl_int4 input0_shape_ = {input0_shape[0], input0_shape[1], input0_shape[2], input0_shape[3]};
cl_int4 input1_shape_ = {input1_shape[0], input1_shape[1], input1_shape[2], input1_shape[3]};
cl_int4 output_shape_ = {output_shape[0], output_shape[1], output_shape[2], output_shape[3]};
uint32_t OH = output_shape[0] * output_shape[1]; // N*H
uint32_t OW = output_shape[2];
uint32_t OC = output_shape[3];
global = {OH, OW, OC}; // HWC
ConcatGetWorkGroup(global, local, 384);
std::cout << "local size=:" << std::endl;
for (int i = 0; i < local.size(); i++) {
std::cout << local[i] << " ";
}
std::cout << std::endl;
int arg_cn = 0;
ocl_runtime->SetKernelArg(kernel_, arg_cn++, inputs_[0]->Data());
ocl_runtime->SetKernelArg(kernel_, arg_cn++, inputs_[1]->Data());
ocl_runtime->SetKernelArg(kernel_, arg_cn++, outputs_[0]->Data());
ocl_runtime->SetKernelArg(kernel_, arg_cn++, input0_shape_);
ocl_runtime->SetKernelArg(kernel_, arg_cn++, input1_shape_);
ocl_runtime->SetKernelArg(kernel_, arg_cn++, output_shape_);
ocl_runtime->SetKernelArg(kernel_, arg_cn++, param->axis_);
}
if (inputs_.size() == 3) {
auto input0_shape = inputs_[0]->shape();
auto input1_shape = inputs_[1]->shape();
auto input2_shape = inputs_[2]->shape();
auto output_shape = outputs_[0]->shape();
cl_int4 input0_shape_ = {input0_shape[0], input0_shape[1], input0_shape[2], input0_shape[3]};
cl_int4 input1_shape_ = {input1_shape[0], input1_shape[1], input1_shape[2], input1_shape[3]};
cl_int4 input2_shape_ = {input2_shape[0], input2_shape[1], input2_shape[2], input2_shape[3]};
cl_int4 output_shape_ = {output_shape[0], output_shape[1], output_shape[2], output_shape[3]};
uint32_t OH = output_shape[0] * output_shape[1]; // N*H
uint32_t OW = output_shape[2];
uint32_t OC = output_shape[3];
global = {OH, OW, OC}; // HWC
ConcatGetWorkGroup(global, local, 384);
std::cout << "local size=:" << std::endl;
for (int i = 0; i < local.size(); i++) {
std::cout << local[i] << " ";
}
std::cout << std::endl;
int arg_cn = 0;
ocl_runtime->SetKernelArg(kernel_, arg_cn++, inputs_[0]->Data());
ocl_runtime->SetKernelArg(kernel_, arg_cn++, inputs_[1]->Data());
ocl_runtime->SetKernelArg(kernel_, arg_cn++, inputs_[2]->Data());
ocl_runtime->SetKernelArg(kernel_, arg_cn++, outputs_[0]->Data());
ocl_runtime->SetKernelArg(kernel_, arg_cn++, input0_shape_);
ocl_runtime->SetKernelArg(kernel_, arg_cn++, input1_shape_);
ocl_runtime->SetKernelArg(kernel_, arg_cn++, input2_shape_);
ocl_runtime->SetKernelArg(kernel_, arg_cn++, output_shape_);
ocl_runtime->SetKernelArg(kernel_, arg_cn++, param->axis_);
}
ocl_runtime->RunKernel(kernel_, global, local, nullptr);
return 0;
}
kernel::LiteKernel *OpenCLConcatKernelCreator(const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs,
OpParameter *opParameter, const lite::Context *ctx,
const kernel::KernelKey &desc) {
kernel::LiteKernel *OpenCLConcatKernelCreator(const std::vector<tensor::Tensor *> &inputs,
const std::vector<tensor::Tensor *> &outputs, OpParameter *opParameter,
const lite::Context *ctx, const kernel::KernelKey &desc) {
auto *kernel = new ConcatOpenCLKernel(opParameter, inputs, outputs);
auto ret = kernel->Init();
if (0 != ret) {
......@@ -136,4 +211,3 @@ kernel::LiteKernel *OpenCLConcatKernelCreator(const std::vector<lite::tensor::Te
REG_KERNEL(kGPU, kNumberTypeFloat32, PrimitiveType_Concat, OpenCLConcatKernelCreator);
} // namespace mindspore::kernel
......@@ -17,30 +17,27 @@
#ifndef MINDSPORE_LITE_SRC_BACKEND_OPENCL_Concat_H_
#define MINDSPORE_LITE_SRC_BACKEND_OPENCL_Concat_H_
#include <memory.h>
#include <iostream>
#include <vector>
#include "ir/anf.h"
#include "src/lite_kernel.h"
#include "src/runtime/kernel/arm/opclib/concat_parameter.h"
#include "src/backend/arm/opclib/conv_parameter.h"
#include "src/runtime/opencl/opencl_runtime.h"
#include "src/runtime/kernel/arm/opclib/fp32/concat.h"
#include "src/runtime/kernel/arm/opclib/int8/concat_int8.h"
#include "src/backend/arm/opclib/concat.h"
namespace mindspore::kernel {
class ConcatOpenCLKernel : public LiteKernel {
public:
explicit ConcatOpenCLKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs)
explicit ConcatOpenCLKernel(OpParameter *parameter, const std::vector<tensor::Tensor *> &inputs,
const std::vector<tensor::Tensor *> &outputs)
: LiteKernel(parameter, inputs, outputs) {}
~ConcatOpenCLKernel() override{};
int Init() override;
int InferShape() { return {}; }
// int InferShape() { return {}; };
int InferShape() {}
int ReSize() override;
int Run_axis0();
......@@ -52,6 +49,4 @@ class ConcatOpenCLKernel : public LiteKernel {
};
} // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_BACKEND_OPENCL_DEPTHWISE_H_
#endif
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <iostream>
#include <memory>
#include "utils/log_adapter.h"
#include "common/common_test.h"
#include "mindspore/lite/src/runtime/opencl/opencl_runtime.h"
#include "mindspore/lite/src/backend/opencl/subgraph_opencl_kernel.h"
#include "mindspore/lite/src/backend/opencl/kernel/concat.h"
using mindspore::kernel;
using mindspore::lite;
using mindspore;
int DivideRoundUp(int n, int div) {
int q = n / div;
return n % div == 0 ? q : q + 1;
}
void printfNode(float *result, const std::vector<int> &tempNode) {
for (int i = 0; i < tempNode[0]; i++) {
for (int j = 0; j < tempNode[1]; j++) {
for (int k = 0; k < tempNode[2]; k++) {
for (int w = 0; w < tempNode[3]; w++) {
std::cout
<< result[i * tempNode[2] * tempNode[1] * tempNode[3] + j * tempNode[2] * tempNode[3] + k * tempNode[3] + w]
<< " ";
}
std::cout << std::endl;
}
std::cout << std::endl;
}
std::cout << std::endl;
}
std::cout << std::endl;
}
void ConcatComputeByCPU_2input_dim4_axis3(float *input0, float *input1, float *output, std::vector<int> input_shape0,
std::vector<int> input_shape1, std::vector<int> output_shape,
const int axis) {
int postion, index0 = 0, index1 = 0;
for (int i = 0; i < output_shape[0]; i++) {
for (int j = 0; j < output_shape[1]; j++) {
for (int k = 0; k < output_shape[2]; k++) {
postion = i * output_shape[1] * output_shape[2] * output_shape[3] + j * output_shape[2] * output_shape[3] +
k * output_shape[3];
for (int w = 0; w < output_shape[3]; w++) {
if (w < input_shape0[3] + input_shape1[3]) {
output[postion++] = (w < input_shape0[3]) ? input0[index0++] : input1[index1++];
} else {
for (int ind = input_shape0[3] + input_shape1[3]; ind < output_shape[3]; ind++) {
output[postion++] = 0;
}
}
}
}
}
}
}
void ConcatComputeByCPU_3input_dim4_axis3(float *input0, float *input1, float *input2, float *output,
std::vector<int> input_shape0, std::vector<int> input_shape1,
std::vector<int> input_shape2, std::vector<int> output_shape,
const int axis) {
int postion, index0 = 0, index1 = 0, index2 = 0;
for (int i = 0; i < output_shape[0]; i++) {
for (int j = 0; j < output_shape[1]; j++) {
for (int k = 0; k < output_shape[2]; k++) {
postion = i * output_shape[1] * output_shape[2] * output_shape[3] + j * output_shape[2] * output_shape[3] +
k * output_shape[3];
for (int w = 0; w < output_shape[3]; w++) {
if (w < input_shape0[3] + input_shape1[3]) {
output[postion++] = (w < input_shape0[3]) ? input0[index0++] : input1[index1++];
} else if ((input_shape0[3] + input_shape1[3]) <= w &&
w < (input_shape0[3] + input_shape1[3] + input_shape2[3])) {
output[postion++] = input2[index2++];
} else {
for (int ind = input_shape0[3] + input_shape1[3]; ind < output_shape[3]; ind++) {
output[postion++] = 0;
}
}
}
}
}
}
}
namespace mindspore {
class TestConcatOpenCL : public UT::Common {
public:
TestConcatOpenCL(){}
};
TEST_F(TestConcatOpenCL, ConcatFp32_2input_dim4_axis3) {
MS_LOG(INFO) << "begin test";
auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance();
ocl_runtime->Init();
MS_LOG(INFO) << "init tensors";
constexpr int INPUT_NUM = 3;
std::array<std::vector<int>, INPUT_NUM> input_shapes = {
std::vector<int>{1, 240, 240, 16}, std::vector<int>{1, 240, 240, 16}, std::vector<int>{1, 240, 240, 64}};
std::vector<int> output_shape = {1, 240, 240, 96};
output_shape[3] = DivideRoundUp(output_shape[3], 4) * 4;
auto data_type = kNumberTypeFloat32;
auto tensor_type = schema::NodeType_ValueNode;
std::vector<tensor::Tensor *> inputs;
for (auto &shape : input_shapes) {
inputs.push_back(new tensor::Tensor(data_type, shape, schema::Format_NHWC, tensor_type));
}
auto *output_tensor = new tensor::Tensor(data_type, output_shape, schema::Format_NHWC, tensor_type);
std::vector<tensor::Tensor *> outputs{output_tensor};
std::cout << "input_shapes size=: " << input_shapes.size() << std::endl;
MS_LOG(INFO) << "initialize tensors";
auto param = new ConcatParameter();
param->axis_ = 3;
auto *concat_kernel = new ConcatOpenCLKernel(reinterpret_cast<OpParameter *>(param), inputs, outputs);
concat_kernel->Init();
MS_LOG(INFO) << "initialize sub_graph";
std::vector<LiteKernel *> kernels{concat_kernel};
auto *sub_graph = new SubGraphOpenCLKernel(inputs, outputs, kernels, kernels, kernels);
sub_graph->Init();
MS_LOG(INFO) << "initialize input data";
srand(time(NULL));
for (auto &input_tensor : inputs) {
auto input_data = reinterpret_cast<float *>(input_tensor->Data());
for (int i = 0; i < input_tensor->ElementsNum(); ++i) {
input_data[i] = static_cast<float>(rand_r() % 10 + 1);
}
printf("\n");
}
MS_LOG(INFO) << "==================output data================";
sub_graph->Run();
auto *output_data_gpu = reinterpret_cast<float *>(output_tensor->Data());
printf("\n");
auto *input_data0 = reinterpret_cast<float *>(inputs[0]->Data());
auto *input_data1 = reinterpret_cast<float *>(inputs[1]->Data());
std::vector<float> output_data_cpu(output_shape[0] * output_shape[1] * output_shape[2] * output_shape[3]);
if (inputs.size() == 2) {
ConcatComputeByCPU_2input_dim4_axis3(input_data0, input_data1, output_data_cpu.data(), input_shapes[0],
input_shapes[1], output_shape, param->axis_);
}
if (inputs.size() == 3) {
auto *input_data2 = reinterpret_cast<float *>(inputs[2]->Data());
ConcatComputeByCPU_3input_dim4_axis3(input_data0, input_data1, input_data2, output_data_cpu.data(), input_shapes[0],
input_shapes[1], input_shapes[2], output_shape, param->axis_);
}
printf("\n");
CompareOutputData(output_data_gpu, output_data_cpu.data(), output_tensor->ElementsNum(), 0.00001);
MS_LOG(INFO) << "Testconcat passed";
}
} // namespace mindspore
set(ANF_SRC
${ANF_SRC}
# core/abstract
#core / abstract
${CMAKE_CURRENT_SOURCE_DIR}/../../../core/abstract/abstract_function.cc
${CMAKE_CURRENT_SOURCE_DIR}/../../../core/abstract/analysis_context.cc
${CMAKE_CURRENT_SOURCE_DIR}/../../../core/abstract/param_validator.cc
${CMAKE_CURRENT_SOURCE_DIR}/../../../core/abstract/abstract_value.cc
${CMAKE_CURRENT_SOURCE_DIR}/../../../core/abstract/dshape.cc
${CMAKE_CURRENT_SOURCE_DIR}/../../../core/abstract/utils.cc
# core/base
#core / base
${CMAKE_CURRENT_SOURCE_DIR}/../../../core/base/base_ref.cc
# core/ir
#core / ir
${CMAKE_CURRENT_SOURCE_DIR}/../../../core/ir/anf.cc
${CMAKE_CURRENT_SOURCE_DIR}/../../../core/ir/anf_extends.cc
${CMAKE_CURRENT_SOURCE_DIR}/../../../core/ir/meta_func_graph.cc
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册