未验证 提交 3d4d995f 编写于 作者: zhouweiwei2014's avatar zhouweiwei2014 提交者: GitHub

[Zero-Dim] paddle.nanmedian/nanquantile support 0D Tensor (#54500)

* [Zero-Dim] paddle.nanmedian support 0D Tensor

* fix CI
上级 ca59c72b
......@@ -2323,37 +2323,47 @@ void NanmedianInferMeta(const MetaTensor& x,
for (int64_t i = 0; i < x_rank; i++) {
out_dim.push_back(1);
}
} else {
out_dim.push_back(1);
}
} else {
std::vector<int64_t> cleaned_axis;
std::vector<int64_t> formated_axis;
for (auto& axis : axis_list) {
if (axis < 0) axis += x_rank;
PADDLE_ENFORCE_LT(
axis,
if (x_rank == 0) {
PADDLE_ENFORCE_EQ(axis == 0 || axis == -1,
true,
phi::errors::InvalidArgument(
"When input 0D Tensor, each element of the axis "
"can only be -1, 0, None"));
} else {
PADDLE_ENFORCE_LT(axis,
x_rank,
errors::InvalidArgument(
"Attr(axis) value should be in range [-R, R-1], R is "
"the rank of Input(X). But received axis: %d, R: %d. "
"Current Input(X)'s shape is=[%s].",
axis,
"each element of the axis should be in the "
"range [ -dimension(X), dimension(X) ) "
"which dimesion = %d. But received axis = %d.",
x_rank,
x_dim));
axis));
PADDLE_ENFORCE_GE(axis,
-x_rank,
errors::InvalidArgument(
"each element of the axis should be in the "
"range [ -dimension(X), dimension(X) ) "
"which dimesion = %d. But received axis = %d.",
x_rank,
axis));
}
if (axis < 0) axis += x_rank;
PADDLE_ENFORCE_EQ(
std::find(cleaned_axis.begin(), cleaned_axis.end(), axis),
cleaned_axis.end(),
std::find(formated_axis.begin(), formated_axis.end(), axis),
formated_axis.end(),
errors::InvalidArgument("Attr(axes) has duplicated elements: %d.",
static_cast<int>(axis)));
cleaned_axis.push_back(axis);
formated_axis.push_back(axis);
}
for (int64_t i = 0; i < x_rank; i++) {
if (std::find(cleaned_axis.begin(), cleaned_axis.end(), i) ==
cleaned_axis.end()) {
if (std::find(formated_axis.begin(), formated_axis.end(), i) ==
formated_axis.end()) {
out_dim.push_back(x_dim[i]);
} else if (keep_dim) {
out_dim.push_back(1);
......
......@@ -17,7 +17,7 @@
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/impl/nanmedian_grad_kernel_impl.h"
#include "paddle/phi/kernels/funcs/nanmedian_utils.h"
namespace phi {
......@@ -26,31 +26,31 @@ void CalcMedianGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& median_index,
const DenseTensor& out_grad,
const IntArray& axes UNUSED,
DenseTensor* x_grad,
T* x_grad_ptr) {
DenseTensor* x_grad) {
T* dx_data = dev_ctx.template Alloc<T>(x_grad);
if (!dx_data) return;
phi::funcs::SetConstant<Context, T> set_zero;
set_zero(dev_ctx, x_grad, static_cast<T>(0));
if (!x_grad_ptr) return;
const int64_t* m_ptr = median_index.data<int64_t>();
const T* out_grad_ptr = out_grad.data<T>();
const int64_t* m_data = median_index.data<int64_t>();
const T* dout_data = out_grad.data<T>();
int64_t numel = x.numel();
auto x_dim = x.dims();
int64_t rank = x_dim.size();
int64_t stride = x_dim[rank - 1];
int64_t pre_dim = numel / stride;
int64_t i = 0;
int64_t offset = 0;
T div_factor = static_cast<T>(2.0);
for (i = 0; i < pre_dim; i++) {
if (m_ptr[2 * i] >= 0) {
if (m_ptr[2 * i] == m_ptr[2 * i + 1]) {
x_grad_ptr[offset + m_ptr[2 * i]] = out_grad_ptr[i];
if (m_data[2 * i] >= 0) {
if (m_data[2 * i] == m_data[2 * i + 1]) {
dx_data[offset + m_data[2 * i]] = dout_data[i];
} else {
x_grad_ptr[offset + m_ptr[2 * i]] = out_grad_ptr[i] / div_factor;
x_grad_ptr[offset + m_ptr[2 * i + 1]] = out_grad_ptr[i] / div_factor;
dx_data[offset + m_data[2 * i]] = dout_data[i] / static_cast<T>(2.0);
dx_data[offset + m_data[2 * i + 1]] =
dout_data[i] / static_cast<T>(2.0);
}
}
offset += stride;
......@@ -58,35 +58,32 @@ void CalcMedianGradKernel(const Context& dev_ctx,
}
template <typename T, typename Context>
void BaseMedianGradKernel(const Context& dev_ctx,
void NanmedianGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& median_index,
const DenseTensor& out_grad,
const IntArray& axes,
bool keepdim UNUSED,
DenseTensor* x_grad) {
DenseTensor tmp_x;
auto rank = x.dims().size();
T* x_grad_ptr = dev_ctx.template Alloc<T>(x_grad);
if (axes.size() && (rank > 1)) {
DenseTensor tmp_x_grad(*x_grad);
if ((axes.size() == 0) || rank <= 1) {
tmp_x = x;
tmp_x.Resize({x.numel()});
CalcMedianGradKernel<T, Context>(
dev_ctx, x, median_index, out_grad, axes, &tmp_x_grad, x_grad_ptr);
PostprocessMedianGradKernel<T, Context>(dev_ctx, &tmp_x_grad, axes, x_grad);
dev_ctx, tmp_x, median_index, out_grad, x_grad);
} else {
funcs::PreprocessMedianKernel<T, Context>(dev_ctx, x, axes, &tmp_x);
DenseTensor tmp_x_grad;
tmp_x_grad.Resize(x_grad->dims());
CalcMedianGradKernel<T, Context>(
dev_ctx, x, median_index, out_grad, axes, x_grad, x_grad_ptr);
}
}
dev_ctx, tmp_x, median_index, out_grad, &tmp_x_grad);
template <typename T, typename Context>
void NanmedianGradKernel(const Context& dev_ctx,
const DenseTensor& input,
const DenseTensor& median_index,
const DenseTensor& out_grad,
const IntArray& axes,
bool keep_dim UNUSED,
DenseTensor* x_grad) {
BaseMedianGradKernel<T, Context>(
dev_ctx, input, median_index, out_grad, axes, x_grad);
dev_ctx.template Alloc<T>(x_grad);
funcs::PostprocessMedianGradKernel<T, Context>(
dev_ctx, &tmp_x_grad, axes, x_grad);
}
}
} // namespace phi
......
......@@ -16,7 +16,7 @@
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/nanmedian_kernel_impl.h"
#include "paddle/phi/kernels/funcs/nanmedian_utils.h"
#include "paddle/phi/kernels/top_k_kernel.h"
namespace phi {
......@@ -31,7 +31,6 @@ void CalcMedianFunc(const Context& dev_ctx,
int64_t pre_dim,
T* o_ptr,
int64_t* m_ptr) {
bool should_ignore_nan = ignore_nan;
DenseTensor sort_out;
DenseTensor sort_indices;
auto sort_dim = x.dims();
......@@ -52,7 +51,7 @@ void CalcMedianFunc(const Context& dev_ctx,
int64_t offset = 0;
int64_t i = 0;
bool is_ori_odd = stride & 1;
if (should_ignore_nan) {
if (ignore_nan) {
for (i = 0; i < pre_dim; i++) {
offset = i * sort_k;
if (nan_counts[i] == stride) {
......@@ -107,11 +106,11 @@ void CalcMedianFunc(const Context& dev_ctx,
template <typename T, typename Context>
void ProcessMedianKernel(const Context& dev_ctx,
const DenseTensor& x,
T* o_ptr,
int64_t* m_ptr,
bool ignore_nan) {
bool should_ignore_nan = ignore_nan;
const T* x_ptr = x.data<T>();
DenseTensor* out,
DenseTensor* median_index) {
const T* x_data = x.data<T>();
T* out_data = dev_ctx.template Alloc<T>(out);
int64_t* m_data = dev_ctx.template Alloc<int64_t>(median_index);
int64_t numel = x.numel();
auto x_dim = x.dims();
......@@ -122,7 +121,8 @@ void ProcessMedianKernel(const Context& dev_ctx,
int64_t max_valid_num = 0;
std::vector<int64_t> nan_counts;
if (should_ignore_nan) {
bool ignore_nan = true;
if (ignore_nan) {
int64_t total_nan_num = 0;
std::vector<T> col_vec;
col_vec.reserve(stride);
......@@ -133,7 +133,7 @@ void ProcessMedianKernel(const Context& dev_ctx,
for (int64_t i = 0; i < pre_dim; i++) {
col_vec.clear();
col_vec.insert(
col_vec.begin(), x_ptr + i * stride, x_ptr + (i + 1) * stride);
col_vec.begin(), x_data + i * stride, x_data + (i + 1) * stride);
nan_counts[i] =
std::count_if(col_vec.begin(), col_vec.end(), [&](const T& val) {
return std::isnan(static_cast<float>(val));
......@@ -145,47 +145,25 @@ void ProcessMedianKernel(const Context& dev_ctx,
// all elems are nan
if (total_nan_num == numel) {
for (i = 0; i < pre_dim; i++) {
o_ptr[i] = x_ptr[0];
m_ptr[2 * i] = -1;
m_ptr[2 * i + 1] = -1;
out_data[i] = std::numeric_limits<T>::quiet_NaN();
m_data[2 * i] = -1;
m_data[2 * i + 1] = -1;
}
return;
}
should_ignore_nan = total_nan_num > 0;
ignore_nan = total_nan_num > 0;
}
int64_t sort_k = should_ignore_nan ? max_valid_num : ((stride >> 1) + 1);
int64_t sort_k = ignore_nan ? max_valid_num : ((stride >> 1) + 1);
CalcMedianFunc<T, Context>(dev_ctx,
x,
nan_counts,
should_ignore_nan,
ignore_nan,
sort_k,
stride,
pre_dim,
o_ptr,
m_ptr);
}
template <typename T, typename Context>
void BaseMedianKernel(const Context& dev_ctx,
const DenseTensor& input,
const IntArray& axes,
DenseTensor* out,
DenseTensor* median_index,
bool ignore_nan) {
DenseTensor x;
auto rank = input.dims().size();
if ((axes.size() == 0) || rank <= 1) {
x = input;
x.Resize({input.numel()});
} else {
PreprocessMedianKernel<T, Context>(dev_ctx, input, axes, &x);
}
T* o_ptr = dev_ctx.template Alloc<T>(out);
int64_t* m_ptr = dev_ctx.template Alloc<int64_t>(median_index);
ProcessMedianKernel<T, Context>(dev_ctx, x, o_ptr, m_ptr, ignore_nan);
out->Resize(out->dims());
out_data,
m_data);
}
template <typename T, typename Context>
......@@ -195,7 +173,16 @@ void NanmedianKernel(const Context& dev_ctx,
bool keepdim UNUSED,
DenseTensor* out,
DenseTensor* median_index) {
BaseMedianKernel<T, Context>(dev_ctx, x, axes, out, median_index, true);
DenseTensor tmp_x;
auto rank = x.dims().size();
if ((axes.size() == 0) || rank <= 1) {
tmp_x = x;
tmp_x.Resize({x.numel()});
} else {
funcs::PreprocessMedianKernel<T, Context>(dev_ctx, x, axes, &tmp_x);
}
ProcessMedianKernel<T, Context>(dev_ctx, tmp_x, out, median_index);
}
} // namespace phi
......
......@@ -15,9 +15,51 @@
#pragma once
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/nanmedian_kernel.h"
namespace phi {
namespace funcs {
template <typename T, typename Context>
void PostprocessMedianGradKernel(const Context& dev_ctx,
DenseTensor* input,
const IntArray& raw_axes,
DenseTensor* x) {
auto input_dim = input->dims();
auto rank = input_dim.size();
std::vector<int64_t> axes = raw_axes.GetData();
int64_t axes_size = static_cast<int>(axes.size());
for (int64_t i = 0; i < axes_size; i++) {
if (axes[i] < 0) {
axes[i] += rank;
}
}
std::vector<int> trans_back;
std::vector<int> reshape_back;
trans_back.resize(rank);
int offset = 0;
for (int64_t i = 0; i < rank; i++) {
if (std::find(axes.begin(), axes.end(), i) == axes.end()) {
reshape_back.push_back(input_dim[i]);
trans_back[i] = offset;
offset += 1;
}
}
for (int64_t i = 0; i < rank; i++) {
if (std::find(axes.begin(), axes.end(), i) != axes.end()) {
trans_back[i] = offset;
reshape_back.push_back(input_dim[i]);
offset += 1;
}
}
input->Resize(make_ddim(reshape_back));
funcs::TransCompute<Context, T>(
static_cast<int>(trans_back.size()), dev_ctx, *input, x, trans_back);
}
template <typename T, typename Context>
void PreprocessMedianKernel(const Context& dev_ctx,
......@@ -65,4 +107,5 @@ void PreprocessMedianKernel(const Context& dev_ctx,
x->Resize(make_ddim(reshape));
}
} // namespace funcs
} // namespace phi
......@@ -20,7 +20,7 @@
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_meta.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/impl/nanmedian_grad_kernel_impl.h"
#include "paddle/phi/kernels/funcs/nanmedian_utils.h"
namespace phi {
......@@ -30,23 +30,26 @@ inline int GET_BLOCKS(const int N) {
}
template <typename T>
__global__ void KernelNanmedianGrad(const T* x_ptr,
__global__ void KernelNanmedianGrad(const T* x_data,
const int64_t* medians_ptr,
const T* out_grad_ptr,
T* x_grad_ptr,
T* dx_data,
int64_t stride,
int64_t pre_dim,
T div_factor) {
int64_t pre_dim) {
CUDA_KERNEL_LOOP(index, pre_dim) {
int64_t offset = index * stride;
printf("index: %d\n", index);
printf("medians_ptr[2 * index]: %d\n", medians_ptr[2 * index]);
printf("medians_ptr[2 * index+1]: %d\n", medians_ptr[2 * index + 1]);
if (medians_ptr[2 * index] >= 0) {
if (medians_ptr[2 * index] == medians_ptr[2 * index + 1]) {
x_grad_ptr[offset + medians_ptr[2 * index]] = out_grad_ptr[index];
dx_data[offset + medians_ptr[2 * index]] = out_grad_ptr[index];
} else {
x_grad_ptr[offset + medians_ptr[2 * index]] =
out_grad_ptr[index] / div_factor;
x_grad_ptr[offset + medians_ptr[2 * index + 1]] =
out_grad_ptr[index] / div_factor;
dx_data[offset + medians_ptr[2 * index]] =
out_grad_ptr[index] / static_cast<T>(2.0);
dx_data[offset + medians_ptr[2 * index + 1]] =
out_grad_ptr[index] / static_cast<T>(2.0);
}
}
}
......@@ -57,14 +60,17 @@ void CalcMedianGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& median_index,
const DenseTensor& out_grad,
DenseTensor* x_grad,
T* x_grad_ptr) {
DenseTensor* x_grad) {
T* dx_data = dev_ctx.template Alloc<T>(x_grad);
if (!dx_data) return;
phi::funcs::SetConstant<Context, T> set_zero;
set_zero(dev_ctx, x_grad, static_cast<T>(0));
VLOG(0) << "x_grad->dims(): " << x_grad->dims();
auto stream = dev_ctx.stream();
const T* x_ptr = x.data<T>();
const int64_t* m_ptr = median_index.data<int64_t>();
const T* x_data = x.data<T>();
const int64_t* m_data = median_index.data<int64_t>();
const T* out_grad_ptr = out_grad.data<T>();
int64_t numel = x.numel();
......@@ -73,42 +79,38 @@ void CalcMedianGradKernel(const Context& dev_ctx,
int64_t stride = x_dim[x_rank - 1];
int64_t pre_dim = numel / stride;
T div_factor = static_cast<T>(2.0);
KernelNanmedianGrad<T>
<<<GET_BLOCKS(pre_dim), PADDLE_CUDA_NUM_THREADS, 0, stream>>>(
x_ptr, m_ptr, out_grad_ptr, x_grad_ptr, stride, pre_dim, div_factor);
x_data, m_data, out_grad_ptr, dx_data, stride, pre_dim);
}
template <typename T, typename Context>
void BaseMedianGradKernel(const Context& dev_ctx,
void NanmedianGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& median_index,
const DenseTensor& out_grad,
const IntArray& axes,
bool keepdim UNUSED,
DenseTensor* x_grad) {
DenseTensor tmp_x;
auto rank = x.dims().size();
T* x_grad_ptr = dev_ctx.template Alloc<T>(x_grad);
if (axes.size() && (rank > 1)) {
DenseTensor tmp_x_grad(*x_grad);
if ((axes.size() == 0) || rank <= 1) {
tmp_x = x;
tmp_x.Resize({x.numel()});
CalcMedianGradKernel<T, Context>(
dev_ctx, x, median_index, out_grad, &tmp_x_grad, x_grad_ptr);
PostprocessMedianGradKernel<T, Context>(dev_ctx, &tmp_x_grad, axes, x_grad);
dev_ctx, tmp_x, median_index, out_grad, x_grad);
} else {
funcs::PreprocessMedianKernel<T, Context>(dev_ctx, x, axes, &tmp_x);
DenseTensor tmp_x_grad;
tmp_x_grad.Resize(x_grad->dims());
CalcMedianGradKernel<T, Context>(
dev_ctx, x, median_index, out_grad, x_grad, x_grad_ptr);
}
}
dev_ctx, tmp_x, median_index, out_grad, &tmp_x_grad);
template <typename T, typename Context>
void NanmedianGradKernel(const Context& dev_ctx,
const DenseTensor& input,
const DenseTensor& median_index,
const DenseTensor& out_grad,
const IntArray& axes,
bool keep_dim,
DenseTensor* x_grad) {
BaseMedianGradKernel<T, Context>(
dev_ctx, input, median_index, out_grad, axes, x_grad);
dev_ctx.template Alloc<T>(x_grad);
funcs::PostprocessMedianGradKernel<T, Context>(
dev_ctx, &tmp_x_grad, axes, x_grad);
}
}
} // namespace phi
......
......@@ -20,7 +20,7 @@
#include "paddle/phi/common/memory_utils.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/full_kernel.h"
#include "paddle/phi/kernels/impl/nanmedian_kernel_impl.h"
#include "paddle/phi/kernels/funcs/nanmedian_utils.h"
#include "paddle/phi/kernels/top_k_kernel.h"
namespace phi {
......@@ -138,14 +138,13 @@ __global__ void CalcNanmedianKernel(const T* sort_out_ptr,
template <typename T, typename Context>
void ProcessMedianKernel(const Context& dev_ctx,
const DenseTensor& x,
bool ignore_nan,
DenseTensor* out,
int64_t* m_ptr) {
bool should_ignore_nan = ignore_nan;
DenseTensor* median_index) {
auto stream = dev_ctx.stream();
const T* x_data = x.data<T>();
T* out_data = dev_ctx.template Alloc<T>(out);
int64_t* m_data = dev_ctx.template Alloc<int64_t>(median_index);
const T* x_ptr = x.data<T>();
T* o_ptr = dev_ctx.template Alloc<T>(out);
int64_t numel = x.numel();
auto x_dim = x.dims();
int64_t x_rank = x_dim.size();
......@@ -156,7 +155,9 @@ void ProcessMedianKernel(const Context& dev_ctx,
DenseTensor nan_counts, nan_stat;
int64_t* nan_counts_ptr;
int64_t max_valid_num = 0;
if (should_ignore_nan) {
bool ignore_nan = true;
if (ignore_nan) {
nan_counts.Resize(phi::make_ddim({pre_dim}));
dev_ctx.template Alloc<int64_t>(&nan_counts);
nan_counts_ptr = nan_counts.data<int64_t>();
......@@ -167,7 +168,7 @@ void ProcessMedianKernel(const Context& dev_ctx,
KernelNanCounts<T><<<GET_BLOCKS(numel),
PADDLE_CUDA_NUM_THREADS,
pre_dim * sizeof(int64_t),
stream>>>(x_ptr,
stream>>>(x_data,
numel,
pre_dim,
stride,
......@@ -189,15 +190,19 @@ void ProcessMedianKernel(const Context& dev_ctx,
// all elements are nan values
T nan_val = std::numeric_limits<T>::quiet_NaN();
if (nan_stat_cpu_ptr[0] == numel) {
FullLikeKernel<T, Context>(dev_ctx, x, nan_val, x.dtype(), out);
phi::funcs::SetConstant<Context, T> set_nan;
set_nan(dev_ctx, out, nan_val);
phi::funcs::SetConstant<Context, int64_t> set_negatvie;
set_negatvie(dev_ctx, median_index, static_cast<int64_t>(-1));
return;
}
should_ignore_nan = nan_stat_cpu_ptr[0] > 0;
ignore_nan = nan_stat_cpu_ptr[0] > 0;
max_valid_num = nan_stat_cpu_ptr[1];
}
int64_t sort_k = should_ignore_nan ? max_valid_num : ((stride >> 1) + 1);
int64_t sort_k = ignore_nan ? max_valid_num : ((stride >> 1) + 1);
bool is_ori_odd = stride & 1;
DenseTensor sort_out, sort_indices;
......@@ -217,14 +222,14 @@ void ProcessMedianKernel(const Context& dev_ctx,
T div_factor = static_cast<T>(2.0);
T nan_val = std::numeric_limits<T>::quiet_NaN();
if (should_ignore_nan) {
if (ignore_nan) {
CalcNanmedianKernel<T>
<<<GET_BLOCKS(pre_dim), PADDLE_CUDA_NUM_THREADS, 0, stream>>>(
sort_out_ptr,
sort_indices_ptr,
nan_counts_ptr,
m_ptr,
o_ptr,
m_data,
out_data,
is_ori_odd,
pre_dim,
max_valid_num,
......@@ -236,8 +241,8 @@ void ProcessMedianKernel(const Context& dev_ctx,
<<<GET_BLOCKS(pre_dim), PADDLE_CUDA_NUM_THREADS, 0, stream>>>(
sort_out_ptr,
sort_indices_ptr,
m_ptr,
o_ptr,
m_data,
out_data,
div_factor,
is_ori_odd,
pre_dim,
......@@ -246,34 +251,22 @@ void ProcessMedianKernel(const Context& dev_ctx,
}
template <typename T, typename Context>
void BaseMedianKernel(const Context& dev_ctx,
const DenseTensor& input,
void NanmedianKernel(const Context& dev_ctx,
const DenseTensor& x,
const IntArray& axes,
bool ignore_nan,
bool keepdim,
DenseTensor* out,
DenseTensor* median_index) {
DenseTensor x;
auto rank = input.dims().size();
DenseTensor tmp_x;
auto rank = x.dims().size();
if ((axes.size() == 0) || rank <= 1) {
x = input;
x.Resize({input.numel()});
tmp_x = x;
tmp_x.Resize({x.numel()});
} else {
PreprocessMedianKernel<T, Context>(dev_ctx, input, axes, &x);
funcs::PreprocessMedianKernel<T, Context>(dev_ctx, x, axes, &tmp_x);
}
int64_t* m_ptr = dev_ctx.template Alloc<int64_t>(median_index);
ProcessMedianKernel<T, Context>(dev_ctx, x, ignore_nan, out, m_ptr);
out->Resize(out->dims());
}
template <typename T, typename Context>
void NanmedianKernel(const Context& dev_ctx,
const DenseTensor& x,
const IntArray& axes,
bool keepdim,
DenseTensor* out,
DenseTensor* median_index) {
BaseMedianKernel<T, Context>(dev_ctx, x, axes, true, out, median_index);
ProcessMedianKernel<T, Context>(dev_ctx, tmp_x, out, median_index);
}
} // namespace phi
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/nanmedian_grad_kernel.h"
namespace phi {
template <typename T, typename Context>
void PostprocessMedianGradKernel(const Context& dev_ctx,
DenseTensor* input,
const IntArray& raw_axes,
DenseTensor* x) {
auto input_dim = input->dims();
auto rank = input_dim.size();
std::vector<int64_t> axes = raw_axes.GetData();
int64_t axes_size = static_cast<int>(axes.size());
for (int64_t i = 0; i < axes_size; i++) {
if (axes[i] < 0) {
axes[i] += rank;
}
}
std::vector<int> trans_back;
std::vector<int> reshape_back;
trans_back.reserve(rank);
trans_back.resize(rank);
int offset = 0;
for (int64_t i = 0; i < rank; i++) {
if (std::find(axes.begin(), axes.end(), i) == axes.end()) {
reshape_back.push_back(input_dim[i]);
trans_back[i] = offset;
offset += 1;
}
}
for (int64_t i = 0; i < rank; i++) {
if (std::find(axes.begin(), axes.end(), i) != axes.end()) {
trans_back[i] = offset;
reshape_back.push_back(input_dim[i]);
offset += 1;
}
}
input->Resize(make_ddim(reshape_back));
funcs::TransCompute<Context, T>(
static_cast<int>(trans_back.size()), dev_ctx, *input, x, trans_back);
}
} // namespace phi
......@@ -255,7 +255,7 @@ def numel(x, name=None):
return out
def nanmedian(x, axis=None, keepdim=True, name=None):
def nanmedian(x, axis=None, keepdim=False, name=None):
r"""
Compute the median along the specified axis, while ignoring NaNs.
......@@ -273,7 +273,7 @@ def nanmedian(x, axis=None, keepdim=True, name=None):
in the output Tensor. If ``keepdim`` is True, the dimensions of
the output Tensor is the same as ``x`` except in the reduced
dimensions(it is of size 1 in this case). Otherwise, the shape of
the output Tensor is squeezed in ``axis`` . Default is True.
the output Tensor is squeezed in ``axis`` . Default is False.
name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
......@@ -287,16 +287,16 @@ def nanmedian(x, axis=None, keepdim=True, name=None):
x = paddle.to_tensor([[float('nan'), 2. , 3. ], [0. , 1. , 2. ]])
y1 = x.nanmedian()
# y1 is [[2.]]
# y1 is 2.
y2 = x.nanmedian(0)
# y2 is [[0., 1.5, 2.5]]
# y2 is [0., 1.5, 2.5]
y3 = x.nanmedian(0, keepdim=False)
# y3 is [0., 1.5, 2.5]
y3 = x.nanmedian(0, keepdim=True)
# y3 is [[0., 1.5, 2.5]]
y4 = x.nanmedian((0, 1))
# y4 is [[2.]]
# y4 is 2.
"""
if not isinstance(x, Variable):
raise TypeError("In median, the input x should be a Tensor.")
......@@ -304,7 +304,6 @@ def nanmedian(x, axis=None, keepdim=True, name=None):
if isinstance(axis, (list, tuple)) and len(axis) == 0:
raise ValueError("Axis list should not be empty.")
dims = len(x.shape)
if axis is None:
axis = []
elif isinstance(axis, tuple):
......@@ -312,24 +311,6 @@ def nanmedian(x, axis=None, keepdim=True, name=None):
elif isinstance(axis, int):
axis = [axis]
if not isinstance(axis, list):
raise ValueError(
"Axis should be None, int, or a list, element should in range [-rank(x), rank(x))."
)
for i in range(len(axis)):
if not isinstance(axis[i], int) or not (
axis[i] < dims and axis[i] >= -dims
):
raise ValueError(
"Axis should be None, int, or a list, element should in range [-rank(x), rank(x))."
)
if axis[i] < 0:
axis[i] += dims
if len(axis) != len(set(axis)):
raise ValueError("Axis has duplicated elements.")
if in_dynamic_mode():
return _C_ops.nanmedian(x, axis, keepdim)
else:
......
......@@ -125,6 +125,7 @@ class TestNanmedian(unittest.TestCase):
pd_res = paddle.nanmedian(
paddle.to_tensor(data), keepdim=keep_dim
)
assert np_res.shape == pd_res.numpy().shape
np.testing.assert_allclose(
np_res, pd_res.numpy(), rtol=1e-05, equal_nan=True
)
......@@ -187,6 +188,23 @@ class TestNanmedian(unittest.TestCase):
x_np[0, :] = np.nan
x_np[1, :3] = np.nan
x_np[2, 3:] = np.nan
x_tensor = paddle.to_tensor(x_np, stop_gradient=False)
y = paddle.nanmedian(x_tensor, keepdim=True)
dx = paddle.grad(y, x_tensor)[0].numpy()
np_grad = np.zeros(shape)
np_grad[1, 3] = 0.5
np_grad[3, 2] = 0.5
np.testing.assert_allclose(np_grad, dx, rtol=1e-05, equal_nan=True)
def test_check_grad_axis(self):
paddle.disable_static(place=self.place)
shape = (4, 5)
x_np = np.random.uniform(-1, 1, shape).astype(np.float64)
x_np[0, :] = np.nan
x_np[1, :3] = np.nan
x_np[2, 3:] = np.nan
x_np_sorted = np.sort(x_np)
nan_counts = np.count_nonzero(np.isnan(x_np).astype(np.int32), axis=1)
np_grad = np.zeros(shape)
......@@ -205,10 +223,25 @@ class TestNanmedian(unittest.TestCase):
np_grad[i, j] = 1 if is_odd else 0.5
x_tensor = paddle.to_tensor(x_np, stop_gradient=False)
y = paddle.nanmedian(x_tensor, axis=1, keepdim=True)
y = paddle.nanmedian(x_tensor, axis=1)
dx = paddle.grad(y, x_tensor)[0].numpy()
np.testing.assert_allclose(np_grad, dx, rtol=1e-05, equal_nan=True)
def test_check_grad_0d(self):
paddle.disable_static(place=self.place)
x = paddle.rand([])
x.stop_gradient = False
y = paddle.nanmedian(x)
y.backward()
self.assertEqual(x.grad.shape, [])
np.testing.assert_allclose(x.grad, np.array(1.0))
x = paddle.to_tensor(float('nan'), stop_gradient=False)
y = paddle.nanmedian(x)
y.backward()
self.assertEqual(x.grad.shape, [])
np.testing.assert_allclose(x.grad, np.array(0.0))
if __name__ == "__main__":
unittest.main()
......@@ -179,6 +179,8 @@ reduce_api_list = [
paddle.mean,
paddle.nansum,
paddle.nanmean,
paddle.median,
paddle.nanmedian,
paddle.min,
paddle.max,
paddle.amin,
......@@ -202,7 +204,7 @@ class TestReduceAPI(unittest.TestCase):
else:
x = paddle.rand([])
x.stop_gradient = False
out = api(x, None)
out = api(x, axis=None)
out.retain_grads()
out.backward()
......@@ -212,7 +214,8 @@ class TestReduceAPI(unittest.TestCase):
if api not in [paddle.count_nonzero]:
np.testing.assert_allclose(out.numpy(), x.numpy())
out_empty_list = api(x, [])
if api not in [paddle.median, paddle.nanmedian]:
out_empty_list = api(x, axis=[])
self.assertEqual(out_empty_list, out)
self.assertEqual(out_empty_list.shape, [])
......@@ -222,12 +225,12 @@ class TestReduceAPI(unittest.TestCase):
np.testing.assert_allclose(x.grad.numpy(), np.array(1.0))
np.testing.assert_allclose(out.grad.numpy(), np.array(1.0))
out1 = api(x, 0)
out1 = api(x, axis=0)
self.assertEqual(out1.shape, [])
self.assertEqual(out1, out)
out1.backward()
out2 = api(x, -1)
out2 = api(x, axis=-1)
self.assertEqual(out2.shape, [])
self.assertEqual(out2, out)
out2.backward()
......@@ -236,13 +239,28 @@ class TestReduceAPI(unittest.TestCase):
self.assertEqual(x.grad.shape, [])
np.testing.assert_allclose(x.grad.numpy(), np.array(3.0))
# 2) x is ND, reduce to 0D
# 2) x is 1D, axis=0, reduce to 0D
if api in [paddle.all, paddle.any]:
x = paddle.randint(0, 2, [5]).astype('bool')
else:
x = paddle.rand([5])
x.stop_gradient = False
out = api(x, axis=0)
out.retain_grads()
out.backward()
self.assertEqual(out.shape, [])
if x.grad is not None:
self.assertEqual(out.grad.shape, [])
self.assertEqual(x.grad.shape, [5])
# 3) x is ND, reduce to 0D
if api in [paddle.all, paddle.any]:
x = paddle.randint(0, 2, [3, 5]).astype('bool')
else:
x = paddle.rand([3, 5])
x.stop_gradient = False
out = api(x, None)
out = api(x, axis=None)
out.retain_grads()
out.backward()
......@@ -251,20 +269,20 @@ class TestReduceAPI(unittest.TestCase):
self.assertEqual(out.grad.shape, [])
self.assertEqual(x.grad.shape, [3, 5])
# 3) x is 1D, axis=0, reduce to 0D
# 4) x is ND, reduce to 0D, keepdim=True
if api in [paddle.all, paddle.any]:
x = paddle.randint(0, 2, [5]).astype('bool')
x = paddle.randint(0, 2, [3, 5]).astype('bool')
else:
x = paddle.rand([5])
x = paddle.rand([3, 5])
x.stop_gradient = False
out = api(x, 0)
out = api(x, keepdim=True)
out.retain_grads()
out.backward()
self.assertEqual(out.shape, [])
self.assertEqual(out.shape, [1, 1])
if x.grad is not None:
self.assertEqual(out.grad.shape, [])
self.assertEqual(x.grad.shape, [5])
self.assertEqual(out.grad.shape, [1, 1])
self.assertEqual(x.grad.shape, [3, 5])
paddle.enable_static()
......@@ -283,16 +301,17 @@ class TestReduceAPI(unittest.TestCase):
else:
x = paddle.rand([])
x.stop_gradient = False
out = api(x, None)
out = api(x, axis=None)
paddle.static.append_backward(out)
out_empty_list = api(x, None)
if api not in [paddle.median, paddle.nanmedian]:
out_empty_list = api(x, axis=[])
self.assertEqual(out_empty_list.shape, ())
out1 = api(x, 0)
out1 = api(x, axis=0)
self.assertEqual(out1.shape, ())
out2 = api(x, -1)
out2 = api(x, axis=-1)
self.assertEqual(out2.shape, ())
fetch_list = [x, out]
......@@ -317,7 +336,7 @@ class TestReduceAPI(unittest.TestCase):
else:
x = paddle.rand([3, 5])
x.stop_gradient = False
out = api(x, None)
out = api(x, axis=None)
paddle.static.append_backward(out)
fetch_list = [out]
......@@ -336,7 +355,7 @@ class TestReduceAPI(unittest.TestCase):
else:
x = paddle.rand([5])
x.stop_gradient = False
out = api(x, 0)
out = api(x, axis=0)
paddle.static.append_backward(out)
fetch_list = [out]
......@@ -1200,54 +1219,6 @@ class TestSundryAPI(unittest.TestCase):
out = paddle.argmax(x, keepdim=True)
self.assertEqual(out.shape, [1, 1])
def test_median(self):
# 1) x is 0D
x = paddle.rand([])
x.stop_gradient = False
out1 = paddle.median(x, 0)
out2 = paddle.median(x, -1)
out3 = paddle.median(x, None)
out1.backward()
out2.backward()
out3.backward()
self.assertEqual(out1.shape, [])
np.testing.assert_allclose(out1, x)
self.assertEqual(out2.shape, [])
np.testing.assert_allclose(out2, x)
self.assertEqual(out3.shape, [])
np.testing.assert_allclose(out3, x)
self.assertEqual(x.grad.shape, [])
np.testing.assert_allclose(x.grad, 3.0)
# 2) x is 1D
x = paddle.rand([5])
x.stop_gradient = False
out = paddle.median(x, 0)
out.backward()
self.assertEqual(out.shape, [])
self.assertEqual(x.grad.shape, [5])
# 3) x is ND
x = paddle.rand([3, 5])
x.stop_gradient = False
out = paddle.median(x, None)
out.backward()
self.assertEqual(out.shape, [])
self.assertEqual(x.grad.shape, [3, 5])
# 4) x is ND, keepdim=True
x = paddle.rand([3, 5])
x.stop_gradient = False
out = paddle.median(x, keepdim=True)
out.backward()
self.assertEqual(out.shape, [1, 1])
self.assertEqual(x.grad.shape, [3, 5])
def test_kthvalue(self):
# 1) x is 0D
x = paddle.randn([])
......@@ -1535,6 +1506,40 @@ class TestSundryAPI(unittest.TestCase):
self.assertEqual(out.grad, 1.0)
self.assertEqual(x.grad.shape, [2, 3])
def test_nanquantile(self):
# 1) x is 0D
x = paddle.rand([])
x.stop_gradient = False
out = paddle.quantile(x, 0.5, axis=None)
out.retain_grads()
out.backward()
out_empty_list = paddle.quantile(x, 0.5, axis=[])
self.assertEqual(out_empty_list, out)
self.assertEqual(x.shape, [])
self.assertEqual(out.shape, [])
self.assertEqual(out, x)
self.assertEqual(x.grad.shape, [])
self.assertEqual(x.grad, 1.0)
self.assertEqual(out.grad.shape, [])
self.assertEqual(out.grad, 1.0)
# 2) x is ND with 'nan'
x = paddle.to_tensor([[float('nan'), 2.0, 3.0], [0.0, 1.0, 2.0]])
x.stop_gradient = False
out = paddle.quantile(x, 0.5, axis=None)
out.retain_grads()
out.backward()
self.assertEqual(out.shape, [])
self.assertEqual(out.grad.shape, [])
self.assertEqual(out.grad, 1.0)
self.assertEqual(x.grad.shape, [2, 3])
def test_flip(self):
x = paddle.rand([])
x.stop_gradient = False
......@@ -3442,40 +3447,6 @@ class TestSundryAPIStatic(unittest.TestCase):
np.testing.assert_allclose(res[2], 0.0)
self.assertEqual(res[3].shape, ())
@prog_scope()
def test_median(self):
# 1) x is 0D
x = paddle.rand([])
x.stop_gradient = False
out = paddle.median(x)
paddle.static.append_backward(out)
# 2) x is ND
x1 = paddle.rand([3, 5])
x1.stop_gradient = False
out1 = paddle.median(x1)
paddle.static.append_backward(out1)
prog = paddle.static.default_main_program()
res = self.exe.run(
prog,
fetch_list=[
x,
out,
x.grad_name,
out1,
x1.grad_name,
],
)
self.assertEqual(res[1].shape, ())
np.testing.assert_allclose(res[1], res[0])
self.assertEqual(res[2].shape, ())
np.testing.assert_allclose(res[2], 1.0)
self.assertEqual(res[3].shape, ())
self.assertEqual(res[4].shape, (3, 5))
@prog_scope()
def test_kthvalue(self):
# 1) x is 0D
......@@ -3813,12 +3784,12 @@ class TestSundryAPIStatic(unittest.TestCase):
x1 = paddle.rand([])
x1.stop_gradient = False
out1 = paddle.quantile(x1, 0.5, axis=None)
paddle.static.append_backward(out1.sum())
paddle.static.append_backward(out1)
x2 = paddle.rand([2, 3])
x2.stop_gradient = False
out2 = paddle.quantile(x2, 0.5, axis=None)
paddle.static.append_backward(out2.sum())
paddle.static.append_backward(out2)
out_empty_list = paddle.quantile(x1, 0.5, axis=[])
self.assertEqual(out_empty_list.shape, ())
......@@ -3846,6 +3817,37 @@ class TestSundryAPIStatic(unittest.TestCase):
self.assertEqual(res[5].shape, ())
self.assertEqual(res[5], 1.0)
@prog_scope()
def test_nanquantile(self):
# 1) x is 0D
x1 = paddle.rand([])
x1.stop_gradient = False
out1 = paddle.nanquantile(x1, 0.5, axis=None)
paddle.static.append_backward(out1)
# 2) x is ND with 'nan'
x2 = paddle.to_tensor([[float('nan'), 2.0, 3.0], [0.0, 1.0, 2.0]])
x2.stop_gradient = False
out2 = paddle.nanquantile(x2, 0.5, axis=None)
print(out2)
paddle.static.append_backward(out2)
prog = paddle.static.default_main_program()
res = self.exe.run(
prog,
fetch_list=[
out1,
x1.grad_name,
out2,
x2.grad_name,
],
)
self.assertEqual(res[0].shape, ())
self.assertEqual(res[1].shape, ())
self.assertEqual(res[2].shape, ())
self.assertEqual(res[3].shape, (2, 3))
@prog_scope()
def test_flip(self):
x = paddle.rand([])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册