Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
337bb47b
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
337bb47b
编写于
7月 07, 2022
作者:
L
Leo Chen
提交者:
GitHub
7月 07, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Refine dist kernel, reuse norm (#44154)
* refine dist kernel, reuse norm * follow comments
上级
fa6333f9
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
20 addition
and
194 deletion
+20
-194
paddle/phi/kernels/dist_kernel.cc
paddle/phi/kernels/dist_kernel.cc
+20
-1
paddle/phi/kernels/gpu/dist_kernel.cu
paddle/phi/kernels/gpu/dist_kernel.cu
+0
-27
paddle/phi/kernels/impl/dist_kernel_impl.h
paddle/phi/kernels/impl/dist_kernel_impl.h
+0
-166
未找到文件。
paddle/phi/kernels/
cpu/
dist_kernel.cc
→
paddle/phi/kernels/dist_kernel.cc
浏览文件 @
337bb47b
...
...
@@ -16,6 +16,25 @@
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/dist_kernel_impl.h"
#include "paddle/phi/kernels/elementwise_subtract_kernel.h"
#include "paddle/phi/kernels/p_norm_kernel.h"
namespace
phi
{
template
<
typename
T
,
typename
Context
>
void
DistKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
const
DenseTensor
&
y
,
float
p
,
DenseTensor
*
out
)
{
auto
t
=
Subtract
<
T
,
Context
>
(
dev_ctx
,
x
,
y
);
PNormKernel
<
T
,
Context
>
(
dev_ctx
,
t
,
p
,
-
1
,
1e-12
,
false
,
true
,
out
);
}
}
// namespace phi
PD_REGISTER_KERNEL
(
dist
,
CPU
,
ALL_LAYOUT
,
phi
::
DistKernel
,
float
,
double
)
{}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PD_REGISTER_KERNEL
(
dist
,
GPU
,
ALL_LAYOUT
,
phi
::
DistKernel
,
float
,
double
)
{}
#endif
paddle/phi/kernels/gpu/dist_kernel.cu
已删除
100644 → 0
浏览文件 @
fa6333f9
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/dist_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/dist_kernel_impl.h"
#ifdef PADDLE_WITH_HIP
// Eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorReductionGpu.h:922
// do not support double in HIPCC platform (Eigen3 to be fixed)
PD_REGISTER_KERNEL
(
dist
,
GPU
,
ALL_LAYOUT
,
phi
::
DistKernel
,
float
)
{}
#else
PD_REGISTER_KERNEL
(
dist
,
GPU
,
ALL_LAYOUT
,
phi
::
DistKernel
,
float
,
double
)
{}
#endif
paddle/phi/kernels/impl/dist_kernel_impl.h
已删除
100644 → 0
浏览文件 @
fa6333f9
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <math.h>
#include <algorithm>
#include <vector>
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace
phi
{
template
<
typename
T
,
size_t
D
,
int
MajorType
=
Eigen
::
RowMajor
,
typename
IndexType
=
Eigen
::
DenseIndex
>
using
ETensor
=
phi
::
EigenTensor
<
T
,
D
,
MajorType
,
IndexType
>
;
template
<
int
Rank
>
static
void
GetBraodcastDims
(
const
phi
::
DDim
&
x_dims
,
const
phi
::
DDim
&
y_dims
,
Eigen
::
DSizes
<
int
,
Rank
>*
x_bcast_dims
,
Eigen
::
DSizes
<
int
,
Rank
>*
y_bcast_dims
)
{
int
bcast_dims_remainder
=
0
;
for
(
int
i
=
0
;
i
<
x_dims
.
size
();
++
i
)
{
if
(
x_dims
[
i
]
>=
y_dims
[
i
])
{
(
*
x_bcast_dims
)[
i
]
=
1
;
(
*
y_bcast_dims
)[
i
]
=
x_dims
[
i
]
/
y_dims
[
i
];
bcast_dims_remainder
+=
x_dims
[
i
]
%
y_dims
[
i
];
}
else
{
(
*
y_bcast_dims
)[
i
]
=
1
;
(
*
x_bcast_dims
)[
i
]
=
y_dims
[
i
]
/
x_dims
[
i
];
bcast_dims_remainder
+=
y_dims
[
i
]
%
x_dims
[
i
];
}
}
PADDLE_ENFORCE_EQ
(
bcast_dims_remainder
,
0
,
phi
::
errors
::
PreconditionNotMet
(
"The input tensor of Op(dist) could not be broadcast, "
"X's shape is [%s], Y's shape is [%s]."
,
x_dims
,
y_dims
));
}
static
phi
::
DDim
GetNewDims
(
const
phi
::
DDim
&
in_dims
,
int
rank
)
{
std
::
vector
<
int64_t
>
new_dims_vec
(
rank
);
if
(
in_dims
.
size
()
<
rank
)
{
for
(
int
i
=
0
;
i
<
rank
-
in_dims
.
size
();
++
i
)
{
new_dims_vec
[
i
]
=
1
;
}
for
(
int
i
=
0
;
i
<
in_dims
.
size
();
++
i
)
{
new_dims_vec
[
i
+
rank
-
in_dims
.
size
()]
=
in_dims
[
i
];
}
}
else
{
new_dims_vec
=
vectorize
(
in_dims
);
}
return
phi
::
make_ddim
(
new_dims_vec
);
}
template
<
typename
Context
,
typename
T
,
int
Rank
>
static
void
DistFunction
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
const
DenseTensor
&
y
,
float
p
,
DenseTensor
*
out
)
{
if
(
out
)
{
dev_ctx
.
template
Alloc
<
T
>(
out
);
}
auto
x_dims
=
x
.
dims
();
auto
y_dims
=
y
.
dims
();
// new dims with same size as rank, e.g. (rank=3, (4, 3) => (1, 4, 3))
phi
::
DDim
x_new_dims
=
GetNewDims
(
x_dims
,
Rank
);
phi
::
DDim
y_new_dims
=
GetNewDims
(
y_dims
,
Rank
);
auto
x_t
=
ETensor
<
T
,
Rank
>::
From
(
x
,
x_new_dims
);
auto
y_t
=
ETensor
<
T
,
Rank
>::
From
(
y
,
y_new_dims
);
auto
out_t
=
ETensor
<
T
,
1
>::
From
(
*
out
);
auto
&
place
=
*
dev_ctx
.
eigen_device
();
Eigen
::
DSizes
<
int
,
Rank
>
x_bcast_dims
;
Eigen
::
DSizes
<
int
,
Rank
>
y_bcast_dims
;
GetBraodcastDims
<
Rank
>
(
x_new_dims
,
y_new_dims
,
&
x_bcast_dims
,
&
y_bcast_dims
);
// p=0 means number of non-zero elements of (x-y)
// p=inf means the maximum of |x-y|
// p=-inf means the minimum of |x-y|
// otherwise, Lp-norm = pow(sum(pow(|x-y|, p)), 1/p)
if
(
p
==
0
)
{
out_t
.
device
(
place
)
=
(
x_t
.
broadcast
(
x_bcast_dims
)
!=
y_t
.
broadcast
(
y_bcast_dims
))
.
template
cast
<
T
>()
.
sum
();
}
else
if
(
p
==
INFINITY
)
{
out_t
.
device
(
place
)
=
(
x_t
.
broadcast
(
x_bcast_dims
)
-
y_t
.
broadcast
(
y_bcast_dims
))
.
abs
()
.
maximum
();
}
else
if
(
p
==
-
INFINITY
)
{
out_t
.
device
(
place
)
=
(
x_t
.
broadcast
(
x_bcast_dims
)
-
y_t
.
broadcast
(
y_bcast_dims
))
.
abs
()
.
minimum
();
}
else
{
out_t
.
device
(
place
)
=
(
x_t
.
broadcast
(
x_bcast_dims
)
-
y_t
.
broadcast
(
y_bcast_dims
))
.
abs
()
.
pow
(
p
)
.
sum
()
.
pow
(
1.0
/
p
);
}
}
template
<
typename
T
,
typename
Context
>
void
DistKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
const
DenseTensor
&
y
,
float
p
,
DenseTensor
*
out
)
{
auto
x_rank
=
x
.
dims
().
size
();
auto
y_rank
=
y
.
dims
().
size
();
auto
rank
=
std
::
max
(
x_rank
,
y_rank
);
PADDLE_ENFORCE_LE
(
rank
,
6
,
phi
::
errors
::
Unimplemented
(
"Op(dist) only support tensors with no more than 6 "
"dimensions, but X's rank is %d, Y's rank is %d."
,
x_rank
,
y_rank
));
switch
(
rank
)
{
case
1
:
DistFunction
<
Context
,
T
,
1
>
(
dev_ctx
,
x
,
y
,
p
,
out
);
break
;
case
2
:
DistFunction
<
Context
,
T
,
2
>
(
dev_ctx
,
x
,
y
,
p
,
out
);
break
;
case
3
:
DistFunction
<
Context
,
T
,
3
>
(
dev_ctx
,
x
,
y
,
p
,
out
);
break
;
case
4
:
DistFunction
<
Context
,
T
,
4
>
(
dev_ctx
,
x
,
y
,
p
,
out
);
break
;
case
5
:
DistFunction
<
Context
,
T
,
5
>
(
dev_ctx
,
x
,
y
,
p
,
out
);
break
;
case
6
:
DistFunction
<
Context
,
T
,
6
>
(
dev_ctx
,
x
,
y
,
p
,
out
);
break
;
}
}
}
// namespace phi
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录