diff --git a/paddle/fluid/operators/math/pooling.h b/paddle/fluid/operators/math/pooling.h index 21d588cc01f322b44c029d7cc9f95b8f5262a864..3547de0a4d7b7f7e5974bfd733a36deb561c10ba 100644 --- a/paddle/fluid/operators/math/pooling.h +++ b/paddle/fluid/operators/math/pooling.h @@ -18,6 +18,7 @@ limitations under the License. */ #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/tensor.h" +#include "paddle/fluid/operators/amp/fp16_type_traits.h" #include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/hostdevice.h" #include "paddle/fluid/platform/macros.h" @@ -46,10 +47,22 @@ class MaxPool { template class AvgPool { + using MT = typename details::MPTypeTrait::Type; + MT intermediate_res; + public: - DEVICE inline T initial() { return static_cast(0); } - DEVICE inline void compute(const T& x, T* y) { *y += x; } - DEVICE inline void finalize(const T& pool_field, T* y) { *y /= pool_field; } + DEVICE inline T initial() { + intermediate_res = static_cast(0.0f); + return static_cast(0); + } + + DEVICE inline void compute(const T& x, T* y) { + intermediate_res += static_cast(x); + } + + DEVICE inline void finalize(const T& pool_field, T* y) { + *y = static_cast(intermediate_res / (static_cast(pool_field))); + } }; template