未验证 提交 3a3fb1fe 编写于 作者: N NetPunk 提交者: GitHub

【Hackathon 4 No.17】Add cummax / cummin API to Paddle (#53546)

上级 1bcf437a
......@@ -463,6 +463,26 @@
func : cross_grad
data_type : out_grad
- backward_op : cummax_grad
forward : cummax(Tensor x, int axis=-1, int dtype=3) -> Tensor(out), Tensor(indices)
args : (Tensor x, Tensor indices, Tensor out_grad, int axis, int dtype)
output : Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
param: [x]
kernel :
func : cummax_grad
- backward_op : cummin_grad
forward : cummin(Tensor x, int axis=-1, int dtype=3) -> Tensor(out), Tensor(indices)
args : (Tensor x, Tensor indices, Tensor out_grad, int axis, int dtype)
output : Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
param: [x]
kernel :
func : cummin_grad
- backward_op : cumprod_grad
forward : cumprod (Tensor x, int dim) -> Tensor(out)
args : (Tensor x, Tensor out, Tensor out_grad, int dim)
......
......@@ -552,6 +552,24 @@
data_type : input
backward : cross_entropy_with_softmax_grad
- op : cummax
args : (Tensor x, int axis=-1, int dtype=3)
output : Tensor(out), Tensor(indices)
infer_meta :
func : CumWithIndicesInferMeta
kernel :
func : cummax
backward : cummax_grad
- op : cummin
args : (Tensor x, int axis=-1, int dtype=3)
output : Tensor(out), Tensor(indices)
infer_meta :
func : CumWithIndicesInferMeta
kernel :
func : cummin
backward : cummin_grad
- op : cumprod
args : (Tensor x, int dim)
output : Tensor(out)
......
......@@ -506,6 +506,69 @@ void CumScalarAxisInferMeta(const MetaTensor& x,
CumInferMeta(x, axis.to<int>(), flatten, exclusive, reverse, out);
}
void CumWithIndicesInferMeta(const MetaTensor& x,
int axis,
int dtype,
MetaTensor* out,
MetaTensor* indices) {
auto x_dims = x.dims();
auto indices_type = phi::TransToPhiDataType(dtype);
PADDLE_ENFORCE_EQ(
(indices_type == DataType::INT32 || indices_type == DataType::INT64),
true,
phi::errors::InvalidArgument("dtype of indices must be int32 or int64"));
if (indices_type == DataType::INT32) {
int _axis;
if (axis < 0) {
_axis = axis + x_dims.size();
} else {
_axis = axis;
}
PADDLE_ENFORCE_LT(
phi::vectorize(x_dims)[_axis],
INT32_MAX,
phi::errors::OutOfRange(
"cummax with axis %ld may be overflow, set dtype int64 to continue",
axis));
}
if (x_dims.size() > 0) {
PADDLE_ENFORCE_GE(
axis,
-x_dims.size(),
phi::errors::OutOfRange(
"axis is out of range (expected to be in range of [%ld, "
"%ld), but got %ld).",
-(x_dims.size()),
x_dims.size(),
axis));
PADDLE_ENFORCE_LT(
axis,
x_dims.size(),
phi::errors::OutOfRange(
"axis is out of range (expected to be in range of [%ld, "
"%ld), but got %ld).",
-(x_dims.size()),
x_dims.size(),
axis));
} else {
PADDLE_ENFORCE_EQ(
(axis == 0 || axis == -1),
true,
errors::InvalidArgument("The axis must be -1 or 0 in 0D Tensor, "
"but the value given is %d.",
axis));
}
out->set_dims(x_dims);
out->set_dtype(x.dtype());
out->share_lod(x);
indices->set_dims(x_dims);
indices->set_dtype(indices_type);
indices->share_lod(x);
}
void CropInferMeta(const MetaTensor& x,
const IntArray& shape,
const IntArray& offsets,
......
......@@ -120,6 +120,12 @@ void CumScalarAxisInferMeta(const MetaTensor& x,
bool reverse,
MetaTensor* out);
void CumWithIndicesInferMeta(const MetaTensor& x,
int axis,
int dtype,
MetaTensor* out,
MetaTensor* indices);
void DecodeJpegInferMeta(const MetaTensor& x,
const std::string& mode,
MetaTensor* out);
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/cum_maxmin_grad_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/gather_scatter_functor.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace phi {
template <typename T, typename Context>
void CummaxGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& indices,
const DenseTensor& out_grad,
int axis,
int dtype,
DenseTensor* x_grad) {
dev_ctx.template Alloc<T>(x_grad);
phi::funcs::SetConstant<Context, T> functor;
functor(dev_ctx, x_grad, static_cast<T>(0));
if (axis < 0) {
axis = axis + x.dims().size();
}
auto indices_type = phi::TransToPhiDataType(dtype);
if (indices_type == DataType::INT32) {
phi::funcs::cpu_scatter_add_kernel<T, int32_t>(
*x_grad, axis, indices, out_grad, dev_ctx);
} else if (indices_type == DataType::INT64) {
phi::funcs::cpu_scatter_add_kernel<T, int64_t>(
*x_grad, axis, indices, out_grad, dev_ctx);
}
}
template <typename T, typename Context>
void CumminGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& indices,
const DenseTensor& out_grad,
int axis,
int dtype,
DenseTensor* x_grad) {
dev_ctx.template Alloc<T>(x_grad);
phi::funcs::SetConstant<Context, T> functor;
functor(dev_ctx, x_grad, static_cast<T>(0));
if (axis < 0) {
axis = axis + x.dims().size();
}
auto indices_type = phi::TransToPhiDataType(dtype);
if (indices_type == DataType::INT32) {
phi::funcs::cpu_scatter_add_kernel<T, int32_t>(
*x_grad, axis, indices, out_grad, dev_ctx);
} else if (indices_type == DataType::INT64) {
phi::funcs::cpu_scatter_add_kernel<T, int64_t>(
*x_grad, axis, indices, out_grad, dev_ctx);
}
}
} // namespace phi
PD_REGISTER_KERNEL(cummax_grad,
CPU,
ALL_LAYOUT,
phi::CummaxGradKernel,
float,
double,
int32_t,
int64_t) {}
PD_REGISTER_KERNEL(cummin_grad,
CPU,
ALL_LAYOUT,
phi::CumminGradKernel,
float,
double,
int32_t,
int64_t) {}
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/cum_maxmin_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace phi {
#ifdef _MSC_VER
template <typename T>
typename std::enable_if<std::is_integral<T>::value, bool>::type isnan_(T x) {
return false;
}
template <typename T>
typename std::enable_if<!std::is_integral<T>::value, bool>::type isnan_(T x) {
return std::isnan(x);
}
#else
template <typename T>
bool isnan_(T x) {
return std::isnan(x);
}
#endif
template <typename T>
T compute_stride(T axis, phi::DDim dims) {
T size = 1;
for (T i = axis + 1; i < dims.size(); i++) {
size *= dims[i];
}
return size;
}
template <typename T1, typename T2, typename BinaryFunction>
void ComputeImp(const DenseTensor& x,
DenseTensor* out,
DenseTensor* indices,
int64_t axis) {
int ndims = x.dims().size();
int finished = 0;
std::vector<int64_t> counter(ndims, 0);
const T1* x_data = x.data<T1>();
T1* values_data = out->data<T1>();
T2* indices_data = indices->data<T2>();
int64_t x_stride = compute_stride<int64_t>(axis, x.dims());
int64_t values_stride = compute_stride<int64_t>(axis, out->dims());
int64_t indices_stride = compute_stride<int64_t>(axis, indices->dims());
auto x_dim_vec = phi::vectorize<int>(x.dims());
int x_dim_size = x_dim_vec[axis];
BinaryFunction op;
while (!finished) {
T1 max = *reinterpret_cast<const T1*>(x_data);
int idx = 0;
for (int i = 0; i < x_dim_size; i++) {
T1 curr_elem = *reinterpret_cast<const T1*>(&x_data[i * x_stride]);
if (isnan_(curr_elem) || (!isnan_(max) && op(curr_elem, max))) {
max = curr_elem;
idx = i;
}
values_data[i * values_stride] = max;
indices_data[i * indices_stride] = idx;
}
if (ndims == 1) break;
for (int dim_i = 0; dim_i < ndims; dim_i++) {
if (dim_i == axis) {
if (dim_i == (ndims - 1)) {
finished = 1;
break;
}
continue;
}
int64_t x_stride_ = compute_stride<int64_t>(dim_i, x.dims());
int64_t values_stride_ = compute_stride<int64_t>(dim_i, out->dims());
int64_t indices_stride_ = compute_stride<int64_t>(dim_i, indices->dims());
counter[dim_i]++;
x_data += x_stride_;
values_data += values_stride_;
indices_data += indices_stride_;
if (counter[dim_i] == x_dim_vec[dim_i]) {
if (dim_i == ndims - 1) {
finished = 1;
break;
} else {
x_data -= counter[dim_i] * x_stride_;
values_data -= counter[dim_i] * values_stride_;
indices_data -= counter[dim_i] * indices_stride_;
counter[dim_i] = 0;
}
} else {
break;
}
}
}
}
template <typename T1, typename T2, typename BinaryFunction, typename Context>
void ScanWithIndicesKernel(const Context& dev_ctx,
const DenseTensor& x,
int axis,
DenseTensor* out,
DenseTensor* indices) {
dev_ctx.template Alloc<T1>(out);
dev_ctx.template Alloc<T2>(indices);
// For 0D Tensor
if (x.numel() == 1) {
auto raw_dims = out->dims();
phi::Copy<Context>(dev_ctx, x, dev_ctx.GetPlace(), false, out);
phi::funcs::SetConstant<Context, T2> set_zero;
set_zero(dev_ctx, indices, static_cast<T2>(0.0));
out->Resize(raw_dims);
indices->Resize(raw_dims);
return;
}
auto out_dims = out->dims();
PADDLE_ENFORCE_EQ(
axis < out_dims.size() && axis >= (0 - out_dims.size()),
true,
phi::errors::OutOfRange(
"Attr(axis) is out of range, It's expected "
"to be in range of [-%d, %d]. But received Attr(axis) = %d.",
out_dims.size(),
out_dims.size() - 1,
axis));
if (axis < 0) {
axis = axis + out_dims.size();
}
ComputeImp<T1, T2, BinaryFunction>(x, out, indices, axis);
}
template <typename T, typename Context>
void CummaxKernel(const Context& dev_ctx,
const DenseTensor& x,
int axis,
int dtype,
DenseTensor* out,
DenseTensor* indices) {
auto indices_type = phi::TransToPhiDataType(dtype);
if (indices_type == DataType::INT32) {
ScanWithIndicesKernel<T, int32_t, std::greater_equal<T>, Context>(
dev_ctx, x, axis, out, indices);
} else if (indices_type == DataType::INT64) {
ScanWithIndicesKernel<T, int64_t, std::greater_equal<T>, Context>(
dev_ctx, x, axis, out, indices);
}
}
template <typename T, typename Context>
void CumminKernel(const Context& dev_ctx,
const DenseTensor& x,
int axis,
int dtype,
DenseTensor* out,
DenseTensor* indices) {
auto indices_type = phi::TransToPhiDataType(dtype);
if (indices_type == DataType::INT32) {
ScanWithIndicesKernel<T, int32_t, std::less_equal<T>, Context>(
dev_ctx, x, axis, out, indices);
} else if (indices_type == DataType::INT64) {
ScanWithIndicesKernel<T, int64_t, std::less_equal<T>, Context>(
dev_ctx, x, axis, out, indices);
}
}
} // namespace phi
PD_REGISTER_KERNEL(cummax,
CPU,
ALL_LAYOUT,
phi::CummaxKernel,
float,
double,
int32_t,
int64_t) {}
PD_REGISTER_KERNEL(cummin,
CPU,
ALL_LAYOUT,
phi::CumminKernel,
float,
double,
int32_t,
int64_t) {}
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void CummaxGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& indices,
const DenseTensor& out_grad,
int axis,
int dtype,
DenseTensor* x_grad);
template <typename T, typename Context>
void CumminGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& indices,
const DenseTensor& out_grad,
int axis,
int dtype,
DenseTensor* x_grad);
} // namespace phi
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void CummaxKernel(const Context& dev_ctx,
const DenseTensor& x,
int axis,
int dtype,
DenseTensor* out,
DenseTensor* indices);
template <typename T, typename Context>
void CumminKernel(const Context& dev_ctx,
const DenseTensor& x,
int axis,
int dtype,
DenseTensor* out,
DenseTensor* indices);
} // namespace phi
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/cum_maxmin_grad_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/gather_scatter_functor.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace phi {
template <typename T, typename Context>
void CummaxGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& indices,
const DenseTensor& out_grad,
int axis,
int dtype,
DenseTensor* x_grad) {
dev_ctx.template Alloc<T>(x_grad);
phi::funcs::SetConstant<Context, T> functor;
functor(dev_ctx, x_grad, static_cast<T>(0));
if (axis < 0) {
axis = axis + x.dims().size();
}
auto indices_type = phi::TransToPhiDataType(dtype);
if (indices_type == DataType::INT32) {
phi::funcs::gpu_scatter_add_kernel<T, int32_t>(
*x_grad, axis, indices, out_grad, dev_ctx);
} else if (indices_type == DataType::INT64) {
phi::funcs::gpu_scatter_add_kernel<T, int64_t>(
*x_grad, axis, indices, out_grad, dev_ctx);
}
}
template <typename T, typename Context>
void CumminGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& indices,
const DenseTensor& out_grad,
int axis,
int dtype,
DenseTensor* x_grad) {
dev_ctx.template Alloc<T>(x_grad);
phi::funcs::SetConstant<Context, T> functor;
functor(dev_ctx, x_grad, static_cast<T>(0));
if (axis < 0) {
axis = axis + x.dims().size();
}
auto indices_type = phi::TransToPhiDataType(dtype);
if (indices_type == DataType::INT32) {
phi::funcs::gpu_scatter_add_kernel<T, int32_t>(
*x_grad, axis, indices, out_grad, dev_ctx);
} else if (indices_type == DataType::INT64) {
phi::funcs::gpu_scatter_add_kernel<T, int64_t>(
*x_grad, axis, indices, out_grad, dev_ctx);
}
}
} // namespace phi
PD_REGISTER_KERNEL(cummax_grad,
GPU,
ALL_LAYOUT,
phi::CummaxGradKernel,
float,
double,
int32_t,
int64_t) {}
PD_REGISTER_KERNEL(cummin_grad,
GPU,
ALL_LAYOUT,
phi::CumminGradKernel,
float,
double,
int32_t,
int64_t) {}
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/cum_maxmin_kernel.h"
#include <numeric>
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/hostdevice.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace phi {
template <
typename T1,
typename T2,
typename BinaryOperation,
typename std::enable_if<std::is_floating_point<T1>::value, int>::type = 0>
__device__ void binary_op_update(const T1 lhs,
T1* rhs,
const T2 lhs_idx,
T2* rhs_idx,
BinaryOperation binary_op) {
if (!isnan(*rhs) && (isnan(lhs) || !binary_op(*rhs, lhs))) {
*rhs = lhs;
*rhs_idx = lhs_idx;
}
}
template <typename T1,
typename T2,
typename BinaryOperation,
typename std::enable_if<std::is_integral<T1>::value, int>::type = 0>
__device__ void binary_op_update(const T1 lhs,
T1* rhs,
const T2 lhs_idx,
T2* rhs_idx,
BinaryOperation binary_op) {
if (!binary_op(*rhs, lhs)) {
*rhs = lhs;
*rhs_idx = lhs_idx;
}
}
template <
typename T1,
typename T2,
typename BinaryOperation,
typename std::enable_if<std::is_floating_point<T1>::value, int>::type = 0>
__device__ void binary_op_update_v(const T1 lhs,
T1* rhs,
const T2 lhs_idx,
T2* rhs_idx,
BinaryOperation binary_op) {
if (isnan(lhs) || (!isnan(*rhs) && binary_op(lhs, *rhs))) {
*rhs = lhs;
*rhs_idx = lhs_idx;
}
}
template <typename T1,
typename T2,
typename BinaryOperation,
typename std::enable_if<std::is_integral<T1>::value, int>::type = 0>
__device__ void binary_op_update_v(const T1 lhs,
T1* rhs,
const T2 lhs_idx,
T2* rhs_idx,
BinaryOperation binary_op) {
if (binary_op(lhs, *rhs)) {
*rhs = lhs;
*rhs_idx = lhs_idx;
}
}
template <typename T1,
typename T2,
int num_threads_x,
int num_threads_y,
class BinaryFunction>
__global__ void KernelScanInnerWithIndices(const T1* x_data,
T1* values_data,
T2* indices_data,
int num_rows,
int row_size,
T1 init,
BinaryFunction binary_op) {
__shared__ T1 vbuf[num_threads_y][2 * num_threads_x];
__shared__ T2 ibuf[num_threads_y][2 * num_threads_x];
T1* row_buf = vbuf[threadIdx.y];
T2* row_idx_buf = ibuf[threadIdx.y];
for (int block_row = blockIdx.x * blockDim.y; block_row < num_rows;
block_row += blockDim.y * gridDim.x) {
int row = block_row + threadIdx.y;
const T1* row_self = x_data + row * row_size;
T1* row_values = values_data + row * row_size;
T2* row_indices = indices_data + row * row_size;
T1 block_total = init;
T2 block_idx_final = 0;
// Perform scan on one block at a time, keeping track of the total value of
// all blocks processed so far.
for (int block_col = 0; block_col < row_size;
block_col += 2 * num_threads_x) {
// Load data into shared memory (two values per thread).
int col1 = block_col + threadIdx.x;
int col2 = block_col + num_threads_x + threadIdx.x;
if (row < num_rows) {
if (col1 < row_size) {
row_buf[threadIdx.x] = *reinterpret_cast<const T1*>(&row_self[col1]);
row_idx_buf[threadIdx.x] = col1;
} else {
row_buf[threadIdx.x] = init;
}
if (col2 < row_size) {
row_buf[num_threads_x + threadIdx.x] =
*reinterpret_cast<const T1*>(&row_self[col2]);
row_idx_buf[num_threads_x + threadIdx.x] = col2;
} else {
row_buf[num_threads_x + threadIdx.x] = init;
}
if (threadIdx.x == 0) {
binary_op_update(block_total,
&row_buf[0],
block_idx_final,
&row_idx_buf[0],
binary_op);
}
}
__syncthreads();
// Parallel reduction (up-sweep).
for (int s = num_threads_x, d = 1; s >= 1; s >>= 1, d <<= 1) {
if (row < num_rows && threadIdx.x < s) {
int offset = (2 * threadIdx.x + 1) * d - 1;
binary_op_update(row_buf[offset],
&row_buf[offset + d],
row_idx_buf[offset],
&row_idx_buf[offset + d],
binary_op);
}
__syncthreads();
}
// Down-sweep.
for (int s = 2, d = num_threads_x / 2; d >= 1; s <<= 1, d >>= 1) {
if (row < num_rows && threadIdx.x < s - 1) {
int offset = 2 * (threadIdx.x + 1) * d - 1;
binary_op_update(row_buf[offset],
&row_buf[offset + d],
row_idx_buf[offset],
&row_idx_buf[offset + d],
binary_op);
}
__syncthreads();
}
// Write back to output.
if (row < num_rows) {
if (col1 < row_size) {
row_values[col1] = row_buf[threadIdx.x];
row_indices[col1] = row_idx_buf[threadIdx.x];
}
if (col2 < row_size) {
row_values[col2] = row_buf[num_threads_x + threadIdx.x];
row_indices[col2] = row_idx_buf[num_threads_x + threadIdx.x];
}
}
block_total = row_buf[2 * num_threads_x - 1];
block_idx_final = row_idx_buf[2 * num_threads_x - 1];
__syncthreads();
}
}
}
template <typename T1, typename T2, class BinaryFunction>
__global__ void KernelScanOuterWithIndices(const T1* x_data,
T1* values_data,
T2* indices_data,
const uint32_t num_orows,
const uint32_t num_irows,
const uint32_t row_size,
T1 init,
BinaryFunction binary_op) {
for (uint32_t orow = blockIdx.x; orow < num_orows; orow += gridDim.x) {
for (uint32_t irow = blockIdx.y * blockDim.x + threadIdx.x;
irow < num_irows;
irow += gridDim.y * blockDim.x) {
const T1* x = x_data + orow * row_size * num_irows + irow;
T1* values = values_data + orow * row_size * num_irows + irow;
T2* indices = indices_data + orow * row_size * num_irows + irow;
T1 out = init;
T2 out_idx = 0;
for (T2 col = 0; col < row_size; ++col) {
const auto val = *reinterpret_cast<const T1*>(x);
binary_op_update_v(val, &out, col, &out_idx, binary_op);
*values = out;
*indices = out_idx;
x += num_irows;
values += num_irows;
indices += num_irows;
}
}
}
}
template <typename T1, typename T2, typename BinaryFunction, typename Context>
void ScanWithIndicesKernel(const Context& dev_ctx,
const DenseTensor& x,
int axis,
T1 init,
DenseTensor* out,
DenseTensor* indices) {
dev_ctx.template Alloc<T1>(out);
dev_ctx.template Alloc<T2>(indices);
// For 0D Tensor
if (out->numel() == 1) {
auto raw_dims = out->dims();
phi::Copy<Context>(dev_ctx, x, dev_ctx.GetPlace(), false, out);
phi::funcs::SetConstant<Context, T2> set_zero;
set_zero(dev_ctx, indices, static_cast<T2>(0.0));
out->Resize(raw_dims);
indices->Resize(raw_dims);
return;
}
BinaryFunction op;
auto out_dims = out->dims();
auto size = x.numel();
PADDLE_ENFORCE_EQ(
axis < out_dims.size() && axis >= (0 - out_dims.size()),
true,
phi::errors::OutOfRange(
"Attr(axis) is out of range, It's expected "
"to be in range of [-%d, %d]. But received Attr(axis) = %d.",
out_dims.size(),
out_dims.size() - 1,
axis));
if (axis < 0) {
axis += out_dims.size();
}
const T1* x_data = x.data<T1>();
T1* values_data = out->data<T1>();
T2* indices_data = indices->data<T2>();
if (axis == out_dims.size() - 1) {
int ndim = x.dims().size();
int row_size = x.dims()[ndim - 1];
int num_rows = x.numel() / row_size;
dim3 threads(16, 32);
dim3 grid(
std::min(dev_ctx.GetCUDAMaxGridDimSize()[0],
static_cast<int>(std::ceil(static_cast<float>(num_rows) /
static_cast<float>(threads.y)))));
KernelScanInnerWithIndices<T1, T2, 16, 32>
<<<grid, threads, 0, dev_ctx.stream()>>>(
x_data, values_data, indices_data, num_rows, row_size, init, op);
} else {
int64_t row_size = x.dims()[axis];
auto sizes = phi::vectorize(x.dims());
const int64_t num_orows =
std::accumulate(sizes.begin(),
sizes.begin() + axis,
int64_t(1),
[](int64_t& a, int64_t& b) { return a * b; });
const int64_t num_irows =
std::accumulate(sizes.begin() + axis + 1,
sizes.end(),
int64_t(1),
[](int64_t& a, int64_t& b) { return a * b; });
dim3 threads(std::min(512, static_cast<int>(num_irows)));
int64_t maxGridDim = dev_ctx.GetCUDAMaxGridDimSize()[1];
dim3 grid(std::min(maxGridDim, num_orows),
std::min(maxGridDim,
static_cast<int64_t>(
std::ceil(static_cast<double>(num_irows) /
static_cast<double>(threads.x)))));
KernelScanOuterWithIndices<T1, T2>
<<<grid, threads, 0, dev_ctx.stream()>>>(x_data,
values_data,
indices_data,
num_orows,
num_irows,
row_size,
init,
op);
}
}
template <typename T, typename Context>
void CummaxKernel(const Context& dev_ctx,
const DenseTensor& x,
int axis,
int dtype,
DenseTensor* out,
DenseTensor* indices) {
auto indices_type = phi::TransToPhiDataType(dtype);
T init = std::is_floating_point<T>::value
? (-1 * std::numeric_limits<T>::infinity())
: std::numeric_limits<T>::lowest();
if (indices_type == DataType::INT32) {
ScanWithIndicesKernel<T, int32_t, std::greater_equal<T>, Context>(
dev_ctx, x, axis, init, out, indices);
} else if (indices_type == DataType::INT64) {
ScanWithIndicesKernel<T, int64_t, std::greater_equal<T>, Context>(
dev_ctx, x, axis, init, out, indices);
}
}
template <typename T, typename Context>
void CumminKernel(const Context& dev_ctx,
const DenseTensor& x,
int axis,
int dtype,
DenseTensor* out,
DenseTensor* indices) {
auto indices_type = phi::TransToPhiDataType(dtype);
T init = std::is_floating_point<T>::value ? std::numeric_limits<T>::infinity()
: std::numeric_limits<T>::max();
if (indices_type == DataType::INT32) {
ScanWithIndicesKernel<T, int32_t, std::less_equal<T>, Context>(
dev_ctx, x, axis, init, out, indices);
} else if (indices_type == DataType::INT64) {
ScanWithIndicesKernel<T, int64_t, std::less_equal<T>, Context>(
dev_ctx, x, axis, init, out, indices);
}
}
} // namespace phi
PD_REGISTER_KERNEL(cummax,
GPU,
ALL_LAYOUT,
phi::CummaxKernel,
float,
double,
int32_t,
int64_t) {}
PD_REGISTER_KERNEL(cummin,
GPU,
ALL_LAYOUT,
phi::CumminKernel,
float,
double,
int32_t,
int64_t) {}
......@@ -212,6 +212,8 @@ from .tensor.math import cos # noqa: F401
from .tensor.math import tan # noqa: F401
from .tensor.math import cosh # noqa: F401
from .tensor.math import cumsum # noqa: F401
from .tensor.math import cummax # noqa: F401
from .tensor.math import cummin # noqa: F401
from .tensor.math import cumprod # noqa: F401
from .tensor.math import logcumsumexp # noqa: F401
from .tensor.math import logit # noqa: F401
......@@ -447,6 +449,8 @@ __all__ = [ # noqa
'empty_like',
'eye',
'cumsum',
'cummax',
'cummin',
'cumprod',
'logaddexp',
'logcumsumexp',
......
......@@ -149,6 +149,8 @@ from .math import cos # noqa: F401
from .math import tan # noqa: F401
from .math import cosh # noqa: F401
from .math import cumsum # noqa: F401
from .math import cummax # noqa: F401
from .math import cummin # noqa: F401
from .math import cumprod # noqa: F401
from .math import logcumsumexp # noqa: F401
from .math import logit # noqa: F401
......@@ -344,6 +346,8 @@ tensor_method_func = [ # noqa
'cos',
'cosh',
'cumsum',
'cummax',
'cummin',
'cumprod',
'logcumsumexp',
'logit',
......
......@@ -3384,6 +3384,155 @@ def cumsum(x, axis=None, dtype=None, name=None):
return _cum_sum_(**kwargs)
def cummax(x, axis=None, dtype='int64', name=None):
"""
The cumulative max of the elements along a given axis.
Note:
The first element of the result is the same as the first element of the input.
Args:
x (Tensor): The input tensor needed to be cummaxed.
axis (int, optional): The dimension to accumulate along. -1 means the last dimension. The default (None) is to compute the cummax over the flattened array.
dtype (str, optional): The data type of the indices tensor, can be int32, int64. The default value is int64.
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
Returns:
out (Tensor), The result of cummax operation. The dtype of cummax result is same with input x.
indices (Tensor), The corresponding index results of cummax operation.
Examples:
.. code-block:: python
import paddle
data = paddle.to_tensor([-1, 5, 0, -2, -3, 2])
data = paddle.reshape(data, (2, 3))
y = paddle.cummax(data)
# value: [-1, 5, 5, 5, 5, 5]
# indcies: [0, 1, 1, 1, 1, 1]
y = paddle.cummax(data, axis=0)
# value: [[-1, 5, 0]
# [-1, 5, 2]]
# indcies: [[0, 0, 0]
# [0, 0, 1]]
y = paddle.cummax(data, axis=-1)
# value: [[-1, 5, 5]
# [-2, -2, 2]]
# indcies: [[0, 1, 1]
# [0, 0, 2]]
y = paddle.cummax(data, dtype='int64')
print(y[1].dtype)
# indcies type: paddle.int64
"""
if axis is None:
axis = -1
x = x.flatten(0, len(x.shape) - 1)
check_dtype(dtype, 'dtype', ['int32', 'int64'], 'cummax')
dtype = convert_np_dtype_to_dtype_(dtype)
if in_dynamic_mode():
return _C_ops.cummax(x, axis, dtype)
else:
check_variable_and_dtype(
x,
'x',
['float32', 'float64', 'int32', 'int64'],
'cummax',
)
check_type(x, 'x', (Variable), 'cummax')
helper = LayerHelper('cummax', **locals())
out = helper.create_variable_for_type_inference(x.dtype)
indices = helper.create_variable_for_type_inference(dtype='int64')
helper.append_op(
type='cummax',
inputs={'x': x},
outputs={'out': out, 'indices': indices},
attrs={'axis': axis, 'dtype': dtype},
)
return out, indices
def cummin(x, axis=None, dtype='int64', name=None):
"""
The cumulative min of the elements along a given axis.
Note:
The first element of the result is the same as the first element of the input.
Args:
x (Tensor): The input tensor needed to be cummined.
axis (int, optional): The dimension to accumulate along. -1 means the last dimension. The default (None) is to compute the cummin over the flattened array.
dtype (str, optional): The data type of the indices tensor, can be int32, int64. The default value is int64.
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
Returns:
out (Tensor), The result of cummin operation. The dtype of cummin result is same with input x.
indices (Tensor), The corresponding index results of cummin operation.
Examples:
.. code-block:: python
import paddle
data = paddle.to_tensor([-1, 5, 0, -2, -3, 2])
data = paddle.reshape(data, (2, 3))
y = paddle.cummin(data)
# value: [-1, -1, -1, -2, -3, -3]
# indcies: [0, 0, 0, 3, 4, 4]
y = paddle.cummin(data, axis=0)
# value: [[-1, 5, 0]
# [-2, -3, 0]]
# indcies: [[0, 0, 0]
# [1, 1, 0]]
y = paddle.cummin(data, axis=-1)
# value: [[-1, -1, -1]
# [-2, -3, -3]]
# indcies: [[0, 0, 0]
# [0, 1, 1]]
y = paddle.cummin(data, dtype='int64')
print(y[1].dtype)
# indcies type: paddle.int64
"""
if axis is None:
axis = -1
x = x.flatten(0, len(x.shape) - 1)
check_dtype(dtype, 'dtype', ['int32', 'int64'], 'cummin')
dtype = convert_np_dtype_to_dtype_(dtype)
if in_dynamic_mode():
return _C_ops.cummin(x, axis, dtype)
else:
check_variable_and_dtype(
x,
'x',
['float32', 'float64', 'int32', 'int64'],
'cummin',
)
check_type(x, 'x', (Variable), 'cummin')
helper = LayerHelper('cummin', **locals())
out = helper.create_variable_for_type_inference(x.dtype)
indices = helper.create_variable_for_type_inference(dtype='int64')
helper.append_op(
type='cummin',
inputs={'x': x},
outputs={'out': out, 'indices': indices},
attrs={'axis': axis, 'dtype': dtype},
)
return out, indices
def logcumsumexp(x, axis=None, dtype=None, name=None):
r"""
The logarithm of the cumulative summation of the exponentiation of the elements along a given axis.
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import unittest
import numpy as np
from eager_op_test import OpTest
import paddle
from paddle import fluid
from paddle.fluid import core
def cummax_dim2(arr, axis=None):
if axis is None:
arr = arr.flatten()
cummax = np.maximum.accumulate(arr)
shape = arr.shape
indices = np.zeros(shape).astype('int32')
max_val = -sys.maxsize
max_ind = 0
for i in range(shape[0]):
if arr[i] >= max_val:
max_val = max(arr[i], max_val)
max_ind = i
indices[i] = i
else:
indices[i] = max_ind
else:
cummax = np.maximum.accumulate(arr, axis)
shape = arr.shape
indices = np.zeros(shape).astype('int32')
if axis < 0:
axis = axis + len(shape)
if axis == 0:
for j in range(shape[1]):
max_ind = 0
max_val = -sys.maxsize
for i in range(shape[0]):
if arr[i][j] >= max_val:
max_val = arr[i][j]
max_ind = i
indices[i][j] = i
else:
indices[i][j] = max_ind
elif axis == 1:
for i in range(shape[0]):
max_ind = 0
max_val = -sys.maxsize
for j in range(shape[1]):
if arr[i][j] >= max_val:
max_val = arr[i][j]
max_ind = j
indices[i][j] = j
else:
indices[i][j] = max_ind
else:
raise Exception("unfeasible axis")
return cummax, indices
class TestCummaxOp(OpTest):
def setUp(self):
self.op_type = "cummax"
self.python_api = paddle.cummax
self.dtype = np.float64
self.axis = -1
self.indices_type = 3
self.input_data = np.random.random((10, 10)).astype(self.dtype)
self.set_attrs()
self.inputs = {'x': self.input_data}
self.attrs = {'axis': self.axis, 'dtype': self.indices_type}
self.np_res, self.np_ind = cummax_dim2(self.input_data, axis=self.axis)
self.outputs = {'out': self.np_res, 'indices': self.np_ind}
def set_attrs(self):
pass
def test_check_output(self):
paddle.enable_static()
self.check_output()
def test_check_grad(self):
paddle.enable_static()
self.check_grad(['x'], 'out')
class TestCummaxOpAxis1(TestCummaxOp):
def set_attrs(self):
self.axis = 0
class TestCummaxOpAxis2(TestCummaxOp):
def set_attrs(self):
self.axis = -2
class TestCummaxOpIndexType(TestCummaxOp):
def set_attrs(self):
self.indices_type = 2
class TestCummaxAPI(unittest.TestCase):
def run_cases(self):
data_np = np.random.random((100, 100)).astype(np.float32)
data = paddle.to_tensor(data_np)
y, indices = paddle.cummax(data)
z, ind = cummax_dim2(data_np)
np.testing.assert_array_equal(z, y.numpy())
np.testing.assert_array_equal(ind, indices.numpy())
y, indices = paddle.cummax(data, axis=0)
z, ind = cummax_dim2(data_np, axis=0)
np.testing.assert_array_equal(z, y.numpy())
np.testing.assert_array_equal(ind, indices.numpy())
y, indices = paddle.cummax(data, axis=-1)
z, ind = cummax_dim2(data_np, axis=-1)
np.testing.assert_array_equal(z, y.numpy())
np.testing.assert_array_equal(ind, indices.numpy())
y, indices = paddle.cummax(data, axis=-2)
z, ind = cummax_dim2(data_np, axis=-2)
np.testing.assert_array_equal(z, y.numpy())
np.testing.assert_array_equal(ind, indices.numpy())
y, indices = paddle.cummax(data, axis=-2, dtype='int32')
z, ind = cummax_dim2(data_np, axis=-2)
np.testing.assert_array_equal(z, y.numpy())
np.testing.assert_array_equal(ind, indices.numpy())
self.assertTrue(indices.dtype == core.VarDesc.VarType.INT32)
data_np = np.random.randint(0, 10, size=(100, 100)).astype(np.int32)
data = paddle.to_tensor(data_np)
y, indices = paddle.cummax(data, axis=0)
z, ind = cummax_dim2(data_np, axis=0)
np.testing.assert_array_equal(z, y.numpy())
np.testing.assert_array_equal(ind, indices.numpy())
def run_static(self, use_gpu=False):
with fluid.program_guard(fluid.Program()):
data_np = np.random.random((100, 100)).astype(np.float32)
x = paddle.static.data('x', [100, 100])
y1, indices1 = paddle.cummax(x)
y2, indices2 = paddle.cummax(x, axis=0)
y3, indices3 = paddle.cummax(x, axis=-1)
y4, indices4 = paddle.cummax(x, axis=-2)
y5, indices5 = paddle.cummax(x, axis=-2, dtype=np.int32)
place = fluid.CUDAPlace(0) if use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
out = exe.run(
feed={'x': data_np},
fetch_list=[
y1.name,
indices1.name,
y2.name,
indices2.name,
y3.name,
indices3.name,
y4.name,
indices4.name,
y5.name,
indices5.name,
],
)
z, ind = cummax_dim2(data_np)
np.testing.assert_allclose(z, out[0], rtol=1e-05)
np.testing.assert_allclose(ind, out[1], rtol=1e-05)
z, ind = cummax_dim2(data_np, axis=0)
np.testing.assert_allclose(z, out[2], rtol=1e-05)
np.testing.assert_allclose(ind, out[3], rtol=1e-05)
z, ind = cummax_dim2(data_np, axis=-1)
np.testing.assert_allclose(z, out[4], rtol=1e-05)
np.testing.assert_allclose(ind, out[5], rtol=1e-05)
z, ind = cummax_dim2(data_np, axis=-2)
np.testing.assert_allclose(z, out[6], rtol=1e-05)
np.testing.assert_allclose(ind, out[7], rtol=1e-05)
z, ind = cummax_dim2(data_np, axis=-2)
np.testing.assert_allclose(z, out[8], rtol=1e-05)
np.testing.assert_allclose(ind, out[9], rtol=1e-05)
def test_cpu(self):
paddle.disable_static(paddle.fluid.CPUPlace())
self.run_cases()
paddle.enable_static()
self.run_static()
def test_gpu(self):
if not fluid.core.is_compiled_with_cuda():
return
paddle.disable_static(paddle.fluid.CUDAPlace(0))
self.run_cases()
paddle.enable_static()
self.run_static(use_gpu=True)
def test_errors(self):
paddle.enable_static()
with fluid.program_guard(fluid.Program()):
def test_x_type():
data = [1, 2, 3]
y, indices = paddle.cummax(data, axis=0)
self.assertRaises(TypeError, test_x_type)
paddle.disable_static()
def test_indices_type():
data_np = np.random.random((10, 10)).astype(np.float32)
data = paddle.to_tensor(data_np)
y, indices = paddle.cummax(data, dtype='float32')
self.assertRaises(ValueError, test_indices_type)
def test_axis_outrange():
data_np = np.random.random(100).astype(np.float32)
data = paddle.to_tensor(data_np)
y, indices = paddle.cummax(data, axis=-2)
self.assertRaises(IndexError, test_axis_outrange)
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import unittest
import numpy as np
from eager_op_test import OpTest
import paddle
from paddle import fluid
from paddle.fluid import core
def cummin_dim2(arr, axis=None):
if axis is None:
arr = arr.flatten()
cummin = np.minimum.accumulate(arr)
shape = arr.shape
indices = np.zeros(shape).astype('int32')
min_val = sys.maxsize
min_ind = 0
for i in range(shape[0]):
if arr[i] <= min_val:
min_val = min(arr[i], min_val)
min_ind = i
indices[i] = i
else:
indices[i] = min_ind
else:
cummin = np.minimum.accumulate(arr, axis)
shape = arr.shape
indices = np.zeros(shape).astype('int32')
if axis < 0:
axis = axis + len(shape)
if axis == 0:
for j in range(shape[1]):
min_ind = 0
min_val = sys.maxsize
for i in range(shape[0]):
if arr[i][j] <= min_val:
min_val = arr[i][j]
min_ind = i
indices[i][j] = i
else:
indices[i][j] = min_ind
elif axis == 1:
for i in range(shape[0]):
min_ind = 0
min_val = sys.maxsize
for j in range(shape[1]):
if arr[i][j] <= min_val:
min_val = arr[i][j]
min_ind = j
indices[i][j] = j
else:
indices[i][j] = min_ind
else:
raise Exception("unfeasible axis")
return cummin, indices
class TestCumminOp(OpTest):
def setUp(self):
self.op_type = "cummin"
self.python_api = paddle.cummin
self.dtype = np.float64
self.axis = -1
self.indices_type = 3
self.input_data = np.random.random((10, 10)).astype(self.dtype)
self.set_attrs()
self.inputs = {'x': self.input_data}
self.attrs = {'axis': self.axis, 'dtype': self.indices_type}
self.np_res, self.np_ind = cummin_dim2(self.input_data, axis=self.axis)
self.outputs = {'out': self.np_res, 'indices': self.np_ind}
def set_attrs(self):
pass
def test_check_output(self):
paddle.enable_static()
self.check_output()
def test_check_grad(self):
paddle.enable_static()
self.check_grad(['x'], 'out')
class TestCuinOpAxis1(TestCumminOp):
def set_attrs(self):
self.axis = 0
class TestCumminOpAxis2(TestCumminOp):
def set_attrs(self):
self.axis = -2
class TestCumminOpIndexType(TestCumminOp):
def set_attrs(self):
self.indices_type = 2
class TestCumminAPI(unittest.TestCase):
def run_cases(self):
data_np = np.random.random((100, 100)).astype(np.float32)
data = paddle.to_tensor(data_np)
y, indices = paddle.cummin(data)
z, ind = cummin_dim2(data_np)
np.testing.assert_array_equal(z, y.numpy())
np.testing.assert_array_equal(ind, indices.numpy())
y, indices = paddle.cummin(data, axis=0)
z, ind = cummin_dim2(data_np, axis=0)
np.testing.assert_array_equal(z, y.numpy())
np.testing.assert_array_equal(ind, indices.numpy())
y, indices = paddle.cummin(data, axis=-1)
z, ind = cummin_dim2(data_np, axis=-1)
np.testing.assert_array_equal(z, y.numpy())
np.testing.assert_array_equal(ind, indices.numpy())
y, indices = paddle.cummin(data, axis=-2)
z, ind = cummin_dim2(data_np, axis=-2)
np.testing.assert_array_equal(z, y.numpy())
np.testing.assert_array_equal(ind, indices.numpy())
y, indices = paddle.cummin(data, axis=-2, dtype='int32')
z, ind = cummin_dim2(data_np, axis=-2)
np.testing.assert_array_equal(z, y.numpy())
np.testing.assert_array_equal(ind, indices.numpy())
self.assertTrue(indices.dtype == core.VarDesc.VarType.INT32)
data_np = np.random.randint(0, 10, size=(100, 100)).astype(np.int32)
data = paddle.to_tensor(data_np)
y, indices = paddle.cummin(data, axis=0)
z, ind = cummin_dim2(data_np, axis=0)
np.testing.assert_array_equal(z, y.numpy())
np.testing.assert_array_equal(ind, indices.numpy())
def run_static(self, use_gpu=False):
with fluid.program_guard(fluid.Program()):
data_np = np.random.random((100, 100)).astype(np.float32)
x = paddle.static.data('x', [100, 100])
y1, indices1 = paddle.cummin(x)
y2, indices2 = paddle.cummin(x, axis=0)
y3, indices3 = paddle.cummin(x, axis=-1)
y4, indices4 = paddle.cummin(x, axis=-2)
y5, indices5 = paddle.cummin(x, axis=-2, dtype=np.int32)
place = fluid.CUDAPlace(0) if use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
out = exe.run(
feed={'x': data_np},
fetch_list=[
y1.name,
indices1.name,
y2.name,
indices2.name,
y3.name,
indices3.name,
y4.name,
indices4.name,
y5.name,
indices5.name,
],
)
z, ind = cummin_dim2(data_np)
np.testing.assert_allclose(z, out[0], rtol=1e-05)
np.testing.assert_allclose(ind, out[1], rtol=1e-05)
z, ind = cummin_dim2(data_np, axis=0)
np.testing.assert_allclose(z, out[2], rtol=1e-05)
np.testing.assert_allclose(ind, out[3], rtol=1e-05)
z, ind = cummin_dim2(data_np, axis=-1)
np.testing.assert_allclose(z, out[4], rtol=1e-05)
np.testing.assert_allclose(ind, out[5], rtol=1e-05)
z, ind = cummin_dim2(data_np, axis=-2)
np.testing.assert_allclose(z, out[6], rtol=1e-05)
np.testing.assert_allclose(ind, out[7], rtol=1e-05)
z, ind = cummin_dim2(data_np, axis=-2)
np.testing.assert_allclose(z, out[8], rtol=1e-05)
np.testing.assert_allclose(ind, out[9], rtol=1e-05)
def test_cpu(self):
paddle.disable_static(paddle.fluid.CPUPlace())
self.run_cases()
paddle.enable_static()
self.run_static()
def test_gpu(self):
if not fluid.core.is_compiled_with_cuda():
return
paddle.disable_static(paddle.fluid.CUDAPlace(0))
self.run_cases()
paddle.enable_static()
self.run_static(use_gpu=True)
def test_errors(self):
paddle.enable_static()
with fluid.program_guard(fluid.Program()):
def test_x_type():
data = [1, 2, 3]
y, indices = paddle.cummin(data, axis=0)
self.assertRaises(TypeError, test_x_type)
paddle.disable_static()
def test_indices_type():
data_np = np.random.random((10, 10)).astype(np.float32)
data = paddle.to_tensor(data_np)
y, indices = paddle.cummin(data, dtype='float32')
self.assertRaises(ValueError, test_indices_type)
def test_axis_outrange():
data_np = np.random.random(100).astype(np.float32)
data = paddle.to_tensor(data_np)
y, indices = paddle.cummin(data, axis=-2)
self.assertRaises(IndexError, test_axis_outrange)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册