Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Oneflow-Inc
oneflow
提交
e177ae88
O
oneflow
项目概览
Oneflow-Inc
/
oneflow
上一次同步 2 年多
通知
13
Star
2733
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
O
oneflow
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
e177ae88
编写于
10月 05, 2020
作者:
O
oneflow-bot
提交者:
GitHub
10月 05, 2020
浏览文件
操作
浏览文件
下载
差异文件
Merge branch 'master' into tfrecord_dataloader
Former-commit-id: b611e30be88931ebca4e901246f8d16a10719214
上级
06600c66
5c0d9865
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
95 addition
and
24 deletion
+95
-24
oneflow/user/kernels/normalization_kernel.cu
oneflow/user/kernels/normalization_kernel.cu
+91
-22
oneflow/user/ops/normalization_op.cpp
oneflow/user/ops/normalization_op.cpp
+4
-2
未找到文件。
oneflow/user/kernels/normalization_kernel.cu
浏览文件 @
e177ae88
...
...
@@ -229,38 +229,106 @@ REGISTER_BN_INFERENCE_KERNEL(double)
#undef REGISTER_BN_INFERENCE_KERNEL
constexpr
int64_t
kCudaWarpSize
=
32
;
template
<
typename
T
>
__global__
void
ReluGpu
(
int64_t
n
,
const
T
*
x
,
T
*
y
,
int32_t
*
mask
)
{
const
int32_t
lane_id
=
threadIdx
.
x
%
kCudaWarpSize
;
CUDA_1D_KERNEL_LOOP
(
i
,
n
)
{
const
T
x_val
=
x
[
i
];
const
bool
is_positive
=
(
x_val
>
0
);
int32_t
warp_mask
=
__ballot_sync
(
__activemask
(),
static_cast
<
int
>
(
is_positive
));
if
(
lane_id
==
0
)
{
mask
[
i
/
kCudaWarpSize
]
=
warp_mask
;
}
y
[
i
]
=
is_positive
?
x_val
:
0
;
}
}
template
<
>
__global__
void
ReluGpu
<
half
>
(
int64_t
n
,
const
half
*
x
,
half
*
y
,
int32_t
*
mask
)
{
const
int32_t
lane_id
=
threadIdx
.
x
%
kCudaWarpSize
;
const
half
zero
=
__float2half
(
0.0
f
);
CUDA_1D_KERNEL_LOOP
(
i
,
n
)
{
const
half
x_val
=
x
[
i
];
const
bool
is_positive
=
__hgt
(
x_val
,
zero
);
int32_t
warp_mask
=
__ballot_sync
(
__activemask
(),
static_cast
<
int
>
(
is_positive
));
if
(
lane_id
==
0
)
{
mask
[
i
/
kCudaWarpSize
]
=
warp_mask
;
}
y
[
i
]
=
is_positive
?
x_val
:
zero
;
}
}
template
<
typename
T
>
__global__
void
AddReluGpu
(
int64_t
n
,
const
T
*
x
,
const
T
*
addend
,
T
*
y
)
{
__global__
void
AddReluGpu
(
int64_t
n
,
const
T
*
x
,
const
T
*
addend
,
T
*
y
,
int32_t
*
mask
)
{
const
int32_t
lane_id
=
threadIdx
.
x
%
kCudaWarpSize
;
CUDA_1D_KERNEL_LOOP
(
i
,
n
)
{
T
sum
=
x
[
i
]
+
addend
[
i
];
y
[
i
]
=
sum
>
0
?
sum
:
0
;
const
T
sum
=
x
[
i
]
+
addend
[
i
];
const
bool
is_positive
=
(
sum
>
0
);
int32_t
warp_mask
=
__ballot_sync
(
__activemask
(),
static_cast
<
int
>
(
is_positive
));
if
(
lane_id
==
0
)
{
mask
[
i
/
kCudaWarpSize
]
=
warp_mask
;
}
y
[
i
]
=
is_positive
?
sum
:
0
;
}
}
template
<
>
__global__
void
AddReluGpu
<
half
>
(
int64_t
n
,
const
half
*
x
,
const
half
*
addend
,
half
*
y
)
{
__global__
void
AddReluGpu
<
half
>
(
int64_t
n
,
const
half
*
x
,
const
half
*
addend
,
half
*
y
,
int32_t
*
mask
)
{
const
int32_t
lane_id
=
threadIdx
.
x
%
kCudaWarpSize
;
const
half
zero
=
__float2half
(
0.0
f
);
CUDA_1D_KERNEL_LOOP
(
i
,
n
)
{
const
half
sum
=
__hadd
(
x
[
i
],
addend
[
i
]);
if
(
__hgt
(
sum
,
zero
))
{
y
[
i
]
=
sum
;
}
else
{
y
[
i
]
=
zero
;
}
const
bool
is_positive
=
__hgt
(
sum
,
zero
);
int32_t
warp_mask
=
__ballot_sync
(
__activemask
(),
static_cast
<
int
>
(
is_positive
));
if
(
lane_id
==
0
)
{
mask
[
i
/
kCudaWarpSize
]
=
warp_mask
;
}
y
[
i
]
=
is_positive
?
sum
:
zero
;
}
}
template
<
typename
T
>
void
AddRelu
(
DeviceCtx
*
device_ctx
,
int64_t
n
,
const
T
*
x
,
const
T
*
addend
,
T
*
y
)
{
void
Relu
(
DeviceCtx
*
device_ctx
,
int64_t
n
,
const
T
*
x
,
T
*
y
,
int32_t
*
mask
)
{
ReluGpu
<
T
><<<
BlocksNum4ThreadsNum
(
n
),
kCudaThreadsNumPerBlock
,
0
,
device_ctx
->
cuda_stream
()
>>>
(
n
,
x
,
y
,
mask
);
}
template
<
>
void
Relu
<
float16
>
(
DeviceCtx
*
device_ctx
,
int64_t
n
,
const
float16
*
x
,
float16
*
y
,
int32_t
*
mask
)
{
Relu
<
half
>
(
device_ctx
,
n
,
reinterpret_cast
<
const
half
*>
(
x
),
reinterpret_cast
<
half
*>
(
y
),
mask
);
}
template
<
typename
T
>
void
AddRelu
(
DeviceCtx
*
device_ctx
,
int64_t
n
,
const
T
*
x
,
const
T
*
addend
,
T
*
y
,
int32_t
*
mask
)
{
AddReluGpu
<
T
><<<
BlocksNum4ThreadsNum
(
n
),
kCudaThreadsNumPerBlock
,
0
,
device_ctx
->
cuda_stream
()
>>>
(
n
,
x
,
addend
,
y
);
n
,
x
,
addend
,
y
,
mask
);
}
template
<
>
void
AddRelu
<
float16
>
(
DeviceCtx
*
device_ctx
,
int64_t
n
,
const
float16
*
x
,
const
float16
*
addend
,
float16
*
y
)
{
float16
*
y
,
int32_t
*
mask
)
{
AddRelu
<
half
>
(
device_ctx
,
n
,
reinterpret_cast
<
const
half
*>
(
x
),
reinterpret_cast
<
const
half
*>
(
addend
),
reinterpret_cast
<
half
*>
(
y
));
reinterpret_cast
<
const
half
*>
(
addend
),
reinterpret_cast
<
half
*>
(
y
),
mask
);
}
template
<
typename
T
>
__global__
void
ReluBackwardGpu
(
int64_t
n
,
const
int32_t
*
mask
,
const
T
*
dy
,
T
*
addend_diff
)
{
int32_t
lane_id
=
threadIdx
.
x
%
kCudaWarpSize
;
CUDA_1D_KERNEL_LOOP
(
i
,
n
)
{
int32_t
mask_val
=
mask
[
i
/
kCudaWarpSize
];
bool
is_positive
=
mask_val
&
(
1
<<
lane_id
);
addend_diff
[
i
]
=
static_cast
<
T
>
(
is_positive
)
*
dy
[
i
];
}
}
template
<
typename
T
>
void
ReluBackward
(
DeviceCtx
*
device_ctx
,
int64_t
n
,
const
int32_t
*
mask
,
const
T
*
dy
,
T
*
addend_diff
)
{
ReluBackwardGpu
<
T
>
<<<
BlocksNum4ThreadsNum
(
n
),
kCudaThreadsNumPerBlock
,
0
,
device_ctx
->
cuda_stream
()
>>>
(
n
,
mask
,
dy
,
addend_diff
);
}
template
<
>
void
ReluBackward
<
float16
>
(
DeviceCtx
*
device_ctx
,
int64_t
n
,
const
int32_t
*
mask
,
const
float16
*
dy
,
float16
*
addend_diff
)
{
ReluBackward
<
half
>
(
device_ctx
,
n
,
mask
,
reinterpret_cast
<
const
half
*>
(
dy
),
reinterpret_cast
<
half
*>
(
addend_diff
));
}
template
<
typename
T
>
...
...
@@ -354,12 +422,14 @@ class NormalizationTrainKernel final : public user_op::OpKernel {
if
(
ctx
->
user_op_conf
().
op_type_name
()
==
"normalization_add_relu"
)
{
CHECK
(
!
ctx
->
user_op_conf
().
has_input
(
"_add_to_output"
,
0
));
const
int64_t
elem_cnt
=
x
->
shape
().
elem_cnt
();
auto
*
mask
=
ctx
->
Tensor4ArgNameAndIndex
(
"reserve_space"
,
0
);
if
(
ctx
->
user_op_conf
().
has_input
(
"addend"
,
0
))
{
const
auto
*
addend
=
ctx
->
Tensor4ArgNameAndIndex
(
"addend"
,
0
);
AddRelu
(
ctx
->
device_ctx
(),
elem_cnt
,
y
->
dptr
<
T
>
(),
addend
->
dptr
<
T
>
(),
y
->
mut_dptr
<
T
>
());
AddRelu
(
ctx
->
device_ctx
(),
elem_cnt
,
y
->
dptr
<
T
>
(),
addend
->
dptr
<
T
>
(),
y
->
mut_dptr
<
T
>
(),
mask
->
mut_dptr
<
int32_t
>
());
}
else
{
NewKernelUtil
<
DeviceType
::
kGPU
>::
Relu
(
ctx
->
device_ctx
(),
elem_cnt
,
y
->
dptr
<
T
>
(),
y
->
mut_dptr
<
T
>
());
Relu
(
ctx
->
device_ctx
(),
elem_cnt
,
y
->
dptr
<
T
>
(),
y
->
mut_
dptr
<
T
>
(),
mask
->
mut_dptr
<
int32_t
>
());
}
}
}
...
...
@@ -443,12 +513,12 @@ class NormalizationGradUserKernel final : public user_op::OpKernel {
bn_dy_ptr
=
dy
->
dptr
();
}
else
if
(
ctx
->
user_op_conf
().
op_type_name
()
==
"normalization_add_relu_grad"
)
{
const
int64_t
elem_cnt
=
dy
->
shape
().
elem_cnt
();
const
auto
*
mask
=
ctx
->
Tensor4ArgNameAndIndex
(
"reserve_space"
,
0
);
user_op
::
Tensor
*
y
=
ctx
->
Tensor4ArgNameAndIndex
(
"y"
,
0
);
if
(
ctx
->
user_op_conf
().
has_output
(
"addend_diff"
,
0
))
{
user_op
::
Tensor
*
addend_diff
=
ctx
->
Tensor4ArgNameAndIndex
(
"addend_diff"
,
0
);
NewKernelUtil
<
DeviceType
::
kGPU
>::
ReluBackward
(
ctx
->
device_ctx
(),
elem_cnt
,
nullptr
,
y
->
dptr
<
T
>
(),
dy
->
dptr
<
T
>
(),
addend_diff
->
mut_dptr
<
T
>
());
ReluBackward
(
ctx
->
device_ctx
(),
elem_cnt
,
mask
->
dptr
<
int32_t
>
(),
dy
->
dptr
<
T
>
(),
addend_diff
->
mut_dptr
<
T
>
());
bn_workspace_ptr
=
tmp_buffer
->
mut_dptr
();
bn_workspace_size
=
tmp_buffer
->
shape
().
elem_cnt
();
bn_dy_ptr
=
addend_diff
->
dptr
();
...
...
@@ -457,9 +527,8 @@ class NormalizationGradUserKernel final : public user_op::OpKernel {
const
size_t
relu_dx_size
=
GetCudaAlignedSize
(
dy
->
shape
().
elem_cnt
()
*
GetSizeOfDataType
(
dy
->
data_type
()));
CHECK_GE
(
tmp_buffer_size
,
relu_dx_size
);
NewKernelUtil
<
DeviceType
::
kGPU
>::
ReluBackward
(
ctx
->
device_ctx
(),
elem_cnt
,
nullptr
,
y
->
dptr
<
T
>
(),
dy
->
dptr
<
T
>
(),
reinterpret_cast
<
T
*>
(
tmp_buffer
->
mut_dptr
()));
ReluBackward
(
ctx
->
device_ctx
(),
elem_cnt
,
mask
->
dptr
<
int32_t
>
(),
dy
->
dptr
<
T
>
(),
reinterpret_cast
<
T
*>
(
tmp_buffer
->
mut_dptr
()));
bn_workspace_ptr
=
tmp_buffer
->
mut_dptr
<
char
>
()
+
relu_dx_size
;
bn_workspace_size
=
tmp_buffer_size
-
relu_dx_size
;
bn_dy_ptr
=
tmp_buffer
->
dptr
();
...
...
oneflow/user/ops/normalization_op.cpp
浏览文件 @
e177ae88
...
...
@@ -191,8 +191,10 @@ REGISTER_USER_OP("normalization_add_relu")
.
SetTensorDescInferFn
(
MakeFwTensorDescInferFn
([](
user_op
::
InferContext
*
ctx
,
const
user_op
::
TensorDesc
*
x
,
user_op
::
TensorDesc
*
reserve_space
)
->
Maybe
<
void
>
{
*
reserve_space
->
mut_data_type
()
=
DataType
::
kChar
;
*
reserve_space
->
mut_shape
()
=
Shape
({
1
});
const
auto
*
x_desc
=
ctx
->
TensorDesc4ArgNameAndIndex
(
"x"
,
0
);
*
reserve_space
->
mut_data_type
()
=
DataType
::
kInt32
;
*
reserve_space
->
mut_shape
()
=
Shape
({
static_cast
<
int64_t
>
(
RoundUp
(
x_desc
->
shape
().
elem_cnt
(),
32
)
/
32
)});
return
Maybe
<
void
>::
Ok
();
}))
.
SetBatchAxisInferFn
(
FwBatchAxisInferFn
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录