未验证 提交 19a7524f 编写于 作者: T tiancaishaonvjituizi 提交者: GitHub

[Hackathon No.28] implement logcumsumexp (#42267)

上级 06de4891
...@@ -49,7 +49,7 @@ class CumsumOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -49,7 +49,7 @@ class CumsumOpMaker : public framework::OpProtoAndCheckerMaker {
AddComment(R"DOC( AddComment(R"DOC(
The cumulative sum of the elements along a given axis. The cumulative sum of the elements along a given axis.
By default, the first element of the result is the same of the first element of By default, the first element of the result is the same of the first element of
the input. If exlusive is true, the first element of the result is 0. the input. If exclusive is true, the first element of the result is 0.
)DOC"); )DOC");
} }
}; };
...@@ -74,17 +74,87 @@ class CumsumGradMaker : public framework::SingleGradOpMaker<T> { ...@@ -74,17 +74,87 @@ class CumsumGradMaker : public framework::SingleGradOpMaker<T> {
} }
}; };
class LogcumsumexpOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "Input of logcumsumexp operator");
AddOutput("Out", "Output of logcumsumexp operator");
AddAttr<int>("axis",
"The dimension to accumulate along. -1 means the last "
"dimension [default -1].")
.SetDefault(-1);
AddAttr<bool>("flatten",
"Whether to compute the logcumsumexp over the flattened array. "
"[default false].")
.SetDefault(false);
AddAttr<bool>("exclusive",
"Whether to perform exclusive logcumsumexp. [default false].")
.SetDefault(false);
AddAttr<bool>("reverse",
"If true, the logcumsumexp is performed in the reversed direction. "
"[default false].")
.SetDefault(false);
AddComment(R"DOC(
Returns the logarithm of the cumulative summation of the exponentiation of elements of input along the given axis.
By default, the first element of the result is the same of the first element of
the input. If exclusive is true, the first element of the result is the lowest finite value of the dtype of output tensor.
)DOC");
}
};
class LogcumsumexpGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "logcumsumexp");
OP_INOUT_CHECK(ctx->HasInput("Out"), "Input", "Out", "logcumsumexp");
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input",
"Out@GRAD", "logcumsumexp");
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
}
};
template <typename T>
class LogcumsumexpGradMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> grad_op) const override {
grad_op->SetType("logcumsumexp_grad");
grad_op->SetInput("X", this->Input("X"));
grad_op->SetInput("Out", this->Output("Out"));
grad_op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
grad_op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
grad_op->SetAttr("axis", BOOST_GET_CONST(int, this->GetAttr("axis")));
grad_op->SetAttr("flatten",
BOOST_GET_CONST(bool, this->GetAttr("flatten")));
grad_op->SetAttr("exclusive",
BOOST_GET_CONST(bool, this->GetAttr("exclusive")));
grad_op->SetAttr("reverse",
BOOST_GET_CONST(bool, this->GetAttr("reverse")));
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
using CPU = paddle::platform::CPUDeviceContext; using CPU = paddle::platform::CPUDeviceContext;
DECLARE_INFER_SHAPE_FUNCTOR(cumsum, CumsumInferShapeFunctor, DECLARE_INFER_SHAPE_FUNCTOR(cumsum, CumsumInferShapeFunctor,
PD_INFER_META(phi::CumsumInferMeta)); PD_INFER_META(phi::CumInferMeta));
DECLARE_INFER_SHAPE_FUNCTOR(logcumsumexp, LogcumsumexpInferShapeFunctor,
PD_INFER_META(phi::CumInferMeta));
REGISTER_OPERATOR(cumsum, ops::CumOp, ops::CumsumOpMaker, REGISTER_OPERATOR(cumsum, ops::CumOp, ops::CumsumOpMaker,
ops::CumsumGradMaker<paddle::framework::OpDesc>, ops::CumsumGradMaker<paddle::framework::OpDesc>,
ops::CumsumGradMaker<paddle::imperative::OpBase>, ops::CumsumGradMaker<paddle::imperative::OpBase>,
CumsumInferShapeFunctor); CumsumInferShapeFunctor);
REGISTER_OPERATOR(logcumsumexp, ops::CumOp, ops::LogcumsumexpOpMaker,
ops::LogcumsumexpGradMaker<paddle::framework::OpDesc>,
ops::LogcumsumexpGradMaker<paddle::imperative::OpBase>,
LogcumsumexpInferShapeFunctor);
REGISTER_OPERATOR(logcumsumexp_grad, ops::LogcumsumexpGradOp);
REGISTER_OP_VERSION(cumsum).AddCheckpoint( REGISTER_OP_VERSION(cumsum).AddCheckpoint(
R"ROC( R"ROC(
......
...@@ -235,12 +235,12 @@ void CreateLikeInferMeta(const MetaTensor& x, DataType dtype, MetaTensor* out) { ...@@ -235,12 +235,12 @@ void CreateLikeInferMeta(const MetaTensor& x, DataType dtype, MetaTensor* out) {
out->set_layout(x.layout()); out->set_layout(x.layout());
} }
void CumsumInferMeta(const MetaTensor& x, void CumInferMeta(const MetaTensor& x,
int axis, int axis,
bool flatten, bool flatten,
bool exclusive, bool exclusive,
bool reverse, bool reverse,
MetaTensor* out) { MetaTensor* out) {
auto x_dims = x.dims(); auto x_dims = x.dims();
if (flatten) { if (flatten) {
out->set_dims(phi::make_ddim({phi::product(x_dims)})); out->set_dims(phi::make_ddim({phi::product(x_dims)}));
......
...@@ -60,12 +60,12 @@ void CholeskyInferMeta(const MetaTensor& x, bool upper, MetaTensor* out); ...@@ -60,12 +60,12 @@ void CholeskyInferMeta(const MetaTensor& x, bool upper, MetaTensor* out);
void CreateLikeInferMeta(const MetaTensor& x, DataType dtype, MetaTensor* out); void CreateLikeInferMeta(const MetaTensor& x, DataType dtype, MetaTensor* out);
void CumsumInferMeta(const MetaTensor& x, void CumInferMeta(const MetaTensor& x,
int axis, int axis,
bool flatten, bool flatten,
bool exclusive, bool exclusive,
bool reverse, bool reverse,
MetaTensor* out); MetaTensor* out);
void DiagInferMeta(const MetaTensor& x, void DiagInferMeta(const MetaTensor& x,
int offset, int offset,
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/phi/kernels/cumsum_kernel.h" #include "paddle/phi/kernels/cum_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
...@@ -21,44 +21,42 @@ ...@@ -21,44 +21,42 @@
namespace phi { namespace phi {
struct CumsumFunctor { template <typename Device,
template <typename X> typename Dim,
const typename X::TensorScanSumOp operator()(X x, typename X,
int axis, typename Out,
bool exclusive) const { typename Reducer>
return x.cumsum(axis, exclusive);
}
};
template <typename Device, typename Dim, typename X, typename Out>
void ComputeImp(Device d, void ComputeImp(Device d,
const Dim& dims, const Dim& dims,
X x, X x,
Out out, Out out,
int axis, int axis,
bool reverse, bool reverse,
bool exclusive) { bool exclusive,
Reducer reducer) {
if (!reverse) { if (!reverse) {
out.reshape(dims).device(d) = out.reshape(dims).device(d) =
CumsumFunctor()(x.reshape(dims), axis, exclusive); x.reshape(dims).scan(axis, reducer, exclusive);
} else { } else {
std::array<bool, Dim::count> rev; std::array<bool, Dim::count> rev;
rev.fill(false); rev.fill(false);
rev[axis] = reverse; rev[axis] = reverse;
out.reshape(dims).device(d) = out.reshape(dims).device(d) = x.reshape(dims)
CumsumFunctor()(x.reshape(dims).reverse(rev), axis, exclusive) .reverse(rev)
.reverse(rev); .scan(axis, reducer, exclusive)
.reverse(rev);
} }
} }
template <typename T, typename Context> template <typename T, typename Context, typename Reducer>
void CumsumKernel(const Context& dev_ctx, void ScanKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
int axis, int axis,
bool flatten, bool flatten,
bool exclusive, bool exclusive,
bool reverse, bool reverse,
DenseTensor* out) { Reducer reducer,
DenseTensor* out) {
auto out_dims = out->dims(); auto out_dims = out->dims();
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
...@@ -99,7 +97,8 @@ void CumsumKernel(const Context& dev_ctx, ...@@ -99,7 +97,8 @@ void CumsumKernel(const Context& dev_ctx,
out0, out0,
/* axis= */ 0, /* axis= */ 0,
reverse, reverse,
exclusive); exclusive,
reducer);
} else { } else {
ComputeImp(place, ComputeImp(place,
Eigen::DSizes<IndexT, 2>(mid, post), Eigen::DSizes<IndexT, 2>(mid, post),
...@@ -107,7 +106,8 @@ void CumsumKernel(const Context& dev_ctx, ...@@ -107,7 +106,8 @@ void CumsumKernel(const Context& dev_ctx,
out0, out0,
/* axis= */ 0, /* axis= */ 0,
reverse, reverse,
exclusive); exclusive,
reducer);
} }
} else { } else {
if (post == 1) { if (post == 1) {
...@@ -117,7 +117,8 @@ void CumsumKernel(const Context& dev_ctx, ...@@ -117,7 +117,8 @@ void CumsumKernel(const Context& dev_ctx,
out0, out0,
/* axis= */ 1, /* axis= */ 1,
reverse, reverse,
exclusive); exclusive,
reducer);
} else { } else {
ComputeImp(place, ComputeImp(place,
Eigen::DSizes<IndexT, 3>(pre, mid, post), Eigen::DSizes<IndexT, 3>(pre, mid, post),
...@@ -125,11 +126,135 @@ void CumsumKernel(const Context& dev_ctx, ...@@ -125,11 +126,135 @@ void CumsumKernel(const Context& dev_ctx,
out0, out0,
/* axis= */ 1, /* axis= */ 1,
reverse, reverse,
exclusive); exclusive,
reducer);
} }
} }
} }
template <typename T, typename Context>
void CumsumKernel(const Context& dev_ctx,
const DenseTensor& x,
int axis,
bool flatten,
bool exclusive,
bool reverse,
DenseTensor* out) {
using Reducer = Eigen::internal::SumReducer<T>;
auto reducer = Reducer();
ScanKernel<T, Context, Reducer>(
dev_ctx, x, axis, flatten, exclusive, reverse, reducer, out);
}
template <typename T>
struct LogSumExp {
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& a,
const T& b) const {
auto mi = Eigen::internal::scalar_min_op<T>()(a, b);
auto ma = Eigen::internal::scalar_max_op<T>()(a, b);
auto sub = Eigen::internal::scalar_difference_op<T>();
auto add = Eigen::internal::scalar_sum_op<T>();
auto exp = Eigen::internal::scalar_exp_op<T>();
auto log1p = Eigen::internal::scalar_log1p_op<T>();
auto cmp_lt =
Eigen::internal::scalar_cmp_op<T, T, Eigen::internal::cmp_LT>();
auto logsumexp = add(log1p(exp(sub(mi, ma))), ma);
return cmp_lt(ma, Eigen::NumTraits<T>::lowest()) ? ma : logsumexp;
}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T packetOp(const T& a,
const T& b) const {
auto mi = Eigen::internal::pmin(a, b);
auto ma = Eigen::internal::pmax(a, b);
using Eigen::internal::padd;
using Eigen::internal::pcmp_lt;
using Eigen::internal::pexp;
using Eigen::internal::plog1p;
using Eigen::internal::pset1;
using Eigen::internal::psub;
auto logsumexp = padd(plog1p(pexp(psub(mi, ma))), ma);
return pselect(
pcmp_lt(ma, pset1(Eigen::NumTraits<T>::lowest())), ma, logsumexp);
}
};
template <typename T>
struct LogSumExpReducer {
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void reduce(const T t, T* accum) const {
LogSumExp<T> logsumexp;
*accum = logsumexp(*accum, t);
}
template <typename Packet>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void reducePacket(const Packet& p,
Packet* accum) const {
LogSumExp<T> logsumexp;
*accum = logsumexp.packetOp(*accum, p);
}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T initialize() const {
return Eigen::NumTraits<T>::lowest();
}
template <typename Packet>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet initializePacket() const {
return Eigen::internal::pset1(initialize());
}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T finalize(const T accum) const {
return accum;
}
template <typename Packet>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet
finalizePacket(const Packet& vaccum) const {
return vaccum;
}
template <typename Packet>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T
finalizeBoth(const T saccum, const Packet& vaccum) const {
auto max_reducer = Eigen::internal::MaxReducer<T, Eigen::PropagateNaN>();
auto sum_reducer = Eigen::internal::SumReducer<T>();
auto exp = Eigen::internal::scalar_exp_op<T>();
auto cmp_lt =
Eigen::internal::scalar_cmp_op<T, T, Eigen::internal::cmp_LT>();
auto log = Eigen::internal::scalar_log_op<T>();
auto add = Eigen::internal::scalar_sum_op<T>();
using Eigen::internal::pexp;
using Eigen::internal::psub;
// `ma = max(x1, ..., xn)`
// If the max of all of the `xi` is `-infinity` then the result is
// -infinity. If the max is larger than `-infinity` then it's safe to use
// for normalization even if the other elements are `-infinity`.
//
// `logsumexp(x1, ..., xn) = ma + log (exp(x1 - ma) + ... + exp(xn - ma))`
auto ma = max_reducer.finalizeBoth(saccum, vaccum);
auto logsumexp = add(log(sum_reducer.finalizeBoth(
exp(saccum - ma), pexp(psub(vaccum, pset1(ma))))),
ma);
return cmp_lt(ma, Eigen::NumTraits<T>::lowest()) ? initialize() : logsumexp;
}
};
template <typename T, typename Context>
void LogcumsumexpKernel(const Context& dev_ctx,
const DenseTensor& x,
int axis,
bool flatten,
bool exclusive,
bool reverse,
DenseTensor* out) {
using Reducer = LogSumExpReducer<T>;
auto reducer = Reducer();
ScanKernel<T, Context, Reducer>(
dev_ctx, x, axis, flatten, exclusive, reverse, reducer, out);
}
} // namespace phi } // namespace phi
PD_REGISTER_KERNEL(cumsum, PD_REGISTER_KERNEL(cumsum,
...@@ -141,3 +266,6 @@ PD_REGISTER_KERNEL(cumsum, ...@@ -141,3 +266,6 @@ PD_REGISTER_KERNEL(cumsum,
int16_t, int16_t,
int, int,
int64_t) {} int64_t) {}
PD_REGISTER_KERNEL(
logcumsumexp, CPU, ALL_LAYOUT, phi::LogcumsumexpKernel, float, double) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/logcumsumexp_grad_kernel.h"
#include <limits>
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/logcumsumexp_grad_impl.h"
PD_REGISTER_KERNEL(logcumsumexp_grad,
CPU,
ALL_LAYOUT,
phi::LogcumsumexpGradKernel,
float,
double) {}
...@@ -27,4 +27,13 @@ void CumsumKernel(const Context& dev_ctx, ...@@ -27,4 +27,13 @@ void CumsumKernel(const Context& dev_ctx,
bool reverse, bool reverse,
DenseTensor* out); DenseTensor* out);
template <typename T, typename Context>
void LogcumsumexpKernel(const Context& dev_ctx,
const DenseTensor& x,
int axis,
bool flatten,
bool exclusive,
bool reverse,
DenseTensor* out);
} // namespace phi } // namespace phi
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
#include <thrust/reverse.h> #include <thrust/reverse.h>
#include <thrust/scan.h> #include <thrust/scan.h>
#include "paddle/phi/kernels/cumsum_kernel.h" #include "paddle/phi/kernels/cum_kernel.h"
#ifdef __NVCC__ #ifdef __NVCC__
#include <cub/cub.cuh> #include <cub/cub.cuh>
#endif #endif
...@@ -82,19 +82,20 @@ __global__ void MatrixRowReverse(const T* matrix_data, ...@@ -82,19 +82,20 @@ __global__ void MatrixRowReverse(const T* matrix_data,
} }
} }
template <typename T> template <typename T, typename Op>
struct BlockPrefixCallbackOp { struct BlockPrefixCallbackOp {
// Running prefix // Running prefix
T running_total; T running_total_;
// Constructor Op op_;
__device__ BlockPrefixCallbackOp(T running_total)
: running_total(running_total) {} __device__ BlockPrefixCallbackOp(T running_total, Op op)
: running_total_(running_total), op_(op) {}
// Callback operator to be entered by the first warp of threads in the block. // Callback operator to be entered by the first warp of threads in the block.
// Thread-0 is responsible for returning a value for seeding the block-wide // tid 0 is responsible for returning a value for seeding the block-wide scan.
// scan.
__device__ T operator()(T block_aggregate) { __device__ T operator()(T block_aggregate) {
T old_prefix = running_total; T old_prefix = running_total_;
running_total = old_prefix + block_aggregate; running_total_ = op_(old_prefix, block_aggregate);
return old_prefix; return old_prefix;
} }
}; };
...@@ -129,13 +130,36 @@ __global__ void MatrixTranspose(T* odata, ...@@ -129,13 +130,36 @@ __global__ void MatrixTranspose(T* odata,
} }
} }
template <typename T, int BLOCK_THREADS, int ITEMS_PER_THREAD> struct LogAddExp {
template <typename T>
__host__ __device__ __forceinline__ T operator()(const T& a,
const T& b) const {
return std::log(1 + std::exp(std::min(a, b) - std::max(a, b))) +
std::max(a, b);
}
};
template <typename T, typename op>
struct Identity;
template <typename T>
struct Identity<T, cub::Sum> {
static constexpr T value = 0;
};
template <typename T>
struct Identity<T, LogAddExp> {
static constexpr T value = std::numeric_limits<T>::lowest();
};
template <typename T, int BLOCK_THREADS, int ITEMS_PER_THREAD, typename Op>
__global__ void BlockScanKernel(T* d_out, __global__ void BlockScanKernel(T* d_out,
const T* d_in, const T* d_in,
int inner_size, int inner_size,
int outer_size, int outer_size,
int scan_size, int scan_size,
bool exclusive) { bool exclusive,
Op op) {
// Specialize BlockLoad, BlockStore, and BlockRadixSort collective types // Specialize BlockLoad, BlockStore, and BlockRadixSort collective types
typedef cub:: typedef cub::
BlockLoad<T, BLOCK_THREADS, ITEMS_PER_THREAD, cub::BLOCK_LOAD_TRANSPOSE> BlockLoad<T, BLOCK_THREADS, ITEMS_PER_THREAD, cub::BLOCK_LOAD_TRANSPOSE>
...@@ -154,7 +178,7 @@ __global__ void BlockScanKernel(T* d_out, ...@@ -154,7 +178,7 @@ __global__ void BlockScanKernel(T* d_out,
int bx = blockIdx.x; int bx = blockIdx.x;
int by = blockIdx.y; int by = blockIdx.y;
BlockPrefixCallbackOp<T> prefix_op(0); BlockPrefixCallbackOp<T, Op> prefix_op(Identity<T, Op>::value, op);
T block_aggregate = static_cast<T>(0); T block_aggregate = static_cast<T>(0);
// Obtain this block's segment of consecutive keys (blocked across threads) // Obtain this block's segment of consecutive keys (blocked across threads)
...@@ -176,12 +200,11 @@ __global__ void BlockScanKernel(T* d_out, ...@@ -176,12 +200,11 @@ __global__ void BlockScanKernel(T* d_out,
__syncthreads(); __syncthreads();
if (exclusive) { if (exclusive) {
T init_value = static_cast<T>(0);
BlockScanT(temp_storage.scan) BlockScanT(temp_storage.scan)
.ExclusiveScan(thread_keys, thread_keys, cub::Sum(), prefix_op); .ExclusiveScan(thread_keys, thread_keys, op, prefix_op);
} else { } else {
BlockScanT(temp_storage.scan) BlockScanT(temp_storage.scan)
.InclusiveScan(thread_keys, thread_keys, cub::Sum(), prefix_op); .InclusiveScan(thread_keys, thread_keys, op, prefix_op);
} }
__syncthreads(); __syncthreads();
...@@ -190,14 +213,15 @@ __global__ void BlockScanKernel(T* d_out, ...@@ -190,14 +213,15 @@ __global__ void BlockScanKernel(T* d_out,
} }
} }
template <typename T, typename Context> template <typename T, typename Context, typename Op>
void CumsumKernel(const Context& dev_ctx, void ScanKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
int axis, int axis,
bool flatten, bool flatten,
bool exclusive, bool exclusive,
bool reverse, bool reverse,
DenseTensor* out) { Op op,
DenseTensor* out) {
auto out_dims = out->dims(); auto out_dims = out->dims();
auto size = x.numel(); auto size = x.numel();
...@@ -219,7 +243,7 @@ void CumsumKernel(const Context& dev_ctx, ...@@ -219,7 +243,7 @@ void CumsumKernel(const Context& dev_ctx,
// Use thrust for parallel acceleration when the input size is equal to the // Use thrust for parallel acceleration when the input size is equal to the
// length of the ‘axis’ dimension. // length of the ‘axis’ dimension.
if (size == out_dims[axis]) { if (std::is_same<Op, cub::Sum>::value && size == out_dims[axis]) {
#ifdef __HIPCC__ #ifdef __HIPCC__
const auto& policy = thrust::hip::par.on(dev_ctx.stream()); const auto& policy = thrust::hip::par.on(dev_ctx.stream());
#else #else
...@@ -247,6 +271,7 @@ void CumsumKernel(const Context& dev_ctx, ...@@ -247,6 +271,7 @@ void CumsumKernel(const Context& dev_ctx,
return; return;
} }
size_t height = 1; size_t height = 1;
size_t width = 1; size_t width = 1;
for (size_t i = 0; i <= axis; i++) { for (size_t i = 0; i <= axis; i++) {
...@@ -299,17 +324,18 @@ void CumsumKernel(const Context& dev_ctx, ...@@ -299,17 +324,18 @@ void CumsumKernel(const Context& dev_ctx,
} }
} }
if (!transpose && !reverse) { if (!transpose && !reverse) {
BlockScanKernel<T, 128, 4><<<scan_grid, 128, 0, dev_ctx.stream()>>>( BlockScanKernel<T, 128, 4, Op><<<scan_grid, 128, 0, dev_ctx.stream()>>>(
out_data, in_data, outer_size, inner_size, scan_size, exclusive); out_data, in_data, outer_size, inner_size, scan_size, exclusive, op);
} else { } else {
BlockScanKernel<T, 128, 4> BlockScanKernel<T, 128, 4, Op>
<<<scan_grid, 128, 0, dev_ctx.stream()>>>(next_out_data, <<<scan_grid, 128, 0, dev_ctx.stream()>>>(next_out_data,
next_in_data, next_in_data,
outer_size, outer_size,
inner_size, inner_size,
scan_size, scan_size,
exclusive); exclusive,
op);
} }
swap_ptr(next_in_data, next_out_data); swap_ptr(next_in_data, next_out_data);
if (reverse) { if (reverse) {
...@@ -325,6 +351,34 @@ void CumsumKernel(const Context& dev_ctx, ...@@ -325,6 +351,34 @@ void CumsumKernel(const Context& dev_ctx,
} }
} }
template <typename T, typename Context>
void CumsumKernel(const Context& dev_ctx,
const DenseTensor& x,
int axis,
bool flatten,
bool exclusive,
bool reverse,
DenseTensor* out) {
using Op = cub::Sum;
auto op = Op();
ScanKernel<T, Context, Op>(
dev_ctx, x, axis, flatten, exclusive, reverse, op, out);
}
template <typename T, typename Context>
void LogcumsumexpKernel(const Context& dev_ctx,
const DenseTensor& x,
int axis,
bool flatten,
bool exclusive,
bool reverse,
DenseTensor* out) {
using Op = LogAddExp;
auto op = Op();
ScanKernel<T, Context, Op>(
dev_ctx, x, axis, flatten, exclusive, reverse, op, out);
}
} // namespace phi } // namespace phi
PD_REGISTER_KERNEL(cumsum, PD_REGISTER_KERNEL(cumsum,
...@@ -336,3 +390,10 @@ PD_REGISTER_KERNEL(cumsum, ...@@ -336,3 +390,10 @@ PD_REGISTER_KERNEL(cumsum,
int16_t, int16_t,
int, int,
int64_t) {} int64_t) {}
PD_REGISTER_KERNEL(logcumsumexp,
GPU,
ALL_LAYOUT,
phi::LogcumsumexpKernel,
float,
double) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <limits>
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/logcumsumexp_grad_impl.h"
#include "paddle/phi/kernels/logcumsumexp_grad_kernel.h"
PD_REGISTER_KERNEL(logcumsumexp_grad,
GPU,
ALL_LAYOUT,
phi::LogcumsumexpGradKernel,
float,
double) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <limits>
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/cum_kernel.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
namespace phi {
template <typename T>
struct LogGradPositiveFunctor {
HOSTDEVICE T operator()(const T& x) const {
const T kMin = std::numeric_limits<T>::lowest();
return x > 0 ? std::log(x) : kMin;
}
};
template <typename T>
struct LogGradNegativeFunctor {
HOSTDEVICE T operator()(const T& x) const {
const T kMin = std::numeric_limits<T>::lowest();
return x < 0 ? std::log(-x) : kMin;
}
};
template <typename T, typename Context>
void LogcumsumexpGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& out,
const DenseTensor& d_out,
int axis,
bool flatten,
bool exclusive,
bool reverse,
DenseTensor* d_x) {
reverse = !reverse;
dev_ctx.template Alloc<T>(d_x);
auto eigen_x = EigenVector<T>::Flatten(x);
auto eigen_out = EigenVector<T>::Flatten(out);
auto eigen_d_out = EigenVector<T>::Flatten(d_out);
auto& place = *dev_ctx.eigen_device();
DenseTensor output_pos;
output_pos.Resize(d_out.dims());
dev_ctx.template Alloc<T>(&output_pos);
auto eigen_output_pos = EigenVector<T>::Flatten(output_pos);
DenseTensor output_neg;
output_neg.Resize(d_out.dims());
dev_ctx.template Alloc<T>(&output_neg);
auto eigen_output_neg = EigenVector<T>::Flatten(output_neg);
DenseTensor tmp;
tmp.Resize(d_out.dims());
dev_ctx.template Alloc<T>(&tmp);
auto eigen_tmp = EigenVector<T>::Flatten(tmp);
eigen_tmp.device(place) =
eigen_d_out.unaryExpr(LogGradPositiveFunctor<T>()) - eigen_out;
LogcumsumexpKernel<T, Context>(
dev_ctx, tmp, axis, flatten, exclusive, reverse, &output_pos);
eigen_output_pos.device(place) = (eigen_output_pos + eigen_x).exp();
eigen_tmp.device(place) =
eigen_d_out.unaryExpr(LogGradNegativeFunctor<T>()) - eigen_out;
LogcumsumexpKernel<T, Context>(
dev_ctx, tmp, axis, flatten, exclusive, reverse, &output_neg);
eigen_output_neg.device(place) = (eigen_output_neg + eigen_x).exp();
auto eigen_d_x = EigenVector<T>::Flatten(*d_x);
eigen_d_x.device(place) = eigen_output_pos - eigen_output_neg;
}
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void LogcumsumexpGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& out,
const DenseTensor& d_out,
int axis,
bool flatten,
bool exclusive,
bool reverse,
DenseTensor* d_x);
}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature LogcumsumexpOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("logcumsumexp",
{"X"},
{"axis", "flatten", "exclusive", "reverse"},
{"Out"});
}
KernelSignature LogcumsumexpGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("logcumsumexp_grad",
{"X", "Out", "Out@GRAD"},
{"axis", "flatten", "exclusive", "reverse"},
{"X@GRAD"});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(logcumsumexp, phi::LogcumsumexpOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(logcumsumexp_grad,
phi::LogcumsumexpGradOpArgumentMapping);
...@@ -193,6 +193,7 @@ from .tensor.math import tan # noqa: F401 ...@@ -193,6 +193,7 @@ from .tensor.math import tan # noqa: F401
from .tensor.math import cosh # noqa: F401 from .tensor.math import cosh # noqa: F401
from .tensor.math import cumsum # noqa: F401 from .tensor.math import cumsum # noqa: F401
from .tensor.math import cumprod # noqa: F401 from .tensor.math import cumprod # noqa: F401
from .tensor.math import logcumsumexp # noqa: F401
from .tensor.math import logit # noqa: F401 from .tensor.math import logit # noqa: F401
from .tensor.math import exp # noqa: F401 from .tensor.math import exp # noqa: F401
from .tensor.math import expm1 # noqa: F401 from .tensor.math import expm1 # noqa: F401
...@@ -407,6 +408,7 @@ __all__ = [ # noqa ...@@ -407,6 +408,7 @@ __all__ = [ # noqa
'eye', 'eye',
'cumsum', 'cumsum',
'cumprod', 'cumprod',
'logcumsumexp',
'logit', 'logit',
'sign', 'sign',
'is_empty', 'is_empty',
......
...@@ -684,6 +684,7 @@ endif() ...@@ -684,6 +684,7 @@ endif()
foreach(TEST_OP ${TEST_OPS}) foreach(TEST_OP ${TEST_OPS})
py_test_modules(${TEST_OP} MODULES ${TEST_OP}) py_test_modules(${TEST_OP} MODULES ${TEST_OP})
endforeach(TEST_OP) endforeach(TEST_OP)
set_tests_properties(test_logcumsumexp_op PROPERTIES TIMEOUT 30)
py_test_modules(test_adam_op_multi_thread MODULES test_adam_op ENVS py_test_modules(test_adam_op_multi_thread MODULES test_adam_op ENVS
FLAGS_inner_op_parallelism=4) FLAGS_inner_op_parallelism=4)
if(WITH_GPU if(WITH_GPU
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
from typing import Optional
import unittest
import itertools
import numpy as np
import paddle
import paddle.fluid.core as core
import paddle.fluid as fluid
from paddle.fluid import compiler, Program, program_guard
from paddle.fluid.framework import _test_eager_guard
from op_test import OpTest
def np_naive_logcumsumexp(x: np.ndarray, axis: Optional[int] = None):
return np.log(np.cumsum(np.exp(x), axis=axis))
def np_logcumsumexp(x: np.ndarray,
axis: Optional[int] = None,
flatten: Optional[bool] = None,
reverse: bool = False,
exclusive: bool = False):
# `flatten` aligns with c++ op
if flatten:
assert axis in [0, None]
axis = None
x = np.copy(x)
if axis is None:
x = x.flatten()
axis = 0
if reverse:
x = np.flip(x, axis)
dimensions = [range(dim) for dim in x.shape[:axis]]
if exclusive:
x = np.roll(x, 1, axis)
for prefix_dim in itertools.product(*dimensions):
x[prefix_dim][0] = np.finfo(x.dtype).min
for prefix_dim in itertools.product(*dimensions):
arr = x[prefix_dim]
for dim in range(1, arr.shape[0]):
arr[dim] = np.logaddexp(arr[dim - 1], arr[dim])
if reverse:
x = np.flip(x, axis)
return x
def np_logcumsumexp_grad(
x: np.ndarray,
dout: np.ndarray,
axis: Optional[int] = None,
flatten: Optional[bool] = None,
reverse: bool = False,
exclusive: bool = False,
):
out = np_logcumsumexp(x, axis, flatten, reverse, exclusive)
log_grad_positive = np.where(dout > 0, np.log(dout), np.finfo(x.dtype).min)
log_grad_negative = np.where(dout < 0, np.log(-dout), np.finfo(x.dtype).min)
output_pos = np.exp(
np_logcumsumexp(log_grad_positive - out,
axis=axis,
flatten=flatten,
reverse=not reverse,
exclusive=exclusive).reshape(x.shape) + x)
output_neg = np.exp(
np_logcumsumexp(log_grad_negative - out,
axis=axis,
flatten=flatten,
reverse=not reverse,
exclusive=exclusive).reshape(x.shape) + x)
return output_pos - output_neg
class TestLogcumsumexp(unittest.TestCase):
def run_imperative(self):
data_np = np.arange(12, dtype=np.float32).reshape(3, 4)
data = paddle.to_tensor(data_np)
y = paddle.logcumsumexp(data)
z = np_logcumsumexp(data_np)
self.assertTrue(np.allclose(z, y.numpy()))
y = paddle.logcumsumexp(data, axis=0)
z = np_logcumsumexp(data_np, axis=0)
self.assertTrue(np.allclose(z, y.numpy()))
y = paddle.logcumsumexp(data, axis=-1)
z = np_logcumsumexp(data_np, axis=-1)
self.assertTrue(np.allclose(z, y.numpy()))
y = paddle.logcumsumexp(data, dtype='float32')
self.assertTrue(y.dtype == core.VarDesc.VarType.FP32)
y = paddle.logcumsumexp(data, axis=-2)
z = np_logcumsumexp(data_np, axis=-2)
self.assertTrue(np.allclose(z, y.numpy()))
with self.assertRaises(IndexError):
y = paddle.logcumsumexp(data, axis=-3)
with self.assertRaises(IndexError):
y = paddle.logcumsumexp(data, axis=2)
data_np = np.arange(10000, 10024, dtype=np.float32)
data = paddle.to_tensor(data_np)
y = paddle.logcumsumexp(data)
z = np_naive_logcumsumexp(data_np)
# check that naive algorithm overflows
self.assertTrue(all(z == np.inf))
z = np_logcumsumexp(data_np)
# check that our algorithm doesn't overflow
self.assertTrue(all(z != np.inf))
self.assertTrue(np.allclose(z, y.numpy()))
def run_static(self, use_gpu=False):
with fluid.program_guard(fluid.Program()):
data_np = np.random.random((5, 4)).astype(np.float32)
x = paddle.static.data('X', [5, 4])
y = paddle.logcumsumexp(x)
y2 = paddle.logcumsumexp(x, axis=0)
y3 = paddle.logcumsumexp(x, axis=-1)
y4 = paddle.logcumsumexp(x, dtype='float64')
y5 = paddle.logcumsumexp(x, axis=-2)
place = fluid.CUDAPlace(0) if use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
out = exe.run(feed={'X': data_np},
fetch_list=[
y.name,
y2.name,
y3.name,
y4.name,
y5.name,
])
z = np_logcumsumexp(data_np)
self.assertTrue(np.allclose(z, out[0]))
z = np_logcumsumexp(data_np, axis=0)
self.assertTrue(np.allclose(z, out[1]))
z = np_logcumsumexp(data_np, axis=-1)
self.assertTrue(np.allclose(z, out[2]))
self.assertTrue(out[3].dtype == np.float64)
z = np_logcumsumexp(data_np, axis=-2)
self.assertTrue(np.allclose(z, out[4]))
def test_cpu(self):
paddle.disable_static(paddle.fluid.CPUPlace())
self.run_imperative()
paddle.enable_static()
self.run_static()
def test_gpu(self):
if not fluid.core.is_compiled_with_cuda():
return
paddle.disable_static(paddle.fluid.CUDAPlace(0))
self.run_imperative()
paddle.enable_static()
self.run_static(use_gpu=True)
def test_name(self):
with fluid.program_guard(fluid.Program()):
x = paddle.static.data('x', [3, 4])
y = paddle.logcumsumexp(x, name='out')
self.assertTrue('out' in y.name)
def test_type_error(self):
with fluid.program_guard(fluid.Program()):
with self.assertRaises(TypeError):
data_np = np.random.random((100, 100), dtype=np.int32)
x = paddle.static.data('X', [100, 100], dtype='int32')
y = paddle.logcumsumexp(x)
place = fluid.CUDAPlace(0)
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
out = exe.run(feed={'X': data_np}, fetch_list=[y.name])
class BaseTestCases:
class BaseOpTest(OpTest):
def setUp(self):
self.op_type = "logcumsumexp"
input, attrs = self.input_and_attrs()
self.inputs = {'X': input}
self.attrs = attrs
self.outputs = {'Out': np_logcumsumexp(input, **attrs)}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'],
'Out',
user_defined_grads=[
np_logcumsumexp_grad(self.inputs['X'],
1 / self.inputs['X'].size,
**self.attrs)
])
def input_and_attrs(self):
raise NotImplementedError()
class TestLogcumsumexpOp1(BaseTestCases.BaseOpTest):
def input_and_attrs(self):
return np.arange(100, dtype=np.float64).reshape(10, 10), {
'axis': 0,
'flatten': True,
'reverse': True
}
class TestLogcumsumexpOp2(BaseTestCases.BaseOpTest):
def input_and_attrs(self):
return np.arange(100, dtype=np.float64).reshape(10, 10), {
'axis': 1,
'reverse': True
}
class TestLogcumsumexpOp3(BaseTestCases.BaseOpTest):
def input_and_attrs(self):
return np.arange(100, dtype=np.float64).reshape(10, 10), {'axis': 1}
class TestLogcumsumexpOp4(BaseTestCases.BaseOpTest):
def input_and_attrs(self):
return np.arange(100, dtype=np.float64).reshape(10, 10), {
'axis': 0,
'flatten': True,
'reverse': True,
'exclusive': True
}
if __name__ == '__main__':
unittest.main()
...@@ -139,6 +139,7 @@ from .math import tan # noqa: F401 ...@@ -139,6 +139,7 @@ from .math import tan # noqa: F401
from .math import cosh # noqa: F401 from .math import cosh # noqa: F401
from .math import cumsum # noqa: F401 from .math import cumsum # noqa: F401
from .math import cumprod # noqa: F401 from .math import cumprod # noqa: F401
from .math import logcumsumexp # noqa: F401
from .math import logit # noqa: F401 from .math import logit # noqa: F401
from .math import exp # noqa: F401 from .math import exp # noqa: F401
from .math import exp_ # noqa: F401 from .math import exp_ # noqa: F401
...@@ -310,6 +311,7 @@ tensor_method_func = [ #noqa ...@@ -310,6 +311,7 @@ tensor_method_func = [ #noqa
'cosh', 'cosh',
'cumsum', 'cumsum',
'cumprod', 'cumprod',
'logcumsumexp',
'logit', 'logit',
'exp', 'exp',
'exp_', 'exp_',
......
...@@ -2909,7 +2909,7 @@ def cumsum(x, axis=None, dtype=None, name=None): ...@@ -2909,7 +2909,7 @@ def cumsum(x, axis=None, dtype=None, name=None):
The cumulative sum of the elements along a given axis. The cumulative sum of the elements along a given axis.
**Note**: **Note**:
The first element of the result is the same of the first element of the input. The first element of the result is the same as the first element of the input.
Args: Args:
x (Tensor): The input tensor needed to be cumsumed. x (Tensor): The input tensor needed to be cumsumed.
...@@ -2970,6 +2970,79 @@ def cumsum(x, axis=None, dtype=None, name=None): ...@@ -2970,6 +2970,79 @@ def cumsum(x, axis=None, dtype=None, name=None):
_cum_sum_ = generate_layer_fn('cumsum') _cum_sum_ = generate_layer_fn('cumsum')
return _cum_sum_(**kwargs) return _cum_sum_(**kwargs)
def logcumsumexp(x, axis=None, dtype=None, name=None):
r"""
The logarithm of the cumulative summation of the exponentiation of the elements along a given axis.
For summation index j given by `axis` and other indices i, the result is
.. math::
logcumsumexp(x)_{ij} = log \sum_{i=0}^{j}exp(x_{ij})
Note:
The first element of the result is the same as the first element of the input.
Args:
x (Tensor): The input tensor.
axis (int, optional): The dimension to do the operation along. -1 means the last dimension. The default (None) is to compute the cumsum over the flattened array.
dtype (str, optional): The data type of the output tensor, can be float32, float64. If specified, the input tensor is casted to dtype before the operation is performed. This is useful for preventing data type overflows. The default value is None.
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
Returns:
Tensor, the result of logcumsumexp operator.
Examples:
.. code-block:: python
import paddle
data = paddle.arange(12, dtype='float64')
data = paddle.reshape(data, (3, 4))
y = paddle.logcumsumexp(data)
# [ 0. 1.3132617 2.4076061 3.4401898 4.4519143 5.4561934
# 6.4577627 7.4583397 8.458551 9.45863 10.458658 11.458669 ]
y = paddle.logcumsumexp(data, axis=0)
# [[ 0. 1. 2. 3. ]
# [ 4.01815 5.01815 6.01815 7.01815 ]
# [ 8.018479 9.018479 10.018479 11.018479]]
y = paddle.logcumsumexp(data, axis=-1)
# [[ 0. 1.3132617 2.4076061 3.4401898]
# [ 4. 5.3132615 6.407606 7.44019 ]
# [ 8. 9.313262 10.407606 11.440189 ]]
y = paddle.logcumsumexp(data, dtype='float64')
print(y.dtype)
# paddle.float64
"""
if axis is None:
flatten = True
else:
flatten = False
if dtype is not None and x.dtype != convert_np_dtype_to_dtype_(dtype):
x = cast(x, dtype)
if in_dygraph_mode():
if axis is None: axis = -1
return _C_ops.final_state_logcumsumexp(x, axis, flatten, False, False)
if _in_legacy_dygraph():
if axis is None:
return _C_ops.logcumsumexp(x, 'flatten', flatten)
else:
return _C_ops.logcumsumexp(x, 'axis', axis, 'flatten', flatten)
check_variable_and_dtype(x, 'x', ['float32', 'float64'], "logcumsumexp")
helper = LayerHelper('logcumsumexp', **locals())
out = helper.create_variable_for_type_inference(x.dtype)
helper.append_op(type='logcumsumexp', inputs={'X': x}, outputs={'Out': out}, attrs={'axis': axis, 'flatten': flatten})
return out
def cumprod(x, dim=None, dtype=None, name=None): def cumprod(x, dim=None, dtype=None, name=None):
""" """
Compute the cumulative product of the input tensor x along a given dimension dim. Compute the cumulative product of the input tensor x along a given dimension dim.
......
...@@ -482,7 +482,7 @@ ...@@ -482,7 +482,7 @@
args : (Tensor x, int axis, bool flatten, bool exclusive, bool reverse) args : (Tensor x, int axis, bool flatten, bool exclusive, bool reverse)
output : Tensor(out) output : Tensor(out)
infer_meta : infer_meta :
func : CumsumInferMeta func : CumInferMeta
kernel : kernel :
func : cumsum func : cumsum
backward : cumsum_grad backward : cumsum_grad
...@@ -1259,6 +1259,15 @@ ...@@ -1259,6 +1259,15 @@
func : log_softmax func : log_softmax
backward : log_softmax_grad backward : log_softmax_grad
- api : logcumsumexp
args : (Tensor x, int axis, bool flatten, bool exclusive, bool reverse)
output : Tensor(out)
infer_meta :
func : CumInferMeta
kernel :
func : logcumsumexp
backward : logcumsumexp_grad
# logical_and # logical_and
- api : logical_and - api : logical_and
args : (Tensor x, Tensor y) args : (Tensor x, Tensor y)
......
...@@ -1137,6 +1137,16 @@ ...@@ -1137,6 +1137,16 @@
kernel : kernel :
func : log_softmax_grad func : log_softmax_grad
- backward_api : logcumsumexp_grad
forward : logcumsumexp(Tensor x, int axis, bool flatten, bool exclusive, bool reverse) -> Tensor(out)
infer_meta :
func : UnchangedInferMeta
param : [x]
args : (Tensor x, Tensor out, Tensor out_grad, int axis, bool flatten, bool exclusive, bool reverse)
output : Tensor(x_grad)
kernel :
func : logcumsumexp_grad
- backward_api : logit_grad - backward_api : logit_grad
forward : logit (Tensor x, float eps = 1e-6f) -> Tensor(out) forward : logit (Tensor x, float eps = 1e-6f) -> Tensor(out)
args : (Tensor x, Tensor out_grad, float eps) args : (Tensor x, Tensor out_grad, float eps)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册