#pragma once #include #include #include "megdnn/basic_types.h" #include "megdnn/dtype.h" #include "src/common/utils.h" using namespace megdnn; /* anonymous namespace */ namespace { using Mode = Reduce::Mode; /* Reduce Trait */ template struct Trait; template struct Trait { static const ctype INIT; static ctype apply(ctype x, ctype y) { return x + y; } static ctype visit(ctype x) { return x; } static ctype write(ctype x, size_t) { return x; } }; template const ctype Trait::INIT = ctype(0); template struct Trait { static const ctype INIT; static ctype apply(ctype x, ctype y) { return x + y; } static ctype visit(ctype x) { return x; } static ctype write(ctype x, size_t B) { return x / (ctype)B; } }; template const ctype Trait::INIT = ctype(0); template struct Trait { static const ctype INIT; static ctype apply(ctype x, ctype y) { return x + y; } static ctype visit(ctype x) { return x * x; } static ctype write(ctype x, size_t) { return x; } }; template const ctype Trait::INIT = ctype(0); template struct Trait { static const ctype INIT; static ctype apply(ctype x, ctype y) { return x * y; } static ctype visit(ctype x) { return x; } static ctype write(ctype x, size_t) { return x; } }; template const ctype Trait::INIT = ctype(1); template struct Trait { static ctype apply(ctype x, ctype y) { return x < y ? x : y; } static ctype visit(ctype x) { return x; } static ctype write(ctype x, size_t) { return x; } }; template <> struct Trait { using ctype = dt_float32; static ctype apply(ctype x, ctype y) { return (std::isnan(x) || x < y) ? x : y; } static ctype visit(ctype x) { return x; } static ctype write(ctype x, size_t) { return x; } }; template struct Trait { static ctype apply(ctype x, ctype y) { return x > y ? x : y; } static ctype visit(ctype x) { return x; } static ctype write(ctype x, size_t) { return x; } }; template <> struct Trait { using ctype = dt_float32; static ctype apply(ctype x, ctype y) { return (std::isnan(x) || x > y) ? x : y; } static ctype visit(ctype x) { return x; } static ctype write(ctype x, size_t) { return x; } }; /* NormOp */ template struct NormOp; template <> struct NormOp { typedef dt_float32 ctype; static const ctype INIT; static ctype apply(ctype x, ctype y) { return x + y; } static ctype visit(ctype x, dt_float32 p) { return powf(fabs(x), p); } static ctype write(ctype x, size_t, dt_float32 p) { return powf(x, 1.f / p); } }; #if !MEGDNN_DISABLE_FLOAT16 template <> struct NormOp { typedef dt_float16 ctype; static const ctype INIT; static ctype apply(ctype x, ctype y) { return x + y; } static ctype visit(ctype x, dt_float32 p) { return half_float::pow(half_float::abs(x), half_float::half(p)); } static ctype write(ctype x, size_t, dt_float32 p) { return half_float::pow(x, half_float::half(1.f / p)); } }; #endif template struct NormZeroOp; template <> struct NormZeroOp { typedef dt_float32 ctype; static const ctype INIT; static ctype apply(ctype x, ctype y) { return x + y; } static ctype visit(ctype x) { return x - 0.f < 0.00001f ? 0.f : 1.f; } static ctype write(ctype x, size_t) { return x; } }; #if !MEGDNN_DISABLE_FLOAT16 template <> struct NormZeroOp { typedef dt_float16 ctype; static const ctype INIT; static ctype apply(ctype x, ctype y) { return x + y; } static ctype visit(ctype x) { return x - half_float::half(0.f) < half_float::half(0.00001f) ? half_float::half(0.f) : half_float::half(1.f); } static ctype write(ctype x, size_t) { return x; } }; #endif } // namespace