未验证 提交 35de47b3 编写于 作者: zhouweiwei2014's avatar zhouweiwei2014 提交者: GitHub

[cherry-pick 2.5][Zero-Dim] paddle.nanmedian/count_nonzero/logspace support...

[cherry-pick 2.5][Zero-Dim] paddle.nanmedian/count_nonzero/logspace support 0D, add some 0D case (#54649)

* [Zero-Dim] add 0D test case (#54581)

* [Zero-Dim] paddle.nanmedian/nanquantile support 0D Tensor (#54500)

* [Zero-Dim] paddle.nanmedian support 0D Tensor

* fix CI
上级 cf64aa0b
...@@ -2162,32 +2162,32 @@ void LogspaceInferMeta(const MetaTensor& start, ...@@ -2162,32 +2162,32 @@ void LogspaceInferMeta(const MetaTensor& start,
MetaTensor* out) { MetaTensor* out) {
auto s_dims = start.dims(); auto s_dims = start.dims();
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
(s_dims.size() == 1) && (s_dims[0] == 1), phi::product(s_dims),
true, 1,
phi::errors::InvalidArgument("The shape of Input(Start) must be [1]," phi::errors::InvalidArgument("The size of Input(Start) must be 1,"
"but received input shape is [%s].", "but received input size is %s.",
s_dims)); phi::product(s_dims)));
auto e_dims = stop.dims(); auto e_dims = stop.dims();
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
(e_dims.size() == 1) && (e_dims[0] == 1), phi::product(e_dims),
true, true,
phi::errors::InvalidArgument("The shape of Input(Stop) must be [1]," phi::errors::InvalidArgument("The size of Input(Stop) must be 1,"
"but received input shape is [%s].", "but received input size is %s.",
e_dims)); phi::product(e_dims)));
auto num_dims = number.dims(); auto num_dims = number.dims();
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
(num_dims.size() == 1) && (num_dims[0] == 1), phi::product(num_dims),
true, true,
phi::errors::InvalidArgument("The shape of Input(Num) must be [1]," phi::errors::InvalidArgument("The size of Input(Num) must be 1,"
"but received input shape is [%s].", "but received input size is %s.",
num_dims)); phi::product(num_dims)));
auto b_dims = base.dims(); auto b_dims = base.dims();
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(phi::product(b_dims),
(b_dims.size() == 1) && (b_dims[0] == 1), true,
true, phi::errors::InvalidArgument(
phi::errors::InvalidArgument("The shape of Input(Base) must be [1]," "The size of Input(Base) must be 1,"
"but received input shape is [%s].", "but received input size is phi::product(b_dims).",
b_dims)); phi::product(b_dims)));
out->set_dims(phi::make_ddim({-1})); out->set_dims(phi::make_ddim({-1}));
out->set_dtype(dtype); out->set_dtype(dtype);
} }
......
...@@ -2260,37 +2260,47 @@ void NanmedianInferMeta(const MetaTensor& x, ...@@ -2260,37 +2260,47 @@ void NanmedianInferMeta(const MetaTensor& x,
for (int64_t i = 0; i < x_rank; i++) { for (int64_t i = 0; i < x_rank; i++) {
out_dim.push_back(1); out_dim.push_back(1);
} }
} else {
out_dim.push_back(1);
} }
} else { } else {
std::vector<int64_t> cleaned_axis; std::vector<int64_t> formated_axis;
for (auto& axis : axis_list) { for (auto& axis : axis_list) {
if (x_rank == 0) {
PADDLE_ENFORCE_EQ(axis == 0 || axis == -1,
true,
phi::errors::InvalidArgument(
"When input 0D Tensor, each element of the axis "
"can only be -1, 0, None"));
} else {
PADDLE_ENFORCE_LT(axis,
x_rank,
errors::InvalidArgument(
"each element of the axis should be in the "
"range [ -dimension(X), dimension(X) ) "
"which dimesion = %d. But received axis = %d.",
x_rank,
axis));
PADDLE_ENFORCE_GE(axis,
-x_rank,
errors::InvalidArgument(
"each element of the axis should be in the "
"range [ -dimension(X), dimension(X) ) "
"which dimesion = %d. But received axis = %d.",
x_rank,
axis));
}
if (axis < 0) axis += x_rank; if (axis < 0) axis += x_rank;
PADDLE_ENFORCE_LT(
axis,
x_rank,
errors::InvalidArgument(
"Attr(axis) value should be in range [-R, R-1], R is "
"the rank of Input(X). But received axis: %d, R: %d. "
"Current Input(X)'s shape is=[%s].",
axis,
x_rank,
x_dim));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
std::find(cleaned_axis.begin(), cleaned_axis.end(), axis), std::find(formated_axis.begin(), formated_axis.end(), axis),
cleaned_axis.end(), formated_axis.end(),
errors::InvalidArgument("Attr(axes) has duplicated elements: %d.", errors::InvalidArgument("Attr(axes) has duplicated elements: %d.",
static_cast<int>(axis))); static_cast<int>(axis)));
cleaned_axis.push_back(axis); formated_axis.push_back(axis);
} }
for (int64_t i = 0; i < x_rank; i++) { for (int64_t i = 0; i < x_rank; i++) {
if (std::find(cleaned_axis.begin(), cleaned_axis.end(), i) == if (std::find(formated_axis.begin(), formated_axis.end(), i) ==
cleaned_axis.end()) { formated_axis.end()) {
out_dim.push_back(x_dim[i]); out_dim.push_back(x_dim[i]);
} else if (keep_dim) { } else if (keep_dim) {
out_dim.push_back(1); out_dim.push_back(1);
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
#include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/impl/nanmedian_grad_kernel_impl.h" #include "paddle/phi/kernels/funcs/nanmedian_utils.h"
namespace phi { namespace phi {
...@@ -26,67 +26,64 @@ void CalcMedianGradKernel(const Context& dev_ctx, ...@@ -26,67 +26,64 @@ void CalcMedianGradKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const DenseTensor& median_index, const DenseTensor& median_index,
const DenseTensor& out_grad, const DenseTensor& out_grad,
const IntArray& axes UNUSED, DenseTensor* x_grad) {
DenseTensor* x_grad, T* dx_data = dev_ctx.template Alloc<T>(x_grad);
T* x_grad_ptr) { if (!dx_data) return;
phi::funcs::SetConstant<Context, T> set_zero; phi::funcs::SetConstant<Context, T> set_zero;
set_zero(dev_ctx, x_grad, static_cast<T>(0)); set_zero(dev_ctx, x_grad, static_cast<T>(0));
if (!x_grad_ptr) return;
const int64_t* m_ptr = median_index.data<int64_t>(); const int64_t* m_data = median_index.data<int64_t>();
const T* out_grad_ptr = out_grad.data<T>(); const T* dout_data = out_grad.data<T>();
int64_t numel = x.numel(); int64_t numel = x.numel();
auto x_dim = x.dims(); auto x_dim = x.dims();
int64_t rank = x_dim.size(); int64_t rank = x_dim.size();
int64_t stride = x_dim[rank - 1]; int64_t stride = x_dim[rank - 1];
int64_t pre_dim = numel / stride; int64_t pre_dim = numel / stride;
int64_t i = 0; int64_t i = 0;
int64_t offset = 0; int64_t offset = 0;
T div_factor = static_cast<T>(2.0);
for (i = 0; i < pre_dim; i++) { for (i = 0; i < pre_dim; i++) {
if (m_ptr[2 * i] >= 0) { if (m_data[2 * i] >= 0) {
if (m_ptr[2 * i] == m_ptr[2 * i + 1]) { if (m_data[2 * i] == m_data[2 * i + 1]) {
x_grad_ptr[offset + m_ptr[2 * i]] = out_grad_ptr[i]; dx_data[offset + m_data[2 * i]] = dout_data[i];
} else { } else {
x_grad_ptr[offset + m_ptr[2 * i]] = out_grad_ptr[i] / div_factor; dx_data[offset + m_data[2 * i]] = dout_data[i] / static_cast<T>(2.0);
x_grad_ptr[offset + m_ptr[2 * i + 1]] = out_grad_ptr[i] / div_factor; dx_data[offset + m_data[2 * i + 1]] =
dout_data[i] / static_cast<T>(2.0);
} }
} }
offset += stride; offset += stride;
} }
} }
template <typename T, typename Context>
void BaseMedianGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& median_index,
const DenseTensor& out_grad,
const IntArray& axes,
DenseTensor* x_grad) {
auto rank = x.dims().size();
T* x_grad_ptr = dev_ctx.template Alloc<T>(x_grad);
if (axes.size() && (rank > 1)) {
DenseTensor tmp_x_grad(*x_grad);
CalcMedianGradKernel<T, Context>(
dev_ctx, x, median_index, out_grad, axes, &tmp_x_grad, x_grad_ptr);
PostprocessMedianGradKernel<T, Context>(dev_ctx, &tmp_x_grad, axes, x_grad);
} else {
CalcMedianGradKernel<T, Context>(
dev_ctx, x, median_index, out_grad, axes, x_grad, x_grad_ptr);
}
}
template <typename T, typename Context> template <typename T, typename Context>
void NanmedianGradKernel(const Context& dev_ctx, void NanmedianGradKernel(const Context& dev_ctx,
const DenseTensor& input, const DenseTensor& x,
const DenseTensor& median_index, const DenseTensor& median_index,
const DenseTensor& out_grad, const DenseTensor& out_grad,
const IntArray& axes, const IntArray& axes,
bool keep_dim UNUSED, bool keepdim UNUSED,
DenseTensor* x_grad) { DenseTensor* x_grad) {
BaseMedianGradKernel<T, Context>( DenseTensor tmp_x;
dev_ctx, input, median_index, out_grad, axes, x_grad); auto rank = x.dims().size();
if ((axes.size() == 0) || rank <= 1) {
tmp_x = x;
tmp_x.Resize({x.numel()});
CalcMedianGradKernel<T, Context>(
dev_ctx, tmp_x, median_index, out_grad, x_grad);
} else {
funcs::PreprocessMedianKernel<T, Context>(dev_ctx, x, axes, &tmp_x);
DenseTensor tmp_x_grad;
tmp_x_grad.Resize(x_grad->dims());
CalcMedianGradKernel<T, Context>(
dev_ctx, tmp_x, median_index, out_grad, &tmp_x_grad);
dev_ctx.template Alloc<T>(x_grad);
funcs::PostprocessMedianGradKernel<T, Context>(
dev_ctx, &tmp_x_grad, axes, x_grad);
}
} }
} // namespace phi } // namespace phi
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
#include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/nanmedian_kernel_impl.h" #include "paddle/phi/kernels/funcs/nanmedian_utils.h"
#include "paddle/phi/kernels/top_k_kernel.h" #include "paddle/phi/kernels/top_k_kernel.h"
namespace phi { namespace phi {
...@@ -31,7 +31,6 @@ void CalcMedianFunc(const Context& dev_ctx, ...@@ -31,7 +31,6 @@ void CalcMedianFunc(const Context& dev_ctx,
int64_t pre_dim, int64_t pre_dim,
T* o_ptr, T* o_ptr,
int64_t* m_ptr) { int64_t* m_ptr) {
bool should_ignore_nan = ignore_nan;
DenseTensor sort_out; DenseTensor sort_out;
DenseTensor sort_indices; DenseTensor sort_indices;
auto sort_dim = x.dims(); auto sort_dim = x.dims();
...@@ -52,7 +51,7 @@ void CalcMedianFunc(const Context& dev_ctx, ...@@ -52,7 +51,7 @@ void CalcMedianFunc(const Context& dev_ctx,
int64_t offset = 0; int64_t offset = 0;
int64_t i = 0; int64_t i = 0;
bool is_ori_odd = stride & 1; bool is_ori_odd = stride & 1;
if (should_ignore_nan) { if (ignore_nan) {
for (i = 0; i < pre_dim; i++) { for (i = 0; i < pre_dim; i++) {
offset = i * sort_k; offset = i * sort_k;
if (nan_counts[i] == stride) { if (nan_counts[i] == stride) {
...@@ -107,11 +106,11 @@ void CalcMedianFunc(const Context& dev_ctx, ...@@ -107,11 +106,11 @@ void CalcMedianFunc(const Context& dev_ctx,
template <typename T, typename Context> template <typename T, typename Context>
void ProcessMedianKernel(const Context& dev_ctx, void ProcessMedianKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
T* o_ptr, DenseTensor* out,
int64_t* m_ptr, DenseTensor* median_index) {
bool ignore_nan) { const T* x_data = x.data<T>();
bool should_ignore_nan = ignore_nan; T* out_data = dev_ctx.template Alloc<T>(out);
const T* x_ptr = x.data<T>(); int64_t* m_data = dev_ctx.template Alloc<int64_t>(median_index);
int64_t numel = x.numel(); int64_t numel = x.numel();
auto x_dim = x.dims(); auto x_dim = x.dims();
...@@ -122,7 +121,8 @@ void ProcessMedianKernel(const Context& dev_ctx, ...@@ -122,7 +121,8 @@ void ProcessMedianKernel(const Context& dev_ctx,
int64_t max_valid_num = 0; int64_t max_valid_num = 0;
std::vector<int64_t> nan_counts; std::vector<int64_t> nan_counts;
if (should_ignore_nan) { bool ignore_nan = true;
if (ignore_nan) {
int64_t total_nan_num = 0; int64_t total_nan_num = 0;
std::vector<T> col_vec; std::vector<T> col_vec;
col_vec.reserve(stride); col_vec.reserve(stride);
...@@ -133,7 +133,7 @@ void ProcessMedianKernel(const Context& dev_ctx, ...@@ -133,7 +133,7 @@ void ProcessMedianKernel(const Context& dev_ctx,
for (int64_t i = 0; i < pre_dim; i++) { for (int64_t i = 0; i < pre_dim; i++) {
col_vec.clear(); col_vec.clear();
col_vec.insert( col_vec.insert(
col_vec.begin(), x_ptr + i * stride, x_ptr + (i + 1) * stride); col_vec.begin(), x_data + i * stride, x_data + (i + 1) * stride);
nan_counts[i] = nan_counts[i] =
std::count_if(col_vec.begin(), col_vec.end(), [&](const T& val) { std::count_if(col_vec.begin(), col_vec.end(), [&](const T& val) {
return std::isnan(static_cast<float>(val)); return std::isnan(static_cast<float>(val));
...@@ -145,47 +145,25 @@ void ProcessMedianKernel(const Context& dev_ctx, ...@@ -145,47 +145,25 @@ void ProcessMedianKernel(const Context& dev_ctx,
// all elems are nan // all elems are nan
if (total_nan_num == numel) { if (total_nan_num == numel) {
for (i = 0; i < pre_dim; i++) { for (i = 0; i < pre_dim; i++) {
o_ptr[i] = x_ptr[0]; out_data[i] = std::numeric_limits<T>::quiet_NaN();
m_ptr[2 * i] = -1; m_data[2 * i] = -1;
m_ptr[2 * i + 1] = -1; m_data[2 * i + 1] = -1;
} }
return; return;
} }
should_ignore_nan = total_nan_num > 0; ignore_nan = total_nan_num > 0;
} }
int64_t sort_k = should_ignore_nan ? max_valid_num : ((stride >> 1) + 1); int64_t sort_k = ignore_nan ? max_valid_num : ((stride >> 1) + 1);
CalcMedianFunc<T, Context>(dev_ctx, CalcMedianFunc<T, Context>(dev_ctx,
x, x,
nan_counts, nan_counts,
should_ignore_nan, ignore_nan,
sort_k, sort_k,
stride, stride,
pre_dim, pre_dim,
o_ptr, out_data,
m_ptr); m_data);
}
template <typename T, typename Context>
void BaseMedianKernel(const Context& dev_ctx,
const DenseTensor& input,
const IntArray& axes,
DenseTensor* out,
DenseTensor* median_index,
bool ignore_nan) {
DenseTensor x;
auto rank = input.dims().size();
if ((axes.size() == 0) || rank <= 1) {
x = input;
x.Resize({input.numel()});
} else {
PreprocessMedianKernel<T, Context>(dev_ctx, input, axes, &x);
}
T* o_ptr = dev_ctx.template Alloc<T>(out);
int64_t* m_ptr = dev_ctx.template Alloc<int64_t>(median_index);
ProcessMedianKernel<T, Context>(dev_ctx, x, o_ptr, m_ptr, ignore_nan);
out->Resize(out->dims());
} }
template <typename T, typename Context> template <typename T, typename Context>
...@@ -195,7 +173,16 @@ void NanmedianKernel(const Context& dev_ctx, ...@@ -195,7 +173,16 @@ void NanmedianKernel(const Context& dev_ctx,
bool keepdim UNUSED, bool keepdim UNUSED,
DenseTensor* out, DenseTensor* out,
DenseTensor* median_index) { DenseTensor* median_index) {
BaseMedianKernel<T, Context>(dev_ctx, x, axes, out, median_index, true); DenseTensor tmp_x;
auto rank = x.dims().size();
if ((axes.size() == 0) || rank <= 1) {
tmp_x = x;
tmp_x.Resize({x.numel()});
} else {
funcs::PreprocessMedianKernel<T, Context>(dev_ctx, x, axes, &tmp_x);
}
ProcessMedianKernel<T, Context>(dev_ctx, tmp_x, out, median_index);
} }
} // namespace phi } // namespace phi
......
...@@ -15,9 +15,51 @@ ...@@ -15,9 +15,51 @@
#pragma once #pragma once
#include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/nanmedian_kernel.h"
namespace phi { namespace phi {
namespace funcs {
template <typename T, typename Context>
void PostprocessMedianGradKernel(const Context& dev_ctx,
DenseTensor* input,
const IntArray& raw_axes,
DenseTensor* x) {
auto input_dim = input->dims();
auto rank = input_dim.size();
std::vector<int64_t> axes = raw_axes.GetData();
int64_t axes_size = static_cast<int>(axes.size());
for (int64_t i = 0; i < axes_size; i++) {
if (axes[i] < 0) {
axes[i] += rank;
}
}
std::vector<int> trans_back;
std::vector<int> reshape_back;
trans_back.resize(rank);
int offset = 0;
for (int64_t i = 0; i < rank; i++) {
if (std::find(axes.begin(), axes.end(), i) == axes.end()) {
reshape_back.push_back(input_dim[i]);
trans_back[i] = offset;
offset += 1;
}
}
for (int64_t i = 0; i < rank; i++) {
if (std::find(axes.begin(), axes.end(), i) != axes.end()) {
trans_back[i] = offset;
reshape_back.push_back(input_dim[i]);
offset += 1;
}
}
input->Resize(make_ddim(reshape_back));
funcs::TransCompute<Context, T>(
static_cast<int>(trans_back.size()), dev_ctx, *input, x, trans_back);
}
template <typename T, typename Context> template <typename T, typename Context>
void PreprocessMedianKernel(const Context& dev_ctx, void PreprocessMedianKernel(const Context& dev_ctx,
...@@ -65,4 +107,5 @@ void PreprocessMedianKernel(const Context& dev_ctx, ...@@ -65,4 +107,5 @@ void PreprocessMedianKernel(const Context& dev_ctx,
x->Resize(make_ddim(reshape)); x->Resize(make_ddim(reshape));
} }
} // namespace funcs
} // namespace phi } // namespace phi
...@@ -20,7 +20,7 @@ ...@@ -20,7 +20,7 @@
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_meta.h" #include "paddle/phi/core/tensor_meta.h"
#include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/impl/nanmedian_grad_kernel_impl.h" #include "paddle/phi/kernels/funcs/nanmedian_utils.h"
namespace phi { namespace phi {
...@@ -30,23 +30,26 @@ inline int GET_BLOCKS(const int N) { ...@@ -30,23 +30,26 @@ inline int GET_BLOCKS(const int N) {
} }
template <typename T> template <typename T>
__global__ void KernelNanmedianGrad(const T* x_ptr, __global__ void KernelNanmedianGrad(const T* x_data,
const int64_t* medians_ptr, const int64_t* medians_ptr,
const T* out_grad_ptr, const T* out_grad_ptr,
T* x_grad_ptr, T* dx_data,
int64_t stride, int64_t stride,
int64_t pre_dim, int64_t pre_dim) {
T div_factor) {
CUDA_KERNEL_LOOP(index, pre_dim) { CUDA_KERNEL_LOOP(index, pre_dim) {
int64_t offset = index * stride; int64_t offset = index * stride;
printf("index: %d\n", index);
printf("medians_ptr[2 * index]: %d\n", medians_ptr[2 * index]);
printf("medians_ptr[2 * index+1]: %d\n", medians_ptr[2 * index + 1]);
if (medians_ptr[2 * index] >= 0) { if (medians_ptr[2 * index] >= 0) {
if (medians_ptr[2 * index] == medians_ptr[2 * index + 1]) { if (medians_ptr[2 * index] == medians_ptr[2 * index + 1]) {
x_grad_ptr[offset + medians_ptr[2 * index]] = out_grad_ptr[index]; dx_data[offset + medians_ptr[2 * index]] = out_grad_ptr[index];
} else { } else {
x_grad_ptr[offset + medians_ptr[2 * index]] = dx_data[offset + medians_ptr[2 * index]] =
out_grad_ptr[index] / div_factor; out_grad_ptr[index] / static_cast<T>(2.0);
x_grad_ptr[offset + medians_ptr[2 * index + 1]] = dx_data[offset + medians_ptr[2 * index + 1]] =
out_grad_ptr[index] / div_factor; out_grad_ptr[index] / static_cast<T>(2.0);
} }
} }
} }
...@@ -57,14 +60,16 @@ void CalcMedianGradKernel(const Context& dev_ctx, ...@@ -57,14 +60,16 @@ void CalcMedianGradKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const DenseTensor& median_index, const DenseTensor& median_index,
const DenseTensor& out_grad, const DenseTensor& out_grad,
DenseTensor* x_grad, DenseTensor* x_grad) {
T* x_grad_ptr) { T* dx_data = dev_ctx.template Alloc<T>(x_grad);
if (!dx_data) return;
phi::funcs::SetConstant<Context, T> set_zero; phi::funcs::SetConstant<Context, T> set_zero;
set_zero(dev_ctx, x_grad, static_cast<T>(0)); set_zero(dev_ctx, x_grad, static_cast<T>(0));
auto stream = dev_ctx.stream(); auto stream = dev_ctx.stream();
const T* x_ptr = x.data<T>(); const T* x_data = x.data<T>();
const int64_t* m_ptr = median_index.data<int64_t>(); const int64_t* m_data = median_index.data<int64_t>();
const T* out_grad_ptr = out_grad.data<T>(); const T* out_grad_ptr = out_grad.data<T>();
int64_t numel = x.numel(); int64_t numel = x.numel();
...@@ -73,42 +78,38 @@ void CalcMedianGradKernel(const Context& dev_ctx, ...@@ -73,42 +78,38 @@ void CalcMedianGradKernel(const Context& dev_ctx,
int64_t stride = x_dim[x_rank - 1]; int64_t stride = x_dim[x_rank - 1];
int64_t pre_dim = numel / stride; int64_t pre_dim = numel / stride;
T div_factor = static_cast<T>(2.0);
KernelNanmedianGrad<T> KernelNanmedianGrad<T>
<<<GET_BLOCKS(pre_dim), PADDLE_CUDA_NUM_THREADS, 0, stream>>>( <<<GET_BLOCKS(pre_dim), PADDLE_CUDA_NUM_THREADS, 0, stream>>>(
x_ptr, m_ptr, out_grad_ptr, x_grad_ptr, stride, pre_dim, div_factor); x_data, m_data, out_grad_ptr, dx_data, stride, pre_dim);
}
template <typename T, typename Context>
void BaseMedianGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& median_index,
const DenseTensor& out_grad,
const IntArray& axes,
DenseTensor* x_grad) {
auto rank = x.dims().size();
T* x_grad_ptr = dev_ctx.template Alloc<T>(x_grad);
if (axes.size() && (rank > 1)) {
DenseTensor tmp_x_grad(*x_grad);
CalcMedianGradKernel<T, Context>(
dev_ctx, x, median_index, out_grad, &tmp_x_grad, x_grad_ptr);
PostprocessMedianGradKernel<T, Context>(dev_ctx, &tmp_x_grad, axes, x_grad);
} else {
CalcMedianGradKernel<T, Context>(
dev_ctx, x, median_index, out_grad, x_grad, x_grad_ptr);
}
} }
template <typename T, typename Context> template <typename T, typename Context>
void NanmedianGradKernel(const Context& dev_ctx, void NanmedianGradKernel(const Context& dev_ctx,
const DenseTensor& input, const DenseTensor& x,
const DenseTensor& median_index, const DenseTensor& median_index,
const DenseTensor& out_grad, const DenseTensor& out_grad,
const IntArray& axes, const IntArray& axes,
bool keep_dim, bool keepdim UNUSED,
DenseTensor* x_grad) { DenseTensor* x_grad) {
BaseMedianGradKernel<T, Context>( DenseTensor tmp_x;
dev_ctx, input, median_index, out_grad, axes, x_grad); auto rank = x.dims().size();
if ((axes.size() == 0) || rank <= 1) {
tmp_x = x;
tmp_x.Resize({x.numel()});
CalcMedianGradKernel<T, Context>(
dev_ctx, tmp_x, median_index, out_grad, x_grad);
} else {
funcs::PreprocessMedianKernel<T, Context>(dev_ctx, x, axes, &tmp_x);
DenseTensor tmp_x_grad;
tmp_x_grad.Resize(x_grad->dims());
CalcMedianGradKernel<T, Context>(
dev_ctx, tmp_x, median_index, out_grad, &tmp_x_grad);
dev_ctx.template Alloc<T>(x_grad);
funcs::PostprocessMedianGradKernel<T, Context>(
dev_ctx, &tmp_x_grad, axes, x_grad);
}
} }
} // namespace phi } // namespace phi
......
...@@ -20,7 +20,7 @@ ...@@ -20,7 +20,7 @@
#include "paddle/phi/common/memory_utils.h" #include "paddle/phi/common/memory_utils.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/full_kernel.h" #include "paddle/phi/kernels/full_kernel.h"
#include "paddle/phi/kernels/impl/nanmedian_kernel_impl.h" #include "paddle/phi/kernels/funcs/nanmedian_utils.h"
#include "paddle/phi/kernels/top_k_kernel.h" #include "paddle/phi/kernels/top_k_kernel.h"
namespace phi { namespace phi {
...@@ -138,14 +138,13 @@ __global__ void CalcNanmedianKernel(const T* sort_out_ptr, ...@@ -138,14 +138,13 @@ __global__ void CalcNanmedianKernel(const T* sort_out_ptr,
template <typename T, typename Context> template <typename T, typename Context>
void ProcessMedianKernel(const Context& dev_ctx, void ProcessMedianKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
bool ignore_nan,
DenseTensor* out, DenseTensor* out,
int64_t* m_ptr) { DenseTensor* median_index) {
bool should_ignore_nan = ignore_nan;
auto stream = dev_ctx.stream(); auto stream = dev_ctx.stream();
const T* x_data = x.data<T>();
T* out_data = dev_ctx.template Alloc<T>(out);
int64_t* m_data = dev_ctx.template Alloc<int64_t>(median_index);
const T* x_ptr = x.data<T>();
T* o_ptr = dev_ctx.template Alloc<T>(out);
int64_t numel = x.numel(); int64_t numel = x.numel();
auto x_dim = x.dims(); auto x_dim = x.dims();
int64_t x_rank = x_dim.size(); int64_t x_rank = x_dim.size();
...@@ -156,7 +155,9 @@ void ProcessMedianKernel(const Context& dev_ctx, ...@@ -156,7 +155,9 @@ void ProcessMedianKernel(const Context& dev_ctx,
DenseTensor nan_counts, nan_stat; DenseTensor nan_counts, nan_stat;
int64_t* nan_counts_ptr; int64_t* nan_counts_ptr;
int64_t max_valid_num = 0; int64_t max_valid_num = 0;
if (should_ignore_nan) {
bool ignore_nan = true;
if (ignore_nan) {
nan_counts.Resize(phi::make_ddim({pre_dim})); nan_counts.Resize(phi::make_ddim({pre_dim}));
dev_ctx.template Alloc<int64_t>(&nan_counts); dev_ctx.template Alloc<int64_t>(&nan_counts);
nan_counts_ptr = nan_counts.data<int64_t>(); nan_counts_ptr = nan_counts.data<int64_t>();
...@@ -167,7 +168,7 @@ void ProcessMedianKernel(const Context& dev_ctx, ...@@ -167,7 +168,7 @@ void ProcessMedianKernel(const Context& dev_ctx,
KernelNanCounts<T><<<GET_BLOCKS(numel), KernelNanCounts<T><<<GET_BLOCKS(numel),
PADDLE_CUDA_NUM_THREADS, PADDLE_CUDA_NUM_THREADS,
pre_dim * sizeof(int64_t), pre_dim * sizeof(int64_t),
stream>>>(x_ptr, stream>>>(x_data,
numel, numel,
pre_dim, pre_dim,
stride, stride,
...@@ -189,15 +190,19 @@ void ProcessMedianKernel(const Context& dev_ctx, ...@@ -189,15 +190,19 @@ void ProcessMedianKernel(const Context& dev_ctx,
// all elements are nan values // all elements are nan values
T nan_val = std::numeric_limits<T>::quiet_NaN(); T nan_val = std::numeric_limits<T>::quiet_NaN();
if (nan_stat_cpu_ptr[0] == numel) { if (nan_stat_cpu_ptr[0] == numel) {
FullLikeKernel<T, Context>(dev_ctx, x, nan_val, x.dtype(), out); phi::funcs::SetConstant<Context, T> set_nan;
set_nan(dev_ctx, out, nan_val);
phi::funcs::SetConstant<Context, int64_t> set_negatvie;
set_negatvie(dev_ctx, median_index, static_cast<int64_t>(-1));
return; return;
} }
should_ignore_nan = nan_stat_cpu_ptr[0] > 0; ignore_nan = nan_stat_cpu_ptr[0] > 0;
max_valid_num = nan_stat_cpu_ptr[1]; max_valid_num = nan_stat_cpu_ptr[1];
} }
int64_t sort_k = should_ignore_nan ? max_valid_num : ((stride >> 1) + 1); int64_t sort_k = ignore_nan ? max_valid_num : ((stride >> 1) + 1);
bool is_ori_odd = stride & 1; bool is_ori_odd = stride & 1;
DenseTensor sort_out, sort_indices; DenseTensor sort_out, sort_indices;
...@@ -217,14 +222,14 @@ void ProcessMedianKernel(const Context& dev_ctx, ...@@ -217,14 +222,14 @@ void ProcessMedianKernel(const Context& dev_ctx,
T div_factor = static_cast<T>(2.0); T div_factor = static_cast<T>(2.0);
T nan_val = std::numeric_limits<T>::quiet_NaN(); T nan_val = std::numeric_limits<T>::quiet_NaN();
if (should_ignore_nan) { if (ignore_nan) {
CalcNanmedianKernel<T> CalcNanmedianKernel<T>
<<<GET_BLOCKS(pre_dim), PADDLE_CUDA_NUM_THREADS, 0, stream>>>( <<<GET_BLOCKS(pre_dim), PADDLE_CUDA_NUM_THREADS, 0, stream>>>(
sort_out_ptr, sort_out_ptr,
sort_indices_ptr, sort_indices_ptr,
nan_counts_ptr, nan_counts_ptr,
m_ptr, m_data,
o_ptr, out_data,
is_ori_odd, is_ori_odd,
pre_dim, pre_dim,
max_valid_num, max_valid_num,
...@@ -236,8 +241,8 @@ void ProcessMedianKernel(const Context& dev_ctx, ...@@ -236,8 +241,8 @@ void ProcessMedianKernel(const Context& dev_ctx,
<<<GET_BLOCKS(pre_dim), PADDLE_CUDA_NUM_THREADS, 0, stream>>>( <<<GET_BLOCKS(pre_dim), PADDLE_CUDA_NUM_THREADS, 0, stream>>>(
sort_out_ptr, sort_out_ptr,
sort_indices_ptr, sort_indices_ptr,
m_ptr, m_data,
o_ptr, out_data,
div_factor, div_factor,
is_ori_odd, is_ori_odd,
pre_dim, pre_dim,
...@@ -245,27 +250,6 @@ void ProcessMedianKernel(const Context& dev_ctx, ...@@ -245,27 +250,6 @@ void ProcessMedianKernel(const Context& dev_ctx,
} }
} }
template <typename T, typename Context>
void BaseMedianKernel(const Context& dev_ctx,
const DenseTensor& input,
const IntArray& axes,
bool ignore_nan,
DenseTensor* out,
DenseTensor* median_index) {
DenseTensor x;
auto rank = input.dims().size();
if ((axes.size() == 0) || rank <= 1) {
x = input;
x.Resize({input.numel()});
} else {
PreprocessMedianKernel<T, Context>(dev_ctx, input, axes, &x);
}
int64_t* m_ptr = dev_ctx.template Alloc<int64_t>(median_index);
ProcessMedianKernel<T, Context>(dev_ctx, x, ignore_nan, out, m_ptr);
out->Resize(out->dims());
}
template <typename T, typename Context> template <typename T, typename Context>
void NanmedianKernel(const Context& dev_ctx, void NanmedianKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
...@@ -273,7 +257,16 @@ void NanmedianKernel(const Context& dev_ctx, ...@@ -273,7 +257,16 @@ void NanmedianKernel(const Context& dev_ctx,
bool keepdim, bool keepdim,
DenseTensor* out, DenseTensor* out,
DenseTensor* median_index) { DenseTensor* median_index) {
BaseMedianKernel<T, Context>(dev_ctx, x, axes, true, out, median_index); DenseTensor tmp_x;
auto rank = x.dims().size();
if ((axes.size() == 0) || rank <= 1) {
tmp_x = x;
tmp_x.Resize({x.numel()});
} else {
funcs::PreprocessMedianKernel<T, Context>(dev_ctx, x, axes, &tmp_x);
}
ProcessMedianKernel<T, Context>(dev_ctx, tmp_x, out, median_index);
} }
} // namespace phi } // namespace phi
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/nanmedian_grad_kernel.h"
namespace phi {
template <typename T, typename Context>
void PostprocessMedianGradKernel(const Context& dev_ctx,
DenseTensor* input,
const IntArray& raw_axes,
DenseTensor* x) {
auto input_dim = input->dims();
auto rank = input_dim.size();
std::vector<int64_t> axes = raw_axes.GetData();
int64_t axes_size = static_cast<int>(axes.size());
for (int64_t i = 0; i < axes_size; i++) {
if (axes[i] < 0) {
axes[i] += rank;
}
}
std::vector<int> trans_back;
std::vector<int> reshape_back;
trans_back.reserve(rank);
trans_back.resize(rank);
int offset = 0;
for (int64_t i = 0; i < rank; i++) {
if (std::find(axes.begin(), axes.end(), i) == axes.end()) {
reshape_back.push_back(input_dim[i]);
trans_back[i] = offset;
offset += 1;
}
}
for (int64_t i = 0; i < rank; i++) {
if (std::find(axes.begin(), axes.end(), i) != axes.end()) {
trans_back[i] = offset;
reshape_back.push_back(input_dim[i]);
offset += 1;
}
}
input->Resize(make_ddim(reshape_back));
funcs::TransCompute<Context, T>(
static_cast<int>(trans_back.size()), dev_ctx, *input, x, trans_back);
}
} // namespace phi
...@@ -64,8 +64,8 @@ def pairwise_distance(x, y, p=2.0, epsilon=1e-6, keepdim=False, name=None): ...@@ -64,8 +64,8 @@ def pairwise_distance(x, y, p=2.0, epsilon=1e-6, keepdim=False, name=None):
y = paddle.to_tensor([[5., 6.], [7., 8.]], dtype=paddle.float64) y = paddle.to_tensor([[5., 6.], [7., 8.]], dtype=paddle.float64)
distance = paddle.nn.functional.pairwise_distance(x, y) distance = paddle.nn.functional.pairwise_distance(x, y)
print(distance) print(distance)
# Tensor(shape=[2], dtype=float64, place=Place(gpu:0), stop_gradient=True, # Tensor(shape=[2], dtype=float64, place=Place(gpu:0), stop_gradient=True,
# [4.99999860, 4.99999860]) # [4.99999860, 4.99999860])
""" """
if in_dynamic_mode(): if in_dynamic_mode():
......
...@@ -394,15 +394,15 @@ def logspace(start, stop, num, base=10.0, dtype=None, name=None): ...@@ -394,15 +394,15 @@ def logspace(start, stop, num, base=10.0, dtype=None, name=None):
Args: Args:
start(int|float|Tensor): The input :attr:`start` is exponent of first entry in \ start(int|float|Tensor): The input :attr:`start` is exponent of first entry in \
the sequence. It is a scalar, or a Tensor of shape [1] with input data \ the sequence. It is a scalar, or a 0-D Tensor of shape [] with input data \
type int32, int64, float32 or float64. type int32, int64, float32 or float64.
stop(int|float|Tensor): The input :attr:`stop` is exponent of last entry in the \ stop(int|float|Tensor): The input :attr:`stop` is exponent of last entry in the \
sequence. It is a scalar, or a Tensor of shape [1] with input data \ sequence. It is a scalar, or a 0-D Tensor of shape [] with input data \
type int32, int64, float32 or float64. type int32, int64, float32 or float64.
num(int|Tensor): The input :attr:`num` is given number of items in the sequence. \ num(int|Tensor): The input :attr:`num` is given number of items in the sequence. \
It is an int scalar, or a Tensor of shape [1] with data type int32. It is an int scalar, or a 0-D Tensor of shape [] with data type int32.
base(int|float|Tensor): The input :attr:`base` is base of the logarithm function. \ base(int|float|Tensor): The input :attr:`base` is base of the logarithm function. \
It is a scalar, or a Tensor of shape [1] with input data type int32, int64, \ It is a scalar, or a 0-D Tensor of shape [] with input data type int32, int64, \
float32 or float64. float32 or float64.
dtype(np.dtype|str, optional): The data type of output tensor, it could be \ dtype(np.dtype|str, optional): The data type of output tensor, it could be \
int32, int64, float32 or float64. Default: if None, the data type is float32. \ int32, int64, float32 or float64. Default: if None, the data type is float32. \
......
...@@ -1615,7 +1615,7 @@ def count_nonzero(x, axis=None, keepdim=False, name=None): ...@@ -1615,7 +1615,7 @@ def count_nonzero(x, axis=None, keepdim=False, name=None):
# x is a 2-D Tensor: # x is a 2-D Tensor:
x = paddle.to_tensor([[0., 1.1, 1.2], [0., 0., 1.3], [0., 0., 0.]]) x = paddle.to_tensor([[0., 1.1, 1.2], [0., 0., 1.3], [0., 0., 0.]])
out1 = paddle.count_nonzero(x) out1 = paddle.count_nonzero(x)
# [3] # 3
out2 = paddle.count_nonzero(x, axis=0) out2 = paddle.count_nonzero(x, axis=0)
# [0, 1, 2] # [0, 1, 2]
out3 = paddle.count_nonzero(x, axis=0, keepdim=True) out3 = paddle.count_nonzero(x, axis=0, keepdim=True)
...@@ -1636,17 +1636,8 @@ def count_nonzero(x, axis=None, keepdim=False, name=None): ...@@ -1636,17 +1636,8 @@ def count_nonzero(x, axis=None, keepdim=False, name=None):
# [1, 3, 5] # [1, 3, 5]
""" """
if axis is not None: if isinstance(axis, int):
if isinstance(axis, int): axis = [axis]
axis = [axis]
dims = len(x.shape)
for i in range(len(axis)):
if not isinstance(axis[i], int) or not (
axis[i] < dims and axis[i] >= -dims
):
raise ValueError(
"Axis should be None, int, or a list, element should in range [-rank(x), rank(x))."
)
bool_tensor = paddle.cast(x, 'bool') bool_tensor = paddle.cast(x, 'bool')
int_tensor = paddle.cast(bool_tensor, 'int64') int_tensor = paddle.cast(bool_tensor, 'int64')
......
...@@ -255,7 +255,7 @@ def numel(x, name=None): ...@@ -255,7 +255,7 @@ def numel(x, name=None):
return out return out
def nanmedian(x, axis=None, keepdim=True, name=None): def nanmedian(x, axis=None, keepdim=False, name=None):
r""" r"""
Compute the median along the specified axis, while ignoring NaNs. Compute the median along the specified axis, while ignoring NaNs.
...@@ -273,7 +273,7 @@ def nanmedian(x, axis=None, keepdim=True, name=None): ...@@ -273,7 +273,7 @@ def nanmedian(x, axis=None, keepdim=True, name=None):
in the output Tensor. If ``keepdim`` is True, the dimensions of in the output Tensor. If ``keepdim`` is True, the dimensions of
the output Tensor is the same as ``x`` except in the reduced the output Tensor is the same as ``x`` except in the reduced
dimensions(it is of size 1 in this case). Otherwise, the shape of dimensions(it is of size 1 in this case). Otherwise, the shape of
the output Tensor is squeezed in ``axis`` . Default is True. the output Tensor is squeezed in ``axis`` . Default is False.
name (str, optional): Name for the operation (optional, default is None). name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`. For more information, please refer to :ref:`api_guide_Name`.
...@@ -287,16 +287,16 @@ def nanmedian(x, axis=None, keepdim=True, name=None): ...@@ -287,16 +287,16 @@ def nanmedian(x, axis=None, keepdim=True, name=None):
x = paddle.to_tensor([[float('nan'), 2. , 3. ], [0. , 1. , 2. ]]) x = paddle.to_tensor([[float('nan'), 2. , 3. ], [0. , 1. , 2. ]])
y1 = x.nanmedian() y1 = x.nanmedian()
# y1 is [[2.]] # y1 is 2.
y2 = x.nanmedian(0) y2 = x.nanmedian(0)
# y2 is [[0., 1.5, 2.5]] # y2 is [0., 1.5, 2.5]
y3 = x.nanmedian(0, keepdim=False) y3 = x.nanmedian(0, keepdim=True)
# y3 is [0., 1.5, 2.5] # y3 is [[0., 1.5, 2.5]]
y4 = x.nanmedian((0, 1)) y4 = x.nanmedian((0, 1))
# y4 is [[2.]] # y4 is 2.
""" """
if not isinstance(x, Variable): if not isinstance(x, Variable):
raise TypeError("In median, the input x should be a Tensor.") raise TypeError("In median, the input x should be a Tensor.")
...@@ -304,7 +304,6 @@ def nanmedian(x, axis=None, keepdim=True, name=None): ...@@ -304,7 +304,6 @@ def nanmedian(x, axis=None, keepdim=True, name=None):
if isinstance(axis, (list, tuple)) and len(axis) == 0: if isinstance(axis, (list, tuple)) and len(axis) == 0:
raise ValueError("Axis list should not be empty.") raise ValueError("Axis list should not be empty.")
dims = len(x.shape)
if axis is None: if axis is None:
axis = [] axis = []
elif isinstance(axis, tuple): elif isinstance(axis, tuple):
...@@ -312,24 +311,6 @@ def nanmedian(x, axis=None, keepdim=True, name=None): ...@@ -312,24 +311,6 @@ def nanmedian(x, axis=None, keepdim=True, name=None):
elif isinstance(axis, int): elif isinstance(axis, int):
axis = [axis] axis = [axis]
if not isinstance(axis, list):
raise ValueError(
"Axis should be None, int, or a list, element should in range [-rank(x), rank(x))."
)
for i in range(len(axis)):
if not isinstance(axis[i], int) or not (
axis[i] < dims and axis[i] >= -dims
):
raise ValueError(
"Axis should be None, int, or a list, element should in range [-rank(x), rank(x))."
)
if axis[i] < 0:
axis[i] += dims
if len(axis) != len(set(axis)):
raise ValueError("Axis has duplicated elements.")
if in_dynamic_mode(): if in_dynamic_mode():
return _C_ops.nanmedian(x, axis, keepdim) return _C_ops.nanmedian(x, axis, keepdim)
else: else:
......
...@@ -125,6 +125,7 @@ class TestNanmedian(unittest.TestCase): ...@@ -125,6 +125,7 @@ class TestNanmedian(unittest.TestCase):
pd_res = paddle.nanmedian( pd_res = paddle.nanmedian(
paddle.to_tensor(data), keepdim=keep_dim paddle.to_tensor(data), keepdim=keep_dim
) )
assert np_res.shape == pd_res.numpy().shape
np.testing.assert_allclose( np.testing.assert_allclose(
np_res, pd_res.numpy(), rtol=1e-05, equal_nan=True np_res, pd_res.numpy(), rtol=1e-05, equal_nan=True
) )
...@@ -187,6 +188,23 @@ class TestNanmedian(unittest.TestCase): ...@@ -187,6 +188,23 @@ class TestNanmedian(unittest.TestCase):
x_np[0, :] = np.nan x_np[0, :] = np.nan
x_np[1, :3] = np.nan x_np[1, :3] = np.nan
x_np[2, 3:] = np.nan x_np[2, 3:] = np.nan
x_tensor = paddle.to_tensor(x_np, stop_gradient=False)
y = paddle.nanmedian(x_tensor, keepdim=True)
dx = paddle.grad(y, x_tensor)[0].numpy()
np_grad = np.zeros(shape)
np_grad[1, 3] = 0.5
np_grad[3, 2] = 0.5
np.testing.assert_allclose(np_grad, dx, rtol=1e-05, equal_nan=True)
def test_check_grad_axis(self):
paddle.disable_static(place=self.place)
shape = (4, 5)
x_np = np.random.uniform(-1, 1, shape).astype(np.float64)
x_np[0, :] = np.nan
x_np[1, :3] = np.nan
x_np[2, 3:] = np.nan
x_np_sorted = np.sort(x_np) x_np_sorted = np.sort(x_np)
nan_counts = np.count_nonzero(np.isnan(x_np).astype(np.int32), axis=1) nan_counts = np.count_nonzero(np.isnan(x_np).astype(np.int32), axis=1)
np_grad = np.zeros(shape) np_grad = np.zeros(shape)
...@@ -205,10 +223,25 @@ class TestNanmedian(unittest.TestCase): ...@@ -205,10 +223,25 @@ class TestNanmedian(unittest.TestCase):
np_grad[i, j] = 1 if is_odd else 0.5 np_grad[i, j] = 1 if is_odd else 0.5
x_tensor = paddle.to_tensor(x_np, stop_gradient=False) x_tensor = paddle.to_tensor(x_np, stop_gradient=False)
y = paddle.nanmedian(x_tensor, axis=1, keepdim=True) y = paddle.nanmedian(x_tensor, axis=1)
dx = paddle.grad(y, x_tensor)[0].numpy() dx = paddle.grad(y, x_tensor)[0].numpy()
np.testing.assert_allclose(np_grad, dx, rtol=1e-05, equal_nan=True) np.testing.assert_allclose(np_grad, dx, rtol=1e-05, equal_nan=True)
def test_check_grad_0d(self):
paddle.disable_static(place=self.place)
x = paddle.rand([])
x.stop_gradient = False
y = paddle.nanmedian(x)
y.backward()
self.assertEqual(x.grad.shape, [])
np.testing.assert_allclose(x.grad, np.array(1.0))
x = paddle.to_tensor(float('nan'), stop_gradient=False)
y = paddle.nanmedian(x)
y.backward()
self.assertEqual(x.grad.shape, [])
np.testing.assert_allclose(x.grad, np.array(0.0))
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册