From 15309fde2c50a485fd120f749661ea16a6c75232 Mon Sep 17 00:00:00 2001 From: Yang Yu Date: Wed, 27 Dec 2017 14:45:04 +0800 Subject: [PATCH] Add API for HasNAN HasInf --- paddle/framework/tensor_util.h | 96 ++++++++++++++++++++++++++++++++ paddle/platform/device_context.h | 20 +++++++ paddle/platform/place.h | 28 +++++++++- 3 files changed, 143 insertions(+), 1 deletion(-) diff --git a/paddle/framework/tensor_util.h b/paddle/framework/tensor_util.h index ea4e4f22e..5c7822814 100644 --- a/paddle/framework/tensor_util.h +++ b/paddle/framework/tensor_util.h @@ -13,7 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once +#include "paddle/framework/data_type.h" +#include "paddle/framework/eigen.h" #include "paddle/framework/tensor.h" +#include "paddle/platform/device_context.h" namespace paddle { namespace framework { @@ -205,5 +208,98 @@ inline void CopyToVector(const Tensor& src, std::vector* dst) { src_ptr, size); } +template +struct AnyDTypeVisitor { + Predicate predicate_; + const Tensor& tensor_; + const DevCtx& ctx_; + Tensor* out_; + + AnyDTypeVisitor(Predicate predicate, const Tensor& tensor, const DevCtx& ctx, + Tensor* out) + : predicate_(predicate), tensor_(tensor), ctx_(ctx), out_(out) {} + + template + void operator()() const { + auto t = EigenVector::Flatten(tensor_); + auto o = EigenScalar::From(*out_); + o.device(*ctx_.eigen_device()) = predicate_(t).any(); + } +}; + +template +inline void AnyImpl(Predicate predicate, const framework::Tensor& tensor, + const DevCtx& ctx, framework::Tensor* out) { + VisitDataType(ToDataType(tensor.type()), AnyDTypeVisitor( + predicate, tensor, ctx, out)); +} + +template +struct AnyVisitor : public boost::static_visitor { + const framework::Tensor& tensor_; + Predicate predicate_; + + AnyVisitor(const framework::Tensor& tensor, Predicate predicate) + : tensor_(tensor), predicate_(std::move(predicate)) {} + + template + bool operator()(const Place& place) const { + framework::Tensor out; + out.Resize({1}); + out.mutable_data(place); + auto* ctx = platform::DeviceContextPool::Instance().GetByPlace(place); + AnyImpl(predicate_, tensor_, *ctx, &out); + return this->GetResult(out, place); + } + + bool GetResult(const framework::Tensor& out, + const platform::CUDAPlace& gpu) const { + platform::CPUPlace cpu; + framework::Tensor tmp; + tmp.Resize({1}); + tmp.mutable_data(cpu); + platform::DeviceContextPool::Instance().Get(gpu)->Wait(); + CopyFrom(out, cpu, &tmp); + platform::DeviceContextPool::Instance().Get(gpu)->Wait(); + return GetResult(tmp, cpu); + } + + bool GetResult(const framework::Tensor& out, + const platform::CPUPlace& cpu) const { + return *out.data(); + } +}; + +template +inline bool Any(const framework::Tensor& tensor, Predicate predicate) { + AnyVisitor visitor(tensor, predicate); + auto place = tensor.place(); + return platform::VisitPlace(place, visitor); +} + +struct HasNanPredicate { + template + auto operator()(T eigen_vec) const -> decltype(std::declval().isnan()) { + return eigen_vec.isnan(); + } +}; + +inline bool HasNan(const framework::Tensor& tensor) { + HasNanPredicate predicate; + return Any(tensor, predicate); +} + +struct HasInfPredicate { + template + auto operator()(T eigen_vec) const -> decltype(std::declval().isinf()) { + return eigen_vec.isinf(); + } +}; + +inline bool HasInf(const framework::Tensor& tensor) { + HasInfPredicate predicate; + return Any(tensor, predicate); +} + } // namespace framework } // namespace paddle diff --git a/paddle/platform/device_context.h b/paddle/platform/device_context.h index dfef2c16d..fd441d27f 100644 --- a/paddle/platform/device_context.h +++ b/paddle/platform/device_context.h @@ -52,6 +52,14 @@ class CPUDeviceContext : public DeviceContext { std::unique_ptr eigen_device_; }; +template +struct DefaultDeviceContextType; + +template <> +struct DefaultDeviceContextType { + using TYPE = CPUDeviceContext; +}; + #ifdef PADDLE_WITH_CUDA class EigenCudaStreamDevice; @@ -90,6 +98,11 @@ class CUDADeviceContext : public DeviceContext { cublasHandle_t cublas_handle_; }; +template <> +struct DefaultDeviceContextType { + using T = CUDADeviceContext; +}; + class CUDNNDeviceContext : public CUDADeviceContext { public: explicit CUDNNDeviceContext(CUDAPlace place); @@ -125,6 +138,13 @@ class DeviceContextPool { /*! \brief Return handle of single device context. */ const platform::DeviceContext* Get(const platform::Place& place); + template + const typename DefaultDeviceContextType::TYPE* GetByPlace( + const Place& place) { + return reinterpret_cast< + const typename DefaultDeviceContextType::TYPE*>(Get(place)); + } + private: static DeviceContextPool* pool; constexpr static int LEFT_SHIFT = 8; diff --git a/paddle/platform/place.h b/paddle/platform/place.h index d25eaa689..76b5c502c 100644 --- a/paddle/platform/place.h +++ b/paddle/platform/place.h @@ -15,7 +15,7 @@ limitations under the License. */ #pragma once #include - +#include "paddle/platform/enforce.h" #include "paddle/platform/variant.h" namespace paddle { @@ -64,5 +64,31 @@ bool places_are_same_class(const Place &, const Place &); std::ostream &operator<<(std::ostream &, const Place &); +template +struct PlaceVisitorWrapper + : public boost::static_visitor { + const Visitor &visitor_; + explicit PlaceVisitorWrapper(const Visitor &visitor) : visitor_(visitor) {} + + typename Visitor::result_type operator()(const CPUPlace &cpu) const { + return visitor_(cpu); + } + + typename Visitor::result_type operator()(const CUDAPlace &cuda) const { +#ifdef PADDLE_WITH_CUDA + return visitor_(cuda); +#else + PADDLE_THROW("Paddle is not compiled with CUDA. Cannot visit cuda device"); + return typename Visitor::result_type(); +#endif + } +}; + +template +typename Visitor::result_type VisitPlace(const Place &place, + const Visitor &visitor) { + return boost::apply_visitor(PlaceVisitorWrapper(visitor), place); +} + } // namespace platform } // namespace paddle -- GitLab