Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
e93e8a3f
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 2 年 前同步成功
通知
2325
Star
20933
Fork
5424
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
e93e8a3f
编写于
4月 14, 2023
作者:
H
huangjiyi
提交者:
GitHub
4月 14, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
update (#52878)
上级
aac8da90
变更
19
隐藏空白更改
内联
并排
Showing
19 changed file
with
101 addition
and
1288 deletion
+101
-1288
paddle/fluid/operators/amp/get_float_status_op.cc
paddle/fluid/operators/amp/get_float_status_op.cc
+3
-2
paddle/fluid/operators/collective/global_gather_op.cc
paddle/fluid/operators/collective/global_gather_op.cc
+9
-6
paddle/fluid/operators/collective/global_gather_op.cu.cc
paddle/fluid/operators/collective/global_gather_op.cu.cc
+10
-7
paddle/fluid/operators/collective/global_gather_op.h
paddle/fluid/operators/collective/global_gather_op.h
+1
-1
paddle/fluid/operators/collective/global_scatter_op.cc
paddle/fluid/operators/collective/global_scatter_op.cc
+9
-6
paddle/fluid/operators/collective/global_scatter_op.cu.cc
paddle/fluid/operators/collective/global_scatter_op.cu.cc
+10
-7
paddle/fluid/operators/collective/global_scatter_op.h
paddle/fluid/operators/collective/global_scatter_op.h
+1
-1
paddle/fluid/operators/detection/generate_mask_labels_op.cc
paddle/fluid/operators/detection/generate_mask_labels_op.cc
+7
-3
paddle/fluid/operators/detection/generate_proposal_labels_op.cc
.../fluid/operators/detection/generate_proposal_labels_op.cc
+7
-4
paddle/fluid/operators/gaussian_random_batch_size_like_op.cc
paddle/fluid/operators/gaussian_random_batch_size_like_op.cc
+8
-5
paddle/fluid/operators/gaussian_random_batch_size_like_op.cu
paddle/fluid/operators/gaussian_random_batch_size_like_op.cu
+10
-7
paddle/fluid/operators/graph_khop_sampler_op.cc
paddle/fluid/operators/graph_khop_sampler_op.cc
+7
-3
paddle/fluid/operators/graph_khop_sampler_op.cu
paddle/fluid/operators/graph_khop_sampler_op.cu
+7
-4
paddle/fluid/operators/graph_khop_sampler_op.h
paddle/fluid/operators/graph_khop_sampler_op.h
+1
-1
paddle/fluid/operators/group_norm_op.cc
paddle/fluid/operators/group_norm_op.cc
+0
-2
paddle/fluid/operators/group_norm_op.cu
paddle/fluid/operators/group_norm_op.cu
+0
-834
paddle/fluid/operators/group_norm_op.h
paddle/fluid/operators/group_norm_op.h
+0
-387
paddle/fluid/operators/l1_norm_op.cc
paddle/fluid/operators/l1_norm_op.cc
+9
-6
paddle/fluid/operators/l1_norm_op.h
paddle/fluid/operators/l1_norm_op.h
+2
-2
未找到文件。
paddle/fluid/operators/amp/get_float_status_op.cc
浏览文件 @
e93e8a3f
...
@@ -53,7 +53,7 @@ class GetFloatStatusMaker : public framework::OpProtoAndCheckerMaker {
...
@@ -53,7 +53,7 @@ class GetFloatStatusMaker : public framework::OpProtoAndCheckerMaker {
}
}
};
};
template
<
typename
DeviceContext
,
typename
T
>
template
<
typename
T
,
typename
DeviceContext
>
class
GetFloatStatusKernel
:
public
framework
::
OpKernel
<
T
>
{
class
GetFloatStatusKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
...
@@ -75,4 +75,5 @@ REGISTER_OPERATOR(
...
@@ -75,4 +75,5 @@ REGISTER_OPERATOR(
paddle
::
framework
::
EmptyGradOpMaker
<
paddle
::
framework
::
OpDesc
>
,
paddle
::
framework
::
EmptyGradOpMaker
<
paddle
::
framework
::
OpDesc
>
,
paddle
::
framework
::
EmptyGradOpMaker
<
paddle
::
imperative
::
OpBase
>
);
paddle
::
framework
::
EmptyGradOpMaker
<
paddle
::
imperative
::
OpBase
>
);
REGISTER_OP_CPU_KERNEL
(
get_float_status
,
ops
::
GetFloatStatusKernel
<
CPU
,
float
>
);
PD_REGISTER_STRUCT_KERNEL
(
get_float_status
,
CPU
,
ALL_LAYOUT
,
ops
::
GetFloatStatusKernel
,
float
)
{}
paddle/fluid/operators/collective/global_gather_op.cc
浏览文件 @
e93e8a3f
...
@@ -111,9 +111,12 @@ REGISTER_OPERATOR(global_gather,
...
@@ -111,9 +111,12 @@ REGISTER_OPERATOR(global_gather,
ops
::
GlobalGatherOpGradMaker
<
paddle
::
framework
::
OpDesc
>
,
ops
::
GlobalGatherOpGradMaker
<
paddle
::
framework
::
OpDesc
>
,
ops
::
GlobalGatherOpGradMaker
<
paddle
::
imperative
::
OpBase
>
)
ops
::
GlobalGatherOpGradMaker
<
paddle
::
imperative
::
OpBase
>
)
REGISTER_OP_CPU_KERNEL
(
global_gather
,
PD_REGISTER_STRUCT_KERNEL
(
global_gather
,
ops
::
GlobalGatherOpCPUKernel
<
float
>
,
CPU
,
ops
::
GlobalGatherOpCPUKernel
<
double
>
,
ALL_LAYOUT
,
ops
::
GlobalGatherOpCPUKernel
<
int
>
,
ops
::
GlobalGatherOpCPUKernel
,
ops
::
GlobalGatherOpCPUKernel
<
int64_t
>
,
float
,
ops
::
GlobalGatherOpCPUKernel
<
plat
::
float16
>
);
double
,
int
,
int64_t
,
plat
::
float16
)
{}
paddle/fluid/operators/collective/global_gather_op.cu.cc
浏览文件 @
e93e8a3f
...
@@ -261,7 +261,7 @@ struct GlobalGatherProcessGroupFunctor<phi::GPUContext, T> {
...
@@ -261,7 +261,7 @@ struct GlobalGatherProcessGroupFunctor<phi::GPUContext, T> {
}
}
};
};
template
<
typename
T
>
template
<
typename
T
,
typename
DeivceContext
>
class
GlobalGatherOpCUDAKernel
:
public
framework
::
OpKernel
<
T
>
{
class
GlobalGatherOpCUDAKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
...
@@ -283,9 +283,12 @@ class GlobalGatherOpCUDAKernel : public framework::OpKernel<T> {
...
@@ -283,9 +283,12 @@ class GlobalGatherOpCUDAKernel : public framework::OpKernel<T> {
namespace
ops
=
paddle
::
operators
;
namespace
ops
=
paddle
::
operators
;
namespace
plat
=
paddle
::
platform
;
namespace
plat
=
paddle
::
platform
;
REGISTER_OP_CUDA_KERNEL
(
global_gather
,
PD_REGISTER_STRUCT_KERNEL
(
global_gather
,
ops
::
GlobalGatherOpCUDAKernel
<
float
>
,
GPU
,
ops
::
GlobalGatherOpCUDAKernel
<
double
>
,
ALL_LAYOUT
,
ops
::
GlobalGatherOpCUDAKernel
<
int
>
,
ops
::
GlobalGatherOpCUDAKernel
,
ops
::
GlobalGatherOpCUDAKernel
<
int64_t
>
,
float
,
ops
::
GlobalGatherOpCUDAKernel
<
plat
::
float16
>
);
double
,
int
,
int64_t
,
plat
::
float16
)
{}
paddle/fluid/operators/collective/global_gather_op.h
浏览文件 @
e93e8a3f
...
@@ -25,7 +25,7 @@ limitations under the License. */
...
@@ -25,7 +25,7 @@ limitations under the License. */
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
template
<
typename
T
>
template
<
typename
T
,
typename
DeviceContext
>
class
GlobalGatherOpCPUKernel
:
public
framework
::
OpKernel
<
T
>
{
class
GlobalGatherOpCPUKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
...
...
paddle/fluid/operators/collective/global_scatter_op.cc
浏览文件 @
e93e8a3f
...
@@ -115,9 +115,12 @@ REGISTER_OPERATOR(global_scatter,
...
@@ -115,9 +115,12 @@ REGISTER_OPERATOR(global_scatter,
ops
::
GlobalScatterOpGradMaker
<
paddle
::
framework
::
OpDesc
>
,
ops
::
GlobalScatterOpGradMaker
<
paddle
::
framework
::
OpDesc
>
,
ops
::
GlobalScatterOpGradMaker
<
paddle
::
imperative
::
OpBase
>
)
ops
::
GlobalScatterOpGradMaker
<
paddle
::
imperative
::
OpBase
>
)
REGISTER_OP_CPU_KERNEL
(
global_scatter
,
PD_REGISTER_STRUCT_KERNEL
(
global_scatter
,
ops
::
GlobalScatterOpCPUKernel
<
float
>
,
CPU
,
ops
::
GlobalScatterOpCPUKernel
<
double
>
,
ALL_LAYOUT
,
ops
::
GlobalScatterOpCPUKernel
<
int
>
,
ops
::
GlobalScatterOpCPUKernel
,
ops
::
GlobalScatterOpCPUKernel
<
int64_t
>
,
float
,
ops
::
GlobalScatterOpCPUKernel
<
plat
::
float16
>
);
double
,
int
,
int64_t
,
plat
::
float16
)
{}
paddle/fluid/operators/collective/global_scatter_op.cu.cc
浏览文件 @
e93e8a3f
...
@@ -259,7 +259,7 @@ struct GlobalScatterProcessGroupFunctor<phi::GPUContext, T> {
...
@@ -259,7 +259,7 @@ struct GlobalScatterProcessGroupFunctor<phi::GPUContext, T> {
}
}
};
};
template
<
typename
T
>
template
<
typename
T
,
typename
DeviceContext
>
class
GlobalScatterOpCUDAKernel
:
public
framework
::
OpKernel
<
T
>
{
class
GlobalScatterOpCUDAKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
...
@@ -281,9 +281,12 @@ class GlobalScatterOpCUDAKernel : public framework::OpKernel<T> {
...
@@ -281,9 +281,12 @@ class GlobalScatterOpCUDAKernel : public framework::OpKernel<T> {
namespace
ops
=
paddle
::
operators
;
namespace
ops
=
paddle
::
operators
;
namespace
plat
=
paddle
::
platform
;
namespace
plat
=
paddle
::
platform
;
REGISTER_OP_CUDA_KERNEL
(
global_scatter
,
PD_REGISTER_STRUCT_KERNEL
(
global_scatter
,
ops
::
GlobalScatterOpCUDAKernel
<
float
>
,
GPU
,
ops
::
GlobalScatterOpCUDAKernel
<
double
>
,
ALL_LAYOUT
,
ops
::
GlobalScatterOpCUDAKernel
<
int
>
,
ops
::
GlobalScatterOpCUDAKernel
,
ops
::
GlobalScatterOpCUDAKernel
<
int64_t
>
,
float
,
ops
::
GlobalScatterOpCUDAKernel
<
plat
::
float16
>
);
double
,
int
,
int64_t
,
plat
::
float16
)
{}
paddle/fluid/operators/collective/global_scatter_op.h
浏览文件 @
e93e8a3f
...
@@ -25,7 +25,7 @@ limitations under the License. */
...
@@ -25,7 +25,7 @@ limitations under the License. */
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
template
<
typename
T
>
template
<
typename
T
,
typename
DeviceContext
>
class
GlobalScatterOpCPUKernel
:
public
framework
::
OpKernel
<
T
>
{
class
GlobalScatterOpCPUKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
...
...
paddle/fluid/operators/detection/generate_mask_labels_op.cc
浏览文件 @
e93e8a3f
...
@@ -328,7 +328,7 @@ std::vector<phi::DenseTensor> SampleMaskForOneImage(
...
@@ -328,7 +328,7 @@ std::vector<phi::DenseTensor> SampleMaskForOneImage(
return
res
;
return
res
;
}
}
template
<
typename
T
>
template
<
typename
T
,
typename
DeviceContext
>
class
GenerateMaskLabelsKernel
:
public
framework
::
OpKernel
<
T
>
{
class
GenerateMaskLabelsKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
...
@@ -533,5 +533,9 @@ REGISTER_OPERATOR(
...
@@ -533,5 +533,9 @@ REGISTER_OPERATOR(
ops
::
GenerateMaskLabelsOpMaker
,
ops
::
GenerateMaskLabelsOpMaker
,
paddle
::
framework
::
EmptyGradOpMaker
<
paddle
::
framework
::
OpDesc
>
,
paddle
::
framework
::
EmptyGradOpMaker
<
paddle
::
framework
::
OpDesc
>
,
paddle
::
framework
::
EmptyGradOpMaker
<
paddle
::
imperative
::
OpBase
>
);
paddle
::
framework
::
EmptyGradOpMaker
<
paddle
::
imperative
::
OpBase
>
);
REGISTER_OP_CPU_KERNEL
(
generate_mask_labels
,
ops
::
GenerateMaskLabelsKernel
<
float
>
);
PD_REGISTER_STRUCT_KERNEL
(
generate_mask_labels
,
CPU
,
ALL_LAYOUT
,
ops
::
GenerateMaskLabelsKernel
,
float
)
{}
paddle/fluid/operators/detection/generate_proposal_labels_op.cc
浏览文件 @
e93e8a3f
...
@@ -510,7 +510,7 @@ std::vector<phi::DenseTensor> SampleRoisForOneImage(
...
@@ -510,7 +510,7 @@ std::vector<phi::DenseTensor> SampleRoisForOneImage(
return
res
;
return
res
;
}
}
template
<
typename
T
>
template
<
typename
T
,
typename
DeviceContext
>
class
GenerateProposalLabelsKernel
:
public
framework
::
OpKernel
<
T
>
{
class
GenerateProposalLabelsKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
...
@@ -811,9 +811,12 @@ REGISTER_OPERATOR(
...
@@ -811,9 +811,12 @@ REGISTER_OPERATOR(
ops
::
GenerateProposalLabelsOpMaker
,
ops
::
GenerateProposalLabelsOpMaker
,
paddle
::
framework
::
EmptyGradOpMaker
<
paddle
::
framework
::
OpDesc
>
,
paddle
::
framework
::
EmptyGradOpMaker
<
paddle
::
framework
::
OpDesc
>
,
paddle
::
framework
::
EmptyGradOpMaker
<
paddle
::
imperative
::
OpBase
>
);
paddle
::
framework
::
EmptyGradOpMaker
<
paddle
::
imperative
::
OpBase
>
);
REGISTER_OP_CPU_KERNEL
(
generate_proposal_labels
,
PD_REGISTER_STRUCT_KERNEL
(
generate_proposal_labels
,
ops
::
GenerateProposalLabelsKernel
<
float
>
,
CPU
,
ops
::
GenerateProposalLabelsKernel
<
double
>
);
ALL_LAYOUT
,
ops
::
GenerateProposalLabelsKernel
,
float
,
double
)
{}
REGISTER_OP_VERSION
(
generate_proposal_labels
)
REGISTER_OP_VERSION
(
generate_proposal_labels
)
.
AddCheckpoint
(
.
AddCheckpoint
(
...
...
paddle/fluid/operators/gaussian_random_batch_size_like_op.cc
浏览文件 @
e93e8a3f
...
@@ -19,7 +19,7 @@ limitations under the License. */
...
@@ -19,7 +19,7 @@ limitations under the License. */
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
template
<
typename
T
>
template
<
typename
T
,
typename
DeviceContext
>
class
CPUGaussianRandomBatchSizeLikeKernel
:
public
framework
::
OpKernel
<
T
>
{
class
CPUGaussianRandomBatchSizeLikeKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
...
@@ -99,7 +99,10 @@ REGISTER_OPERATOR(
...
@@ -99,7 +99,10 @@ REGISTER_OPERATOR(
paddle
::
framework
::
EmptyGradOpMaker
<
paddle
::
imperative
::
OpBase
>
,
paddle
::
framework
::
EmptyGradOpMaker
<
paddle
::
imperative
::
OpBase
>
,
paddle
::
operators
::
BatchSizeLikeNoNeedBufferVarsInferer
);
paddle
::
operators
::
BatchSizeLikeNoNeedBufferVarsInferer
);
REGISTER_OP_CPU_KERNEL
(
namespace
ops
=
paddle
::
operators
;
gaussian_random_batch_size_like
,
PD_REGISTER_STRUCT_KERNEL
(
gaussian_random_batch_size_like
,
paddle
::
operators
::
CPUGaussianRandomBatchSizeLikeKernel
<
float
>
,
CPU
,
paddle
::
operators
::
CPUGaussianRandomBatchSizeLikeKernel
<
double
>
);
ALL_LAYOUT
,
ops
::
CPUGaussianRandomBatchSizeLikeKernel
,
float
,
double
)
{}
paddle/fluid/operators/gaussian_random_batch_size_like_op.cu
浏览文件 @
e93e8a3f
...
@@ -47,7 +47,7 @@ struct GaussianGenerator {
...
@@ -47,7 +47,7 @@ struct GaussianGenerator {
}
}
};
};
template
<
typename
T
>
template
<
typename
T
,
typename
DeviceContext
>
class
GPUGaussianRandomBatchSizeLikeKernel
:
public
framework
::
OpKernel
<
T
>
{
class
GPUGaussianRandomBatchSizeLikeKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
...
@@ -78,9 +78,12 @@ class GPUGaussianRandomBatchSizeLikeKernel : public framework::OpKernel<T> {
...
@@ -78,9 +78,12 @@ class GPUGaussianRandomBatchSizeLikeKernel : public framework::OpKernel<T> {
}
// namespace operators
}
// namespace operators
}
// namespace paddle
}
// namespace paddle
REGISTER_OP_CUDA_KERNEL
(
namespace
ops
=
paddle
::
operators
;
gaussian_random_batch_size_like
,
namespace
plat
=
paddle
::
platform
;
paddle
::
operators
::
GPUGaussianRandomBatchSizeLikeKernel
<
PD_REGISTER_STRUCT_KERNEL
(
gaussian_random_batch_size_like
,
paddle
::
platform
::
float16
>
,
GPU
,
paddle
::
operators
::
GPUGaussianRandomBatchSizeLikeKernel
<
float
>
,
ALL_LAYOUT
,
paddle
::
operators
::
GPUGaussianRandomBatchSizeLikeKernel
<
double
>
);
ops
::
GPUGaussianRandomBatchSizeLikeKernel
,
float
,
double
,
plat
::
float16
)
{}
paddle/fluid/operators/graph_khop_sampler_op.cc
浏览文件 @
e93e8a3f
...
@@ -136,6 +136,10 @@ using CPU = phi::CPUContext;
...
@@ -136,6 +136,10 @@ using CPU = phi::CPUContext;
REGISTER_OPERATOR
(
graph_khop_sampler
,
REGISTER_OPERATOR
(
graph_khop_sampler
,
ops
::
GraphKhopSamplerOP
,
ops
::
GraphKhopSamplerOP
,
ops
::
GraphKhopSamplerOpMaker
);
ops
::
GraphKhopSamplerOpMaker
);
REGISTER_OP_CPU_KERNEL
(
graph_khop_sampler
,
ops
::
GraphKhopSamplerOpKernel
<
CPU
,
int32_t
>
,
PD_REGISTER_STRUCT_KERNEL
(
graph_khop_sampler
,
ops
::
GraphKhopSamplerOpKernel
<
CPU
,
int64_t
>
);
CPU
,
ALL_LAYOUT
,
ops
::
GraphKhopSamplerOpKernel
,
int32_t
,
int64_t
)
{}
paddle/fluid/operators/graph_khop_sampler_op.cu
浏览文件 @
e93e8a3f
...
@@ -412,7 +412,7 @@ void ReindexFunc(const framework::ExecutionContext& ctx,
...
@@ -412,7 +412,7 @@ void ReindexFunc(const framework::ExecutionContext& ctx,
thrust
::
raw_pointer_cast
(
values
.
data
()));
thrust
::
raw_pointer_cast
(
values
.
data
()));
}
}
template
<
typename
DeviceContext
,
typename
T
>
template
<
typename
T
,
typename
DeviceContext
>
class
GraphKhopSamplerOpCUDAKernel
:
public
framework
::
OpKernel
<
T
>
{
class
GraphKhopSamplerOpCUDAKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
...
@@ -668,6 +668,9 @@ class GraphKhopSamplerOpCUDAKernel : public framework::OpKernel<T> {
...
@@ -668,6 +668,9 @@ class GraphKhopSamplerOpCUDAKernel : public framework::OpKernel<T> {
using
CUDA
=
phi
::
GPUContext
;
using
CUDA
=
phi
::
GPUContext
;
namespace
ops
=
paddle
::
operators
;
namespace
ops
=
paddle
::
operators
;
REGISTER_OP_CUDA_KERNEL
(
graph_khop_sampler
,
PD_REGISTER_STRUCT_KERNEL
(
graph_khop_sampler
,
ops
::
GraphKhopSamplerOpCUDAKernel
<
CUDA
,
int32_t
>
,
GPU
,
ops
::
GraphKhopSamplerOpCUDAKernel
<
CUDA
,
int64_t
>
);
ALL_LAYOUT
,
ops
::
GraphKhopSamplerOpCUDAKernel
,
int32_t
,
int64_t
)
{}
paddle/fluid/operators/graph_khop_sampler_op.h
浏览文件 @
e93e8a3f
...
@@ -191,7 +191,7 @@ void SampleNeighbors(const T* src,
...
@@ -191,7 +191,7 @@ void SampleNeighbors(const T* src,
}
}
}
}
template
<
typename
DeviceContext
,
typename
T
>
template
<
typename
T
,
typename
DeviceContext
>
class
GraphKhopSamplerOpKernel
:
public
framework
::
OpKernel
<
T
>
{
class
GraphKhopSamplerOpKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
...
...
paddle/fluid/operators/group_norm_op.cc
浏览文件 @
e93e8a3f
...
@@ -12,8 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...
@@ -12,8 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#include "paddle/fluid/operators/group_norm_op.h"
#include <memory>
#include <memory>
#include <string>
#include <string>
#include <unordered_map>
#include <unordered_map>
...
...
paddle/fluid/operators/group_norm_op.cu
已删除
100644 → 0
浏览文件 @
aac8da90
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef __NVCC__
#include "cub/cub.cuh"
#endif
#ifdef __HIPCC__
#include <hipcub/hipcub.hpp>
namespace
cub
=
hipcub
;
#endif
#include "paddle/fluid/operators/group_norm_op.h"
#include "paddle/phi/backends/gpu/gpu_device_function.h"
#include "paddle/phi/backends/gpu/gpu_primitives.h"
namespace
paddle
{
namespace
operators
{
using
DataLayout
=
phi
::
DataLayout
;
enum
GroupNormKernelFlags
{
kHasScale
=
1
,
kHasBias
=
2
};
#define ALIGN_BYTES 16
#define CHECK_CASE(i, flags, kernel_name, ...) \
if (i == flags) { \
kernel_name<T, i><<<grid, threads, 0, dev_ctx.stream()>>>(__VA_ARGS__); \
}
// 0 for no scale, no bias
// 1 for has scale, no bias
// 2 for no scale, has bias
// 3 for has scale, has bias
#define UNROLL_ALL_CASES(flags, kernel_name, ...) \
CHECK_CASE(0, flags, kernel_name, __VA_ARGS__) \
CHECK_CASE(1, flags, kernel_name, __VA_ARGS__) \
CHECK_CASE(2, flags, kernel_name, __VA_ARGS__) \
CHECK_CASE(3, flags, kernel_name, __VA_ARGS__)
template
<
typename
T
>
__device__
__inline__
void
CudaAtomicAddWithWarp
(
T
*
sum
,
T
value
)
{
typedef
cub
::
WarpReduce
<
T
>
WarpReduce
;
typename
WarpReduce
::
TempStorage
temp_storage
;
value
=
WarpReduce
(
temp_storage
).
Sum
(
value
);
if
(
cub
::
LaneId
()
==
0
)
phi
::
CudaAtomicAdd
(
sum
,
value
);
}
template
<
typename
T
>
__global__
void
GroupNormForwardGetMeanAndVar
(
const
T
*
x
,
int
N
,
int
C
,
int
W
,
int
imsize
,
int
groups
,
int
group_size
,
T
*
mean
,
T
*
var
)
{
int
gid
=
blockIdx
.
y
;
int
cid
=
blockIdx
.
x
;
int
bid
=
blockIdx
.
z
;
int
H
=
imsize
/
W
;
int
number
=
min
(
group_size
,
static_cast
<
int
>
(
C
-
gid
*
group_size
));
int
ccid
=
gid
*
group_size
+
cid
;
if
(
ccid
>=
C
)
return
;
T
x_mean
=
0
,
x_var
=
0
;
for
(
int
imid
=
threadIdx
.
x
;
imid
<
imsize
;
imid
+=
blockDim
.
x
)
{
T
val
;
int
hid
=
imid
/
W
;
int
wid
=
imid
%
W
;
val
=
x
[(
bid
*
H
+
hid
)
*
W
*
C
+
wid
*
C
+
ccid
];
x_mean
+=
val
;
x_var
+=
val
*
val
;
}
x_mean
/=
number
*
imsize
;
x_var
/=
number
*
imsize
;
CudaAtomicAddWithWarp
(
&
mean
[
bid
*
groups
+
gid
],
x_mean
);
CudaAtomicAddWithWarp
(
&
var
[
bid
*
groups
+
gid
],
x_var
);
}
template
<
typename
T
,
typename
AccT
,
int
VecSize
,
int
Num
>
__device__
__forceinline__
void
ThreadReduce
(
phi
::
Array
<
const
T
*
,
Num
>
arrs
,
int
size
,
const
int
offset
,
AccT
*
out_mean
,
AccT
*
out_var
)
{
const
T
*
x
=
arrs
[
0
];
const
T
*
y
;
if
(
Num
==
2
)
{
y
=
arrs
[
1
];
}
using
VecT
=
kps
::
details
::
VectorType
<
T
,
VecSize
>
;
int
tid
=
threadIdx
.
x
;
if
(
offset
>
0
)
{
x
-=
offset
;
if
(
Num
==
2
)
{
y
-=
offset
;
}
size
+=
offset
;
if
(
tid
>=
offset
)
{
if
(
Num
==
1
)
{
*
out_mean
+=
x
[
tid
];
*
out_var
+=
x
[
tid
]
*
x
[
tid
];
}
else
if
(
Num
==
2
)
{
*
out_mean
+=
y
[
tid
];
*
out_var
+=
y
[
tid
]
*
x
[
tid
];
}
}
size
-=
blockDim
.
x
;
x
+=
blockDim
.
x
;
if
(
Num
==
2
)
{
y
+=
blockDim
.
x
;
}
}
int
remain
=
size
%
(
VecSize
*
blockDim
.
x
);
T
ins_x
[
VecSize
];
T
ins_y
[
VecSize
];
VecT
*
ins_vec_x
=
reinterpret_cast
<
VecT
*>
(
&
ins_x
);
VecT
*
ins_vec_y
=
reinterpret_cast
<
VecT
*>
(
&
ins_y
);
// vector part
for
(;
VecSize
*
tid
<
(
size
-
remain
);
tid
+=
blockDim
.
x
)
{
*
ins_vec_x
=
reinterpret_cast
<
const
VecT
*>
(
x
)[
tid
];
if
(
Num
==
2
)
{
*
ins_vec_y
=
reinterpret_cast
<
const
VecT
*>
(
y
)[
tid
];
}
#pragma unroll
for
(
int
i
=
0
;
i
<
VecSize
;
++
i
)
{
if
(
Num
==
1
)
{
*
out_mean
+=
ins_x
[
i
];
*
out_var
+=
ins_x
[
i
]
*
ins_x
[
i
];
}
else
if
(
Num
==
2
)
{
*
out_mean
+=
ins_y
[
i
];
*
out_var
+=
ins_y
[
i
]
*
ins_x
[
i
];
}
}
}
// scalar part
tid
=
size
-
remain
+
threadIdx
.
x
;
for
(;
tid
<
size
;
tid
+=
blockDim
.
x
)
{
if
(
Num
==
1
)
{
*
out_mean
+=
x
[
tid
];
*
out_var
+=
x
[
tid
]
*
x
[
tid
];
}
else
if
(
Num
==
2
)
{
*
out_mean
+=
y
[
tid
];
*
out_var
+=
y
[
tid
]
*
x
[
tid
];
}
}
}
template
<
typename
T
>
__device__
__forceinline__
void
ReduceMeanAndVar
(
T
*
mean
,
T
*
var
,
T
x_mean
,
T
x_var
,
int
size
)
{
const
int
nc
=
blockIdx
.
x
;
x_mean
=
kps
::
details
::
BlockXReduce
<
T
,
kps
::
AddFunctor
<
T
>>
(
x_mean
,
kps
::
AddFunctor
<
T
>
());
x_var
=
kps
::
details
::
BlockXReduce
<
T
,
kps
::
AddFunctor
<
T
>>
(
x_var
,
kps
::
AddFunctor
<
T
>
());
__syncthreads
();
if
(
threadIdx
.
x
==
0
)
{
mean
[
nc
]
=
static_cast
<
T
>
(
x_mean
/
size
);
var
[
nc
]
=
static_cast
<
T
>
(
x_var
/
size
);
}
}
template
<
typename
T
>
__global__
void
ScalarGetMeanAndVarNCHW
(
const
T
*
x
,
T
*
mean
,
T
*
var
,
int
size
)
{
int
i
=
blockIdx
.
x
;
T
x_mean
=
0
,
x_var
=
0
;
for
(
int
j
=
threadIdx
.
x
;
j
<
size
;
j
+=
blockDim
.
x
)
{
T
val
;
val
=
x
[
i
*
size
+
j
];
x_mean
+=
val
;
x_var
+=
val
*
val
;
}
ReduceMeanAndVar
<
T
>
(
mean
,
var
,
x_mean
,
x_var
,
size
);
}
template
<
typename
T
,
typename
AccT
,
int
VecSize
>
__global__
void
VectorizedGetMeanAndVarNCHW
(
const
T
*
x
,
T
*
mean
,
T
*
var
,
int
size
)
{
int
i
=
blockIdx
.
x
;
AccT
x_mean
=
static_cast
<
AccT
>
(
0
);
AccT
x_var
=
static_cast
<
AccT
>
(
0
);
x
+=
i
*
size
;
const
int
input_offset
=
((
uint64_t
)
x
)
%
ALIGN_BYTES
/
sizeof
(
T
);
phi
::
Array
<
const
T
*
,
1
>
ins
;
ins
[
0
]
=
x
;
ThreadReduce
<
T
,
AccT
,
VecSize
,
1
>
(
ins
,
size
,
input_offset
,
&
x_mean
,
&
x_var
);
ReduceMeanAndVar
<
AccT
>
(
mean
,
var
,
x_mean
,
x_var
,
size
);
}
template
<
typename
T
,
int
flags
>
__global__
void
GroupNormForward
(
const
T
*
x
,
const
T
*
mean
,
const
T
*
var
,
const
T
*
scale
,
const
T
*
bias
,
int
N
,
int
C
,
int
W
,
int
imsize
,
int
groups
,
int
group_size
,
T
epsilon
,
T
*
y
,
T
*
real_var
,
const
DataLayout
data_layout
)
{
int
gid
=
blockIdx
.
y
;
int
cid
=
blockIdx
.
x
;
int
bid
=
blockIdx
.
z
;
int
H
=
imsize
/
W
;
int
ccid
=
gid
*
group_size
+
cid
;
if
(
ccid
>=
C
)
return
;
auto
ng
=
bid
*
groups
+
gid
;
T
x_mean
=
mean
[
ng
];
T
x_var
=
var
[
ng
];
x_var
=
x_var
-
x_mean
*
x_mean
;
T
var_inv
=
rsqrt
(
x_var
+
epsilon
);
if
(
cid
==
0
&&
threadIdx
.
x
==
0
)
{
real_var
[
ng
]
=
x_var
;
}
for
(
int
imid
=
threadIdx
.
x
;
imid
<
imsize
;
imid
+=
blockDim
.
x
)
{
T
val
;
int
hid
,
wid
;
int
index
=
(
bid
*
C
+
ccid
)
*
imsize
+
imid
;
if
(
data_layout
==
DataLayout
::
kNCHW
)
{
val
=
x
[
index
];
}
else
{
hid
=
imid
/
W
;
wid
=
imid
%
W
;
val
=
x
[(
bid
*
H
+
hid
)
*
W
*
C
+
wid
*
C
+
ccid
];
}
val
=
(
val
-
x_mean
)
*
var_inv
;
if
(
flags
&
kHasScale
)
{
val
*=
scale
[
ccid
];
}
if
(
flags
&
kHasBias
)
{
val
+=
bias
[
ccid
];
}
if
(
data_layout
==
DataLayout
::
kNCHW
)
{
y
[
index
]
=
val
;
}
else
{
y
[(
bid
*
H
+
hid
)
*
W
*
C
+
wid
*
C
+
ccid
]
=
val
;
}
}
}
template
<
typename
T
>
class
GroupNormKernel
<
phi
::
GPUContext
,
T
>
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
std
::
string
data_layout_str
=
ctx
.
Attr
<
std
::
string
>
(
"data_layout"
);
const
DataLayout
data_layout
=
phi
::
StringToDataLayout
(
data_layout_str
);
const
float
epsilon
=
ctx
.
Attr
<
float
>
(
"epsilon"
);
auto
*
scale
=
ctx
.
Input
<
phi
::
DenseTensor
>
(
"Scale"
);
auto
*
bias
=
ctx
.
Input
<
phi
::
DenseTensor
>
(
"Bias"
);
auto
*
x
=
ctx
.
Input
<
phi
::
DenseTensor
>
(
"X"
);
auto
*
y
=
ctx
.
Output
<
phi
::
DenseTensor
>
(
"Y"
);
auto
*
mean
=
ctx
.
Output
<
phi
::
DenseTensor
>
(
"Mean"
);
auto
*
var
=
ctx
.
Output
<
phi
::
DenseTensor
>
(
"Variance"
);
const
auto
groups
=
ctx
.
Attr
<
int
>
(
"groups"
);
const
auto
x_dims
=
x
->
dims
();
const
int
C
=
(
data_layout
==
DataLayout
::
kNCHW
?
x_dims
[
1
]
:
x_dims
[
x_dims
.
size
()
-
1
]);
const
int
group_size
=
C
/
groups
;
const
int
W
=
(
data_layout
==
DataLayout
::
kNCHW
?
x_dims
[
x_dims
.
size
()
-
1
]
:
x_dims
[
x_dims
.
size
()
-
2
]);
y
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
mean
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
var
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
phi
::
funcs
::
SetConstant
<
phi
::
GPUContext
,
T
>
set_zero
;
auto
&
dev_ctx
=
ctx
.
template
device_context
<
phi
::
GPUContext
>();
phi
::
DenseTensor
temp_var
;
temp_var
.
mutable_data
<
T
>
(
var
->
dims
(),
ctx
.
GetPlace
());
auto
*
x_data
=
x
->
data
<
T
>
();
auto
*
y_data
=
y
->
data
<
T
>
();
auto
*
mean_data
=
mean
->
data
<
T
>
();
auto
*
var_data
=
var
->
data
<
T
>
();
auto
*
temp_var_data
=
temp_var
.
data
<
T
>
();
const
T
*
scale_data
=
nullptr
;
if
(
scale
)
scale_data
=
scale
->
data
<
T
>
();
const
T
*
bias_data
=
nullptr
;
if
(
bias
)
bias_data
=
bias
->
data
<
T
>
();
int
imsize
=
1
;
if
(
data_layout
==
DataLayout
::
kNCHW
)
{
for
(
int
i
=
2
;
i
<
x_dims
.
size
();
++
i
)
{
imsize
*=
x_dims
[
i
];
}
}
else
{
for
(
int
i
=
1
;
i
<
x_dims
.
size
()
-
1
;
++
i
)
{
imsize
*=
x_dims
[
i
];
}
}
#ifdef __HIPCC__
int
block_size
=
std
::
max
(
std
::
min
(
256
,
imsize
),
64
);
#else
int
block_size
=
std
::
min
(
1024
,
imsize
);
#endif
dim3
grid
(
group_size
,
groups
,
x_dims
[
0
]);
dim3
threads
(
block_size
,
1
,
1
);
if
(
data_layout
==
DataLayout
::
kNCHW
)
{
using
AccT
=
typename
phi
::
dtype
::
MPTypeTrait
<
T
>::
Type
;
constexpr
int
vec_size
=
sizeof
(
float4
)
/
sizeof
(
T
);
int
size
=
group_size
*
imsize
;
const
int
max_num_threads
=
1024
;
int
max_block_size
=
std
::
min
(
size
/
vec_size
,
max_num_threads
);
int
block_size_nchw
=
1
;
while
(
block_size_nchw
<
max_block_size
)
{
block_size_nchw
*=
2
;
}
block_size_nchw
=
std
::
max
(
block_size_nchw
,
kps
::
details
::
kWarpSize
);
dim3
grids
(
x_dims
[
0
]
*
groups
);
dim3
blocks
(
block_size_nchw
);
if
(
size
<
vec_size
*
block_size_nchw
)
{
ScalarGetMeanAndVarNCHW
<
T
><<<
grids
,
blocks
,
0
,
dev_ctx
.
stream
()
>>>
(
x_data
,
mean_data
,
temp_var_data
,
size
);
}
else
{
VectorizedGetMeanAndVarNCHW
<
T
,
AccT
,
vec_size
>
<<<
grids
,
blocks
,
0
,
dev_ctx
.
stream
()
>>>
(
x_data
,
mean_data
,
temp_var_data
,
size
);
}
}
else
{
set_zero
(
dev_ctx
,
mean
,
static_cast
<
T
>
(
0
));
set_zero
(
dev_ctx
,
&
temp_var
,
static_cast
<
T
>
(
0
));
GroupNormForwardGetMeanAndVar
<
T
>
<<<
grid
,
threads
,
0
,
dev_ctx
.
stream
()
>>>
(
x_data
,
x_dims
[
0
],
C
,
W
,
imsize
,
groups
,
group_size
,
mean_data
,
temp_var_data
);
}
int
flags
=
(
scale_data
!=
nullptr
)
*
kHasScale
+
(
bias_data
!=
nullptr
)
*
kHasBias
;
UNROLL_ALL_CASES
(
flags
,
GroupNormForward
,
x_data
,
mean_data
,
temp_var_data
,
scale_data
,
bias_data
,
x_dims
[
0
],
C
,
W
,
imsize
,
groups
,
group_size
,
epsilon
,
y_data
,
var_data
,
data_layout
);
}
};
template
<
typename
T
,
int
flags
>
__global__
void
GroupNormBackwardGetMeanAndVar
(
const
T
*
x
,
const
T
*
scale
,
const
T
*
bias
,
const
T
*
d_y
,
int
N
,
int
C
,
int
W
,
int
imsize
,
int
groups
,
int
group_size
,
T
epsilon
,
T
*
d_mean
,
T
*
d_var
,
T
*
d_scale
,
T
*
d_bias
)
{
int
gid
=
blockIdx
.
y
;
int
cid
=
blockIdx
.
x
;
int
bid
=
blockIdx
.
z
;
int
H
=
imsize
/
W
;
int
number
=
min
(
group_size
,
static_cast
<
int
>
(
C
-
gid
*
group_size
));
int
ccid
=
gid
*
group_size
+
cid
;
if
(
ccid
>=
C
)
return
;
T
x_scale
=
(
flags
&
kHasScale
)
?
scale
[
ccid
]
:
1
;
T
x_bias
=
(
flags
&
kHasBias
)
?
bias
[
ccid
]
:
0
;
T
x_scale_inv
=
0
;
if
(
x_scale
!=
0
)
x_scale_inv
=
1.0
/
x_scale
;
T
d_mean_data
=
0
,
d_var_data
=
0
,
d_scale_data
=
0
,
d_bias_data
=
0
;
for
(
int
imid
=
threadIdx
.
x
;
imid
<
imsize
;
imid
+=
blockDim
.
x
)
{
T
val
,
dval
;
int
hid
=
imid
/
W
;
int
wid
=
imid
%
W
;
val
=
x
[(
bid
*
H
+
hid
)
*
W
*
C
+
wid
*
C
+
ccid
]
-
x_bias
;
dval
=
d_y
[(
bid
*
H
+
hid
)
*
W
*
C
+
wid
*
C
+
ccid
];
d_var_data
+=
val
*
dval
;
d_mean_data
+=
dval
*
x_scale
;
val
=
val
*
x_scale_inv
;
d_bias_data
+=
dval
;
d_scale_data
+=
val
*
dval
;
}
CudaAtomicAddWithWarp
(
&
(
d_mean
[
bid
*
groups
+
gid
]),
d_mean_data
);
CudaAtomicAddWithWarp
(
&
(
d_var
[
bid
*
groups
+
gid
]),
d_var_data
);
if
(
flags
&
kHasScale
)
{
#if CUDA_VERSION >= 11070
phi
::
CudaAtomicAdd
(
&
(
d_scale
[
ccid
]),
d_scale_data
);
#else
CudaAtomicAddWithWarp
(
&
(
d_scale
[
ccid
]),
d_scale_data
);
#endif
}
if
(
flags
&
kHasBias
)
{
#if CUDA_VERSION >= 11070
phi
::
CudaAtomicAdd
(
&
(
d_bias
[
ccid
]),
d_bias_data
);
#else
CudaAtomicAddWithWarp
(
&
(
d_bias
[
ccid
]),
d_bias_data
);
#endif
}
}
template
<
typename
T
,
int
flags
>
__global__
void
GroupNormBackward
(
const
T
*
x
,
const
T
*
d_y
,
const
T
*
scale
,
const
T
*
bias
,
const
T
*
var
,
const
T
*
d_mean
,
const
T
*
d_var
,
int
N
,
int
C
,
int
W
,
int
imsize
,
int
groups
,
int
group_size
,
T
epsilon
,
T
*
d_x
)
{
int
gid
=
blockIdx
.
y
;
int
cid
=
blockIdx
.
x
;
int
bid
=
blockIdx
.
z
;
int
H
=
imsize
/
W
;
int
number
=
min
(
group_size
,
static_cast
<
int
>
(
C
-
gid
*
group_size
));
int
ccid
=
gid
*
group_size
+
cid
;
if
(
ccid
>=
C
)
return
;
T
x_var
=
var
[
bid
*
groups
+
gid
];
T
d_x_mean
=
d_mean
[
bid
*
groups
+
gid
];
T
d_x_var
=
d_var
[
bid
*
groups
+
gid
];
T
x_var_inv
=
1.0
/
sqrt
(
x_var
+
epsilon
);
T
number_inv
=
1.0
/
(
number
*
imsize
);
T
x_scale
=
(
flags
&
kHasScale
)
?
scale
[
ccid
]
:
1
;
T
x_bias
=
(
flags
&
kHasBias
)
?
bias
[
ccid
]
:
0
;
T
x_scale_inv
=
0
;
if
(
x_scale
!=
0
)
x_scale_inv
=
1.0
/
x_scale
;
for
(
int
imid
=
threadIdx
.
x
;
imid
<
imsize
;
imid
+=
blockDim
.
x
)
{
int
hid
=
imid
/
W
;
int
wid
=
imid
%
W
;
T
tmp
=
x
[(
bid
*
H
+
hid
)
*
W
*
C
+
wid
*
C
+
ccid
];
T
v_y
=
(
tmp
-
x_bias
)
*
x_scale_inv
;
T
dly
=
d_y
[(
bid
*
H
+
hid
)
*
W
*
C
+
wid
*
C
+
ccid
];
d_x
[(
bid
*
H
+
hid
)
*
W
*
C
+
wid
*
C
+
ccid
]
=
x_var_inv
*
(
dly
*
x_scale
-
number_inv
*
d_x_var
*
v_y
-
number_inv
*
d_x_mean
);
}
}
template
<
typename
T
>
__global__
void
ScalarGetDsDbCUDAKernel
(
int
imsize
,
const
T
*
x
,
const
T
*
dy
,
T
*
ds
,
T
*
db
)
{
const
int
nc
=
blockIdx
.
x
;
T
ds_sum
=
0
;
T
db_sum
=
0
;
for
(
int
i
=
threadIdx
.
x
;
i
<
imsize
;
i
+=
blockDim
.
x
)
{
const
int
index
=
nc
*
imsize
+
i
;
ds_sum
+=
dy
[
index
]
*
x
[
index
];
db_sum
+=
dy
[
index
];
}
ReduceMeanAndVar
<
T
>
(
db
,
ds
,
db_sum
,
ds_sum
,
1
);
}
template
<
typename
T
>
__global__
void
GetScaleBiasGradientCUDAKernel
(
int
N
,
int
C
,
int
group
,
T
epsilon
,
const
T
*
mean
,
const
T
*
var
,
const
T
*
ds
,
const
T
*
db
,
T
*
d_scale
,
T
*
d_bias
)
{
const
int
c
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
c
<
C
)
{
const
int
G
=
group
;
const
int
D
=
C
/
G
;
T
sum1
=
0
;
T
sum2
=
0
;
for
(
int
n
=
0
;
n
<
N
;
++
n
)
{
const
int
nc
=
n
*
C
+
c
;
const
int
ng
=
n
*
G
+
c
/
D
;
sum1
+=
(
d_scale
==
nullptr
)
?
T
(
0
)
:
((
ds
[
nc
]
-
db
[
nc
]
*
static_cast
<
T
>
(
mean
[
ng
]))
*
static_cast
<
T
>
(
rsqrt
(
var
[
ng
]
+
epsilon
)));
sum2
+=
(
d_bias
==
nullptr
)
?
T
(
0
)
:
db
[
nc
];
}
if
(
d_scale
!=
nullptr
)
{
d_scale
[
c
]
=
sum1
;
}
if
(
d_bias
!=
nullptr
)
{
d_bias
[
c
]
=
sum2
;
}
}
}
template
<
typename
T
,
int
BlockDim
>
__global__
void
GetBackwardParamsCUDAKernel
(
int
imsize
,
int
groups
,
int
group_size
,
T
epsilon
,
const
T
*
mean
,
const
T
*
var
,
const
T
*
scale
,
const
T
*
ds
,
const
T
*
db
,
T
*
p1
,
T
*
p2
,
T
*
p3
)
{
const
int
n
=
blockIdx
.
x
;
const
int
g
=
blockIdx
.
y
;
const
int
ng
=
n
*
groups
+
g
;
T
sum1
=
0
;
T
sum2
=
0
;
T
var_inv
=
rsqrt
(
var
[
ng
]
+
epsilon
);
for
(
int64_t
i
=
threadIdx
.
x
;
i
<
group_size
;
i
+=
blockDim
.
x
)
{
const
int64_t
index
=
ng
*
group_size
+
i
;
const
int64_t
c
=
g
*
group_size
+
i
;
const
T
scale_v
=
scale
==
nullptr
?
T
(
1
)
:
static_cast
<
T
>
(
scale
[
c
]);
sum1
+=
ds
[
index
]
*
scale_v
;
sum2
+=
db
[
index
]
*
scale_v
;
const
T
scale_c
=
scale
==
nullptr
?
T
(
0
)
:
static_cast
<
T
>
(
scale
[
c
]);
p1
[
index
]
=
scale_c
*
var_inv
;
}
typedef
cub
::
BlockReduce
<
T
,
BlockDim
>
BlockReduce
;
__shared__
typename
BlockReduce
::
TempStorage
ds_storage
;
__shared__
typename
BlockReduce
::
TempStorage
db_storage
;
sum1
=
BlockReduce
(
ds_storage
).
Reduce
(
sum1
,
cub
::
Sum
());
sum2
=
BlockReduce
(
db_storage
).
Reduce
(
sum2
,
cub
::
Sum
());
if
(
threadIdx
.
x
==
0
)
{
const
T
s
=
T
(
1
)
/
static_cast
<
T
>
(
group_size
*
imsize
);
const
T
x
=
(
sum2
*
static_cast
<
T
>
(
mean
[
ng
])
-
sum1
)
*
static_cast
<
T
>
(
var_inv
)
*
static_cast
<
T
>
(
var_inv
)
*
static_cast
<
T
>
(
var_inv
)
*
s
;
p2
[
ng
]
=
x
;
p3
[
ng
]
=
-
x
*
static_cast
<
T
>
(
mean
[
ng
])
-
sum2
*
static_cast
<
T
>
(
var_inv
)
*
s
;
}
}
template
<
typename
T
>
__global__
void
GetXGradientCUDAKernel
(
int
imsize
,
int
C
,
int
group_size
,
int
groups
,
T
*
p1
,
T
*
p2
,
T
*
p3
,
const
T
*
x
,
const
T
*
dy
,
T
*
dx
)
{
int
cid
=
blockIdx
.
x
;
int
gid
=
blockIdx
.
y
;
int
bid
=
blockIdx
.
z
;
int
ccid
=
bid
*
C
+
gid
*
group_size
+
cid
;
int
ng
=
bid
*
groups
+
gid
;
int
nc
=
gid
*
group_size
+
cid
;
for
(
int
imid
=
threadIdx
.
x
;
imid
<
imsize
;
imid
+=
blockDim
.
x
)
{
int
index
=
(
bid
*
C
+
nc
)
*
imsize
+
imid
;
dx
[
index
]
=
p1
[
ccid
]
*
dy
[
index
]
+
p2
[
ng
]
*
x
[
index
]
+
p3
[
ng
];
}
}
template
<
typename
T
>
class
GroupNormGradKernel
<
phi
::
GPUContext
,
T
>
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
std
::
string
data_layout_str
=
ctx
.
Attr
<
std
::
string
>
(
"data_layout"
);
const
DataLayout
data_layout
=
phi
::
StringToDataLayout
(
data_layout_str
);
const
float
epsilon
=
ctx
.
Attr
<
float
>
(
"epsilon"
);
auto
*
x
=
ctx
.
Input
<
phi
::
DenseTensor
>
(
"X"
);
auto
*
y
=
ctx
.
Input
<
phi
::
DenseTensor
>
(
"Y"
);
auto
*
mean
=
ctx
.
Input
<
phi
::
DenseTensor
>
(
"Mean"
);
auto
*
var
=
ctx
.
Input
<
phi
::
DenseTensor
>
(
"Variance"
);
auto
*
scale
=
ctx
.
Input
<
phi
::
DenseTensor
>
(
"Scale"
);
auto
*
bias
=
ctx
.
Input
<
phi
::
DenseTensor
>
(
"Bias"
);
auto
*
d_y
=
ctx
.
Input
<
phi
::
DenseTensor
>
(
framework
::
GradVarName
(
"Y"
));
const
auto
groups
=
ctx
.
Attr
<
int
>
(
"groups"
);
// init output
auto
*
d_x
=
ctx
.
Output
<
phi
::
DenseTensor
>
(
framework
::
GradVarName
(
"X"
));
auto
*
d_scale
=
ctx
.
Output
<
phi
::
DenseTensor
>
(
framework
::
GradVarName
(
"Scale"
));
auto
*
d_bias
=
ctx
.
Output
<
phi
::
DenseTensor
>
(
framework
::
GradVarName
(
"Bias"
));
const
auto
&
x_dims
=
x
->
dims
();
const
int
C
=
(
data_layout
==
DataLayout
::
kNCHW
?
x_dims
[
1
]
:
x_dims
[
x_dims
.
size
()
-
1
]);
const
int
group_size
=
C
/
groups
;
const
int
W
=
(
data_layout
==
DataLayout
::
kNCHW
?
x_dims
[
x_dims
.
size
()
-
1
]
:
x_dims
[
x_dims
.
size
()
-
2
]);
d_x
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
phi
::
funcs
::
SetConstant
<
phi
::
GPUContext
,
T
>
set_zero
;
auto
&
dev_ctx
=
ctx
.
template
device_context
<
phi
::
GPUContext
>();
phi
::
DenseTensor
ds
,
db
;
ds
.
mutable_data
<
T
>
({
x_dims
[
0
],
C
},
ctx
.
GetPlace
());
db
.
mutable_data
<
T
>
({
x_dims
[
0
],
C
},
ctx
.
GetPlace
());
T
*
ds_data
=
ds
.
data
<
T
>
();
T
*
db_data
=
db
.
data
<
T
>
();
auto
*
y_data
=
y
->
data
<
T
>
();
auto
*
x_data
=
x
->
data
<
T
>
();
T
*
d_x_data
=
nullptr
;
if
(
d_x
)
d_x_data
=
d_x
->
data
<
T
>
();
auto
*
dy_data
=
d_y
->
data
<
T
>
();
auto
*
var_data
=
var
->
data
<
T
>
();
auto
*
mean_data
=
mean
->
data
<
T
>
();
T
*
d_scale_data
=
nullptr
;
if
(
d_scale
)
{
d_scale
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
d_scale_data
=
d_scale
->
data
<
T
>
();
}
T
*
d_bias_data
=
nullptr
;
if
(
d_bias
)
{
d_bias
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
d_bias_data
=
d_bias
->
data
<
T
>
();
}
const
T
*
scale_data
=
nullptr
;
if
(
scale
)
scale_data
=
scale
->
data
<
T
>
();
const
T
*
bias_data
=
nullptr
;
if
(
bias
)
bias_data
=
bias
->
data
<
T
>
();
int
imsize
=
1
;
if
(
data_layout
==
DataLayout
::
kNCHW
)
{
for
(
int
i
=
2
;
i
<
x_dims
.
size
();
++
i
)
{
imsize
*=
x_dims
[
i
];
}
}
else
{
for
(
int
i
=
1
;
i
<
x_dims
.
size
()
-
1
;
++
i
)
{
imsize
*=
x_dims
[
i
];
}
}
#ifdef __HIPCC__
int
block_size
=
std
::
max
(
std
::
min
(
256
,
imsize
),
64
);
const
int
block_dims
=
256
;
#else
int
block_size
=
std
::
min
(
1024
,
imsize
);
const
int
block_dims
=
1024
;
#endif
dim3
grid
(
group_size
,
groups
,
x_dims
[
0
]);
dim3
threads
(
block_size
,
1
,
1
);
int
flags
=
(
scale_data
!=
nullptr
)
*
kHasScale
+
(
bias_data
!=
nullptr
)
*
kHasBias
;
if
(
data_layout
==
DataLayout
::
kNCHW
)
{
const
int
max_num_threads
=
1024
;
int
max_block_size
=
std
::
min
(
imsize
,
max_num_threads
);
int
block_size_nchw
=
1
;
while
(
block_size_nchw
<
max_block_size
)
{
block_size_nchw
*=
2
;
}
block_size_nchw
=
std
::
max
(
block_size_nchw
,
kps
::
details
::
kWarpSize
);
dim3
blocks
(
block_size_nchw
);
ScalarGetDsDbCUDAKernel
<
T
>
<<<
x_dims
[
0
]
*
C
,
blocks
,
0
,
dev_ctx
.
stream
()
>>>
(
imsize
,
x_data
,
dy_data
,
ds_data
,
db_data
);
if
(
d_scale
||
d_bias
)
{
const
int
block
=
256
;
GetScaleBiasGradientCUDAKernel
<
T
>
<<<
(
C
+
block
-
1
)
/
block
,
block
,
0
,
dev_ctx
.
stream
()
>>>
(
x_dims
[
0
],
C
,
groups
,
epsilon
,
mean_data
,
var_data
,
ds_data
,
db_data
,
d_scale_data
,
d_bias_data
);
}
if
(
d_x_data
!=
nullptr
)
{
// p1 * dy + p2 * x + p3,
// p1, p2, p3 represent the reverse calculation of temporary variables
// p1 = scale * var_inv
// p2 = (db * scale * mean - ds * scale) * pow(var_inv, 3) * (1/n)
// p3 = -p2 * mean[ng] - db * scale * var_inv * (1/n);
phi
::
DenseTensor
p1
,
p2
,
p3
;
p1
.
mutable_data
<
T
>
({
x_dims
[
0
]
*
C
},
ctx
.
GetPlace
());
p2
.
mutable_data
<
T
>
({
x_dims
[
0
],
groups
},
ctx
.
GetPlace
());
p3
.
mutable_data
<
T
>
({
x_dims
[
0
],
groups
},
ctx
.
GetPlace
());
T
*
p1_data
=
p1
.
data
<
T
>
();
T
*
p2_data
=
p2
.
data
<
T
>
();
T
*
p3_data
=
p3
.
data
<
T
>
();
GetBackwardParamsCUDAKernel
<
T
,
block_dims
>
<<<
dim3
(
x_dims
[
0
],
groups
),
block_dims
,
0
,
dev_ctx
.
stream
()
>>>
(
imsize
,
groups
,
group_size
,
epsilon
,
mean_data
,
var_data
,
scale_data
,
ds_data
,
db_data
,
p1_data
,
p2_data
,
p3_data
);
GetXGradientCUDAKernel
<
T
>
<<<
grid
,
threads
,
0
,
dev_ctx
.
stream
()
>>>
(
imsize
,
C
,
group_size
,
groups
,
p1_data
,
p2_data
,
p3_data
,
x_data
,
dy_data
,
d_x_data
);
}
}
else
{
if
(
d_scale
)
{
set_zero
(
dev_ctx
,
d_scale
,
static_cast
<
T
>
(
0
));
}
if
(
d_bias
)
{
set_zero
(
dev_ctx
,
d_bias
,
static_cast
<
T
>
(
0
));
}
phi
::
DenseTensor
temp_var
;
temp_var
.
mutable_data
<
T
>
(
var
->
dims
(),
ctx
.
GetPlace
());
set_zero
(
dev_ctx
,
&
temp_var
,
static_cast
<
T
>
(
0
));
T
*
temp_var_data
=
temp_var
.
data
<
T
>
();
phi
::
DenseTensor
temp_mean
;
temp_mean
.
mutable_data
<
T
>
(
var
->
dims
(),
ctx
.
GetPlace
());
set_zero
(
dev_ctx
,
&
temp_mean
,
static_cast
<
T
>
(
0
));
T
*
temp_mean_data
=
temp_mean
.
data
<
T
>
();
int
flags
=
(
scale_data
!=
nullptr
)
*
kHasScale
+
(
bias_data
!=
nullptr
)
*
kHasBias
;
UNROLL_ALL_CASES
(
flags
,
GroupNormBackwardGetMeanAndVar
,
y_data
,
scale_data
,
bias_data
,
dy_data
,
x_dims
[
0
],
C
,
W
,
imsize
,
groups
,
group_size
,
epsilon
,
temp_mean_data
,
temp_var_data
,
d_scale_data
,
d_bias_data
);
if
(
d_x_data
!=
nullptr
)
{
UNROLL_ALL_CASES
(
flags
,
GroupNormBackward
,
y_data
,
dy_data
,
scale_data
,
bias_data
,
var_data
,
temp_mean_data
,
temp_var_data
,
x_dims
[
0
],
C
,
W
,
imsize
,
groups
,
group_size
,
epsilon
,
d_x_data
);
}
}
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OP_CUDA_KERNEL
(
group_norm
,
ops
::
GroupNormKernel
<
phi
::
GPUContext
,
float
>
,
ops
::
GroupNormKernel
<
phi
::
GPUContext
,
double
>
);
REGISTER_OP_CUDA_KERNEL
(
group_norm_grad
,
ops
::
GroupNormGradKernel
<
phi
::
GPUContext
,
float
>
,
ops
::
GroupNormGradKernel
<
phi
::
GPUContext
,
double
>
);
paddle/fluid/operators/group_norm_op.h
已删除
100644 → 0
浏览文件 @
aac8da90
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <algorithm>
#include <array>
#include <numeric>
#include <string>
#include "paddle/fluid/framework/data_layout.h"
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace
paddle
{
namespace
operators
{
using
DataLayout
=
phi
::
DataLayout
;
template
<
typename
DeviceContext
,
typename
T
>
class
GroupNormKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
std
::
string
data_layout_str
=
ctx
.
Attr
<
std
::
string
>
(
"data_layout"
);
const
DataLayout
data_layout
=
phi
::
StringToDataLayout
(
data_layout_str
);
const
float
epsilon
=
ctx
.
Attr
<
float
>
(
"epsilon"
);
auto
*
scale
=
ctx
.
Input
<
phi
::
DenseTensor
>
(
"Scale"
);
auto
*
bias
=
ctx
.
Input
<
phi
::
DenseTensor
>
(
"Bias"
);
auto
*
x
=
ctx
.
Input
<
phi
::
DenseTensor
>
(
"X"
);
auto
*
y
=
ctx
.
Output
<
phi
::
DenseTensor
>
(
"Y"
);
auto
*
mean
=
ctx
.
Output
<
phi
::
DenseTensor
>
(
"Mean"
);
auto
*
var
=
ctx
.
Output
<
phi
::
DenseTensor
>
(
"Variance"
);
const
auto
groups
=
ctx
.
Attr
<
int
>
(
"groups"
);
const
auto
x_dims
=
x
->
dims
();
const
int
C
=
(
data_layout
==
DataLayout
::
kNCHW
?
x_dims
[
1
]
:
x_dims
[
x_dims
.
size
()
-
1
]);
const
int
group_size
=
C
/
groups
;
y
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
mean
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
var
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
*
x_data
=
x
->
data
<
T
>
();
auto
*
y_data
=
y
->
data
<
T
>
();
auto
*
mean_data
=
mean
->
data
<
T
>
();
auto
*
var_data
=
var
->
data
<
T
>
();
const
T
*
scale_data
=
nullptr
;
if
(
scale
)
scale_data
=
scale
->
data
<
T
>
();
const
T
*
bias_data
=
nullptr
;
if
(
bias
)
bias_data
=
bias
->
data
<
T
>
();
int
imsize
=
1
;
if
(
data_layout
==
DataLayout
::
kNCHW
)
{
for
(
int
i
=
2
;
i
<
x_dims
.
size
();
++
i
)
{
imsize
*=
x_dims
[
i
];
}
}
else
{
for
(
int
i
=
1
;
i
<
x_dims
.
size
()
-
1
;
++
i
)
{
imsize
*=
x_dims
[
i
];
}
}
auto
*
iter_x_data
=
x_data
;
auto
*
iter_y_data
=
y_data
;
for
(
int
bid
=
0
;
bid
<
x_dims
[
0
];
bid
++
)
{
for
(
int
gid
=
0
;
gid
<
groups
;
gid
++
)
{
const
int64_t
M
=
8
;
std
::
array
<
T
,
M
>
x_mean_arr
;
std
::
array
<
T
,
M
>
x_var_arr
;
std
::
fill
(
x_mean_arr
.
begin
(),
x_mean_arr
.
end
(),
T
(
0
));
std
::
fill
(
x_var_arr
.
begin
(),
x_var_arr
.
end
(),
T
(
0
));
T
x_mean
=
0
,
x_var
=
0
;
int
number
=
std
::
min
(
group_size
,
static_cast
<
int
>
(
C
-
gid
*
group_size
));
auto
*
tmp_x
=
iter_x_data
;
auto
*
x_src_data
=
iter_x_data
;
auto
*
tmp_y
=
iter_y_data
;
auto
*
y_src_data
=
iter_y_data
;
if
(
data_layout
==
DataLayout
::
kNCHW
)
{
for
(
int
cid
=
0
;
cid
<
number
;
cid
++
)
{
int
imid
;
for
(
imid
=
0
;
imid
<
imsize
-
(
imsize
%
M
);
imid
+=
M
,
iter_x_data
+=
M
)
{
// TODO(gaoxiang): Because AVX/AVX2/AVX512 can not directly used
// in template class/function, before we complete high
// performance cpu vector extension, temporarily unrolling
// loop to get high precision and performance
x_mean_arr
[
0
]
+=
iter_x_data
[
0
];
x_var_arr
[
0
]
+=
iter_x_data
[
0
]
*
iter_x_data
[
0
];
x_mean_arr
[
1
]
+=
iter_x_data
[
1
];
x_var_arr
[
1
]
+=
iter_x_data
[
1
]
*
iter_x_data
[
1
];
x_mean_arr
[
2
]
+=
iter_x_data
[
2
];
x_var_arr
[
2
]
+=
iter_x_data
[
2
]
*
iter_x_data
[
2
];
x_mean_arr
[
3
]
+=
iter_x_data
[
3
];
x_var_arr
[
3
]
+=
iter_x_data
[
3
]
*
iter_x_data
[
3
];
x_mean_arr
[
4
]
+=
iter_x_data
[
4
];
x_var_arr
[
4
]
+=
iter_x_data
[
4
]
*
iter_x_data
[
4
];
x_mean_arr
[
5
]
+=
iter_x_data
[
5
];
x_var_arr
[
5
]
+=
iter_x_data
[
5
]
*
iter_x_data
[
5
];
x_mean_arr
[
6
]
+=
iter_x_data
[
6
];
x_var_arr
[
6
]
+=
iter_x_data
[
6
]
*
iter_x_data
[
6
];
x_mean_arr
[
7
]
+=
iter_x_data
[
7
];
x_var_arr
[
7
]
+=
iter_x_data
[
7
]
*
iter_x_data
[
7
];
}
x_mean
=
std
::
accumulate
(
x_mean_arr
.
cbegin
(),
x_mean_arr
.
cend
(),
x_mean
);
x_var
=
std
::
accumulate
(
x_var_arr
.
cbegin
(),
x_var_arr
.
cend
(),
x_var
);
std
::
fill
(
x_mean_arr
.
begin
(),
x_mean_arr
.
end
(),
T
(
0
));
std
::
fill
(
x_var_arr
.
begin
(),
x_var_arr
.
end
(),
T
(
0
));
for
(;
imid
<
imsize
;
imid
++
,
iter_x_data
++
)
{
x_mean
+=
iter_x_data
[
0
];
x_var
+=
iter_x_data
[
0
]
*
iter_x_data
[
0
];
}
}
}
else
{
for
(
int
cid
=
0
;
cid
<
number
;
cid
++
)
{
iter_x_data
=
tmp_x
+
cid
;
int
imid
;
for
(
imid
=
0
;
imid
<
imsize
-
(
imsize
%
M
);
imid
+=
M
,
iter_x_data
+=
M
*
C
)
{
// TODO(gaoxiang): Because AVX/AVX2/AVX512 can not directly used
// in template class/function, before we complete high
// performance cpu vector extension, temporarily unrolling
// loop to get high precision and performance
x_mean_arr
[
0
]
+=
iter_x_data
[
0
*
C
];
x_var_arr
[
0
]
+=
iter_x_data
[
0
*
C
]
*
iter_x_data
[
0
*
C
];
x_mean_arr
[
1
]
+=
iter_x_data
[
1
*
C
];
x_var_arr
[
1
]
+=
iter_x_data
[
1
*
C
]
*
iter_x_data
[
1
*
C
];
x_mean_arr
[
2
]
+=
iter_x_data
[
2
*
C
];
x_var_arr
[
2
]
+=
iter_x_data
[
2
*
C
]
*
iter_x_data
[
2
*
C
];
x_mean_arr
[
3
]
+=
iter_x_data
[
3
*
C
];
x_var_arr
[
3
]
+=
iter_x_data
[
3
*
C
]
*
iter_x_data
[
3
*
C
];
x_mean_arr
[
4
]
+=
iter_x_data
[
4
*
C
];
x_var_arr
[
4
]
+=
iter_x_data
[
4
*
C
]
*
iter_x_data
[
4
*
C
];
x_mean_arr
[
5
]
+=
iter_x_data
[
5
*
C
];
x_var_arr
[
5
]
+=
iter_x_data
[
5
*
C
]
*
iter_x_data
[
5
*
C
];
x_mean_arr
[
6
]
+=
iter_x_data
[
6
*
C
];
x_var_arr
[
6
]
+=
iter_x_data
[
6
*
C
]
*
iter_x_data
[
6
*
C
];
x_mean_arr
[
7
]
+=
iter_x_data
[
7
*
C
];
x_var_arr
[
7
]
+=
iter_x_data
[
7
*
C
]
*
iter_x_data
[
7
*
C
];
}
x_mean
=
std
::
accumulate
(
x_mean_arr
.
cbegin
(),
x_mean_arr
.
cend
(),
x_mean
);
x_var
=
std
::
accumulate
(
x_var_arr
.
cbegin
(),
x_var_arr
.
cend
(),
x_var
);
std
::
fill
(
x_mean_arr
.
begin
(),
x_mean_arr
.
end
(),
T
(
0
));
std
::
fill
(
x_var_arr
.
begin
(),
x_var_arr
.
end
(),
T
(
0
));
for
(;
imid
<
imsize
;
imid
++
,
iter_x_data
+=
C
)
{
x_mean
+=
iter_x_data
[
0
];
x_var
+=
iter_x_data
[
0
]
*
iter_x_data
[
0
];
}
}
iter_x_data
=
tmp_x
+
group_size
;
}
x_mean
/=
number
*
imsize
;
x_var
/=
number
*
imsize
;
x_var
=
std
::
max
(
x_var
-
x_mean
*
x_mean
,
T
(
0
));
T
var_inv
=
T
(
1
)
/
std
::
sqrt
(
x_var
+
epsilon
);
mean_data
[
bid
*
groups
+
gid
]
=
x_mean
;
var_data
[
bid
*
groups
+
gid
]
=
x_var
;
if
(
data_layout
==
DataLayout
::
kNCHW
)
{
for
(
int
cid
=
0
;
cid
<
number
;
cid
++
)
{
for
(
int
imid
=
0
;
imid
<
imsize
;
imid
++
,
tmp_x
++
,
iter_y_data
++
)
{
T
val
=
(
tmp_x
[
0
]
-
x_mean
)
*
var_inv
;
if
(
scale_data
)
val
*=
scale_data
[
gid
*
group_size
+
cid
];
if
(
bias_data
)
val
+=
bias_data
[
gid
*
group_size
+
cid
];
iter_y_data
[
0
]
=
val
;
}
}
}
else
{
for
(
int
cid
=
0
;
cid
<
number
;
cid
++
)
{
tmp_x
=
x_src_data
+
cid
;
iter_y_data
=
y_src_data
+
cid
;
for
(
int
imid
=
0
;
imid
<
imsize
;
imid
++
,
tmp_x
+=
C
,
iter_y_data
+=
C
)
{
T
val
=
(
tmp_x
[
0
]
-
x_mean
)
*
var_inv
;
if
(
scale_data
)
val
*=
scale_data
[
gid
*
group_size
+
cid
];
if
(
bias_data
)
val
+=
bias_data
[
gid
*
group_size
+
cid
];
iter_y_data
[
0
]
=
val
;
}
}
iter_y_data
=
tmp_y
+
group_size
;
}
}
if
(
data_layout
==
DataLayout
::
kNHWC
)
{
iter_x_data
=
x_data
+
(
bid
+
1
)
*
C
*
imsize
;
iter_y_data
=
y_data
+
(
bid
+
1
)
*
C
*
imsize
;
}
}
}
};
template
<
typename
DeviceContext
,
typename
T
>
class
GroupNormGradKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
std
::
string
data_layout_str
=
ctx
.
Attr
<
std
::
string
>
(
"data_layout"
);
const
DataLayout
data_layout
=
phi
::
StringToDataLayout
(
data_layout_str
);
const
float
epsilon
=
ctx
.
Attr
<
float
>
(
"epsilon"
);
auto
*
x
=
ctx
.
Input
<
phi
::
DenseTensor
>
(
"Y"
);
auto
*
var
=
ctx
.
Input
<
phi
::
DenseTensor
>
(
"Variance"
);
auto
*
scale
=
ctx
.
Input
<
phi
::
DenseTensor
>
(
"Scale"
);
auto
*
bias
=
ctx
.
Input
<
phi
::
DenseTensor
>
(
"Bias"
);
auto
*
d_y
=
ctx
.
Input
<
phi
::
DenseTensor
>
(
framework
::
GradVarName
(
"Y"
));
const
auto
groups
=
ctx
.
Attr
<
int
>
(
"groups"
);
// init output
auto
*
d_x
=
ctx
.
Output
<
phi
::
DenseTensor
>
(
framework
::
GradVarName
(
"X"
));
auto
*
d_scale
=
ctx
.
Output
<
phi
::
DenseTensor
>
(
framework
::
GradVarName
(
"Scale"
));
auto
*
d_bias
=
ctx
.
Output
<
phi
::
DenseTensor
>
(
framework
::
GradVarName
(
"Bias"
));
const
auto
&
x_dims
=
x
->
dims
();
const
int
C
=
(
data_layout
==
DataLayout
::
kNCHW
?
x_dims
[
1
]
:
x_dims
[
x_dims
.
size
()
-
1
]);
const
int
group_size
=
C
/
groups
;
d_x
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
phi
::
funcs
::
SetConstant
<
DeviceContext
,
T
>
set_zero
;
auto
&
dev_ctx
=
ctx
.
template
device_context
<
DeviceContext
>();
auto
*
x_data
=
x
->
data
<
T
>
();
auto
*
d_x_data
=
d_x
->
data
<
T
>
();
auto
*
y_data
=
d_y
->
data
<
T
>
();
auto
*
var_data
=
var
->
data
<
T
>
();
T
*
d_scale_data
=
nullptr
;
if
(
d_scale
)
{
d_scale
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
set_zero
(
dev_ctx
,
d_scale
,
static_cast
<
T
>
(
0
));
d_scale_data
=
d_scale
->
data
<
T
>
();
}
T
*
d_bias_data
=
nullptr
;
if
(
d_bias
)
{
d_bias
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
set_zero
(
dev_ctx
,
d_bias
,
static_cast
<
T
>
(
0
));
d_bias_data
=
d_bias
->
data
<
T
>
();
}
const
T
*
scale_data
=
nullptr
;
if
(
scale
)
scale_data
=
scale
->
data
<
T
>
();
const
T
*
bias_data
=
nullptr
;
if
(
bias
)
bias_data
=
bias
->
data
<
T
>
();
int
imsize
=
1
;
if
(
data_layout
==
DataLayout
::
kNCHW
)
{
for
(
int
i
=
2
;
i
<
x_dims
.
size
();
++
i
)
{
imsize
*=
x_dims
[
i
];
}
}
else
{
for
(
int
i
=
1
;
i
<
x_dims
.
size
()
-
1
;
++
i
)
{
imsize
*=
x_dims
[
i
];
}
}
auto
*
iter_x_data
=
x_data
;
auto
*
iter_d_x_data
=
d_x_data
;
auto
*
iter_y_data
=
y_data
;
for
(
int
bid
=
0
;
bid
<
x_dims
[
0
];
bid
++
)
{
for
(
int
gid
=
0
;
gid
<
groups
;
gid
++
)
{
T
x_var
=
var_data
[
bid
*
groups
+
gid
];
T
var_inv
=
1.0
/
sqrt
(
x_var
+
epsilon
);
int
number
=
std
::
min
(
group_size
,
static_cast
<
int
>
(
C
-
gid
*
group_size
));
T
number_inv
=
1.0
/
(
number
*
imsize
);
auto
*
tmp_x
=
iter_x_data
;
auto
*
tmp_y
=
iter_y_data
;
auto
*
tmp_d_x
=
iter_d_x_data
;
auto
*
x_src_data
=
iter_x_data
;
auto
*
y_src_data
=
iter_y_data
;
auto
*
iter_x_data_backup
=
iter_x_data
;
auto
*
iter_y_data_backup
=
iter_y_data
;
auto
*
iter_d_x_data_backup
=
iter_d_x_data
;
T
dp_scale
=
0
,
dp_bias
=
0
;
if
(
data_layout
==
DataLayout
::
kNCHW
)
{
for
(
int
cid
=
0
;
cid
<
number
;
cid
++
)
{
for
(
int
imid
=
0
;
imid
<
imsize
;
imid
++
,
iter_x_data
++
,
iter_y_data
++
)
{
T
val
=
iter_x_data
[
0
];
if
(
bias_data
)
val
-=
bias_data
[
gid
*
group_size
+
cid
];
T
dval
=
iter_y_data
[
0
];
dp_scale
+=
val
*
dval
;
if
(
scale_data
)
dp_bias
+=
dval
*
scale_data
[
gid
*
group_size
+
cid
];
if
(
scale_data
&&
scale_data
[
gid
*
group_size
+
cid
]
!=
0
)
val
/=
scale_data
[
gid
*
group_size
+
cid
];
if
(
d_bias_data
)
d_bias_data
[
gid
*
group_size
+
cid
]
+=
dval
;
if
(
d_scale_data
)
d_scale_data
[
gid
*
group_size
+
cid
]
+=
val
*
dval
;
}
}
for
(
int
cid
=
0
;
cid
<
number
;
cid
++
)
{
for
(
int
imid
=
0
;
imid
<
imsize
;
imid
++
,
iter_d_x_data
++
,
tmp_x
++
,
tmp_y
++
)
{
T
v_y
=
tmp_x
[
0
];
T
dly
=
tmp_y
[
0
];
T
dss
=
dp_scale
;
T
dbs
=
dp_bias
;
T
v_scale
=
1.
,
v_bias
=
0.
;
if
(
scale_data
)
v_scale
=
scale_data
[
gid
*
group_size
+
cid
];
if
(
bias_data
)
v_bias
=
bias_data
[
gid
*
group_size
+
cid
];
v_y
-=
v_bias
;
if
(
v_scale
!=
0
)
v_y
/=
v_scale
;
iter_d_x_data
[
0
]
=
(
dly
*
v_scale
-
number_inv
*
dss
*
v_y
-
number_inv
*
dbs
)
*
var_inv
;
}
}
}
else
{
for
(
int
cid
=
0
;
cid
<
number
;
cid
++
)
{
iter_x_data
=
x_src_data
+
cid
;
iter_y_data
=
y_src_data
+
cid
;
for
(
int
imid
=
0
;
imid
<
imsize
;
imid
++
,
iter_x_data
+=
C
,
iter_y_data
+=
C
)
{
T
val
=
iter_x_data
[
0
];
if
(
bias_data
)
val
-=
bias_data
[
gid
*
group_size
+
cid
];
T
dval
=
iter_y_data
[
0
];
dp_scale
+=
val
*
dval
;
if
(
scale_data
)
dp_bias
+=
dval
*
scale_data
[
gid
*
group_size
+
cid
];
if
(
scale_data
&&
scale_data
[
gid
*
group_size
+
cid
]
!=
0
)
val
/=
scale_data
[
gid
*
group_size
+
cid
];
if
(
d_bias_data
)
d_bias_data
[
gid
*
group_size
+
cid
]
+=
dval
;
if
(
d_scale_data
)
d_scale_data
[
gid
*
group_size
+
cid
]
+=
val
*
dval
;
}
}
for
(
int
cid
=
0
;
cid
<
number
;
cid
++
)
{
tmp_x
=
x_src_data
+
cid
;
tmp_y
=
y_src_data
+
cid
;
iter_d_x_data
=
tmp_d_x
+
cid
;
for
(
int
imid
=
0
;
imid
<
imsize
;
imid
++
,
iter_d_x_data
+=
C
,
tmp_x
+=
C
,
tmp_y
+=
C
)
{
T
v_y
=
tmp_x
[
0
];
T
dly
=
tmp_y
[
0
];
T
dss
=
dp_scale
;
T
dbs
=
dp_bias
;
T
v_scale
=
1.0
,
v_bias
=
0.
;
if
(
scale_data
)
v_scale
=
scale_data
[
gid
*
group_size
+
cid
];
if
(
bias_data
)
v_bias
=
bias_data
[
gid
*
group_size
+
cid
];
v_y
-=
v_bias
;
if
(
v_scale
!=
0
)
v_y
/=
v_scale
;
iter_d_x_data
[
0
]
=
(
dly
*
v_scale
-
number_inv
*
dss
*
v_y
-
number_inv
*
dbs
)
*
var_inv
;
}
}
iter_x_data
=
iter_x_data_backup
+
group_size
;
iter_y_data
=
iter_y_data_backup
+
group_size
;
iter_d_x_data
=
iter_d_x_data_backup
+
group_size
;
}
}
if
(
data_layout
==
DataLayout
::
kNHWC
)
{
iter_x_data
=
x_data
+
(
bid
+
1
)
*
C
*
imsize
;
iter_d_x_data
=
d_x_data
+
(
bid
+
1
)
*
C
*
imsize
;
iter_y_data
=
y_data
+
(
bid
+
1
)
*
C
*
imsize
;
}
}
}
};
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/l1_norm_op.cc
浏览文件 @
e93e8a3f
...
@@ -91,10 +91,13 @@ REGISTER_OPERATOR(l1_norm,
...
@@ -91,10 +91,13 @@ REGISTER_OPERATOR(l1_norm,
ops
::
L1NormGradMaker
<
paddle
::
framework
::
OpDesc
>
,
ops
::
L1NormGradMaker
<
paddle
::
framework
::
OpDesc
>
,
ops
::
L1NormGradMaker
<
paddle
::
imperative
::
OpBase
>
);
ops
::
L1NormGradMaker
<
paddle
::
imperative
::
OpBase
>
);
REGISTER_OPERATOR
(
l1_norm_grad
,
ops
::
L1NormGradOp
);
REGISTER_OPERATOR
(
l1_norm_grad
,
ops
::
L1NormGradOp
);
REGISTER_OP_CPU_KERNEL
(
l1_norm
,
ops
::
L1NormKernel
<
phi
::
CPUContext
,
float
>
);
REGISTER_OP_CPU_KERNEL
(
l1_norm_grad
,
ops
::
L1NormGradKernel
<
phi
::
CPUContext
,
float
>
);
REGISTER_OP_CUDA_KERNEL
(
l1_norm
,
ops
::
L1NormKernel
<
phi
::
GPUContext
,
float
>
);
PD_REGISTER_STRUCT_KERNEL
(
l1_norm
,
CPU
,
ALL_LAYOUT
,
ops
::
L1NormKernel
,
float
)
{}
REGISTER_OP_CUDA_KERNEL
(
l1_norm_grad
,
PD_REGISTER_STRUCT_KERNEL
(
ops
::
L1NormGradKernel
<
phi
::
GPUContext
,
float
>
);
l1_norm_grad
,
CPU
,
ALL_LAYOUT
,
ops
::
L1NormGradKernel
,
float
)
{}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PD_REGISTER_STRUCT_KERNEL
(
l1_norm
,
GPU
,
ALL_LAYOUT
,
ops
::
L1NormKernel
,
float
)
{}
PD_REGISTER_STRUCT_KERNEL
(
l1_norm_grad
,
GPU
,
ALL_LAYOUT
,
ops
::
L1NormGradKernel
,
float
)
{}
#endif
paddle/fluid/operators/l1_norm_op.h
浏览文件 @
e93e8a3f
...
@@ -21,7 +21,7 @@ namespace paddle {
...
@@ -21,7 +21,7 @@ namespace paddle {
namespace
operators
{
namespace
operators
{
// Out = sum(abs(X))
// Out = sum(abs(X))
template
<
typename
DeviceContext
,
typename
T
>
template
<
typename
T
,
typename
DeviceContext
>
class
L1NormKernel
:
public
framework
::
OpKernel
<
T
>
{
class
L1NormKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
...
@@ -39,7 +39,7 @@ class L1NormKernel : public framework::OpKernel<T> {
...
@@ -39,7 +39,7 @@ class L1NormKernel : public framework::OpKernel<T> {
};
};
// dX = dout * sign(X)
// dX = dout * sign(X)
template
<
typename
DeviceContext
,
typename
T
>
template
<
typename
T
,
typename
DeviceContext
>
class
L1NormGradKernel
:
public
framework
::
OpKernel
<
T
>
{
class
L1NormGradKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录