未验证 提交 f87fa3c0 编写于 作者: T thunder95 提交者: GitHub

【PaddlePaddle Hackathon 2】15 新增 API Nanmedian (#42385)

* nanmedian op

* 修改cuda kernel的bug

* 修复count_if在其他硬件平台不兼容

* 修复某些cpu硬件不兼容

* 修复某些cpu硬件不兼容

* 修复isnan判断

* 兼容numpy低版本不支持全部nan的情况

* 兼容numpy低版本不支持全部nan的情况

* fix code example

* fix api comment error

* 修改反向传播逻辑以及c++处理逻辑

* 完成修改建议

* typo pre_dim

* update en docs, test=document_fix

* remove numpy in en doc, test=document_fix

* add r,test=document_fix

* 添加api到all

* follow advice from chenwhql
上级 5df92262
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <memory>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/backward.h"
#include "paddle/phi/infermeta/unary.h"
namespace paddle {
namespace operators {
class NanmedianOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
}
};
class NanmedianOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X",
"(Tensor), "
"the input feature data of NanmedianOp, dtype should be"
"int32, int64, float16, float32 or float64.");
AddOutput(
"MedianIndex",
"Store the index position of median values, The calculation differs "
"in the odd or even valid elements numbers."
"Along the axis, two elements contributed to the median value in "
"each row."
"If the amount of valid elements were even, both were the same.")
.AsIntermediate()
.AsExtra();
AddOutput("Out",
"(Tensor),"
" the output of NanmedianOp, whose dtype is the same as X");
AddAttr<bool>("keepdim",
"(bool, default true) "
"If true, retain the reduced axis with length 1.")
.SetDefault(true);
AddAttr<std::vector<int>>("axis",
"(std::vector<int>). List of integers,"
" indicating the dimensions to calculate medians")
.SetDefault({});
AddComment(R"DOC(
Nanmedian operator
This operator is considered as an extention of median operation,
which supports specifically the case of NaN values in the input.
If all the elements in input are NaN it will also return NaN.
If no elements in input are Nan, this op is identical to thie median op.
If the valid count of elements is a even number, the average value of
the elements in the middle is calculated as the median.
This operator can also supports multiple axis.
)DOC");
}
};
template <typename T>
class NanmedianGradMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
void Apply(GradOpPtr<T> op) const override {
op->SetType("nanmedian_grad");
op->SetInput("X", this->Input("X"));
op->SetInput("MedianIndex", this->Output("MedianIndex"));
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetAttrMap(this->Attrs());
}
};
class NanmedianGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out")),
ctx.GetPlace());
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(nanmedian, NanmedianInferShapeFunctor,
PD_INFER_META(phi::NanmedianInferMeta));
REGISTER_OPERATOR(nanmedian, ops::NanmedianOp, ops::NanmedianOpMaker,
ops::NanmedianGradMaker<paddle::framework::OpDesc>,
ops::NanmedianGradMaker<paddle::imperative::OpBase>,
NanmedianInferShapeFunctor);
DECLARE_INFER_SHAPE_FUNCTOR(nanmedian_grad, NanmedianGradInferShapeFunctor,
PD_INFER_META(phi::NanmedianGradInferMeta));
REGISTER_OPERATOR(nanmedian_grad, ops::NanmedianGradOp,
NanmedianGradInferShapeFunctor);
......@@ -433,6 +433,17 @@ void MultiplexGradInferMeta(const MetaTensor& ids,
}
}
void NanmedianGradInferMeta(const MetaTensor& x,
const MetaTensor& median_index,
const MetaTensor& out_grad,
const IntArray& axes,
bool keep_dim,
MetaTensor* x_grad) {
auto x_dims = x.dims();
x_grad->set_dims(x_dims);
x_grad->set_dtype(x.dtype());
}
void NllLossGradInferMeta(const MetaTensor& x,
const MetaTensor& label,
const MetaTensor& weight,
......
......@@ -191,6 +191,13 @@ void MultiplexGradInferMeta(const MetaTensor& ids,
const MetaTensor& out_grad,
std::vector<MetaTensor*> ins_grad);
void NanmedianGradInferMeta(const MetaTensor& x,
const MetaTensor& median_index,
const MetaTensor& out_grad,
const IntArray& axes,
bool keep_dim,
MetaTensor* x_grad);
void NllLossGradInferMeta(const MetaTensor& input,
const MetaTensor& label,
const MetaTensor& weight,
......
......@@ -1246,6 +1246,65 @@ void MultinomialInferMeta(const MetaTensor& x,
out->set_dtype(DataType::INT64);
}
void NanmedianInferMeta(const MetaTensor& x,
const IntArray& axes,
bool keep_dim,
MetaTensor* out,
MetaTensor* median_index) {
std::vector<int64_t> axis_list = axes.GetData();
auto x_dim = x.dims();
int64_t x_rank = x_dim.size();
out->set_dtype(x.dtype());
median_index->set_dtype(DataType::INT64);
median_index->set_dims(make_ddim({x.numel() * 2}));
std::vector<int32_t> out_dim;
if (axis_list.empty()) {
if (keep_dim) {
for (int64_t i = 0; i < x_rank; i++) {
out_dim.push_back(1);
}
} else {
out_dim.push_back(1);
}
} else {
std::vector<int64_t> cleaned_axis;
for (auto& axis : axis_list) {
if (axis < 0) axis += x_rank;
PADDLE_ENFORCE_LT(
axis,
x_rank,
errors::InvalidArgument(
"Attr(axis) value should be in range [-R, R-1], R is "
"the rank of Input(X). But received axis: %d, R: %d. "
"Current Input(X)'s shape is=[%s].",
axis,
x_rank,
x_dim));
PADDLE_ENFORCE_EQ(
std::find(cleaned_axis.begin(), cleaned_axis.end(), axis),
cleaned_axis.end(),
errors::InvalidArgument("Attr(axes) has duplicated elements: %d.",
static_cast<int>(axis)));
cleaned_axis.push_back(axis);
}
for (int64_t i = 0; i < x_rank; i++) {
if (std::find(cleaned_axis.begin(), cleaned_axis.end(), i) ==
cleaned_axis.end()) {
out_dim.push_back(x_dim[i]);
} else if (keep_dim) {
out_dim.push_back(1);
}
}
}
out->set_dims(make_ddim(out_dim));
}
void NormInferMeta(const MetaTensor& x,
int axis,
float epsilon,
......
......@@ -178,6 +178,13 @@ void MultinomialInferMeta(const MetaTensor& x,
int num_samples,
bool replacement,
MetaTensor* out);
void NanmedianInferMeta(const MetaTensor& x,
const IntArray& axes,
bool keep_dim,
MetaTensor* out,
MetaTensor* median_index);
void NormInferMeta(const MetaTensor& x,
int axis,
float epsilon,
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/nanmedian_grad_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace phi {
template <typename T, typename Context>
void CalcMedianGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& median_index,
const DenseTensor& out_grad,
const IntArray& axes,
DenseTensor* x_grad,
T* x_grad_ptr) {
phi::funcs::SetConstant<Context, T> set_zero;
set_zero(dev_ctx, x_grad, static_cast<T>(0));
if (!x_grad_ptr) return;
const int64_t* m_ptr = median_index.data<int64_t>();
const T* out_grad_ptr = out_grad.data<T>();
int64_t numel = x.numel();
auto x_dim = x.dims();
int64_t rank = x_dim.size();
int64_t stride = x_dim[rank - 1];
int64_t pre_dim = numel / stride;
int64_t i = 0;
int64_t offset = 0;
T div_factor = static_cast<T>(2.0);
for (i = 0; i < pre_dim; i++) {
if (m_ptr[2 * i] >= 0) {
if (m_ptr[2 * i] == m_ptr[2 * i + 1]) {
x_grad_ptr[offset + m_ptr[2 * i]] = out_grad_ptr[i];
} else {
x_grad_ptr[offset + m_ptr[2 * i]] = out_grad_ptr[i] / div_factor;
x_grad_ptr[offset + m_ptr[2 * i + 1]] = out_grad_ptr[i] / div_factor;
}
}
offset += stride;
}
}
template <typename T, typename Context>
void BaseMedianGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& median_index,
const DenseTensor& out_grad,
const IntArray& axes,
DenseTensor* x_grad) {
auto rank = x.dims().size();
T* x_grad_ptr = dev_ctx.template Alloc<T>(x_grad);
if (axes.size() && (rank > 1)) {
DenseTensor tmp_x_grad(*x_grad);
CalcMedianGradKernel<T, Context>(
dev_ctx, x, median_index, out_grad, axes, &tmp_x_grad, x_grad_ptr);
PostprocessMedianGradKernel<T, Context>(dev_ctx, &tmp_x_grad, axes, x_grad);
} else {
CalcMedianGradKernel<T, Context>(
dev_ctx, x, median_index, out_grad, axes, x_grad, x_grad_ptr);
}
}
template <typename T, typename Context>
void NanmedianGradKernel(const Context& dev_ctx,
const DenseTensor& input,
const DenseTensor& median_index,
const DenseTensor& out_grad,
const IntArray& axes,
bool keep_dim,
DenseTensor* x_grad) {
BaseMedianGradKernel<T, Context>(
dev_ctx, input, median_index, out_grad, axes, x_grad);
}
} // namespace phi
PD_REGISTER_KERNEL(nanmedian_grad,
CPU,
ALL_LAYOUT,
phi::NanmedianGradKernel,
float,
double,
int,
int64_t) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/nanmedian_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/top_k_kernel.h"
namespace phi {
template <typename T, typename Context>
void CalcMedianFunc(const Context& dev_ctx,
const DenseTensor& x,
const std::vector<int64_t>& nan_counts,
bool ignore_nan,
int64_t sort_k,
int64_t stride,
int64_t pre_dim,
T* o_ptr,
int64_t* m_ptr) {
bool should_ignore_nan = ignore_nan;
DenseTensor sort_out;
DenseTensor sort_indices;
auto sort_dim = x.dims();
int64_t rank = sort_dim.size();
sort_dim[rank - 1] = sort_k;
sort_out.Resize(sort_dim);
sort_indices.Resize(sort_dim);
dev_ctx.template Alloc<T>(&sort_out);
T* sort_out_ptr = sort_out.data<T>();
dev_ctx.template Alloc<int64_t>(&sort_indices);
int64_t* sort_indices_ptr = sort_indices.data<int64_t>();
TopkKernel<T, Context>(
dev_ctx, x, Scalar(sort_k), -1, false, true, &sort_out, &sort_indices);
T div_factor = static_cast<T>(2.0);
int64_t offset = 0;
int64_t i = 0;
bool is_ori_odd = stride & 1;
if (should_ignore_nan) {
for (i = 0; i < pre_dim; i++) {
offset = i * sort_k;
if (nan_counts[i] == stride) {
m_ptr[i * 2] = -1;
m_ptr[i * 2 + 1] = -1;
o_ptr[i] = sort_out_ptr[offset];
} else {
int64_t nan_k = nan_counts[i] > 0
? static_cast<int64_t>(stride - nan_counts[i])
: sort_k;
int64_t row_pos = static_cast<int64_t>(nan_k >> 1);
int64_t pos = offset + row_pos;
if (nan_k & 1) {
m_ptr[2 * i] = sort_indices_ptr[pos];
m_ptr[2 * i + 1] = sort_indices_ptr[pos];
o_ptr[i] = sort_out_ptr[pos];
} else {
m_ptr[2 * i] =
row_pos > 0 ? sort_indices_ptr[pos - 1] : sort_indices_ptr[pos];
m_ptr[2 * i + 1] = sort_indices_ptr[pos];
T m_val_left =
row_pos > 0 ? sort_out_ptr[pos - 1] : sort_out_ptr[pos];
T m_val_right = sort_out_ptr[pos];
o_ptr[i] = (m_val_left + m_val_right) / div_factor;
}
}
}
} else {
if (is_ori_odd) {
for (i = 0; i < pre_dim; i++) {
offset = i * sort_k;
int64_t pos = offset + sort_k - 1;
o_ptr[i] = sort_out_ptr[pos];
m_ptr[2 * i] = sort_indices_ptr[pos];
m_ptr[2 * i + 1] = sort_indices_ptr[pos];
}
} else {
for (i = 0; i < pre_dim; i++) {
offset = i * sort_k;
int64_t pos = offset + sort_k - 1;
m_ptr[2 * i] =
sort_k > 1 ? sort_indices_ptr[pos - 1] : sort_indices_ptr[pos];
m_ptr[2 * i + 1] = sort_indices_ptr[pos];
T m_val_left = sort_k > 1 ? sort_out_ptr[pos - 1] : sort_out_ptr[pos];
T m_val_right = sort_out_ptr[pos];
o_ptr[i] = (m_val_left + m_val_right) / div_factor;
}
}
}
}
template <typename T, typename Context>
void ProcessMedianKernel(const Context& dev_ctx,
const DenseTensor& x,
T* o_ptr,
int64_t* m_ptr,
bool ignore_nan) {
bool should_ignore_nan = ignore_nan;
const T* x_ptr = x.data<T>();
int64_t numel = x.numel();
auto x_dim = x.dims();
int64_t x_rank = x_dim.size();
int64_t stride = x_dim[x_rank - 1];
int64_t pre_dim = numel / stride;
int64_t i = 0;
int64_t max_valid_num = 0;
std::vector<int64_t> nan_counts;
if (should_ignore_nan) {
int64_t total_nan_num = 0;
std::vector<T> col_vec;
col_vec.reserve(stride);
col_vec.resize(stride);
nan_counts.clear();
nan_counts.reserve(pre_dim);
nan_counts.resize(pre_dim);
for (int64_t i = 0; i < pre_dim; i++) {
col_vec.clear();
col_vec.insert(
col_vec.begin(), x_ptr + i * stride, x_ptr + (i + 1) * stride);
nan_counts[i] =
std::count_if(col_vec.begin(), col_vec.end(), [&](const T& val) {
return std::isnan(static_cast<float>(val));
});
total_nan_num += nan_counts[i];
if (stride - nan_counts[i] > max_valid_num)
max_valid_num = stride - nan_counts[i];
}
// all elems are nan
if (total_nan_num == numel) {
for (i = 0; i < pre_dim; i++) {
o_ptr[i] = x_ptr[0];
m_ptr[2 * i] = -1;
m_ptr[2 * i + 1] = -1;
}
return;
}
should_ignore_nan = total_nan_num > 0;
}
int64_t sort_k = should_ignore_nan ? max_valid_num : ((stride >> 1) + 1);
CalcMedianFunc<T, Context>(dev_ctx,
x,
nan_counts,
should_ignore_nan,
sort_k,
stride,
pre_dim,
o_ptr,
m_ptr);
}
template <typename T, typename Context>
void BaseMedianKernel(const Context& dev_ctx,
const DenseTensor& input,
const IntArray& axes,
DenseTensor* out,
DenseTensor* median_index,
bool ignore_nan) {
DenseTensor x;
auto rank = input.dims().size();
if ((axes.size() == 0) || rank <= 1) {
x = input;
x.Resize({input.numel()});
} else {
PreprocessMedianKernel<T, Context>(dev_ctx, input, axes, &x);
}
T* o_ptr = dev_ctx.template Alloc<T>(out);
int64_t* m_ptr = dev_ctx.template Alloc<int64_t>(median_index);
ProcessMedianKernel<T, Context>(dev_ctx, x, o_ptr, m_ptr, ignore_nan);
out->Resize(out->dims());
}
template <typename T, typename Context>
void NanmedianKernel(const Context& dev_ctx,
const DenseTensor& x,
const IntArray& axes,
bool keepdim,
DenseTensor* out,
DenseTensor* median_index) {
BaseMedianKernel<T, Context>(dev_ctx, x, axes, out, median_index, true);
}
} // namespace phi
PD_REGISTER_KERNEL(nanmedian,
CPU,
ALL_LAYOUT,
phi::NanmedianKernel,
float,
double,
int,
int64_t) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_meta.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/nanmedian_grad_kernel.h"
namespace phi {
using paddle::platform::PADDLE_CUDA_NUM_THREADS;
inline int GET_BLOCKS(const int N) {
return (N + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS;
}
template <typename T>
__global__ void KernelNanmedianGrad(const T* x_ptr,
const int64_t* medians_ptr,
const T* out_grad_ptr,
T* x_grad_ptr,
int64_t stride,
int64_t pre_dim,
T div_factor) {
CUDA_KERNEL_LOOP(index, pre_dim) {
int64_t offset = index * stride;
if (medians_ptr[2 * index] >= 0) {
if (medians_ptr[2 * index] == medians_ptr[2 * index + 1]) {
x_grad_ptr[offset + medians_ptr[2 * index]] = out_grad_ptr[index];
} else {
x_grad_ptr[offset + medians_ptr[2 * index]] =
out_grad_ptr[index] / div_factor;
x_grad_ptr[offset + medians_ptr[2 * index + 1]] =
out_grad_ptr[index] / div_factor;
}
}
}
}
template <typename T, typename Context>
void CalcMedianGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& median_index,
const DenseTensor& out_grad,
DenseTensor* x_grad,
T* x_grad_ptr) {
phi::funcs::SetConstant<Context, T> set_zero;
set_zero(dev_ctx, x_grad, static_cast<T>(0));
auto stream = dev_ctx.stream();
const T* x_ptr = x.data<T>();
const int64_t* m_ptr = median_index.data<int64_t>();
const T* out_grad_ptr = out_grad.data<T>();
int64_t numel = x.numel();
auto x_dim = x.dims();
int64_t x_rank = x_dim.size();
int64_t stride = x_dim[x_rank - 1];
int64_t pre_dim = numel / stride;
T div_factor = static_cast<T>(2.0);
KernelNanmedianGrad<
T><<<GET_BLOCKS(pre_dim), PADDLE_CUDA_NUM_THREADS, 0, stream>>>(
x_ptr, m_ptr, out_grad_ptr, x_grad_ptr, stride, pre_dim, div_factor);
}
template <typename T, typename Context>
void BaseMedianGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& median_index,
const DenseTensor& out_grad,
const IntArray& axes,
DenseTensor* x_grad) {
auto rank = x.dims().size();
T* x_grad_ptr = dev_ctx.template Alloc<T>(x_grad);
if (axes.size() && (rank > 1)) {
DenseTensor tmp_x_grad(*x_grad);
CalcMedianGradKernel<T, Context>(
dev_ctx, x, median_index, out_grad, &tmp_x_grad, x_grad_ptr);
PostprocessMedianGradKernel<T, Context>(dev_ctx, &tmp_x_grad, axes, x_grad);
} else {
CalcMedianGradKernel<T, Context>(
dev_ctx, x, median_index, out_grad, x_grad, x_grad_ptr);
}
}
template <typename T, typename Context>
void NanmedianGradKernel(const Context& dev_ctx,
const DenseTensor& input,
const DenseTensor& median_index,
const DenseTensor& out_grad,
const IntArray& axes,
bool keep_dim,
DenseTensor* x_grad) {
BaseMedianGradKernel<T, Context>(
dev_ctx, input, median_index, out_grad, axes, x_grad);
}
} // namespace phi
PD_REGISTER_KERNEL(nanmedian_grad,
GPU,
ALL_LAYOUT,
phi::NanmedianGradKernel,
float,
double,
int,
int64_t,
phi::dtype::float16) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/full_kernel.h"
#include "paddle/phi/kernels/nanmedian_kernel.h"
#include "paddle/phi/kernels/top_k_kernel.h"
namespace phi {
using paddle::platform::PADDLE_CUDA_NUM_THREADS;
inline int GET_BLOCKS(const int N) {
return (N + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS;
}
template <typename T>
__global__ void KernelNanCounts(const T* input,
const int numel,
const int64_t pre_dim,
const int64_t stride,
T min_val,
int64_t* nan_total,
int64_t* nan_counts) {
extern __shared__ int64_t buf[];
for (int i = threadIdx.x; i < pre_dim; i += blockDim.x) {
buf[i] = 0;
nan_counts[i] = 0;
}
if (threadIdx.x == 0) {
nan_total[0] = 0;
nan_total[1] = 0;
}
__syncthreads();
CUDA_KERNEL_LOOP(index, numel) {
const T x = input[index];
if (isnan(static_cast<float>(x))) {
auto bin = static_cast<int64_t>(index / stride);
paddle::platform::CudaAtomicAdd(&buf[bin], 1);
}
}
__syncthreads();
for (int i = threadIdx.x; i < pre_dim; i += blockDim.x) {
paddle::platform::CudaAtomicAdd(&nan_counts[i], buf[i]);
paddle::platform::CudaAtomicAdd(&nan_total[0], buf[i]);
paddle::platform::CudaAtomicMax(&nan_total[1], stride - buf[i]);
}
}
template <typename T>
__global__ void CalcMedianKernel(const T* sort_out_ptr,
const int64_t* sort_indices_ptr,
int64_t* median_val,
T* output,
T div_factor,
const bool is_odd,
const int64_t pre_dim,
const int64_t stride) {
CUDA_KERNEL_LOOP(index, pre_dim) {
int64_t pos = static_cast<int64_t>((index + 1) * stride) - 1;
if (is_odd) {
median_val[index * 2] = sort_indices_ptr[pos];
median_val[index * 2 + 1] = sort_indices_ptr[pos];
output[index] = sort_out_ptr[pos];
} else {
median_val[index * 2] =
pos > 0 ? sort_indices_ptr[pos - 1] : sort_indices_ptr[pos];
median_val[index * 2 + 1] = sort_indices_ptr[pos];
T median_val_left = pos > 0 ? sort_out_ptr[pos - 1] : sort_out_ptr[pos];
T median_val_right = sort_out_ptr[pos];
output[index] = (median_val_left + median_val_right) / div_factor;
}
}
}
template <typename T>
__global__ void CalcNanmedianKernel(const T* sort_out_ptr,
const int64_t* sort_indices_ptr,
int64_t* nan_counts,
int64_t* median_val,
T* output,
const bool is_odd,
const int64_t pre_dim,
const int64_t max_valid_num,
const int64_t stride,
const T div_factor,
const T nan_val) {
CUDA_KERNEL_LOOP(index, pre_dim) {
int64_t pos = static_cast<int64_t>(index * max_valid_num);
int64_t nan_cnt = nan_counts[index];
if (nan_cnt == stride) {
median_val[index * 2] = -1;
median_val[index * 2 + 1] = -1;
output[index] = nan_val;
} else {
int64_t nan_k =
nan_cnt > 0 ? static_cast<int64_t>(stride - nan_cnt) : max_valid_num;
int64_t row_pos = static_cast<int64_t>(nan_k >> 1);
pos += row_pos;
if (nan_k & 1) {
median_val[index * 2] = sort_indices_ptr[pos];
median_val[index * 2 + 1] = sort_indices_ptr[pos];
output[index] = sort_out_ptr[pos];
} else {
median_val[index * 2] =
pos > 0 ? sort_indices_ptr[pos - 1] : sort_indices_ptr[pos];
median_val[index * 2 + 1] = sort_indices_ptr[pos];
T median_val_left = pos > 0 ? sort_out_ptr[pos - 1] : sort_out_ptr[pos];
T median_val_right = sort_out_ptr[pos];
output[index] = (median_val_left + median_val_right) / div_factor;
}
}
}
}
template <typename T, typename Context>
void ProcessMedianKernel(const Context& dev_ctx,
const DenseTensor& x,
bool ignore_nan,
DenseTensor* out,
int64_t* m_ptr) {
bool should_ignore_nan = ignore_nan;
auto stream = dev_ctx.stream();
const T* x_ptr = x.data<T>();
T* o_ptr = dev_ctx.template Alloc<T>(out);
int64_t numel = x.numel();
auto x_dim = x.dims();
int64_t x_rank = x_dim.size();
int64_t stride = x_dim[x_rank - 1];
int64_t pre_dim = numel / stride;
int64_t i = 0;
DenseTensor nan_counts, nan_stat;
int64_t* nan_counts_ptr;
int64_t max_valid_num = 0;
if (should_ignore_nan) {
nan_counts.Resize(phi::make_ddim({pre_dim}));
dev_ctx.template Alloc<int64_t>(&nan_counts);
nan_counts_ptr = nan_counts.data<int64_t>();
nan_stat.Resize(phi::make_ddim({2}));
int64_t* nan_stat_mem = dev_ctx.template Alloc<int64_t>(&nan_stat);
int64_t* nan_stat_ptr = nan_stat.data<int64_t>();
KernelNanCounts<T><<<GET_BLOCKS(numel),
PADDLE_CUDA_NUM_THREADS,
pre_dim * sizeof(int64_t),
stream>>>(x_ptr,
numel,
pre_dim,
stride,
std::numeric_limits<T>::min(),
nan_stat_ptr,
nan_counts_ptr);
auto nan_stat_mem_cpu =
paddle::memory::Alloc(phi::CPUPlace(), sizeof(int64_t) * 2);
int64_t* nan_stat_cpu_ptr =
reinterpret_cast<int64_t*>(nan_stat_mem_cpu->ptr());
paddle::memory::Copy(phi::CPUPlace(),
nan_stat_cpu_ptr,
dev_ctx.GetPlace(),
nan_stat_mem,
sizeof(int64_t) * 2,
stream);
// all elements are nan values
T nan_val = std::numeric_limits<T>::quiet_NaN();
if (nan_stat_cpu_ptr[0] == numel) {
FullLikeKernel<T, Context>(dev_ctx, x, nan_val, x.dtype(), out);
return;
}
should_ignore_nan = nan_stat_cpu_ptr[0] > 0;
max_valid_num = nan_stat_cpu_ptr[1];
}
int64_t sort_k = should_ignore_nan ? max_valid_num : ((stride >> 1) + 1);
bool is_ori_odd = stride & 1;
DenseTensor sort_out, sort_indices;
auto sort_dim = x.dims();
int64_t rank = sort_dim.size();
sort_dim[rank - 1] = sort_k;
sort_out.Resize(sort_dim);
sort_indices.Resize(sort_dim);
dev_ctx.template Alloc<T>(&sort_out);
T* sort_out_ptr = sort_out.data<T>();
dev_ctx.template Alloc<int64_t>(&sort_indices);
int64_t* sort_indices_ptr = sort_indices.data<int64_t>();
TopkKernel<T, Context>(
dev_ctx, x, Scalar(sort_k), -1, false, true, &sort_out, &sort_indices);
T div_factor = static_cast<T>(2.0);
T nan_val = std::numeric_limits<T>::quiet_NaN();
if (should_ignore_nan) {
CalcNanmedianKernel<
T><<<GET_BLOCKS(pre_dim), PADDLE_CUDA_NUM_THREADS, 0, stream>>>(
sort_out_ptr,
sort_indices_ptr,
nan_counts_ptr,
m_ptr,
o_ptr,
is_ori_odd,
pre_dim,
max_valid_num,
stride,
div_factor,
nan_val);
} else {
CalcMedianKernel<
T><<<GET_BLOCKS(pre_dim), PADDLE_CUDA_NUM_THREADS, 0, stream>>>(
sort_out_ptr,
sort_indices_ptr,
m_ptr,
o_ptr,
div_factor,
is_ori_odd,
pre_dim,
sort_k);
}
}
template <typename T, typename Context>
void BaseMedianKernel(const Context& dev_ctx,
const DenseTensor& input,
const IntArray& axes,
bool ignore_nan,
DenseTensor* out,
DenseTensor* median_index) {
DenseTensor x;
auto rank = input.dims().size();
if ((axes.size() == 0) || rank <= 1) {
x = input;
x.Resize({input.numel()});
} else {
PreprocessMedianKernel<T, Context>(dev_ctx, input, axes, &x);
}
int64_t* m_ptr = dev_ctx.template Alloc<int64_t>(median_index);
ProcessMedianKernel<T, Context>(dev_ctx, x, ignore_nan, out, m_ptr);
out->Resize(out->dims());
}
template <typename T, typename Context>
void NanmedianKernel(const Context& dev_ctx,
const DenseTensor& x,
const IntArray& axes,
bool keepdim,
DenseTensor* out,
DenseTensor* median_index) {
BaseMedianKernel<T, Context>(dev_ctx, x, axes, true, out, median_index);
}
} // namespace phi
PD_REGISTER_KERNEL(nanmedian,
GPU,
ALL_LAYOUT,
phi::NanmedianKernel,
float,
double,
int,
int64_t,
phi::dtype::float16) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/common/int_array.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace phi {
template <typename T, typename Context>
void PostprocessMedianGradKernel(const Context& dev_ctx,
DenseTensor* input,
const IntArray& raw_axes,
DenseTensor* x) {
auto input_dim = input->dims();
auto rank = input_dim.size();
std::vector<int64_t> axes = raw_axes.GetData();
int64_t axes_size = static_cast<int>(axes.size());
for (int64_t i = 0; i < axes_size; i++) {
if (axes[i] < 0) {
axes[i] += rank;
}
}
std::vector<int> trans_back;
std::vector<int> reshape_back;
trans_back.reserve(rank);
trans_back.resize(rank);
int offset = 0;
for (int64_t i = 0; i < rank; i++) {
if (std::find(axes.begin(), axes.end(), i) == axes.end()) {
reshape_back.push_back(input_dim[i]);
trans_back[i] = offset;
offset += 1;
}
}
for (int64_t i = 0; i < rank; i++) {
if (std::find(axes.begin(), axes.end(), i) != axes.end()) {
trans_back[i] = offset;
reshape_back.push_back(input_dim[i]);
offset += 1;
}
}
input->Resize(make_ddim(reshape_back));
funcs::TransCompute<Context, T>(
static_cast<int>(trans_back.size()), dev_ctx, *input, x, trans_back);
}
template <typename T, typename Context>
void NanmedianGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& median_index,
const DenseTensor& out_grad,
const IntArray& axes,
bool keep_dim,
DenseTensor* x_grad);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/common/int_array.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace phi {
template <typename T, typename Context>
void PreprocessMedianKernel(const Context& dev_ctx,
const DenseTensor& input,
const IntArray& raw_axes,
DenseTensor* x) {
auto input_dim = input.dims();
auto rank = input_dim.size();
std::vector<int> perm;
std::vector<int64_t> reshape;
std::vector<int64_t> axes = raw_axes.GetData();
int64_t axes_size = static_cast<int>(axes.size());
for (int64_t i = 0; i < axes_size; i++) {
if (axes[i] < 0) {
axes[i] += rank;
}
}
for (int64_t i = 0; i < rank; i++) {
if (std::find(axes.begin(), axes.end(), i) == axes.end()) {
perm.push_back(i);
reshape.push_back(input_dim[i]);
}
}
int64_t post_numel = 1;
for (int64_t i = 0; i < rank; i++) {
if (std::find(axes.begin(), axes.end(), i) != axes.end()) {
perm.push_back(i);
post_numel *= input_dim[i];
}
}
reshape.push_back(post_numel);
DDim trans_dim(input_dim);
int ndims = perm.size();
for (int i = 0; i < ndims; i++) {
trans_dim[i] = input_dim[perm[i]];
}
x->Resize(trans_dim);
dev_ctx.template Alloc<T>(x);
funcs::TransCompute<Context, T>(ndims, dev_ctx, input, x, perm);
x->Resize(make_ddim(reshape));
}
template <typename T, typename Context>
void NanmedianKernel(const Context& dev_ctx,
const DenseTensor& x,
const IntArray& axes,
bool keep_dim,
DenseTensor* out,
DenseTensor* medians);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature NanmedianOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature(
"nanmedian", {"X"}, {"axis", "keepdim"}, {"Out", "MedianIndex"});
}
KernelSignature NanmedianGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("nanmedian_grad",
{"X", "MedianIndex", "Out@GRAD"},
{"axis", "keepdim"},
{"X@GRAD"});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(nanmedian, phi::NanmedianOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(nanmedian_grad, phi::NanmedianGradOpArgumentMapping);
......@@ -331,6 +331,7 @@ from .tensor.stat import std # noqa: F401
from .tensor.stat import var # noqa: F401
from .tensor.stat import numel # noqa: F401
from .tensor.stat import median # noqa: F401
from .tensor.stat import nanmedian # noqa: F401
from .tensor.stat import quantile # noqa: F401
from .tensor.stat import nanquantile # noqa: F401
from .device import get_cudnn_version # noqa: F401
......@@ -498,6 +499,7 @@ __all__ = [ # noqa
'load',
'numel',
'median',
'nanmedian',
'quantile',
'nanquantile',
'no_grad',
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import paddle
import paddle.fluid.core as core
np.random.seed(102)
class TestNanmedian(unittest.TestCase):
def setUp(self):
single_axis_shape = (120)
multi_axis_shape = (2, 3, 4, 5)
self.fake_data = {
"single_axis_normal":
np.random.uniform(-1, 1, single_axis_shape).astype(np.float32),
"multi_axis_normal":
np.random.uniform(-1, 1, multi_axis_shape).astype(np.float32),
"single_axis_all_nan": np.full(single_axis_shape, np.nan),
"multi_axis_all_nan": np.full(multi_axis_shape, np.nan),
}
single_partial_nan = self.fake_data["single_axis_normal"].copy()
single_partial_nan[single_partial_nan > 0] = np.nan
multi_partial_nan = self.fake_data["multi_axis_normal"].copy()
multi_partial_nan[multi_partial_nan > 0] = np.nan
self.fake_data["single_axis_partial_nan"] = single_partial_nan
self.fake_data["multi_axis_partial_nan"] = multi_partial_nan
row_data = np.random.uniform(-1, 1, multi_axis_shape).astype(np.float32)
row_data[:, :, :, 0] = np.nan
row_data[:, :, :2, 1] = np.nan
row_data[:, :, 2:, 2] = np.nan
self.fake_data["row_nan_even"] = row_data
self.fake_data["row_nan_float64"] = row_data.astype(np.float64)
self.fake_data["row_nan_int64"] = row_data.astype(np.int64)
self.fake_data["row_nan_int32"] = row_data.astype(np.int32)
col_data = np.random.uniform(-1, 1, multi_axis_shape).astype(np.float32)
col_data[:, :, 0, :] = np.nan
col_data[:, :, 1, :3] = np.nan
col_data[:, :, 2, 3:] = np.nan
self.fake_data["col_nan_odd"] = col_data
self.place = paddle.CUDAPlace(0) if core.is_compiled_with_cuda() \
else paddle.CPUPlace()
self.axis_candiate_list = [
None, 0, 2, -1, -2, (1, 2), [0, -1], [0, 1, 3], (1, 2, 3),
[0, 2, 1, 3]
]
def test_api_static(self):
data = self.fake_data["col_nan_odd"]
paddle.enable_static()
np_res = np.nanmedian(data, keepdims=True)
with paddle.static.program_guard(paddle.static.Program()):
x = paddle.fluid.data('X', data.shape)
out1 = paddle.nanmedian(x, keepdim=True)
out2 = paddle.tensor.nanmedian(x, keepdim=True)
out3 = paddle.tensor.stat.nanmedian(x, keepdim=True)
axis = np.arange(len(data.shape)).tolist()
out4 = paddle.nanmedian(x, axis=axis, keepdim=True)
out5 = paddle.nanmedian(x, axis=tuple(axis), keepdim=True)
exe = paddle.static.Executor(self.place)
res = exe.run(feed={'X': data},
fetch_list=[out1, out2, out3, out4, out5])
for out in res:
self.assertTrue(np.allclose(np_res, out, equal_nan=True))
def test_api_dygraph(self):
paddle.disable_static(self.place)
def clean_axis_numpy(axis, shape_len):
if isinstance(axis, tuple):
axis = list(axis)
if isinstance(axis, list):
for k in range(len(axis)):
if axis[k] < 0:
axis[k] += shape_len
axis = set(axis)
return axis
def test_data_case(data):
for keep_dim in [False, True]:
if np.isnan(data).all() and keep_dim:
np_ver = np.version.version.split('.')
if int(np_ver[0]) < 1 or int(np_ver[1]) <= 20:
print(
"This numpy version does not support all nan elements when keepdim is True"
)
continue
np_res = np.nanmedian(data, keepdims=keep_dim)
pd_res = paddle.nanmedian(
paddle.to_tensor(data), keepdim=keep_dim)
self.assertTrue(
np.allclose(
np_res, pd_res.numpy(), equal_nan=True))
def test_axis_case(data, axis):
pd_res = paddle.nanmedian(
paddle.to_tensor(data), axis=axis, keepdim=False)
axis = clean_axis_numpy(axis, len(data.shape))
np_res = np.nanmedian(data, axis=axis, keepdims=False)
self.assertTrue(np.allclose(np_res, pd_res.numpy(), equal_nan=True))
for name, data in self.fake_data.items():
test_data_case(data)
for axis in self.axis_candiate_list:
test_axis_case(self.fake_data["row_nan_even"], axis)
test_axis_case(self.fake_data["col_nan_odd"], axis)
paddle.enable_static()
def test_errors(self):
paddle.enable_static()
with paddle.static.program_guard(paddle.static.Program()):
x = paddle.fluid.data("X", [10, 12])
def test_dtype():
x2 = paddle.fluid.data('X2', [10, 12], 'bool')
paddle.nanmedian(x2)
def test_empty_axis():
paddle.nanmedian(x, axis=[], keepdim=True)
def test_axis_not_in_range():
paddle.nanmedian(x, axis=3, keepdim=True)
def test_duplicated_axis():
paddle.nanmedian(x, axis=[1, -1], keepdim=True)
self.assertRaises(TypeError, test_dtype)
self.assertRaises(ValueError, test_empty_axis)
self.assertRaises(ValueError, test_axis_not_in_range)
self.assertRaises(ValueError, test_duplicated_axis)
def test_dygraph(self):
paddle.disable_static(place=self.place)
with paddle.fluid.dygraph.guard():
data = self.fake_data["col_nan_odd"]
out = paddle.nanmedian(paddle.to_tensor(data), keepdim=True)
np_res = np.nanmedian(data, keepdims=True)
self.assertTrue(np.allclose(np_res, out, equal_nan=True))
paddle.enable_static()
def test_check_grad(self):
paddle.disable_static(place=self.place)
shape = (4, 5)
x_np = np.random.uniform(-1, 1, shape).astype(np.float64)
x_np[0, :] = np.nan
x_np[1, :3] = np.nan
x_np[2, 3:] = np.nan
x_np_sorted = np.sort(x_np)
nan_counts = np.count_nonzero(np.isnan(x_np).astype(np.int32), axis=1)
np_grad = np.zeros((shape))
for i in range(shape[0]):
valid_cnts = shape[1] - nan_counts[i]
if valid_cnts == 0:
continue
mid = int(valid_cnts / 2)
targets = [x_np_sorted[i, mid]]
is_odd = valid_cnts % 2
if not is_odd and mid > 0:
targets.append(x_np_sorted[i, mid - 1])
for j in range(shape[1]):
if x_np[i, j] in targets:
np_grad[i, j] = 1 if is_odd else 0.5
x_tensor = paddle.to_tensor(x_np, stop_gradient=False)
y = paddle.nanmedian(x_tensor, axis=1, keepdim=True)
dx = paddle.grad(y, x_tensor)[0].numpy()
self.assertTrue(np.allclose(np_grad, dx, equal_nan=True))
if __name__ == "__main__":
unittest.main()
......@@ -263,6 +263,7 @@ from .stat import std # noqa: F401
from .stat import var # noqa: F401
from .stat import numel # noqa: F401
from .stat import median # noqa: F401
from .stat import nanmedian # noqa: F401
from .stat import quantile # noqa: F401
from .stat import nanquantile # noqa: F401
......@@ -448,6 +449,7 @@ tensor_method_func = [ #noqa
'var',
'numel',
'median',
'nanmedian',
'quantile',
'nanquantile',
'is_complex',
......
......@@ -241,6 +241,103 @@ def numel(x, name=None):
return out
def nanmedian(x, axis=None, keepdim=True, name=None):
r"""
Compute the median along the specified axis, while ignoring NaNs.
If the valid count of elements is a even number,
the average value of both elements in the middle is calculated as the median.
Args:
x (Tensor): The input Tensor, it's data type can be int32, int64, float16, float32, float64.
axis (None|int|list|tuple, optional):
The axis along which to perform median calculations ``axis`` should be int or list of int.
``axis`` should be in range [-D, D), where D is the dimensions of ``x`` .
If ``axis`` is less than 0, it works the same way as :math:`axis + D`.
If ``axis`` is None, median is calculated over all elements of ``x``. Default is None.
keepdim (bool, optional): Whether to reserve the reduced dimension(s)
in the output Tensor. If ``keepdim`` is True, the dimensions of
the output Tensor is the same as ``x`` except in the reduced
dimensions(it is of size 1 in this case). Otherwise, the shape of
the output Tensor is squeezed in ``axis`` . Default is True.
name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
Returns:
Tensor, results of median along ``axis`` of ``x``. The output dtype is the same as `x`.
Examples:
.. code-block:: python
:name: nanmedian-example
import paddle
x = paddle.to_tensor([[float('nan'), 2. , 3. ], [0. , 1. , 2. ]])
y1 = x.nanmedian()
# y1 is [[2.]]
y2 = x.nanmedian(0)
# y2 is [[0., 1.5, 2.5]]
y3 = x.nanmedian(0, keepdim=False)
# y3 is [0., 1.5, 2.5]
y4 = x.nanmedian((0, 1))
# y4 is [[2.]]
"""
if not isinstance(x, Variable):
raise TypeError("In median, the input x should be a Tensor.")
if isinstance(axis, (list, tuple)) and len(axis) == 0:
raise ValueError("Axis list should not be empty.")
dims = len(x.shape)
if axis is None:
axis = []
elif isinstance(axis, tuple):
axis = list(axis)
elif isinstance(axis, int):
axis = [axis]
if not isinstance(axis, list):
raise ValueError(
"Axis should be None, int, or a list, element should in range [-rank(x), rank(x))."
)
for i in range(len(axis)):
if not isinstance(axis[i], int) or not (axis[i] < dims and
axis[i] >= -dims):
raise ValueError(
"Axis should be None, int, or a list, element should in range [-rank(x), rank(x))."
)
if axis[i] < 0:
axis[i] += dims
if len(axis) != len(set(axis)):
raise ValueError("Axis has duplicated elements.")
if _in_legacy_dygraph():
median_index, out = _C_ops.nanmedian(x, 'axis', axis, 'keepdim',
keepdim)
return out
check_variable_and_dtype(
x, 'X', ['int32', 'int64', 'float16', 'float32', 'float64'],
'nanmedian')
helper = LayerHelper('nanmedian', **locals())
attrs = {'axis': axis, 'keepdim': keepdim}
out = helper.create_variable_for_type_inference(x.dtype)
medians = helper.create_variable_for_type_inference(x.dtype)
helper.append_op(
type='nanmedian',
inputs={'X': x},
outputs={'Out': out,
'MedianIndex': medians},
attrs=attrs)
return out
def median(x, axis=None, keepdim=False, name=None):
"""
Compute the median along the specified axis.
......
......@@ -824,7 +824,7 @@ FOURTH_HIGH_PARALLEL_JOB_NEW = [
'test_mean_op', 'test_is_tensor', 'test_run_program_op',
'test_cuda_random_seed', 'test_linear_interp_op',
'test_fuse_all_reduce_pass', 'tensor_util_test', 'test_median',
'test_linear', 'test_imperative_qat_amp',
'test_nanmedian', 'test_linear', 'test_imperative_qat_amp',
'test_truncated_gaussian_random_op', 'test_lstm_cudnn_op',
'copy_same_tensor_test', 'test_squeeze2_op',
'naive_best_fit_allocator_test', 'test_model', 'test_py_reader_combination',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册