Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Oneflow-Inc
oneflow
提交
9c0d25ec
O
oneflow
项目概览
Oneflow-Inc
/
oneflow
上一次同步 2 年多
通知
13
Star
2733
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
O
oneflow
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
9c0d25ec
编写于
10月 01, 2020
作者:
S
simonJJJ
浏览文件
操作
浏览文件
下载
差异文件
Merge branch 'master' of
https://github.com/Oneflow-Inc/oneflow
into dev_hsv
Former-commit-id: 2f787462d13393508c9a20b74b3aa51aa0292325
上级
c3818919
6c35718c
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
516 addition
and
40 deletion
+516
-40
oneflow/python/ops/nn_ops.py
oneflow/python/ops/nn_ops.py
+114
-0
oneflow/python/test/ops/test_layer_norm.py
oneflow/python/test/ops/test_layer_norm.py
+141
-12
oneflow/user/kernels/layer_norm_gpu_kernel.cu
oneflow/user/kernels/layer_norm_gpu_kernel.cu
+261
-28
未找到文件。
oneflow/python/ops/nn_ops.py
浏览文件 @
9c0d25ec
...
...
@@ -926,6 +926,120 @@ def batch_normalization(
raise
NotImplementedError
@
oneflow_export
(
"nn.layer_norm"
)
def
layer_norm
(
inputs
:
remote_blob_util
.
BlobDef
,
gamma
:
Optional
[
remote_blob_util
.
BlobDef
]
=
None
,
beta
:
Optional
[
remote_blob_util
.
BlobDef
]
=
None
,
begin_norm_axis
:
int
=
1
,
begin_params_axis
:
int
=
-
1
,
epsilon
:
float
=
1e-5
,
name
:
Optional
[
str
]
=
None
,
)
->
remote_blob_util
.
BlobDef
:
r
"""Layer Normalization.
Args:
inputs (remote_blob_util.BlobDef): Input `Blob`.
gamma (Optional[remote_blob_util.BlobDef]).
beta (Optional[remote_blob_util.BlobDef]).
begin_norm_axis (int, optional): An integer specifies which axis to normalize at first. Defaults to 1.
begin_params_axis (int, optional): An integer specifies which axis params at . Defaults to -1.
epsilon (float, optional): A small float is added to avoid division by zero. Defaults to 1e-5.
name (Optional[str], optional): This operator's name. Defaults to None.
Returns:
remote_blob_util.BlobDef: A normalized `Blob` with same shape of input.
For example:
.. code-block:: python
import oneflow as flow
import numpy as np
import oneflow.typing as tp
@flow.global_function()
def layer_norm_Job(x: tp.Numpy.Placeholder((1, 64, 128, 128))
) -> tp.Numpy:
layer_norm = flow.nn.layer_norm(
x,
name="LayerNorm1"
)
return layer_norm
x = np.random.randn(1, 64, 128, 128).astype(np.float32)
out = layer_norm_Job(x)
# out.shape (1, 64, 128, 128)
"""
param_shape
=
inputs
.
shape
[
begin_params_axis
:]
if
name
is
None
:
name
=
id_util
.
UniqueStr
(
"LayerNorm_"
)
if
flow
.
current_scope
().
device_parallel_desc_symbol
.
device_tag
==
"cpu"
:
if
begin_norm_axis
<
0
:
begin_norm_axis
=
begin_norm_axis
+
len
(
inputs
.
shape
)
reduce_axis
=
[]
for
dim
in
range
(
len
(
inputs
.
shape
)):
if
dim
>=
begin_norm_axis
:
reduce_axis
.
append
(
dim
)
mean
,
variance
=
flow
.
nn
.
moments
(
inputs
,
reduce_axis
,
keepdims
=
True
)
axis
=
begin_norm_axis
normalized
=
flow
.
nn
.
batch_normalization
(
x
=
inputs
,
mean
=
mean
,
variance
=
variance
,
variance_epsilon
=
epsilon
,
axis
=
axis
,
name
=
name
,
)
nd_params_shape
=
[
1
]
*
(
len
(
inputs
.
shape
)
-
len
(
param_shape
))
+
list
(
param_shape
)
affined
=
normalized
if
gamma
:
gamma
=
flow
.
reshape
(
gamma
,
nd_params_shape
)
affined
*=
gamma
if
beta
:
beta
=
flow
.
reshape
(
beta
,
nd_params_shape
)
affined
+=
beta
return
affined
elif
flow
.
current_scope
().
device_parallel_desc_symbol
.
device_tag
==
"gpu"
:
op_builder
=
(
flow
.
user_op_builder
(
name
)
.
Op
(
"layer_norm"
)
.
Input
(
"x"
,
[
inputs
])
.
Output
(
"y"
)
.
Output
(
"mean"
)
.
Output
(
"inv_variance"
)
)
scale
=
False
center
=
False
if
beta
is
not
None
:
center
=
True
op_builder
.
Input
(
"beta"
,
[
beta
])
if
gamma
is
not
None
:
scale
=
True
op_builder
.
Input
(
"gamma"
,
[
gamma
])
op_builder
.
Output
(
"normalized"
)
op_builder
.
Attr
(
"center"
,
center
)
op_builder
.
Attr
(
"scale"
,
scale
)
op_builder
.
Attr
(
"begin_norm_axis"
,
begin_norm_axis
)
op_builder
.
Attr
(
"begin_params_axis"
,
begin_params_axis
)
op_builder
.
Attr
(
"epsilon"
,
epsilon
)
y
=
op_builder
.
Build
().
InferAndTryRun
().
RemoteBlobList
()[
0
]
return
y
else
:
raise
NotImplementedError
@
oneflow_export
(
"nn.compat_conv2d"
)
def
tf_conv2d
(
input
:
remote_blob_util
.
BlobDef
,
...
...
oneflow/python/test/ops/test_layer_norm.py
浏览文件 @
9c0d25ec
...
...
@@ -20,6 +20,8 @@ from collections import OrderedDict
import
numpy
as
np
import
oneflow
as
flow
import
tensorflow
as
tf
import
test_global_storage
from
test_util
import
GenArgList
,
type_name_to_flow_type
,
type_name_to_np_type
import
oneflow.typing
as
oft
...
...
@@ -30,12 +32,12 @@ for gpu in gpus:
def
test_layer_norm
(
_
):
confs
=
[
{
"x_shape"
:
(
4
,
5
,
2
,
6
),
"begin_norm_axis"
:
-
1
,
"begin_params_axis"
:
-
1
},
{
"x_shape"
:
(
4
0
,
50
),
"begin_norm_axis"
:
-
1
,
"begin_params_axis"
:
-
1
},
]
arg_dict
=
OrderedDict
()
arg_dict
[
"device_type"
]
=
[
"cpu"
,
"gpu"
]
arg_dict
[
"confs"
]
=
confs
arg_dict
[
"data_type"
]
=
[
"float32"
]
arg_dict
[
"data_type"
]
=
[
"float32"
,
"float16"
]
arg_dict
[
"trainable"
]
=
[
True
,
False
]
arg_dict
[
"center"
]
=
[
True
,
False
]
arg_dict
[
"scale"
]
=
[
True
,
False
]
...
...
@@ -43,6 +45,8 @@ def test_layer_norm(_):
for
case
in
GenArgList
(
arg_dict
):
(
device_type
,
confs
,
data_type
,
trainable
,
center
,
scale
,
epsilon
)
=
case
if
device_type
==
"cpu"
and
data_type
==
"float16"
:
continue
x_shape
=
confs
[
"x_shape"
]
begin_norm_axis
=
confs
[
"begin_norm_axis"
]
begin_params_axis
=
confs
[
"begin_params_axis"
]
...
...
@@ -51,13 +55,26 @@ def test_layer_norm(_):
begin_norm_axis
==
begin_params_axis
),
"tf doesn't support a dedicated begin_params_axis"
# Random inputs
x
=
np
.
random
.
randn
(
*
x_shape
).
astype
(
type_name_to_np_type
[
data_type
])
if
data_type
==
"float16"
:
x
=
(
np
.
random
.
uniform
(
low
=-
1
,
high
=
1
,
size
=
x_shape
)
.
astype
(
np
.
float16
)
.
astype
(
np
.
float32
)
)
else
:
x
=
np
.
random
.
uniform
(
low
=-
1
,
high
=
1
,
size
=
x_shape
).
astype
(
type_name_to_np_type
[
data_type
]
)
dim
=
len
(
x
.
shape
)
-
2
# TF results
with
tf
.
GradientTape
(
persistent
=
True
)
as
tape
:
x_tf
=
tf
.
Variable
(
x
)
y_tf
=
tf
.
keras
.
layers
.
LayerNormalization
(
if
data_type
==
"float16"
:
x_tf
=
tf
.
cast
(
x_tf
,
dtype
=
tf
.
float16
)
tf
.
keras
.
backend
.
set_floatx
(
"float16"
)
layer
=
tf
.
keras
.
layers
.
LayerNormalization
(
axis
=
begin_norm_axis
,
epsilon
=
epsilon
,
center
=
center
,
...
...
@@ -69,20 +86,65 @@ def test_layer_norm(_):
beta_constraint
=
None
,
gamma_constraint
=
None
,
trainable
=
trainable
,
)(
x_tf
)
dx_tf
=
tape
.
gradient
(
y_tf
,
x_tf
,
tf
.
constant
(
1.0
,
shape
=
y_tf
.
shape
))
)
y_tf
=
layer
(
x_tf
)
if
data_type
==
"float16"
:
dx_tf
=
tape
.
gradient
(
y_tf
,
x_tf
,
tf
.
constant
(
1.0
,
shape
=
y_tf
.
shape
,
dtype
=
tf
.
float16
)
)
else
:
dx_tf
=
tape
.
gradient
(
y_tf
,
x_tf
,
tf
.
constant
(
1.0
,
shape
=
y_tf
.
shape
))
grad
=
tape
.
gradient
(
y_tf
,
layer
.
trainable_variables
)
if
trainable
:
if
scale
and
center
:
tf_gamma_diff
=
grad
[
0
]
tf_beta_diff
=
grad
[
1
]
elif
scale
and
not
center
:
tf_gamma_diff
=
grad
[
0
]
elif
not
scale
and
center
:
tf_beta_diff
=
grad
[
0
]
else
:
pass
else
:
pass
def
assert_grad
(
b
):
diff
=
dx_tf
.
numpy
()
-
b
.
numpy
()
max_diff
=
np
.
max
(
np
.
abs
(
diff
))
assert
np
.
allclose
(
dx_tf
.
numpy
(),
b
.
numpy
(),
rtol
=
1e-5
,
atol
=
1e-5
),
(
if
data_type
==
"float16"
:
tolerance
=
2e-3
else
:
tolerance
=
1e-5
assert
np
.
allclose
(
dx_tf
.
numpy
(),
b
.
numpy
(),
rtol
=
tolerance
,
atol
=
tolerance
),
(
case
,
max_diff
,
)
def
assert_grad_gamma
(
b
):
diff
=
tf_gamma_diff
.
numpy
()
-
b
.
numpy
()
max_diff
=
np
.
max
(
np
.
abs
(
diff
))
assert
np
.
allclose
(
tf_gamma_diff
.
numpy
(),
b
.
numpy
(),
rtol
=
1e-4
,
atol
=
1e-4
),
(
case
,
max_diff
,
)
def
assert_grad_beta
(
b
):
diff
=
tf_beta_diff
.
numpy
()
-
b
.
numpy
()
max_diff
=
np
.
max
(
np
.
abs
(
diff
))
assert
np
.
allclose
(
tf_beta_diff
.
numpy
(),
b
.
numpy
(),
rtol
=
1e-5
,
atol
=
1e-5
),
(
case
,
max_diff
,
)
# 1F results
dtype
=
type_name_to_flow_type
[
data_type
]
if
data_type
==
"float16"
:
dtype
=
flow
.
float
else
:
dtype
=
type_name_to_flow_type
[
data_type
]
func_config
=
flow
.
FunctionConfig
()
func_config
.
default_data_type
(
flow
.
float
)
...
...
@@ -98,14 +160,60 @@ def test_layer_norm(_):
)
flow
.
watch_diff
(
v
,
assert_grad
)
x
+=
v
if
data_type
==
"float16"
:
x
=
flow
.
cast
(
x
,
dtype
=
flow
.
float16
)
with
flow
.
scope
.
placement
(
device_type
,
"0:0"
):
y
=
flow
.
layers
.
layer_norm
(
param_shape
=
x
.
shape
[
begin_params_axis
:]
gamma
=
None
beta
=
None
if
center
:
with
flow
.
scope
.
namespace
(
"LayerNorm"
):
beta
=
flow
.
get_variable
(
name
=
"beta"
,
shape
=
param_shape
,
dtype
=
flow
.
float
,
initializer
=
flow
.
constant_initializer
(
0.0
),
trainable
=
trainable
,
model_name
=
"beta"
,
reuse
=
False
,
)
if
trainable
:
flow
.
watch_diff
(
beta
,
assert_grad_beta
)
if
data_type
==
"float16"
:
beta
=
flow
.
cast
(
beta
,
dtype
=
flow
.
float16
)
if
scale
:
with
flow
.
scope
.
namespace
(
"LayerNorm"
):
gamma
=
flow
.
get_variable
(
name
=
"gamma"
,
shape
=
param_shape
,
dtype
=
flow
.
float
,
initializer
=
flow
.
constant_initializer
(
1.0
),
trainable
=
trainable
,
model_name
=
"gamma"
,
reuse
=
False
,
)
if
trainable
:
if
data_type
==
"float16"
:
flow
.
watch_diff
(
gamma
,
test_global_storage
.
Setter
(
"gamma_diff"
)
)
else
:
flow
.
watch_diff
(
gamma
,
assert_grad_gamma
)
if
data_type
==
"float16"
:
gamma
=
flow
.
cast
(
gamma
,
dtype
=
flow
.
float16
)
y
=
flow
.
nn
.
layer_norm
(
x
,
gamma
=
gamma
,
beta
=
beta
,
begin_norm_axis
=
begin_norm_axis
,
begin_params_axis
=
begin_params_axis
,
center
=
center
,
scale
=
scale
,
epsilon
=
epsilon
,
)
if
data_type
==
"float16"
:
y
=
flow
.
cast
(
y
,
dtype
=
flow
.
float
)
flow
.
optimizer
.
SGD
(
flow
.
optimizer
.
PiecewiseConstantScheduler
([],
[
1e-4
]),
momentum
=
0
).
minimize
(
y
)
...
...
@@ -114,6 +222,7 @@ def test_layer_norm(_):
check_point
=
flow
.
train
.
CheckPoint
()
check_point
.
init
()
y
=
test_job
(
x
).
get
()
assert
y
.
numpy
().
shape
==
y_tf
.
numpy
().
shape
,
(
y
.
numpy
().
shape
,
y_tf
.
numpy
().
shape
,
...
...
@@ -124,3 +233,23 @@ def test_layer_norm(_):
case
,
max_diff
,
)
if
data_type
==
"float16"
and
trainable
and
scale
:
np_dy
=
np
.
ones
(
x
.
shape
).
astype
(
np
.
float32
)
np_gamma_diff
=
np
.
sum
(
np_dy
*
y
.
numpy
().
astype
(
np
.
float32
),
axis
=
0
).
astype
(
np
.
float16
)
max_diff
=
np
.
max
(
np
.
abs
(
np_gamma_diff
-
test_global_storage
.
Get
(
"gamma_diff"
).
astype
(
np
.
float16
)
)
)
assert
np
.
allclose
(
np_gamma_diff
,
test_global_storage
.
Get
(
"gamma_diff"
).
astype
(
np
.
float16
),
rtol
=
1e-2
,
atol
=
1e-2
,
),
(
case
,
max_diff
,
)
oneflow/user/kernels/layer_norm_gpu_kernel.cu
浏览文件 @
9c0d25ec
...
...
@@ -17,6 +17,7 @@ limitations under the License.
#include "oneflow/core/device/cudnn_util.h"
#include "oneflow/core/framework/framework.h"
#include "oneflow/core/ndarray/ndarray_util.h"
#include "oneflow/core/kernel/kernel_util.cuh"
namespace
oneflow
{
...
...
@@ -142,6 +143,86 @@ void InstanceScaleCenter<float16>(DeviceCtx* ctx, const int64_t batch_size,
}
}
constexpr
int64_t
kLayerNormGpuBlockSize
=
512
;
int64_t
GetLayerNormBlockSize
()
{
return
kLayerNormGpuBlockSize
;
}
int64_t
GetLayerNormNumBlocks
(
const
int64_t
elem_cnt
)
{
return
std
::
min
(
static_cast
<
int
>
((
elem_cnt
+
kLayerNormGpuBlockSize
-
1
)
/
kLayerNormGpuBlockSize
),
256
);
}
template
<
typename
T
>
int64_t
GetDynamicSharedMemorySize
(
const
int64_t
instance_size
)
{
return
2
*
instance_size
*
sizeof
(
T
);
}
template
<
>
int64_t
GetDynamicSharedMemorySize
<
float16
>
(
const
int64_t
instance_size
)
{
return
2
*
instance_size
*
sizeof
(
float
);
}
template
<
typename
T
,
typename
I
>
__global__
void
LayerNormParamGradImpl
(
const
I
n
,
const
I
instance_size
,
const
T
*
dy
,
const
T
*
normalized
,
const
T
*
gamma
,
T
*
gamma_diff
,
T
*
beta_diff
,
T
*
normalized_diff
)
{
extern
__shared__
__align__
(
sizeof
(
T
))
unsigned
char
bw_shared_buf
[];
auto
*
gamma_diff_sum_buf
=
reinterpret_cast
<
T
*>
(
bw_shared_buf
);
auto
*
beta_diff_sum_buf
=
gamma_diff_sum_buf
+
instance_size
;
const
I
tid
=
threadIdx
.
x
;
for
(
I
elem_id
=
tid
;
elem_id
<
instance_size
;
elem_id
+=
blockDim
.
x
)
{
gamma_diff_sum_buf
[
elem_id
]
=
0
;
beta_diff_sum_buf
[
elem_id
]
=
0
;
}
__syncthreads
();
CUDA_1D_KERNEL_LOOP_T
(
I
,
i
,
n
)
{
const
I
elem_id
=
i
%
instance_size
;
T
dy_val
=
dy
[
i
];
T
normalized_val
=
normalized
[
i
];
gpu_atomic_add
(
&
gamma_diff_sum_buf
[
elem_id
],
dy_val
*
normalized_val
);
gpu_atomic_add
(
&
beta_diff_sum_buf
[
elem_id
],
dy_val
);
T
gamma_val
=
gamma
[
elem_id
];
normalized_diff
[
i
]
=
gamma_val
*
dy_val
;
}
__syncthreads
();
for
(
I
elem_id
=
tid
;
elem_id
<
instance_size
;
elem_id
+=
blockDim
.
x
)
{
gpu_atomic_add
(
gamma_diff
+
elem_id
,
gamma_diff_sum_buf
[
elem_id
]);
gpu_atomic_add
(
beta_diff
+
elem_id
,
beta_diff_sum_buf
[
elem_id
]);
}
}
template
<
typename
I
>
__global__
void
LayerNormParamGradHalfImpl
(
const
I
n
,
const
I
instance_size
,
const
half
*
dy
,
const
half
*
normalized
,
const
half
*
gamma
,
half
*
tmp_gamma_diff
,
half
*
tmp_beta_diff
,
half
*
normalized_diff
)
{
extern
__shared__
__align__
(
sizeof
(
float
))
unsigned
char
bw_shared_buf
[];
auto
*
gamma_diff_sum_buf
=
reinterpret_cast
<
float
*>
(
bw_shared_buf
);
auto
*
beta_diff_sum_buf
=
gamma_diff_sum_buf
+
instance_size
;
const
I
tid
=
threadIdx
.
x
;
for
(
I
elem_id
=
tid
;
elem_id
<
instance_size
;
elem_id
+=
blockDim
.
x
)
{
gamma_diff_sum_buf
[
elem_id
]
=
0
;
beta_diff_sum_buf
[
elem_id
]
=
0
;
}
__syncthreads
();
CUDA_1D_KERNEL_LOOP_T
(
I
,
i
,
n
)
{
const
I
elem_id
=
i
%
instance_size
;
half
dy_val
=
dy
[
i
];
half
normalized_val
=
normalized
[
i
];
gpu_atomic_add
(
&
gamma_diff_sum_buf
[
elem_id
],
__half2float
(
dy_val
)
*
__half2float
(
normalized_val
));
gpu_atomic_add
(
&
beta_diff_sum_buf
[
elem_id
],
__half2float
(
dy_val
));
half
gamma_val
=
gamma
[
elem_id
];
normalized_diff
[
i
]
=
__hmul
(
gamma_val
,
dy_val
);
}
__syncthreads
();
for
(
I
elem_id
=
tid
;
elem_id
<
instance_size
;
elem_id
+=
blockDim
.
x
)
{
const
I
offset
=
blockIdx
.
x
*
instance_size
+
elem_id
;
tmp_gamma_diff
[
offset
]
=
__float2half
(
gamma_diff_sum_buf
[
elem_id
]);
tmp_beta_diff
[
offset
]
=
__float2half
(
beta_diff_sum_buf
[
elem_id
]);
}
}
}
// namespace
template
<
typename
T
,
typename
BNParamT
>
...
...
@@ -298,36 +379,67 @@ class LayerNormParamGradGpuKernel final : public user_op::OpKernel {
const
bool
has_gamma_diff
=
gamma_diff
!=
nullptr
;
const
bool
has_normalized_diff
=
normalized_diff
!=
nullptr
;
const
bool
has_gamma
=
gamma
!=
nullptr
;
if
(
has_beta_diff
)
{
user_op
::
Tensor
*
reduce_buf
=
ctx
->
Tensor4ArgNameAndIndex
(
"reduce_buf"
,
0
);
const
int64_t
m
=
beta_diff
->
shape
().
elem_cnt
();
CHECK_EQ
(
dy
->
shape
().
elem_cnt
()
%
m
,
0
);
const
int64_t
n
=
dy
->
shape
().
elem_cnt
()
/
m
;
NdUtil
::
ReduceSum
(
ctx
->
device_ctx
(),
Var
({
1
,
m
},
beta_diff
->
mut_dptr
<
T
>
()),
Val
({
n
,
m
},
dy
->
dptr
<
T
>
()),
Var
({
n
,
m
},
reduce_buf
->
mut_dptr
<
T
>
()));
}
if
(
has_gamma_diff
)
{
const
int64_t
begin_params_axis
=
ctx
->
Attr
<
int64_t
>
(
"begin_params_axis"
);
const
int64_t
elem_cnt
=
dy
->
shape
().
elem_cnt
();
const
int64_t
m
=
dy
->
shape
().
Count
(
begin_params_axis
);
int
max_active_blocks
;
OF_CUDA_CHECK
(
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
&
max_active_blocks
,
LayerNormParamGradImpl
<
T
,
int64_t
>
,
GetLayerNormBlockSize
(),
GetDynamicSharedMemorySize
<
T
>
(
m
)));
if
(
has_gamma_diff
&&
has_beta_diff
&&
has_normalized_diff
&&
max_active_blocks
>
0
)
{
const
user_op
::
Tensor
*
normalized
=
ctx
->
Tensor4ArgNameAndIndex
(
"normalized"
,
0
);
user_op
::
Tensor
*
reduce_buf
=
ctx
->
Tensor4ArgNameAndIndex
(
"reduce_buf"
,
0
);
const
int64_t
m
=
gamma_diff
->
shape
().
elem_cnt
();
CHECK_EQ
(
dy
->
shape
().
elem_cnt
()
%
m
,
0
);
const
int64_t
n
=
dy
->
shape
().
elem_cnt
()
/
m
;
NdUtil
::
BroadcastMul
(
ctx
->
device_ctx
(),
Var
({
n
,
m
},
reduce_buf
->
mut_dptr
<
T
>
()),
Val
({
n
,
m
},
normalized
->
dptr
<
T
>
()),
Val
({
n
,
m
},
dy
->
dptr
<
T
>
()));
NdUtil
::
ReduceSum
(
ctx
->
device_ctx
(),
Var
({
1
,
m
},
gamma_diff
->
mut_dptr
<
T
>
()),
Val
({
n
,
m
},
reduce_buf
->
dptr
<
T
>
()),
Var
({
n
,
m
},
reduce_buf
->
mut_dptr
<
T
>
()));
}
if
(
has_normalized_diff
)
{
if
(
has_gamma
)
{
const
int64_t
m
=
gamma
->
shape
().
elem_cnt
();
Memset
<
DeviceType
::
kGPU
>
(
ctx
->
device_ctx
(),
gamma_diff
->
mut_dptr
<
T
>
(),
0
,
gamma_diff
->
shape
().
elem_cnt
()
*
sizeof
(
T
));
Memset
<
DeviceType
::
kGPU
>
(
ctx
->
device_ctx
(),
beta_diff
->
mut_dptr
<
T
>
(),
0
,
beta_diff
->
shape
().
elem_cnt
()
*
sizeof
(
T
));
if
(
elem_cnt
>
static_cast
<
int64_t
>
(
GetMaxVal
<
int32_t
>
()
/
2
))
{
LayerNormParamGradImpl
<
T
,
int64_t
>
<<<
GetLayerNormNumBlocks
(
elem_cnt
),
GetLayerNormBlockSize
(),
GetDynamicSharedMemorySize
<
T
>
(
m
),
ctx
->
device_ctx
()
->
cuda_stream
()
>>>
(
elem_cnt
,
m
,
dy
->
dptr
<
T
>
(),
normalized
->
dptr
<
T
>
(),
gamma
->
dptr
<
T
>
(),
gamma_diff
->
mut_dptr
<
T
>
(),
beta_diff
->
mut_dptr
<
T
>
(),
normalized_diff
->
mut_dptr
<
T
>
());
}
else
{
LayerNormParamGradImpl
<
T
,
int32_t
>
<<<
GetLayerNormNumBlocks
(
elem_cnt
),
GetLayerNormBlockSize
(),
GetDynamicSharedMemorySize
<
T
>
(
m
),
ctx
->
device_ctx
()
->
cuda_stream
()
>>>
(
static_cast
<
int32_t
>
(
elem_cnt
),
static_cast
<
int32_t
>
(
m
),
dy
->
dptr
<
T
>
(),
normalized
->
dptr
<
T
>
(),
gamma
->
dptr
<
T
>
(),
gamma_diff
->
mut_dptr
<
T
>
(),
beta_diff
->
mut_dptr
<
T
>
(),
normalized_diff
->
mut_dptr
<
T
>
());
}
}
else
{
if
(
has_beta_diff
)
{
user_op
::
Tensor
*
reduce_buf
=
ctx
->
Tensor4ArgNameAndIndex
(
"reduce_buf"
,
0
);
CHECK_EQ
(
m
,
beta_diff
->
shape
().
elem_cnt
());
CHECK_EQ
(
dy
->
shape
().
elem_cnt
()
%
m
,
0
);
const
int64_t
n
=
dy
->
shape
().
elem_cnt
()
/
m
;
NdUtil
::
BroadcastMul
(
ctx
->
device_ctx
(),
Var
({
n
,
m
},
normalized_diff
->
mut_dptr
<
T
>
()),
Val
({
n
,
m
},
dy
->
dptr
<
T
>
()),
Val
({
1
,
m
},
gamma
->
dptr
<
T
>
()));
}
else
{
Memcpy
<
DeviceType
::
kGPU
>
(
ctx
->
device_ctx
(),
normalized_diff
->
mut_dptr
<
void
>
(),
dy
->
dptr
<
void
>
(),
dy
->
shape
().
elem_cnt
()
*
GetSizeOfDataType
(
dy
->
data_type
()));
NdUtil
::
ReduceSum
(
ctx
->
device_ctx
(),
Var
({
1
,
m
},
beta_diff
->
mut_dptr
<
T
>
()),
Val
({
n
,
m
},
dy
->
dptr
<
T
>
()),
Var
({
n
,
m
},
reduce_buf
->
mut_dptr
<
T
>
()));
}
if
(
has_gamma_diff
)
{
const
user_op
::
Tensor
*
normalized
=
ctx
->
Tensor4ArgNameAndIndex
(
"normalized"
,
0
);
user_op
::
Tensor
*
reduce_buf
=
ctx
->
Tensor4ArgNameAndIndex
(
"reduce_buf"
,
0
);
CHECK_EQ
(
m
,
gamma_diff
->
shape
().
elem_cnt
());
CHECK_EQ
(
dy
->
shape
().
elem_cnt
()
%
m
,
0
);
const
int64_t
n
=
dy
->
shape
().
elem_cnt
()
/
m
;
NdUtil
::
BroadcastMul
(
ctx
->
device_ctx
(),
Var
({
n
,
m
},
reduce_buf
->
mut_dptr
<
T
>
()),
Val
({
n
,
m
},
normalized
->
dptr
<
T
>
()),
Val
({
n
,
m
},
dy
->
dptr
<
T
>
()));
NdUtil
::
ReduceSum
(
ctx
->
device_ctx
(),
Var
({
1
,
m
},
gamma_diff
->
mut_dptr
<
T
>
()),
Val
({
n
,
m
},
reduce_buf
->
dptr
<
T
>
()),
Var
({
n
,
m
},
reduce_buf
->
mut_dptr
<
T
>
()));
}
if
(
has_normalized_diff
)
{
if
(
has_gamma
)
{
CHECK_EQ
(
m
,
gamma
->
shape
().
elem_cnt
());
CHECK_EQ
(
dy
->
shape
().
elem_cnt
()
%
m
,
0
);
const
int64_t
n
=
dy
->
shape
().
elem_cnt
()
/
m
;
NdUtil
::
BroadcastMul
(
ctx
->
device_ctx
(),
Var
({
n
,
m
},
normalized_diff
->
mut_dptr
<
T
>
()),
Val
({
n
,
m
},
dy
->
dptr
<
T
>
()),
Val
({
1
,
m
},
gamma
->
dptr
<
T
>
()));
}
else
{
Memcpy
<
DeviceType
::
kGPU
>
(
ctx
->
device_ctx
(),
normalized_diff
->
mut_dptr
<
void
>
(),
dy
->
dptr
<
void
>
(),
dy
->
shape
().
elem_cnt
()
*
GetSizeOfDataType
(
dy
->
data_type
()));
}
}
}
};
...
...
@@ -341,6 +453,127 @@ class LayerNormParamGradGpuKernel final : public user_op::OpKernel {
REGISTER_LAYER_NORM_PARAM_GRAD_GPU_KERNEL
(
float
)
REGISTER_LAYER_NORM_PARAM_GRAD_GPU_KERNEL
(
double
)
REGISTER_LAYER_NORM_PARAM_GRAD_GPU_KERNEL
(
float16
)
class
LayerNormParamGradGpuHalfKernel
final
:
public
user_op
::
OpKernel
{
public:
LayerNormParamGradGpuHalfKernel
()
=
default
;
~
LayerNormParamGradGpuHalfKernel
()
=
default
;
private:
bool
AlwaysComputeWhenAllOutputsEmpty
()
const
override
{
return
false
;
}
void
Compute
(
user_op
::
KernelComputeContext
*
ctx
)
const
override
{
using
NdUtil
=
NdarrayUtil
<
DeviceType
::
kGPU
,
float16
>
;
auto
Val
=
NdUtil
::
GetValNdarrayBuilder
();
auto
Var
=
NdUtil
::
GetVarNdarrayBuilder
();
const
user_op
::
Tensor
*
dy
=
ctx
->
Tensor4ArgNameAndIndex
(
"dy"
,
0
);
user_op
::
Tensor
*
beta_diff
=
ctx
->
Tensor4ArgNameAndIndex
(
"beta_diff"
,
0
);
user_op
::
Tensor
*
gamma_diff
=
ctx
->
Tensor4ArgNameAndIndex
(
"gamma_diff"
,
0
);
user_op
::
Tensor
*
normalized_diff
=
ctx
->
Tensor4ArgNameAndIndex
(
"normalized_diff"
,
0
);
user_op
::
Tensor
*
gamma
=
ctx
->
Tensor4ArgNameAndIndex
(
"gamma"
,
0
);
const
bool
has_beta_diff
=
beta_diff
!=
nullptr
;
const
bool
has_gamma_diff
=
gamma_diff
!=
nullptr
;
const
bool
has_normalized_diff
=
normalized_diff
!=
nullptr
;
const
bool
has_gamma
=
gamma
!=
nullptr
;
const
int64_t
begin_params_axis
=
ctx
->
Attr
<
int64_t
>
(
"begin_params_axis"
);
const
int64_t
elem_cnt
=
dy
->
shape
().
elem_cnt
();
const
int64_t
m
=
dy
->
shape
().
Count
(
begin_params_axis
);
int
max_active_blocks
;
OF_CUDA_CHECK
(
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
&
max_active_blocks
,
LayerNormParamGradHalfImpl
<
int64_t
>
,
GetLayerNormBlockSize
(),
GetDynamicSharedMemorySize
<
float16
>
(
m
)));
if
(
has_gamma_diff
&&
has_beta_diff
&&
has_normalized_diff
&&
max_active_blocks
>
0
)
{
const
user_op
::
Tensor
*
normalized
=
ctx
->
Tensor4ArgNameAndIndex
(
"normalized"
,
0
);
user_op
::
Tensor
*
tmp_buffer
=
ctx
->
Tensor4ArgNameAndIndex
(
"tmp_buffer"
,
0
);
const
int64_t
num_blocks
=
GetLayerNormNumBlocks
(
dy
->
shape
().
elem_cnt
());
const
size_t
tmp_diff_size
=
GetCudaAlignedSize
(
num_blocks
*
m
*
sizeof
(
float16
));
float16
*
tmp_gamma_diff
=
tmp_buffer
->
mut_dptr
<
float16
>
();
float16
*
tmp_beta_diff
=
reinterpret_cast
<
float16
*>
(
tmp_buffer
->
mut_dptr
<
char
>
()
+
tmp_diff_size
);
float16
*
tmp_reduce_buf
=
reinterpret_cast
<
float16
*>
(
tmp_buffer
->
mut_dptr
<
char
>
()
+
2
*
tmp_diff_size
);
CHECK_GE
(
tmp_buffer
->
shape
().
elem_cnt
(),
3
*
tmp_diff_size
);
if
(
elem_cnt
>
static_cast
<
int64_t
>
(
GetMaxVal
<
int32_t
>
()
/
2
))
{
LayerNormParamGradHalfImpl
<
int64_t
>
<<<
GetLayerNormNumBlocks
(
elem_cnt
),
GetLayerNormBlockSize
(),
GetDynamicSharedMemorySize
<
float16
>
(
m
),
ctx
->
device_ctx
()
->
cuda_stream
()
>>>
(
elem_cnt
,
m
,
dy
->
dptr
<
half
>
(),
normalized
->
dptr
<
half
>
(),
gamma
->
dptr
<
half
>
(),
reinterpret_cast
<
half
*>
(
tmp_gamma_diff
),
reinterpret_cast
<
half
*>
(
tmp_beta_diff
),
normalized_diff
->
mut_dptr
<
half
>
());
}
else
{
LayerNormParamGradHalfImpl
<
int32_t
>
<<<
GetLayerNormNumBlocks
(
elem_cnt
),
GetLayerNormBlockSize
(),
GetDynamicSharedMemorySize
<
float16
>
(
m
),
ctx
->
device_ctx
()
->
cuda_stream
()
>>>
(
static_cast
<
int32_t
>
(
elem_cnt
),
static_cast
<
int32_t
>
(
m
),
dy
->
dptr
<
half
>
(),
normalized
->
dptr
<
half
>
(),
gamma
->
dptr
<
half
>
(),
reinterpret_cast
<
half
*>
(
tmp_gamma_diff
),
reinterpret_cast
<
half
*>
(
tmp_beta_diff
),
normalized_diff
->
mut_dptr
<
half
>
());
}
NdUtil
::
ReduceSum
(
ctx
->
device_ctx
(),
Var
({
1
,
m
},
gamma_diff
->
mut_dptr
<
float16
>
()),
Val
({
num_blocks
,
m
},
tmp_gamma_diff
),
Var
({
num_blocks
,
m
},
tmp_reduce_buf
));
NdUtil
::
ReduceSum
(
ctx
->
device_ctx
(),
Var
({
1
,
m
},
beta_diff
->
mut_dptr
<
float16
>
()),
Val
({
num_blocks
,
m
},
tmp_beta_diff
),
Var
({
num_blocks
,
m
},
tmp_reduce_buf
));
}
else
{
if
(
has_beta_diff
)
{
user_op
::
Tensor
*
reduce_buf
=
ctx
->
Tensor4ArgNameAndIndex
(
"reduce_buf"
,
0
);
CHECK_EQ
(
m
,
beta_diff
->
shape
().
elem_cnt
());
CHECK_EQ
(
dy
->
shape
().
elem_cnt
()
%
m
,
0
);
const
int64_t
n
=
dy
->
shape
().
elem_cnt
()
/
m
;
NdUtil
::
ReduceSum
(
ctx
->
device_ctx
(),
Var
({
1
,
m
},
beta_diff
->
mut_dptr
<
float16
>
()),
Val
({
n
,
m
},
dy
->
dptr
<
float16
>
()),
Var
({
n
,
m
},
reduce_buf
->
mut_dptr
<
float16
>
()));
}
if
(
has_gamma_diff
)
{
const
user_op
::
Tensor
*
normalized
=
ctx
->
Tensor4ArgNameAndIndex
(
"normalized"
,
0
);
user_op
::
Tensor
*
reduce_buf
=
ctx
->
Tensor4ArgNameAndIndex
(
"reduce_buf"
,
0
);
CHECK_EQ
(
m
,
gamma_diff
->
shape
().
elem_cnt
());
CHECK_EQ
(
dy
->
shape
().
elem_cnt
()
%
m
,
0
);
const
int64_t
n
=
dy
->
shape
().
elem_cnt
()
/
m
;
NdUtil
::
BroadcastMul
(
ctx
->
device_ctx
(),
Var
({
n
,
m
},
reduce_buf
->
mut_dptr
<
float16
>
()),
Val
({
n
,
m
},
normalized
->
dptr
<
float16
>
()),
Val
({
n
,
m
},
dy
->
dptr
<
float16
>
()));
NdUtil
::
ReduceSum
(
ctx
->
device_ctx
(),
Var
({
1
,
m
},
gamma_diff
->
mut_dptr
<
float16
>
()),
Val
({
n
,
m
},
reduce_buf
->
dptr
<
float16
>
()),
Var
({
n
,
m
},
reduce_buf
->
mut_dptr
<
float16
>
()));
}
if
(
has_normalized_diff
)
{
if
(
has_gamma
)
{
CHECK_EQ
(
m
,
gamma
->
shape
().
elem_cnt
());
CHECK_EQ
(
dy
->
shape
().
elem_cnt
()
%
m
,
0
);
const
int64_t
n
=
dy
->
shape
().
elem_cnt
()
/
m
;
NdUtil
::
BroadcastMul
(
ctx
->
device_ctx
(),
Var
({
n
,
m
},
normalized_diff
->
mut_dptr
<
float16
>
()),
Val
({
n
,
m
},
dy
->
dptr
<
float16
>
()),
Val
({
1
,
m
},
gamma
->
dptr
<
float16
>
()));
}
else
{
Memcpy
<
DeviceType
::
kGPU
>
(
ctx
->
device_ctx
(),
normalized_diff
->
mut_dptr
<
void
>
(),
dy
->
dptr
<
void
>
(),
dy
->
shape
().
elem_cnt
()
*
GetSizeOfDataType
(
dy
->
data_type
()));
}
}
}
}
};
REGISTER_USER_KERNEL
(
"layer_norm_param_grad"
)
.
SetCreateFn
<
LayerNormParamGradGpuHalfKernel
>
()
.
SetIsMatchedHob
((
user_op
::
HobDeviceTag
()
==
"gpu"
)
&
(
user_op
::
HobDataType
(
"dy"
,
0
)
==
DataType
::
kFloat16
))
.
SetInferTmpSizeFn
([](
user_op
::
InferContext
*
ctx
)
{
const
int64_t
begin_params_axis
=
ctx
->
Attr
<
int64_t
>
(
"begin_params_axis"
);
const
bool
has_gamma_diff
=
ctx
->
user_op_conf
().
has_output
(
"gamma_diff"
,
0
);
const
bool
has_beta_diff
=
ctx
->
user_op_conf
().
has_output
(
"beta_diff"
,
0
);
const
bool
has_normalized_diff
=
ctx
->
user_op_conf
().
has_output
(
"normalized_diff"
,
0
);
const
auto
*
dy
=
ctx
->
TensorDesc4ArgNameAndIndex
(
"dy"
,
0
);
const
int64_t
instance_size
=
dy
->
shape
().
Count
(
begin_params_axis
);
size_t
tmp_buffer_size
=
0
;
if
(
has_gamma_diff
&&
has_beta_diff
&&
has_normalized_diff
)
{
const
size_t
tmp_gamma_diff
=
GetCudaAlignedSize
(
GetLayerNormNumBlocks
(
dy
->
shape
().
elem_cnt
())
*
instance_size
*
sizeof
(
float16
));
const
size_t
tmp_beta_diff
=
tmp_gamma_diff
;
const
size_t
tmp_reduce_buf
=
tmp_gamma_diff
;
tmp_buffer_size
=
tmp_gamma_diff
+
tmp_beta_diff
+
tmp_reduce_buf
;
}
else
{
tmp_buffer_size
=
0
;
}
return
tmp_buffer_size
;
});
}
// namespace oneflow
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录