Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
6ebe132c
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
6ebe132c
编写于
9月 07, 2020
作者:
W
wilfChen
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
broadcast refactor
上级
4499d126
变更
6
显示空白变更内容
内联
并排
Showing
6 changed file
with
438 addition
and
247 deletion
+438
-247
mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cu
...c/backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cu
+268
-155
mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cuh
.../backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cuh
+12
-8
mindspore/ccsrc/backend/kernel_compiler/gpu/math/addn_gpu_kernel.h
.../ccsrc/backend/kernel_compiler/gpu/math/addn_gpu_kernel.h
+2
-2
mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_gpu_kernel.cc
.../backend/kernel_compiler/gpu/math/broadcast_gpu_kernel.cc
+64
-64
mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_gpu_kernel.h
...c/backend/kernel_compiler/gpu/math/broadcast_gpu_kernel.h
+50
-18
tests/st/ops/gpu/test_broadcast_op.py
tests/st/ops/gpu/test_broadcast_op.py
+42
-0
未找到文件。
mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cu
浏览文件 @
6ebe132c
...
...
@@ -15,110 +15,216 @@
*/
#include <vector>
#include <iostream>
#include "backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cuh"
#include "runtime/device/gpu/cuda_common.h"
template
<
typename
T
,
typename
S
>
// Basic function
template
<
typename
T
>
struct
GreaterFunc
{
__device__
__
forceinline__
S
operator
()(
const
T
&
lhs
,
const
T
&
rhs
)
{
return
lhs
>
rhs
?
true
:
false
;
}
__device__
__
host__
__forceinline__
bool
operator
()(
const
T
&
lhs
,
const
T
&
rhs
)
{
return
lhs
>
rhs
?
true
:
false
;
}
};
template
<
typename
T
,
typename
S
>
template
<
typename
T
>
struct
LessFunc
{
__device__
__
forceinline__
S
operator
()(
const
T
&
lhs
,
const
T
&
rhs
)
{
return
lhs
<
rhs
?
true
:
false
;
}
__device__
__
host__
__forceinline__
bool
operator
()(
const
T
&
lhs
,
const
T
&
rhs
)
{
return
lhs
<
rhs
?
true
:
false
;
}
};
template
<
typename
T
,
typename
S
>
template
<
typename
T
>
struct
MinimumFunc
{
__device__
__
forceinline__
S
operator
()(
const
T
&
lhs
,
const
T
&
rhs
)
{
return
lhs
<
rhs
?
lhs
:
rhs
;
}
__device__
__
host__
__forceinline__
T
operator
()(
const
T
&
lhs
,
const
T
&
rhs
)
{
return
lhs
<
rhs
?
lhs
:
rhs
;
}
};
template
<
typename
T
,
typename
S
>
template
<
typename
T
>
struct
MaximumFunc
{
__device__
__
forceinline__
S
operator
()(
const
T
&
lhs
,
const
T
&
rhs
)
{
return
lhs
>
rhs
?
lhs
:
rhs
;
}
__device__
__
host__
__forceinline__
T
operator
()(
const
T
&
lhs
,
const
T
&
rhs
)
{
return
lhs
>
rhs
?
lhs
:
rhs
;
}
};
template
<
typename
T
,
typename
S
>
template
<
typename
T
>
struct
PowerFunc
{
__device__
__
forceinline__
S
operator
()(
const
T
&
lhs
,
const
T
&
rhs
)
{
return
pow
(
lhs
,
rhs
);
}
__device__
__
host__
__forceinline__
T
operator
()(
const
T
&
lhs
,
const
T
&
rhs
)
{
return
pow
(
lhs
,
rhs
);
}
};
template
<
>
struct
PowerFunc
<
half
,
half
>
{
__device__
__forceinline__
half
operator
()(
const
half
&
lhs
,
const
half
&
rhs
)
{
struct
PowerFunc
<
half
>
{
__device__
__
host__
__
forceinline__
half
operator
()(
const
half
&
lhs
,
const
half
&
rhs
)
{
return
__float2half
(
pow
(
__half2float
(
lhs
),
__half2float
(
rhs
)));
}
};
template
<
typename
T
,
typename
S
>
template
<
>
struct
PowerFunc
<
half2
>
{
__device__
__host__
__forceinline__
half2
operator
()(
const
half2
&
lhs
,
const
half2
&
rhs
)
{
float2
base
=
__half22float2
(
lhs
);
float2
index
=
__half22float2
(
rhs
);
base
.
x
=
pow
(
base
.
x
,
index
.
x
);
base
.
y
=
pow
(
base
.
y
,
index
.
y
);
return
__float22half2_rn
(
base
);
}
};
template
<
typename
T
>
struct
RealDivFunc
{
__device__
__
forceinline__
S
operator
()(
const
T
&
lhs
,
const
T
&
rhs
)
{
return
(
lhs
/
rhs
);
}
__device__
__
host__
__forceinline__
T
operator
()(
const
T
&
lhs
,
const
T
&
rhs
)
{
return
(
lhs
/
rhs
);
}
};
template
<
typename
T
,
typename
S
>
template
<
typename
T
>
struct
DivFunc
{
__device__
__
forceinline__
S
operator
()(
const
T
&
lhs
,
const
T
&
rhs
)
{
return
(
lhs
/
rhs
);
}
__device__
__
host__
__forceinline__
T
operator
()(
const
T
&
lhs
,
const
T
&
rhs
)
{
return
(
lhs
/
rhs
);
}
};
template
<
typename
T
,
typename
S
>
template
<
typename
T
>
struct
MulFunc
{
__device__
__
forceinline__
S
operator
()(
const
T
&
lhs
,
const
T
&
rhs
)
{
return
(
lhs
*
rhs
);
}
__device__
__
host__
__forceinline__
T
operator
()(
const
T
&
lhs
,
const
T
&
rhs
)
{
return
(
lhs
*
rhs
);
}
};
template
<
typename
T
,
typename
S
>
template
<
typename
T
>
struct
SubFunc
{
__device__
__
forceinline__
S
operator
()(
const
T
&
lhs
,
const
T
&
rhs
)
{
return
(
lhs
-
rhs
);
}
__device__
__
host__
__forceinline__
T
operator
()(
const
T
&
lhs
,
const
T
&
rhs
)
{
return
(
lhs
-
rhs
);
}
};
template
<
typename
T
,
typename
S
>
template
<
typename
T
>
struct
AddFunc
{
__device__
__
forceinline__
S
operator
()(
const
T
&
lhs
,
const
T
&
rhs
)
{
return
(
lhs
+
rhs
);
}
__device__
__
host__
__forceinline__
T
operator
()(
const
T
&
lhs
,
const
T
&
rhs
)
{
return
(
lhs
+
rhs
);
}
};
template
<
typename
T
,
typename
S
>
// convert to float to fix accuracy issue
template
<
typename
T
>
struct
FloorDivFunc
{
__device__
__
forceinline__
S
operator
()(
const
T
&
lhs
,
const
T
&
rhs
)
{
return
floor
(
static_cast
<
float
>
(
lhs
)
/
static_cast
<
float
>
(
rhs
));
__device__
__
host__
__forceinline__
T
operator
()(
const
T
&
lhs
,
const
T
&
rhs
)
{
return
floor
f
(
static_cast
<
float
>
(
lhs
)
/
static_cast
<
float
>
(
rhs
));
}
};
template
<
>
struct
FloorDivFunc
<
half
,
half
>
{
__device__
__forceinline__
half
operator
()(
const
half
&
lhs
,
const
half
&
rhs
)
{
return
__float2half
(
floor
(
__half2float
(
lhs
)
/
__half2float
(
rhs
)
));
struct
FloorDivFunc
<
half
>
{
__device__
__
host__
__
forceinline__
half
operator
()(
const
half
&
lhs
,
const
half
&
rhs
)
{
return
floorf
(
__half2float
(
lhs
)
/
__half2float
(
rhs
));
}
};
template
<
>
struct
FloorDivFunc
<
half
,
bool
>
{
// invalid branch
__device__
__forceinline__
half
operator
()(
const
half
&
lhs
,
const
half
&
rhs
)
{
return
false
;
}
struct
FloorDivFunc
<
half2
>
{
__device__
__host__
__forceinline__
half2
operator
()(
const
half2
&
lhs
,
const
half2
&
rhs
)
{
float2
l
=
__half22float2
(
lhs
);
float2
r
=
__half22float2
(
rhs
);
l
.
x
=
floorf
(
l
.
x
/
r
.
x
);
l
.
y
=
floorf
(
l
.
y
/
r
.
y
);
return
__float22half2_rn
(
l
);
}
};
template
<
typename
T
,
typename
S
>
template
<
typename
T
>
struct
AbsGradFunc
{
__device__
__forceinline__
S
operator
()(
const
T
&
lhs
,
const
T
&
rhs
)
{
__device__
__forceinline__
T
operator
()(
const
T
&
lhs
,
const
T
&
rhs
)
{
T
zero
=
0.0
;
return
lhs
<
zero
?
-
rhs
:
rhs
;
}
};
template
<
>
struct
PowerFunc
<
half
,
bool
>
{
// invalid branch
__device__
__forceinline__
half
operator
()(
const
half
&
lhs
,
const
half
&
rhs
)
{
return
false
;
}
struct
AbsGradFunc
<
half2
>
{
__device__
__forceinline__
half2
operator
()(
const
half2
&
lhs
,
const
half2
&
rhs
)
{
half2
zero
(
0.0
,
0.0
);
return
lhs
<
zero
?
-
rhs
:
rhs
;
}
};
// Element-wise Comparation
template
<
typename
T
,
typename
Func
>
__global__
void
ElewiseCmpKernel
(
const
int
nums
,
const
T
*
x0
,
const
T
*
x1
,
bool
*
y
)
{
for
(
size_t
pos
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
pos
<
nums
;
pos
+=
blockDim
.
x
*
gridDim
.
x
)
{
y
[
pos
]
=
Func
()(
x0
[
pos
],
x1
[
pos
]);
}
}
template
<
typename
T
>
void
ElewiseCmp
(
const
int
&
nums
,
enum
BroadcastOpType
op
,
const
T
*
x0
,
const
T
*
x1
,
bool
*
y
,
cudaStream_t
stream
)
{
switch
(
op
)
{
case
BROADCAST_TYPE_GREATER
:
return
ElewiseCmpKernel
<
T
,
GreaterFunc
<
T
>><<<
(
nums
+
255
)
/
256
,
256
,
0
,
stream
>>>
(
nums
,
x0
,
x1
,
y
);
case
BROADCAST_TYPE_LESS
:
return
ElewiseCmpKernel
<
T
,
LessFunc
<
T
>><<<
(
nums
+
255
)
/
256
,
256
,
0
,
stream
>>>
(
nums
,
x0
,
x1
,
y
);
default:
break
;
}
}
template
void
ElewiseCmp
(
const
int
&
nums
,
enum
BroadcastOpType
op
,
const
float
*
x0
,
const
float
*
x1
,
bool
*
y
,
cudaStream_t
stream
);
template
void
ElewiseCmp
(
const
int
&
nums
,
enum
BroadcastOpType
op
,
const
half
*
x0
,
const
half
*
x1
,
bool
*
y
,
cudaStream_t
stream
);
template
void
ElewiseCmp
(
const
int
&
nums
,
enum
BroadcastOpType
op
,
const
int
*
x0
,
const
int
*
x1
,
bool
*
y
,
cudaStream_t
stream
);
// Element-wise ArithMetic
template
<
typename
T
,
typename
Func
>
__global__
void
ElewiseArithKernel
(
const
int
nums
,
const
T
*
x0
,
const
T
*
x1
,
T
*
y
)
{
for
(
size_t
pos
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
pos
<
nums
;
pos
+=
blockDim
.
x
*
gridDim
.
x
)
{
y
[
pos
]
=
Func
()(
x0
[
pos
],
x1
[
pos
]);
}
}
template
<
typename
T
>
void
ElewiseArithKernel
(
const
int
&
nums
,
enum
BroadcastOpType
op
,
const
T
*
x0
,
const
T
*
x1
,
T
*
y
,
cudaStream_t
stream
)
{
switch
(
op
)
{
case
BROADCAST_TYPE_MINIMUM
:
return
ElewiseArithKernel
<
T
,
MinimumFunc
<
T
>><<<
(
nums
+
255
)
/
256
,
256
,
0
,
stream
>>>
(
nums
,
x0
,
x1
,
y
);
case
BROADCAST_TYPE_MAXIMUM
:
return
ElewiseArithKernel
<
T
,
MaximumFunc
<
T
>><<<
(
nums
+
255
)
/
256
,
256
,
0
,
stream
>>>
(
nums
,
x0
,
x1
,
y
);
case
BROADCAST_TYPE_POWER
:
return
ElewiseArithKernel
<
T
,
PowerFunc
<
T
>><<<
(
nums
+
255
)
/
256
,
256
,
0
,
stream
>>>
(
nums
,
x0
,
x1
,
y
);
case
BROADCAST_TYPE_REALDIV
:
return
ElewiseArithKernel
<
T
,
RealDivFunc
<
T
>><<<
(
nums
+
255
)
/
256
,
256
,
0
,
stream
>>>
(
nums
,
x0
,
x1
,
y
);
case
BROADCAST_TYPE_MUL
:
return
ElewiseArithKernel
<
T
,
MulFunc
<
T
>><<<
(
nums
+
255
)
/
256
,
256
,
0
,
stream
>>>
(
nums
,
x0
,
x1
,
y
);
case
BROADCAST_TYPE_SUB
:
return
ElewiseArithKernel
<
T
,
SubFunc
<
T
>><<<
(
nums
+
255
)
/
256
,
256
,
0
,
stream
>>>
(
nums
,
x0
,
x1
,
y
);
case
BROADCAST_TYPE_ADD
:
return
ElewiseArithKernel
<
T
,
AddFunc
<
T
>><<<
(
nums
+
255
)
/
256
,
256
,
0
,
stream
>>>
(
nums
,
x0
,
x1
,
y
);
case
BROADCAST_TYPE_FLOORDIV
:
return
ElewiseArithKernel
<
T
,
FloorDivFunc
<
T
>><<<
(
nums
+
255
)
/
256
,
256
,
0
,
stream
>>>
(
nums
,
x0
,
x1
,
y
);
case
BROADCAST_TYPE_ABSGRAD
:
return
ElewiseArithKernel
<
T
,
AbsGradFunc
<
T
>><<<
(
nums
+
255
)
/
256
,
256
,
0
,
stream
>>>
(
nums
,
x0
,
x1
,
y
);
case
BROADCAST_TYPE_DIV
:
return
ElewiseArithKernel
<
T
,
DivFunc
<
T
>><<<
(
nums
+
255
)
/
256
,
256
,
0
,
stream
>>>
(
nums
,
x0
,
x1
,
y
);
default:
break
;
}
}
template
<
typename
T
>
void
ElewiseArith
(
const
int
&
nums
,
enum
BroadcastOpType
op
,
const
T
*
x0
,
const
T
*
x1
,
T
*
y
,
cudaStream_t
stream
)
{
return
ElewiseArithKernel
(
nums
,
op
,
x0
,
x1
,
y
,
stream
);
}
template
<
>
void
ElewiseArith
(
const
int
&
nums
,
enum
BroadcastOpType
op
,
const
half
*
x0
,
const
half
*
x1
,
half
*
y
,
cudaStream_t
stream
)
{
if
(
nums
%
2
==
0
)
{
ElewiseArithKernel
<
half2
>
(
nums
/
2
,
op
,
reinterpret_cast
<
const
half2
*>
(
x0
),
reinterpret_cast
<
const
half2
*>
(
x1
),
reinterpret_cast
<
half2
*>
(
y
),
stream
);
}
else
{
return
ElewiseArithKernel
(
nums
,
op
,
x0
,
x1
,
y
,
stream
);
}
}
template
void
ElewiseArith
(
const
int
&
nums
,
enum
BroadcastOpType
op
,
const
float
*
x0
,
const
float
*
x1
,
float
*
y
,
cudaStream_t
stream
);
template
void
ElewiseArith
(
const
int
&
nums
,
enum
BroadcastOpType
op
,
const
half
*
x0
,
const
half
*
x1
,
half
*
y
,
cudaStream_t
stream
);
template
void
ElewiseArith
(
const
int
&
nums
,
enum
BroadcastOpType
op
,
const
int
*
x0
,
const
int
*
x1
,
int
*
y
,
cudaStream_t
stream
);
// Broadcast comparation
__device__
__forceinline__
int
Index
(
const
int
&
index
,
const
int
&
dim
)
{
return
dim
==
1
?
0
:
index
;
}
template
<
typename
T
,
typename
S
,
typename
Func
>
__device__
__forceinline__
void
BroadcastOperator
(
const
int
&
l0
,
const
int
&
l1
,
const
int
&
l2
,
const
int
&
l3
,
const
int
&
l4
,
const
int
&
l5
,
const
int
&
l6
,
const
int
&
r0
,
const
int
&
r1
,
const
int
&
r2
,
const
int
&
r3
,
const
int
&
r4
,
const
int
&
r5
,
const
int
&
r6
,
const
int
&
d0
,
const
int
&
d1
,
const
int
&
d2
,
const
int
&
d3
,
const
int
&
d4
,
const
int
&
d5
,
const
int
&
d6
,
const
T
*
input0
,
const
T
*
input1
,
S
*
output
)
{
template
<
typename
T
,
typename
Func
>
__global__
void
BroadcastCmpKernel
(
const
int
l0
,
const
int
l1
,
const
int
l2
,
const
int
l3
,
const
int
l4
,
const
int
l5
,
const
int
l6
,
const
int
r0
,
const
int
r1
,
const
int
r2
,
const
int
r3
,
const
int
r4
,
const
int
r5
,
const
int
r6
,
const
int
d0
,
const
int
d1
,
const
int
d2
,
const
int
d3
,
const
int
d4
,
const
int
d5
,
const
int
d6
,
const
T
*
x0
,
const
T
*
x1
,
bool
*
y
)
{
for
(
size_t
pos
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
pos
<
d0
*
d1
*
d2
*
d3
*
d4
*
d5
*
d6
;
pos
+=
blockDim
.
x
*
gridDim
.
x
)
{
int
i
=
pos
/
(
d1
*
d2
*
d3
*
d4
*
d5
*
d6
)
%
d0
;
...
...
@@ -143,115 +249,152 @@ __device__ __forceinline__ void BroadcastOperator(const int &l0, const int &l1,
r_index
+=
Index
(
m
,
r4
)
*
r5
*
r6
;
r_index
+=
Index
(
n
,
r5
)
*
r6
;
r_index
+=
Index
(
o
,
r6
);
output
[
pos
]
=
Func
()(
input0
[
l_index
],
input
1
[
r_index
]);
y
[
pos
]
=
Func
()(
x0
[
l_index
],
x
1
[
r_index
]);
}
}
template
<
typename
T
,
typename
S
>
__global__
void
BroadcastKernel
(
const
int
l0
,
const
int
l1
,
const
int
l2
,
const
int
l3
,
const
int
l4
,
const
int
l5
,
const
int
l6
,
const
int
r0
,
const
int
r1
,
const
int
r2
,
const
int
r3
,
const
int
r4
,
const
int
r5
,
const
int
r6
,
const
int
d0
,
const
int
d1
,
const
int
d2
,
const
int
d3
,
const
int
d4
,
const
int
d5
,
const
int
d6
,
enum
BroadcastOpType
op
,
const
T
*
input0
,
const
T
*
input1
,
S
*
output
)
{
template
<
typename
T
>
void
BroadcastCmp
(
const
std
::
vector
<
int
>
&
x0_dims
,
const
std
::
vector
<
int
>
&
x1_dims
,
const
std
::
vector
<
int
>
&
y_dims
,
enum
BroadcastOpType
op
,
const
T
*
x0
,
const
T
*
x1
,
bool
*
y
,
cudaStream_t
stream
)
{
int
size
=
1
;
for
(
auto
d
:
y_dims
)
{
size
*=
d
;
}
switch
(
op
)
{
case
BROADCAST_TYPE_GREATER
:
return
BroadcastOperator
<
T
,
S
,
GreaterFunc
<
T
,
S
>>
(
l0
,
l1
,
l2
,
l3
,
l4
,
l5
,
l6
,
r0
,
r1
,
r2
,
r3
,
r4
,
r5
,
r6
,
d0
,
d1
,
d2
,
d3
,
d4
,
d5
,
d6
,
input0
,
input1
,
output
);
return
BroadcastCmpKernel
<
T
,
GreaterFunc
<
T
>><<<
(
size
+
255
)
/
256
,
256
,
0
,
stream
>>>
(
x0_dims
[
0
],
x0_dims
[
1
],
x0_dims
[
2
],
x0_dims
[
3
],
x0_dims
[
4
],
x0_dims
[
5
],
x0_dims
[
6
],
x1_dims
[
0
],
x1_dims
[
1
],
x1_dims
[
2
],
x1_dims
[
3
],
x1_dims
[
4
],
x1_dims
[
5
],
x1_dims
[
6
],
y_dims
[
0
],
y_dims
[
1
],
y_dims
[
2
],
y_dims
[
3
],
y_dims
[
4
],
y_dims
[
5
],
y_dims
[
6
],
x0
,
x1
,
y
);
case
BROADCAST_TYPE_LESS
:
return
BroadcastOperator
<
T
,
S
,
LessFunc
<
T
,
S
>>
(
l0
,
l1
,
l2
,
l3
,
l4
,
l5
,
l6
,
r0
,
r1
,
r2
,
r3
,
r4
,
r5
,
r6
,
d0
,
d1
,
d2
,
d3
,
d4
,
d5
,
d6
,
input0
,
input1
,
output
);
case
BROADCAST_TYPE_MINIMUM
:
return
BroadcastOperator
<
T
,
S
,
MinimumFunc
<
T
,
S
>>
(
l0
,
l1
,
l2
,
l3
,
l4
,
l5
,
l6
,
r0
,
r1
,
r2
,
r3
,
r4
,
r5
,
r6
,
d0
,
d1
,
d2
,
d3
,
d4
,
d5
,
d6
,
input0
,
input1
,
output
);
case
BROADCAST_TYPE_MAXIMUM
:
return
BroadcastOperator
<
T
,
S
,
MaximumFunc
<
T
,
S
>>
(
l0
,
l1
,
l2
,
l3
,
l4
,
l5
,
l6
,
r0
,
r1
,
r2
,
r3
,
r4
,
r5
,
r6
,
d0
,
d1
,
d2
,
d3
,
d4
,
d5
,
d6
,
input0
,
input1
,
output
);
case
BROADCAST_TYPE_POWER
:
return
BroadcastOperator
<
T
,
S
,
PowerFunc
<
T
,
S
>>
(
l0
,
l1
,
l2
,
l3
,
l4
,
l5
,
l6
,
r0
,
r1
,
r2
,
r3
,
r4
,
r5
,
r6
,
d0
,
d1
,
d2
,
d3
,
d4
,
d5
,
d6
,
input0
,
input1
,
output
);
case
BROADCAST_TYPE_REALDIV
:
return
BroadcastOperator
<
T
,
S
,
RealDivFunc
<
T
,
S
>>
(
l0
,
l1
,
l2
,
l3
,
l4
,
l5
,
l6
,
r0
,
r1
,
r2
,
r3
,
r4
,
r5
,
r6
,
d0
,
d1
,
d2
,
d3
,
d4
,
d5
,
d6
,
input0
,
input1
,
output
);
case
BROADCAST_TYPE_MUL
:
return
BroadcastOperator
<
T
,
S
,
MulFunc
<
T
,
S
>>
(
l0
,
l1
,
l2
,
l3
,
l4
,
l5
,
l6
,
r0
,
r1
,
r2
,
r3
,
r4
,
r5
,
r6
,
d0
,
d1
,
d2
,
d3
,
d4
,
d5
,
d6
,
input0
,
input1
,
output
);
case
BROADCAST_TYPE_SUB
:
return
BroadcastOperator
<
T
,
S
,
SubFunc
<
T
,
S
>>
(
l0
,
l1
,
l2
,
l3
,
l4
,
l5
,
l6
,
r0
,
r1
,
r2
,
r3
,
r4
,
r5
,
r6
,
d0
,
d1
,
d2
,
d3
,
d4
,
d5
,
d6
,
input0
,
input1
,
output
);
case
BROADCAST_TYPE_ADD
:
return
BroadcastOperator
<
T
,
S
,
AddFunc
<
T
,
S
>>
(
l0
,
l1
,
l2
,
l3
,
l4
,
l5
,
l6
,
r0
,
r1
,
r2
,
r3
,
r4
,
r5
,
r6
,
d0
,
d1
,
d2
,
d3
,
d4
,
d5
,
d6
,
input0
,
input1
,
output
);
case
BROADCAST_TYPE_FLOORDIV
:
return
BroadcastOperator
<
T
,
S
,
FloorDivFunc
<
T
,
S
>>
(
l0
,
l1
,
l2
,
l3
,
l4
,
l5
,
l6
,
r0
,
r1
,
r2
,
r3
,
r4
,
r5
,
r6
,
d0
,
d1
,
d2
,
d3
,
d4
,
d5
,
d6
,
input0
,
input1
,
output
);
case
BROADCAST_TYPE_ABSGRAD
:
return
BroadcastOperator
<
T
,
S
,
AbsGradFunc
<
T
,
S
>>
(
l0
,
l1
,
l2
,
l3
,
l4
,
l5
,
l6
,
r0
,
r1
,
r2
,
r3
,
r4
,
r5
,
r6
,
d0
,
d1
,
d2
,
d3
,
d4
,
d5
,
d6
,
input0
,
input1
,
output
);
case
BROADCAST_TYPE_DIV
:
return
BroadcastOperator
<
T
,
S
,
DivFunc
<
T
,
S
>>
(
l0
,
l1
,
l2
,
l3
,
l4
,
l5
,
l6
,
r0
,
r1
,
r2
,
r3
,
r4
,
r5
,
r6
,
d0
,
d1
,
d2
,
d3
,
d4
,
d5
,
d6
,
input0
,
input1
,
output
);
return
BroadcastCmpKernel
<
T
,
LessFunc
<
T
>><<<
(
size
+
255
)
/
256
,
256
,
0
,
stream
>>>
(
x0_dims
[
0
],
x0_dims
[
1
],
x0_dims
[
2
],
x0_dims
[
3
],
x0_dims
[
4
],
x0_dims
[
5
],
x0_dims
[
6
],
x1_dims
[
0
],
x1_dims
[
1
],
x1_dims
[
2
],
x1_dims
[
3
],
x1_dims
[
4
],
x1_dims
[
5
],
x1_dims
[
6
],
y_dims
[
0
],
y_dims
[
1
],
y_dims
[
2
],
y_dims
[
3
],
y_dims
[
4
],
y_dims
[
5
],
y_dims
[
6
],
x0
,
x1
,
y
);
default:
break
;
}
}
template
<
typename
T
,
typename
S
>
void
Broadcast
(
const
std
::
vector
<
int
>
&
lhs_shape
,
const
std
::
vector
<
int
>
&
rhs_shape
,
const
std
::
vector
<
int
>
&
output_shape
,
enum
BroadcastOpType
op
,
const
T
*
input0
,
const
T
*
input1
,
S
*
output
,
cudaStream_t
stream
)
{
int
size
=
1
;
for
(
auto
d
:
output_shape
)
{
size
*=
d
;
}
BroadcastKernel
<<<
GET_BLOCKS
(
size
),
GET_THREADS
,
0
,
stream
>>>
(
lhs_shape
[
0
],
lhs_shape
[
1
],
lhs_shape
[
2
],
lhs_shape
[
3
],
lhs_shape
[
4
],
lhs_shape
[
5
],
lhs_shape
[
6
],
rhs_shape
[
0
],
rhs_shape
[
1
],
rhs_shape
[
2
],
rhs_shape
[
3
],
rhs_shape
[
4
],
rhs_shape
[
5
],
rhs_shape
[
6
],
output_shape
[
0
],
output_shape
[
1
],
output_shape
[
2
],
output_shape
[
3
],
output_shape
[
4
],
output_shape
[
5
],
output_shape
[
6
],
op
,
input0
,
input1
,
output
);
}
template
void
BroadcastCmp
(
const
std
::
vector
<
int
>
&
x0_dims
,
const
std
::
vector
<
int
>
&
x1_dims
,
const
std
::
vector
<
int
>
&
y_dims
,
enum
BroadcastOpType
op
,
const
float
*
x0
,
const
float
*
x1
,
bool
*
y
,
cudaStream_t
stream
);
template
void
BroadcastCmp
(
const
std
::
vector
<
int
>
&
x0_dims
,
const
std
::
vector
<
int
>
&
x1_dims
,
const
std
::
vector
<
int
>
&
y_dims
,
enum
BroadcastOpType
op
,
const
half
*
x0
,
const
half
*
x1
,
bool
*
y
,
cudaStream_t
stream
);
template
void
BroadcastCmp
(
const
std
::
vector
<
int
>
&
x0_dims
,
const
std
::
vector
<
int
>
&
x1_dims
,
const
std
::
vector
<
int
>
&
y_dims
,
enum
BroadcastOpType
op
,
const
int
*
x0
,
const
int
*
x1
,
bool
*
y
,
cudaStream_t
stream
);
template
<
typename
T
,
typename
S
,
typename
Func
>
__device__
__forceinline__
void
NoBroadcastOperator
(
const
int
&
nums
,
const
T
*
input0
,
const
T
*
input1
,
S
*
output
)
{
for
(
size_t
pos
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
pos
<
nums
;
pos
+=
blockDim
.
x
*
gridDim
.
x
)
{
output
[
pos
]
=
Func
()(
input0
[
pos
],
input1
[
pos
]);
// Broadcast Arithmetic
template
<
typename
T
,
typename
Func
>
__global__
void
BroadcastArithKernel
(
const
int
l0
,
const
int
l1
,
const
int
l2
,
const
int
l3
,
const
int
l4
,
const
int
l5
,
const
int
l6
,
const
int
r0
,
const
int
r1
,
const
int
r2
,
const
int
r3
,
const
int
r4
,
const
int
r5
,
const
int
r6
,
const
int
d0
,
const
int
d1
,
const
int
d2
,
const
int
d3
,
const
int
d4
,
const
int
d5
,
const
int
d6
,
const
T
*
x0
,
const
T
*
x1
,
T
*
y
)
{
for
(
size_t
pos
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
pos
<
d0
*
d1
*
d2
*
d3
*
d4
*
d5
*
d6
;
pos
+=
blockDim
.
x
*
gridDim
.
x
)
{
int
i
=
pos
/
(
d1
*
d2
*
d3
*
d4
*
d5
*
d6
)
%
d0
;
int
j
=
pos
/
(
d2
*
d3
*
d4
*
d5
*
d6
)
%
d1
;
int
k
=
pos
/
(
d3
*
d4
*
d5
*
d6
)
%
d2
;
int
l
=
pos
/
(
d4
*
d5
*
d6
)
%
d3
;
int
m
=
pos
/
(
d5
*
d6
)
%
d4
;
int
n
=
pos
/
d6
%
d5
;
int
o
=
pos
%
d6
;
int
l_index
=
Index
(
i
,
l0
)
*
l1
*
l2
*
l3
*
l4
*
l5
*
l6
;
l_index
+=
Index
(
j
,
l1
)
*
l2
*
l3
*
l4
*
l5
*
l6
;
l_index
+=
Index
(
k
,
l2
)
*
l3
*
l4
*
l5
*
l6
;
l_index
+=
Index
(
l
,
l3
)
*
l4
*
l5
*
l6
;
l_index
+=
Index
(
m
,
l4
)
*
l5
*
l6
;
l_index
+=
Index
(
n
,
l5
)
*
l6
;
l_index
+=
Index
(
o
,
l6
);
int
r_index
=
Index
(
i
,
r0
)
*
r1
*
r2
*
r3
*
r4
*
r5
*
r6
;
r_index
+=
Index
(
j
,
r1
)
*
r2
*
r3
*
r4
*
r5
*
r6
;
r_index
+=
Index
(
k
,
r2
)
*
r3
*
r4
*
r5
*
r6
;
r_index
+=
Index
(
l
,
r3
)
*
r4
*
r5
*
r6
;
r_index
+=
Index
(
m
,
r4
)
*
r5
*
r6
;
r_index
+=
Index
(
n
,
r5
)
*
r6
;
r_index
+=
Index
(
o
,
r6
);
y
[
pos
]
=
Func
()(
x0
[
l_index
],
x1
[
r_index
]);
}
}
template
<
typename
T
,
typename
S
>
__global__
void
NoBroadcastKernel
(
const
int
nums
,
enum
BroadcastOpType
op
,
const
T
*
input0
,
const
T
*
input1
,
S
*
output
)
{
template
<
typename
T
>
void
BroadcastArith
(
const
std
::
vector
<
int
>
&
x0_dims
,
const
std
::
vector
<
int
>
&
x1_dims
,
const
std
::
vector
<
int
>
&
y_dims
,
enum
BroadcastOpType
op
,
const
T
*
x0
,
const
T
*
x1
,
T
*
y
,
cudaStream_t
stream
)
{
int
size
=
1
;
for
(
auto
d
:
y_dims
)
{
size
*=
d
;
}
switch
(
op
)
{
case
BROADCAST_TYPE_GREATER
:
return
NoBroadcastOperator
<
T
,
S
,
GreaterFunc
<
T
,
bool
>>
(
nums
,
input0
,
input1
,
output
);
case
BROADCAST_TYPE_LESS
:
return
NoBroadcastOperator
<
T
,
S
,
LessFunc
<
T
,
bool
>>
(
nums
,
input0
,
input1
,
output
);
case
BROADCAST_TYPE_MINIMUM
:
return
NoBroadcastOperator
<
T
,
S
,
MinimumFunc
<
T
,
S
>>
(
nums
,
input0
,
input1
,
output
);
case
BROADCAST_TYPE_MAXIMUM
:
return
NoBroadcastOperator
<
T
,
S
,
MaximumFunc
<
T
,
S
>>
(
nums
,
input0
,
input1
,
output
);
return
BroadcastArithKernel
<
T
,
MaximumFunc
<
T
>><<<
(
size
+
255
)
/
256
,
256
,
0
,
stream
>>>
(
x0_dims
[
0
],
x0_dims
[
1
],
x0_dims
[
2
],
x0_dims
[
3
],
x0_dims
[
4
],
x0_dims
[
5
],
x0_dims
[
6
],
x1_dims
[
0
],
x1_dims
[
1
],
x1_dims
[
2
],
x1_dims
[
3
],
x1_dims
[
4
],
x1_dims
[
5
],
x1_dims
[
6
],
y_dims
[
0
],
y_dims
[
1
],
y_dims
[
2
],
y_dims
[
3
],
y_dims
[
4
],
y_dims
[
5
],
y_dims
[
6
],
x0
,
x1
,
y
);
case
BROADCAST_TYPE_MINIMUM
:
return
BroadcastArithKernel
<
T
,
MinimumFunc
<
T
>><<<
(
size
+
255
)
/
256
,
256
,
0
,
stream
>>>
(
x0_dims
[
0
],
x0_dims
[
1
],
x0_dims
[
2
],
x0_dims
[
3
],
x0_dims
[
4
],
x0_dims
[
5
],
x0_dims
[
6
],
x1_dims
[
0
],
x1_dims
[
1
],
x1_dims
[
2
],
x1_dims
[
3
],
x1_dims
[
4
],
x1_dims
[
5
],
x1_dims
[
6
],
y_dims
[
0
],
y_dims
[
1
],
y_dims
[
2
],
y_dims
[
3
],
y_dims
[
4
],
y_dims
[
5
],
y_dims
[
6
],
x0
,
x1
,
y
);
case
BROADCAST_TYPE_POWER
:
return
NoBroadcastOperator
<
T
,
S
,
PowerFunc
<
T
,
S
>>
(
nums
,
input0
,
input1
,
output
);
return
BroadcastArithKernel
<
T
,
PowerFunc
<
T
>><<<
(
size
+
255
)
/
256
,
256
,
0
,
stream
>>>
(
x0_dims
[
0
],
x0_dims
[
1
],
x0_dims
[
2
],
x0_dims
[
3
],
x0_dims
[
4
],
x0_dims
[
5
],
x0_dims
[
6
],
x1_dims
[
0
],
x1_dims
[
1
],
x1_dims
[
2
],
x1_dims
[
3
],
x1_dims
[
4
],
x1_dims
[
5
],
x1_dims
[
6
],
y_dims
[
0
],
y_dims
[
1
],
y_dims
[
2
],
y_dims
[
3
],
y_dims
[
4
],
y_dims
[
5
],
y_dims
[
6
],
x0
,
x1
,
y
);
case
BROADCAST_TYPE_REALDIV
:
return
NoBroadcastOperator
<
T
,
S
,
RealDivFunc
<
T
,
S
>>
(
nums
,
input0
,
input1
,
output
);
return
BroadcastArithKernel
<
T
,
RealDivFunc
<
T
>><<<
(
size
+
255
)
/
256
,
256
,
0
,
stream
>>>
(
x0_dims
[
0
],
x0_dims
[
1
],
x0_dims
[
2
],
x0_dims
[
3
],
x0_dims
[
4
],
x0_dims
[
5
],
x0_dims
[
6
],
x1_dims
[
0
],
x1_dims
[
1
],
x1_dims
[
2
],
x1_dims
[
3
],
x1_dims
[
4
],
x1_dims
[
5
],
x1_dims
[
6
],
y_dims
[
0
],
y_dims
[
1
],
y_dims
[
2
],
y_dims
[
3
],
y_dims
[
4
],
y_dims
[
5
],
y_dims
[
6
],
x0
,
x1
,
y
);
case
BROADCAST_TYPE_MUL
:
return
NoBroadcastOperator
<
T
,
S
,
MulFunc
<
T
,
S
>>
(
nums
,
input0
,
input1
,
output
);
return
BroadcastArithKernel
<
T
,
MulFunc
<
T
>><<<
(
size
+
255
)
/
256
,
256
,
0
,
stream
>>>
(
x0_dims
[
0
],
x0_dims
[
1
],
x0_dims
[
2
],
x0_dims
[
3
],
x0_dims
[
4
],
x0_dims
[
5
],
x0_dims
[
6
],
x1_dims
[
0
],
x1_dims
[
1
],
x1_dims
[
2
],
x1_dims
[
3
],
x1_dims
[
4
],
x1_dims
[
5
],
x1_dims
[
6
],
y_dims
[
0
],
y_dims
[
1
],
y_dims
[
2
],
y_dims
[
3
],
y_dims
[
4
],
y_dims
[
5
],
y_dims
[
6
],
x0
,
x1
,
y
);
case
BROADCAST_TYPE_SUB
:
return
NoBroadcastOperator
<
T
,
S
,
SubFunc
<
T
,
S
>>
(
nums
,
input0
,
input1
,
output
);
return
BroadcastArithKernel
<
T
,
SubFunc
<
T
>><<<
(
size
+
255
)
/
256
,
256
,
0
,
stream
>>>
(
x0_dims
[
0
],
x0_dims
[
1
],
x0_dims
[
2
],
x0_dims
[
3
],
x0_dims
[
4
],
x0_dims
[
5
],
x0_dims
[
6
],
x1_dims
[
0
],
x1_dims
[
1
],
x1_dims
[
2
],
x1_dims
[
3
],
x1_dims
[
4
],
x1_dims
[
5
],
x1_dims
[
6
],
y_dims
[
0
],
y_dims
[
1
],
y_dims
[
2
],
y_dims
[
3
],
y_dims
[
4
],
y_dims
[
5
],
y_dims
[
6
],
x0
,
x1
,
y
);
case
BROADCAST_TYPE_ADD
:
return
NoBroadcastOperator
<
T
,
S
,
AddFunc
<
T
,
S
>>
(
nums
,
input0
,
input1
,
output
);
return
BroadcastArithKernel
<
T
,
AddFunc
<
T
>><<<
(
size
+
255
)
/
256
,
256
,
0
,
stream
>>>
(
x0_dims
[
0
],
x0_dims
[
1
],
x0_dims
[
2
],
x0_dims
[
3
],
x0_dims
[
4
],
x0_dims
[
5
],
x0_dims
[
6
],
x1_dims
[
0
],
x1_dims
[
1
],
x1_dims
[
2
],
x1_dims
[
3
],
x1_dims
[
4
],
x1_dims
[
5
],
x1_dims
[
6
],
y_dims
[
0
],
y_dims
[
1
],
y_dims
[
2
],
y_dims
[
3
],
y_dims
[
4
],
y_dims
[
5
],
y_dims
[
6
],
x0
,
x1
,
y
);
case
BROADCAST_TYPE_FLOORDIV
:
return
NoBroadcastOperator
<
T
,
S
,
FloorDivFunc
<
T
,
S
>>
(
nums
,
input0
,
input1
,
output
);
return
BroadcastArithKernel
<
T
,
FloorDivFunc
<
T
>><<<
(
size
+
255
)
/
256
,
256
,
0
,
stream
>>>
(
x0_dims
[
0
],
x0_dims
[
1
],
x0_dims
[
2
],
x0_dims
[
3
],
x0_dims
[
4
],
x0_dims
[
5
],
x0_dims
[
6
],
x1_dims
[
0
],
x1_dims
[
1
],
x1_dims
[
2
],
x1_dims
[
3
],
x1_dims
[
4
],
x1_dims
[
5
],
x1_dims
[
6
],
y_dims
[
0
],
y_dims
[
1
],
y_dims
[
2
],
y_dims
[
3
],
y_dims
[
4
],
y_dims
[
5
],
y_dims
[
6
],
x0
,
x1
,
y
);
case
BROADCAST_TYPE_ABSGRAD
:
return
NoBroadcastOperator
<
T
,
S
,
AbsGradFunc
<
T
,
S
>>
(
nums
,
input0
,
input1
,
output
);
return
BroadcastArithKernel
<
T
,
AbsGradFunc
<
T
>><<<
(
size
+
255
)
/
256
,
256
,
0
,
stream
>>>
(
x0_dims
[
0
],
x0_dims
[
1
],
x0_dims
[
2
],
x0_dims
[
3
],
x0_dims
[
4
],
x0_dims
[
5
],
x0_dims
[
6
],
x1_dims
[
0
],
x1_dims
[
1
],
x1_dims
[
2
],
x1_dims
[
3
],
x1_dims
[
4
],
x1_dims
[
5
],
x1_dims
[
6
],
y_dims
[
0
],
y_dims
[
1
],
y_dims
[
2
],
y_dims
[
3
],
y_dims
[
4
],
y_dims
[
5
],
y_dims
[
6
],
x0
,
x1
,
y
);
case
BROADCAST_TYPE_DIV
:
return
NoBroadcastOperator
<
T
,
S
,
DivFunc
<
T
,
S
>>
(
nums
,
input0
,
input1
,
output
);
return
BroadcastArithKernel
<
T
,
DivFunc
<
T
>><<<
(
size
+
255
)
/
256
,
256
,
0
,
stream
>>>
(
x0_dims
[
0
],
x0_dims
[
1
],
x0_dims
[
2
],
x0_dims
[
3
],
x0_dims
[
4
],
x0_dims
[
5
],
x0_dims
[
6
],
x1_dims
[
0
],
x1_dims
[
1
],
x1_dims
[
2
],
x1_dims
[
3
],
x1_dims
[
4
],
x1_dims
[
5
],
x1_dims
[
6
],
y_dims
[
0
],
y_dims
[
1
],
y_dims
[
2
],
y_dims
[
3
],
y_dims
[
4
],
y_dims
[
5
],
y_dims
[
6
],
x0
,
x1
,
y
);
default:
break
;
}
}
template
<
typename
T
,
typename
S
>
void
NoBroadcast
(
const
int
&
nums
,
enum
BroadcastOpType
op
,
const
T
*
input0
,
const
T
*
input1
,
S
*
output
,
cudaStream_t
stream
)
{
NoBroadcastKernel
<<<
GET_BLOCKS
(
nums
),
GET_THREADS
,
0
,
stream
>>>
(
nums
,
op
,
input0
,
input1
,
output
);
}
template
void
BroadcastArith
(
const
std
::
vector
<
int
>
&
x0_dims
,
const
std
::
vector
<
int
>
&
x1_dims
,
const
std
::
vector
<
int
>
&
y_dims
,
enum
BroadcastOpType
op
,
const
float
*
x0
,
const
float
*
x1
,
float
*
y
,
cudaStream_t
stream
);
template
void
BroadcastArith
(
const
std
::
vector
<
int
>
&
x0_dims
,
const
std
::
vector
<
int
>
&
x1_dims
,
const
std
::
vector
<
int
>
&
y_dims
,
enum
BroadcastOpType
op
,
const
half
*
x0
,
const
half
*
x1
,
half
*
y
,
cudaStream_t
stream
);
template
void
BroadcastArith
(
const
std
::
vector
<
int
>
&
x0_dims
,
const
std
::
vector
<
int
>
&
x1_dims
,
const
std
::
vector
<
int
>
&
y_dims
,
enum
BroadcastOpType
op
,
const
int
*
x0
,
const
int
*
x1
,
int
*
y
,
cudaStream_t
stream
);
// BroadcastTo
template
<
typename
T
>
__global__
void
BroadcastToKernel
(
const
int
i0
,
const
int
i1
,
const
int
i2
,
const
int
i3
,
const
int
o0
,
const
int
o1
,
const
int
o2
,
const
int
o3
,
const
T
*
input_addr
,
T
*
output_addr
)
{
...
...
@@ -274,36 +417,6 @@ void BroadcastTo(const int &i0, const int &i1, const int &i2, const int &i3, con
output_addr
);
}
template
void
Broadcast
(
const
std
::
vector
<
int
>
&
lhs_shape
,
const
std
::
vector
<
int
>
&
rhs_shape
,
const
std
::
vector
<
int
>
&
output_shape
,
enum
BroadcastOpType
op
,
const
float
*
input0
,
const
float
*
input1
,
bool
*
output
,
cudaStream_t
stream
);
template
void
Broadcast
(
const
std
::
vector
<
int
>
&
lhs_shape
,
const
std
::
vector
<
int
>
&
rhs_shape
,
const
std
::
vector
<
int
>
&
output_shape
,
enum
BroadcastOpType
op
,
const
float
*
input0
,
const
float
*
input1
,
float
*
output
,
cudaStream_t
stream
);
template
void
Broadcast
(
const
std
::
vector
<
int
>
&
lhs_shape
,
const
std
::
vector
<
int
>
&
rhs_shape
,
const
std
::
vector
<
int
>
&
output_shape
,
enum
BroadcastOpType
op
,
const
half
*
input0
,
const
half
*
input1
,
bool
*
output
,
cudaStream_t
stream
);
template
void
Broadcast
(
const
std
::
vector
<
int
>
&
lhs_shape
,
const
std
::
vector
<
int
>
&
rhs_shape
,
const
std
::
vector
<
int
>
&
output_shape
,
enum
BroadcastOpType
op
,
const
half
*
input0
,
const
half
*
input1
,
half
*
output
,
cudaStream_t
stream
);
template
void
Broadcast
(
const
std
::
vector
<
int
>
&
lhs_shape
,
const
std
::
vector
<
int
>
&
rhs_shape
,
const
std
::
vector
<
int
>
&
output_shape
,
enum
BroadcastOpType
op
,
const
int
*
input0
,
const
int
*
input1
,
int
*
output
,
cudaStream_t
stream
);
template
void
Broadcast
(
const
std
::
vector
<
int
>
&
lhs_shape
,
const
std
::
vector
<
int
>
&
rhs_shape
,
const
std
::
vector
<
int
>
&
output_shape
,
enum
BroadcastOpType
op
,
const
int
*
input0
,
const
int
*
input1
,
bool
*
output
,
cudaStream_t
stream
);
template
void
NoBroadcast
(
const
int
&
nums
,
enum
BroadcastOpType
op
,
const
float
*
input0
,
const
float
*
input1
,
bool
*
output
,
cudaStream_t
stream
);
template
void
NoBroadcast
(
const
int
&
nums
,
enum
BroadcastOpType
op
,
const
float
*
input0
,
const
float
*
input1
,
float
*
output
,
cudaStream_t
stream
);
template
void
NoBroadcast
(
const
int
&
nums
,
enum
BroadcastOpType
op
,
const
half
*
input0
,
const
half
*
input1
,
bool
*
output
,
cudaStream_t
stream
);
template
void
NoBroadcast
(
const
int
&
nums
,
enum
BroadcastOpType
op
,
const
half
*
input0
,
const
half
*
input1
,
half
*
output
,
cudaStream_t
stream
);
template
void
NoBroadcast
(
const
int
&
nums
,
enum
BroadcastOpType
op
,
const
int
*
input0
,
const
int
*
input1
,
int
*
output
,
cudaStream_t
stream
);
template
void
NoBroadcast
(
const
int
&
nums
,
enum
BroadcastOpType
op
,
const
int
*
input0
,
const
int
*
input1
,
bool
*
output
,
cudaStream_t
stream
);
template
void
BroadcastTo
(
const
int
&
i0
,
const
int
&
i1
,
const
int
&
i2
,
const
int
&
i3
,
const
int
&
o0
,
const
int
&
o1
,
const
int
&
o2
,
const
int
&
o3
,
const
float
*
input_addr
,
float
*
output_addr
,
cudaStream_t
stream
);
...
...
mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cuh
浏览文件 @
6ebe132c
...
...
@@ -36,17 +36,21 @@ enum BroadcastOpType {
BROADCAST_TYPE_INVALID
=
0xffffffff
,
};
template
<
typename
T
,
typename
S
>
void
Broadcast
(
const
std
::
vector
<
int
>
&
lhs_shape
,
const
std
::
vector
<
int
>
&
rhs_shape
,
const
std
::
vector
<
int
>
&
output_shape
,
enum
BroadcastOpType
op
,
const
T
*
input0
,
const
T
*
input1
,
S
*
output
,
cudaStream_t
stream
);
template
<
typename
T
>
void
ElewiseCmp
(
const
int
&
nums
,
enum
BroadcastOpType
op
,
const
T
*
x0
,
const
T
*
x1
,
bool
*
y
,
cudaStream_t
stream
);
template
<
typename
T
>
void
ElewiseArith
(
const
int
&
nums
,
enum
BroadcastOpType
op
,
const
T
*
x0
,
const
T
*
x1
,
T
*
y
,
cudaStream_t
stream
);
template
<
typename
T
,
typename
S
>
void
NoBroadcast
(
const
int
&
size
,
enum
BroadcastOpType
op
,
const
T
*
input0
,
const
T
*
input1
,
S
*
output
,
cudaStream_t
stream
);
template
<
typename
T
>
void
BroadcastCmp
(
const
std
::
vector
<
int
>
&
x0_dims
,
const
std
::
vector
<
int
>
&
x1_dims
,
const
std
::
vector
<
int
>
&
y_dims
,
enum
BroadcastOpType
op
,
const
T
*
x0
,
const
T
*
x1
,
bool
*
y
,
cudaStream_t
stream
);
template
<
typename
T
>
void
BroadcastArith
(
const
std
::
vector
<
int
>
&
x0_dims
,
const
std
::
vector
<
int
>
&
x1_dims
,
const
std
::
vector
<
int
>
&
y_dims
,
enum
BroadcastOpType
op
,
const
T
*
x0
,
const
T
*
x1
,
T
*
y
,
cudaStream_t
stream
);
template
<
typename
T
>
void
BroadcastTo
(
const
int
&
i0
,
const
int
&
i1
,
const
int
&
i2
,
const
int
&
i3
,
const
int
&
o0
,
const
int
&
o1
,
const
int
&
o2
,
const
int
&
o3
,
const
T
*
input_addr
,
T
*
output_addr
,
cudaStream_t
stream
);
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_BROADCAST_H_
mindspore/ccsrc/backend/kernel_compiler/gpu/math/addn_gpu_kernel.h
浏览文件 @
6ebe132c
...
...
@@ -58,7 +58,7 @@ class AddNGpuFwdKernel : public GpuKernel {
for
(
size_t
i
=
0
;
i
<
IntToSize
(
num_input_
);
i
++
)
{
T
*
input_addr
=
GetDeviceAddress
<
T
>
(
inputs
,
i
);
if
(
cudnn_data_type_
==
CUDNN_DATA_INT32
)
{
NoBroadcast
(
outputs
[
0
]
->
size
/
sizeof
(
T
),
BROADCAST_TYPE_ADD
,
input_addr
,
output_addr
,
output_addr
,
ElewiseArith
(
outputs
[
0
]
->
size
/
sizeof
(
T
),
BROADCAST_TYPE_ADD
,
input_addr
,
output_addr
,
output_addr
,
reinterpret_cast
<
cudaStream_t
>
(
stream_ptr
));
}
else
{
CHECK_CUDNN_RET_WITH_EXCEPT
(
cudnnAddTensor
(
cudnn_handle_
,
&
alpha
,
input_descriptor_
,
input_addr
,
...
...
mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_gpu_kernel.cc
浏览文件 @
6ebe132c
...
...
@@ -19,119 +19,119 @@
namespace
mindspore
{
namespace
kernel
{
// fp32
MS_REG_GPU_KERNEL_
TWO
(
MS_REG_GPU_KERNEL_
ONE
(
Greater
,
KernelAttr
().
AddInputAttr
(
kNumberTypeFloat32
).
AddInputAttr
(
kNumberTypeFloat32
).
AddOutputAttr
(
kNumberTypeBool
),
BroadcastOpGpuKernel
,
float
,
bool
)
MS_REG_GPU_KERNEL_
TWO
(
BroadcastOpGpuKernel
,
float
)
MS_REG_GPU_KERNEL_
ONE
(
Less
,
KernelAttr
().
AddInputAttr
(
kNumberTypeFloat32
).
AddInputAttr
(
kNumberTypeFloat32
).
AddOutputAttr
(
kNumberTypeBool
),
BroadcastOpGpuKernel
,
float
,
bool
)
MS_REG_GPU_KERNEL_
TWO
(
BroadcastOpGpuKernel
,
float
)
MS_REG_GPU_KERNEL_
ONE
(
Maximum
,
KernelAttr
().
AddInputAttr
(
kNumberTypeFloat32
).
AddInputAttr
(
kNumberTypeFloat32
).
AddOutputAttr
(
kNumberTypeFloat32
),
BroadcastOpGpuKernel
,
float
,
float
)
MS_REG_GPU_KERNEL_
TWO
(
BroadcastOpGpuKernel
,
float
)
MS_REG_GPU_KERNEL_
ONE
(
Minimum
,
KernelAttr
().
AddInputAttr
(
kNumberTypeFloat32
).
AddInputAttr
(
kNumberTypeFloat32
).
AddOutputAttr
(
kNumberTypeFloat32
),
BroadcastOpGpuKernel
,
float
,
float
)
MS_REG_GPU_KERNEL_
TWO
(
BroadcastOpGpuKernel
,
float
)
MS_REG_GPU_KERNEL_
ONE
(
Pow
,
KernelAttr
().
AddInputAttr
(
kNumberTypeFloat32
).
AddInputAttr
(
kNumberTypeFloat32
).
AddOutputAttr
(
kNumberTypeFloat32
),
BroadcastOpGpuKernel
,
float
,
float
)
MS_REG_GPU_KERNEL_
TWO
(
BroadcastOpGpuKernel
,
float
)
MS_REG_GPU_KERNEL_
ONE
(
RealDiv
,
KernelAttr
().
AddInputAttr
(
kNumberTypeFloat32
).
AddInputAttr
(
kNumberTypeFloat32
).
AddOutputAttr
(
kNumberTypeFloat32
),
BroadcastOpGpuKernel
,
float
,
float
)
MS_REG_GPU_KERNEL_
TWO
(
BroadcastOpGpuKernel
,
float
)
MS_REG_GPU_KERNEL_
ONE
(
Mul
,
KernelAttr
().
AddInputAttr
(
kNumberTypeFloat32
).
AddInputAttr
(
kNumberTypeFloat32
).
AddOutputAttr
(
kNumberTypeFloat32
),
BroadcastOpGpuKernel
,
float
,
float
)
MS_REG_GPU_KERNEL_
TWO
(
BroadcastOpGpuKernel
,
float
)
MS_REG_GPU_KERNEL_
ONE
(
Sub
,
KernelAttr
().
AddInputAttr
(
kNumberTypeFloat32
).
AddInputAttr
(
kNumberTypeFloat32
).
AddOutputAttr
(
kNumberTypeFloat32
),
BroadcastOpGpuKernel
,
float
,
float
)
MS_REG_GPU_KERNEL_
TWO
(
BroadcastOpGpuKernel
,
float
)
MS_REG_GPU_KERNEL_
ONE
(
TensorAdd
,
KernelAttr
().
AddInputAttr
(
kNumberTypeFloat32
).
AddInputAttr
(
kNumberTypeFloat32
).
AddOutputAttr
(
kNumberTypeFloat32
),
BroadcastOpGpuKernel
,
float
,
float
)
MS_REG_GPU_KERNEL_
TWO
(
BroadcastOpGpuKernel
,
float
)
MS_REG_GPU_KERNEL_
ONE
(
FloorDiv
,
KernelAttr
().
AddInputAttr
(
kNumberTypeFloat32
).
AddInputAttr
(
kNumberTypeFloat32
).
AddOutputAttr
(
kNumberTypeFloat32
),
BroadcastOpGpuKernel
,
float
,
float
)
MS_REG_GPU_KERNEL_
TWO
(
BroadcastOpGpuKernel
,
float
)
MS_REG_GPU_KERNEL_
ONE
(
AbsGrad
,
KernelAttr
().
AddInputAttr
(
kNumberTypeFloat32
).
AddInputAttr
(
kNumberTypeFloat32
).
AddOutputAttr
(
kNumberTypeFloat32
),
BroadcastOpGpuKernel
,
float
,
float
)
MS_REG_GPU_KERNEL_
TWO
(
BroadcastOpGpuKernel
,
float
)
MS_REG_GPU_KERNEL_
ONE
(
Div
,
KernelAttr
().
AddInputAttr
(
kNumberTypeFloat32
).
AddInputAttr
(
kNumberTypeFloat32
).
AddOutputAttr
(
kNumberTypeFloat32
),
BroadcastOpGpuKernel
,
float
,
float
)
BroadcastOpGpuKernel
,
float
)
// fp16
MS_REG_GPU_KERNEL_
TWO
(
MS_REG_GPU_KERNEL_
ONE
(
Greater
,
KernelAttr
().
AddInputAttr
(
kNumberTypeFloat16
).
AddInputAttr
(
kNumberTypeFloat16
).
AddOutputAttr
(
kNumberTypeBool
),
BroadcastOpGpuKernel
,
half
,
bool
)
MS_REG_GPU_KERNEL_
TWO
(
BroadcastOpGpuKernel
,
half
)
MS_REG_GPU_KERNEL_
ONE
(
Less
,
KernelAttr
().
AddInputAttr
(
kNumberTypeFloat16
).
AddInputAttr
(
kNumberTypeFloat16
).
AddOutputAttr
(
kNumberTypeBool
),
BroadcastOpGpuKernel
,
half
,
bool
)
MS_REG_GPU_KERNEL_
TWO
(
BroadcastOpGpuKernel
,
half
)
MS_REG_GPU_KERNEL_
ONE
(
Maximum
,
KernelAttr
().
AddInputAttr
(
kNumberTypeFloat16
).
AddInputAttr
(
kNumberTypeFloat16
).
AddOutputAttr
(
kNumberTypeFloat16
),
BroadcastOpGpuKernel
,
half
,
half
)
MS_REG_GPU_KERNEL_
TWO
(
BroadcastOpGpuKernel
,
half
)
MS_REG_GPU_KERNEL_
ONE
(
Minimum
,
KernelAttr
().
AddInputAttr
(
kNumberTypeFloat16
).
AddInputAttr
(
kNumberTypeFloat16
).
AddOutputAttr
(
kNumberTypeFloat16
),
BroadcastOpGpuKernel
,
half
,
half
)
MS_REG_GPU_KERNEL_
TWO
(
BroadcastOpGpuKernel
,
half
)
MS_REG_GPU_KERNEL_
ONE
(
Pow
,
KernelAttr
().
AddInputAttr
(
kNumberTypeFloat16
).
AddInputAttr
(
kNumberTypeFloat16
).
AddOutputAttr
(
kNumberTypeFloat16
),
BroadcastOpGpuKernel
,
half
,
half
)
MS_REG_GPU_KERNEL_
TWO
(
BroadcastOpGpuKernel
,
half
)
MS_REG_GPU_KERNEL_
ONE
(
RealDiv
,
KernelAttr
().
AddInputAttr
(
kNumberTypeFloat16
).
AddInputAttr
(
kNumberTypeFloat16
).
AddOutputAttr
(
kNumberTypeFloat16
),
BroadcastOpGpuKernel
,
half
,
half
)
MS_REG_GPU_KERNEL_
TWO
(
BroadcastOpGpuKernel
,
half
)
MS_REG_GPU_KERNEL_
ONE
(
Mul
,
KernelAttr
().
AddInputAttr
(
kNumberTypeFloat16
).
AddInputAttr
(
kNumberTypeFloat16
).
AddOutputAttr
(
kNumberTypeFloat16
),
BroadcastOpGpuKernel
,
half
,
half
)
MS_REG_GPU_KERNEL_
TWO
(
BroadcastOpGpuKernel
,
half
)
MS_REG_GPU_KERNEL_
ONE
(
Sub
,
KernelAttr
().
AddInputAttr
(
kNumberTypeFloat16
).
AddInputAttr
(
kNumberTypeFloat16
).
AddOutputAttr
(
kNumberTypeFloat16
),
BroadcastOpGpuKernel
,
half
,
half
)
MS_REG_GPU_KERNEL_
TWO
(
BroadcastOpGpuKernel
,
half
)
MS_REG_GPU_KERNEL_
ONE
(
TensorAdd
,
KernelAttr
().
AddInputAttr
(
kNumberTypeFloat16
).
AddInputAttr
(
kNumberTypeFloat16
).
AddOutputAttr
(
kNumberTypeFloat16
),
BroadcastOpGpuKernel
,
half
,
half
)
MS_REG_GPU_KERNEL_
TWO
(
BroadcastOpGpuKernel
,
half
)
MS_REG_GPU_KERNEL_
ONE
(
FloorDiv
,
KernelAttr
().
AddInputAttr
(
kNumberTypeFloat16
).
AddInputAttr
(
kNumberTypeFloat16
).
AddOutputAttr
(
kNumberTypeFloat16
),
BroadcastOpGpuKernel
,
half
,
half
)
MS_REG_GPU_KERNEL_
TWO
(
BroadcastOpGpuKernel
,
half
)
MS_REG_GPU_KERNEL_
ONE
(
AbsGrad
,
KernelAttr
().
AddInputAttr
(
kNumberTypeFloat16
).
AddInputAttr
(
kNumberTypeFloat16
).
AddOutputAttr
(
kNumberTypeFloat16
),
BroadcastOpGpuKernel
,
half
,
half
)
MS_REG_GPU_KERNEL_
TWO
(
BroadcastOpGpuKernel
,
half
)
MS_REG_GPU_KERNEL_
ONE
(
Div
,
KernelAttr
().
AddInputAttr
(
kNumberTypeFloat16
).
AddInputAttr
(
kNumberTypeFloat16
).
AddOutputAttr
(
kNumberTypeFloat16
),
BroadcastOpGpuKernel
,
half
,
half
)
BroadcastOpGpuKernel
,
half
)
// int32
MS_REG_GPU_KERNEL_
TWO
(
MS_REG_GPU_KERNEL_
ONE
(
Less
,
KernelAttr
().
AddInputAttr
(
kNumberTypeInt32
).
AddInputAttr
(
kNumberTypeInt32
).
AddOutputAttr
(
kNumberTypeBool
),
BroadcastOpGpuKernel
,
int
,
bool
)
MS_REG_GPU_KERNEL_
TWO
(
BroadcastOpGpuKernel
,
int
)
MS_REG_GPU_KERNEL_
ONE
(
TensorAdd
,
KernelAttr
().
AddInputAttr
(
kNumberTypeInt32
).
AddInputAttr
(
kNumberTypeInt32
).
AddOutputAttr
(
kNumberTypeInt32
),
BroadcastOpGpuKernel
,
int
,
int
)
MS_REG_GPU_KERNEL_
TWO
(
BroadcastOpGpuKernel
,
int
)
MS_REG_GPU_KERNEL_
ONE
(
Minimum
,
KernelAttr
().
AddInputAttr
(
kNumberTypeInt32
).
AddInputAttr
(
kNumberTypeInt32
).
AddOutputAttr
(
kNumberTypeInt32
),
BroadcastOpGpuKernel
,
int
,
int
)
MS_REG_GPU_KERNEL_
TWO
(
BroadcastOpGpuKernel
,
int
)
MS_REG_GPU_KERNEL_
ONE
(
Maximum
,
KernelAttr
().
AddInputAttr
(
kNumberTypeInt32
).
AddInputAttr
(
kNumberTypeInt32
).
AddOutputAttr
(
kNumberTypeInt32
),
BroadcastOpGpuKernel
,
int
,
int
)
MS_REG_GPU_KERNEL_
TWO
(
BroadcastOpGpuKernel
,
int
)
MS_REG_GPU_KERNEL_
ONE
(
Mul
,
KernelAttr
().
AddInputAttr
(
kNumberTypeInt32
).
AddInputAttr
(
kNumberTypeInt32
).
AddOutputAttr
(
kNumberTypeInt32
),
BroadcastOpGpuKernel
,
int
,
int
)
MS_REG_GPU_KERNEL_
TWO
(
BroadcastOpGpuKernel
,
int
)
MS_REG_GPU_KERNEL_
ONE
(
FloorDiv
,
KernelAttr
().
AddInputAttr
(
kNumberTypeInt32
).
AddInputAttr
(
kNumberTypeInt32
).
AddOutputAttr
(
kNumberTypeInt32
),
BroadcastOpGpuKernel
,
int
,
int
)
MS_REG_GPU_KERNEL_
TWO
(
BroadcastOpGpuKernel
,
int
)
MS_REG_GPU_KERNEL_
ONE
(
AbsGrad
,
KernelAttr
().
AddInputAttr
(
kNumberTypeInt32
).
AddInputAttr
(
kNumberTypeInt32
).
AddOutputAttr
(
kNumberTypeInt32
),
BroadcastOpGpuKernel
,
int
,
int
)
MS_REG_GPU_KERNEL_
TWO
(
BroadcastOpGpuKernel
,
int
)
MS_REG_GPU_KERNEL_
ONE
(
Div
,
KernelAttr
().
AddInputAttr
(
kNumberTypeInt32
).
AddInputAttr
(
kNumberTypeInt32
).
AddOutputAttr
(
kNumberTypeInt32
),
BroadcastOpGpuKernel
,
int
,
int
)
BroadcastOpGpuKernel
,
int
)
}
// namespace kernel
}
// namespace mindspore
mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_gpu_kernel.h
浏览文件 @
6ebe132c
...
...
@@ -28,11 +28,16 @@
namespace
mindspore
{
namespace
kernel
{
constexpr
int
MAX_DIMS
=
7
;
template
<
typename
T
,
typename
S
>
template
<
typename
T
>
class
BroadcastOpGpuKernel
:
public
GpuKernel
{
public:
BroadcastOpGpuKernel
()
:
op_type_
(
BROADCAST_TYPE_INVALID
),
need_broadcast_
(
false
),
input1_num_
(
1
),
input2_num_
(
1
),
output_num_
(
1
)
{}
:
op_type_
(
BROADCAST_TYPE_INVALID
),
need_broadcast_
(
false
),
is_comp_op_
(
false
),
input1_num_
(
1
),
input2_num_
(
1
),
output_num_
(
1
)
{}
~
BroadcastOpGpuKernel
()
override
=
default
;
const
std
::
vector
<
size_t
>
&
GetInputSizeList
()
const
override
{
return
input_size_list_
;
}
...
...
@@ -43,13 +48,23 @@ class BroadcastOpGpuKernel : public GpuKernel {
const
std
::
vector
<
AddressPtr
>
&
outputs
,
void
*
stream_ptr
)
override
{
T
*
lhs
=
GetDeviceAddress
<
T
>
(
inputs
,
0
);
T
*
rhs
=
GetDeviceAddress
<
T
>
(
inputs
,
1
);
S
*
output
=
GetDeviceAddress
<
S
>
(
outputs
,
0
);
if
(
is_comp_op_
)
{
bool
*
output
=
GetDeviceAddress
<
bool
>
(
outputs
,
0
);
if
(
need_broadcast_
)
{
Broadcast
(
lhs_shape_
,
rhs_shape_
,
output_shape_
,
op_type_
,
lhs
,
rhs
,
output
,
BroadcastCmp
(
lhs_shape_
,
rhs_shape_
,
output_shape_
,
op_type_
,
lhs
,
rhs
,
output
,
reinterpret_cast
<
cudaStream_t
>
(
stream_ptr
));
}
else
{
NoBroadcast
(
output_num_
,
op_type_
,
lhs
,
rhs
,
output
,
reinterpret_cast
<
cudaStream_t
>
(
stream_ptr
));
ElewiseCmp
(
output_num_
,
op_type_
,
lhs
,
rhs
,
output
,
reinterpret_cast
<
cudaStream_t
>
(
stream_ptr
));
}
}
else
{
T
*
output
=
GetDeviceAddress
<
T
>
(
outputs
,
0
);
if
(
need_broadcast_
)
{
BroadcastArith
(
lhs_shape_
,
rhs_shape_
,
output_shape_
,
op_type_
,
lhs
,
rhs
,
output
,
reinterpret_cast
<
cudaStream_t
>
(
stream_ptr
));
}
else
{
ElewiseArith
(
output_num_
,
op_type_
,
lhs
,
rhs
,
output
,
reinterpret_cast
<
cudaStream_t
>
(
stream_ptr
));
}
}
return
true
;
...
...
@@ -91,26 +106,42 @@ class BroadcastOpGpuKernel : public GpuKernel {
void
InitSizeLists
()
override
{
input_size_list_
.
push_back
(
input1_num_
*
sizeof
(
T
));
input_size_list_
.
push_back
(
input2_num_
*
sizeof
(
T
));
output_size_list_
.
push_back
(
output_num_
*
sizeof
(
S
));
auto
unit_size
=
is_comp_op_
?
sizeof
(
bool
)
:
sizeof
(
T
);
output_size_list_
.
push_back
(
output_num_
*
unit_size
);
}
private:
void
GetOpType
(
const
CNodePtr
&
kernel_node
)
{
std
::
string
kernel_name
=
AnfAlgo
::
GetCNodeName
(
kernel_node
);
static
std
::
map
<
std
::
string
,
BroadcastOpType
>
kBroadcastTypeMap
=
{
{
"Greater"
,
BROADCAST_TYPE_GREATER
},
{
"Less"
,
BROADCAST_TYPE_LESS
},
{
"Maximum"
,
BROADCAST_TYPE_MAXIMUM
},
{
"Minimum"
,
BROADCAST_TYPE_MINIMUM
},
{
"Pow"
,
BROADCAST_TYPE_POWER
},
{
"RealDiv"
,
BROADCAST_TYPE_REALDIV
},
{
"Mul"
,
BROADCAST_TYPE_MUL
},
{
"Sub"
,
BROADCAST_TYPE_SUB
},
{
"TensorAdd"
,
BROADCAST_TYPE_ADD
},
{
"FloorDiv"
,
BROADCAST_TYPE_FLOORDIV
},
{
"AbsGrad"
,
BROADCAST_TYPE_ABSGRAD
},
{
"Div"
,
BROADCAST_TYPE_DIV
},
static
std
::
map
<
std
::
string
,
BroadcastOpType
>
kBroadcastCmpTypeMap
=
{
{
"Greater"
,
BROADCAST_TYPE_GREATER
},
{
"Less"
,
BROADCAST_TYPE_LESS
},
};
auto
iter
=
kBroadcastTypeMap
.
find
(
kernel_name
);
if
(
iter
==
kBroadcastTypeMap
.
end
())
{
MS_LOG
(
EXCEPTION
)
<<
"operation "
<<
kernel_name
<<
" is not supported."
;
}
else
{
auto
iter
=
kBroadcastCmpTypeMap
.
find
(
kernel_name
);
if
(
iter
!=
kBroadcastCmpTypeMap
.
end
())
{
op_type_
=
iter
->
second
;
is_comp_op_
=
true
;
return
;
}
static
std
::
map
<
std
::
string
,
BroadcastOpType
>
kBroadcastArithmetricTypeMap
=
{
{
"Maximum"
,
BROADCAST_TYPE_MAXIMUM
},
{
"Minimum"
,
BROADCAST_TYPE_MINIMUM
},
{
"Pow"
,
BROADCAST_TYPE_POWER
},
{
"RealDiv"
,
BROADCAST_TYPE_REALDIV
},
{
"Mul"
,
BROADCAST_TYPE_MUL
},
{
"Sub"
,
BROADCAST_TYPE_SUB
},
{
"TensorAdd"
,
BROADCAST_TYPE_ADD
},
{
"FloorDiv"
,
BROADCAST_TYPE_FLOORDIV
},
{
"AbsGrad"
,
BROADCAST_TYPE_ABSGRAD
},
{
"Div"
,
BROADCAST_TYPE_DIV
},
};
iter
=
kBroadcastArithmetricTypeMap
.
find
(
kernel_name
);
if
(
iter
!=
kBroadcastArithmetricTypeMap
.
end
())
{
op_type_
=
iter
->
second
;
is_comp_op_
=
false
;
return
;
}
MS_LOG
(
EXCEPTION
)
<<
"operation "
<<
kernel_name
<<
" is not supported."
;
}
bool
IsBroadcast
(
const
std
::
vector
<
size_t
>
&
lhs
,
const
std
::
vector
<
size_t
>
&
rhs
)
{
...
...
@@ -127,6 +158,7 @@ class BroadcastOpGpuKernel : public GpuKernel {
BroadcastOpType
op_type_
;
bool
need_broadcast_
;
bool
is_comp_op_
;
int
input1_num_
;
int
input2_num_
;
int
output_num_
;
...
...
@@ -137,7 +169,7 @@ class BroadcastOpGpuKernel : public GpuKernel {
std
::
vector
<
size_t
>
input_size_list_
;
std
::
vector
<
size_t
>
output_size_list_
;
std
::
vector
<
size_t
>
workspace_size_list_
;
};
};
// namespace kernel
}
// namespace kernel
}
// namespace mindspore
...
...
tests/st/ops/gpu/test_broadcast_op.py
浏览文件 @
6ebe132c
...
...
@@ -160,3 +160,45 @@ def test_broadcast_diff_dims():
output_ms
=
P
.
Sub
()(
Tensor
(
x1_np
),
Tensor
(
x2_np
))
output_np
=
x1_np
-
x2_np
assert
np
.
allclose
(
output_ms
.
asnumpy
(),
output_np
)
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
platform_x86_gpu_training
@
pytest
.
mark
.
env_onecard
def
test_broadcast_fp16
():
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
'GPU'
)
x1_np
=
np
.
random
.
rand
(
3
,
1
,
5
,
1
).
astype
(
np
.
float16
)
x2_np
=
np
.
random
.
rand
(
1
,
4
,
1
,
6
).
astype
(
np
.
float16
)
output_ms
=
P
.
Minimum
()(
Tensor
(
x1_np
),
Tensor
(
x2_np
))
output_np
=
np
.
minimum
(
x1_np
,
x2_np
)
assert
np
.
allclose
(
output_ms
.
asnumpy
(),
output_np
)
output_ms
=
P
.
Maximum
()(
Tensor
(
x1_np
),
Tensor
(
x2_np
))
output_np
=
np
.
maximum
(
x1_np
,
x2_np
)
assert
np
.
allclose
(
output_ms
.
asnumpy
(),
output_np
)
output_ms
=
P
.
Greater
()(
Tensor
(
x1_np
),
Tensor
(
x2_np
))
output_np
=
x1_np
>
x2_np
assert
np
.
allclose
(
output_ms
.
asnumpy
(),
output_np
)
output_ms
=
P
.
Less
()(
Tensor
(
x1_np
),
Tensor
(
x2_np
))
output_np
=
x1_np
<
x2_np
assert
np
.
allclose
(
output_ms
.
asnumpy
(),
output_np
)
output_ms
=
P
.
Pow
()(
Tensor
(
x1_np
),
Tensor
(
x2_np
))
output_np
=
np
.
power
(
x1_np
,
x2_np
)
assert
np
.
allclose
(
output_ms
.
asnumpy
(),
output_np
)
output_ms
=
P
.
RealDiv
()(
Tensor
(
x1_np
),
Tensor
(
x2_np
))
output_np
=
x1_np
/
x2_np
assert
np
.
allclose
(
output_ms
.
asnumpy
(),
output_np
)
output_ms
=
P
.
Mul
()(
Tensor
(
x1_np
),
Tensor
(
x2_np
))
output_np
=
x1_np
*
x2_np
assert
np
.
allclose
(
output_ms
.
asnumpy
(),
output_np
)
output_ms
=
P
.
Sub
()(
Tensor
(
x1_np
),
Tensor
(
x2_np
))
output_np
=
x1_np
-
x2_np
assert
np
.
allclose
(
output_ms
.
asnumpy
(),
output_np
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录