Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
0a1195dd
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
0a1195dd
编写于
5月 08, 2020
作者:
W
wilfChen
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
broadcast kernel support unqual dims & half
上级
bab6e0f5
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
89 addition
and
5 deletion
+89
-5
mindspore/ccsrc/kernel/gpu/cuda_impl/broadcast_impl.cu
mindspore/ccsrc/kernel/gpu/cuda_impl/broadcast_impl.cu
+25
-0
mindspore/ccsrc/kernel/gpu/math/broadcast_gpu_kernel.cc
mindspore/ccsrc/kernel/gpu/math/broadcast_gpu_kernel.cc
+21
-0
mindspore/ccsrc/kernel/gpu/math/broadcast_gpu_kernel.h
mindspore/ccsrc/kernel/gpu/math/broadcast_gpu_kernel.h
+13
-5
tests/st/ops/gpu/test_broadcast_op.py
tests/st/ops/gpu/test_broadcast_op.py
+30
-0
未找到文件。
mindspore/ccsrc/kernel/gpu/cuda_impl/broadcast_impl.cu
浏览文件 @
0a1195dd
...
...
@@ -42,6 +42,19 @@ struct PowerFunc {
__device__
__forceinline__
S
operator
()(
const
T
&
lhs
,
const
T
&
rhs
)
{
return
pow
(
lhs
,
rhs
);
}
};
template
<
>
struct
PowerFunc
<
half
,
half
>
{
__device__
__forceinline__
half
operator
()(
const
half
&
lhs
,
const
half
&
rhs
)
{
return
__float2half
(
pow
(
__half2float
(
lhs
),
__half2float
(
rhs
)));
}
};
template
<
>
struct
PowerFunc
<
half
,
bool
>
{
// invalid branch
__device__
__forceinline__
half
operator
()(
const
half
&
lhs
,
const
half
&
rhs
)
{
return
false
;
}
};
__device__
__forceinline__
int
Index
(
const
int
&
index
,
const
int
&
dim
)
{
return
dim
==
1
?
0
:
index
;
}
template
<
typename
T
,
typename
S
,
typename
Func
>
...
...
@@ -131,8 +144,20 @@ template void Broadcast(const int &l0, const int &l1, const int &l2, const int &
const
int
&
r2
,
const
int
&
r3
,
const
int
&
d0
,
const
int
&
d1
,
const
int
&
d2
,
const
int
&
d3
,
enum
BroadcastOpType
op
,
const
float
*
input0
,
const
float
*
input1
,
float
*
output
,
cudaStream_t
stream
);
template
void
Broadcast
(
const
int
&
l0
,
const
int
&
l1
,
const
int
&
l2
,
const
int
&
l3
,
const
int
&
r0
,
const
int
&
r1
,
const
int
&
r2
,
const
int
&
r3
,
const
int
&
d0
,
const
int
&
d1
,
const
int
&
d2
,
const
int
&
d3
,
enum
BroadcastOpType
op
,
const
half
*
input0
,
const
half
*
input1
,
bool
*
output
,
cudaStream_t
stream
);
template
void
Broadcast
(
const
int
&
l0
,
const
int
&
l1
,
const
int
&
l2
,
const
int
&
l3
,
const
int
&
r0
,
const
int
&
r1
,
const
int
&
r2
,
const
int
&
r3
,
const
int
&
d0
,
const
int
&
d1
,
const
int
&
d2
,
const
int
&
d3
,
enum
BroadcastOpType
op
,
const
half
*
input0
,
const
half
*
input1
,
half
*
output
,
cudaStream_t
stream
);
template
void
NoBroadcast
(
const
int
&
nums
,
enum
BroadcastOpType
op
,
const
float
*
input0
,
const
float
*
input1
,
bool
*
output
,
cudaStream_t
stream
);
template
void
NoBroadcast
(
const
int
&
nums
,
enum
BroadcastOpType
op
,
const
float
*
input0
,
const
float
*
input1
,
float
*
output
,
cudaStream_t
stream
);
template
void
NoBroadcast
(
const
int
&
nums
,
enum
BroadcastOpType
op
,
const
half
*
input0
,
const
half
*
input1
,
bool
*
output
,
cudaStream_t
stream
);
template
void
NoBroadcast
(
const
int
&
nums
,
enum
BroadcastOpType
op
,
const
half
*
input0
,
const
half
*
input1
,
half
*
output
,
cudaStream_t
stream
);
mindspore/ccsrc/kernel/gpu/math/broadcast_gpu_kernel.cc
浏览文件 @
0a1195dd
...
...
@@ -18,6 +18,7 @@
namespace
mindspore
{
namespace
kernel
{
// fp32
MS_REG_GPU_KERNEL_TWO
(
Greater
,
KernelAttr
().
AddInputAttr
(
kNumberTypeFloat32
).
AddInputAttr
(
kNumberTypeFloat32
).
AddOutputAttr
(
kNumberTypeBool
),
...
...
@@ -36,5 +37,25 @@ MS_REG_GPU_KERNEL_TWO(
MS_REG_GPU_KERNEL_TWO
(
Pow
,
KernelAttr
().
AddInputAttr
(
kNumberTypeFloat32
).
AddInputAttr
(
kNumberTypeFloat32
).
AddOutputAttr
(
kNumberTypeFloat32
),
BroadcastOpGpuKernel
,
float
,
float
)
// fp16
MS_REG_GPU_KERNEL_TWO
(
Greater
,
KernelAttr
().
AddInputAttr
(
kNumberTypeFloat16
).
AddInputAttr
(
kNumberTypeFloat16
).
AddOutputAttr
(
kNumberTypeBool
),
BroadcastOpGpuKernel
,
half
,
bool
)
MS_REG_GPU_KERNEL_TWO
(
Less
,
KernelAttr
().
AddInputAttr
(
kNumberTypeFloat16
).
AddInputAttr
(
kNumberTypeFloat16
).
AddOutputAttr
(
kNumberTypeBool
),
BroadcastOpGpuKernel
,
half
,
bool
)
MS_REG_GPU_KERNEL_TWO
(
Maximum
,
KernelAttr
().
AddInputAttr
(
kNumberTypeFloat16
).
AddInputAttr
(
kNumberTypeFloat16
).
AddOutputAttr
(
kNumberTypeFloat16
),
BroadcastOpGpuKernel
,
half
,
half
)
MS_REG_GPU_KERNEL_TWO
(
Minimum
,
KernelAttr
().
AddInputAttr
(
kNumberTypeFloat16
).
AddInputAttr
(
kNumberTypeFloat16
).
AddOutputAttr
(
kNumberTypeFloat16
),
BroadcastOpGpuKernel
,
half
,
half
)
MS_REG_GPU_KERNEL_TWO
(
Pow
,
KernelAttr
().
AddInputAttr
(
kNumberTypeFloat16
).
AddInputAttr
(
kNumberTypeFloat16
).
AddOutputAttr
(
kNumberTypeFloat16
),
BroadcastOpGpuKernel
,
half
,
half
)
}
// namespace kernel
}
// namespace mindspore
mindspore/ccsrc/kernel/gpu/math/broadcast_gpu_kernel.h
浏览文件 @
0a1195dd
...
...
@@ -65,14 +65,19 @@ class BroadcastOpGpuKernel : public GpuKernel {
MS_LOG
(
EXCEPTION
)
<<
"Broadcast operation not support dim greater than 4"
;
}
for
(
size_t
i
=
0
;
i
<
shape1
.
size
();
i
++
)
{
lhs_shape_
[
i
]
=
shape1
[
i
];
rhs_shape_
[
i
]
=
shape2
[
i
];
for
(
size_t
i
=
0
;
i
<
shape3
.
size
();
i
++
)
{
output_shape_
[
i
]
=
shape3
[
i
];
output_num_
*=
shape3
[
i
];
}
int
offset
=
shape3
.
size
()
-
shape1
.
size
();
for
(
size_t
i
=
0
;
i
<
shape1
.
size
();
i
++
)
{
lhs_shape_
[
i
+
offset
]
=
shape1
[
i
];
input1_num_
*=
shape1
[
i
];
}
offset
=
shape3
.
size
()
-
shape2
.
size
();
for
(
size_t
i
=
0
;
i
<
shape2
.
size
();
i
++
)
{
rhs_shape_
[
i
+
offset
]
=
shape2
[
i
];
input2_num_
*=
shape2
[
i
];
output_num_
*=
shape3
[
i
];
}
InitSizeLists
();
...
...
@@ -105,6 +110,9 @@ class BroadcastOpGpuKernel : public GpuKernel {
}
bool
IsBroadcast
(
const
std
::
vector
<
size_t
>
&
lhs
,
const
std
::
vector
<
size_t
>
&
rhs
)
{
if
(
lhs
.
size
()
!=
rhs
.
size
())
{
return
true
;
}
for
(
size_t
i
=
0
;
i
<
lhs
.
size
();
i
++
)
{
if
(
lhs
[
i
]
!=
rhs
[
i
])
{
return
true
;
...
...
tests/st/ops/gpu/test_broadcast_op.py
浏览文件 @
0a1195dd
...
...
@@ -79,3 +79,33 @@ def test_broadcast():
output_ms
=
P
.
Pow
()(
Tensor
(
x1_np
),
Tensor
(
x2_np
))
output_np
=
np
.
power
(
x1_np
,
x2_np
)
assert
np
.
allclose
(
output_ms
.
asnumpy
(),
output_np
)
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
platform_x86_gpu_training
@
pytest
.
mark
.
env_onecard
def
test_broadcast_diff_dims
():
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
'GPU'
)
x1_np
=
np
.
random
.
rand
(
2
).
astype
(
np
.
float32
)
x2_np
=
np
.
random
.
rand
(
2
,
1
).
astype
(
np
.
float32
)
output_ms
=
P
.
Minimum
()(
Tensor
(
x1_np
),
Tensor
(
x2_np
))
output_np
=
np
.
minimum
(
x1_np
,
x2_np
)
assert
np
.
allclose
(
output_ms
.
asnumpy
(),
output_np
)
output_ms
=
P
.
Maximum
()(
Tensor
(
x1_np
),
Tensor
(
x2_np
))
output_np
=
np
.
maximum
(
x1_np
,
x2_np
)
assert
np
.
allclose
(
output_ms
.
asnumpy
(),
output_np
)
output_ms
=
P
.
Greater
()(
Tensor
(
x1_np
),
Tensor
(
x2_np
))
output_np
=
x1_np
>
x2_np
assert
np
.
allclose
(
output_ms
.
asnumpy
(),
output_np
)
output_ms
=
P
.
Less
()(
Tensor
(
x1_np
),
Tensor
(
x2_np
))
output_np
=
x1_np
<
x2_np
assert
np
.
allclose
(
output_ms
.
asnumpy
(),
output_np
)
output_ms
=
P
.
Pow
()(
Tensor
(
x1_np
),
Tensor
(
x2_np
))
output_np
=
np
.
power
(
x1_np
,
x2_np
)
assert
np
.
allclose
(
output_ms
.
asnumpy
(),
output_np
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录