Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
4d600e70
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
4d600e70
编写于
7月 30, 2020
作者:
W
wilfChen
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
gpu layernorm
上级
699e616b
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
62 addition
and
8 deletion
+62
-8
mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/layer_norm_grad_impl.cu
...end/kernel_compiler/gpu/cuda_impl/layer_norm_grad_impl.cu
+14
-8
tests/st/ops/gpu/test_layer_norm_grad_op.py
tests/st/ops/gpu/test_layer_norm_grad_op.py
+26
-0
tests/st/ops/gpu/test_layer_norm_op.py
tests/st/ops/gpu/test_layer_norm_op.py
+22
-0
未找到文件。
mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/layer_norm_grad_impl.cu
浏览文件 @
4d600e70
...
...
@@ -35,8 +35,8 @@ inline __device__ half my_pow(half a, double b) {
template
<
typename
T
>
inline
__device__
void
GammaAndBetaThreadReduce
(
const
int
&
col
,
const
int
&
row_dim
,
const
int
&
col_dim
,
const
T
&
epsilon
,
const
T
*
dy
,
const
T
*
x
,
const
T
*
mean
,
const
T
*
var
,
T
*
dg
,
T
*
db
)
{
const
int
&
mean_dim
,
const
T
&
epsilon
,
const
T
*
dy
,
const
T
*
x
,
const
T
*
mean
,
const
T
*
var
,
T
*
dg
,
T
*
db
)
{
int
loop_num
=
(
row_dim
+
NUM_PER_THREAD_REDUCE
-
1
)
/
NUM_PER_THREAD_REDUCE
;
for
(
int
i
=
threadIdx
.
x
;
i
<
loop_num
;
i
+=
blockDim
.
x
)
{
for
(
int
j
=
0
;
j
<
NUM_PER_THREAD_REDUCE
;
j
++
)
{
...
...
@@ -46,7 +46,8 @@ inline __device__ void GammaAndBetaThreadReduce(const int &col, const int &row_d
}
int
pos
=
row
*
col_dim
+
col
;
dg
[
0
]
+=
dy
[
pos
]
*
my_pow
(
var
[
row
]
+
epsilon
,
-
0.5
)
*
(
x
[
pos
]
-
mean
[
row
]);
int
mean_offset
=
pos
/
mean_dim
;
dg
[
0
]
+=
dy
[
pos
]
*
my_pow
(
var
[
mean_offset
]
+
epsilon
,
-
0.5
)
*
(
x
[
pos
]
-
mean
[
mean_offset
]);
db
[
0
]
+=
dy
[
pos
];
}
}
...
...
@@ -89,8 +90,9 @@ inline __device__ void GammaAndBetaBlockReduce(const int &col, const int &row_di
}
template
<
typename
T
>
__global__
void
GammaAndBetaPropKernel
(
const
int
row_dim
,
const
int
col_dim
,
const
T
epsilon
,
const
T
*
dy
,
const
T
*
x
,
const
T
*
mean_addr
,
const
T
*
var_addr
,
T
*
dg_addr
,
T
*
db_addr
)
{
__global__
void
GammaAndBetaPropKernel
(
const
int
row_dim
,
const
int
col_dim
,
const
int
mean_dim
,
const
T
epsilon
,
const
T
*
dy
,
const
T
*
x
,
const
T
*
mean_addr
,
const
T
*
var_addr
,
T
*
dg_addr
,
T
*
db_addr
)
{
// row: [0:param_axis]
// col: [param_axis:]
// dg[i][j] = dy[i][j] * (var[i] + epsilon, -0.5) * (x[i][j] - mean[i])
...
...
@@ -98,7 +100,7 @@ __global__ void GammaAndBetaPropKernel(const int row_dim, const int col_dim, con
for
(
int
col
=
blockIdx
.
x
;
col
<
col_dim
;
col
+=
gridDim
.
x
)
{
T
dg
=
0
;
T
db
=
0
;
GammaAndBetaThreadReduce
(
col
,
row_dim
,
col_dim
,
epsilon
,
dy
,
x
,
mean_addr
,
var_addr
,
&
dg
,
&
db
);
GammaAndBetaThreadReduce
(
col
,
row_dim
,
col_dim
,
mean_dim
,
epsilon
,
dy
,
x
,
mean_addr
,
var_addr
,
&
dg
,
&
db
);
GammaAndBetaWarpReduce
(
&
dg
,
&
db
);
GammaAndBetaBlockReduce
(
col
,
row_dim
,
&
dg
,
&
db
,
dg_addr
,
db_addr
);
}
...
...
@@ -239,8 +241,12 @@ void LayerNormGrad(const int &row_dim, const int &col_dim, const int ¶m_dim,
mean
,
var
,
gamma
,
dx
);
share_mem_size
=
thread_per_block
/
WARP_SIZE
*
2
*
sizeof
(
T
);
GammaAndBetaPropKernel
<<<
col_dim
,
thread_per_block
,
share_mem_size
,
stream
>>>
(
row_dim
,
col_dim
,
epsilon
,
dy
,
x
,
mean
,
var
,
dg
,
db
);
// GammaAndBetaPropKernel<<<col_dim, thread_per_block, share_mem_size, stream>>>(row_dim, col_dim, epsilon, dy, x,
// mean,
// var, dg, db);
int
param_reduce_dim
=
row_dim
*
col_dim
/
param_dim
;
GammaAndBetaPropKernel
<<<
param_dim
,
thread_per_block
,
share_mem_size
,
stream
>>>
(
param_reduce_dim
,
param_dim
,
col_dim
,
epsilon
,
dy
,
x
,
mean
,
var
,
dg
,
db
);
}
template
void
LayerNormGrad
(
const
int
&
row_dim
,
const
int
&
col_dim
,
const
int
&
param_dim
,
const
float
&
epsilon
,
...
...
tests/st/ops/gpu/test_layer_norm_grad_op.py
浏览文件 @
4d600e70
...
...
@@ -193,3 +193,29 @@ def test_layernormgrad4():
assert
np
.
allclose
(
dx_ms
.
asnumpy
(),
dx_np
,
rtol
=
1e-6
,
atol
=
1e-6
)
assert
np
.
allclose
(
dg_ms
.
asnumpy
(),
dg_np
,
rtol
=
1e-6
,
atol
=
1e-3
)
assert
np
.
allclose
(
db_ms
.
asnumpy
(),
db_np
,
rtol
=
1e-6
,
atol
=
1e-3
)
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
platform_x86_gpu_training
@
pytest
.
mark
.
env_onecard
def
test_layernormgrad5
():
begin_norm_axis
=
2
begin_params_axis
=
1
x_np
=
np
.
random
.
randn
(
128
,
2
,
16
,
32
).
astype
(
np
.
float32
)
dy_np
=
np
.
random
.
randn
(
128
,
2
,
16
,
32
).
astype
(
np
.
float32
)
gamma_np
=
np
.
random
.
randn
(
*
x_np
.
shape
[
begin_params_axis
:]).
astype
(
np
.
float32
)
epsilon
=
10e-12
dx_np
,
dg_np
,
db_np
,
mean_np
,
var_np
=
LayerNormGradReference
(
x_np
,
dy_np
,
gamma_np
,
epsilon
,
begin_norm_axis
,
begin_params_axis
)
dy_ms
=
Tensor
(
dy_np
)
x_ms
=
Tensor
(
x_np
)
var_ms
=
Tensor
(
var_np
)
mean_ms
=
Tensor
(
mean_np
)
gamma_ms
=
Tensor
(
gamma_np
)
net
=
LayerNormGradNet
(
begin_norm_axis
,
begin_params_axis
)
dx_ms
,
dg_ms
,
db_ms
=
net
(
x_ms
,
dy_ms
,
var_ms
,
mean_ms
,
gamma_ms
)
assert
np
.
allclose
(
dx_ms
.
asnumpy
(),
dx_np
,
rtol
=
1e-6
,
atol
=
1e-6
)
assert
np
.
allclose
(
db_ms
.
asnumpy
(),
db_np
,
rtol
=
1e-6
,
atol
=
1e-3
)
assert
np
.
allclose
(
dg_ms
.
asnumpy
(),
dg_np
,
rtol
=
1e-6
,
atol
=
1e-3
)
tests/st/ops/gpu/test_layer_norm_op.py
浏览文件 @
4d600e70
...
...
@@ -175,3 +175,25 @@ def test_layernorm2d_3():
assert
np
.
allclose
(
y_ms
.
asnumpy
(),
y_np
,
rtol
=
1e-6
,
atol
=
1e-6
)
assert
np
.
allclose
(
mean_ms
.
asnumpy
(),
mean_np
,
rtol
=
1e-6
,
atol
=
1e-6
)
assert
np
.
allclose
(
var_ms
.
asnumpy
(),
var_np
,
rtol
=
1e-6
,
atol
=
1e-6
)
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
platform_x86_gpu_training
@
pytest
.
mark
.
env_onecard
def
test_layernorm2d_4
():
begin_norm_axis
=
2
begin_params_axis
=
1
np
.
random
.
seed
(
42
)
x_np
=
np
.
random
.
randn
(
128
,
2
,
16
,
32
).
astype
(
np
.
float32
)
gamma_np
=
np
.
random
.
randn
(
*
x_np
.
shape
[
begin_params_axis
:]).
astype
(
np
.
float32
)
beta_np
=
np
.
random
.
randn
(
*
x_np
.
shape
[
begin_params_axis
:]).
astype
(
np
.
float32
)
y_np
,
mean_np
,
var_np
=
LayerNormReference
(
begin_norm_axis
,
begin_params_axis
,
x_np
,
gamma_np
,
beta_np
)
x_ms
=
Tensor
(
x_np
)
gamma_ms
=
Tensor
(
gamma_np
)
beta_ms
=
Tensor
(
beta_np
)
net
=
LayerNormNet
(
begin_norm_axis
,
begin_params_axis
)
y_ms
,
mean_ms
,
var_ms
=
net
(
x_ms
,
gamma_ms
,
beta_ms
)
assert
np
.
allclose
(
y_ms
.
asnumpy
(),
y_np
,
rtol
=
1e-6
,
atol
=
1e-6
)
assert
np
.
allclose
(
mean_ms
.
asnumpy
(),
mean_np
,
rtol
=
1e-6
,
atol
=
1e-6
)
assert
np
.
allclose
(
var_ms
.
asnumpy
(),
var_np
,
rtol
=
1e-6
,
atol
=
1e-6
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录