Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
da71a914
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
da71a914
编写于
6月 22, 2020
作者:
V
VectorSL
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
gpu momentum layernorm layernormgrad support fp16
上级
932b7649
变更
9
隐藏空白更改
内联
并排
Showing
9 changed file
with
177 addition
and
49 deletion
+177
-49
mindspore/ccsrc/kernel/gpu/cuda_impl/layer_norm_grad_impl.cu
mindspore/ccsrc/kernel/gpu/cuda_impl/layer_norm_grad_impl.cu
+72
-18
mindspore/ccsrc/kernel/gpu/cuda_impl/layer_norm_impl.cu
mindspore/ccsrc/kernel/gpu/cuda_impl/layer_norm_impl.cu
+22
-7
mindspore/ccsrc/kernel/gpu/cuda_impl/layer_norm_impl.cuh
mindspore/ccsrc/kernel/gpu/cuda_impl/layer_norm_impl.cuh
+17
-0
mindspore/ccsrc/kernel/gpu/cuda_impl/momentum_impl.cu
mindspore/ccsrc/kernel/gpu/cuda_impl/momentum_impl.cu
+25
-12
mindspore/ccsrc/kernel/gpu/cuda_impl/momentum_impl.cuh
mindspore/ccsrc/kernel/gpu/cuda_impl/momentum_impl.cuh
+3
-3
mindspore/ccsrc/kernel/gpu/nn/layer_norm_gpu_kernel.cc
mindspore/ccsrc/kernel/gpu/nn/layer_norm_gpu_kernel.cc
+9
-0
mindspore/ccsrc/kernel/gpu/nn/layer_norm_grad_gpu_kernel.cc
mindspore/ccsrc/kernel/gpu/nn/layer_norm_grad_gpu_kernel.cc
+11
-0
mindspore/ccsrc/kernel/gpu/nn/momentum_gpu_kernel.cc
mindspore/ccsrc/kernel/gpu/nn/momentum_gpu_kernel.cc
+13
-4
mindspore/ccsrc/kernel/gpu/nn/momentum_gpu_kernel.h
mindspore/ccsrc/kernel/gpu/nn/momentum_gpu_kernel.h
+5
-5
未找到文件。
mindspore/ccsrc/kernel/gpu/cuda_impl/layer_norm_grad_impl.cu
浏览文件 @
da71a914
...
...
@@ -18,10 +18,21 @@
#include <stdint.h>
#include <cuda_runtime.h>
#include "kernel/gpu/cuda_impl/layer_norm_grad_impl.cuh"
#include "kernel/gpu/cuda_impl/layer_norm_impl.cuh"
constexpr
int
NUM_PER_THREAD_REDUCE
=
4
;
constexpr
int
WARP_SIZE
=
32
;
template
<
typename
T
>
inline
__device__
T
my_pow
(
T
a
,
double
b
)
{
return
pow
(
a
,
static_cast
<
float
>
(
b
));
}
template
<
>
inline
__device__
half
my_pow
(
half
a
,
double
b
)
{
return
__float2half
(
pow
(
__half2float
(
a
),
static_cast
<
float
>
(
b
)));
}
template
<
typename
T
>
inline
__device__
void
GammaAndBetaThreadReduce
(
const
int
&
col
,
const
int
&
row_dim
,
const
int
&
col_dim
,
const
T
&
epsilon
,
const
T
*
dy
,
const
T
*
x
,
const
T
*
mean
,
const
T
*
var
,
...
...
@@ -35,7 +46,7 @@ inline __device__ void GammaAndBetaThreadReduce(const int& col, const int& row_d
}
int
pos
=
row
*
col_dim
+
col
;
dg
[
0
]
+=
dy
[
pos
]
*
pow
(
var
[
row
]
+
epsilon
,
-
0.5
)
*
(
x
[
pos
]
-
mean
[
row
]);
dg
[
0
]
+=
dy
[
pos
]
*
my_
pow
(
var
[
row
]
+
epsilon
,
-
0.5
)
*
(
x
[
pos
]
-
mean
[
row
]);
db
[
0
]
+=
dy
[
pos
];
}
}
...
...
@@ -58,26 +69,26 @@ inline __device__ void GammaAndBetaBlockReduce(const int& col, const int& row_di
// load data to share memory
// thread(0, 32, 64, 96, ...) keep the data
extern
__shared__
T
share_mem
[]
;
DynamicSharedMem
<
T
>
share_mem
;
if
(
threadIdx
.
x
%
WARP_SIZE
==
0
)
{
int
offset
=
threadIdx
.
x
/
WARP_SIZE
*
2
;
share_mem
[
offset
]
=
dg
[
0
];
share_mem
[
offset
+
1
]
=
db
[
0
];
share_mem
.
addr
()
[
offset
]
=
dg
[
0
];
share_mem
.
addr
()
[
offset
+
1
]
=
db
[
0
];
}
__syncthreads
();
for
(
int
stride
=
blockDim
.
x
/
WARP_SIZE
/
2
;
stride
>
0
;
stride
>>=
1
)
{
if
(
threadIdx
.
x
<
stride
)
{
int
offset
=
(
threadIdx
.
x
+
stride
)
*
2
;
share_mem
[
threadIdx
.
x
*
2
]
+=
share_mem
[
offset
];
share_mem
[
threadIdx
.
x
*
2
+
1
]
+=
share_mem
[
offset
+
1
];
share_mem
.
addr
()[
threadIdx
.
x
*
2
]
+=
share_mem
.
addr
()
[
offset
];
share_mem
.
addr
()[
threadIdx
.
x
*
2
+
1
]
+=
share_mem
.
addr
()
[
offset
+
1
];
}
}
__syncthreads
();
if
(
threadIdx
.
x
==
0
)
{
dg_addr
[
col
]
=
share_mem
[
0
];
db_addr
[
col
]
=
share_mem
[
1
];
dg_addr
[
col
]
=
share_mem
.
addr
()
[
0
];
db_addr
[
col
]
=
share_mem
.
addr
()
[
1
];
}
}
...
...
@@ -114,13 +125,37 @@ inline __device__ void InputThreadReduce(const int& row, const int& col_dim, con
T
v1
=
dy
[
pos
]
*
gamma
[
gamma_offset
];
T
v2
=
x
[
pos
]
-
mean
[
row
];
sum1
[
0
]
+=
-
0.5
*
v1
*
v2
*
pow
(
var
[
row
]
+
epsilon
,
-
1.5
);
sum1
[
0
]
+=
-
0.5
*
v1
*
v2
*
my_
pow
(
var
[
row
]
+
epsilon
,
-
1.5
);
sum2
[
0
]
+=
v1
;
sum3
[
0
]
+=
-
2.0
*
v2
;
}
}
}
template
<
>
inline
__device__
void
InputThreadReduce
(
const
int
&
row
,
const
int
&
col_dim
,
const
int
&
param_dim
,
const
half
&
epsilon
,
half
*
sum1
,
half
*
sum2
,
half
*
sum3
,
const
half
*
dy
,
const
half
*
x
,
const
half
*
mean
,
const
half
*
var
,
const
half
*
gamma
)
{
int
loop_num
=
(
col_dim
+
NUM_PER_THREAD_REDUCE
-
1
)
/
NUM_PER_THREAD_REDUCE
;
for
(
int
i
=
threadIdx
.
x
;
i
<
loop_num
;
i
+=
blockDim
.
x
)
{
for
(
int
j
=
0
;
j
<
NUM_PER_THREAD_REDUCE
;
j
++
)
{
int
col
=
NUM_PER_THREAD_REDUCE
*
i
+
j
;
if
(
col
>=
col_dim
)
{
return
;
}
int
pos
=
row
*
col_dim
+
col
;
int
gamma_offset
=
pos
%
param_dim
;
half
v1
=
dy
[
pos
]
*
gamma
[
gamma_offset
];
half
v2
=
x
[
pos
]
-
mean
[
row
];
sum1
[
0
]
+=
__float2half
(
-
0.5
)
*
v1
*
v2
*
my_pow
(
var
[
row
]
+
epsilon
,
-
1.5
);
sum2
[
0
]
+=
v1
;
sum3
[
0
]
+=
__float2half
(
-
2.0
)
*
v2
;
}
}
}
template
<
typename
T
>
inline
__device__
void
InputWarpReduce
(
T
*
sum1
,
T
*
sum2
,
T
*
sum3
)
{
for
(
int
delta
=
(
WARP_SIZE
>>
1
);
delta
>
0
;
delta
>>=
1
)
{
...
...
@@ -166,12 +201,28 @@ inline __device__ void InputProp(const int& row, const int& col_dim, const int&
int
gamma_offset
=
pos
%
param_dim
;
T
v1
=
dy
[
pos
]
*
gamma
[
gamma_offset
];
T
v2
=
x
[
pos
]
-
mean
[
row
];
T
v3
=
pow
(
var
[
row
]
+
epsilon
,
-
0.5
);
T
v3
=
my_
pow
(
var
[
row
]
+
epsilon
,
-
0.5
);
dx
[
pos
]
=
v1
*
v3
+
share_mem
[
0
]
*
(
2.0
/
col_dim
)
*
v2
+
(
-
1.0
*
v3
*
share_mem
[
1
]
+
(
1.0
/
col_dim
)
*
share_mem
[
0
]
*
share_mem
[
2
])
*
(
1.0
/
col_dim
);
}
}
template
<
>
inline
__device__
void
InputProp
(
const
int
&
row
,
const
int
&
col_dim
,
const
int
&
param_dim
,
const
half
&
epsilon
,
const
half
*
dy
,
const
half
*
x
,
const
half
*
mean
,
const
half
*
var
,
const
half
*
gamma
,
half
*
dx
,
const
half
*
share_mem
)
{
for
(
int
col
=
threadIdx
.
x
;
col
<
col_dim
;
col
+=
blockDim
.
x
)
{
int
pos
=
(
row
*
col_dim
+
col
);
int
gamma_offset
=
pos
%
param_dim
;
half
v1
=
dy
[
pos
]
*
gamma
[
gamma_offset
];
half
v2
=
x
[
pos
]
-
mean
[
row
];
half
v3
=
my_pow
(
var
[
row
]
+
epsilon
,
-
0.5
);
dx
[
pos
]
=
v1
*
v3
+
share_mem
[
0
]
*
__float2half
(
2.0
/
col_dim
)
*
v2
+
(
__float2half
(
-
1.0
)
*
v3
*
share_mem
[
1
]
+
__float2half
(
1.0
/
col_dim
)
*
share_mem
[
0
]
*
share_mem
[
2
])
\
*
__float2half
(
1.0
/
col_dim
);
}
}
template
<
typename
T
>
__global__
void
InputPropKernel
(
const
int
row_dim
,
const
int
col_dim
,
const
int
param_dim
,
const
T
epsilon
,
const
T
*
dy
,
const
T
*
x
,
const
T
*
mean
,
const
T
*
var
,
const
T
*
gamma
,
T
*
dx
)
{
...
...
@@ -179,27 +230,30 @@ __global__ void InputPropKernel(const int row_dim, const int col_dim, const int
T
sum1
=
0
;
T
sum2
=
0
;
T
sum3
=
0
;
extern
__shared__
T
share_mem
[]
;
DynamicSharedMem
<
T
>
share_mem
;
InputThreadReduce
(
row
,
col_dim
,
param_dim
,
epsilon
,
&
sum1
,
&
sum2
,
&
sum3
,
dy
,
x
,
mean
,
var
,
gamma
);
InputWarpReduce
(
&
sum1
,
&
sum2
,
&
sum3
);
InputBlockReduce
(
col_dim
,
&
sum1
,
&
sum2
,
&
sum3
,
share_mem
);
InputProp
(
row
,
col_dim
,
param_dim
,
epsilon
,
dy
,
x
,
mean
,
var
,
gamma
,
dx
,
share_mem
);
InputBlockReduce
(
col_dim
,
&
sum1
,
&
sum2
,
&
sum3
,
share_mem
.
addr
()
);
InputProp
(
row
,
col_dim
,
param_dim
,
epsilon
,
dy
,
x
,
mean
,
var
,
gamma
,
dx
,
share_mem
.
addr
()
);
}
}
template
<
typename
T
>
void
LayerNormGrad
(
const
int
&
row_dim
,
const
int
&
col_dim
,
const
int
&
param_dim
,
const
T
&
epsilon
,
const
T
*
dy
,
const
T
*
x
,
const
T
*
mean
,
const
T
*
var
,
const
T
*
gamma
,
T
*
dx
,
T
*
dg
,
T
*
db
,
cudaStream_t
stream
)
{
int
share_mem
=
int
share_mem
_size
=
((
col_dim
+
NUM_PER_THREAD_REDUCE
-
1
)
/
NUM_PER_THREAD_REDUCE
+
WARP_SIZE
-
1
)
/
WARP_SIZE
*
3
*
sizeof
(
T
);
InputPropKernel
<<<
row_dim
,
256
,
share_mem
,
stream
>>>
(
row_dim
,
col_dim
,
param_dim
,
epsilon
,
dy
,
x
,
mean
,
var
,
gamma
,
dx
);
InputPropKernel
<<<
row_dim
,
256
,
share_mem
_size
,
stream
>>>
(
row_dim
,
col_dim
,
param_dim
,
epsilon
,
dy
,
x
,
mean
,
var
,
gamma
,
dx
);
share_mem
=
share_mem
_size
=
((
row_dim
+
NUM_PER_THREAD_REDUCE
-
1
)
/
NUM_PER_THREAD_REDUCE
+
WARP_SIZE
-
1
)
/
WARP_SIZE
*
2
*
sizeof
(
T
);
GammaAndBetaPropKernel
<<<
col_dim
,
256
,
share_mem
,
stream
>>>
(
row_dim
,
col_dim
,
epsilon
,
dy
,
x
,
mean
,
var
,
dg
,
db
);
GammaAndBetaPropKernel
<<<
col_dim
,
256
,
share_mem
_size
,
stream
>>>
(
row_dim
,
col_dim
,
epsilon
,
dy
,
x
,
mean
,
var
,
dg
,
db
);
}
template
void
LayerNormGrad
(
const
int
&
row_dim
,
const
int
&
col_dim
,
const
int
&
param_dim
,
const
float
&
epsilon
,
const
float
*
dy
,
const
float
*
x
,
const
float
*
mean
,
const
float
*
var
,
const
float
*
gamma
,
float
*
dx
,
float
*
dg
,
float
*
db
,
cudaStream_t
stream
);
template
void
LayerNormGrad
(
const
int
&
row_dim
,
const
int
&
col_dim
,
const
int
&
param_dim
,
const
half
&
epsilon
,
const
half
*
dy
,
const
half
*
x
,
const
half
*
mean
,
const
half
*
var
,
const
half
*
gamma
,
half
*
dx
,
half
*
dg
,
half
*
db
,
cudaStream_t
stream
);
mindspore/ccsrc/kernel/gpu/cuda_impl/layer_norm_impl.cu
浏览文件 @
da71a914
...
...
@@ -35,7 +35,8 @@ inline __device__ void MeanAndVarAccumulation(T *mean, T *var, T *num, const T &
template
<
typename
T
>
inline
__device__
void
MeanAndVarMerge
(
T
*
m1
,
T
*
v1
,
T
*
n1
,
const
T
&
m2
,
const
T
&
v2
,
const
T
&
n2
)
{
if
(
n2
==
0
)
{
T
zero
=
0
;
if
(
n2
==
zero
)
{
return
;
}
...
...
@@ -112,6 +113,17 @@ inline __device__ void LayerNorm(const int &row, const int &col_dim, const int &
}
}
template
<
>
inline
__device__
void
LayerNorm
(
const
int
&
row
,
const
int
&
col_dim
,
const
int
&
param_dim
,
const
half
*
x
,
const
half
*
share_mem
,
const
half
*
gamma
,
const
half
*
beta
,
const
half
epsilon
,
half
*
y
)
{
for
(
int
col
=
threadIdx
.
x
;
col
<
col_dim
;
col
+=
blockDim
.
x
)
{
int
pos
=
row
*
col_dim
+
col
;
int
i
=
pos
%
param_dim
;
y
[
pos
]
=
(
x
[
pos
]
-
share_mem
[
0
])
/
hsqrt
(
share_mem
[
1
]
+
epsilon
)
*
gamma
[
i
]
+
beta
[
i
];
}
}
template
<
typename
T
>
__global__
void
LayerNormKernel
(
const
int
row_dim
,
const
int
col_dim
,
const
int
param_dim
,
const
T
epsilon
,
const
T
*
x
,
const
T
*
gamma
,
const
T
*
beta
,
T
*
y
,
T
*
mean_addr
,
T
*
var_addr
)
{
...
...
@@ -120,14 +132,14 @@ __global__ void LayerNormKernel(const int row_dim, const int col_dim, const int
T
var
=
0
;
T
num
=
0
;
const
T
*
block_addr
=
x
+
row
*
col_dim
;
extern
__shared__
T
share_mem
[]
;
DynamicSharedMem
<
T
>
share_mem
;
ThreadReduce
(
col_dim
,
block_addr
,
&
mean
,
&
var
,
&
num
);
WarpReduce
(
&
mean
,
&
var
,
&
num
);
BlockReduce
(
col_dim
,
&
mean
,
&
var
,
&
num
,
mean_addr
,
var_addr
,
share_mem
);
BlockReduce
(
col_dim
,
&
mean
,
&
var
,
&
num
,
mean_addr
,
var_addr
,
share_mem
.
addr
()
);
__syncthreads
();
LayerNorm
(
row
,
col_dim
,
param_dim
,
x
,
share_mem
,
gamma
,
beta
,
epsilon
,
y
);
LayerNorm
(
row
,
col_dim
,
param_dim
,
x
,
share_mem
.
addr
()
,
gamma
,
beta
,
epsilon
,
y
);
}
}
...
...
@@ -137,12 +149,15 @@ void LayerNorm(const int &row_dim, const int &col_dim, const int ¶m_dim, con
const
dim3
block
(
row_dim
);
const
dim3
thread
(
256
);
// keep the mean/var/num after warp reduce
int
share_mem
=
int
share_mem
_size
=
((
col_dim
+
NUM_PER_THREAD_REDUCE
-
1
)
/
NUM_PER_THREAD_REDUCE
+
WARP_SIZE
-
1
)
/
WARP_SIZE
*
3
*
sizeof
(
T
);
LayerNormKernel
<<<
block
,
thread
,
share_mem
,
stream
>>>
(
row_dim
,
col_dim
,
param_dim
,
epsilon
,
x
,
gamma
,
beta
,
y
,
mean
,
var
);
LayerNormKernel
<<<
block
,
thread
,
share_mem
_size
,
stream
>>>
(
row_dim
,
col_dim
,
param_dim
,
epsilon
,
x
,
gamma
,
beta
,
y
,
mean
,
var
);
}
template
void
LayerNorm
(
const
int
&
row_dim
,
const
int
&
col_dim
,
const
int
&
param_dim
,
const
float
&
epsilon
,
const
float
*
x
,
const
float
*
gamma
,
const
float
*
beta
,
float
*
y
,
float
*
mean
,
float
*
var
,
cudaStream_t
stream
);
template
void
LayerNorm
(
const
int
&
row_dim
,
const
int
&
col_dim
,
const
int
&
param_dim
,
const
half
&
epsilon
,
const
half
*
x
,
const
half
*
gamma
,
const
half
*
beta
,
half
*
y
,
half
*
mean
,
half
*
var
,
cudaStream_t
stream
);
mindspore/ccsrc/kernel/gpu/cuda_impl/layer_norm_impl.cuh
浏览文件 @
da71a914
...
...
@@ -19,6 +19,23 @@
#include "device/gpu/cuda_common.h"
template
<
typename
T
>
struct
DynamicSharedMem
;
template
<
>
struct
DynamicSharedMem
<
float
>
{
__device__
float
*
addr
()
{
extern
__shared__
float
addr_float
[];
return
addr_float
;
}
};
template
<
>
struct
DynamicSharedMem
<
half
>
{
__device__
half
*
addr
()
{
extern
__shared__
half
addr_half
[];
return
addr_half
;
}
};
template
<
typename
T
>
void
LayerNorm
(
const
int
&
outer
,
const
int
&
inner
,
const
int
&
param_dim
,
const
T
&
epsilon
,
const
T
*
x
,
const
T
*
gamma
,
const
T
*
beta
,
T
*
y
,
T
*
mean
,
T
*
var
,
cudaStream_t
stream
);
...
...
mindspore/ccsrc/kernel/gpu/cuda_impl/momentum_impl.cu
浏览文件 @
da71a914
...
...
@@ -15,25 +15,38 @@
*/
#include "momentum_impl.cuh"
template
<
typename
T
>
__global__
void
MomentumUpdateVariableKernel
(
const
size_t
size
,
T
*
variable
,
T
*
accumulation
,
const
T
*
learning_rate
,
const
T
*
gradient
,
const
T
*
momentum
)
{
template
<
typename
T
,
typename
S
>
__global__
void
MomentumUpdateVariableKernel
(
const
size_t
size
,
T
*
variable
,
T
*
accumulation
,
const
S
*
learning_rate
,
const
T
*
gradient
,
const
S
*
momentum
)
{
for
(
size_t
i
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
i
<
(
size
);
i
+=
blockDim
.
x
*
gridDim
.
x
)
{
accumulation
[
i
]
=
momentum
[
0
]
*
accumulation
[
i
]
+
gradient
[
i
];
variable
[
i
]
-=
learning_rate
[
0
]
*
accumulation
[
i
];
}
return
;
}
template
<
typename
T
>
void
MomentumUpdateVariable
(
const
size_t
size
,
T
*
variable
,
T
*
accumulation
,
const
T
*
learning_rate
,
const
T
*
gradient
,
const
T
*
momentum
,
cudaStream_t
cuda_stream
)
{
template
<
>
__global__
void
MomentumUpdateVariableKernel
(
const
size_t
size
,
half
*
variable
,
half
*
accumulation
,
const
float
*
learning_rate
,
const
half
*
gradient
,
const
float
*
momentum
)
{
for
(
size_t
i
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
i
<
(
size
);
i
+=
blockDim
.
x
*
gridDim
.
x
)
{
accumulation
[
i
]
=
__float2half
(
momentum
[
0
])
*
accumulation
[
i
]
+
gradient
[
i
];
variable
[
i
]
-=
__float2half
(
learning_rate
[
0
])
*
accumulation
[
i
];
}
return
;
}
template
<
typename
T
,
typename
S
>
void
MomentumUpdateVariable
(
const
size_t
size
,
T
*
variable
,
T
*
accumulation
,
const
S
*
learning_rate
,
const
T
*
gradient
,
const
S
*
momentum
,
cudaStream_t
cuda_stream
)
{
MomentumUpdateVariableKernel
<<<
GET_BLOCKS
(
size
),
GET_THREADS
,
0
,
cuda_stream
>>>
(
size
,
variable
,
accumulation
,
learning_rate
,
gradient
,
momentum
);
return
;
}
template
void
MomentumUpdateVariable
<
float
>(
const
size_t
size
,
float
*
variable
,
float
*
accumulation
,
const
float
*
learning_rate
,
const
float
*
gradient
,
const
float
*
momentum
,
cudaStream_t
cuda_stream
);
template
void
MomentumUpdateVariable
<
half
>(
const
size_t
size
,
half
*
variable
,
half
*
accumulation
,
const
half
*
learning_rate
,
const
half
*
gradient
,
const
half
*
momentum
,
cudaStream_t
cuda_stream
);
template
void
MomentumUpdateVariable
<
float
,
float
>(
const
size_t
size
,
float
*
variable
,
float
*
accumulation
,
const
float
*
learning_rate
,
const
float
*
gradient
,
const
float
*
momentum
,
cudaStream_t
cuda_stream
);
template
void
MomentumUpdateVariable
<
half
,
half
>(
const
size_t
size
,
half
*
variable
,
half
*
accumulation
,
const
half
*
learning_rate
,
const
half
*
gradient
,
const
half
*
momentum
,
cudaStream_t
cuda_stream
);
template
void
MomentumUpdateVariable
<
half
,
float
>(
const
size_t
size
,
half
*
variable
,
half
*
accumulation
,
const
float
*
learning_rate
,
const
half
*
gradient
,
const
float
*
momentum
,
cudaStream_t
cuda_stream
);
mindspore/ccsrc/kernel/gpu/cuda_impl/momentum_impl.cuh
浏览文件 @
da71a914
...
...
@@ -18,8 +18,8 @@
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_MOMENTUMIMPL_H_
#include "device/gpu/cuda_common.h"
template
<
typename
T
>
void
MomentumUpdateVariable
(
const
size_t
size
,
T
*
variable
,
T
*
accumulation
,
const
T
*
learning_rate
,
const
T
*
gradient
,
const
T
*
momentum
,
cudaStream_t
cuda_stream
);
template
<
typename
T
,
typename
S
>
void
MomentumUpdateVariable
(
const
size_t
size
,
T
*
variable
,
T
*
accumulation
,
const
S
*
learning_rate
,
const
T
*
gradient
,
const
S
*
momentum
,
cudaStream_t
cuda_stream
);
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_MOMENTUMIMPL_H_
mindspore/ccsrc/kernel/gpu/nn/layer_norm_gpu_kernel.cc
浏览文件 @
da71a914
...
...
@@ -27,5 +27,14 @@ MS_REG_GPU_KERNEL_ONE(LayerNorm,
.
AddOutputAttr
(
kNumberTypeFloat32
)
.
AddOutputAttr
(
kNumberTypeFloat32
),
LayerNormGpuKernel
,
float
)
MS_REG_GPU_KERNEL_ONE
(
LayerNorm
,
KernelAttr
()
.
AddInputAttr
(
kNumberTypeFloat16
)
.
AddInputAttr
(
kNumberTypeFloat16
)
.
AddInputAttr
(
kNumberTypeFloat16
)
.
AddOutputAttr
(
kNumberTypeFloat16
)
.
AddOutputAttr
(
kNumberTypeFloat16
)
.
AddOutputAttr
(
kNumberTypeFloat16
),
LayerNormGpuKernel
,
half
)
}
// namespace kernel
}
// namespace mindspore
mindspore/ccsrc/kernel/gpu/nn/layer_norm_grad_gpu_kernel.cc
浏览文件 @
da71a914
...
...
@@ -29,5 +29,16 @@ MS_REG_GPU_KERNEL_ONE(LayerNormGrad,
.
AddOutputAttr
(
kNumberTypeFloat32
)
.
AddOutputAttr
(
kNumberTypeFloat32
),
LayerNormGradGpuKernel
,
float
)
MS_REG_GPU_KERNEL_ONE
(
LayerNormGrad
,
KernelAttr
()
.
AddInputAttr
(
kNumberTypeFloat16
)
.
AddInputAttr
(
kNumberTypeFloat16
)
.
AddInputAttr
(
kNumberTypeFloat16
)
.
AddInputAttr
(
kNumberTypeFloat16
)
.
AddInputAttr
(
kNumberTypeFloat16
)
.
AddOutputAttr
(
kNumberTypeFloat16
)
.
AddOutputAttr
(
kNumberTypeFloat16
)
.
AddOutputAttr
(
kNumberTypeFloat16
),
LayerNormGradGpuKernel
,
half
)
}
// namespace kernel
}
// namespace mindspore
mindspore/ccsrc/kernel/gpu/nn/momentum_gpu_kernel.cc
浏览文件 @
da71a914
...
...
@@ -18,7 +18,7 @@
namespace
mindspore
{
namespace
kernel
{
MS_REG_GPU_KERNEL_
ONE
(
ApplyMomentum
,
MS_REG_GPU_KERNEL_
TWO
(
ApplyMomentum
,
KernelAttr
()
.
AddInputAttr
(
kNumberTypeFloat32
)
.
AddInputAttr
(
kNumberTypeFloat32
)
...
...
@@ -26,8 +26,8 @@ MS_REG_GPU_KERNEL_ONE(ApplyMomentum,
.
AddInputAttr
(
kNumberTypeFloat32
)
.
AddInputAttr
(
kNumberTypeFloat32
)
.
AddOutputAttr
(
kNumberTypeFloat32
),
MomentumGpuKernel
,
float
)
MS_REG_GPU_KERNEL_
ONE
(
ApplyMomentum
,
MomentumGpuKernel
,
float
,
float
)
MS_REG_GPU_KERNEL_
TWO
(
ApplyMomentum
,
KernelAttr
()
.
AddInputAttr
(
kNumberTypeFloat16
)
.
AddInputAttr
(
kNumberTypeFloat16
)
...
...
@@ -35,6 +35,15 @@ MS_REG_GPU_KERNEL_ONE(ApplyMomentum,
.
AddInputAttr
(
kNumberTypeFloat16
)
.
AddInputAttr
(
kNumberTypeFloat16
)
.
AddOutputAttr
(
kNumberTypeFloat16
),
MomentumGpuKernel
,
half
)
MomentumGpuKernel
,
half
,
half
)
MS_REG_GPU_KERNEL_TWO
(
ApplyMomentum
,
KernelAttr
()
.
AddInputAttr
(
kNumberTypeFloat16
)
.
AddInputAttr
(
kNumberTypeFloat16
)
.
AddInputAttr
(
kNumberTypeFloat32
)
.
AddInputAttr
(
kNumberTypeFloat16
)
.
AddInputAttr
(
kNumberTypeFloat32
)
.
AddOutputAttr
(
kNumberTypeFloat16
),
MomentumGpuKernel
,
half
,
float
)
}
// namespace kernel
}
// namespace mindspore
mindspore/ccsrc/kernel/gpu/nn/momentum_gpu_kernel.h
浏览文件 @
da71a914
...
...
@@ -23,7 +23,7 @@
#include "kernel/gpu/cuda_impl/momentum_impl.cuh"
namespace
mindspore
{
namespace
kernel
{
template
<
typename
T
>
template
<
typename
T
,
typename
S
>
class
MomentumGpuKernel
:
public
GpuKernel
{
public:
MomentumGpuKernel
()
...
...
@@ -37,9 +37,9 @@ class MomentumGpuKernel : public GpuKernel {
void
*
stream_ptr
)
override
{
T
*
variable
=
GetDeviceAddress
<
T
>
(
inputs
,
0
);
T
*
accumulation
=
GetDeviceAddress
<
T
>
(
inputs
,
1
);
T
*
learning_rate
=
GetDeviceAddress
<
T
>
(
inputs
,
2
);
S
*
learning_rate
=
GetDeviceAddress
<
S
>
(
inputs
,
2
);
T
*
gradient
=
GetDeviceAddress
<
T
>
(
inputs
,
3
);
T
*
momentum
=
GetDeviceAddress
<
T
>
(
inputs
,
4
);
S
*
momentum
=
GetDeviceAddress
<
S
>
(
inputs
,
4
);
MomentumUpdateVariable
(
inputs
[
0
]
->
size
/
sizeof
(
T
),
variable
,
accumulation
,
learning_rate
,
gradient
,
momentum
,
reinterpret_cast
<
cudaStream_t
>
(
stream_ptr
));
return
true
;
...
...
@@ -53,9 +53,9 @@ class MomentumGpuKernel : public GpuKernel {
variable_size_
=
sizeof
(
T
);
accumulation_size_
=
sizeof
(
T
);
learning_rate_size_
=
sizeof
(
T
);
learning_rate_size_
=
sizeof
(
S
);
gradient_size_
=
sizeof
(
T
);
momentum_size_
=
sizeof
(
T
);
momentum_size_
=
sizeof
(
S
);
auto
variable_shape
=
AnfAlgo
::
GetPrevNodeOutputInferShape
(
kernel_node
,
0
);
for
(
size_t
i
=
0
;
i
<
variable_shape
.
size
();
i
++
)
{
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录