未验证 提交 ee8eeb45 编写于 作者: C Chen Weihang 提交者: GitHub

Revert "Revert "[Phi] trans logsumexp op (#40790)" (#41068)" (#41109)

This reverts commit 054fc997.
上级 91bb52cd
......@@ -12,10 +12,13 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/reduce_ops/logsumexp_op.h"
#include <algorithm>
#include <string>
#include <vector>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op_function.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/unary.h"
namespace paddle {
namespace operators {
......@@ -23,80 +26,6 @@ namespace operators {
class LogsumexpOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "logsumexp");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "logsumexp");
auto x_dims = ctx->GetInputDim("X");
auto x_rank = x_dims.size();
PADDLE_ENFORCE_LE(x_rank, 4,
platform::errors::InvalidArgument(
"The input tensor X's dimensions of logsumexp "
"should be less or equal than 4. But received X's "
"dimensions = %d, X's shape = [%s].",
x_rank, x_dims));
auto axis = ctx->Attrs().Get<std::vector<int>>("axis");
PADDLE_ENFORCE_GT(
axis.size(), 0,
platform::errors::InvalidArgument(
"The size of axis of logsumexp "
"should be greater than 0. But received the size of axis "
"of logsumexp is %d.",
axis.size()));
for (size_t i = 0; i < axis.size(); i++) {
PADDLE_ENFORCE_LT(axis[i], x_rank,
platform::errors::InvalidArgument(
"axis[%d] should be in the "
"range [-D, D), where D is the dimensions of X and "
"D is %d. But received axis[%d] = %d.",
i, x_rank, i, axis[i]));
PADDLE_ENFORCE_GE(axis[i], -x_rank,
platform::errors::InvalidArgument(
"axis[%d] should be in the "
"range [-D, D), where D is the dimensions of X and "
"D is %d. But received axis[%d] = %d.",
i, x_rank, i, axis[i]));
if (axis[i] < 0) {
axis[i] += x_rank;
}
}
bool keepdim = ctx->Attrs().Get<bool>("keepdim");
bool reduce_all = ctx->Attrs().Get<bool>("reduce_all");
auto dims_vector = vectorize(x_dims);
if (reduce_all) {
if (keepdim)
ctx->SetOutputDim("Out",
phi::make_ddim(std::vector<int64_t>(x_rank, 1)));
else
ctx->SetOutputDim("Out", {1});
} else {
auto dims_vector = vectorize(x_dims);
if (keepdim) {
for (size_t i = 0; i < axis.size(); ++i) {
dims_vector[axis[i]] = 1;
}
} else {
const int kDelFlag = -1;
for (size_t i = 0; i < axis.size(); ++i) {
dims_vector[axis[i]] = kDelFlag;
}
dims_vector.erase(
std::remove(dims_vector.begin(), dims_vector.end(), kDelFlag),
dims_vector.end());
}
if (!keepdim && dims_vector.size() == 0) {
dims_vector.push_back(1);
}
auto out_dims = phi::make_ddim(dims_vector);
ctx->SetOutputDim("Out", out_dims);
if (axis.size() > 0 && axis[0] != 0) {
// Only pass LoD when not reducing on the first dim.
ctx->ShareLoD("X", /*->*/ "Out");
}
}
}
};
class LogsumexpOpMaker : public framework::OpProtoAndCheckerMaker {
......@@ -164,16 +93,10 @@ class LogsumexpGradOpMaker : public framework::SingleGradOpMaker<T> {
} // namespace paddle
namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(logsumexp, LogsumexpInferShapeFunctor,
PD_INFER_META(phi::LogsumexpInferMeta));
REGISTER_OPERATOR(logsumexp, ops::LogsumexpOp, ops::LogsumexpOpMaker,
ops::LogsumexpGradOpMaker<paddle::framework::OpDesc>,
ops::LogsumexpGradOpMaker<paddle::imperative::OpBase>);
ops::LogsumexpGradOpMaker<paddle::imperative::OpBase>,
LogsumexpInferShapeFunctor);
REGISTER_OPERATOR(logsumexp_grad, ops::LogsumexpGrapOp);
REGISTER_OP_CPU_KERNEL(
logsumexp, ops::LogsumexpKernel<paddle::platform::CPUDeviceContext, float>,
ops::LogsumexpKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(
logsumexp_grad,
ops::LogsumexpGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::LogsumexpGradKernel<paddle::platform::CPUDeviceContext, double>);
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include <vector>
#include "paddle/fluid/operators/reduce_ops/reduce_op_function.h"
namespace paddle {
namespace operators {
#define HANDLE_DIM(NDIM, RDIM) \
if (ndim == NDIM && rdim == RDIM) { \
paddle::operators::ReduceFunctor<DeviceContext, OutT, NDIM, RDIM, \
LogsumexpFunctor>( \
context.template device_context<DeviceContext>(), *input, output, \
axis, keepdim); \
}
struct LogsumexpFunctor {
template <typename DeviceContext, typename X, typename Y, typename Dim>
void operator()(const DeviceContext& place, X* x, Y* y, const Dim& dim) {
auto x_dim = x->dimensions();
auto t_dim = x_dim;
for (int i = 0; i < static_cast<int>(dim.size()); i++) {
t_dim[dim[i]] = 1;
}
auto r_dim = x_dim;
for (int i = 0; i < static_cast<int>(r_dim.size()); i++) {
r_dim[i] = 1;
}
for (int i = 0; i < static_cast<int>(dim.size()); i++) {
r_dim[dim[i]] = x_dim[dim[i]];
}
auto y_dim = y->dimensions();
auto x_max = x->maximum(dim);
y->device(place) =
(x_max +
(*x - x_max.reshape(t_dim).broadcast(r_dim)).exp().sum(dim).log())
.reshape(y_dim);
}
};
struct LogsumexpGradFunctor {
template <typename DeviceContext, typename X, typename Y, typename DX,
typename DY, typename Dim>
void operator()(const DeviceContext& place, X* x, Y* y, DX* dx, DY* dy,
const Dim& dim, int size) {
dx->device(place) = dy->broadcast(dim) * (*x - y->broadcast(dim)).exp();
}
};
template <typename DeviceContext, typename OutT>
class LogsumexpKernel : public framework::OpKernel<OutT> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* input = context.Input<Tensor>("X");
auto* output = context.Output<Tensor>("Out");
output->mutable_data<OutT>(context.GetPlace());
auto axis = context.Attr<std::vector<int>>("axis");
auto keepdim = context.Attr<bool>("keepdim");
auto reduce_all = context.Attr<bool>("reduce_all");
const auto& input_dim_size = input->dims().size();
// The dims has full dim, set the reduce_all is True
reduce_all |= (static_cast<const int>(axis.size()) == input_dim_size);
if (reduce_all) {
// Flatten and reduce 1-D tensor
auto x = EigenVector<OutT>::Flatten(*input);
auto out = EigenScalar<OutT>::From(*output);
auto& place =
*context.template device_context<DeviceContext>().eigen_device();
auto reduce_dim = Eigen::array<int, 1>({{0}});
LogsumexpFunctor()(place, &x, &out, reduce_dim);
} else {
int ndim = input_dim_size;
int rdim = axis.size();
// comments for accelerating compiling temporarily.
// HANDLE_DIM(6, 5);
// HANDLE_DIM(6, 4);
// HANDLE_DIM(6, 3);
// HANDLE_DIM(6, 2);
// HANDLE_DIM(6, 1);
// HANDLE_DIM(5, 4);
// HANDLE_DIM(5, 3);
// HANDLE_DIM(5, 2);
// HANDLE_DIM(5, 1);
HANDLE_DIM(4, 3);
HANDLE_DIM(4, 2);
HANDLE_DIM(4, 1);
HANDLE_DIM(3, 2);
HANDLE_DIM(3, 1);
HANDLE_DIM(2, 1);
}
}
};
template <typename DeviceContext, typename T>
class LogsumexpGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* input = context.Input<Tensor>("X");
auto* output = context.Input<Tensor>("Out");
auto* output_grad = context.Input<Tensor>(framework::GradVarName("Out"));
auto* input_grad = context.Output<Tensor>(framework::GradVarName("X"));
input_grad->mutable_data<T>(context.GetPlace());
auto axis = context.Attr<std::vector<int>>("axis");
auto reduce_all = context.Attr<bool>("reduce_all");
const auto input_dim_size = context.Input<Tensor>("X")->dims().size();
reduce_all |= (static_cast<const int>(axis.size()) == input_dim_size);
if (reduce_all) {
auto x = EigenVector<T>::Flatten(*input);
auto y = EigenVector<T>::Flatten(*output);
auto dy = EigenVector<T>::Flatten(*output_grad);
auto dx = EigenVector<T>::Flatten(*input_grad);
auto& place =
*context.template device_context<DeviceContext>().eigen_device();
auto broadcast_dim =
Eigen::array<int, 1>({{static_cast<int>(input->numel())}});
LogsumexpGradFunctor()(place, &x, &y, &dx, &dy, broadcast_dim,
broadcast_dim[0]);
} else {
int rank = input->dims().size();
LogsumexpGradFunctor functor;
switch (rank) {
case 1:
ReduceGradFunctor<DeviceContext, T, 1, LogsumexpGradFunctor>(
context.template device_context<DeviceContext>(), *input, *output,
*output_grad, input_grad, functor, axis);
break;
case 2:
ReduceGradFunctor<DeviceContext, T, 2, LogsumexpGradFunctor>(
context.template device_context<DeviceContext>(), *input, *output,
*output_grad, input_grad, functor, axis);
break;
case 3:
ReduceGradFunctor<DeviceContext, T, 3, LogsumexpGradFunctor>(
context.template device_context<DeviceContext>(), *input, *output,
*output_grad, input_grad, functor, axis);
break;
case 4:
ReduceGradFunctor<DeviceContext, T, 4, LogsumexpGradFunctor>(
context.template device_context<DeviceContext>(), *input, *output,
*output_grad, input_grad, functor, axis);
break;
}
}
}
};
} // namespace operators
} // namespace paddle
......@@ -14,7 +14,7 @@
#ifdef PADDLE_WITH_XPU
#include "paddle/fluid/operators/reduce_ops/logsumexp_op.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op_function.h"
#include "paddle/fluid/platform/device/xpu/xpu_header.h"
#include "paddle/fluid/platform/device_context.h"
......
......@@ -804,6 +804,91 @@ void KthvalueInferMeta(const MetaTensor& x,
indices->set_dtype(x.dtype());
}
void LogsumexpInferMeta(const MetaTensor& input,
const std::vector<int64_t>& axis,
bool keepdim,
bool reduce_all,
MetaTensor* out) {
auto x_dims = input.dims();
auto x_rank = x_dims.size();
std::vector<int64_t> formated_axis = axis;
PADDLE_ENFORCE_LE(x_rank,
4,
errors::InvalidArgument(
"The input tensor X's dimensions of logsumexp "
"should be less or equal than 4. But received X's "
"dimensions = %d, X's shape = [%s].",
x_rank,
x_dims));
PADDLE_ENFORCE_GT(
axis.size(),
0,
errors::InvalidArgument(
"The size of axis of logsumexp "
"should be greater than 0. But received the size of axis "
"of logsumexp is %d.",
axis.size()));
for (size_t i = 0; i < axis.size(); i++) {
PADDLE_ENFORCE_LT(axis[i],
x_rank,
errors::InvalidArgument(
"axis[%d] should be in the "
"range [-D, D), where D is the dimensions of X and "
"D is %d. But received axis[%d] = %d.",
i,
x_rank,
i,
axis[i]));
PADDLE_ENFORCE_GE(axis[i],
-x_rank,
errors::InvalidArgument(
"axis[%d] should be in the "
"range [-D, D), where D is the dimensions of X and "
"D is %d. But received axis[%d] = %d.",
i,
x_rank,
i,
axis[i]));
if (axis[i] < 0) {
formated_axis[i] += x_rank;
}
}
auto dims_vector = vectorize(x_dims);
if (reduce_all) {
if (keepdim)
out->set_dims(phi::make_ddim(std::vector<int64_t>(x_rank, 1)));
else
out->set_dims({1});
} else {
auto dims_vector = vectorize(x_dims);
if (keepdim) {
for (size_t i = 0; i < formated_axis.size(); ++i) {
dims_vector[formated_axis[i]] = 1;
}
} else {
const int kDelFlag = -1;
for (size_t i = 0; i < formated_axis.size(); ++i) {
dims_vector[formated_axis[i]] = kDelFlag;
}
dims_vector.erase(
std::remove(dims_vector.begin(), dims_vector.end(), kDelFlag),
dims_vector.end());
}
if (!keepdim && dims_vector.size() == 0) {
dims_vector.push_back(1);
}
auto out_dims = phi::make_ddim(dims_vector);
out->set_dims(out_dims);
if (formated_axis.size() > 0 && formated_axis[0] != 0) {
// Only pass LoD when not reducing on the first dim.
out->share_lod(input);
}
}
out->set_dtype(input.dtype());
}
void MatrixPowerInferMeta(const MetaTensor& x, int n, MetaTensor* out) {
auto dims = x.dims();
auto n_dim = dims.size();
......
......@@ -136,6 +136,12 @@ void KthvalueInferMeta(const MetaTensor& x,
MetaTensor* indices,
MetaConfig = MetaConfig());
void LogsumexpInferMeta(const MetaTensor& input,
const std::vector<int64_t>& axis,
bool keepdim,
bool reduce_all,
MetaTensor* out);
void MatrixPowerInferMeta(const MetaTensor& x, int n, MetaTensor* out);
void MaxOutInferMeta(const MetaTensor& x,
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
......@@ -12,12 +12,11 @@
// See the License for the specific language governing permissions and
// limitations under the License.
// .part used to speed up nvcc compile
#include "paddle/fluid/operators/reduce_ops/logsumexp_op.h"
#include "paddle/phi/kernels/logsumexp_grad_kernel.h"
namespace ops = paddle::operators;
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/logsumexp_grad_kernel_impl.h"
REGISTER_OP_CUDA_KERNEL(
logsumexp_grad,
ops::LogsumexpGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::LogsumexpGradKernel<paddle::platform::CUDADeviceContext, double>);
PD_REGISTER_KERNEL(
logsumexp_grad, CPU, ALL_LAYOUT, phi::LogsumexpGradKernel, float, double) {}
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
......@@ -12,10 +12,12 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/reduce_ops/logsumexp_op.h"
#include "paddle/phi/kernels/logsumexp_kernel.h"
namespace ops = paddle::operators;
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
REGISTER_OP_CUDA_KERNEL(
logsumexp, ops::LogsumexpKernel<paddle::platform::CUDADeviceContext, float>,
ops::LogsumexpKernel<paddle::platform::CUDADeviceContext, double>);
#include "paddle/phi/kernels/impl/logsumexp_kernel_impl.h"
PD_REGISTER_KERNEL(
logsumexp, CPU, ALL_LAYOUT, phi::LogsumexpKernel, float, double) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/logsumexp_grad_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/logsumexp_grad_kernel_impl.h"
PD_REGISTER_KERNEL(
logsumexp_grad, GPU, ALL_LAYOUT, phi::LogsumexpGradKernel, float, double) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/logsumexp_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/logsumexp_kernel_impl.h"
PD_REGISTER_KERNEL(
logsumexp, GPU, ALL_LAYOUT, phi::LogsumexpKernel, float, double) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <type_traits>
#include <vector>
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/eigen/eigen_function.h"
#include "paddle/phi/kernels/funcs/reduce_grad_functions.h"
#include "paddle/phi/kernels/logsumexp_grad_kernel.h"
namespace phi {
struct LogsumexpGradFunctor {
template <typename Context,
typename X,
typename Y,
typename DX,
typename DY,
typename Dim>
void operator()(const Context& place,
X* x,
Y* y,
DX* dx,
DY* dy,
const Dim& dim,
int size) {
dx->device(place) = dy->broadcast(dim) * (*x - y->broadcast(dim)).exp();
}
};
template <typename T, typename Context>
void LogsumexpGradKernel(const Context& dev_ctx,
const DenseTensor& in,
const DenseTensor& out,
const DenseTensor& out_grad,
const std::vector<int>& axis,
bool keepdim,
bool reduce_all,
DenseTensor* in_grad) {
dev_ctx.template Alloc<T>(in_grad);
const auto input_dim_size = in.dims().size();
reduce_all |= (static_cast<const int>(axis.size()) == input_dim_size);
if (reduce_all) {
auto x = phi::EigenVector<T>::Flatten(in);
auto y = phi::EigenVector<T>::Flatten(out);
auto dy = phi::EigenVector<T>::Flatten(out_grad);
auto dx = phi::EigenVector<T>::Flatten(*in_grad);
auto& place = *dev_ctx.eigen_device();
auto broadcast_dim = Eigen::array<int, 1>({{static_cast<int>(in.numel())}});
LogsumexpGradFunctor()(
place, &x, &y, &dx, &dy, broadcast_dim, broadcast_dim[0]);
} else {
int rank = in.dims().size();
LogsumexpGradFunctor functor;
switch (rank) {
case 1:
phi::funcs::ReduceGradFunctor<Context, T, 1, LogsumexpGradFunctor>(
dev_ctx, in, out, out_grad, in_grad, functor, axis);
break;
case 2:
phi::funcs::ReduceGradFunctor<Context, T, 2, LogsumexpGradFunctor>(
dev_ctx, in, out, out_grad, in_grad, functor, axis);
break;
case 3:
phi::funcs::ReduceGradFunctor<Context, T, 3, LogsumexpGradFunctor>(
dev_ctx, in, out, out_grad, in_grad, functor, axis);
break;
case 4:
phi::funcs::ReduceGradFunctor<Context, T, 4, LogsumexpGradFunctor>(
dev_ctx, in, out, out_grad, in_grad, functor, axis);
break;
}
}
}
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <type_traits>
#include <vector>
#include "paddle/phi/kernels/cpu/reduce.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/eigen/eigen_function.h"
#include "paddle/phi/kernels/logsumexp_kernel.h"
namespace phi {
#define HANDLE_DIM(NDIM, RDIM) \
if (ndim == NDIM && rdim == RDIM) { \
ReduceFunctor<Context, T, NDIM, RDIM, LogsumexpFunctor>( \
dev_ctx, x, out, axis, keepdim); \
}
struct LogsumexpFunctor {
template <typename Context, typename X, typename Y, typename Dim>
void operator()(const Context& place, X* x, Y* y, const Dim& dim) {
auto x_dim = x->dimensions();
auto t_dim = x_dim;
for (int i = 0; i < static_cast<int>(dim.size()); i++) {
t_dim[dim[i]] = 1;
}
auto r_dim = x_dim;
for (int i = 0; i < static_cast<int>(r_dim.size()); i++) {
r_dim[i] = 1;
}
for (int i = 0; i < static_cast<int>(dim.size()); i++) {
r_dim[dim[i]] = x_dim[dim[i]];
}
auto y_dim = y->dimensions();
auto x_max = x->maximum(dim);
y->device(place) =
(x_max +
(*x - x_max.reshape(t_dim).broadcast(r_dim)).exp().sum(dim).log())
.reshape(y_dim);
}
};
template <typename T, typename Context>
void LogsumexpKernel(const Context& dev_ctx,
const DenseTensor& x,
const std::vector<int64_t>& axis,
bool keepdim,
bool reduce_all,
DenseTensor* out) {
dev_ctx.template Alloc<T>(out);
const auto& input_dim_size = x.dims().size();
// The dims has full dim, set the reduce_all is True
reduce_all |= (static_cast<const int>(axis.size()) == input_dim_size);
if (reduce_all) {
// Flatten and reduce 1-D tensor
auto input = phi::EigenVector<T>::Flatten(x);
auto output = phi::EigenScalar<T>::From(*out);
auto& place = *dev_ctx.eigen_device();
auto reduce_dim = Eigen::array<int, 1>({{0}});
LogsumexpFunctor()(place, &input, &output, reduce_dim);
} else {
int ndim = input_dim_size;
int rdim = axis.size();
// comments for accelerating compiling temporarily.
// HANDLE_DIM(6, 5);
// HANDLE_DIM(6, 4);
// HANDLE_DIM(6, 3);
// HANDLE_DIM(6, 2);
// HANDLE_DIM(6, 1);
// HANDLE_DIM(5, 4);
// HANDLE_DIM(5, 3);
// HANDLE_DIM(5, 2);
// HANDLE_DIM(5, 1);
HANDLE_DIM(4, 3);
HANDLE_DIM(4, 2);
HANDLE_DIM(4, 1);
HANDLE_DIM(3, 2);
HANDLE_DIM(3, 1);
HANDLE_DIM(2, 1);
}
}
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void LogsumexpGradKernel(const Context& ctx,
const DenseTensor& in,
const DenseTensor& out,
const DenseTensor& out_grad,
const std::vector<int>& axis,
bool keepdim,
bool reduce_all,
DenseTensor* in_grad);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void LogsumexpKernel(const Context& ctx,
const DenseTensor& x,
const std::vector<int64_t>& axis,
bool keepdim,
bool reduce_all,
DenseTensor* out);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature LogsumexpGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("logsumexp_grad",
{"X", "Out", GradVarName("Out")},
{"axis", "keepdim", "reduce_all"},
{GradVarName("X")});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(logsumexp_grad, phi::LogsumexpGradOpArgumentMapping);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册