Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
3fc0d192
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
3fc0d192
编写于
3月 08, 2022
作者:
P
phlrain
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
update
上级
a8e02ef1
变更
11
隐藏空白更改
内联
并排
Showing
11 changed file
with
863 addition
and
89 deletion
+863
-89
paddle/fluid/operators/temporal_shift_op.h
paddle/fluid/operators/temporal_shift_op.h
+1
-89
paddle/phi/kernels/clip_by_norm_kernel.h
paddle/phi/kernels/clip_by_norm_kernel.h
+34
-0
paddle/phi/kernels/cpu/clip_by_norm_kernel.cc
paddle/phi/kernels/cpu/clip_by_norm_kernel.cc
+24
-0
paddle/phi/kernels/cpu/temporal_shift_grad_kernel.cc
paddle/phi/kernels/cpu/temporal_shift_grad_kernel.cc
+136
-0
paddle/phi/kernels/cpu/temporal_shift_kernel.cc
paddle/phi/kernels/cpu/temporal_shift_kernel.cc
+131
-0
paddle/phi/kernels/gpu/clip_by_norm_kernel.cu
paddle/phi/kernels/gpu/clip_by_norm_kernel.cu
+112
-0
paddle/phi/kernels/gpu/temporal_shift_grad_kernel.cu
paddle/phi/kernels/gpu/temporal_shift_grad_kernel.cu
+149
-0
paddle/phi/kernels/gpu/temporal_shift_kernel.cu
paddle/phi/kernels/gpu/temporal_shift_kernel.cu
+148
-0
paddle/phi/kernels/impl/clip_by_norm_kernel_impl.h
paddle/phi/kernels/impl/clip_by_norm_kernel_impl.h
+70
-0
paddle/phi/kernels/temporal_shift_grad_kernel.h
paddle/phi/kernels/temporal_shift_grad_kernel.h
+29
-0
paddle/phi/kernels/temporal_shift_kernel.h
paddle/phi/kernels/temporal_shift_kernel.h
+29
-0
未找到文件。
paddle/fluid/operators/temporal_shift_op.h
浏览文件 @
3fc0d192
...
...
@@ -19,56 +19,6 @@ namespace operators {
using
Tensor
=
framework
::
Tensor
;
using
DataLayout
=
framework
::
DataLayout
;
template
<
typename
T
>
void
TemporalShiftFwNCHW
(
const
T
*
input
,
T
*
output
,
const
int
ntchw
,
const
int
tchw
,
const
int
chw
,
const
int
hw
,
const
int
t
,
const
int
c1
,
const
int
c2
)
{
int
src_it
=
0
;
for
(
int
i
=
0
;
i
<
ntchw
;
i
++
)
{
int
it
=
(
i
%
tchw
)
/
chw
;
int
ic
=
(
i
%
chw
)
/
hw
;
if
(
ic
<
c1
)
{
src_it
=
it
-
1
;
}
else
if
(
ic
<
c2
)
{
src_it
=
it
+
1
;
}
else
{
src_it
=
it
;
}
if
(
src_it
<
0
||
src_it
>=
t
)
{
output
[
i
]
=
0
;
}
else
{
output
[
i
]
=
input
[
i
+
(
src_it
-
it
)
*
chw
];
}
}
}
template
<
typename
T
>
void
TemporalShiftFwNHWC
(
const
T
*
input
,
T
*
output
,
const
int
nthwc
,
const
int
thwc
,
const
int
hwc
,
const
int
t
,
const
int
c
,
const
int
c1
,
const
int
c2
)
{
int
src_it
=
0
;
for
(
int
i
=
0
;
i
<
nthwc
;
i
++
)
{
int
it
=
(
i
%
thwc
)
/
hwc
;
int
ic
=
i
%
c
;
if
(
ic
<
c1
)
{
src_it
=
it
-
1
;
}
else
if
(
ic
<
c2
)
{
src_it
=
it
+
1
;
}
else
{
src_it
=
it
;
}
if
(
src_it
<
0
||
src_it
>=
t
)
{
output
[
i
]
=
0
;
}
else
{
output
[
i
]
=
input
[
i
+
(
src_it
-
it
)
*
hwc
];
}
}
}
template
<
typename
T
>
void
TemporalShiftBwNCHW
(
const
T
*
output_grad
,
T
*
input_grad
,
const
int
ntchw
,
const
int
tchw
,
const
int
chw
,
const
int
hw
,
...
...
@@ -122,45 +72,7 @@ void TemporalShiftBwNHWC(const T* output_grad, T* input_grad, const int nthwc,
template
<
typename
T
>
class
TemporalShiftKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
*
input
=
ctx
.
Input
<
Tensor
>
(
"X"
);
auto
*
output
=
ctx
.
Output
<
Tensor
>
(
"Out"
);
int
t
=
ctx
.
Attr
<
int
>
(
"seg_num"
);
float
shift_ratio
=
ctx
.
Attr
<
float
>
(
"shift_ratio"
);
const
std
::
string
data_format_str
=
ctx
.
Attr
<
std
::
string
>
(
"data_format"
);
const
DataLayout
data_layout
=
framework
::
StringToDataLayout
(
data_format_str
);
const
int
nt
=
input
->
dims
()[
0
];
const
int
c
=
(
data_layout
==
DataLayout
::
kNCHW
?
input
->
dims
()[
1
]
:
input
->
dims
()[
3
]);
const
int
h
=
(
data_layout
==
DataLayout
::
kNCHW
?
input
->
dims
()[
2
]
:
input
->
dims
()[
1
]);
const
int
w
=
(
data_layout
==
DataLayout
::
kNCHW
?
input
->
dims
()[
3
]
:
input
->
dims
()[
2
]);
const
int
hw
=
h
*
w
;
const
int
chw
=
c
*
hw
;
const
int
tchw
=
t
*
chw
;
const
int
ntchw
=
nt
*
chw
;
const
int
c1
=
static_cast
<
int
>
(
c
*
shift_ratio
);
const
int
c2
=
static_cast
<
int
>
(
c
*
2
*
shift_ratio
);
framework
::
DDim
out_dims
=
(
data_layout
==
DataLayout
::
kNCHW
?
phi
::
make_ddim
({
nt
,
c
,
h
,
w
})
:
phi
::
make_ddim
({
nt
,
h
,
w
,
c
}));
const
T
*
input_data
=
input
->
data
<
T
>
();
T
*
output_data
=
output
->
mutable_data
<
T
>
(
out_dims
,
ctx
.
GetPlace
());
if
(
data_layout
==
DataLayout
::
kNCHW
)
{
TemporalShiftFwNCHW
<
T
>
(
input_data
,
output_data
,
ntchw
,
tchw
,
chw
,
hw
,
t
,
c1
,
c2
);
}
else
{
TemporalShiftFwNHWC
<
T
>
(
input_data
,
output_data
,
ntchw
,
tchw
,
chw
,
t
,
c
,
c1
,
c2
);
}
}
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{}
};
template
<
typename
T
>
...
...
paddle/phi/kernels/clip_by_norm_kernel.h
0 → 100644
浏览文件 @
3fc0d192
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/selected_rows.h"
namespace
phi
{
template
<
typename
T
,
typename
Context
>
void
ClipByNormKernel
(
const
Context
&
ctx
,
const
DenseTensor
&
x
,
float
max_norm
,
DenseTensor
*
out
);
template
<
typename
T
,
typename
Context
>
void
ClipByNormSparseKernel
(
const
Context
&
ctx
,
const
SelectedRows
&
x
,
float
max_norm
,
SelectedRows
*
out
);
}
// namespace phi
paddle/phi/kernels/cpu/clip_by_norm_kernel.cc
0 → 100644
浏览文件 @
3fc0d192
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/clip_by_norm_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/clip_by_norm_kernel_impl.h"
PD_REGISTER_KERNEL
(
clip_by_norm
,
CPU
,
ALL_LAYOUT
,
phi
::
ClipByNormKernel
,
float
)
{}
PD_REGISTER_KERNEL
(
clip_by_norm_sparse
,
CPU
,
ALL_LAYOUT
,
phi
::
ClipByNormSparseKernel
,
float
)
{}
paddle/phi/kernels/cpu/temporal_shift_grad_kernel.cc
0 → 100644
浏览文件 @
3fc0d192
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/temporal_shift_grad_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/common/layout.h"
#include "paddle/phi/core/kernel_registry.h"
namespace
phi
{
template
<
typename
T
>
void
TemporalShiftBwNCHW
(
const
T
*
output_grad
,
T
*
input_grad
,
const
int
ntchw
,
const
int
tchw
,
const
int
chw
,
const
int
hw
,
const
int
t
,
const
int
c1
,
const
int
c2
)
{
int
src_it
=
0
;
for
(
int
i
=
0
;
i
<
ntchw
;
i
++
)
{
int
it
=
(
i
%
tchw
)
/
chw
;
int
ic
=
(
i
%
chw
)
/
hw
;
if
(
ic
<
c1
)
{
src_it
=
it
+
1
;
}
else
if
(
ic
<
c2
)
{
src_it
=
it
-
1
;
}
else
{
src_it
=
it
;
}
if
(
src_it
>=
0
&&
src_it
<
t
)
{
input_grad
[
i
]
=
output_grad
[
i
+
(
src_it
-
it
)
*
chw
];
}
else
{
input_grad
[
i
]
=
0
;
}
}
}
template
<
typename
T
>
void
TemporalShiftBwNHWC
(
const
T
*
output_grad
,
T
*
input_grad
,
const
int
nthwc
,
const
int
thwc
,
const
int
hwc
,
const
int
t
,
const
int
c
,
const
int
c1
,
const
int
c2
)
{
int
src_it
=
0
;
for
(
int
i
=
0
;
i
<
nthwc
;
i
++
)
{
int
it
=
(
i
%
thwc
)
/
hwc
;
int
ic
=
i
%
c
;
if
(
ic
<
c1
)
{
src_it
=
it
+
1
;
}
else
if
(
ic
<
c2
)
{
src_it
=
it
-
1
;
}
else
{
src_it
=
it
;
}
if
(
src_it
>=
0
&&
src_it
<
t
)
{
input_grad
[
i
]
=
output_grad
[
i
+
(
src_it
-
it
)
*
hwc
];
}
else
{
input_grad
[
i
]
=
0
;
}
}
}
template
<
typename
T
,
typename
Context
>
void
TemporalShiftGradKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
out_grad
,
int
seg_num
,
float
shift_ratio
,
const
std
::
string
&
data_format_str
,
DenseTensor
*
x_grad
)
{
auto
*
input_grad
=
x_grad
;
auto
*
output_grad
=
&
out_grad
;
int
t
=
seg_num
;
const
DataLayout
data_layout
=
paddle
::
framework
::
StringToDataLayout
(
data_format_str
);
const
int
nt
=
output_grad
->
dims
()[
0
];
const
int
c
=
(
data_layout
==
DataLayout
::
kNCHW
?
output_grad
->
dims
()[
1
]
:
output_grad
->
dims
()[
3
]);
const
int
h
=
(
data_layout
==
DataLayout
::
kNCHW
?
output_grad
->
dims
()[
2
]
:
output_grad
->
dims
()[
1
]);
const
int
w
=
(
data_layout
==
DataLayout
::
kNCHW
?
output_grad
->
dims
()[
3
]
:
output_grad
->
dims
()[
2
]);
const
int
hw
=
h
*
w
;
const
int
chw
=
c
*
hw
;
const
int
tchw
=
t
*
chw
;
const
int
ntchw
=
nt
*
chw
;
const
int
c1
=
static_cast
<
int
>
(
c
*
shift_ratio
);
const
int
c2
=
static_cast
<
int
>
(
c
*
2
*
shift_ratio
);
DDim
in_grad_dims
=
(
data_layout
==
DataLayout
::
kNCHW
?
phi
::
make_ddim
({
nt
,
c
,
h
,
w
})
:
phi
::
make_ddim
({
nt
,
h
,
w
,
c
}));
const
T
*
output_grad_data
=
output_grad
->
data
<
T
>
();
T
*
input_grad_data
=
input_grad
->
mutable_data
<
T
>
(
in_grad_dims
,
dev_ctx
.
GetPlace
());
if
(
data_layout
==
DataLayout
::
kNCHW
)
{
TemporalShiftBwNCHW
<
T
>
(
output_grad_data
,
input_grad_data
,
ntchw
,
tchw
,
chw
,
hw
,
t
,
c1
,
c2
);
}
else
{
TemporalShiftBwNHWC
<
T
>
(
output_grad_data
,
input_grad_data
,
ntchw
,
tchw
,
chw
,
t
,
c
,
c1
,
c2
);
}
}
}
// namespace phi
PD_REGISTER_KERNEL
(
temporal_shift_grad
,
CPU
,
ALL_LAYOUT
,
phi
::
TemporalShiftGradKernel
,
float
,
double
)
{}
paddle/phi/kernels/cpu/temporal_shift_kernel.cc
0 → 100644
浏览文件 @
3fc0d192
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/temporal_shift_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/common/layout.h"
#include "paddle/phi/core/kernel_registry.h"
namespace
phi
{
template
<
typename
T
>
void
TemporalShiftFwNCHW
(
const
T
*
input
,
T
*
output
,
const
int
ntchw
,
const
int
tchw
,
const
int
chw
,
const
int
hw
,
const
int
t
,
const
int
c1
,
const
int
c2
)
{
int
src_it
=
0
;
for
(
int
i
=
0
;
i
<
ntchw
;
i
++
)
{
int
it
=
(
i
%
tchw
)
/
chw
;
int
ic
=
(
i
%
chw
)
/
hw
;
if
(
ic
<
c1
)
{
src_it
=
it
-
1
;
}
else
if
(
ic
<
c2
)
{
src_it
=
it
+
1
;
}
else
{
src_it
=
it
;
}
if
(
src_it
<
0
||
src_it
>=
t
)
{
output
[
i
]
=
0
;
}
else
{
output
[
i
]
=
input
[
i
+
(
src_it
-
it
)
*
chw
];
}
}
}
template
<
typename
T
>
void
TemporalShiftFwNHWC
(
const
T
*
input
,
T
*
output
,
const
int
nthwc
,
const
int
thwc
,
const
int
hwc
,
const
int
t
,
const
int
c
,
const
int
c1
,
const
int
c2
)
{
int
src_it
=
0
;
for
(
int
i
=
0
;
i
<
nthwc
;
i
++
)
{
int
it
=
(
i
%
thwc
)
/
hwc
;
int
ic
=
i
%
c
;
if
(
ic
<
c1
)
{
src_it
=
it
-
1
;
}
else
if
(
ic
<
c2
)
{
src_it
=
it
+
1
;
}
else
{
src_it
=
it
;
}
if
(
src_it
<
0
||
src_it
>=
t
)
{
output
[
i
]
=
0
;
}
else
{
output
[
i
]
=
input
[
i
+
(
src_it
-
it
)
*
hwc
];
}
}
}
template
<
typename
T
,
typename
Context
>
void
TemporalShiftKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
int
seg_num
,
float
shift_ratio
,
const
std
::
string
&
data_format_str
,
DenseTensor
*
out
)
{
auto
*
input
=
&
x
;
auto
*
output
=
out
;
int
t
=
seg_num
;
const
DataLayout
data_layout
=
paddle
::
framework
::
StringToDataLayout
(
data_format_str
);
const
int
nt
=
input
->
dims
()[
0
];
const
int
c
=
(
data_layout
==
DataLayout
::
kNCHW
?
input
->
dims
()[
1
]
:
input
->
dims
()[
3
]);
const
int
h
=
(
data_layout
==
DataLayout
::
kNCHW
?
input
->
dims
()[
2
]
:
input
->
dims
()[
1
]);
const
int
w
=
(
data_layout
==
DataLayout
::
kNCHW
?
input
->
dims
()[
3
]
:
input
->
dims
()[
2
]);
const
int
hw
=
h
*
w
;
const
int
chw
=
c
*
hw
;
const
int
tchw
=
t
*
chw
;
const
int
ntchw
=
nt
*
chw
;
const
int
c1
=
static_cast
<
int
>
(
c
*
shift_ratio
);
const
int
c2
=
static_cast
<
int
>
(
c
*
2
*
shift_ratio
);
DDim
out_dims
=
(
data_layout
==
DataLayout
::
kNCHW
?
phi
::
make_ddim
({
nt
,
c
,
h
,
w
})
:
phi
::
make_ddim
({
nt
,
h
,
w
,
c
}));
const
T
*
input_data
=
input
->
data
<
T
>
();
T
*
output_data
=
output
->
mutable_data
<
T
>
(
out_dims
,
dev_ctx
.
GetPlace
());
if
(
data_layout
==
DataLayout
::
kNCHW
)
{
TemporalShiftFwNCHW
<
T
>
(
input_data
,
output_data
,
ntchw
,
tchw
,
chw
,
hw
,
t
,
c1
,
c2
);
}
else
{
TemporalShiftFwNHWC
<
T
>
(
input_data
,
output_data
,
ntchw
,
tchw
,
chw
,
t
,
c
,
c1
,
c2
);
}
}
}
// namespace phi
PD_REGISTER_KERNEL
(
temporal_shift
,
CPU
,
ALL_LAYOUT
,
phi
::
TemporalShiftKernel
,
float
,
double
)
{}
paddle/phi/kernels/gpu/clip_by_norm_kernel.cu
0 → 100644
浏览文件 @
3fc0d192
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/clip_by_norm_kernel.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/impl/clip_by_norm_kernel_impl.h"
#include "paddle/fluid/operators/math/selected_rows_functor.h"
#include "paddle/phi/kernels/gpu/reduce.h"
#include "paddle/phi/kernels/primitive/functor_primitives.h"
namespace
phi
{
template
<
>
void
ClipByNormKernel
<
phi
::
dtype
::
float16
,
phi
::
GPUContext
>
(
const
GPUContext
&
dev_ctx
,
const
DenseTensor
&
x_in
,
float
max_norm
,
DenseTensor
*
out_p
)
{
dev_ctx
.
template
Alloc
<
dtype
::
float16
>(
out_p
);
std
::
vector
<
int
>
reduce_dims
;
reduce_dims
.
resize
(
x_in
.
dims
().
size
());
for
(
int
i
=
0
;
i
<
reduce_dims
.
size
();
++
i
)
{
reduce_dims
[
i
]
=
i
;
}
DenseTensor
tmp
;
tmp
.
Resize
({
1
});
dev_ctx
.
template
Alloc
<
float
>(
&
tmp
);
kernels
::
TensorReduceImpl
<
dtype
::
float16
,
float
,
kps
::
AddFunctor
,
kps
::
SquareFunctor
<
dtype
::
float16
,
float
>>
(
dev_ctx
,
x_in
,
&
tmp
,
kps
::
SquareFunctor
<
dtype
::
float16
,
float
>
(),
reduce_dims
,
dev_ctx
.
stream
());
auto
tmp_eigen
=
EigenVector
<
float
>::
Flatten
(
tmp
);
auto
x_norm
=
tmp_eigen
.
sqrt
();
auto
x
=
EigenVector
<
dtype
::
float16
>::
Flatten
(
x_in
);
auto
out
=
EigenVector
<
dtype
::
float16
>::
Flatten
(
*
out_p
);
auto
&
place
=
*
dev_ctx
.
eigen_device
();
auto
temp
=
(
x_norm
<=
max_norm
).
template
cast
<
float
>();
auto
epsilon
=
((
x_norm
<=
static_cast
<
float
>
(
1e-30
)).
all
().
template
cast
<
float
>())
*
static_cast
<
float
>
(
1e-6
);
auto
scaling
=
(
temp
+
(
static_cast
<
float
>
(
1
)
-
temp
)
*
max_norm
/
(
x_norm
+
epsilon
))
.
template
cast
<
dtype
::
float16
>();
Eigen
::
array
<
int
,
1
>
one_dim
{{
1
}};
Eigen
::
DSizes
<
int
,
1
>
m_dsize
(
x_in
.
numel
());
out
.
device
(
place
)
=
x
*
scaling
.
reshape
(
one_dim
).
broadcast
(
m_dsize
);
}
template
<
>
void
ClipByNormSparseKernel
<
phi
::
dtype
::
float16
,
phi
::
GPUContext
>
(
const
phi
::
GPUContext
&
ctx
,
const
SelectedRows
&
x
,
float
max_norm
,
SelectedRows
*
out
)
{
// merge ids in selected rows first
paddle
::
operators
::
math
::
scatter
::
MergeAdd
<
GPUContext
,
dtype
::
float16
>
merge_func
;
phi
::
SelectedRows
merged_input
;
merge_func
(
ctx
,
x
,
&
merged_input
);
auto
input
=
merged_input
.
value
();
phi
::
SelectedRows
*
output_selected_rows
=
out
;
output_selected_rows
->
set_rows
(
merged_input
.
rows
());
output_selected_rows
->
set_height
(
merged_input
.
height
());
auto
output
=
output_selected_rows
->
mutable_value
();
output
->
Resize
(
merged_input
.
value
().
dims
());
output
->
mutable_data
<
dtype
::
float16
>
(
ctx
.
GetPlace
());
ClipByNormKernel
<
dtype
::
float16
>
(
ctx
,
input
,
max_norm
,
output
);
}
}
// namespace phi
// PD_REGISTER_KERNEL(
// clip_by_norm, GPU, ALL_LAYOUT, phi::ClipByNormKernel, float,
// phi::dtype::float16) {}
// PD_REGISTER_KERNEL(
// clip_by_norm_sparse, GPU, ALL_LAYOUT, phi::ClipByNormSparseKernel, float,
// phi::dtype::float16) {}
PD_REGISTER_KERNEL
(
clip_by_norm
,
GPU
,
ALL_LAYOUT
,
phi
::
ClipByNormKernel
,
float
)
{}
PD_REGISTER_KERNEL
(
clip_by_norm_sparse
,
GPU
,
ALL_LAYOUT
,
phi
::
ClipByNormSparseKernel
,
float
)
{}
paddle/phi/kernels/gpu/temporal_shift_grad_kernel.cu
0 → 100644
浏览文件 @
3fc0d192
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/layout.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/temporal_shift_grad_kernel.h"
namespace
phi
{
template
<
typename
T
>
__global__
void
KeTemporalShiftBwNCHW
(
const
T
*
output_grad
,
T
*
input_grad
,
const
int
ntchw
,
const
int
tchw
,
const
int
chw
,
const
int
hw
,
const
int
t
,
const
int
c1
,
const
int
c2
)
{
int
tid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
stride
=
blockDim
.
x
*
gridDim
.
x
;
int
src_it
=
0
;
for
(;
tid
<
ntchw
;
tid
+=
stride
)
{
int
it
=
(
tid
%
tchw
)
/
chw
;
int
ic
=
(
tid
%
chw
)
/
hw
;
if
(
ic
<
c1
)
{
src_it
=
it
+
1
;
}
else
if
(
ic
<
c2
)
{
src_it
=
it
-
1
;
}
else
{
src_it
=
it
;
}
if
(
src_it
>=
0
&&
src_it
<
t
)
{
input_grad
[
tid
]
=
output_grad
[
tid
+
(
src_it
-
it
)
*
chw
];
}
else
{
input_grad
[
tid
]
=
0
;
}
}
}
template
<
typename
T
>
__global__
void
KeTemporalShiftBwNHWC
(
const
T
*
output_grad
,
T
*
input_grad
,
const
int
nthwc
,
const
int
thwc
,
const
int
hwc
,
const
int
t
,
const
int
c
,
const
int
c1
,
const
int
c2
)
{
int
tid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
stride
=
blockDim
.
x
*
gridDim
.
x
;
int
src_it
=
0
;
for
(;
tid
<
nthwc
;
tid
+=
stride
)
{
int
it
=
(
tid
%
thwc
)
/
hwc
;
int
ic
=
tid
%
c
;
if
(
ic
<
c1
)
{
src_it
=
it
+
1
;
}
else
if
(
ic
<
c2
)
{
src_it
=
it
-
1
;
}
else
{
src_it
=
it
;
}
if
(
src_it
>=
0
&&
src_it
<
t
)
{
input_grad
[
tid
]
=
output_grad
[
tid
+
(
src_it
-
it
)
*
hwc
];
}
else
{
input_grad
[
tid
]
=
0
;
}
}
}
template
<
typename
T
,
typename
Context
>
void
TemporalShiftGradKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
out_grad
,
int
seg_num
,
float
shift_ratio
,
const
std
::
string
&
data_format_str
,
DenseTensor
*
x_grad
)
{
auto
*
input_grad
=
x_grad
;
auto
*
output_grad
=
&
out_grad
;
int
t
=
seg_num
;
const
DataLayout
data_layout
=
paddle
::
framework
::
StringToDataLayout
(
data_format_str
);
const
int
nt
=
output_grad
->
dims
()[
0
];
const
int
c
=
(
data_layout
==
DataLayout
::
kNCHW
?
output_grad
->
dims
()[
1
]
:
output_grad
->
dims
()[
3
]);
const
int
h
=
(
data_layout
==
DataLayout
::
kNCHW
?
output_grad
->
dims
()[
2
]
:
output_grad
->
dims
()[
1
]);
const
int
w
=
(
data_layout
==
DataLayout
::
kNCHW
?
output_grad
->
dims
()[
3
]
:
output_grad
->
dims
()[
2
]);
const
int
hw
=
h
*
w
;
const
int
chw
=
c
*
hw
;
const
int
tchw
=
t
*
chw
;
const
int
ntchw
=
nt
*
chw
;
const
int
c1
=
static_cast
<
int
>
(
c
*
shift_ratio
);
const
int
c2
=
static_cast
<
int
>
(
c
*
2
*
shift_ratio
);
DDim
in_grad_dims
=
(
data_layout
==
DataLayout
::
kNCHW
?
phi
::
make_ddim
({
nt
,
c
,
h
,
w
})
:
phi
::
make_ddim
({
nt
,
h
,
w
,
c
}));
const
T
*
output_grad_data
=
output_grad
->
data
<
T
>
();
T
*
input_grad_data
=
input_grad
->
mutable_data
<
T
>
(
in_grad_dims
,
dev_ctx
.
GetPlace
());
int
pixelNum
=
nt
*
chw
;
int
threads
=
1024
;
int
grid
=
(
pixelNum
+
threads
-
1
)
/
threads
;
int
blocks_per_sm
=
dev_ctx
.
GetMaxPhysicalThreadCount
()
/
threads
;
grid
=
std
::
min
(
dev_ctx
.
GetSMCount
()
*
blocks_per_sm
,
grid
);
if
(
data_layout
==
DataLayout
::
kNCHW
)
{
KeTemporalShiftBwNCHW
<
T
><<<
grid
,
threads
,
0
,
dev_ctx
.
stream
()
>>>
(
output_grad_data
,
input_grad_data
,
ntchw
,
tchw
,
chw
,
hw
,
t
,
c1
,
c2
);
}
else
{
KeTemporalShiftBwNHWC
<
T
><<<
grid
,
threads
,
0
,
dev_ctx
.
stream
()
>>>
(
output_grad_data
,
input_grad_data
,
ntchw
,
tchw
,
chw
,
t
,
c
,
c1
,
c2
);
}
}
}
// namespace phi
PD_REGISTER_KERNEL
(
temporal_shift_grad
,
GPU
,
ALL_LAYOUT
,
phi
::
TemporalShiftGradKernel
,
float
,
double
,
phi
::
dtype
::
float16
)
{}
paddle/phi/kernels/gpu/temporal_shift_kernel.cu
0 → 100644
浏览文件 @
3fc0d192
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/layout.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/temporal_shift_kernel.h"
namespace
phi
{
template
<
typename
T
>
__global__
void
KeTemporalShiftFwNCHW
(
const
T
*
input
,
T
*
output
,
const
int
ntchw
,
const
int
tchw
,
const
int
chw
,
const
int
hw
,
const
int
t
,
const
int
c1
,
const
int
c2
)
{
int
tid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
stride
=
blockDim
.
x
*
gridDim
.
x
;
int
src_it
=
0
;
for
(;
tid
<
ntchw
;
tid
+=
stride
)
{
int
it
=
(
tid
%
tchw
)
/
chw
;
int
ic
=
(
tid
%
chw
)
/
hw
;
if
(
ic
<
c1
)
{
src_it
=
it
-
1
;
}
else
if
(
ic
<
c2
)
{
src_it
=
it
+
1
;
}
else
{
src_it
=
it
;
}
if
(
src_it
<
0
||
src_it
>=
t
)
{
output
[
tid
]
=
0
;
}
else
{
output
[
tid
]
=
input
[
tid
+
(
src_it
-
it
)
*
chw
];
}
}
}
template
<
typename
T
>
__global__
void
KeTemporalShiftFwNHWC
(
const
T
*
input
,
T
*
output
,
const
int
nthwc
,
const
int
thwc
,
const
int
hwc
,
const
int
t
,
const
int
c
,
const
int
c1
,
const
int
c2
)
{
int
tid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
stride
=
blockDim
.
x
*
gridDim
.
x
;
int
src_it
=
0
;
for
(;
tid
<
nthwc
;
tid
+=
stride
)
{
int
it
=
(
tid
%
thwc
)
/
hwc
;
int
ic
=
tid
%
c
;
if
(
ic
<
c1
)
{
src_it
=
it
-
1
;
}
else
if
(
ic
<
c2
)
{
src_it
=
it
+
1
;
}
else
{
src_it
=
it
;
}
if
(
src_it
<
0
||
src_it
>=
t
)
{
output
[
tid
]
=
0
;
}
else
{
output
[
tid
]
=
input
[
tid
+
(
src_it
-
it
)
*
hwc
];
}
}
}
template
<
typename
T
,
typename
Context
>
void
TemporalShiftKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
int
seg_num
,
float
shift_ratio
,
const
std
::
string
&
data_format_str
,
DenseTensor
*
out
)
{
auto
*
input
=
&
x
;
auto
*
output
=
out
;
int
t
=
seg_num
;
const
DataLayout
data_layout
=
paddle
::
framework
::
StringToDataLayout
(
data_format_str
);
const
int
nt
=
input
->
dims
()[
0
];
const
int
c
=
(
data_layout
==
DataLayout
::
kNCHW
?
input
->
dims
()[
1
]
:
input
->
dims
()[
3
]);
const
int
h
=
(
data_layout
==
DataLayout
::
kNCHW
?
input
->
dims
()[
2
]
:
input
->
dims
()[
1
]);
const
int
w
=
(
data_layout
==
DataLayout
::
kNCHW
?
input
->
dims
()[
3
]
:
input
->
dims
()[
2
]);
const
int
hw
=
h
*
w
;
const
int
chw
=
c
*
hw
;
const
int
tchw
=
t
*
chw
;
const
int
ntchw
=
nt
*
chw
;
const
int
c1
=
static_cast
<
int
>
(
c
*
shift_ratio
);
const
int
c2
=
static_cast
<
int
>
(
c
*
2
*
shift_ratio
);
DDim
out_dims
=
(
data_layout
==
DataLayout
::
kNCHW
?
phi
::
make_ddim
({
nt
,
c
,
h
,
w
})
:
phi
::
make_ddim
({
nt
,
h
,
w
,
c
}));
const
T
*
input_data
=
input
->
data
<
T
>
();
T
*
output_data
=
output
->
mutable_data
<
T
>
(
out_dims
,
dev_ctx
.
GetPlace
());
int
pixelNum
=
nt
*
chw
;
int
threads
=
1024
;
int
grid
=
(
pixelNum
+
threads
-
1
)
/
threads
;
int
blocks_per_sm
=
dev_ctx
.
GetMaxPhysicalThreadCount
()
/
threads
;
grid
=
std
::
min
(
dev_ctx
.
GetSMCount
()
*
blocks_per_sm
,
grid
);
if
(
data_layout
==
DataLayout
::
kNCHW
)
{
KeTemporalShiftFwNCHW
<
T
><<<
grid
,
threads
,
0
,
dev_ctx
.
stream
()
>>>
(
input_data
,
output_data
,
ntchw
,
tchw
,
chw
,
hw
,
t
,
c1
,
c2
);
}
else
{
KeTemporalShiftFwNHWC
<
T
><<<
grid
,
threads
,
0
,
dev_ctx
.
stream
()
>>>
(
input_data
,
output_data
,
ntchw
,
tchw
,
chw
,
t
,
c
,
c1
,
c2
);
}
}
}
// namespace phi
PD_REGISTER_KERNEL
(
temporal_shift
,
GPU
,
ALL_LAYOUT
,
phi
::
TemporalShiftKernel
,
float
,
double
,
phi
::
dtype
::
float16
)
{}
paddle/phi/kernels/impl/clip_by_norm_kernel_impl.h
0 → 100644
浏览文件 @
3fc0d192
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/operators/math/selected_rows_functor.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
namespace
phi
{
template
<
typename
T
,
typename
Context
>
void
ClipByNormKernel
(
const
Context
&
ctx
,
const
DenseTensor
&
x_in
,
float
max_norm
,
DenseTensor
*
out_p
)
{
ctx
.
template
Alloc
<
T
>(
out_p
);
auto
x
=
EigenVector
<
T
>::
Flatten
(
x_in
);
auto
out
=
EigenVector
<
T
>::
Flatten
(
*
out_p
);
auto
x_norm
=
x
.
square
().
sum
().
sqrt
();
auto
&
place
=
*
ctx
.
eigen_device
();
auto
temp
=
(
x_norm
<=
max_norm
).
template
cast
<
T
>();
auto
epsilon
=
((
x_norm
<=
static_cast
<
T
>
(
1e-30
)).
all
().
template
cast
<
T
>())
*
static_cast
<
T
>
(
1e-6
);
auto
scaling
=
temp
+
(
static_cast
<
T
>
(
1
)
-
temp
)
*
max_norm
/
(
x_norm
+
epsilon
);
Eigen
::
array
<
int
,
1
>
one_dim
{{
1
}};
Eigen
::
DSizes
<
int
,
1
>
m_dsize
(
x_in
.
numel
());
if
(
ctx
.
GetPlace
()
==
phi
::
CPUPlace
())
{
out
.
device
(
place
)
=
x
*
scaling
.
reshape
(
one_dim
).
eval
().
broadcast
(
m_dsize
);
}
else
{
out
.
device
(
place
)
=
x
*
scaling
.
reshape
(
one_dim
).
broadcast
(
m_dsize
);
}
}
template
<
typename
T
,
typename
Context
>
void
ClipByNormSparseKernel
(
const
Context
&
ctx
,
const
SelectedRows
&
x
,
float
max_norm
,
SelectedRows
*
out
)
{
// merge ids in selected rows first
paddle
::
operators
::
math
::
scatter
::
MergeAdd
<
Context
,
T
>
merge_func
;
phi
::
SelectedRows
merged_input
;
merge_func
(
ctx
,
x
,
&
merged_input
);
auto
input
=
merged_input
.
value
();
phi
::
SelectedRows
*
output_selected_rows
=
out
;
output_selected_rows
->
set_rows
(
merged_input
.
rows
());
output_selected_rows
->
set_height
(
merged_input
.
height
());
auto
output
=
output_selected_rows
->
mutable_value
();
output
->
Resize
(
merged_input
.
value
().
dims
());
output
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
ClipByNormKernel
<
T
>
(
ctx
,
input
,
max_norm
,
output
);
}
}
// namespace phi
paddle/phi/kernels/temporal_shift_grad_kernel.h
0 → 100644
浏览文件 @
3fc0d192
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace
phi
{
template
<
typename
T
,
typename
Context
>
void
TemporalShiftGradKernel
(
const
Context
&
ctx
,
const
DenseTensor
&
out_grad
,
int
seg_num
,
float
shift_ratio
,
const
std
::
string
&
data_format
,
DenseTensor
*
x_grad
);
}
// namespace phi
paddle/phi/kernels/temporal_shift_kernel.h
0 → 100644
浏览文件 @
3fc0d192
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace
phi
{
template
<
typename
T
,
typename
Context
>
void
TemporalShiftKernel
(
const
Context
&
ctx
,
const
DenseTensor
&
x
,
int
seg_num
,
float
shift_ratio
,
const
std
::
string
&
data_format
,
DenseTensor
*
out
);
}
// namespace phi
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录