From 78916a7a26563166fea783c1040e782ad0e3bee3 Mon Sep 17 00:00:00 2001 From: Leo Chen Date: Wed, 24 Aug 2022 19:53:28 +0800 Subject: [PATCH] make tensor_util contains no cuda code (#45256) * make tensor_util contains no cuda code * refine isfinite * revert ut * move isfinite function to its op * fix test * fix compile * std::isnan is not defined for int type on windows * fix windows compile * fix fp16 * fix rocm compile * revert gradient node --- paddle/fluid/framework/CMakeLists.txt | 41 +- paddle/fluid/framework/data_type.h | 22 + .../fluid/framework/downpour_lite_worker.cc | 1 + paddle/fluid/framework/downpour_worker.cc | 1 + paddle/fluid/framework/downpour_worker_opt.cc | 1 + paddle/fluid/framework/operator.cc | 1 + paddle/fluid/framework/tensor_util.cc | 387 ------------------ paddle/fluid/framework/tensor_util.h | 35 +- paddle/fluid/framework/tensor_util_test.cc | 2 +- paddle/fluid/framework/tensor_util_test.cu | 1 + paddle/fluid/operators/isfinite_op.cc | 8 - paddle/fluid/operators/isfinite_op.h | 116 +++++- paddle/fluid/operators/memcpy_d2h_op.h | 1 + paddle/phi/common/float16.h | 4 + paddle/phi/kernels/cpu/amp_kernel.cc | 10 +- paddle/phi/kernels/cpu/isfinite_kernel.cc | 13 - paddle/phi/kernels/funcs/CMakeLists.txt | 2 +- paddle/phi/kernels/funcs/isfinite_functor.h | 83 +++- paddle/phi/kernels/gpu/isfinite_kernel.cu | 13 - .../phi/kernels/impl/isfinite_kernel_impl.h | 27 +- .../selected_rows/impl/isfinite_kernel_impl.h | 18 +- .../kernels/selected_rows/isfinite_kernel.cc | 12 - .../kernels/selected_rows/isfinite_kernel.h | 1 + 23 files changed, 276 insertions(+), 524 deletions(-) diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt index 49b98cf6b51..2dd2b5162f7 100755 --- a/paddle/fluid/framework/CMakeLists.txt +++ b/paddle/fluid/framework/CMakeLists.txt @@ -78,52 +78,31 @@ cc_test( data_type_test SRCS data_type_test.cc DEPS data_type place tensor) -if(WITH_GPU) - if(WIN32) - windows_symbolic(tensor_util SRCS tensor_util.cu) - nv_library( - tensor - SRCS .tensor_util.cu - DEPS place memory data_type device_context dense_tensor) - add_dependencies(tensor tensor_util) - else() - nv_library( - tensor - SRCS tensor_util.cu - DEPS place memory data_type device_context dense_tensor) - endif() -elseif(WITH_ROCM) - hip_library( - tensor - SRCS tensor_util.cu - DEPS place memory data_type device_context dense_tensor) -else() - cc_library( - tensor - SRCS tensor_util.cc - DEPS place memory data_type device_context dense_tensor) -endif() -# target_link(tensor profiler) + +cc_library( + tensor + SRCS tensor_util.cc + DEPS place memory data_type device_context dense_tensor) cc_test( tensor_test SRCS tensor_test.cc - DEPS tensor) + DEPS tensor isfinite_op) if(WITH_GPU) nv_test( tensor_util_test SRCS tensor_util_test.cc tensor_util_test.cu - DEPS tensor dlpack_tensor) + DEPS tensor dlpack_tensor isfinite_op) elseif(WITH_ROCM) hip_test( tensor_util_test SRCS tensor_util_test.cc tensor_util_test.cu - DEPS tensor dlpack_tensor) + DEPS tensor dlpack_tensor isfinite_op) else() cc_test( tensor_util_test SRCS tensor_util_test.cc - DEPS tensor dlpack_tensor) + DEPS tensor dlpack_tensor isfinite_op) endif() cc_test( @@ -204,7 +183,7 @@ cc_test( cc_library( var_type_traits SRCS var_type_traits.cc - DEPS lod_tensor selected_rows_utils framework_proto scope) + DEPS selected_rows_utils framework_proto scope) if(WITH_GPU) target_link_libraries(var_type_traits dynload_cuda) endif() diff --git a/paddle/fluid/framework/data_type.h b/paddle/fluid/framework/data_type.h index 091ef841125..ab63b489a2e 100644 --- a/paddle/fluid/framework/data_type.h +++ b/paddle/fluid/framework/data_type.h @@ -83,6 +83,13 @@ struct DataTypeTrait { _ForEachDataTypeHelper_( \ callback, ::paddle::platform::complex, COMPLEX128); +#define _ForEachDataTypeNormal_(callback) \ + _ForEachDataTypeHelper_(callback, float, FP32); \ + _ForEachDataTypeHelper_(callback, double, FP64); \ + _ForEachDataTypeHelper_(callback, int, INT32); \ + _ForEachDataTypeHelper_(callback, int64_t, INT64); \ + _ForEachDataTypeHelper_(callback, ::paddle::platform::float16, FP16); + // For the use of thrust, as index-type elements can be only integers. #define _ForEachDataTypeTiny_(callback) \ _ForEachDataTypeHelper_(callback, int, INT32); \ @@ -148,6 +155,21 @@ inline void VisitDataTypeSmall(proto::VarType::Type type, Visitor visitor) { #undef VisitDataTypeCallbackSmall } +// for normal dtype, int, int64, float, float64, float16 +template +inline void VisitDataTypeNormal(proto::VarType::Type type, Visitor visitor) { +#define VisitDataTypeCallbackNormal(cpp_type, proto_type) \ + do { \ + if (type == proto_type) { \ + visitor.template apply(); \ + return; \ + } \ + } while (0) + + _ForEachDataTypeNormal_(VisitDataTypeCallbackNormal); +#undef VisitDataTypeCallbackNormal +} + template inline void VisitIntDataType(proto::VarType::Type type, Visitor visitor) { #define VisitIntDataTypeCallback(cpp_type, proto_type) \ diff --git a/paddle/fluid/framework/downpour_lite_worker.cc b/paddle/fluid/framework/downpour_lite_worker.cc index 0d0e7fc468b..bd2c404a6fd 100644 --- a/paddle/fluid/framework/downpour_lite_worker.cc +++ b/paddle/fluid/framework/downpour_lite_worker.cc @@ -15,6 +15,7 @@ limitations under the License. */ #if defined(PADDLE_WITH_PSCORE) #include "paddle/fluid/framework/device_worker.h" #include "paddle/fluid/framework/fleet/metrics.h" +#include "paddle/fluid/operators/isfinite_op.h" #include "paddle/fluid/platform/cpu_helper.h" namespace phi { diff --git a/paddle/fluid/framework/downpour_worker.cc b/paddle/fluid/framework/downpour_worker.cc index e9c3ddc5ebc..0bd577d2aa6 100644 --- a/paddle/fluid/framework/downpour_worker.cc +++ b/paddle/fluid/framework/downpour_worker.cc @@ -14,6 +14,7 @@ limitations under the License. */ #include "paddle/fluid/framework/device_worker.h" #include "paddle/fluid/framework/fleet/metrics.h" +#include "paddle/fluid/operators/isfinite_op.h" #include "paddle/fluid/platform/cpu_helper.h" namespace phi { diff --git a/paddle/fluid/framework/downpour_worker_opt.cc b/paddle/fluid/framework/downpour_worker_opt.cc index f8af18d9b30..9e4b5e9da54 100644 --- a/paddle/fluid/framework/downpour_worker_opt.cc +++ b/paddle/fluid/framework/downpour_worker_opt.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/framework/device_worker.h" +#include "paddle/fluid/operators/isfinite_op.h" #include "paddle/fluid/platform/cpu_helper.h" namespace paddle { diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index 35865252629..b5d6a3786c3 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -27,6 +27,7 @@ limitations under the License. */ #include "paddle/fluid/framework/transfer_scope_cache.h" #include "paddle/fluid/framework/unused_var_check.h" #include "paddle/fluid/framework/var_type.h" +#include "paddle/fluid/operators/isfinite_op.h" #include "paddle/fluid/platform/device/device_wrapper.h" #include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/profiler.h" diff --git a/paddle/fluid/framework/tensor_util.cc b/paddle/fluid/framework/tensor_util.cc index f7f05da6340..ca1a65be7d0 100644 --- a/paddle/fluid/framework/tensor_util.cc +++ b/paddle/fluid/framework/tensor_util.cc @@ -651,393 +651,6 @@ void TensorCopySync(const Tensor& src, #endif } -template -struct AnyDTypeVisitor { - Predicate predicate_; - const Tensor& tensor_; - const DevCtx& ctx_; - Tensor* out_; - - AnyDTypeVisitor(Predicate predicate, - const Tensor& tensor, - const DevCtx& ctx, - Tensor* out) - : predicate_(predicate), tensor_(tensor), ctx_(ctx), out_(out) {} - - template - void apply() const { - auto t = EigenVector::Flatten(tensor_); - auto o = EigenScalar::From(*out_); - // return any of predicate_(t) is true. - o.device(*ctx_.eigen_device()) = predicate_(t).any(); - } -}; - -template -inline void AnyImpl(Predicate predicate, - const framework::Tensor& tensor, - const DevCtx& ctx, - framework::Tensor* out) { - VisitDataType( - framework::TransToProtoVarType(tensor.dtype()), - AnyDTypeVisitor(predicate, tensor, ctx, out)); -} - -template -class AnyVisitor : public std::unary_function { - private: - const framework::Tensor& tensor_; - Predicate predicate_; - - bool GetResultHelper(const framework::Tensor& out, - const platform::Place& place) const { - platform::CPUPlace cpu; - framework::Tensor tmp; - tmp.Resize({1}); - tmp.mutable_data(cpu); - auto ctx = platform::DeviceContextPool::Instance().Get(place); - ctx->Wait(); - TensorCopy(out, cpu, *ctx, &tmp); - ctx->Wait(); - return GetResult(tmp, cpu); - } - - public: - AnyVisitor(const framework::Tensor& tensor, Predicate predicate) - : tensor_(tensor), predicate_(std::move(predicate)) {} - - template - bool operator()(const Place& place) const { - framework::Tensor out; - out.Resize({1}); - out.mutable_data(place); - auto* ctx = platform::DeviceContextPool::Instance().GetByPlace(place); - AnyImpl(predicate_, tensor_, *ctx, &out); - return this->GetResult(out, place); - } - - bool GetResult(const framework::Tensor& out, - const platform::XPUPlace& xpu) const { - return GetResultHelper(out, xpu); - } - - bool GetResult(const framework::Tensor& out, - const platform::MLUPlace& mlu) const { - PADDLE_THROW( - platform::errors::Unimplemented("Not supported on place (%s) ", mlu)); - return true; - } - - bool GetResult(const framework::Tensor& out, - const platform::CUDAPlace& gpu) const { - return GetResultHelper(out, gpu); - } - - bool GetResult(const framework::Tensor& out, - const platform::NPUPlace& npu) const { - PADDLE_THROW( - platform::errors::Unimplemented("Not supported on place (%s) ", npu)); - // return GetResultHelper(out, npu); - } - bool GetResult(const framework::Tensor& out, - const platform::IPUPlace& ipu) const { - PADDLE_THROW( - platform::errors::Unimplemented("Not supported on place (%s) ", ipu)); - } - - bool GetResult(const framework::Tensor& out, - const platform::NPUPinnedPlace& cpu) const { - return *out.data(); - } - - bool GetResult(const framework::Tensor& out, - const platform::CPUPlace& cpu) const { - return *out.data(); - } - - bool GetResult(const framework::Tensor& out, - const platform::CUDAPinnedPlace& cpu) const { - return *out.data(); - } - - bool GetResult(const framework::Tensor& out, - const platform::CustomPlace& custom_dev) const { - PADDLE_THROW(platform::errors::Unimplemented("Not supported on place (%s) ", - custom_dev)); - return false; - } -}; - -template -class AnyOutVisitor : public std::unary_function { - private: - const framework::Tensor& tensor_; - mutable framework::Tensor* out_; - Predicate predicate_; - - public: - AnyOutVisitor(const framework::Tensor& tensor, - Predicate predicate, - framework::Tensor* out) - : tensor_(tensor), out_(out), predicate_(std::move(predicate)) {} - - template - void operator()(const Place& place) const { - auto* ctx = platform::DeviceContextPool::Instance().GetByPlace(place); - out_->Resize({1}); - out_->mutable_data(place); - AnyImpl(predicate_, tensor_, *ctx, out_); - } -}; - -template -inline bool Any(const framework::Tensor& tensor, Predicate predicate) { - AnyVisitor visitor(tensor, predicate); - auto place = tensor.place(); - return platform::VisitPlace(place, visitor); -} - -template -inline void Any(const framework::Tensor& tensor, - Predicate predicate, - framework::Tensor* out) { - AnyOutVisitor visitor(tensor, predicate, out); - auto place = tensor.place(); - platform::VisitPlace(place, visitor); -} - -template -struct AllDTypeVisitor { - Predicate predicate_; - const Tensor& tensor_; - const DevCtx& ctx_; - Tensor* out_; - - AllDTypeVisitor(Predicate predicate, - const Tensor& tensor, - const DevCtx& ctx, - Tensor* out) - : predicate_(predicate), tensor_(tensor), ctx_(ctx), out_(out) {} - - template - void apply() const { - auto t = EigenVector::Flatten(tensor_); - auto o = EigenVector::Flatten(*out_); - o.device(*ctx_.eigen_device()) = predicate_(t); - } -}; - -template -inline void AllImpl(Predicate predicate, - const framework::Tensor& tensor, - const DevCtx& ctx, - framework::Tensor* out) { - VisitDataType( - framework::TransToProtoVarType(tensor.dtype()), - AllDTypeVisitor(predicate, tensor, ctx, out)); -} - -template -class AllOutVisitor : public std::unary_function { - private: - const framework::Tensor& tensor_; - mutable framework::Tensor* out_; - Predicate predicate_; - - public: - AllOutVisitor(const framework::Tensor& tensor, - Predicate predicate, - framework::Tensor* out) - : tensor_(tensor), out_(out), predicate_(predicate) {} - - template - void operator()(const Place& place) const { - auto* ctx = platform::DeviceContextPool::Instance().GetByPlace(place); - out_->Resize(tensor_.dims()); - out_->mutable_data(place); - AllImpl(predicate_, tensor_, *ctx, out_); - } -}; - -template -inline void All(const framework::Tensor& tensor, - Predicate predicate, - framework::Tensor* out) { - AllOutVisitor visitor(tensor, predicate, out); - auto place = tensor.place(); - platform::VisitPlace(place, visitor); -} - -struct ContainsNANPredicate { - template - auto operator()(const T& eigen_vec) const - -> decltype(std::declval().isnan()) { - // Cast eigen_vector to vector of bool. true if is inf. - return eigen_vec.isnan(); - } -}; - -bool TensorContainsNAN(const framework::Tensor& tensor) { - ContainsNANPredicate predicate; - return Any(tensor, predicate); -} - -void TensorContainsNAN(const framework::Tensor& tensor, - framework::Tensor* out) { - ContainsNANPredicate predicate; - Any(tensor, predicate, out); -} - -void TensorContainsNANV2(const framework::Tensor& tensor, - framework::Tensor* out) { - ContainsNANPredicate predicate; - All(tensor, predicate, out); -} - -struct ContainsInfPredicate { - template - auto operator()(const T& eigen_vec) const - -> decltype(std::declval().isinf()) { - // Cast eigen_vector to vector of bool. true if is inf. - return eigen_vec.isinf(); - } -}; - -bool TensorContainsInf(const framework::Tensor& tensor) { - ContainsInfPredicate predicate; - return Any(tensor, predicate); -} - -void TensorContainsInf(const framework::Tensor& tensor, - framework::Tensor* out) { - ContainsInfPredicate predicate; - Any(tensor, predicate, out); -} - -void TensorContainsInfV2(const framework::Tensor& tensor, - framework::Tensor* out) { - ContainsInfPredicate predicate; - All(tensor, predicate, out); -} - -// NOTE(dzhwinter): -// Isfinite need a AllVisitor to loop through all the elements. -// We choose two cuda call instead of one allvisitor. The AllVisitor -// should be implemented if the performance hurts. -bool TensorIsfinite(const framework::Tensor& tensor) { - ContainsInfPredicate pred_inf; - ContainsNANPredicate pred_nan; - return !Any(tensor, pred_inf) && !Any(tensor, pred_nan); -} - -#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) -template -static inline void __global__ BothFalse(const T* cmp, T* out, int element_num) { - CUDA_KERNEL_LOOP(i, element_num) { out[i] = (!cmp[i]) && (!out[i]); } -} -#endif - -struct BothFalseVisitor : public std::unary_function { - const framework::Tensor& in_; - mutable framework::Tensor* out_; - BothFalseVisitor(const framework::Tensor& in, framework::Tensor* out) - : in_(in), out_(out) {} - - template - void operator()(const Place& place) const { - VisitorImpl(place); - } - - void VisitorImpl(const platform::XPUPlace& xpu) const { - PADDLE_THROW(platform::errors::Unimplemented("XPUPlace is not supported")); - } - void VisitorImpl(const platform::IPUPlace& ipu) const { - PADDLE_THROW(platform::errors::Unimplemented("IPUPlace is not supported")); - } - - void VisitorImpl(const platform::CUDAPlace& gpu) const { -#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) - auto* ctx = platform::DeviceContextPool::Instance().GetByPlace(gpu); - constexpr int MAX_BLOCK_DIM = 512; - const int MAX_GRID_DIM = ctx->GetMaxPhysicalThreadCount() / MAX_BLOCK_DIM; - int element_num = in_.numel(); - int block_size = (element_num >= MAX_BLOCK_DIM) - ? MAX_BLOCK_DIM - : (1 << static_cast(std::log2(element_num))); - int grid_size = element_num / block_size; - grid_size = (grid_size >= MAX_GRID_DIM) ? MAX_GRID_DIM : grid_size; - BothFalse<<stream()>>>( - in_.data(), out_->mutable_data(gpu), element_num); -#endif - } - - void VisitorImpl(const platform::NPUPlace& npu) const { - // TODO(zhiqiu) - } - - void VisitorImpl(const platform::MLUPlace& mlu) const { - PADDLE_THROW(platform::errors::Unimplemented("MLUPlace is not supported")); - } - - void VisitorImpl(const platform::CPUPlace& cpu) const { - int num = in_.numel(); - const bool* in_ptr = in_.data(); - bool* out_ptr = out_->data(); - for (int i = 0; i < num; ++i) { - bool lhs = !in_ptr[i]; - bool rhs = !out_ptr[i]; - out_ptr[i] = lhs && rhs; - } - } - - void VisitorImpl( - const platform::CUDAPinnedPlace& cpu /* equals to cpu*/) const { - int num = in_.numel(); - const bool* in_ptr = in_.data(); - bool* out_ptr = out_->data(); - for (int i = 0; i < num; ++i) { - bool lhs = !in_ptr[i]; - bool rhs = !out_ptr[i]; - out_ptr[i] = lhs && rhs; - } - } - - void VisitorImpl( - const platform::NPUPinnedPlace& cpu /* equals to cpu*/) const { - int num = in_.numel(); - const bool* in_ptr = in_.data(); - bool* out_ptr = out_->data(); - for (int i = 0; i < num; ++i) { - bool lhs = !in_ptr[i]; - bool rhs = !out_ptr[i]; - out_ptr[i] = lhs && rhs; - } - } - - void VisitorImpl(const platform::CustomPlace& custom_dev) const { - PADDLE_THROW( - platform::errors::Unimplemented("CustomPlace is not supported")); - } -}; - -void TensorIsfinite(const framework::Tensor& tensor, framework::Tensor* out) { - framework::Tensor tmp; - TensorContainsInf(tensor, &tmp); - TensorContainsNAN(tensor, out); - BothFalseVisitor visitor(tmp, out); - auto place = tensor.place(); - platform::VisitPlace(place, visitor); -} - -void TensorIsfiniteV2(const framework::Tensor& tensor, framework::Tensor* out) { - framework::Tensor tmp; - TensorContainsInfV2(tensor, &tmp); - TensorContainsNANV2(tensor, out); - BothFalseVisitor visitor(tmp, out); - auto place = tensor.place(); - platform::VisitPlace(place, visitor); -} - void TensorToStream(std::ostream& os, const Tensor& tensor, const platform::DeviceContext& dev_ctx) { diff --git a/paddle/fluid/framework/tensor_util.h b/paddle/fluid/framework/tensor_util.h index 3c9d1284cef..c617441fd69 100644 --- a/paddle/fluid/framework/tensor_util.h +++ b/paddle/fluid/framework/tensor_util.h @@ -112,16 +112,6 @@ void TensorToVector(const Tensor& src, template void TesnorToVector(const Tensor& src, std::vector* dst); -// copy the result bool to cpu -bool TensorContainsNAN(const framework::Tensor& tensor); -bool TensorContainsInf(const framework::Tensor& tensor); -bool TensorIsfinite(const framework::Tensor& tensor); - -// store the result bool in gpu tensor, async operation. Faster than above ones. -void TensorContainsNAN(const framework::Tensor& tensor, framework::Tensor* out); -void TensorContainsInf(const framework::Tensor& tensor, framework::Tensor* out); -void TensorIsfinite(const framework::Tensor& tensor, framework::Tensor* out); - void TensorToStream(std::ostream& os, const Tensor& tensor, const platform::DeviceContext& dev_ctx); @@ -134,13 +124,6 @@ void TensorFromStream(std::istream& is, const size_t& seek, const std::vector& shape); -// store the bool result tensor in out tensor -void TensorContainsNANV2(const framework::Tensor& tensor, - framework::Tensor* out); -void TensorContainsInfV2(const framework::Tensor& tensor, - framework::Tensor* out); -void TensorIsfiniteV2(const framework::Tensor& tensor, framework::Tensor* out); - // convert dlpack's DLTensor to tensor void TensorFromDLPack(const ::DLTensor& dl_tensor, framework::Tensor* dst); @@ -601,6 +584,24 @@ inline Tensor ReshapeToMatrix(const Tensor& src, int num_col_dims) { return res; } +template +inline T GetValue(const framework::Tensor* x) { + T value = static_cast(0); + if (!platform::is_cpu_place(x->place())) { + framework::Tensor cpu_x; + framework::TensorCopy(*x, platform::CPUPlace(), &cpu_x); +#if defined(PADDLE_WITH_ASCEND_CL) || defined(PADDLE_WITH_MLU) + platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); + const platform::DeviceContext* dev_ctx = pool.Get(x->place()); + dev_ctx->Wait(); +#endif + value = cpu_x.data()[0]; + } else { + value = x->data()[0]; + } + return value; +} + } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/tensor_util_test.cc b/paddle/fluid/framework/tensor_util_test.cc index 36be5cde506..c7db2186e5d 100644 --- a/paddle/fluid/framework/tensor_util_test.cc +++ b/paddle/fluid/framework/tensor_util_test.cc @@ -13,8 +13,8 @@ // limitations under the License. #include "paddle/fluid/framework/tensor_util.h" - #include +#include "paddle/fluid/operators/isfinite_op.h" #include diff --git a/paddle/fluid/framework/tensor_util_test.cu b/paddle/fluid/framework/tensor_util_test.cu index f888b9632c7..53807beab91 100644 --- a/paddle/fluid/framework/tensor_util_test.cu +++ b/paddle/fluid/framework/tensor_util_test.cu @@ -14,6 +14,7 @@ #include "paddle/fluid/framework/tensor_util.h" #include "gtest/gtest.h" +#include "paddle/fluid/operators/isfinite_op.h" #include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/place.h" diff --git a/paddle/fluid/operators/isfinite_op.cc b/paddle/fluid/operators/isfinite_op.cc index 77583fd2d30..5c49c5e9451 100644 --- a/paddle/fluid/operators/isfinite_op.cc +++ b/paddle/fluid/operators/isfinite_op.cc @@ -122,14 +122,6 @@ namespace ops = paddle::operators; paddle::framework::EmptyGradOpMaker, \ paddle::framework::EmptyGradOpMaker) -#define REGISTER_OVERFLOW_CPU_KERNEL(op_type, functor) \ - REGISTER_OP_CPU_KERNEL( \ - op_type, \ - ops::OverflowKernel, \ - ops::OverflowKernel, \ - ops::OverflowKernel, \ - ops::OverflowKernel); - REGISTER_OP_MAKER(isinf, "isinf(X)"); REGISTER_OP_MAKER(isnan, "isnan(X)"); REGISTER_OP_MAKER(isfinite, "isfinite(X)"); diff --git a/paddle/fluid/operators/isfinite_op.h b/paddle/fluid/operators/isfinite_op.h index 9b4a3e0af2e..427d3569986 100644 --- a/paddle/fluid/operators/isfinite_op.h +++ b/paddle/fluid/operators/isfinite_op.h @@ -21,14 +21,128 @@ #include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/transform.h" +#include "paddle/phi/kernels/isfinite_kernel.h" +#include "paddle/phi/kernels/reduce_all_kernel.h" +#include "paddle/phi/kernels/reduce_any_kernel.h" namespace phi { class DenseTensor; } // namespace phi namespace paddle { -namespace operators { +namespace framework { +// store the result bool in gpu tensor, async operation. Faster than above ones. +void TensorContainsNAN(const framework::Tensor& tensor, framework::Tensor* out); +void TensorContainsInf(const framework::Tensor& tensor, framework::Tensor* out); +void TensorIsfinite(const framework::Tensor& tensor, framework::Tensor* out); + +// copy the result bool to cpu +bool TensorContainsNAN(const framework::Tensor& tensor); +bool TensorContainsInf(const framework::Tensor& tensor); +bool TensorIsfinite(const framework::Tensor& tensor); + +#define FiniteVisitor(type, reduce_type, device) \ + struct type##Visitor##device { \ + type##Visitor##device(const phi::DenseTensor& in, phi::DenseTensor* out) \ + : in_(in), out_(out) {} \ + template \ + void apply() const { \ + auto place = in_.place(); \ + auto* ctx = static_cast( \ + platform::DeviceContextPool::Instance().Get(place)); \ + Tensor tmp; \ + tmp.Resize(in_.dims()); \ + out_->Resize({1}); \ + std::vector dims(tmp.dims().size()); \ + std::iota(dims.begin(), dims.end(), 0); \ + phi::type##Kernel(*ctx, in_, &tmp); \ + phi::reduce_type##Kernel( \ + *ctx, tmp, dims, false, out_); \ + } \ + const phi::DenseTensor& in_; \ + phi::DenseTensor* out_; \ + }; + +FiniteVisitor(Isnan, Any, CPU); +FiniteVisitor(Isinf, Any, CPU); +FiniteVisitor(Isfinite, All, CPU); +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) +FiniteVisitor(Isnan, Any, GPU); +FiniteVisitor(Isinf, Any, GPU); +FiniteVisitor(Isfinite, All, GPU); +#endif + +// store the result bool in gpu tensor, async operation. Faster than above ones. +inline void TensorContainsNAN(const framework::Tensor& tensor, + framework::Tensor* out) { + auto place = tensor.place(); + if (platform::is_cpu_place(tensor.place())) { + VisitDataTypeNormal(TransToProtoVarType(tensor.dtype()), + IsnanVisitorCPU(tensor, out)); + return; + } +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) + if (platform::is_gpu_place(place)) { + VisitDataTypeNormal(TransToProtoVarType(tensor.dtype()), + IsnanVisitorGPU(tensor, out)); + return; + } +#endif + PADDLE_THROW(platform::errors::Unimplemented("Not supported on %s.", place)); +} +inline void TensorContainsInf(const framework::Tensor& tensor, + framework::Tensor* out) { + auto place = tensor.place(); + if (platform::is_cpu_place(tensor.place())) { + VisitDataTypeNormal(TransToProtoVarType(tensor.dtype()), + IsinfVisitorCPU(tensor, out)); + return; + } +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) + if (platform::is_gpu_place(place)) { + VisitDataTypeNormal(TransToProtoVarType(tensor.dtype()), + IsinfVisitorGPU(tensor, out)); + return; + } +#endif + PADDLE_THROW(platform::errors::Unimplemented("Not supported on %s.", place)); +} +inline void TensorIsfinite(const framework::Tensor& tensor, + framework::Tensor* out) { + auto place = tensor.place(); + if (platform::is_cpu_place(tensor.place())) { + VisitDataTypeNormal(TransToProtoVarType(tensor.dtype()), + IsfiniteVisitorCPU(tensor, out)); + return; + } +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) + if (platform::is_gpu_place(place)) { + VisitDataTypeNormal(TransToProtoVarType(tensor.dtype()), + IsfiniteVisitorGPU(tensor, out)); + return; + } +#endif + PADDLE_THROW(platform::errors::Unimplemented("Not supported on %s.", place)); +} +// copy the result bool to cpu +inline bool TensorContainsNAN(const framework::Tensor& tensor) { + Tensor out; + TensorContainsNAN(tensor, &out); + return GetValue(&out); +} +inline bool TensorContainsInf(const framework::Tensor& tensor) { + Tensor out; + TensorContainsInf(tensor, &out); + return GetValue(&out); +} +inline bool TensorIsfinite(const framework::Tensor& tensor) { + Tensor out; + TensorIsfinite(tensor, &out); + return GetValue(&out); +} +} // namespace framework +namespace operators { struct InfinityFunctor { void operator()(const framework::Tensor& tensor, framework::Tensor* out) { framework::TensorContainsInf(tensor, out); diff --git a/paddle/fluid/operators/memcpy_d2h_op.h b/paddle/fluid/operators/memcpy_d2h_op.h index 9be6309fc0c..ce90828f3fc 100644 --- a/paddle/fluid/operators/memcpy_d2h_op.h +++ b/paddle/fluid/operators/memcpy_d2h_op.h @@ -72,6 +72,7 @@ class MemcpyD2HFunctor { framework::LoDTensor &dst) const { // NOLINT if (dst_place_type_ == 1) { framework::TensorCopy(src, platform::CUDAPinnedPlace(), dev_ctx_, &dst); + dev_ctx_.Wait(); } else if (dst_place_type_ == 0) { framework::TensorCopy(src, platform::CPUPlace(), dev_ctx_, &dst); } else { diff --git a/paddle/phi/common/float16.h b/paddle/phi/common/float16.h index 1cdcdef2c12..4b0799a1774 100644 --- a/paddle/phi/common/float16.h +++ b/paddle/phi/common/float16.h @@ -1028,6 +1028,10 @@ inline bool isnan(const phi::dtype::float16& a) { return phi::dtype::isnan(a); } inline bool isinf(const phi::dtype::float16& a) { return phi::dtype::isinf(a); } +inline bool isfinite(const phi::dtype::float16& a) { + return phi::dtype::isfinite(a); +} + template <> struct numeric_limits { static const bool is_specialized = true; diff --git a/paddle/phi/kernels/cpu/amp_kernel.cc b/paddle/phi/kernels/cpu/amp_kernel.cc index d15c91dc8dd..23048ba337d 100644 --- a/paddle/phi/kernels/cpu/amp_kernel.cc +++ b/paddle/phi/kernels/cpu/amp_kernel.cc @@ -21,6 +21,8 @@ #include "paddle/phi/kernels/empty_kernel.h" #include "paddle/phi/kernels/funcs/eigen/common.h" #include "paddle/phi/kernels/impl/amp_kernel_impl.h" +#include "paddle/phi/kernels/isfinite_kernel.h" +#include "paddle/phi/kernels/reduce_all_kernel.h" #include "paddle/fluid/framework/tensor_util.h" @@ -85,7 +87,13 @@ void CheckFiniteAndUnscaleKernel(const Context& dev_ctx, auto* out = outs[i]; dev_ctx.template Alloc(out); if (!(*found_inf_data)) { - paddle::framework::TensorIsfinite(*x, &is_finite); + DenseTensor tmp; + tmp.Resize(x->dims()); + phi::IsfiniteKernel(dev_ctx, *x, &tmp); + + std::vector dims(x->dims().size()); + std::iota(dims.begin(), dims.end(), 0); + phi::AllKernel(dev_ctx, tmp, dims, false, &is_finite); *found_inf_data = !(*is_finite_data); } auto eigen_out = EigenVector::Flatten(*out); diff --git a/paddle/phi/kernels/cpu/isfinite_kernel.cc b/paddle/phi/kernels/cpu/isfinite_kernel.cc index 33a7429a22a..e3cc3c83598 100644 --- a/paddle/phi/kernels/cpu/isfinite_kernel.cc +++ b/paddle/phi/kernels/cpu/isfinite_kernel.cc @@ -18,19 +18,6 @@ #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/impl/isfinite_kernel_impl.h" -namespace phi { - -template -inline void IsfiniteKernelImpl(const Context& dev_ctx, - const DenseTensor& x, - DenseTensor* out) { - dev_ctx.template Alloc(out); - Functor functor; - functor(x, out); -} - -} // namespace phi - PD_REGISTER_KERNEL(isinf, CPU, ALL_LAYOUT, diff --git a/paddle/phi/kernels/funcs/CMakeLists.txt b/paddle/phi/kernels/funcs/CMakeLists.txt index afa46f1daca..e21bea2e242 100644 --- a/paddle/phi/kernels/funcs/CMakeLists.txt +++ b/paddle/phi/kernels/funcs/CMakeLists.txt @@ -9,7 +9,7 @@ math_library(fc_functor DEPS blas jit_kernel_helper) math_library(gpc DEPS phi_enforce) math_library(gru_compute DEPS activation_functions math_function) math_library(lstm_compute DEPS activation_functions) -math_library(math_function DEPS blas dense_tensor tensor) +math_library(math_function DEPS blas dense_tensor) math_library(matrix_reduce DEPS dense_tensor) math_library(matrix_inverse DEPS dense_tensor eigen3 blas) math_library(pooling DEPS dense_tensor) diff --git a/paddle/phi/kernels/funcs/isfinite_functor.h b/paddle/phi/kernels/funcs/isfinite_functor.h index c804bee8d4c..1dc4fd57b48 100644 --- a/paddle/phi/kernels/funcs/isfinite_functor.h +++ b/paddle/phi/kernels/funcs/isfinite_functor.h @@ -14,30 +14,83 @@ #pragma once -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/framework/tensor_util.h" -#include "paddle/phi/common/scalar.h" -#include "paddle/phi/core/dense_tensor.h" -#include "paddle/phi/infermeta/unary.h" - namespace phi { namespace funcs { -struct InfinityV2Functor { - void operator()(const DenseTensor& tensor, DenseTensor* out) { - paddle::framework::TensorContainsInfV2(tensor, out); +template +struct IsNanFunctor { + HOSTDEVICE bool operator()(const T& a) const { +#if defined(__CUDACC__) || defined(__HIPCC__) + return ::isnan(a); +#else + return std::isnan(a); +#endif + } +}; + +template +struct IsNanFunctor::value>::type> { + HOSTDEVICE bool operator()(const T& a) const { return false; } +}; + +// isnan is defined in namespace std in float16.h, but +// on rocm platform, it still got: +// "error: call to 'isnan' is ambiguous". +// So use phi::dtype::isnan here. +template <> +struct IsNanFunctor { + HOSTDEVICE bool operator()(const phi::dtype::float16& a) const { + return phi::dtype::isnan(a); + } +}; + +template +struct IsInfFunctor { + HOSTDEVICE bool operator()(const T& a) const { +#if defined(__CUDACC__) || defined(__HIPCC__) + return ::isinf(a); +#else + return std::isinf(a); +#endif + } +}; + +template +struct IsInfFunctor::value>::type> { + HOSTDEVICE bool operator()(const T& a) const { return false; } +}; + +template <> +struct IsInfFunctor { + HOSTDEVICE bool operator()(const phi::dtype::float16& a) const { + return phi::dtype::isinf(a); } }; -struct NANV2Functor { - void operator()(const DenseTensor& tensor, DenseTensor* out) { - paddle::framework::TensorContainsNANV2(tensor, out); +template +struct IsFiniteFunctor { + HOSTDEVICE bool operator()(const T& a) const { +#if defined(__CUDACC__) || defined(__HIPCC__) + return ::isfinite(a); +#else + return std::isfinite(a); +#endif } }; -struct IsfiniteV2Functor { - void operator()(const DenseTensor& tensor, DenseTensor* out) { - paddle::framework::TensorIsfiniteV2(tensor, out); +template +struct IsFiniteFunctor< + T, + typename std::enable_if::value>::type> { + HOSTDEVICE bool operator()(const T& a) const { return true; } +}; + +template <> +struct IsFiniteFunctor { + HOSTDEVICE bool operator()(const phi::dtype::float16& a) const { + return phi::dtype::isfinite(a); } }; diff --git a/paddle/phi/kernels/gpu/isfinite_kernel.cu b/paddle/phi/kernels/gpu/isfinite_kernel.cu index 17ea2586517..c7dde29101b 100644 --- a/paddle/phi/kernels/gpu/isfinite_kernel.cu +++ b/paddle/phi/kernels/gpu/isfinite_kernel.cu @@ -18,19 +18,6 @@ #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/impl/isfinite_kernel_impl.h" -namespace phi { - -template -inline void IsfiniteKernelImpl(const Context& dev_ctx, - const DenseTensor& x, - DenseTensor* out) { - dev_ctx.template Alloc(out); - Functor functor; - functor(x, out); -} - -} // namespace phi - PD_REGISTER_KERNEL(isinf, GPU, ALL_LAYOUT, diff --git a/paddle/phi/kernels/impl/isfinite_kernel_impl.h b/paddle/phi/kernels/impl/isfinite_kernel_impl.h index affa85f8a2d..d36a7cb915e 100644 --- a/paddle/phi/kernels/impl/isfinite_kernel_impl.h +++ b/paddle/phi/kernels/impl/isfinite_kernel_impl.h @@ -17,23 +17,24 @@ #include "paddle/phi/kernels/funcs/isfinite_functor.h" #include "paddle/phi/kernels/isfinite_kernel.h" -namespace phi { +// See Note [ Why still include the fluid headers? ] +#include "paddle/fluid/platform/transform.h" -template -inline void IsfiniteKernelImpl(const Context& ctx, - const DenseTensor& x, - DenseTensor* out); +namespace phi { -#define DEFINE_ISFINITE_KERNEL(isfinite_kernel, functor) \ - template \ - void isfinite_kernel( \ - const Context& ctx, const DenseTensor& x, DenseTensor* out) { \ - IsfiniteKernelImpl(ctx, x, out); \ +#define DEFINE_ISFINITE_KERNEL(isfinite_kernel, functor) \ + template \ + void isfinite_kernel( \ + const Context& ctx, const DenseTensor& x, DenseTensor* out) { \ + auto* out_ptr = ctx.template Alloc(out); \ + funcs::functor unary_func; \ + paddle::platform::Transform trans; \ + trans(ctx, x.data(), x.data() + x.numel(), out_ptr, unary_func); \ } -DEFINE_ISFINITE_KERNEL(IsinfKernel, funcs::InfinityV2Functor) -DEFINE_ISFINITE_KERNEL(IsnanKernel, funcs::NANV2Functor) -DEFINE_ISFINITE_KERNEL(IsfiniteKernel, funcs::IsfiniteV2Functor) +DEFINE_ISFINITE_KERNEL(IsinfKernel, IsInfFunctor) +DEFINE_ISFINITE_KERNEL(IsnanKernel, IsNanFunctor) +DEFINE_ISFINITE_KERNEL(IsfiniteKernel, IsFiniteFunctor) #undef DEFINE_ISFINITE_KERNEL } // namespace phi diff --git a/paddle/phi/kernels/selected_rows/impl/isfinite_kernel_impl.h b/paddle/phi/kernels/selected_rows/impl/isfinite_kernel_impl.h index c53abdf996c..a6436d1d1ef 100644 --- a/paddle/phi/kernels/selected_rows/impl/isfinite_kernel_impl.h +++ b/paddle/phi/kernels/selected_rows/impl/isfinite_kernel_impl.h @@ -19,21 +19,17 @@ namespace phi { -template -inline void IsfiniteSRImpl(const Context& ctx, - const SelectedRows& x, - SelectedRows* out); - -#define DEFINE_ISFINITE_SR(isfinite_sr, functor) \ +#define DEFINE_ISFINITE_SR(isfinite) \ template \ - void isfinite_sr( \ + void isfinite##SR( \ const Context& ctx, const SelectedRows& x, SelectedRows* out) { \ - IsfiniteSRImpl(ctx, x, out); \ + ctx.template Alloc(out); \ + Isinf##Kernel(ctx, x.value(), out->mutable_value()); \ } -DEFINE_ISFINITE_SR(IsinfSR, funcs::InfinityV2Functor) -DEFINE_ISFINITE_SR(IsnanSR, funcs::NANV2Functor) -DEFINE_ISFINITE_SR(IsfiniteSR, funcs::IsfiniteV2Functor) +DEFINE_ISFINITE_SR(Isinf) +DEFINE_ISFINITE_SR(Isnan) +DEFINE_ISFINITE_SR(Isfinite) #undef DEFINE_ISFINITE_SR } // namespace phi diff --git a/paddle/phi/kernels/selected_rows/isfinite_kernel.cc b/paddle/phi/kernels/selected_rows/isfinite_kernel.cc index 630f6bcf835..d68688a7e40 100644 --- a/paddle/phi/kernels/selected_rows/isfinite_kernel.cc +++ b/paddle/phi/kernels/selected_rows/isfinite_kernel.cc @@ -21,18 +21,6 @@ #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/selected_rows/impl/isfinite_kernel_impl.h" -namespace phi { - -template -inline void IsfiniteSRImpl(const Context& dev_ctx, - const SelectedRows& x, - SelectedRows* out) { - dev_ctx.template Alloc(out); - Functor functor; - functor(x.value(), out->mutable_value()); -} -} // namespace phi - PD_REGISTER_KERNEL(isinf_sr, CPU, ALL_LAYOUT, diff --git a/paddle/phi/kernels/selected_rows/isfinite_kernel.h b/paddle/phi/kernels/selected_rows/isfinite_kernel.h index 948d8c89477..447edd65805 100644 --- a/paddle/phi/kernels/selected_rows/isfinite_kernel.h +++ b/paddle/phi/kernels/selected_rows/isfinite_kernel.h @@ -15,6 +15,7 @@ limitations under the License. */ #pragma once #include "paddle/phi/core/selected_rows.h" +#include "paddle/phi/kernels/isfinite_kernel.h" namespace phi { -- GitLab