Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
8f4cd765
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
8f4cd765
编写于
6月 18, 2020
作者:
W
wilfChen
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
gpu Gelu kernel support fp16
上级
971f10d2
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
139 addition
and
16 deletion
+139
-16
mindspore/ccsrc/kernel/gpu/cuda_impl/gelu_impl.cu
mindspore/ccsrc/kernel/gpu/cuda_impl/gelu_impl.cu
+86
-15
mindspore/ccsrc/kernel/gpu/nn/gelu_grad_kernel.cc
mindspore/ccsrc/kernel/gpu/nn/gelu_grad_kernel.cc
+7
-0
mindspore/ccsrc/kernel/gpu/nn/gelu_kernel.cc
mindspore/ccsrc/kernel/gpu/nn/gelu_kernel.cc
+2
-0
tests/st/ops/gpu/test_gelu_grad_op.py
tests/st/ops/gpu/test_gelu_grad_op.py
+31
-1
tests/st/ops/gpu/test_gelu_op.py
tests/st/ops/gpu/test_gelu_op.py
+13
-0
未找到文件。
mindspore/ccsrc/kernel/gpu/cuda_impl/gelu_impl.cu
浏览文件 @
8f4cd765
...
...
@@ -14,32 +14,62 @@
* limitations under the License.
*/
#include "kernel/gpu/cuda_impl/gelu_impl.cuh"
#include "device/gpu/cuda_common.h"
template
<
typename
T
>
__global__
void
GeluKernel
(
size_t
size
,
T
*
input_addr
,
T
*
output_addr
)
{
template
<
typename
T
>
__global__
void
GeluKernel
(
size_t
size
,
T
*
input_addr
,
T
*
output_addr
)
{
// formula:
// gelu(x) = 0.5 * x * (1.0 + tanh(y))
// tanh(y) = 2 / (1 + exp(-2y)) - 1)
// y = sqrt(2/pi) * (x + 0.044715 * x^3)
for
(
size_t
pos
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
pos
<
(
size
)
;
pos
+=
blockDim
.
x
*
gridDim
.
x
)
{
for
(
size_t
pos
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
pos
<
size
;
pos
+=
blockDim
.
x
*
gridDim
.
x
)
{
float
x
=
input_addr
[
pos
];
float
tanh_res
=
tanh
(
0.7978845608
*
(
x
+
0.044715
*
x
*
x
*
x
));
output_addr
[
pos
]
=
0.5
*
x
*
(
1.0
+
tanh_res
);
}
}
template
<
typename
T
>
void
Gelu
(
size_t
size
,
T
*
input_addr
,
T
*
output_addr
,
cudaStream_t
cuda_stream
)
{
template
<
>
__global__
void
GeluKernel
(
size_t
size
,
half
*
input_addr
,
half
*
output_addr
)
{
for
(
size_t
pos
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
pos
<
size
;
pos
+=
blockDim
.
x
*
gridDim
.
x
)
{
half
x
=
input_addr
[
pos
];
float
tanh_res
=
tanh
(
__half2float
(
half
(
0.7978845608
)
*
(
x
+
half
(
0.044715
)
*
x
*
x
*
x
)));
output_addr
[
pos
]
=
half
(
0.5
)
*
x
*
(
half
(
1.0
)
+
__float2half
(
tanh_res
));
}
}
template
<
>
__global__
void
GeluKernel
(
size_t
size
,
half2
*
input_addr
,
half2
*
output_addr
)
{
for
(
size_t
pos
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
pos
<
size
;
pos
+=
blockDim
.
x
*
gridDim
.
x
)
{
half2
x
=
input_addr
[
pos
];
float2
tanh_param
=
__half22float2
(
half2
(
0.7978845608
,
0.7978845608
)
*
(
x
+
half2
(
0.044715
,
0.044715
)
*
x
*
x
*
x
));
float2
tanh_res
;
tanh_res
.
x
=
tanh
(
tanh_param
.
x
);
tanh_res
.
y
=
tanh
(
tanh_param
.
y
);
output_addr
[
pos
]
=
half2
(
0.5
,
0.5
)
*
x
*
(
half2
(
1.0
,
1.0
)
+
__float22half2_rn
(
tanh_res
));
}
}
template
<
typename
T
>
void
Gelu
(
size_t
size
,
T
*
input_addr
,
T
*
output_addr
,
cudaStream_t
cuda_stream
)
{
GeluKernel
<<<
GET_BLOCKS
(
size
),
GET_THREADS
,
0
,
cuda_stream
>>>
(
size
,
input_addr
,
output_addr
);
return
;
}
template
<
>
void
Gelu
(
size_t
size
,
half
*
input_addr
,
half
*
output_addr
,
cudaStream_t
cuda_stream
)
{
if
(
size
%
2
==
0
)
{
GeluKernel
<
half2
><<<
GET_BLOCKS
(
size
),
GET_THREADS
,
0
,
cuda_stream
>>>
(
size
/
2
,
reinterpret_cast
<
half2
*>
(
input_addr
),
reinterpret_cast
<
half2
*>
(
output_addr
));
}
else
{
GeluKernel
<
half
><<<
GET_BLOCKS
(
size
),
GET_THREADS
,
0
,
cuda_stream
>>>
(
size
,
input_addr
,
output_addr
);
}
return
;
}
template
<
typename
T
>
__global__
void
GeluGradKernel
(
size_t
size
,
T
*
dy_addr
,
T
*
x_addr
,
T
*
dx_addr
)
{
template
<
typename
T
>
__global__
void
GeluGradKernel
(
size_t
size
,
T
*
dy_addr
,
T
*
x_addr
,
T
*
dx_addr
)
{
// formula:
// dx = dy * y'
// y' = 0.5 * (1 + tanh(tanh_para)) +
...
...
@@ -48,18 +78,59 @@ __global__ void GeluGradKernel(size_t size, T* dy_addr, T* x_addr, T* dx_addr) {
// mul_right = sqrt(2/pi) * (1 + 3 * 0.044715 * x^2))
for
(
size_t
pos
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
pos
<
(
size
);
pos
+=
blockDim
.
x
*
gridDim
.
x
)
{
T
x
=
x_addr
[
pos
];
T
tanh_res
=
tanh
(
0.7978845608
*
(
x
+
0.044715
*
x
*
x
*
x
));
T
mul_right
=
0.7978845608
+
0.1070322244
*
x
*
x
;
T
y_res
=
0.5
*
(
1
+
tanh_res
)
+
0.5
*
x
*
(
1
-
tanh_res
*
tanh_res
)
*
mul_right
;
T
tanh_res
=
tanh
(
0.7978845608
*
(
x
+
0.044715
*
x
*
x
*
x
));
T
mul_right
=
0.7978845608
+
0.1070322244
*
x
*
x
;
T
y_res
=
0.5
*
(
1.0
+
tanh_res
)
+
0.5
*
x
*
(
1.0
-
tanh_res
*
tanh_res
)
*
mul_right
;
dx_addr
[
pos
]
=
dy_addr
[
pos
]
*
y_res
;
}
}
template
<
typename
T
>
__global__
void
GeluGradKernel
(
size_t
size
,
half2
*
dy_addr
,
half2
*
x_addr
,
half2
*
dx_addr
)
{
for
(
size_t
pos
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
pos
<
(
size
);
pos
+=
blockDim
.
x
*
gridDim
.
x
)
{
half2
x
=
x_addr
[
pos
];
float2
tanh_param
=
__half22float2
(
half2
(
0.7978845608
,
0.7978845608
)
*
(
x
+
half2
(
0.044715
,
0.044715
)
*
x
*
x
*
x
));
float2
tanh_res
;
tanh_res
.
x
=
tanh
(
tanh_param
.
x
);
tanh_res
.
y
=
tanh
(
tanh_param
.
y
);
half2
tanh_res_half
=
__float22half2_rn
(
tanh_res
);
half2
mul_right
=
half2
(
0.7978845608
,
0.7978845608
)
+
half2
(
0.1070322244
,
0.1070322244
)
*
x
*
x
;
half2
y_res
=
half2
(
0.5
,
0.5
)
*
(
half2
(
1.0
,
1.0
)
+
tanh_res_half
)
+
half2
(
0.5
,
0.5
)
*
x
*
(
half2
(
1.0
,
1.0
)
-
tanh_res_half
*
tanh_res_half
)
*
mul_right
;
dx_addr
[
pos
]
=
dy_addr
[
pos
]
*
y_res
;
}
}
template
<
typename
T
>
__global__
void
GeluGradKernel
(
size_t
size
,
half
*
dy_addr
,
half
*
x_addr
,
half
*
dx_addr
)
{
for
(
size_t
pos
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
pos
<
(
size
);
pos
+=
blockDim
.
x
*
gridDim
.
x
)
{
half
x
=
x_addr
[
pos
];
half
tanh_param
=
half
(
0.7978845608
)
*
(
x
+
half
(
0.044715
)
*
x
*
x
*
x
);
half
tanh_res
=
__float2half_rn
(
tanh
(
__half2float
(
tanh_param
)));
half
mul_right
=
half
(
0.7978845608
)
+
half
(
0.1070322244
)
*
x
*
x
;
half
y_res
=
half
(
0.5
)
*
(
half
(
1.0
)
+
tanh_res
)
+
half
(
0.5
)
*
x
*
(
half
(
1.0
)
-
tanh_res
*
tanh_res
)
*
mul_right
;
dx_addr
[
pos
]
=
dy_addr
[
pos
]
*
y_res
;
}
}
template
<
typename
T
>
void
GeluGradKernel
(
size_t
size
,
T
*
dy_addr
,
T
*
x_addr
,
T
*
dx_addr
,
cudaStream_t
cuda_stream
)
{
template
<
typename
T
>
void
GeluGradKernel
(
size_t
size
,
T
*
dy_addr
,
T
*
x_addr
,
T
*
dx_addr
,
cudaStream_t
cuda_stream
)
{
GeluGradKernel
<<<
GET_BLOCKS
(
size
),
GET_THREADS
,
0
,
cuda_stream
>>>
(
size
,
dy_addr
,
x_addr
,
dx_addr
);
}
template
<
>
void
GeluGradKernel
(
size_t
size
,
half
*
dy_addr
,
half
*
x_addr
,
half
*
dx_addr
,
cudaStream_t
cuda_stream
)
{
if
(
size
%
2
==
0
)
{
GeluGradKernel
<
half2
><<<
GET_BLOCKS
(
size
),
GET_THREADS
,
0
,
cuda_stream
>>>
(
size
/
2
,
reinterpret_cast
<
half2
*>
(
dy_addr
),
reinterpret_cast
<
half2
*>
(
x_addr
),
reinterpret_cast
<
half2
*>
(
dx_addr
));
}
else
{
GeluGradKernel
<
half
><<<
GET_BLOCKS
(
size
),
GET_THREADS
,
0
,
cuda_stream
>>>
(
size
,
dy_addr
,
x_addr
,
dx_addr
);
}
return
;
}
template
void
Gelu
(
size_t
size
,
float
*
input_addr
,
float
*
output_addr
,
cudaStream_t
cuda_stream
);
template
void
GeluGradKernel
(
size_t
size
,
float
*
dy_addr
,
float
*
x_addr
,
float
*
dx_addr
,
cudaStream_t
cuda_stream
);
template
void
Gelu
(
size_t
size
,
float
*
input_addr
,
float
*
output_addr
,
cudaStream_t
cuda_stream
);
template
void
Gelu
(
size_t
size
,
half
*
input_addr
,
half
*
output_addr
,
cudaStream_t
cuda_stream
);
template
void
GeluGradKernel
(
size_t
size
,
float
*
dy_addr
,
float
*
x_addr
,
float
*
dx_addr
,
cudaStream_t
cuda_stream
);
template
void
GeluGradKernel
(
size_t
size
,
half
*
dy_addr
,
half
*
x_addr
,
half
*
dx_addr
,
cudaStream_t
cuda_stream
);
mindspore/ccsrc/kernel/gpu/nn/gelu_grad_kernel.cc
浏览文件 @
8f4cd765
...
...
@@ -25,5 +25,12 @@ MS_REG_GPU_KERNEL_ONE(GeluGrad,
.
AddInputAttr
(
kNumberTypeFloat32
)
.
AddOutputAttr
(
kNumberTypeFloat32
),
GeLUGpuGradKernel
,
float
)
MS_REG_GPU_KERNEL_ONE
(
GeluGrad
,
KernelAttr
()
.
AddInputAttr
(
kNumberTypeFloat16
)
.
AddInputAttr
(
kNumberTypeFloat16
)
.
AddInputAttr
(
kNumberTypeFloat16
)
.
AddOutputAttr
(
kNumberTypeFloat16
),
GeLUGpuGradKernel
,
half
)
}
// namespace kernel
}
// namespace mindspore
mindspore/ccsrc/kernel/gpu/nn/gelu_kernel.cc
浏览文件 @
8f4cd765
...
...
@@ -20,5 +20,7 @@ namespace mindspore {
namespace
kernel
{
MS_REG_GPU_KERNEL_ONE
(
Gelu
,
KernelAttr
().
AddInputAttr
(
kNumberTypeFloat32
).
AddOutputAttr
(
kNumberTypeFloat32
),
GeluGpuKernel
,
float
)
MS_REG_GPU_KERNEL_ONE
(
Gelu
,
KernelAttr
().
AddInputAttr
(
kNumberTypeFloat16
).
AddOutputAttr
(
kNumberTypeFloat16
),
GeluGpuKernel
,
half
)
}
// namespace kernel
}
// namespace mindspore
tests/st/ops/gpu/test_gelu_grad_op.py
浏览文件 @
8f4cd765
...
...
@@ -58,7 +58,37 @@ def test_gelugrad():
grad
=
Grad
(
net
)
output
=
grad
(
x_ms
,
dy_ms
)
print
(
output
)
expect
=
[
0.50963277
,
0.9414753
,
0.2667653
,
0.21358444
,
0.25243032
,
0.0352667
,
0.34266686
,
0.57757664
,
0.04707306
,
0.51536125
]
assert
np
.
allclose
(
output
[
0
].
asnumpy
(),
expect
)
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
platform_x86_gpu_training
@
pytest
.
mark
.
env_onecard
def
test_gelugrad_fp16
():
np
.
random
.
seed
(
42
)
x_np
=
np
.
random
.
randn
(
5
,
3
,
6
).
astype
(
np
.
float16
)
dy_np
=
np
.
random
.
randn
(
5
,
3
,
6
).
astype
(
np
.
float16
)
net
=
GeluNet
()
grad
=
Grad
(
net
)
output
=
grad
(
Tensor
(
x_np
),
Tensor
(
dy_np
))
expect
=
[[[
8.4045e-02
,
3.7817e-01
,
-
6.6748e-01
,
-
3.6914e-01
,
-
1.2415e-01
,
-
4.6362e-01
],
[
3.3301e-01
,
2.6270e-01
,
7.7534e-04
,
-
2.0947e-01
,
-
2.2021e-01
,
-
6.4880e-02
],
[
-
2.3633e-01
,
7.6538e-02
,
1.8280e-02
,
3.8635e-02
,
-
1.6235e-01
,
1.2964e-01
]],
[[
-
1.4801e-02
,
9.6130e-03
,
-
2.1660e+00
,
-
8.5602e-03
,
3.3356e-02
,
-
3.1885e-01
],
[
-
2.0355e-02
,
1.7737e-01
,
3.8719e-03
,
-
9.1895e-01
,
8.4717e-02
,
2.0593e-01
],
[
5.8350e-02
,
-
1.0020e+00
,
6.8652e-01
,
1.3428e-01
,
6.0352e-01
,
-
2.6270e-01
]],
[[
-
6.5820e-01
,
5.1147e-02
,
-
1.2650e-02
,
-
3.2983e-01
,
-
1.5410e+00
,
4.3518e-02
],
[
-
4.3359e-01
,
1.2659e-01
,
1.1792e-01
,
2.2705e-02
,
-
1.2329e-01
,
-
3.5278e-01
],
[
6.2109e-01
,
1.3611e-01
,
1.7041e-01
,
2.7124e-01
,
-
5.5908e-02
,
1.7212e-01
]],
[[
2.8320e-01
,
8.3252e-01
,
4.2480e-02
,
-
3.4473e-01
,
3.9429e-01
,
3.1958e-01
],
[
3.6499e-02
,
1.2250e-01
,
7.1350e-02
,
-
2.7267e-02
,
3.0029e-01
,
-
8.0566e-01
],
[
8.2617e-01
,
5.1367e-01
,
-
9.2480e-01
,
3.3203e-02
,
-
7.5684e-01
,
8.8623e-01
]],
[[
5.4590e-01
,
-
9.2383e-01
,
-
2.8107e-02
,
4.2432e-01
,
4.6826e-01
,
5.0879e-01
],
[
-
1.4062e-01
,
6.6284e-02
,
-
2.9126e-01
,
-
6.3086e-01
,
-
8.6975e-02
,
4.1504e-02
],
[
-
6.3171e-03
,
1.0852e-01
,
1.3779e-02
,
1.0947e+00
,
-
3.0334e-02
,
2.3828e+00
]]]
assert
np
.
allclose
(
output
[
0
].
asnumpy
(),
expect
,
rtol
=
1e-2
)
tests/st/ops/gpu/test_gelu_op.py
浏览文件 @
8f4cd765
...
...
@@ -91,3 +91,16 @@ def test_gelu_neg():
y_ms
=
net
(
x_ms
)
assert
np
.
allclose
(
y_np
,
y_ms
.
asnumpy
())
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
platform_x86_gpu_training
@
pytest
.
mark
.
env_onecard
def
test_gelu_4d_fp16
():
x_np
=
np
.
random
.
random
((
32
,
3
,
224
,
224
)).
astype
(
np
.
float16
)
y_np
=
GeluCompute
(
x_np
)
x_ms
=
Tensor
(
x_np
)
net
=
GeluNet
()
y_ms
=
net
(
x_ms
)
assert
np
.
allclose
(
y_np
,
y_ms
.
asnumpy
(),
rtol
=
1e-3
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录