Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
8c993552
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
8c993552
编写于
8月 15, 2020
作者:
C
Corleone
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
clean code for opencl
上级
7e47cdc4
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
76 addition
and
15 deletion
+76
-15
mindspore/lite/src/runtime/kernel/opencl/cl/fp32/arithmetic_buffer.cl
...te/src/runtime/kernel/opencl/cl/fp32/arithmetic_buffer.cl
+1
-1
mindspore/lite/src/runtime/kernel/opencl/kernel/arithmetic.cc
...spore/lite/src/runtime/kernel/opencl/kernel/arithmetic.cc
+12
-5
mindspore/lite/test/ut/src/runtime/kernel/opencl/arithmetic_tests.cc
...ite/test/ut/src/runtime/kernel/opencl/arithmetic_tests.cc
+63
-9
未找到文件。
mindspore/lite/src/runtime/kernel/opencl/cl/fp32/arithmetic_buffer.cl
浏览文件 @
8c993552
...
...
@@ -23,7 +23,7 @@ __kernel void ElementDiv(__global float *input_a, __global float *input_b, __glo
const
unsigned
int
n
)
{
int
idx
=
get_global_id
(
0
)
;
if
(
idx
>=
n
)
return
;
output[idx]
=
input_a[idx]
*
input_b[idx]
;
output[idx]
=
input_a[idx]
/
input_b[idx]
;
}
__kernel
void
BoardcastArith
(
__global
float
*input_a,
float
weight,
float
bias,
__global
float
*output,
...
...
mindspore/lite/src/runtime/kernel/opencl/kernel/arithmetic.cc
浏览文件 @
8c993552
...
...
@@ -102,19 +102,26 @@ int ArithmeticOpenCLKernel::Init() {
}
}
lite
::
STATUS
error_code
=
RET_OK
;
#ifdef PROGRAM_WITH_IL
runtime_
->
CreateKernelFromIL
(
kernel_
(),
kernel_name
);
bool
ret
=
runtime_
->
CreateKernelFromIL
(
kernel_
(),
kernel_name
);
if
(
!
ret
)
{
error_code
=
RET_ERROR
;
}
#else
std
::
string
program_name
=
"Arithmetic"
;
std
::
set
<
std
::
string
>
build_options
;
std
::
string
source
=
arithmetic_image2d_source_fp32
;
runtime_
->
LoadSource
(
program_name
,
source
);
runtime_
->
BuildKernel
(
kernel_
,
program_name
,
kernel_name
,
build_options
);
error_code
=
runtime_
->
BuildKernel
(
kernel_
,
program_name
,
kernel_name
,
build_options
);
#endif
if
(
error_code
!=
RET_OK
)
{
return
error_code
;
}
ori_format_
=
out_tensors_
[
0
]
->
GetFormat
();
out_tensors_
[
0
]
->
SetFormat
(
schema
::
Format_NHWC4
);
Image2dGetWorkGroupSize
();
return
0
;
return
RET_OK
;
}
int
ArithmeticOpenCLKernel
::
Run
()
{
...
...
@@ -155,7 +162,7 @@ int ArithmeticOpenCLKernel::Run() {
cl_int2
output_shape
{
W
,
H
};
ocl_runtime
->
SetKernelArg
(
kernel_
,
arg_idx
++
,
output_shape
);
ocl_runtime
->
RunKernel
(
kernel_
,
global_size_
,
local_size_
,
nullptr
);
return
0
;
return
RET_OK
;
}
kernel
::
LiteKernel
*
OpenCLArithmeticKernelCreator
(
const
std
::
vector
<
lite
::
tensor
::
Tensor
*>
&
inputs
,
...
...
@@ -170,7 +177,7 @@ kernel::LiteKernel *OpenCLArithmeticKernelCreator(const std::vector<lite::tensor
return
nullptr
;
}
auto
ret
=
kernel
->
Init
();
if
(
0
!=
ret
)
{
if
(
ret
!=
RET_OK
)
{
MS_LOG
(
ERROR
)
<<
"Init kernel failed, name: Arithmetic"
;
delete
kernel
;
return
nullptr
;
...
...
mindspore/lite/test/ut/src/runtime/kernel/opencl/arithmetic_tests.cc
浏览文件 @
8c993552
...
...
@@ -68,18 +68,37 @@ void TestCase(const std::vector<int> &shape_a, const std::vector<int> &shape_b)
auto
tensorType
=
schema
::
NodeType_ValueNode
;
lite
::
tensor
::
Tensor
*
tensor_a
=
new
lite
::
tensor
::
Tensor
(
kNumberTypeFloat32
,
shape_a
,
schema
::
Format_NHWC4
,
tensorType
);
new
(
std
::
nothrow
)
lite
::
tensor
::
Tensor
(
kNumberTypeFloat32
,
shape_a
,
schema
::
Format_NHWC4
,
tensorType
);
lite
::
tensor
::
Tensor
*
tensor_b
=
new
lite
::
tensor
::
Tensor
(
kNumberTypeFloat32
,
shape_b
,
schema
::
Format_NHWC4
,
tensorType
);
new
(
std
::
nothrow
)
lite
::
tensor
::
Tensor
(
kNumberTypeFloat32
,
shape_b
,
schema
::
Format_NHWC4
,
tensorType
);
lite
::
tensor
::
Tensor
*
tensor_c
=
new
lite
::
tensor
::
Tensor
(
kNumberTypeFloat32
,
shape_a
,
schema
::
Format_NHWC4
,
tensorType
);
new
(
std
::
nothrow
)
lite
::
tensor
::
Tensor
(
kNumberTypeFloat32
,
shape_a
,
schema
::
Format_NHWC4
,
tensorType
);
if
(
tensor_a
==
nullptr
||
tensor_b
==
nullptr
||
tensor_c
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"Create tensor failed!"
;
delete
tensor_a
;
delete
tensor_b
;
delete
tensor_c
;
return
;
}
int64_t
element_num
=
tensor_a
->
ElementsC4Num
();
int64_t
element_num_b
=
is_bias_add
?
1
:
tensor_b
->
ElementsC4Num
();
float
*
data_a
=
new
float
[
element_num
];
float
*
data_b
=
new
float
[
element_num_b
];
float
*
data_c_cpu
=
new
float
[
element_num
];
float
*
data_c_ocl
=
new
float
[
element_num
];
float
*
data_a
=
new
(
std
::
nothrow
)
float
[
element_num
];
float
*
data_b
=
new
(
std
::
nothrow
)
float
[
element_num_b
];
float
*
data_c_cpu
=
new
(
std
::
nothrow
)
float
[
element_num
];
float
*
data_c_ocl
=
new
(
std
::
nothrow
)
float
[
element_num
];
if
(
data_a
==
nullptr
||
data_b
==
nullptr
||
data_c_cpu
==
nullptr
||
data_c_ocl
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"Create buffer failed!"
;
delete
tensor_a
;
delete
tensor_b
;
delete
tensor_c
;
delete
[]
data_a
;
delete
[]
data_b
;
delete
[]
data_c_cpu
;
delete
[]
data_c_ocl
;
return
;
}
InitData
(
data_a
,
element_num
);
InitData
(
data_b
,
element_num_b
);
...
...
@@ -100,7 +119,18 @@ void TestCase(const std::vector<int> &shape_a, const std::vector<int> &shape_b)
}
std
::
vector
<
lite
::
tensor
::
Tensor
*>
outputs
=
{
tensor_c
};
ArithmeticParameter
*
param
=
new
ArithmeticParameter
();
ArithmeticParameter
*
param
=
new
(
std
::
nothrow
)
ArithmeticParameter
();
if
(
param
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"Create parameter failed!"
;
delete
tensor_a
;
delete
tensor_b
;
delete
tensor_c
;
delete
[]
data_a
;
delete
[]
data_b
;
delete
[]
data_c_cpu
;
delete
[]
data_c_ocl
;
return
;
}
param
->
ndim_
=
4
;
param
->
op_parameter_
.
type_
=
PrimitiveType_Add
;
...
...
@@ -108,12 +138,36 @@ void TestCase(const std::vector<int> &shape_a, const std::vector<int> &shape_b)
lite
::
Context
ctx
;
auto
*
arith_kernel
=
new
kernel
::
ArithmeticOpenCLKernel
(
reinterpret_cast
<
OpParameter
*>
(
param
),
arithmetic_inputs
,
outputs
,
&
ctx
);
if
(
arith_kernel
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"Create ArithmeticOpenCLKernel failed!"
;
delete
tensor_a
;
delete
tensor_b
;
delete
tensor_c
;
delete
[]
data_a
;
delete
[]
data_b
;
delete
[]
data_c_cpu
;
delete
[]
data_c_ocl
;
delete
param
;
return
;
}
arith_kernel
->
Init
();
tensor_a
->
MallocData
(
allocator
);
tensor_b
->
MallocData
(
allocator
);
std
::
vector
<
kernel
::
LiteKernel
*>
kernels
{
arith_kernel
};
auto
*
kernel
=
new
kernel
::
SubGraphOpenCLKernel
(
inputs
,
outputs
,
kernels
,
kernels
,
kernels
);
auto
*
kernel
=
new
(
std
::
nothrow
)
kernel
::
SubGraphOpenCLKernel
(
inputs
,
outputs
,
kernels
,
kernels
,
kernels
);
if
(
arith_kernel
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"Create SubGraphOpenCLKernel failed!"
;
delete
tensor_a
;
delete
tensor_b
;
delete
tensor_c
;
delete
[]
data_a
;
delete
[]
data_b
;
delete
[]
data_c_cpu
;
delete
[]
data_c_ocl
;
delete
arith_kernel
;
return
;
}
kernel
->
Init
();
memcpy
(
inputs
[
0
]
->
Data
(),
data_a
,
sizeof
(
float
)
*
element_num
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录