diff --git a/paddle/fluid/framework/tensor_util.cc b/paddle/fluid/framework/tensor_util.cc index 829c182c991d71e0a137437f9dc70fcf08358ca9..3b3271fc5b936e65b60930f43ea5c4f6f8448941 100644 --- a/paddle/fluid/framework/tensor_util.cc +++ b/paddle/fluid/framework/tensor_util.cc @@ -420,6 +420,61 @@ inline void Any(const framework::Tensor& tensor, Predicate predicate, platform::VisitPlace(place, visitor); } +template +struct AllDTypeVisitor { + Predicate predicate_; + const Tensor& tensor_; + const DevCtx& ctx_; + Tensor* out_; + + AllDTypeVisitor(Predicate predicate, const Tensor& tensor, const DevCtx& ctx, + Tensor* out) + : predicate_(predicate), tensor_(tensor), ctx_(ctx), out_(out) {} + + template + void apply() const { + auto t = EigenVector::Flatten(tensor_); + auto o = EigenVector::Flatten(*out_); + o.device(*ctx_.eigen_device()) = predicate_(t); + } +}; + +template +inline void AllImpl(Predicate predicate, const framework::Tensor& tensor, + const DevCtx& ctx, framework::Tensor* out) { + VisitDataType(tensor.type(), AllDTypeVisitor( + predicate, tensor, ctx, out)); +} + +template +class AllOutVisitor : public boost::static_visitor<> { + private: + const framework::Tensor& tensor_; + mutable framework::Tensor* out_; + Predicate predicate_; + + public: + AllOutVisitor(const framework::Tensor& tensor, Predicate predicate, + framework::Tensor* out) + : tensor_(tensor), out_(out), predicate_(predicate) {} + + template + void operator()(const Place& place) const { + auto* ctx = platform::DeviceContextPool::Instance().GetByPlace(place); + out_->Resize(tensor_.dims()); + out_->mutable_data(place); + AllImpl(predicate_, tensor_, *ctx, out_); + } +}; + +template +inline void All(const framework::Tensor& tensor, Predicate predicate, + framework::Tensor* out) { + AllOutVisitor visitor(tensor, predicate, out); + auto place = tensor.place(); + platform::VisitPlace(place, visitor); +} + struct ContainsNANPredicate { template auto operator()(const T& eigen_vec) const @@ -440,6 +495,12 @@ void TensorContainsNAN(const framework::Tensor& tensor, Any(tensor, predicate, out); } +void TensorContainsNANV2(const framework::Tensor& tensor, + framework::Tensor* out) { + ContainsNANPredicate predicate; + All(tensor, predicate, out); +} + struct ContainsInfPredicate { template auto operator()(const T& eigen_vec) const @@ -460,6 +521,12 @@ void TensorContainsInf(const framework::Tensor& tensor, Any(tensor, predicate, out); } +void TensorContainsInfV2(const framework::Tensor& tensor, + framework::Tensor* out) { + ContainsInfPredicate predicate; + All(tensor, predicate, out); +} + // NOTE(dzhwinter): // Isfinite need a AllVisitor to loop through all the elements. // We choose two cuda call instead of one allvisitor. The AllVisitor @@ -472,8 +539,8 @@ bool TensorIsfinite(const framework::Tensor& tensor) { #ifdef PADDLE_WITH_CUDA template -static inline void __global__ BothFalse(const T* cmp, T* out) { - out[0] = (!cmp[0]) && (!out[0]); +static inline void __global__ BothFalse(const T* cmp, T* out, int element_num) { + CUDA_KERNEL_LOOP(i, element_num) { out[i] = (!cmp[i]) && (!out[i]); } } #endif @@ -495,22 +562,40 @@ struct BothFalseVisitor : public boost::static_visitor<> { void VisitorImpl(const platform::CUDAPlace& gpu) const { #ifdef PADDLE_WITH_CUDA auto* ctx = platform::DeviceContextPool::Instance().GetByPlace(gpu); - BothFalse<<<1, 1, 0, ctx->stream()>>>(in_.data(), - out_->mutable_data(gpu)); + constexpr int MAX_BLOCK_DIM = 512; + const int MAX_GRID_DIM = ctx->GetMaxPhysicalThreadCount() / MAX_BLOCK_DIM; + int element_num = in_.numel(); + int block_size = (element_num >= MAX_BLOCK_DIM) + ? MAX_BLOCK_DIM + : (1 << static_cast(std::log2(element_num))); + int grid_size = element_num / block_size; + grid_size = (grid_size >= MAX_GRID_DIM) ? MAX_GRID_DIM : grid_size; + BothFalse<<stream()>>>( + in_.data(), out_->mutable_data(gpu), element_num); #endif } void VisitorImpl(const platform::CPUPlace& cpu) const { - bool lhs = !in_.data()[0]; - bool rhs = !out_->mutable_data(cpu)[0]; - out_->mutable_data(cpu)[0] = lhs && rhs; + int num = in_.numel(); + const bool* in_ptr = in_.data(); + bool* out_ptr = out_->data(); + for (int i = 0; i < num; ++i) { + bool lhs = !in_ptr[i]; + bool rhs = !out_ptr[i]; + out_ptr[i] = lhs && rhs; + } } void VisitorImpl( const platform::CUDAPinnedPlace& cpu /* equals to cpu*/) const { - bool lhs = !in_.data()[0]; - bool rhs = !out_->mutable_data(cpu)[0]; - out_->mutable_data(cpu)[0] = lhs && rhs; + int num = in_.numel(); + const bool* in_ptr = in_.data(); + bool* out_ptr = out_->data(); + for (int i = 0; i < num; ++i) { + bool lhs = !in_ptr[i]; + bool rhs = !out_ptr[i]; + out_ptr[i] = lhs && rhs; + } } }; @@ -523,6 +608,15 @@ void TensorIsfinite(const framework::Tensor& tensor, framework::Tensor* out) { platform::VisitPlace(place, visitor); } +void TensorIsfiniteV2(const framework::Tensor& tensor, framework::Tensor* out) { + framework::Tensor tmp; + TensorContainsInfV2(tensor, &tmp); + TensorContainsNANV2(tensor, out); + BothFalseVisitor visitor(tmp, out); + auto place = tensor.place(); + platform::VisitPlace(place, visitor); +} + void TensorToStream(std::ostream& os, const Tensor& tensor, const platform::DeviceContext& dev_ctx) { { // the 1st field, uint32_t version diff --git a/paddle/fluid/framework/tensor_util.h b/paddle/fluid/framework/tensor_util.h index c71327da64042aed85f1247f3c31de3e66a588ba..fce0142b41d3ae9b2a6fcd4f16d38b0492fbd806 100644 --- a/paddle/fluid/framework/tensor_util.h +++ b/paddle/fluid/framework/tensor_util.h @@ -76,6 +76,13 @@ void TensorFromStream(std::istream& is, Tensor* tensor, const platform::DeviceContext& dev_ctx, const size_t& seek, const std::vector& shape); +// store the bool result tensor in out tensor +void TensorContainsNANV2(const framework::Tensor& tensor, + framework::Tensor* out); +void TensorContainsInfV2(const framework::Tensor& tensor, + framework::Tensor* out); +void TensorIsfiniteV2(const framework::Tensor& tensor, framework::Tensor* out); + // convert dlpack's DLTensor to tensor void TensorFromDLPack(const ::DLTensor& dl_tensor, framework::Tensor* dst); diff --git a/paddle/fluid/operators/isfinite_v2_op.cc b/paddle/fluid/operators/isfinite_v2_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..72da43e3bc63c1c585fe19d703892c23ce7b0ec2 --- /dev/null +++ b/paddle/fluid/operators/isfinite_v2_op.cc @@ -0,0 +1,122 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/isfinite_v2_op.h" +#include +#include +#include "paddle/fluid/operators/common_infer_shape_functions.h" +#include "paddle/fluid/platform/float16.h" + +namespace plat = paddle::platform; + +namespace paddle { +namespace operators { + +class OverflowV2Op : public framework::OperatorWithKernel { + public: + OverflowV2Op(const std::string &type, + const framework::VariableNameMap &inputs, + const framework::VariableNameMap &outputs, + const framework::AttributeMap &attrs) + : OperatorWithKernel(type, inputs, outputs, attrs) {} + void InferShape(framework::InferShapeContext *ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "isfinitev2"); + OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "isfinitev2"); + UnaryOpUnchangedInferShape(ctx); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + int dtype = -1; + auto *x_var = ctx.InputVar("X"); + if (x_var->IsType()) { + dtype = x_var->Get().type(); + } else if (x_var->IsType()) { + dtype = x_var->Get().value().type(); + } else { + PADDLE_THROW(plat::errors::InvalidArgument( + "Cannot find the input data type by all input data")); + } + return framework::OpKernelType(framework::proto::VarType::Type(dtype), + ctx.GetPlace()); + } +}; + +class OverflowV2OpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "(Tensor) The input tensors of overflowv2 operator."); + AddOutput("Out", + "(Tensor) The output tensor of overflowv2 operator. " + "Same size compare to input tensor"); + AddComment(string::Sprintf(R"DOC( +Overflow %s operator. + +$$Out = %s(X)$$ + +Check whether each element of X is Inf or Nan, return the bool result of each +element of X as a tensor. + +%s +)DOC", + GetName(), GetComments())); + } + + protected: + virtual std::string GetName() const = 0; + virtual std::string GetComments() const = 0; +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +#define REGISTER_V2OP_MAKER(op_type, comment) \ + namespace paddle { \ + namespace operators { \ + class _##op_type##OverflowV2OpMaker \ + : public ::paddle::operators::OverflowV2OpMaker { \ + protected: \ + std::string GetName() const { return #op_type; } \ + std::string GetComments() const { return comment; } \ + }; \ + } \ + } \ + REGISTER_OPERATOR( \ + op_type, ops::OverflowV2Op, ops::_##op_type##OverflowV2OpMaker, \ + paddle::framework::EmptyGradOpMaker, \ + paddle::framework::EmptyGradOpMaker) + +#define REGISTER_OVERFLOW_CPU_KERNEL(op_type, functor) \ + REGISTER_OP_CPU_KERNEL( \ + op_type, ops::OverflowKernel, \ + ops::OverflowKernel, \ + ops::OverflowKernel, \ + ops::OverflowKernel, \ + ops::OverflowKernel); + +REGISTER_V2OP_MAKER(isinf_v2, "isinfv2(X)"); +REGISTER_V2OP_MAKER(isnan_v2, "isnanv2(X)"); +REGISTER_V2OP_MAKER(isfinite_v2, "isfinitev2(X)"); + +REGISTER_OVERFLOW_CPU_KERNEL(isinf_v2, InfinityV2Functor); +REGISTER_OVERFLOW_CPU_KERNEL(isnan_v2, NANV2Functor); +REGISTER_OVERFLOW_CPU_KERNEL(isfinite_v2, IsfiniteV2Functor); diff --git a/paddle/fluid/operators/isfinite_v2_op.cu b/paddle/fluid/operators/isfinite_v2_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..4a6d818d0501e60dfffc8995075bb7f0369788fd --- /dev/null +++ b/paddle/fluid/operators/isfinite_v2_op.cu @@ -0,0 +1,36 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/isfinite_v2_op.h" +#include "paddle/fluid/platform/float16.h" + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +#define REGISTER_OVERFLOW_CUDA_KERNEL(op_type, functor) \ + REGISTER_OP_CUDA_KERNEL( \ + op_type, ops::OverflowKernel, \ + ops::OverflowKernel, \ + ops::OverflowKernel, \ + ops::OverflowKernel, \ + ops::OverflowKernel); + +REGISTER_OVERFLOW_CUDA_KERNEL(isinf_v2, InfinityV2Functor); +REGISTER_OVERFLOW_CUDA_KERNEL(isnan_v2, NANV2Functor); +REGISTER_OVERFLOW_CUDA_KERNEL(isfinite_v2, IsfiniteV2Functor); diff --git a/paddle/fluid/operators/isfinite_v2_op.h b/paddle/fluid/operators/isfinite_v2_op.h new file mode 100644 index 0000000000000000000000000000000000000000..9f0aa63ce80248ee9f7839890f611b9d5293789e --- /dev/null +++ b/paddle/fluid/operators/isfinite_v2_op.h @@ -0,0 +1,47 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/tensor_util.h" +#include "paddle/fluid/operators/isfinite_op.h" +#include "paddle/fluid/platform/float16.h" +#include "paddle/fluid/platform/transform.h" + +namespace paddle { +namespace operators { + +struct InfinityV2Functor { + void operator()(const framework::Tensor& tensor, framework::Tensor* out) { + framework::TensorContainsInfV2(tensor, out); + } +}; + +struct NANV2Functor { + void operator()(const framework::Tensor& tensor, framework::Tensor* out) { + framework::TensorContainsNANV2(tensor, out); + } +}; + +struct IsfiniteV2Functor { + void operator()(const framework::Tensor& tensor, framework::Tensor* out) { + framework::TensorIsfiniteV2(tensor, out); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/nll_loss_op.cc b/paddle/fluid/operators/nll_loss_op.cc index e99ccd31714787306358d9b19b31a62ff21d5dab..f0b5f4a466a0049c53d51d8610cf115d8bfe0295 100644 --- a/paddle/fluid/operators/nll_loss_op.cc +++ b/paddle/fluid/operators/nll_loss_op.cc @@ -55,8 +55,8 @@ class NLLLossOp : public framework::OperatorWithKernel { "Input(Weight) should be a 1D tensor.")); PADDLE_ENFORCE_EQ(x_dims[1], w_dims[0], platform::errors::InvalidArgument( - "Input(Weight) Tensor's size should match" - "to the class numer.")); + "Input(Weight) Tensor's size should match " + "to the the total number of classes.")); } } if (x_dims.size() == 2) { diff --git a/paddle/fluid/operators/nll_loss_op.h b/paddle/fluid/operators/nll_loss_op.h index 92f3d169f3f6a3be1009d84ebd87c82691eb9f0c..e93d5792205900635093e5f18d715e4607f73cda 100644 --- a/paddle/fluid/operators/nll_loss_op.h +++ b/paddle/fluid/operators/nll_loss_op.h @@ -91,7 +91,7 @@ static void nll_loss_2D(T* out_data, T* total_weight_data, const T* x_data, } PADDLE_ENFORCE_EQ(cur_label >= 0 && cur_label < n_classes, true, platform::errors::InvalidArgument( - "label should nor be out of bounds.")); + "label should not be out of bounds.")); const auto cur_weight = weight_data ? weight_data[cur_label] : static_cast(1); out_data[index] = -x_data[i * sample_size + cur_label * map_size + @@ -117,7 +117,7 @@ static void nll_loss_2D(T* out_data, T* total_weight_data, const T* x_data, } PADDLE_ENFORCE_EQ(cur_label >= 0 && cur_label < n_classes, true, platform::errors::InvalidArgument( - "label should nor be out of bounds.")); + "label should not be out of bounds.")); const auto cur_weight = weight_data ? weight_data[cur_label] : static_cast(1); total_weight_val += cur_weight; diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index bb020b58f44e70e493b9e4fd8b57f7f72a9f6cde..4e1e04043ad7d2fd72bfe891b755a2503c2096b3 100644 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -91,7 +91,7 @@ from .tensor.logic import equal #DEFINE_ALIAS from .tensor.logic import greater_equal #DEFINE_ALIAS from .tensor.logic import greater_than #DEFINE_ALIAS from .tensor.logic import is_empty #DEFINE_ALIAS -from .tensor.logic import isfinite #DEFINE_ALIAS +#from .tensor.logic import isfinite #DEFINE_ALIAS from .tensor.logic import less_equal #DEFINE_ALIAS from .tensor.logic import less_than #DEFINE_ALIAS from .tensor.logic import logical_and #DEFINE_ALIAS @@ -193,6 +193,9 @@ from .tensor.math import addmm #DEFINE_ALIAS from .tensor.math import clip #DEFINE_ALIAS from .tensor.math import trace #DEFINE_ALIAS from .tensor.math import kron #DEFINE_ALIAS +from .tensor.math import isfinite #DEFINE_ALIAS +from .tensor.math import isinf #DEFINE_ALIAS +from .tensor.math import isnan #DEFINE_ALIAS from .tensor.math import prod #DEFINE_ALIAS from .tensor.random import standard_normal from .tensor.random import normal diff --git a/python/paddle/fluid/tests/unittests/test_isfinite_v2_op.py b/python/paddle/fluid/tests/unittests/test_isfinite_v2_op.py new file mode 100644 index 0000000000000000000000000000000000000000..8a868e751f0567e6387b0e9471f0382c9456bcb6 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_isfinite_v2_op.py @@ -0,0 +1,161 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.fluid as fluid +import unittest +import numpy as np + + +def run_static(x_np, dtype, op_str, use_gpu=False): + paddle.enable_static() + startup_program = fluid.Program() + main_program = fluid.Program() + place = paddle.CPUPlace() + if use_gpu and fluid.core.is_compiled_with_cuda(): + place = paddle.CUDAPlace(0) + exe = fluid.Executor(place) + with fluid.program_guard(main_program, startup_program): + x = paddle.data(name='x', shape=x_np.shape, dtype=dtype) + res = getattr(paddle.tensor, op_str)(x) + exe.run(startup_program) + static_result = exe.run(main_program, + feed={'x': x_np}, + fetch_list=[res]) + return static_result + + +def run_dygraph(x_np, op_str, use_gpu=True): + place = paddle.CPUPlace() + if use_gpu and fluid.core.is_compiled_with_cuda(): + place = paddle.CUDAPlace(0) + paddle.disable_static(place) + x = paddle.to_variable(x_np) + dygraph_result = getattr(paddle.tensor, op_str)(x) + return dygraph_result + + +def np_data_generator(low, high, np_shape, type, sv_list, op_str, *args, + **kwargs): + x_np = np.random.uniform(low, high, np_shape).astype(getattr(np, type)) + # x_np.shape[0] >= len(sv_list) + if type in ['float16', 'float32', 'float64']: + for i, v in enumerate(sv_list): + x_np[i] = v + ori_shape = x_np.shape + x_np = x_np.reshape((np.product(ori_shape), )) + np.random.shuffle(x_np) + x_np = x_np.reshape(ori_shape) + result_np = getattr(np, op_str)(x_np) + return x_np, result_np + + +TEST_META_DATA = [ + { + 'low': 0.1, + 'high': 1, + 'np_shape': [8, 17, 5, 6, 7], + 'type': 'float16', + 'sv_list': [np.inf, np.nan] + }, + { + 'low': 0.1, + 'high': 1, + 'np_shape': [11, 17], + 'type': 'float32', + 'sv_list': [np.inf, np.nan] + }, + { + 'low': 0.1, + 'high': 1, + 'np_shape': [2, 3, 4, 5], + 'type': 'float64', + 'sv_list': [np.inf, np.nan] + }, + { + 'low': 0, + 'high': 100, + 'np_shape': [11, 17, 10], + 'type': 'int32', + 'sv_list': [np.inf, np.nan] + }, + { + 'low': 0, + 'high': 999, + 'np_shape': [132], + 'type': 'int64', + 'sv_list': [np.inf, np.nan] + }, +] + + +def test(test_case, op_str, use_gpu=False): + for meta_data in TEST_META_DATA: + meta_data = dict(meta_data) + meta_data['op_str'] = op_str + x_np, result_np = np_data_generator(**meta_data) + static_result = run_static(x_np, meta_data['type'], op_str, use_gpu) + dygraph_result = run_dygraph(x_np, op_str, use_gpu) + test_case.assertTrue((static_result == result_np).all()) + test_case.assertTrue((dygraph_result.numpy() == result_np).all()) + + +class TestCPUNormal(unittest.TestCase): + def test_inf(self): + test(self, 'isinf') + + def test_nan(self): + test(self, 'isnan') + + def test_finite(self): + test(self, 'isfinite') + + +class TestCUDANormal(unittest.TestCase): + def test_inf(self): + test(self, 'isinf', True) + + def test_nan(self): + test(self, 'isnan', True) + + def test_finite(self): + test(self, 'isfinite', True) + + +class TestError(unittest.TestCase): + def test_bad_input(self): + paddle.enable_static() + with fluid.program_guard(fluid.Program()): + + def test_isinf_bad_x(): + x = [1, 2, 3] + result = paddle.tensor.isinf(x) + + self.assertRaises(TypeError, test_isinf_bad_x) + + def test_isnan_bad_x(): + x = [1, 2, 3] + result = paddle.tensor.isnan(x) + + self.assertRaises(TypeError, test_isnan_bad_x) + + def test_isfinite_bad_x(): + x = [1, 2, 3] + result = paddle.tensor.isfinite(x) + + self.assertRaises(TypeError, test_isfinite_bad_x) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py index 35634ba701391ed99464085b70a93acfd5370709..0fed32a1676759bd94961af0a8949d035ec48c8f 100644 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -58,7 +58,7 @@ from .logic import equal #DEFINE_ALIAS from .logic import greater_equal #DEFINE_ALIAS from .logic import greater_than #DEFINE_ALIAS from .logic import is_empty #DEFINE_ALIAS -from .logic import isfinite #DEFINE_ALIAS +#from .logic import isfinite #DEFINE_ALIAS from .logic import less_equal #DEFINE_ALIAS from .logic import less_than #DEFINE_ALIAS from .logic import logical_and #DEFINE_ALIAS @@ -161,6 +161,9 @@ from .math import addmm #DEFINE_ALIAS from .math import clip #DEFINE_ALIAS from .math import trace #DEFINE_ALIAS from .math import kron #DEFINE_ALIAS +from .math import isfinite #DEFINE_ALIAS +from .math import isinf #DEFINE_ALIAS +from .math import isnan #DEFINE_ALIAS from .math import prod #DEFINE_ALIAS from .random import standard_normal from .random import normal diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index 67c225510dbcac72f9debe5febb53228dd228754..77639e8da466bcfb88f81d2d905a66d374a6d6c1 100644 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -126,7 +126,10 @@ __all__ = [ 'addmm', 'clip', 'trace', - 'kron' + 'kron', + 'isfinite', + 'isinf', + 'isnan' ] # yapf: enable. @@ -1669,6 +1672,100 @@ def cumsum(x, axis=None, dtype=None, name=None): _cum_sum_ = generate_layer_fn('cumsum') return _cum_sum_(**kwargs) +def isfinite(x, name=None): + """ + + Return whether every element of input tensor is finite number or not. + + Args: + x (Tensor): The input tensor, it's data type should be float16, float32, float64, int32, int64. + name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. + + Returns: + `Tensor`, the bool result which shows every element of `x` whether it is finite number or not. + + Examples: + .. code-block:: python + + import paddle + import numpy as np + paddle.disable_static() + x_np = np.array([float('-inf'), -2, 3.6, float('inf'), 0, float('-nan'), float('nan')]) + x = paddle.to_tensor(x_np) + out = paddle.tensor.isfinite(x) + print(out.numpy()) # [False True True False True False False] + """ + if in_dygraph_mode(): + return core.ops.isfinite_v2(x) + helper = LayerHelper("isfinite_v2", **locals()) + check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64', 'int32', 'int64'], 'isfinite') + out = helper.create_variable_for_type_inference('bool') + helper.append_op(type="isfinite_v2", inputs={"X": x}, outputs={"Out": out}) + return out + +def isinf(x, name=None): + """ + + Return whether every element of input tensor is `+/-INF` or not. + + Args: + x (Tensor): The input tensor, it's data type should be float16, float32, float64, int32, int64. + name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. + + Returns: + `Tensor`, the bool result which shows every element of `x` whether it is `+/-INF` or not. + + Examples: + .. code-block:: python + + import paddle + import numpy as np + paddle.disable_static() + x_np = np.array([float('-inf'), -2, 3.6, float('inf'), 0, float('-nan'), float('nan')]) + x = paddle.to_tensor(x_np) + out = paddle.tensor.isinf(x) + print(out.numpy()) # [ True False False True False False False] + """ + if in_dygraph_mode(): + return core.ops.isinf_v2(x) + helper = LayerHelper("isinf_v2", **locals()) + check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64', 'int32', 'int64'], 'isinf') + out = helper.create_variable_for_type_inference(dtype='bool') + helper.append_op(type="isinf_v2", inputs={"X": x}, outputs={"Out": out}) + return out + +def isnan(x, name=None): + """ + + Return whether every element of input tensor is `NaN` or not. + + Args: + x (Tensor): The input tensor, it's data type should be float16, float32, float64, int32, int64. + name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. + + Returns: + `Tensor`, the bool result which shows every element of `x` whether it is `NaN` or not. + + Examples: + .. code-block:: python + + import paddle + import numpy as np + paddle.disable_static() + x_np = np.array([float('-inf'), -2, 3.6, float('inf'), 0, float('-nan'), float('nan')]) + x = paddle.to_tensor(x_np) + out = paddle.tensor.isnan(x) + print(out.numpy()) # [False False False False False True True] + """ + if in_dygraph_mode(): + return core.ops.isnan_v2(x) + helper = LayerHelper("isnan_v2", **locals()) + check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64', 'int32', 'int64'], 'isnan') + out = helper.create_variable_for_type_inference(dtype='bool') + helper.append_op(type="isnan_v2", inputs={"X": x}, outputs={"Out": out}) + return out + + def prod(x, axis=None, keepdim=False, dtype=None, name=None): """ Compute the product of tensor elements over the given axis. @@ -1694,7 +1791,7 @@ def prod(x, axis=None, keepdim=False, dtype=None, name=None): Raises: ValueError: The :attr:`dtype` must be float32, float64, int32 or int64. TypeError: The type of :attr:`axis` must be int, list or tuple. - + Examples: .. code-block:: python