未验证 提交 3d4d995f 编写于 作者: zhouweiwei2014's avatar zhouweiwei2014 提交者: GitHub

[Zero-Dim] paddle.nanmedian/nanquantile support 0D Tensor (#54500)

* [Zero-Dim] paddle.nanmedian support 0D Tensor

* fix CI
上级 ca59c72b
...@@ -2323,37 +2323,47 @@ void NanmedianInferMeta(const MetaTensor& x, ...@@ -2323,37 +2323,47 @@ void NanmedianInferMeta(const MetaTensor& x,
for (int64_t i = 0; i < x_rank; i++) { for (int64_t i = 0; i < x_rank; i++) {
out_dim.push_back(1); out_dim.push_back(1);
} }
} else {
out_dim.push_back(1);
} }
} else { } else {
std::vector<int64_t> cleaned_axis; std::vector<int64_t> formated_axis;
for (auto& axis : axis_list) { for (auto& axis : axis_list) {
if (x_rank == 0) {
PADDLE_ENFORCE_EQ(axis == 0 || axis == -1,
true,
phi::errors::InvalidArgument(
"When input 0D Tensor, each element of the axis "
"can only be -1, 0, None"));
} else {
PADDLE_ENFORCE_LT(axis,
x_rank,
errors::InvalidArgument(
"each element of the axis should be in the "
"range [ -dimension(X), dimension(X) ) "
"which dimesion = %d. But received axis = %d.",
x_rank,
axis));
PADDLE_ENFORCE_GE(axis,
-x_rank,
errors::InvalidArgument(
"each element of the axis should be in the "
"range [ -dimension(X), dimension(X) ) "
"which dimesion = %d. But received axis = %d.",
x_rank,
axis));
}
if (axis < 0) axis += x_rank; if (axis < 0) axis += x_rank;
PADDLE_ENFORCE_LT(
axis,
x_rank,
errors::InvalidArgument(
"Attr(axis) value should be in range [-R, R-1], R is "
"the rank of Input(X). But received axis: %d, R: %d. "
"Current Input(X)'s shape is=[%s].",
axis,
x_rank,
x_dim));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
std::find(cleaned_axis.begin(), cleaned_axis.end(), axis), std::find(formated_axis.begin(), formated_axis.end(), axis),
cleaned_axis.end(), formated_axis.end(),
errors::InvalidArgument("Attr(axes) has duplicated elements: %d.", errors::InvalidArgument("Attr(axes) has duplicated elements: %d.",
static_cast<int>(axis))); static_cast<int>(axis)));
cleaned_axis.push_back(axis); formated_axis.push_back(axis);
} }
for (int64_t i = 0; i < x_rank; i++) { for (int64_t i = 0; i < x_rank; i++) {
if (std::find(cleaned_axis.begin(), cleaned_axis.end(), i) == if (std::find(formated_axis.begin(), formated_axis.end(), i) ==
cleaned_axis.end()) { formated_axis.end()) {
out_dim.push_back(x_dim[i]); out_dim.push_back(x_dim[i]);
} else if (keep_dim) { } else if (keep_dim) {
out_dim.push_back(1); out_dim.push_back(1);
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
#include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/impl/nanmedian_grad_kernel_impl.h" #include "paddle/phi/kernels/funcs/nanmedian_utils.h"
namespace phi { namespace phi {
...@@ -26,67 +26,64 @@ void CalcMedianGradKernel(const Context& dev_ctx, ...@@ -26,67 +26,64 @@ void CalcMedianGradKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const DenseTensor& median_index, const DenseTensor& median_index,
const DenseTensor& out_grad, const DenseTensor& out_grad,
const IntArray& axes UNUSED, DenseTensor* x_grad) {
DenseTensor* x_grad, T* dx_data = dev_ctx.template Alloc<T>(x_grad);
T* x_grad_ptr) { if (!dx_data) return;
phi::funcs::SetConstant<Context, T> set_zero; phi::funcs::SetConstant<Context, T> set_zero;
set_zero(dev_ctx, x_grad, static_cast<T>(0)); set_zero(dev_ctx, x_grad, static_cast<T>(0));
if (!x_grad_ptr) return;
const int64_t* m_ptr = median_index.data<int64_t>(); const int64_t* m_data = median_index.data<int64_t>();
const T* out_grad_ptr = out_grad.data<T>(); const T* dout_data = out_grad.data<T>();
int64_t numel = x.numel(); int64_t numel = x.numel();
auto x_dim = x.dims(); auto x_dim = x.dims();
int64_t rank = x_dim.size(); int64_t rank = x_dim.size();
int64_t stride = x_dim[rank - 1]; int64_t stride = x_dim[rank - 1];
int64_t pre_dim = numel / stride; int64_t pre_dim = numel / stride;
int64_t i = 0; int64_t i = 0;
int64_t offset = 0; int64_t offset = 0;
T div_factor = static_cast<T>(2.0);
for (i = 0; i < pre_dim; i++) { for (i = 0; i < pre_dim; i++) {
if (m_ptr[2 * i] >= 0) { if (m_data[2 * i] >= 0) {
if (m_ptr[2 * i] == m_ptr[2 * i + 1]) { if (m_data[2 * i] == m_data[2 * i + 1]) {
x_grad_ptr[offset + m_ptr[2 * i]] = out_grad_ptr[i]; dx_data[offset + m_data[2 * i]] = dout_data[i];
} else { } else {
x_grad_ptr[offset + m_ptr[2 * i]] = out_grad_ptr[i] / div_factor; dx_data[offset + m_data[2 * i]] = dout_data[i] / static_cast<T>(2.0);
x_grad_ptr[offset + m_ptr[2 * i + 1]] = out_grad_ptr[i] / div_factor; dx_data[offset + m_data[2 * i + 1]] =
dout_data[i] / static_cast<T>(2.0);
} }
} }
offset += stride; offset += stride;
} }
} }
template <typename T, typename Context>
void BaseMedianGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& median_index,
const DenseTensor& out_grad,
const IntArray& axes,
DenseTensor* x_grad) {
auto rank = x.dims().size();
T* x_grad_ptr = dev_ctx.template Alloc<T>(x_grad);
if (axes.size() && (rank > 1)) {
DenseTensor tmp_x_grad(*x_grad);
CalcMedianGradKernel<T, Context>(
dev_ctx, x, median_index, out_grad, axes, &tmp_x_grad, x_grad_ptr);
PostprocessMedianGradKernel<T, Context>(dev_ctx, &tmp_x_grad, axes, x_grad);
} else {
CalcMedianGradKernel<T, Context>(
dev_ctx, x, median_index, out_grad, axes, x_grad, x_grad_ptr);
}
}
template <typename T, typename Context> template <typename T, typename Context>
void NanmedianGradKernel(const Context& dev_ctx, void NanmedianGradKernel(const Context& dev_ctx,
const DenseTensor& input, const DenseTensor& x,
const DenseTensor& median_index, const DenseTensor& median_index,
const DenseTensor& out_grad, const DenseTensor& out_grad,
const IntArray& axes, const IntArray& axes,
bool keep_dim UNUSED, bool keepdim UNUSED,
DenseTensor* x_grad) { DenseTensor* x_grad) {
BaseMedianGradKernel<T, Context>( DenseTensor tmp_x;
dev_ctx, input, median_index, out_grad, axes, x_grad); auto rank = x.dims().size();
if ((axes.size() == 0) || rank <= 1) {
tmp_x = x;
tmp_x.Resize({x.numel()});
CalcMedianGradKernel<T, Context>(
dev_ctx, tmp_x, median_index, out_grad, x_grad);
} else {
funcs::PreprocessMedianKernel<T, Context>(dev_ctx, x, axes, &tmp_x);
DenseTensor tmp_x_grad;
tmp_x_grad.Resize(x_grad->dims());
CalcMedianGradKernel<T, Context>(
dev_ctx, tmp_x, median_index, out_grad, &tmp_x_grad);
dev_ctx.template Alloc<T>(x_grad);
funcs::PostprocessMedianGradKernel<T, Context>(
dev_ctx, &tmp_x_grad, axes, x_grad);
}
} }
} // namespace phi } // namespace phi
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
#include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/nanmedian_kernel_impl.h" #include "paddle/phi/kernels/funcs/nanmedian_utils.h"
#include "paddle/phi/kernels/top_k_kernel.h" #include "paddle/phi/kernels/top_k_kernel.h"
namespace phi { namespace phi {
...@@ -31,7 +31,6 @@ void CalcMedianFunc(const Context& dev_ctx, ...@@ -31,7 +31,6 @@ void CalcMedianFunc(const Context& dev_ctx,
int64_t pre_dim, int64_t pre_dim,
T* o_ptr, T* o_ptr,
int64_t* m_ptr) { int64_t* m_ptr) {
bool should_ignore_nan = ignore_nan;
DenseTensor sort_out; DenseTensor sort_out;
DenseTensor sort_indices; DenseTensor sort_indices;
auto sort_dim = x.dims(); auto sort_dim = x.dims();
...@@ -52,7 +51,7 @@ void CalcMedianFunc(const Context& dev_ctx, ...@@ -52,7 +51,7 @@ void CalcMedianFunc(const Context& dev_ctx,
int64_t offset = 0; int64_t offset = 0;
int64_t i = 0; int64_t i = 0;
bool is_ori_odd = stride & 1; bool is_ori_odd = stride & 1;
if (should_ignore_nan) { if (ignore_nan) {
for (i = 0; i < pre_dim; i++) { for (i = 0; i < pre_dim; i++) {
offset = i * sort_k; offset = i * sort_k;
if (nan_counts[i] == stride) { if (nan_counts[i] == stride) {
...@@ -107,11 +106,11 @@ void CalcMedianFunc(const Context& dev_ctx, ...@@ -107,11 +106,11 @@ void CalcMedianFunc(const Context& dev_ctx,
template <typename T, typename Context> template <typename T, typename Context>
void ProcessMedianKernel(const Context& dev_ctx, void ProcessMedianKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
T* o_ptr, DenseTensor* out,
int64_t* m_ptr, DenseTensor* median_index) {
bool ignore_nan) { const T* x_data = x.data<T>();
bool should_ignore_nan = ignore_nan; T* out_data = dev_ctx.template Alloc<T>(out);
const T* x_ptr = x.data<T>(); int64_t* m_data = dev_ctx.template Alloc<int64_t>(median_index);
int64_t numel = x.numel(); int64_t numel = x.numel();
auto x_dim = x.dims(); auto x_dim = x.dims();
...@@ -122,7 +121,8 @@ void ProcessMedianKernel(const Context& dev_ctx, ...@@ -122,7 +121,8 @@ void ProcessMedianKernel(const Context& dev_ctx,
int64_t max_valid_num = 0; int64_t max_valid_num = 0;
std::vector<int64_t> nan_counts; std::vector<int64_t> nan_counts;
if (should_ignore_nan) { bool ignore_nan = true;
if (ignore_nan) {
int64_t total_nan_num = 0; int64_t total_nan_num = 0;
std::vector<T> col_vec; std::vector<T> col_vec;
col_vec.reserve(stride); col_vec.reserve(stride);
...@@ -133,7 +133,7 @@ void ProcessMedianKernel(const Context& dev_ctx, ...@@ -133,7 +133,7 @@ void ProcessMedianKernel(const Context& dev_ctx,
for (int64_t i = 0; i < pre_dim; i++) { for (int64_t i = 0; i < pre_dim; i++) {
col_vec.clear(); col_vec.clear();
col_vec.insert( col_vec.insert(
col_vec.begin(), x_ptr + i * stride, x_ptr + (i + 1) * stride); col_vec.begin(), x_data + i * stride, x_data + (i + 1) * stride);
nan_counts[i] = nan_counts[i] =
std::count_if(col_vec.begin(), col_vec.end(), [&](const T& val) { std::count_if(col_vec.begin(), col_vec.end(), [&](const T& val) {
return std::isnan(static_cast<float>(val)); return std::isnan(static_cast<float>(val));
...@@ -145,47 +145,25 @@ void ProcessMedianKernel(const Context& dev_ctx, ...@@ -145,47 +145,25 @@ void ProcessMedianKernel(const Context& dev_ctx,
// all elems are nan // all elems are nan
if (total_nan_num == numel) { if (total_nan_num == numel) {
for (i = 0; i < pre_dim; i++) { for (i = 0; i < pre_dim; i++) {
o_ptr[i] = x_ptr[0]; out_data[i] = std::numeric_limits<T>::quiet_NaN();
m_ptr[2 * i] = -1; m_data[2 * i] = -1;
m_ptr[2 * i + 1] = -1; m_data[2 * i + 1] = -1;
} }
return; return;
} }
should_ignore_nan = total_nan_num > 0; ignore_nan = total_nan_num > 0;
} }
int64_t sort_k = should_ignore_nan ? max_valid_num : ((stride >> 1) + 1); int64_t sort_k = ignore_nan ? max_valid_num : ((stride >> 1) + 1);
CalcMedianFunc<T, Context>(dev_ctx, CalcMedianFunc<T, Context>(dev_ctx,
x, x,
nan_counts, nan_counts,
should_ignore_nan, ignore_nan,
sort_k, sort_k,
stride, stride,
pre_dim, pre_dim,
o_ptr, out_data,
m_ptr); m_data);
}
template <typename T, typename Context>
void BaseMedianKernel(const Context& dev_ctx,
const DenseTensor& input,
const IntArray& axes,
DenseTensor* out,
DenseTensor* median_index,
bool ignore_nan) {
DenseTensor x;
auto rank = input.dims().size();
if ((axes.size() == 0) || rank <= 1) {
x = input;
x.Resize({input.numel()});
} else {
PreprocessMedianKernel<T, Context>(dev_ctx, input, axes, &x);
}
T* o_ptr = dev_ctx.template Alloc<T>(out);
int64_t* m_ptr = dev_ctx.template Alloc<int64_t>(median_index);
ProcessMedianKernel<T, Context>(dev_ctx, x, o_ptr, m_ptr, ignore_nan);
out->Resize(out->dims());
} }
template <typename T, typename Context> template <typename T, typename Context>
...@@ -195,7 +173,16 @@ void NanmedianKernel(const Context& dev_ctx, ...@@ -195,7 +173,16 @@ void NanmedianKernel(const Context& dev_ctx,
bool keepdim UNUSED, bool keepdim UNUSED,
DenseTensor* out, DenseTensor* out,
DenseTensor* median_index) { DenseTensor* median_index) {
BaseMedianKernel<T, Context>(dev_ctx, x, axes, out, median_index, true); DenseTensor tmp_x;
auto rank = x.dims().size();
if ((axes.size() == 0) || rank <= 1) {
tmp_x = x;
tmp_x.Resize({x.numel()});
} else {
funcs::PreprocessMedianKernel<T, Context>(dev_ctx, x, axes, &tmp_x);
}
ProcessMedianKernel<T, Context>(dev_ctx, tmp_x, out, median_index);
} }
} // namespace phi } // namespace phi
......
...@@ -15,9 +15,51 @@ ...@@ -15,9 +15,51 @@
#pragma once #pragma once
#include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/nanmedian_kernel.h"
namespace phi { namespace phi {
namespace funcs {
template <typename T, typename Context>
void PostprocessMedianGradKernel(const Context& dev_ctx,
DenseTensor* input,
const IntArray& raw_axes,
DenseTensor* x) {
auto input_dim = input->dims();
auto rank = input_dim.size();
std::vector<int64_t> axes = raw_axes.GetData();
int64_t axes_size = static_cast<int>(axes.size());
for (int64_t i = 0; i < axes_size; i++) {
if (axes[i] < 0) {
axes[i] += rank;
}
}
std::vector<int> trans_back;
std::vector<int> reshape_back;
trans_back.resize(rank);
int offset = 0;
for (int64_t i = 0; i < rank; i++) {
if (std::find(axes.begin(), axes.end(), i) == axes.end()) {
reshape_back.push_back(input_dim[i]);
trans_back[i] = offset;
offset += 1;
}
}
for (int64_t i = 0; i < rank; i++) {
if (std::find(axes.begin(), axes.end(), i) != axes.end()) {
trans_back[i] = offset;
reshape_back.push_back(input_dim[i]);
offset += 1;
}
}
input->Resize(make_ddim(reshape_back));
funcs::TransCompute<Context, T>(
static_cast<int>(trans_back.size()), dev_ctx, *input, x, trans_back);
}
template <typename T, typename Context> template <typename T, typename Context>
void PreprocessMedianKernel(const Context& dev_ctx, void PreprocessMedianKernel(const Context& dev_ctx,
...@@ -65,4 +107,5 @@ void PreprocessMedianKernel(const Context& dev_ctx, ...@@ -65,4 +107,5 @@ void PreprocessMedianKernel(const Context& dev_ctx,
x->Resize(make_ddim(reshape)); x->Resize(make_ddim(reshape));
} }
} // namespace funcs
} // namespace phi } // namespace phi
...@@ -20,7 +20,7 @@ ...@@ -20,7 +20,7 @@
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_meta.h" #include "paddle/phi/core/tensor_meta.h"
#include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/impl/nanmedian_grad_kernel_impl.h" #include "paddle/phi/kernels/funcs/nanmedian_utils.h"
namespace phi { namespace phi {
...@@ -30,23 +30,26 @@ inline int GET_BLOCKS(const int N) { ...@@ -30,23 +30,26 @@ inline int GET_BLOCKS(const int N) {
} }
template <typename T> template <typename T>
__global__ void KernelNanmedianGrad(const T* x_ptr, __global__ void KernelNanmedianGrad(const T* x_data,
const int64_t* medians_ptr, const int64_t* medians_ptr,
const T* out_grad_ptr, const T* out_grad_ptr,
T* x_grad_ptr, T* dx_data,
int64_t stride, int64_t stride,
int64_t pre_dim, int64_t pre_dim) {
T div_factor) {
CUDA_KERNEL_LOOP(index, pre_dim) { CUDA_KERNEL_LOOP(index, pre_dim) {
int64_t offset = index * stride; int64_t offset = index * stride;
printf("index: %d\n", index);
printf("medians_ptr[2 * index]: %d\n", medians_ptr[2 * index]);
printf("medians_ptr[2 * index+1]: %d\n", medians_ptr[2 * index + 1]);
if (medians_ptr[2 * index] >= 0) { if (medians_ptr[2 * index] >= 0) {
if (medians_ptr[2 * index] == medians_ptr[2 * index + 1]) { if (medians_ptr[2 * index] == medians_ptr[2 * index + 1]) {
x_grad_ptr[offset + medians_ptr[2 * index]] = out_grad_ptr[index]; dx_data[offset + medians_ptr[2 * index]] = out_grad_ptr[index];
} else { } else {
x_grad_ptr[offset + medians_ptr[2 * index]] = dx_data[offset + medians_ptr[2 * index]] =
out_grad_ptr[index] / div_factor; out_grad_ptr[index] / static_cast<T>(2.0);
x_grad_ptr[offset + medians_ptr[2 * index + 1]] = dx_data[offset + medians_ptr[2 * index + 1]] =
out_grad_ptr[index] / div_factor; out_grad_ptr[index] / static_cast<T>(2.0);
} }
} }
} }
...@@ -57,14 +60,17 @@ void CalcMedianGradKernel(const Context& dev_ctx, ...@@ -57,14 +60,17 @@ void CalcMedianGradKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const DenseTensor& median_index, const DenseTensor& median_index,
const DenseTensor& out_grad, const DenseTensor& out_grad,
DenseTensor* x_grad, DenseTensor* x_grad) {
T* x_grad_ptr) { T* dx_data = dev_ctx.template Alloc<T>(x_grad);
if (!dx_data) return;
phi::funcs::SetConstant<Context, T> set_zero; phi::funcs::SetConstant<Context, T> set_zero;
set_zero(dev_ctx, x_grad, static_cast<T>(0)); set_zero(dev_ctx, x_grad, static_cast<T>(0));
VLOG(0) << "x_grad->dims(): " << x_grad->dims();
auto stream = dev_ctx.stream(); auto stream = dev_ctx.stream();
const T* x_ptr = x.data<T>(); const T* x_data = x.data<T>();
const int64_t* m_ptr = median_index.data<int64_t>(); const int64_t* m_data = median_index.data<int64_t>();
const T* out_grad_ptr = out_grad.data<T>(); const T* out_grad_ptr = out_grad.data<T>();
int64_t numel = x.numel(); int64_t numel = x.numel();
...@@ -73,42 +79,38 @@ void CalcMedianGradKernel(const Context& dev_ctx, ...@@ -73,42 +79,38 @@ void CalcMedianGradKernel(const Context& dev_ctx,
int64_t stride = x_dim[x_rank - 1]; int64_t stride = x_dim[x_rank - 1];
int64_t pre_dim = numel / stride; int64_t pre_dim = numel / stride;
T div_factor = static_cast<T>(2.0);
KernelNanmedianGrad<T> KernelNanmedianGrad<T>
<<<GET_BLOCKS(pre_dim), PADDLE_CUDA_NUM_THREADS, 0, stream>>>( <<<GET_BLOCKS(pre_dim), PADDLE_CUDA_NUM_THREADS, 0, stream>>>(
x_ptr, m_ptr, out_grad_ptr, x_grad_ptr, stride, pre_dim, div_factor); x_data, m_data, out_grad_ptr, dx_data, stride, pre_dim);
}
template <typename T, typename Context>
void BaseMedianGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& median_index,
const DenseTensor& out_grad,
const IntArray& axes,
DenseTensor* x_grad) {
auto rank = x.dims().size();
T* x_grad_ptr = dev_ctx.template Alloc<T>(x_grad);
if (axes.size() && (rank > 1)) {
DenseTensor tmp_x_grad(*x_grad);
CalcMedianGradKernel<T, Context>(
dev_ctx, x, median_index, out_grad, &tmp_x_grad, x_grad_ptr);
PostprocessMedianGradKernel<T, Context>(dev_ctx, &tmp_x_grad, axes, x_grad);
} else {
CalcMedianGradKernel<T, Context>(
dev_ctx, x, median_index, out_grad, x_grad, x_grad_ptr);
}
} }
template <typename T, typename Context> template <typename T, typename Context>
void NanmedianGradKernel(const Context& dev_ctx, void NanmedianGradKernel(const Context& dev_ctx,
const DenseTensor& input, const DenseTensor& x,
const DenseTensor& median_index, const DenseTensor& median_index,
const DenseTensor& out_grad, const DenseTensor& out_grad,
const IntArray& axes, const IntArray& axes,
bool keep_dim, bool keepdim UNUSED,
DenseTensor* x_grad) { DenseTensor* x_grad) {
BaseMedianGradKernel<T, Context>( DenseTensor tmp_x;
dev_ctx, input, median_index, out_grad, axes, x_grad); auto rank = x.dims().size();
if ((axes.size() == 0) || rank <= 1) {
tmp_x = x;
tmp_x.Resize({x.numel()});
CalcMedianGradKernel<T, Context>(
dev_ctx, tmp_x, median_index, out_grad, x_grad);
} else {
funcs::PreprocessMedianKernel<T, Context>(dev_ctx, x, axes, &tmp_x);
DenseTensor tmp_x_grad;
tmp_x_grad.Resize(x_grad->dims());
CalcMedianGradKernel<T, Context>(
dev_ctx, tmp_x, median_index, out_grad, &tmp_x_grad);
dev_ctx.template Alloc<T>(x_grad);
funcs::PostprocessMedianGradKernel<T, Context>(
dev_ctx, &tmp_x_grad, axes, x_grad);
}
} }
} // namespace phi } // namespace phi
......
...@@ -20,7 +20,7 @@ ...@@ -20,7 +20,7 @@
#include "paddle/phi/common/memory_utils.h" #include "paddle/phi/common/memory_utils.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/full_kernel.h" #include "paddle/phi/kernels/full_kernel.h"
#include "paddle/phi/kernels/impl/nanmedian_kernel_impl.h" #include "paddle/phi/kernels/funcs/nanmedian_utils.h"
#include "paddle/phi/kernels/top_k_kernel.h" #include "paddle/phi/kernels/top_k_kernel.h"
namespace phi { namespace phi {
...@@ -138,14 +138,13 @@ __global__ void CalcNanmedianKernel(const T* sort_out_ptr, ...@@ -138,14 +138,13 @@ __global__ void CalcNanmedianKernel(const T* sort_out_ptr,
template <typename T, typename Context> template <typename T, typename Context>
void ProcessMedianKernel(const Context& dev_ctx, void ProcessMedianKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
bool ignore_nan,
DenseTensor* out, DenseTensor* out,
int64_t* m_ptr) { DenseTensor* median_index) {
bool should_ignore_nan = ignore_nan;
auto stream = dev_ctx.stream(); auto stream = dev_ctx.stream();
const T* x_data = x.data<T>();
T* out_data = dev_ctx.template Alloc<T>(out);
int64_t* m_data = dev_ctx.template Alloc<int64_t>(median_index);
const T* x_ptr = x.data<T>();
T* o_ptr = dev_ctx.template Alloc<T>(out);
int64_t numel = x.numel(); int64_t numel = x.numel();
auto x_dim = x.dims(); auto x_dim = x.dims();
int64_t x_rank = x_dim.size(); int64_t x_rank = x_dim.size();
...@@ -156,7 +155,9 @@ void ProcessMedianKernel(const Context& dev_ctx, ...@@ -156,7 +155,9 @@ void ProcessMedianKernel(const Context& dev_ctx,
DenseTensor nan_counts, nan_stat; DenseTensor nan_counts, nan_stat;
int64_t* nan_counts_ptr; int64_t* nan_counts_ptr;
int64_t max_valid_num = 0; int64_t max_valid_num = 0;
if (should_ignore_nan) {
bool ignore_nan = true;
if (ignore_nan) {
nan_counts.Resize(phi::make_ddim({pre_dim})); nan_counts.Resize(phi::make_ddim({pre_dim}));
dev_ctx.template Alloc<int64_t>(&nan_counts); dev_ctx.template Alloc<int64_t>(&nan_counts);
nan_counts_ptr = nan_counts.data<int64_t>(); nan_counts_ptr = nan_counts.data<int64_t>();
...@@ -167,7 +168,7 @@ void ProcessMedianKernel(const Context& dev_ctx, ...@@ -167,7 +168,7 @@ void ProcessMedianKernel(const Context& dev_ctx,
KernelNanCounts<T><<<GET_BLOCKS(numel), KernelNanCounts<T><<<GET_BLOCKS(numel),
PADDLE_CUDA_NUM_THREADS, PADDLE_CUDA_NUM_THREADS,
pre_dim * sizeof(int64_t), pre_dim * sizeof(int64_t),
stream>>>(x_ptr, stream>>>(x_data,
numel, numel,
pre_dim, pre_dim,
stride, stride,
...@@ -189,15 +190,19 @@ void ProcessMedianKernel(const Context& dev_ctx, ...@@ -189,15 +190,19 @@ void ProcessMedianKernel(const Context& dev_ctx,
// all elements are nan values // all elements are nan values
T nan_val = std::numeric_limits<T>::quiet_NaN(); T nan_val = std::numeric_limits<T>::quiet_NaN();
if (nan_stat_cpu_ptr[0] == numel) { if (nan_stat_cpu_ptr[0] == numel) {
FullLikeKernel<T, Context>(dev_ctx, x, nan_val, x.dtype(), out); phi::funcs::SetConstant<Context, T> set_nan;
set_nan(dev_ctx, out, nan_val);
phi::funcs::SetConstant<Context, int64_t> set_negatvie;
set_negatvie(dev_ctx, median_index, static_cast<int64_t>(-1));
return; return;
} }
should_ignore_nan = nan_stat_cpu_ptr[0] > 0; ignore_nan = nan_stat_cpu_ptr[0] > 0;
max_valid_num = nan_stat_cpu_ptr[1]; max_valid_num = nan_stat_cpu_ptr[1];
} }
int64_t sort_k = should_ignore_nan ? max_valid_num : ((stride >> 1) + 1); int64_t sort_k = ignore_nan ? max_valid_num : ((stride >> 1) + 1);
bool is_ori_odd = stride & 1; bool is_ori_odd = stride & 1;
DenseTensor sort_out, sort_indices; DenseTensor sort_out, sort_indices;
...@@ -217,14 +222,14 @@ void ProcessMedianKernel(const Context& dev_ctx, ...@@ -217,14 +222,14 @@ void ProcessMedianKernel(const Context& dev_ctx,
T div_factor = static_cast<T>(2.0); T div_factor = static_cast<T>(2.0);
T nan_val = std::numeric_limits<T>::quiet_NaN(); T nan_val = std::numeric_limits<T>::quiet_NaN();
if (should_ignore_nan) { if (ignore_nan) {
CalcNanmedianKernel<T> CalcNanmedianKernel<T>
<<<GET_BLOCKS(pre_dim), PADDLE_CUDA_NUM_THREADS, 0, stream>>>( <<<GET_BLOCKS(pre_dim), PADDLE_CUDA_NUM_THREADS, 0, stream>>>(
sort_out_ptr, sort_out_ptr,
sort_indices_ptr, sort_indices_ptr,
nan_counts_ptr, nan_counts_ptr,
m_ptr, m_data,
o_ptr, out_data,
is_ori_odd, is_ori_odd,
pre_dim, pre_dim,
max_valid_num, max_valid_num,
...@@ -236,8 +241,8 @@ void ProcessMedianKernel(const Context& dev_ctx, ...@@ -236,8 +241,8 @@ void ProcessMedianKernel(const Context& dev_ctx,
<<<GET_BLOCKS(pre_dim), PADDLE_CUDA_NUM_THREADS, 0, stream>>>( <<<GET_BLOCKS(pre_dim), PADDLE_CUDA_NUM_THREADS, 0, stream>>>(
sort_out_ptr, sort_out_ptr,
sort_indices_ptr, sort_indices_ptr,
m_ptr, m_data,
o_ptr, out_data,
div_factor, div_factor,
is_ori_odd, is_ori_odd,
pre_dim, pre_dim,
...@@ -245,27 +250,6 @@ void ProcessMedianKernel(const Context& dev_ctx, ...@@ -245,27 +250,6 @@ void ProcessMedianKernel(const Context& dev_ctx,
} }
} }
template <typename T, typename Context>
void BaseMedianKernel(const Context& dev_ctx,
const DenseTensor& input,
const IntArray& axes,
bool ignore_nan,
DenseTensor* out,
DenseTensor* median_index) {
DenseTensor x;
auto rank = input.dims().size();
if ((axes.size() == 0) || rank <= 1) {
x = input;
x.Resize({input.numel()});
} else {
PreprocessMedianKernel<T, Context>(dev_ctx, input, axes, &x);
}
int64_t* m_ptr = dev_ctx.template Alloc<int64_t>(median_index);
ProcessMedianKernel<T, Context>(dev_ctx, x, ignore_nan, out, m_ptr);
out->Resize(out->dims());
}
template <typename T, typename Context> template <typename T, typename Context>
void NanmedianKernel(const Context& dev_ctx, void NanmedianKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
...@@ -273,7 +257,16 @@ void NanmedianKernel(const Context& dev_ctx, ...@@ -273,7 +257,16 @@ void NanmedianKernel(const Context& dev_ctx,
bool keepdim, bool keepdim,
DenseTensor* out, DenseTensor* out,
DenseTensor* median_index) { DenseTensor* median_index) {
BaseMedianKernel<T, Context>(dev_ctx, x, axes, true, out, median_index); DenseTensor tmp_x;
auto rank = x.dims().size();
if ((axes.size() == 0) || rank <= 1) {
tmp_x = x;
tmp_x.Resize({x.numel()});
} else {
funcs::PreprocessMedianKernel<T, Context>(dev_ctx, x, axes, &tmp_x);
}
ProcessMedianKernel<T, Context>(dev_ctx, tmp_x, out, median_index);
} }
} // namespace phi } // namespace phi
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/nanmedian_grad_kernel.h"
namespace phi {
template <typename T, typename Context>
void PostprocessMedianGradKernel(const Context& dev_ctx,
DenseTensor* input,
const IntArray& raw_axes,
DenseTensor* x) {
auto input_dim = input->dims();
auto rank = input_dim.size();
std::vector<int64_t> axes = raw_axes.GetData();
int64_t axes_size = static_cast<int>(axes.size());
for (int64_t i = 0; i < axes_size; i++) {
if (axes[i] < 0) {
axes[i] += rank;
}
}
std::vector<int> trans_back;
std::vector<int> reshape_back;
trans_back.reserve(rank);
trans_back.resize(rank);
int offset = 0;
for (int64_t i = 0; i < rank; i++) {
if (std::find(axes.begin(), axes.end(), i) == axes.end()) {
reshape_back.push_back(input_dim[i]);
trans_back[i] = offset;
offset += 1;
}
}
for (int64_t i = 0; i < rank; i++) {
if (std::find(axes.begin(), axes.end(), i) != axes.end()) {
trans_back[i] = offset;
reshape_back.push_back(input_dim[i]);
offset += 1;
}
}
input->Resize(make_ddim(reshape_back));
funcs::TransCompute<Context, T>(
static_cast<int>(trans_back.size()), dev_ctx, *input, x, trans_back);
}
} // namespace phi
...@@ -255,7 +255,7 @@ def numel(x, name=None): ...@@ -255,7 +255,7 @@ def numel(x, name=None):
return out return out
def nanmedian(x, axis=None, keepdim=True, name=None): def nanmedian(x, axis=None, keepdim=False, name=None):
r""" r"""
Compute the median along the specified axis, while ignoring NaNs. Compute the median along the specified axis, while ignoring NaNs.
...@@ -273,7 +273,7 @@ def nanmedian(x, axis=None, keepdim=True, name=None): ...@@ -273,7 +273,7 @@ def nanmedian(x, axis=None, keepdim=True, name=None):
in the output Tensor. If ``keepdim`` is True, the dimensions of in the output Tensor. If ``keepdim`` is True, the dimensions of
the output Tensor is the same as ``x`` except in the reduced the output Tensor is the same as ``x`` except in the reduced
dimensions(it is of size 1 in this case). Otherwise, the shape of dimensions(it is of size 1 in this case). Otherwise, the shape of
the output Tensor is squeezed in ``axis`` . Default is True. the output Tensor is squeezed in ``axis`` . Default is False.
name (str, optional): Name for the operation (optional, default is None). name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`. For more information, please refer to :ref:`api_guide_Name`.
...@@ -287,16 +287,16 @@ def nanmedian(x, axis=None, keepdim=True, name=None): ...@@ -287,16 +287,16 @@ def nanmedian(x, axis=None, keepdim=True, name=None):
x = paddle.to_tensor([[float('nan'), 2. , 3. ], [0. , 1. , 2. ]]) x = paddle.to_tensor([[float('nan'), 2. , 3. ], [0. , 1. , 2. ]])
y1 = x.nanmedian() y1 = x.nanmedian()
# y1 is [[2.]] # y1 is 2.
y2 = x.nanmedian(0) y2 = x.nanmedian(0)
# y2 is [[0., 1.5, 2.5]] # y2 is [0., 1.5, 2.5]
y3 = x.nanmedian(0, keepdim=False) y3 = x.nanmedian(0, keepdim=True)
# y3 is [0., 1.5, 2.5] # y3 is [[0., 1.5, 2.5]]
y4 = x.nanmedian((0, 1)) y4 = x.nanmedian((0, 1))
# y4 is [[2.]] # y4 is 2.
""" """
if not isinstance(x, Variable): if not isinstance(x, Variable):
raise TypeError("In median, the input x should be a Tensor.") raise TypeError("In median, the input x should be a Tensor.")
...@@ -304,7 +304,6 @@ def nanmedian(x, axis=None, keepdim=True, name=None): ...@@ -304,7 +304,6 @@ def nanmedian(x, axis=None, keepdim=True, name=None):
if isinstance(axis, (list, tuple)) and len(axis) == 0: if isinstance(axis, (list, tuple)) and len(axis) == 0:
raise ValueError("Axis list should not be empty.") raise ValueError("Axis list should not be empty.")
dims = len(x.shape)
if axis is None: if axis is None:
axis = [] axis = []
elif isinstance(axis, tuple): elif isinstance(axis, tuple):
...@@ -312,24 +311,6 @@ def nanmedian(x, axis=None, keepdim=True, name=None): ...@@ -312,24 +311,6 @@ def nanmedian(x, axis=None, keepdim=True, name=None):
elif isinstance(axis, int): elif isinstance(axis, int):
axis = [axis] axis = [axis]
if not isinstance(axis, list):
raise ValueError(
"Axis should be None, int, or a list, element should in range [-rank(x), rank(x))."
)
for i in range(len(axis)):
if not isinstance(axis[i], int) or not (
axis[i] < dims and axis[i] >= -dims
):
raise ValueError(
"Axis should be None, int, or a list, element should in range [-rank(x), rank(x))."
)
if axis[i] < 0:
axis[i] += dims
if len(axis) != len(set(axis)):
raise ValueError("Axis has duplicated elements.")
if in_dynamic_mode(): if in_dynamic_mode():
return _C_ops.nanmedian(x, axis, keepdim) return _C_ops.nanmedian(x, axis, keepdim)
else: else:
......
...@@ -125,6 +125,7 @@ class TestNanmedian(unittest.TestCase): ...@@ -125,6 +125,7 @@ class TestNanmedian(unittest.TestCase):
pd_res = paddle.nanmedian( pd_res = paddle.nanmedian(
paddle.to_tensor(data), keepdim=keep_dim paddle.to_tensor(data), keepdim=keep_dim
) )
assert np_res.shape == pd_res.numpy().shape
np.testing.assert_allclose( np.testing.assert_allclose(
np_res, pd_res.numpy(), rtol=1e-05, equal_nan=True np_res, pd_res.numpy(), rtol=1e-05, equal_nan=True
) )
...@@ -187,6 +188,23 @@ class TestNanmedian(unittest.TestCase): ...@@ -187,6 +188,23 @@ class TestNanmedian(unittest.TestCase):
x_np[0, :] = np.nan x_np[0, :] = np.nan
x_np[1, :3] = np.nan x_np[1, :3] = np.nan
x_np[2, 3:] = np.nan x_np[2, 3:] = np.nan
x_tensor = paddle.to_tensor(x_np, stop_gradient=False)
y = paddle.nanmedian(x_tensor, keepdim=True)
dx = paddle.grad(y, x_tensor)[0].numpy()
np_grad = np.zeros(shape)
np_grad[1, 3] = 0.5
np_grad[3, 2] = 0.5
np.testing.assert_allclose(np_grad, dx, rtol=1e-05, equal_nan=True)
def test_check_grad_axis(self):
paddle.disable_static(place=self.place)
shape = (4, 5)
x_np = np.random.uniform(-1, 1, shape).astype(np.float64)
x_np[0, :] = np.nan
x_np[1, :3] = np.nan
x_np[2, 3:] = np.nan
x_np_sorted = np.sort(x_np) x_np_sorted = np.sort(x_np)
nan_counts = np.count_nonzero(np.isnan(x_np).astype(np.int32), axis=1) nan_counts = np.count_nonzero(np.isnan(x_np).astype(np.int32), axis=1)
np_grad = np.zeros(shape) np_grad = np.zeros(shape)
...@@ -205,10 +223,25 @@ class TestNanmedian(unittest.TestCase): ...@@ -205,10 +223,25 @@ class TestNanmedian(unittest.TestCase):
np_grad[i, j] = 1 if is_odd else 0.5 np_grad[i, j] = 1 if is_odd else 0.5
x_tensor = paddle.to_tensor(x_np, stop_gradient=False) x_tensor = paddle.to_tensor(x_np, stop_gradient=False)
y = paddle.nanmedian(x_tensor, axis=1, keepdim=True) y = paddle.nanmedian(x_tensor, axis=1)
dx = paddle.grad(y, x_tensor)[0].numpy() dx = paddle.grad(y, x_tensor)[0].numpy()
np.testing.assert_allclose(np_grad, dx, rtol=1e-05, equal_nan=True) np.testing.assert_allclose(np_grad, dx, rtol=1e-05, equal_nan=True)
def test_check_grad_0d(self):
paddle.disable_static(place=self.place)
x = paddle.rand([])
x.stop_gradient = False
y = paddle.nanmedian(x)
y.backward()
self.assertEqual(x.grad.shape, [])
np.testing.assert_allclose(x.grad, np.array(1.0))
x = paddle.to_tensor(float('nan'), stop_gradient=False)
y = paddle.nanmedian(x)
y.backward()
self.assertEqual(x.grad.shape, [])
np.testing.assert_allclose(x.grad, np.array(0.0))
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -179,6 +179,8 @@ reduce_api_list = [ ...@@ -179,6 +179,8 @@ reduce_api_list = [
paddle.mean, paddle.mean,
paddle.nansum, paddle.nansum,
paddle.nanmean, paddle.nanmean,
paddle.median,
paddle.nanmedian,
paddle.min, paddle.min,
paddle.max, paddle.max,
paddle.amin, paddle.amin,
...@@ -202,7 +204,7 @@ class TestReduceAPI(unittest.TestCase): ...@@ -202,7 +204,7 @@ class TestReduceAPI(unittest.TestCase):
else: else:
x = paddle.rand([]) x = paddle.rand([])
x.stop_gradient = False x.stop_gradient = False
out = api(x, None) out = api(x, axis=None)
out.retain_grads() out.retain_grads()
out.backward() out.backward()
...@@ -212,9 +214,10 @@ class TestReduceAPI(unittest.TestCase): ...@@ -212,9 +214,10 @@ class TestReduceAPI(unittest.TestCase):
if api not in [paddle.count_nonzero]: if api not in [paddle.count_nonzero]:
np.testing.assert_allclose(out.numpy(), x.numpy()) np.testing.assert_allclose(out.numpy(), x.numpy())
out_empty_list = api(x, []) if api not in [paddle.median, paddle.nanmedian]:
self.assertEqual(out_empty_list, out) out_empty_list = api(x, axis=[])
self.assertEqual(out_empty_list.shape, []) self.assertEqual(out_empty_list, out)
self.assertEqual(out_empty_list.shape, [])
if x.grad is not None: if x.grad is not None:
self.assertEqual(x.grad.shape, []) self.assertEqual(x.grad.shape, [])
...@@ -222,12 +225,12 @@ class TestReduceAPI(unittest.TestCase): ...@@ -222,12 +225,12 @@ class TestReduceAPI(unittest.TestCase):
np.testing.assert_allclose(x.grad.numpy(), np.array(1.0)) np.testing.assert_allclose(x.grad.numpy(), np.array(1.0))
np.testing.assert_allclose(out.grad.numpy(), np.array(1.0)) np.testing.assert_allclose(out.grad.numpy(), np.array(1.0))
out1 = api(x, 0) out1 = api(x, axis=0)
self.assertEqual(out1.shape, []) self.assertEqual(out1.shape, [])
self.assertEqual(out1, out) self.assertEqual(out1, out)
out1.backward() out1.backward()
out2 = api(x, -1) out2 = api(x, axis=-1)
self.assertEqual(out2.shape, []) self.assertEqual(out2.shape, [])
self.assertEqual(out2, out) self.assertEqual(out2, out)
out2.backward() out2.backward()
...@@ -236,13 +239,28 @@ class TestReduceAPI(unittest.TestCase): ...@@ -236,13 +239,28 @@ class TestReduceAPI(unittest.TestCase):
self.assertEqual(x.grad.shape, []) self.assertEqual(x.grad.shape, [])
np.testing.assert_allclose(x.grad.numpy(), np.array(3.0)) np.testing.assert_allclose(x.grad.numpy(), np.array(3.0))
# 2) x is ND, reduce to 0D # 2) x is 1D, axis=0, reduce to 0D
if api in [paddle.all, paddle.any]:
x = paddle.randint(0, 2, [5]).astype('bool')
else:
x = paddle.rand([5])
x.stop_gradient = False
out = api(x, axis=0)
out.retain_grads()
out.backward()
self.assertEqual(out.shape, [])
if x.grad is not None:
self.assertEqual(out.grad.shape, [])
self.assertEqual(x.grad.shape, [5])
# 3) x is ND, reduce to 0D
if api in [paddle.all, paddle.any]: if api in [paddle.all, paddle.any]:
x = paddle.randint(0, 2, [3, 5]).astype('bool') x = paddle.randint(0, 2, [3, 5]).astype('bool')
else: else:
x = paddle.rand([3, 5]) x = paddle.rand([3, 5])
x.stop_gradient = False x.stop_gradient = False
out = api(x, None) out = api(x, axis=None)
out.retain_grads() out.retain_grads()
out.backward() out.backward()
...@@ -251,20 +269,20 @@ class TestReduceAPI(unittest.TestCase): ...@@ -251,20 +269,20 @@ class TestReduceAPI(unittest.TestCase):
self.assertEqual(out.grad.shape, []) self.assertEqual(out.grad.shape, [])
self.assertEqual(x.grad.shape, [3, 5]) self.assertEqual(x.grad.shape, [3, 5])
# 3) x is 1D, axis=0, reduce to 0D # 4) x is ND, reduce to 0D, keepdim=True
if api in [paddle.all, paddle.any]: if api in [paddle.all, paddle.any]:
x = paddle.randint(0, 2, [5]).astype('bool') x = paddle.randint(0, 2, [3, 5]).astype('bool')
else: else:
x = paddle.rand([5]) x = paddle.rand([3, 5])
x.stop_gradient = False x.stop_gradient = False
out = api(x, 0) out = api(x, keepdim=True)
out.retain_grads() out.retain_grads()
out.backward() out.backward()
self.assertEqual(out.shape, []) self.assertEqual(out.shape, [1, 1])
if x.grad is not None: if x.grad is not None:
self.assertEqual(out.grad.shape, []) self.assertEqual(out.grad.shape, [1, 1])
self.assertEqual(x.grad.shape, [5]) self.assertEqual(x.grad.shape, [3, 5])
paddle.enable_static() paddle.enable_static()
...@@ -283,16 +301,17 @@ class TestReduceAPI(unittest.TestCase): ...@@ -283,16 +301,17 @@ class TestReduceAPI(unittest.TestCase):
else: else:
x = paddle.rand([]) x = paddle.rand([])
x.stop_gradient = False x.stop_gradient = False
out = api(x, None) out = api(x, axis=None)
paddle.static.append_backward(out) paddle.static.append_backward(out)
out_empty_list = api(x, None) if api not in [paddle.median, paddle.nanmedian]:
self.assertEqual(out_empty_list.shape, ()) out_empty_list = api(x, axis=[])
self.assertEqual(out_empty_list.shape, ())
out1 = api(x, 0) out1 = api(x, axis=0)
self.assertEqual(out1.shape, ()) self.assertEqual(out1.shape, ())
out2 = api(x, -1) out2 = api(x, axis=-1)
self.assertEqual(out2.shape, ()) self.assertEqual(out2.shape, ())
fetch_list = [x, out] fetch_list = [x, out]
...@@ -317,7 +336,7 @@ class TestReduceAPI(unittest.TestCase): ...@@ -317,7 +336,7 @@ class TestReduceAPI(unittest.TestCase):
else: else:
x = paddle.rand([3, 5]) x = paddle.rand([3, 5])
x.stop_gradient = False x.stop_gradient = False
out = api(x, None) out = api(x, axis=None)
paddle.static.append_backward(out) paddle.static.append_backward(out)
fetch_list = [out] fetch_list = [out]
...@@ -336,7 +355,7 @@ class TestReduceAPI(unittest.TestCase): ...@@ -336,7 +355,7 @@ class TestReduceAPI(unittest.TestCase):
else: else:
x = paddle.rand([5]) x = paddle.rand([5])
x.stop_gradient = False x.stop_gradient = False
out = api(x, 0) out = api(x, axis=0)
paddle.static.append_backward(out) paddle.static.append_backward(out)
fetch_list = [out] fetch_list = [out]
...@@ -1200,54 +1219,6 @@ class TestSundryAPI(unittest.TestCase): ...@@ -1200,54 +1219,6 @@ class TestSundryAPI(unittest.TestCase):
out = paddle.argmax(x, keepdim=True) out = paddle.argmax(x, keepdim=True)
self.assertEqual(out.shape, [1, 1]) self.assertEqual(out.shape, [1, 1])
def test_median(self):
# 1) x is 0D
x = paddle.rand([])
x.stop_gradient = False
out1 = paddle.median(x, 0)
out2 = paddle.median(x, -1)
out3 = paddle.median(x, None)
out1.backward()
out2.backward()
out3.backward()
self.assertEqual(out1.shape, [])
np.testing.assert_allclose(out1, x)
self.assertEqual(out2.shape, [])
np.testing.assert_allclose(out2, x)
self.assertEqual(out3.shape, [])
np.testing.assert_allclose(out3, x)
self.assertEqual(x.grad.shape, [])
np.testing.assert_allclose(x.grad, 3.0)
# 2) x is 1D
x = paddle.rand([5])
x.stop_gradient = False
out = paddle.median(x, 0)
out.backward()
self.assertEqual(out.shape, [])
self.assertEqual(x.grad.shape, [5])
# 3) x is ND
x = paddle.rand([3, 5])
x.stop_gradient = False
out = paddle.median(x, None)
out.backward()
self.assertEqual(out.shape, [])
self.assertEqual(x.grad.shape, [3, 5])
# 4) x is ND, keepdim=True
x = paddle.rand([3, 5])
x.stop_gradient = False
out = paddle.median(x, keepdim=True)
out.backward()
self.assertEqual(out.shape, [1, 1])
self.assertEqual(x.grad.shape, [3, 5])
def test_kthvalue(self): def test_kthvalue(self):
# 1) x is 0D # 1) x is 0D
x = paddle.randn([]) x = paddle.randn([])
...@@ -1535,6 +1506,40 @@ class TestSundryAPI(unittest.TestCase): ...@@ -1535,6 +1506,40 @@ class TestSundryAPI(unittest.TestCase):
self.assertEqual(out.grad, 1.0) self.assertEqual(out.grad, 1.0)
self.assertEqual(x.grad.shape, [2, 3]) self.assertEqual(x.grad.shape, [2, 3])
def test_nanquantile(self):
# 1) x is 0D
x = paddle.rand([])
x.stop_gradient = False
out = paddle.quantile(x, 0.5, axis=None)
out.retain_grads()
out.backward()
out_empty_list = paddle.quantile(x, 0.5, axis=[])
self.assertEqual(out_empty_list, out)
self.assertEqual(x.shape, [])
self.assertEqual(out.shape, [])
self.assertEqual(out, x)
self.assertEqual(x.grad.shape, [])
self.assertEqual(x.grad, 1.0)
self.assertEqual(out.grad.shape, [])
self.assertEqual(out.grad, 1.0)
# 2) x is ND with 'nan'
x = paddle.to_tensor([[float('nan'), 2.0, 3.0], [0.0, 1.0, 2.0]])
x.stop_gradient = False
out = paddle.quantile(x, 0.5, axis=None)
out.retain_grads()
out.backward()
self.assertEqual(out.shape, [])
self.assertEqual(out.grad.shape, [])
self.assertEqual(out.grad, 1.0)
self.assertEqual(x.grad.shape, [2, 3])
def test_flip(self): def test_flip(self):
x = paddle.rand([]) x = paddle.rand([])
x.stop_gradient = False x.stop_gradient = False
...@@ -3442,40 +3447,6 @@ class TestSundryAPIStatic(unittest.TestCase): ...@@ -3442,40 +3447,6 @@ class TestSundryAPIStatic(unittest.TestCase):
np.testing.assert_allclose(res[2], 0.0) np.testing.assert_allclose(res[2], 0.0)
self.assertEqual(res[3].shape, ()) self.assertEqual(res[3].shape, ())
@prog_scope()
def test_median(self):
# 1) x is 0D
x = paddle.rand([])
x.stop_gradient = False
out = paddle.median(x)
paddle.static.append_backward(out)
# 2) x is ND
x1 = paddle.rand([3, 5])
x1.stop_gradient = False
out1 = paddle.median(x1)
paddle.static.append_backward(out1)
prog = paddle.static.default_main_program()
res = self.exe.run(
prog,
fetch_list=[
x,
out,
x.grad_name,
out1,
x1.grad_name,
],
)
self.assertEqual(res[1].shape, ())
np.testing.assert_allclose(res[1], res[0])
self.assertEqual(res[2].shape, ())
np.testing.assert_allclose(res[2], 1.0)
self.assertEqual(res[3].shape, ())
self.assertEqual(res[4].shape, (3, 5))
@prog_scope() @prog_scope()
def test_kthvalue(self): def test_kthvalue(self):
# 1) x is 0D # 1) x is 0D
...@@ -3813,12 +3784,12 @@ class TestSundryAPIStatic(unittest.TestCase): ...@@ -3813,12 +3784,12 @@ class TestSundryAPIStatic(unittest.TestCase):
x1 = paddle.rand([]) x1 = paddle.rand([])
x1.stop_gradient = False x1.stop_gradient = False
out1 = paddle.quantile(x1, 0.5, axis=None) out1 = paddle.quantile(x1, 0.5, axis=None)
paddle.static.append_backward(out1.sum()) paddle.static.append_backward(out1)
x2 = paddle.rand([2, 3]) x2 = paddle.rand([2, 3])
x2.stop_gradient = False x2.stop_gradient = False
out2 = paddle.quantile(x2, 0.5, axis=None) out2 = paddle.quantile(x2, 0.5, axis=None)
paddle.static.append_backward(out2.sum()) paddle.static.append_backward(out2)
out_empty_list = paddle.quantile(x1, 0.5, axis=[]) out_empty_list = paddle.quantile(x1, 0.5, axis=[])
self.assertEqual(out_empty_list.shape, ()) self.assertEqual(out_empty_list.shape, ())
...@@ -3846,6 +3817,37 @@ class TestSundryAPIStatic(unittest.TestCase): ...@@ -3846,6 +3817,37 @@ class TestSundryAPIStatic(unittest.TestCase):
self.assertEqual(res[5].shape, ()) self.assertEqual(res[5].shape, ())
self.assertEqual(res[5], 1.0) self.assertEqual(res[5], 1.0)
@prog_scope()
def test_nanquantile(self):
# 1) x is 0D
x1 = paddle.rand([])
x1.stop_gradient = False
out1 = paddle.nanquantile(x1, 0.5, axis=None)
paddle.static.append_backward(out1)
# 2) x is ND with 'nan'
x2 = paddle.to_tensor([[float('nan'), 2.0, 3.0], [0.0, 1.0, 2.0]])
x2.stop_gradient = False
out2 = paddle.nanquantile(x2, 0.5, axis=None)
print(out2)
paddle.static.append_backward(out2)
prog = paddle.static.default_main_program()
res = self.exe.run(
prog,
fetch_list=[
out1,
x1.grad_name,
out2,
x2.grad_name,
],
)
self.assertEqual(res[0].shape, ())
self.assertEqual(res[1].shape, ())
self.assertEqual(res[2].shape, ())
self.assertEqual(res[3].shape, (2, 3))
@prog_scope() @prog_scope()
def test_flip(self): def test_flip(self):
x = paddle.rand([]) x = paddle.rand([])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册