Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
ee8eeb45
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
ee8eeb45
编写于
3月 30, 2022
作者:
C
Chen Weihang
提交者:
GitHub
3月 30, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Revert "Revert "[Phi] trans logsumexp op (#40790)" (#41068)" (#41109)
This reverts commit
054fc997
.
上级
91bb52cd
变更
14
隐藏空白更改
内联
并排
Showing
14 changed file
with
440 addition
and
270 deletion
+440
-270
paddle/fluid/operators/reduce_ops/logsumexp_op.cc
paddle/fluid/operators/reduce_ops/logsumexp_op.cc
+8
-85
paddle/fluid/operators/reduce_ops/logsumexp_op.h
paddle/fluid/operators/reduce_ops/logsumexp_op.h
+0
-170
paddle/fluid/operators/reduce_ops/logsumexp_op_xpu.cc
paddle/fluid/operators/reduce_ops/logsumexp_op_xpu.cc
+1
-1
paddle/phi/infermeta/unary.cc
paddle/phi/infermeta/unary.cc
+85
-0
paddle/phi/infermeta/unary.h
paddle/phi/infermeta/unary.h
+6
-0
paddle/phi/kernels/cpu/logsumexp_grad_kernel.cc
paddle/phi/kernels/cpu/logsumexp_grad_kernel.cc
+7
-8
paddle/phi/kernels/cpu/logsumexp_kernel.cc
paddle/phi/kernels/cpu/logsumexp_kernel.cc
+8
-6
paddle/phi/kernels/gpu/logsumexp_grad_kernel.cu
paddle/phi/kernels/gpu/logsumexp_grad_kernel.cu
+22
-0
paddle/phi/kernels/gpu/logsumexp_kernel.cu
paddle/phi/kernels/gpu/logsumexp_kernel.cu
+23
-0
paddle/phi/kernels/impl/logsumexp_grad_kernel_impl.h
paddle/phi/kernels/impl/logsumexp_grad_kernel_impl.h
+91
-0
paddle/phi/kernels/impl/logsumexp_kernel_impl.h
paddle/phi/kernels/impl/logsumexp_kernel_impl.h
+100
-0
paddle/phi/kernels/logsumexp_grad_kernel.h
paddle/phi/kernels/logsumexp_grad_kernel.h
+31
-0
paddle/phi/kernels/logsumexp_kernel.h
paddle/phi/kernels/logsumexp_kernel.h
+29
-0
paddle/phi/ops/compat/logsumexp_sig.cc
paddle/phi/ops/compat/logsumexp_sig.cc
+29
-0
未找到文件。
paddle/fluid/operators/reduce_ops/logsumexp_op.cc
浏览文件 @
ee8eeb45
...
...
@@ -12,10 +12,13 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/reduce_ops/logsumexp_op.h"
#include <algorithm>
#include <string>
#include <vector>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op_function.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/unary.h"
namespace
paddle
{
namespace
operators
{
...
...
@@ -23,80 +26,6 @@ namespace operators {
class
LogsumexpOp
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"X"
),
"Input"
,
"X"
,
"logsumexp"
);
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"Out"
),
"Output"
,
"Out"
,
"logsumexp"
);
auto
x_dims
=
ctx
->
GetInputDim
(
"X"
);
auto
x_rank
=
x_dims
.
size
();
PADDLE_ENFORCE_LE
(
x_rank
,
4
,
platform
::
errors
::
InvalidArgument
(
"The input tensor X's dimensions of logsumexp "
"should be less or equal than 4. But received X's "
"dimensions = %d, X's shape = [%s]."
,
x_rank
,
x_dims
));
auto
axis
=
ctx
->
Attrs
().
Get
<
std
::
vector
<
int
>>
(
"axis"
);
PADDLE_ENFORCE_GT
(
axis
.
size
(),
0
,
platform
::
errors
::
InvalidArgument
(
"The size of axis of logsumexp "
"should be greater than 0. But received the size of axis "
"of logsumexp is %d."
,
axis
.
size
()));
for
(
size_t
i
=
0
;
i
<
axis
.
size
();
i
++
)
{
PADDLE_ENFORCE_LT
(
axis
[
i
],
x_rank
,
platform
::
errors
::
InvalidArgument
(
"axis[%d] should be in the "
"range [-D, D), where D is the dimensions of X and "
"D is %d. But received axis[%d] = %d."
,
i
,
x_rank
,
i
,
axis
[
i
]));
PADDLE_ENFORCE_GE
(
axis
[
i
],
-
x_rank
,
platform
::
errors
::
InvalidArgument
(
"axis[%d] should be in the "
"range [-D, D), where D is the dimensions of X and "
"D is %d. But received axis[%d] = %d."
,
i
,
x_rank
,
i
,
axis
[
i
]));
if
(
axis
[
i
]
<
0
)
{
axis
[
i
]
+=
x_rank
;
}
}
bool
keepdim
=
ctx
->
Attrs
().
Get
<
bool
>
(
"keepdim"
);
bool
reduce_all
=
ctx
->
Attrs
().
Get
<
bool
>
(
"reduce_all"
);
auto
dims_vector
=
vectorize
(
x_dims
);
if
(
reduce_all
)
{
if
(
keepdim
)
ctx
->
SetOutputDim
(
"Out"
,
phi
::
make_ddim
(
std
::
vector
<
int64_t
>
(
x_rank
,
1
)));
else
ctx
->
SetOutputDim
(
"Out"
,
{
1
});
}
else
{
auto
dims_vector
=
vectorize
(
x_dims
);
if
(
keepdim
)
{
for
(
size_t
i
=
0
;
i
<
axis
.
size
();
++
i
)
{
dims_vector
[
axis
[
i
]]
=
1
;
}
}
else
{
const
int
kDelFlag
=
-
1
;
for
(
size_t
i
=
0
;
i
<
axis
.
size
();
++
i
)
{
dims_vector
[
axis
[
i
]]
=
kDelFlag
;
}
dims_vector
.
erase
(
std
::
remove
(
dims_vector
.
begin
(),
dims_vector
.
end
(),
kDelFlag
),
dims_vector
.
end
());
}
if
(
!
keepdim
&&
dims_vector
.
size
()
==
0
)
{
dims_vector
.
push_back
(
1
);
}
auto
out_dims
=
phi
::
make_ddim
(
dims_vector
);
ctx
->
SetOutputDim
(
"Out"
,
out_dims
);
if
(
axis
.
size
()
>
0
&&
axis
[
0
]
!=
0
)
{
// Only pass LoD when not reducing on the first dim.
ctx
->
ShareLoD
(
"X"
,
/*->*/
"Out"
);
}
}
}
};
class
LogsumexpOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
...
...
@@ -164,16 +93,10 @@ class LogsumexpGradOpMaker : public framework::SingleGradOpMaker<T> {
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
DECLARE_INFER_SHAPE_FUNCTOR
(
logsumexp
,
LogsumexpInferShapeFunctor
,
PD_INFER_META
(
phi
::
LogsumexpInferMeta
));
REGISTER_OPERATOR
(
logsumexp
,
ops
::
LogsumexpOp
,
ops
::
LogsumexpOpMaker
,
ops
::
LogsumexpGradOpMaker
<
paddle
::
framework
::
OpDesc
>
,
ops
::
LogsumexpGradOpMaker
<
paddle
::
imperative
::
OpBase
>
);
ops
::
LogsumexpGradOpMaker
<
paddle
::
imperative
::
OpBase
>
,
LogsumexpInferShapeFunctor
);
REGISTER_OPERATOR
(
logsumexp_grad
,
ops
::
LogsumexpGrapOp
);
REGISTER_OP_CPU_KERNEL
(
logsumexp
,
ops
::
LogsumexpKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
ops
::
LogsumexpKernel
<
paddle
::
platform
::
CPUDeviceContext
,
double
>
);
REGISTER_OP_CPU_KERNEL
(
logsumexp_grad
,
ops
::
LogsumexpGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
ops
::
LogsumexpGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
double
>
);
paddle/fluid/operators/reduce_ops/logsumexp_op.h
已删除
100644 → 0
浏览文件 @
91bb52cd
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include <vector>
#include "paddle/fluid/operators/reduce_ops/reduce_op_function.h"
namespace
paddle
{
namespace
operators
{
#define HANDLE_DIM(NDIM, RDIM) \
if (ndim == NDIM && rdim == RDIM) { \
paddle::operators::ReduceFunctor<DeviceContext, OutT, NDIM, RDIM, \
LogsumexpFunctor>( \
context.template device_context<DeviceContext>(), *input, output, \
axis, keepdim); \
}
struct
LogsumexpFunctor
{
template
<
typename
DeviceContext
,
typename
X
,
typename
Y
,
typename
Dim
>
void
operator
()(
const
DeviceContext
&
place
,
X
*
x
,
Y
*
y
,
const
Dim
&
dim
)
{
auto
x_dim
=
x
->
dimensions
();
auto
t_dim
=
x_dim
;
for
(
int
i
=
0
;
i
<
static_cast
<
int
>
(
dim
.
size
());
i
++
)
{
t_dim
[
dim
[
i
]]
=
1
;
}
auto
r_dim
=
x_dim
;
for
(
int
i
=
0
;
i
<
static_cast
<
int
>
(
r_dim
.
size
());
i
++
)
{
r_dim
[
i
]
=
1
;
}
for
(
int
i
=
0
;
i
<
static_cast
<
int
>
(
dim
.
size
());
i
++
)
{
r_dim
[
dim
[
i
]]
=
x_dim
[
dim
[
i
]];
}
auto
y_dim
=
y
->
dimensions
();
auto
x_max
=
x
->
maximum
(
dim
);
y
->
device
(
place
)
=
(
x_max
+
(
*
x
-
x_max
.
reshape
(
t_dim
).
broadcast
(
r_dim
)).
exp
().
sum
(
dim
).
log
())
.
reshape
(
y_dim
);
}
};
struct
LogsumexpGradFunctor
{
template
<
typename
DeviceContext
,
typename
X
,
typename
Y
,
typename
DX
,
typename
DY
,
typename
Dim
>
void
operator
()(
const
DeviceContext
&
place
,
X
*
x
,
Y
*
y
,
DX
*
dx
,
DY
*
dy
,
const
Dim
&
dim
,
int
size
)
{
dx
->
device
(
place
)
=
dy
->
broadcast
(
dim
)
*
(
*
x
-
y
->
broadcast
(
dim
)).
exp
();
}
};
template
<
typename
DeviceContext
,
typename
OutT
>
class
LogsumexpKernel
:
public
framework
::
OpKernel
<
OutT
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
auto
*
input
=
context
.
Input
<
Tensor
>
(
"X"
);
auto
*
output
=
context
.
Output
<
Tensor
>
(
"Out"
);
output
->
mutable_data
<
OutT
>
(
context
.
GetPlace
());
auto
axis
=
context
.
Attr
<
std
::
vector
<
int
>>
(
"axis"
);
auto
keepdim
=
context
.
Attr
<
bool
>
(
"keepdim"
);
auto
reduce_all
=
context
.
Attr
<
bool
>
(
"reduce_all"
);
const
auto
&
input_dim_size
=
input
->
dims
().
size
();
// The dims has full dim, set the reduce_all is True
reduce_all
|=
(
static_cast
<
const
int
>
(
axis
.
size
())
==
input_dim_size
);
if
(
reduce_all
)
{
// Flatten and reduce 1-D tensor
auto
x
=
EigenVector
<
OutT
>::
Flatten
(
*
input
);
auto
out
=
EigenScalar
<
OutT
>::
From
(
*
output
);
auto
&
place
=
*
context
.
template
device_context
<
DeviceContext
>().
eigen_device
();
auto
reduce_dim
=
Eigen
::
array
<
int
,
1
>
({{
0
}});
LogsumexpFunctor
()(
place
,
&
x
,
&
out
,
reduce_dim
);
}
else
{
int
ndim
=
input_dim_size
;
int
rdim
=
axis
.
size
();
// comments for accelerating compiling temporarily.
// HANDLE_DIM(6, 5);
// HANDLE_DIM(6, 4);
// HANDLE_DIM(6, 3);
// HANDLE_DIM(6, 2);
// HANDLE_DIM(6, 1);
// HANDLE_DIM(5, 4);
// HANDLE_DIM(5, 3);
// HANDLE_DIM(5, 2);
// HANDLE_DIM(5, 1);
HANDLE_DIM
(
4
,
3
);
HANDLE_DIM
(
4
,
2
);
HANDLE_DIM
(
4
,
1
);
HANDLE_DIM
(
3
,
2
);
HANDLE_DIM
(
3
,
1
);
HANDLE_DIM
(
2
,
1
);
}
}
};
template
<
typename
DeviceContext
,
typename
T
>
class
LogsumexpGradKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
auto
*
input
=
context
.
Input
<
Tensor
>
(
"X"
);
auto
*
output
=
context
.
Input
<
Tensor
>
(
"Out"
);
auto
*
output_grad
=
context
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Out"
));
auto
*
input_grad
=
context
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
input_grad
->
mutable_data
<
T
>
(
context
.
GetPlace
());
auto
axis
=
context
.
Attr
<
std
::
vector
<
int
>>
(
"axis"
);
auto
reduce_all
=
context
.
Attr
<
bool
>
(
"reduce_all"
);
const
auto
input_dim_size
=
context
.
Input
<
Tensor
>
(
"X"
)
->
dims
().
size
();
reduce_all
|=
(
static_cast
<
const
int
>
(
axis
.
size
())
==
input_dim_size
);
if
(
reduce_all
)
{
auto
x
=
EigenVector
<
T
>::
Flatten
(
*
input
);
auto
y
=
EigenVector
<
T
>::
Flatten
(
*
output
);
auto
dy
=
EigenVector
<
T
>::
Flatten
(
*
output_grad
);
auto
dx
=
EigenVector
<
T
>::
Flatten
(
*
input_grad
);
auto
&
place
=
*
context
.
template
device_context
<
DeviceContext
>().
eigen_device
();
auto
broadcast_dim
=
Eigen
::
array
<
int
,
1
>
({{
static_cast
<
int
>
(
input
->
numel
())}});
LogsumexpGradFunctor
()(
place
,
&
x
,
&
y
,
&
dx
,
&
dy
,
broadcast_dim
,
broadcast_dim
[
0
]);
}
else
{
int
rank
=
input
->
dims
().
size
();
LogsumexpGradFunctor
functor
;
switch
(
rank
)
{
case
1
:
ReduceGradFunctor
<
DeviceContext
,
T
,
1
,
LogsumexpGradFunctor
>
(
context
.
template
device_context
<
DeviceContext
>(),
*
input
,
*
output
,
*
output_grad
,
input_grad
,
functor
,
axis
);
break
;
case
2
:
ReduceGradFunctor
<
DeviceContext
,
T
,
2
,
LogsumexpGradFunctor
>
(
context
.
template
device_context
<
DeviceContext
>(),
*
input
,
*
output
,
*
output_grad
,
input_grad
,
functor
,
axis
);
break
;
case
3
:
ReduceGradFunctor
<
DeviceContext
,
T
,
3
,
LogsumexpGradFunctor
>
(
context
.
template
device_context
<
DeviceContext
>(),
*
input
,
*
output
,
*
output_grad
,
input_grad
,
functor
,
axis
);
break
;
case
4
:
ReduceGradFunctor
<
DeviceContext
,
T
,
4
,
LogsumexpGradFunctor
>
(
context
.
template
device_context
<
DeviceContext
>(),
*
input
,
*
output
,
*
output_grad
,
input_grad
,
functor
,
axis
);
break
;
}
}
}
};
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/reduce_ops/logsumexp_op_xpu.cc
浏览文件 @
ee8eeb45
...
...
@@ -14,7 +14,7 @@
#ifdef PADDLE_WITH_XPU
#include "paddle/fluid/operators/reduce_ops/
logsumexp_op
.h"
#include "paddle/fluid/operators/reduce_ops/
reduce_op_function
.h"
#include "paddle/fluid/platform/device/xpu/xpu_header.h"
#include "paddle/fluid/platform/device_context.h"
...
...
paddle/phi/infermeta/unary.cc
浏览文件 @
ee8eeb45
...
...
@@ -804,6 +804,91 @@ void KthvalueInferMeta(const MetaTensor& x,
indices
->
set_dtype
(
x
.
dtype
());
}
void
LogsumexpInferMeta
(
const
MetaTensor
&
input
,
const
std
::
vector
<
int64_t
>&
axis
,
bool
keepdim
,
bool
reduce_all
,
MetaTensor
*
out
)
{
auto
x_dims
=
input
.
dims
();
auto
x_rank
=
x_dims
.
size
();
std
::
vector
<
int64_t
>
formated_axis
=
axis
;
PADDLE_ENFORCE_LE
(
x_rank
,
4
,
errors
::
InvalidArgument
(
"The input tensor X's dimensions of logsumexp "
"should be less or equal than 4. But received X's "
"dimensions = %d, X's shape = [%s]."
,
x_rank
,
x_dims
));
PADDLE_ENFORCE_GT
(
axis
.
size
(),
0
,
errors
::
InvalidArgument
(
"The size of axis of logsumexp "
"should be greater than 0. But received the size of axis "
"of logsumexp is %d."
,
axis
.
size
()));
for
(
size_t
i
=
0
;
i
<
axis
.
size
();
i
++
)
{
PADDLE_ENFORCE_LT
(
axis
[
i
],
x_rank
,
errors
::
InvalidArgument
(
"axis[%d] should be in the "
"range [-D, D), where D is the dimensions of X and "
"D is %d. But received axis[%d] = %d."
,
i
,
x_rank
,
i
,
axis
[
i
]));
PADDLE_ENFORCE_GE
(
axis
[
i
],
-
x_rank
,
errors
::
InvalidArgument
(
"axis[%d] should be in the "
"range [-D, D), where D is the dimensions of X and "
"D is %d. But received axis[%d] = %d."
,
i
,
x_rank
,
i
,
axis
[
i
]));
if
(
axis
[
i
]
<
0
)
{
formated_axis
[
i
]
+=
x_rank
;
}
}
auto
dims_vector
=
vectorize
(
x_dims
);
if
(
reduce_all
)
{
if
(
keepdim
)
out
->
set_dims
(
phi
::
make_ddim
(
std
::
vector
<
int64_t
>
(
x_rank
,
1
)));
else
out
->
set_dims
({
1
});
}
else
{
auto
dims_vector
=
vectorize
(
x_dims
);
if
(
keepdim
)
{
for
(
size_t
i
=
0
;
i
<
formated_axis
.
size
();
++
i
)
{
dims_vector
[
formated_axis
[
i
]]
=
1
;
}
}
else
{
const
int
kDelFlag
=
-
1
;
for
(
size_t
i
=
0
;
i
<
formated_axis
.
size
();
++
i
)
{
dims_vector
[
formated_axis
[
i
]]
=
kDelFlag
;
}
dims_vector
.
erase
(
std
::
remove
(
dims_vector
.
begin
(),
dims_vector
.
end
(),
kDelFlag
),
dims_vector
.
end
());
}
if
(
!
keepdim
&&
dims_vector
.
size
()
==
0
)
{
dims_vector
.
push_back
(
1
);
}
auto
out_dims
=
phi
::
make_ddim
(
dims_vector
);
out
->
set_dims
(
out_dims
);
if
(
formated_axis
.
size
()
>
0
&&
formated_axis
[
0
]
!=
0
)
{
// Only pass LoD when not reducing on the first dim.
out
->
share_lod
(
input
);
}
}
out
->
set_dtype
(
input
.
dtype
());
}
void
MatrixPowerInferMeta
(
const
MetaTensor
&
x
,
int
n
,
MetaTensor
*
out
)
{
auto
dims
=
x
.
dims
();
auto
n_dim
=
dims
.
size
();
...
...
paddle/phi/infermeta/unary.h
浏览文件 @
ee8eeb45
...
...
@@ -136,6 +136,12 @@ void KthvalueInferMeta(const MetaTensor& x,
MetaTensor
*
indices
,
MetaConfig
=
MetaConfig
());
void
LogsumexpInferMeta
(
const
MetaTensor
&
input
,
const
std
::
vector
<
int64_t
>&
axis
,
bool
keepdim
,
bool
reduce_all
,
MetaTensor
*
out
);
void
MatrixPowerInferMeta
(
const
MetaTensor
&
x
,
int
n
,
MetaTensor
*
out
);
void
MaxOutInferMeta
(
const
MetaTensor
&
x
,
...
...
paddle/
fluid/operators/reduce_ops/logsumexp_op.part.cu
→
paddle/
phi/kernels/cpu/logsumexp_grad_kernel.cc
浏览文件 @
ee8eeb45
// Copyright (c) 20
18
PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 20
22
PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
...
...
@@ -12,12 +12,11 @@
// See the License for the specific language governing permissions and
// limitations under the License.
// .part used to speed up nvcc compile
#include "paddle/fluid/operators/reduce_ops/logsumexp_op.h"
#include "paddle/phi/kernels/logsumexp_grad_kernel.h"
namespace
ops
=
paddle
::
operators
;
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/logsumexp_grad_kernel_impl.h"
REGISTER_OP_CUDA_KERNEL
(
logsumexp_grad
,
ops
::
LogsumexpGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
,
ops
::
LogsumexpGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
double
>
);
PD_REGISTER_KERNEL
(
logsumexp_grad
,
CPU
,
ALL_LAYOUT
,
phi
::
LogsumexpGradKernel
,
float
,
double
)
{}
paddle/
fluid/operators/reduce_ops/logsumexp_op.cu
→
paddle/
phi/kernels/cpu/logsumexp_kernel.cc
浏览文件 @
ee8eeb45
// Copyright (c) 20
18
PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 20
22
PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
...
...
@@ -12,10 +12,12 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/
fluid/operators/reduce_ops/logsumexp_op
.h"
#include "paddle/
phi/kernels/logsumexp_kernel
.h"
namespace
ops
=
paddle
::
operators
;
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
REGISTER_OP_CUDA_KERNEL
(
logsumexp
,
ops
::
LogsumexpKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
,
ops
::
LogsumexpKernel
<
paddle
::
platform
::
CUDADeviceContext
,
double
>
);
#include "paddle/phi/kernels/impl/logsumexp_kernel_impl.h"
PD_REGISTER_KERNEL
(
logsumexp
,
CPU
,
ALL_LAYOUT
,
phi
::
LogsumexpKernel
,
float
,
double
)
{}
paddle/phi/kernels/gpu/logsumexp_grad_kernel.cu
0 → 100644
浏览文件 @
ee8eeb45
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/logsumexp_grad_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/logsumexp_grad_kernel_impl.h"
PD_REGISTER_KERNEL
(
logsumexp_grad
,
GPU
,
ALL_LAYOUT
,
phi
::
LogsumexpGradKernel
,
float
,
double
)
{}
paddle/phi/kernels/gpu/logsumexp_kernel.cu
0 → 100644
浏览文件 @
ee8eeb45
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/logsumexp_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/logsumexp_kernel_impl.h"
PD_REGISTER_KERNEL
(
logsumexp
,
GPU
,
ALL_LAYOUT
,
phi
::
LogsumexpKernel
,
float
,
double
)
{}
paddle/phi/kernels/impl/logsumexp_grad_kernel_impl.h
0 → 100644
浏览文件 @
ee8eeb45
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <type_traits>
#include <vector>
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/eigen/eigen_function.h"
#include "paddle/phi/kernels/funcs/reduce_grad_functions.h"
#include "paddle/phi/kernels/logsumexp_grad_kernel.h"
namespace
phi
{
struct
LogsumexpGradFunctor
{
template
<
typename
Context
,
typename
X
,
typename
Y
,
typename
DX
,
typename
DY
,
typename
Dim
>
void
operator
()(
const
Context
&
place
,
X
*
x
,
Y
*
y
,
DX
*
dx
,
DY
*
dy
,
const
Dim
&
dim
,
int
size
)
{
dx
->
device
(
place
)
=
dy
->
broadcast
(
dim
)
*
(
*
x
-
y
->
broadcast
(
dim
)).
exp
();
}
};
template
<
typename
T
,
typename
Context
>
void
LogsumexpGradKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
in
,
const
DenseTensor
&
out
,
const
DenseTensor
&
out_grad
,
const
std
::
vector
<
int
>&
axis
,
bool
keepdim
,
bool
reduce_all
,
DenseTensor
*
in_grad
)
{
dev_ctx
.
template
Alloc
<
T
>(
in_grad
);
const
auto
input_dim_size
=
in
.
dims
().
size
();
reduce_all
|=
(
static_cast
<
const
int
>
(
axis
.
size
())
==
input_dim_size
);
if
(
reduce_all
)
{
auto
x
=
phi
::
EigenVector
<
T
>::
Flatten
(
in
);
auto
y
=
phi
::
EigenVector
<
T
>::
Flatten
(
out
);
auto
dy
=
phi
::
EigenVector
<
T
>::
Flatten
(
out_grad
);
auto
dx
=
phi
::
EigenVector
<
T
>::
Flatten
(
*
in_grad
);
auto
&
place
=
*
dev_ctx
.
eigen_device
();
auto
broadcast_dim
=
Eigen
::
array
<
int
,
1
>
({{
static_cast
<
int
>
(
in
.
numel
())}});
LogsumexpGradFunctor
()(
place
,
&
x
,
&
y
,
&
dx
,
&
dy
,
broadcast_dim
,
broadcast_dim
[
0
]);
}
else
{
int
rank
=
in
.
dims
().
size
();
LogsumexpGradFunctor
functor
;
switch
(
rank
)
{
case
1
:
phi
::
funcs
::
ReduceGradFunctor
<
Context
,
T
,
1
,
LogsumexpGradFunctor
>
(
dev_ctx
,
in
,
out
,
out_grad
,
in_grad
,
functor
,
axis
);
break
;
case
2
:
phi
::
funcs
::
ReduceGradFunctor
<
Context
,
T
,
2
,
LogsumexpGradFunctor
>
(
dev_ctx
,
in
,
out
,
out_grad
,
in_grad
,
functor
,
axis
);
break
;
case
3
:
phi
::
funcs
::
ReduceGradFunctor
<
Context
,
T
,
3
,
LogsumexpGradFunctor
>
(
dev_ctx
,
in
,
out
,
out_grad
,
in_grad
,
functor
,
axis
);
break
;
case
4
:
phi
::
funcs
::
ReduceGradFunctor
<
Context
,
T
,
4
,
LogsumexpGradFunctor
>
(
dev_ctx
,
in
,
out
,
out_grad
,
in_grad
,
functor
,
axis
);
break
;
}
}
}
}
// namespace phi
paddle/phi/kernels/impl/logsumexp_kernel_impl.h
0 → 100644
浏览文件 @
ee8eeb45
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <type_traits>
#include <vector>
#include "paddle/phi/kernels/cpu/reduce.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/eigen/eigen_function.h"
#include "paddle/phi/kernels/logsumexp_kernel.h"
namespace
phi
{
#define HANDLE_DIM(NDIM, RDIM) \
if (ndim == NDIM && rdim == RDIM) { \
ReduceFunctor<Context, T, NDIM, RDIM, LogsumexpFunctor>( \
dev_ctx, x, out, axis, keepdim); \
}
struct
LogsumexpFunctor
{
template
<
typename
Context
,
typename
X
,
typename
Y
,
typename
Dim
>
void
operator
()(
const
Context
&
place
,
X
*
x
,
Y
*
y
,
const
Dim
&
dim
)
{
auto
x_dim
=
x
->
dimensions
();
auto
t_dim
=
x_dim
;
for
(
int
i
=
0
;
i
<
static_cast
<
int
>
(
dim
.
size
());
i
++
)
{
t_dim
[
dim
[
i
]]
=
1
;
}
auto
r_dim
=
x_dim
;
for
(
int
i
=
0
;
i
<
static_cast
<
int
>
(
r_dim
.
size
());
i
++
)
{
r_dim
[
i
]
=
1
;
}
for
(
int
i
=
0
;
i
<
static_cast
<
int
>
(
dim
.
size
());
i
++
)
{
r_dim
[
dim
[
i
]]
=
x_dim
[
dim
[
i
]];
}
auto
y_dim
=
y
->
dimensions
();
auto
x_max
=
x
->
maximum
(
dim
);
y
->
device
(
place
)
=
(
x_max
+
(
*
x
-
x_max
.
reshape
(
t_dim
).
broadcast
(
r_dim
)).
exp
().
sum
(
dim
).
log
())
.
reshape
(
y_dim
);
}
};
template
<
typename
T
,
typename
Context
>
void
LogsumexpKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
const
std
::
vector
<
int64_t
>&
axis
,
bool
keepdim
,
bool
reduce_all
,
DenseTensor
*
out
)
{
dev_ctx
.
template
Alloc
<
T
>(
out
);
const
auto
&
input_dim_size
=
x
.
dims
().
size
();
// The dims has full dim, set the reduce_all is True
reduce_all
|=
(
static_cast
<
const
int
>
(
axis
.
size
())
==
input_dim_size
);
if
(
reduce_all
)
{
// Flatten and reduce 1-D tensor
auto
input
=
phi
::
EigenVector
<
T
>::
Flatten
(
x
);
auto
output
=
phi
::
EigenScalar
<
T
>::
From
(
*
out
);
auto
&
place
=
*
dev_ctx
.
eigen_device
();
auto
reduce_dim
=
Eigen
::
array
<
int
,
1
>
({{
0
}});
LogsumexpFunctor
()(
place
,
&
input
,
&
output
,
reduce_dim
);
}
else
{
int
ndim
=
input_dim_size
;
int
rdim
=
axis
.
size
();
// comments for accelerating compiling temporarily.
// HANDLE_DIM(6, 5);
// HANDLE_DIM(6, 4);
// HANDLE_DIM(6, 3);
// HANDLE_DIM(6, 2);
// HANDLE_DIM(6, 1);
// HANDLE_DIM(5, 4);
// HANDLE_DIM(5, 3);
// HANDLE_DIM(5, 2);
// HANDLE_DIM(5, 1);
HANDLE_DIM
(
4
,
3
);
HANDLE_DIM
(
4
,
2
);
HANDLE_DIM
(
4
,
1
);
HANDLE_DIM
(
3
,
2
);
HANDLE_DIM
(
3
,
1
);
HANDLE_DIM
(
2
,
1
);
}
}
}
// namespace phi
paddle/phi/kernels/logsumexp_grad_kernel.h
0 → 100644
浏览文件 @
ee8eeb45
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace
phi
{
template
<
typename
T
,
typename
Context
>
void
LogsumexpGradKernel
(
const
Context
&
ctx
,
const
DenseTensor
&
in
,
const
DenseTensor
&
out
,
const
DenseTensor
&
out_grad
,
const
std
::
vector
<
int
>&
axis
,
bool
keepdim
,
bool
reduce_all
,
DenseTensor
*
in_grad
);
}
// namespace phi
paddle/phi/kernels/logsumexp_kernel.h
0 → 100644
浏览文件 @
ee8eeb45
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace
phi
{
template
<
typename
T
,
typename
Context
>
void
LogsumexpKernel
(
const
Context
&
ctx
,
const
DenseTensor
&
x
,
const
std
::
vector
<
int64_t
>&
axis
,
bool
keepdim
,
bool
reduce_all
,
DenseTensor
*
out
);
}
// namespace phi
paddle/phi/ops/compat/logsumexp_sig.cc
0 → 100644
浏览文件 @
ee8eeb45
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace
phi
{
KernelSignature
LogsumexpGradOpArgumentMapping
(
const
ArgumentMappingContext
&
ctx
)
{
return
KernelSignature
(
"logsumexp_grad"
,
{
"X"
,
"Out"
,
GradVarName
(
"Out"
)},
{
"axis"
,
"keepdim"
,
"reduce_all"
},
{
GradVarName
(
"X"
)});
}
}
// namespace phi
PD_REGISTER_ARG_MAPPING_FN
(
logsumexp_grad
,
phi
::
LogsumexpGradOpArgumentMapping
);
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录