Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
weixin_44025039
mindspore
提交
1dcc34e7
M
mindspore
项目概览
weixin_44025039
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
1dcc34e7
编写于
8月 10, 2020
作者:
Z
ZPaC
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add GPU div kernel
上级
f21ee64f
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
150 addition
and
21 deletion
+150
-21
mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cu
...c/backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cu
+23
-17
mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cuh
.../backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cuh
+1
-0
mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_gpu_kernel.cc
.../backend/kernel_compiler/gpu/math/broadcast_gpu_kernel.cc
+9
-3
mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_gpu_kernel.h
...c/backend/kernel_compiler/gpu/math/broadcast_gpu_kernel.h
+1
-1
tests/st/ops/gpu/test_div_op.py
tests/st/ops/gpu/test_div_op.py
+116
-0
未找到文件。
mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cu
浏览文件 @
1dcc34e7
...
...
@@ -55,6 +55,11 @@ struct RealDivFunc {
__device__
__forceinline__
S
operator
()(
const
T
&
lhs
,
const
T
&
rhs
)
{
return
(
lhs
/
rhs
);
}
};
template
<
typename
T
,
typename
S
>
struct
DivFunc
{
__device__
__forceinline__
S
operator
()(
const
T
&
lhs
,
const
T
&
rhs
)
{
return
(
lhs
/
rhs
);
}
};
template
<
typename
T
,
typename
S
>
struct
MulFunc
{
__device__
__forceinline__
S
operator
()(
const
T
&
lhs
,
const
T
&
rhs
)
{
return
(
lhs
*
rhs
);
}
...
...
@@ -78,7 +83,7 @@ struct FloorDivFunc {
template
<
>
struct
FloorDivFunc
<
half
,
half
>
{
__device__
__forceinline__
half
operator
()(
const
half
&
lhs
,
const
half
&
rhs
)
{
return
__float2half
(
floor
(
__half2float
(
lhs
)
/
__half2float
(
rhs
)));
return
__float2half
(
floor
(
__half2float
(
lhs
)
/
__half2float
(
rhs
)));
}
};
...
...
@@ -96,7 +101,6 @@ struct AbsGradFunc {
}
};
template
<
>
struct
PowerFunc
<
half
,
bool
>
{
// invalid branch
...
...
@@ -105,7 +109,6 @@ struct PowerFunc<half, bool> {
__device__
__forceinline__
int
Index
(
const
int
&
index
,
const
int
&
dim
)
{
return
dim
==
1
?
0
:
index
;
}
template
<
typename
T
,
typename
S
,
typename
Func
>
__device__
__forceinline__
void
BroadcastOperator
(
const
int
&
l0
,
const
int
&
l1
,
const
int
&
l2
,
const
int
&
l3
,
const
int
&
l4
,
const
int
&
l5
,
const
int
&
l6
,
const
int
&
r0
,
...
...
@@ -181,6 +184,9 @@ __global__ void BroadcastKernel(const int l0, const int l1, const int l2, const
case
BROADCAST_TYPE_ABSGRAD
:
return
BroadcastOperator
<
T
,
S
,
AbsGradFunc
<
T
,
S
>>
(
l0
,
l1
,
l2
,
l3
,
l4
,
l5
,
l6
,
r0
,
r1
,
r2
,
r3
,
r4
,
r5
,
r6
,
d0
,
d1
,
d2
,
d3
,
d4
,
d5
,
d6
,
input0
,
input1
,
output
);
case
BROADCAST_TYPE_DIV
:
return
BroadcastOperator
<
T
,
S
,
DivFunc
<
T
,
S
>>
(
l0
,
l1
,
l2
,
l3
,
l4
,
l5
,
l6
,
r0
,
r1
,
r2
,
r3
,
r4
,
r5
,
r6
,
d0
,
d1
,
d2
,
d3
,
d4
,
d5
,
d6
,
input0
,
input1
,
output
);
}
}
...
...
@@ -192,13 +198,11 @@ void Broadcast(const std::vector<int> &lhs_shape, const std::vector<int> &rhs_sh
for
(
auto
d
:
output_shape
)
{
size
*=
d
;
}
BroadcastKernel
<<<
GET_BLOCKS
(
size
),
GET_THREADS
,
0
,
stream
>>>
(
lhs_shape
[
0
],
lhs_shape
[
1
],
lhs_shape
[
2
],
lhs_shape
[
3
],
lhs_shape
[
4
],
lhs_shape
[
5
],
lhs_shape
[
6
],
rhs_shape
[
0
],
rhs_shape
[
1
],
rhs_shape
[
2
],
rhs_shape
[
3
],
rhs_shape
[
4
],
rhs_shape
[
5
],
rhs_shape
[
6
],
output_shape
[
0
],
output_shape
[
1
],
output_shape
[
2
],
output_shape
[
3
],
output_shape
[
4
],
output_shape
[
5
],
output_shape
[
6
],
op
,
input0
,
input1
,
output
);
BroadcastKernel
<<<
GET_BLOCKS
(
size
),
GET_THREADS
,
0
,
stream
>>>
(
lhs_shape
[
0
],
lhs_shape
[
1
],
lhs_shape
[
2
],
lhs_shape
[
3
],
lhs_shape
[
4
],
lhs_shape
[
5
],
lhs_shape
[
6
],
rhs_shape
[
0
],
rhs_shape
[
1
],
rhs_shape
[
2
],
rhs_shape
[
3
],
rhs_shape
[
4
],
rhs_shape
[
5
],
rhs_shape
[
6
],
output_shape
[
0
],
output_shape
[
1
],
output_shape
[
2
],
output_shape
[
3
],
output_shape
[
4
],
output_shape
[
5
],
output_shape
[
6
],
op
,
input0
,
input1
,
output
);
}
template
<
typename
T
,
typename
S
,
typename
Func
>
...
...
@@ -234,6 +238,8 @@ __global__ void NoBroadcastKernel(const int nums, enum BroadcastOpType op, const
return
NoBroadcastOperator
<
T
,
S
,
FloorDivFunc
<
T
,
S
>>
(
nums
,
input0
,
input1
,
output
);
case
BROADCAST_TYPE_ABSGRAD
:
return
NoBroadcastOperator
<
T
,
S
,
AbsGradFunc
<
T
,
S
>>
(
nums
,
input0
,
input1
,
output
);
case
BROADCAST_TYPE_DIV
:
return
NoBroadcastOperator
<
T
,
S
,
DivFunc
<
T
,
S
>>
(
nums
,
input0
,
input1
,
output
);
}
}
...
...
@@ -244,8 +250,8 @@ void NoBroadcast(const int &nums, enum BroadcastOpType op, const T *input0, cons
}
template
<
typename
T
>
__global__
void
BroadcastToKernel
(
const
int
i0
,
const
int
i1
,
const
int
i2
,
const
int
i3
,
const
int
o0
,
const
int
o
1
,
const
int
o
2
,
const
int
o3
,
const
T
*
input_addr
,
T
*
output_addr
)
{
__global__
void
BroadcastToKernel
(
const
int
i0
,
const
int
i1
,
const
int
i2
,
const
int
i3
,
const
int
o0
,
const
int
o1
,
const
int
o2
,
const
int
o3
,
const
T
*
input_addr
,
T
*
output_addr
)
{
for
(
size_t
pos
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
pos
<
o0
*
o1
*
o2
*
o3
;
pos
+=
blockDim
.
x
*
gridDim
.
x
)
{
int
i
=
pos
/
(
o1
*
o2
*
o3
)
%
o0
;
int
j
=
pos
/
(
o2
*
o3
)
%
o1
;
...
...
@@ -262,7 +268,7 @@ void BroadcastTo(const int &i0, const int &i1, const int &i2, const int &i3, con
const
int
&
o2
,
const
int
&
o3
,
const
T
*
input_addr
,
T
*
output_addr
,
cudaStream_t
stream
)
{
int
nums
=
o0
*
o1
*
o2
*
o3
;
BroadcastToKernel
<<<
GET_BLOCKS
(
nums
),
GET_THREADS
,
0
,
stream
>>>
(
i0
,
i1
,
i2
,
i3
,
o0
,
o1
,
o2
,
o3
,
input_addr
,
output_addr
);
output_addr
);
}
template
void
Broadcast
(
const
std
::
vector
<
int
>
&
lhs_shape
,
const
std
::
vector
<
int
>
&
rhs_shape
,
...
...
@@ -291,10 +297,10 @@ template void NoBroadcast(const int &nums, enum BroadcastOpType op, const half *
bool
*
output
,
cudaStream_t
stream
);
template
void
NoBroadcast
(
const
int
&
nums
,
enum
BroadcastOpType
op
,
const
half
*
input0
,
const
half
*
input1
,
half
*
output
,
cudaStream_t
stream
);
template
void
NoBroadcast
(
const
int
&
nums
,
enum
BroadcastOpType
op
,
const
int
*
input0
,
const
int
*
input1
,
int
*
output
,
cudaStream_t
stream
);
template
void
NoBroadcast
(
const
int
&
nums
,
enum
BroadcastOpType
op
,
const
int
*
input0
,
const
int
*
input1
,
bool
*
output
,
cudaStream_t
stream
);
template
void
NoBroadcast
(
const
int
&
nums
,
enum
BroadcastOpType
op
,
const
int
*
input0
,
const
int
*
input1
,
int
*
output
,
cudaStream_t
stream
);
template
void
NoBroadcast
(
const
int
&
nums
,
enum
BroadcastOpType
op
,
const
int
*
input0
,
const
int
*
input1
,
bool
*
output
,
cudaStream_t
stream
);
template
void
BroadcastTo
(
const
int
&
i0
,
const
int
&
i1
,
const
int
&
i2
,
const
int
&
i3
,
const
int
&
o0
,
const
int
&
o1
,
const
int
&
o2
,
const
int
&
o3
,
const
float
*
input_addr
,
float
*
output_addr
,
cudaStream_t
stream
);
...
...
mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cuh
浏览文件 @
1dcc34e7
...
...
@@ -32,6 +32,7 @@ enum BroadcastOpType {
BROADCAST_TYPE_ADD
=
8
,
BROADCAST_TYPE_FLOORDIV
=
9
,
BROADCAST_TYPE_ABSGRAD
=
10
,
BROADCAST_TYPE_DIV
=
11
,
BROADCAST_TYPE_INVALID
=
0xffffffff
,
};
...
...
mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_gpu_kernel.cc
浏览文件 @
1dcc34e7
...
...
@@ -59,6 +59,9 @@ MS_REG_GPU_KERNEL_TWO(
AbsGrad
,
KernelAttr
().
AddInputAttr
(
kNumberTypeFloat32
).
AddInputAttr
(
kNumberTypeFloat32
).
AddOutputAttr
(
kNumberTypeFloat32
),
BroadcastOpGpuKernel
,
float
,
float
)
MS_REG_GPU_KERNEL_TWO
(
Div
,
KernelAttr
().
AddInputAttr
(
kNumberTypeFloat32
).
AddInputAttr
(
kNumberTypeFloat32
).
AddOutputAttr
(
kNumberTypeFloat32
),
BroadcastOpGpuKernel
,
float
,
float
)
// fp16
MS_REG_GPU_KERNEL_TWO
(
...
...
@@ -101,6 +104,9 @@ MS_REG_GPU_KERNEL_TWO(
AbsGrad
,
KernelAttr
().
AddInputAttr
(
kNumberTypeFloat16
).
AddInputAttr
(
kNumberTypeFloat16
).
AddOutputAttr
(
kNumberTypeFloat16
),
BroadcastOpGpuKernel
,
half
,
half
)
MS_REG_GPU_KERNEL_TWO
(
Div
,
KernelAttr
().
AddInputAttr
(
kNumberTypeFloat16
).
AddInputAttr
(
kNumberTypeFloat16
).
AddOutputAttr
(
kNumberTypeFloat16
),
BroadcastOpGpuKernel
,
half
,
half
)
// int32
MS_REG_GPU_KERNEL_TWO
(
...
...
@@ -118,14 +124,14 @@ MS_REG_GPU_KERNEL_TWO(
MS_REG_GPU_KERNEL_TWO
(
Mul
,
KernelAttr
().
AddInputAttr
(
kNumberTypeInt32
).
AddInputAttr
(
kNumberTypeInt32
).
AddOutputAttr
(
kNumberTypeInt32
),
BroadcastOpGpuKernel
,
int
,
int
)
MS_REG_GPU_KERNEL_TWO
(
RealDiv
,
KernelAttr
().
AddInputAttr
(
kNumberTypeInt32
).
AddInputAttr
(
kNumberTypeInt32
).
AddOutputAttr
(
kNumberTypeInt32
),
BroadcastOpGpuKernel
,
int
,
int
)
MS_REG_GPU_KERNEL_TWO
(
FloorDiv
,
KernelAttr
().
AddInputAttr
(
kNumberTypeInt32
).
AddInputAttr
(
kNumberTypeInt32
).
AddOutputAttr
(
kNumberTypeInt32
),
BroadcastOpGpuKernel
,
int
,
int
)
MS_REG_GPU_KERNEL_TWO
(
AbsGrad
,
KernelAttr
().
AddInputAttr
(
kNumberTypeInt32
).
AddInputAttr
(
kNumberTypeInt32
).
AddOutputAttr
(
kNumberTypeInt32
),
BroadcastOpGpuKernel
,
int
,
int
)
MS_REG_GPU_KERNEL_TWO
(
Div
,
KernelAttr
().
AddInputAttr
(
kNumberTypeInt32
).
AddInputAttr
(
kNumberTypeInt32
).
AddOutputAttr
(
kNumberTypeInt32
),
BroadcastOpGpuKernel
,
int
,
int
)
}
// namespace kernel
}
// namespace mindspore
mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_gpu_kernel.h
浏览文件 @
1dcc34e7
...
...
@@ -102,7 +102,7 @@ class BroadcastOpGpuKernel : public GpuKernel {
{
"Greater"
,
BROADCAST_TYPE_GREATER
},
{
"Less"
,
BROADCAST_TYPE_LESS
},
{
"Maximum"
,
BROADCAST_TYPE_MAXIMUM
},
{
"Minimum"
,
BROADCAST_TYPE_MINIMUM
},
{
"Pow"
,
BROADCAST_TYPE_POWER
},
{
"RealDiv"
,
BROADCAST_TYPE_REALDIV
},
{
"Mul"
,
BROADCAST_TYPE_MUL
},
{
"Sub"
,
BROADCAST_TYPE_SUB
},
{
"TensorAdd"
,
BROADCAST_TYPE_ADD
},
{
"FloorDiv"
,
BROADCAST_TYPE_FLOORDIV
},
{
"AbsGrad"
,
BROADCAST_TYPE_ABSGRAD
},
{
"FloorDiv"
,
BROADCAST_TYPE_FLOORDIV
},
{
"AbsGrad"
,
BROADCAST_TYPE_ABSGRAD
},
{
"Div"
,
BROADCAST_TYPE_DIV
},
};
auto
iter
=
kBroadcastTypeMap
.
find
(
kernel_name
);
...
...
tests/st/ops/gpu/test_div_op.py
0 → 100644
浏览文件 @
1dcc34e7
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import
numpy
as
np
import
pytest
import
mindspore.context
as
context
import
mindspore.nn
as
nn
from
mindspore
import
Tensor
from
mindspore.ops
import
operations
as
P
class
NetDiv
(
nn
.
Cell
):
def
__init__
(
self
):
super
(
NetDiv
,
self
).
__init__
()
self
.
div
=
P
.
Div
()
def
construct
(
self
,
x
,
y
):
return
self
.
div
(
x
,
y
)
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
platform_x86_gpu_training
@
pytest
.
mark
.
env_onecard
def
test_div
():
x0_np
=
np
.
random
.
randint
(
1
,
5
,
(
2
,
3
,
4
,
4
)).
astype
(
np
.
float32
)
y0_np
=
np
.
random
.
randint
(
1
,
5
,
(
2
,
3
,
4
,
4
)).
astype
(
np
.
float32
)
x1_np
=
np
.
random
.
randint
(
1
,
5
,
(
2
,
3
,
4
,
4
)).
astype
(
np
.
float32
)
y1_np
=
np
.
random
.
randint
(
1
,
5
,
(
2
,
1
,
4
,
4
)).
astype
(
np
.
float32
)
x2_np
=
np
.
random
.
randint
(
1
,
5
,
(
2
,
1
,
1
,
4
)).
astype
(
np
.
float32
)
y2_np
=
np
.
random
.
randint
(
1
,
5
,
(
2
,
3
,
4
,
4
)).
astype
(
np
.
float32
)
x3_np
=
np
.
random
.
randint
(
1
,
5
,
1
).
astype
(
np
.
float32
)
y3_np
=
np
.
random
.
randint
(
1
,
5
,
1
).
astype
(
np
.
float32
)
x4_np
=
np
.
array
(
768
).
astype
(
np
.
float32
)
y4_np
=
np
.
array
(
3072.5
).
astype
(
np
.
float32
)
x5_np
=
np
.
random
.
randint
(
1
,
5
,
(
2
,
3
,
4
,
4
)).
astype
(
np
.
float16
)
y5_np
=
np
.
random
.
randint
(
1
,
5
,
(
2
,
3
,
4
,
4
)).
astype
(
np
.
float16
)
x6_np
=
np
.
random
.
randint
(
1
,
5
,
(
2
,
3
,
4
,
4
)).
astype
(
np
.
int32
)
y6_np
=
np
.
random
.
randint
(
1
,
5
,
(
2
,
1
,
4
,
4
)).
astype
(
np
.
int32
)
x0
=
Tensor
(
x0_np
)
y0
=
Tensor
(
y0_np
)
x1
=
Tensor
(
x1_np
)
y1
=
Tensor
(
y1_np
)
x2
=
Tensor
(
x2_np
)
y2
=
Tensor
(
y2_np
)
x3
=
Tensor
(
x3_np
)
y3
=
Tensor
(
y3_np
)
x4
=
Tensor
(
x4_np
)
y4
=
Tensor
(
y4_np
)
x5
=
Tensor
(
x5_np
)
y5
=
Tensor
(
y5_np
)
x6
=
Tensor
(
x6_np
)
y6
=
Tensor
(
y6_np
)
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
'GPU'
)
div
=
NetDiv
()
output0
=
div
(
x0
,
y0
)
expect0
=
np
.
divide
(
x0_np
,
y0_np
)
diff0
=
output0
.
asnumpy
()
-
expect0
error0
=
np
.
ones
(
shape
=
expect0
.
shape
)
*
1.0e-5
assert
np
.
all
(
diff0
<
error0
)
assert
output0
.
shape
==
expect0
.
shape
output1
=
div
(
x1
,
y1
)
expect1
=
np
.
divide
(
x1_np
,
y1_np
)
diff1
=
output1
.
asnumpy
()
-
expect1
error1
=
np
.
ones
(
shape
=
expect1
.
shape
)
*
1.0e-5
assert
np
.
all
(
diff1
<
error1
)
assert
output1
.
shape
==
expect1
.
shape
output2
=
div
(
x2
,
y2
)
expect2
=
np
.
divide
(
x2_np
,
y2_np
)
diff2
=
output2
.
asnumpy
()
-
expect2
error2
=
np
.
ones
(
shape
=
expect2
.
shape
)
*
1.0e-5
assert
np
.
all
(
diff2
<
error2
)
assert
output2
.
shape
==
expect2
.
shape
context
.
set_context
(
mode
=
context
.
PYNATIVE_MODE
,
device_target
=
'GPU'
)
output3
=
div
(
x3
,
y3
)
expect3
=
np
.
divide
(
x3_np
,
y3_np
)
diff3
=
output3
.
asnumpy
()
-
expect3
error3
=
np
.
ones
(
shape
=
expect3
.
shape
)
*
1.0e-5
assert
np
.
all
(
diff3
<
error3
)
assert
output3
.
shape
==
expect3
.
shape
output4
=
div
(
x4
,
y4
)
expect4
=
np
.
divide
(
x4_np
,
y4_np
)
diff4
=
output4
.
asnumpy
()
-
expect4
error4
=
np
.
ones
(
shape
=
expect4
.
shape
)
*
1.0e-5
assert
np
.
all
(
diff4
<
error4
)
assert
output4
.
shape
==
expect4
.
shape
output5
=
div
(
x5
,
y5
)
expect5
=
np
.
divide
(
x5_np
,
y5_np
)
diff5
=
output5
.
asnumpy
()
-
expect5
error5
=
np
.
ones
(
shape
=
expect5
.
shape
)
*
1.0e-5
assert
np
.
all
(
diff5
<
error5
)
assert
output5
.
shape
==
expect5
.
shape
output6
=
div
(
x6
,
y6
)
expect6
=
np
.
divide
(
x6_np
,
y6_np
)
diff6
=
output6
.
asnumpy
()
-
expect6
error6
=
np
.
ones
(
shape
=
expect6
.
shape
)
*
1.0e-5
assert
np
.
all
(
diff6
<
error6
)
assert
output6
.
shape
==
expect6
.
shape
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录