isfinite_op.h 7.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <vector>
W
wanghuancoder 已提交
18

19 20 21 22 23
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/fluid/platform/transform.h"
24 25 26
#include "paddle/phi/kernels/isfinite_kernel.h"
#include "paddle/phi/kernels/reduce_all_kernel.h"
#include "paddle/phi/kernels/reduce_any_kernel.h"
27

28
namespace phi {
29
class DenseTensor;
30
}  // namespace phi
W
wanghuancoder 已提交
31

32
namespace paddle {
33 34
namespace framework {
// store the result bool in gpu tensor, async operation. Faster than above ones.
35 36 37
void TensorContainsNAN(const phi::DenseTensor& tensor, phi::DenseTensor* out);
void TensorContainsInf(const phi::DenseTensor& tensor, phi::DenseTensor* out);
void TensorIsfinite(const phi::DenseTensor& tensor, phi::DenseTensor* out);
38 39

// copy the result bool to cpu
40 41 42
bool TensorContainsNAN(const phi::DenseTensor& tensor);
bool TensorContainsInf(const phi::DenseTensor& tensor);
bool TensorIsfinite(const phi::DenseTensor& tensor);
43 44 45 46 47 48 49 50 51 52

#define FiniteVisitor(type, reduce_type, device)                             \
  struct type##Visitor##device {                                             \
    type##Visitor##device(const phi::DenseTensor& in, phi::DenseTensor* out) \
        : in_(in), out_(out) {}                                              \
    template <typename T>                                                    \
    void apply() const {                                                     \
      auto place = in_.place();                                              \
      auto* ctx = static_cast<phi::device##Context*>(                        \
          platform::DeviceContextPool::Instance().Get(place));               \
53
      phi::DenseTensor tmp;                                                  \
54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75
      tmp.Resize(in_.dims());                                                \
      out_->Resize({1});                                                     \
      std::vector<int64_t> dims(tmp.dims().size());                          \
      std::iota(dims.begin(), dims.end(), 0);                                \
      phi::type##Kernel<T, phi::device##Context>(*ctx, in_, &tmp);           \
      phi::reduce_type##Kernel<bool, phi::device##Context>(                  \
          *ctx, tmp, dims, false, out_);                                     \
    }                                                                        \
    const phi::DenseTensor& in_;                                             \
    phi::DenseTensor* out_;                                                  \
  };

FiniteVisitor(Isnan, Any, CPU);
FiniteVisitor(Isinf, Any, CPU);
FiniteVisitor(Isfinite, All, CPU);
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
FiniteVisitor(Isnan, Any, GPU);
FiniteVisitor(Isinf, Any, GPU);
FiniteVisitor(Isfinite, All, GPU);
#endif

// store the result bool in gpu tensor, async operation. Faster than above ones.
76 77
inline void TensorContainsNAN(const phi::DenseTensor& tensor,
                              phi::DenseTensor* out) {
78 79 80 81 82 83 84 85 86 87 88 89 90 91 92
  auto place = tensor.place();
  if (platform::is_cpu_place(tensor.place())) {
    VisitDataTypeNormal(TransToProtoVarType(tensor.dtype()),
                        IsnanVisitorCPU(tensor, out));
    return;
  }
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
  if (platform::is_gpu_place(place)) {
    VisitDataTypeNormal(TransToProtoVarType(tensor.dtype()),
                        IsnanVisitorGPU(tensor, out));
    return;
  }
#endif
  PADDLE_THROW(platform::errors::Unimplemented("Not supported on %s.", place));
}
93 94
inline void TensorContainsInf(const phi::DenseTensor& tensor,
                              phi::DenseTensor* out) {
95 96 97 98 99 100 101 102 103 104 105 106 107 108 109
  auto place = tensor.place();
  if (platform::is_cpu_place(tensor.place())) {
    VisitDataTypeNormal(TransToProtoVarType(tensor.dtype()),
                        IsinfVisitorCPU(tensor, out));
    return;
  }
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
  if (platform::is_gpu_place(place)) {
    VisitDataTypeNormal(TransToProtoVarType(tensor.dtype()),
                        IsinfVisitorGPU(tensor, out));
    return;
  }
#endif
  PADDLE_THROW(platform::errors::Unimplemented("Not supported on %s.", place));
}
110 111
inline void TensorIsfinite(const phi::DenseTensor& tensor,
                           phi::DenseTensor* out) {
112 113 114 115 116 117 118 119 120 121 122 123 124 125 126
  auto place = tensor.place();
  if (platform::is_cpu_place(tensor.place())) {
    VisitDataTypeNormal(TransToProtoVarType(tensor.dtype()),
                        IsfiniteVisitorCPU(tensor, out));
    return;
  }
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
  if (platform::is_gpu_place(place)) {
    VisitDataTypeNormal(TransToProtoVarType(tensor.dtype()),
                        IsfiniteVisitorGPU(tensor, out));
    return;
  }
#endif
  PADDLE_THROW(platform::errors::Unimplemented("Not supported on %s.", place));
}
127

128
// copy the result bool to cpu
129 130
inline bool TensorContainsNAN(const phi::DenseTensor& tensor) {
  phi::DenseTensor out;
131 132 133
  TensorContainsNAN(tensor, &out);
  return GetValue<bool>(&out);
}
134 135
inline bool TensorContainsInf(const phi::DenseTensor& tensor) {
  phi::DenseTensor out;
136 137 138
  TensorContainsInf(tensor, &out);
  return GetValue<bool>(&out);
}
139 140
inline bool TensorIsfinite(const phi::DenseTensor& tensor) {
  phi::DenseTensor out;
141 142 143 144 145
  TensorIsfinite(tensor, &out);
  return GetValue<bool>(&out);
}
}  // namespace framework
namespace operators {
146
struct InfinityFunctor {
147
  void operator()(const phi::DenseTensor& tensor, phi::DenseTensor* out) {
148 149 150 151 152
    framework::TensorContainsInf(tensor, out);
  }
};

struct NANFunctor {
153
  void operator()(const phi::DenseTensor& tensor, phi::DenseTensor* out) {
154 155 156 157 158
    framework::TensorContainsNAN(tensor, out);
  }
};

struct IsfiniteFunctor {
159
  void operator()(const phi::DenseTensor& tensor, phi::DenseTensor* out) {
160 161 162 163 164 165 166 167 168
    framework::TensorIsfinite(tensor, out);
  }
};

template <typename DeviceContext, typename T, typename Functor>
class OverflowKernel : public framework::OpKernel<T> {
 public:
  virtual void Compute(const framework::ExecutionContext& ctx) const {
    auto* x = ctx.InputVar("X");
169
    auto* out = ctx.Output<phi::DenseTensor>("Out");
170 171
    out->mutable_data<T>(ctx.GetPlace());
    Functor functor;
172
    if (x->IsType<phi::DenseTensor>()) {
173
      auto* in = ctx.Input<phi::DenseTensor>("X");
174
      functor(*in, out);
175 176
    } else if (x->IsType<phi::SelectedRows>()) {
      auto& in = ctx.Input<phi::SelectedRows>("X")->value();
177 178
      functor(in, out);
    } else {
179 180 181 182 183 184
      PADDLE_ENFORCE_EQ(true,
                        false,
                        platform::errors::InvalidArgument(
                            "The input type mismatch, the type of Input(X) "
                            "must be phi::DenseTensor or "
                            "SelectedRows, please check your input."));
185 186 187 188 189 190
    }
  }
};

}  // namespace operators
}  // namespace paddle