/** * \file dnn/src/common/reduce_helper.h * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ #pragma once #include "megdnn/dtype.h" #include "megdnn/basic_types.h" namespace megdnn { namespace reduce { template struct SumOp { typedef wtype_ wtype; const wtype INIT; RefPtr src; RefPtr dst; const size_t B; wtype read(uint32_t idx) { return src.ptr()[idx]; } void write(uint32_t idx, wtype val) { dst.ptr()[idx] = val; } static wtype apply(wtype lhs, wtype rhs) { return lhs + rhs; } SumOp(const RefPtr& src, const RefPtr& dst, size_t B) : INIT(wtype(0)), src(src), dst(dst), B(B) {} }; template struct MeanOp { typedef wtype_ wtype; const wtype INIT; RefPtr src; RefPtr dst; const size_t B; wtype read(uint32_t idx) { return src.ptr()[idx]; } void write(uint32_t idx, wtype val) { dst.ptr()[idx] = val / static_cast(B); } static wtype apply(wtype lhs, wtype rhs) { return lhs + rhs; } MeanOp(const RefPtr& src, const RefPtr& dst, size_t B) : INIT(wtype(0)), src(src), dst(dst), B(B) {} }; template struct SumSqrOp { typedef wtype_ wtype; const wtype INIT; RefPtr src; RefPtr dst; const size_t B; wtype read(uint32_t idx) { return static_cast(src.ptr()[idx]) * static_cast(src.ptr()[idx]); } void write(uint32_t idx, wtype val) { dst.ptr()[idx] = val; } static wtype apply(wtype lhs, wtype rhs) { return lhs + rhs; } SumSqrOp(const RefPtr& src, const RefPtr& dst, size_t B) : INIT(wtype(0)), src(src), dst(dst), B(B) {} }; template struct ProdOp { typedef wtype_ wtype; const wtype INIT; RefPtr src; RefPtr dst; const size_t B; wtype read(uint32_t idx) { return src.ptr()[idx]; } void write(uint32_t idx, wtype val) { dst.ptr()[idx] = val; } static wtype apply(wtype lhs, wtype rhs) { return lhs * rhs; } ProdOp(const RefPtr& src, const RefPtr& dst, size_t B) : INIT(wtype(1)), src(src), dst(dst), B(B) {} }; template struct MinOp { typedef wtype_ wtype; const wtype INIT; RefPtr src; RefPtr dst; const size_t B; wtype read(uint32_t idx) { return src.ptr()[idx]; } void write(uint32_t idx, wtype val) { dst.ptr()[idx] = val; } static wtype apply(wtype lhs, wtype rhs) { return std::min(lhs, rhs); } MinOp(const RefPtr& src, const RefPtr& dst, size_t B) : INIT(wtype(DTypeTrait::max())), src(src), dst(dst), B(B) {} }; template struct MinOp { typedef dt_float32 wtype; const wtype INIT; RefPtr src; RefPtr dst; const size_t B; wtype read(uint32_t idx) { return src.ptr()[idx]; } void write(uint32_t idx, wtype val) { dst.ptr()[idx] = val; } static wtype apply(wtype lhs, wtype rhs) { return (std::isnan(lhs) || lhs < rhs) ? lhs : rhs; } MinOp(const RefPtr& src, const RefPtr& dst, size_t B) : INIT(wtype(DTypeTrait::max())), src(src), dst(dst), B(B) {} }; template struct MaxOp { typedef wtype_ wtype; const wtype INIT; RefPtr src; RefPtr dst; const size_t B; wtype read(uint32_t idx) { return src.ptr()[idx]; } void write(uint32_t idx, wtype val) { dst.ptr()[idx] = val; } static wtype apply(wtype lhs, wtype rhs) { return std::max(lhs, rhs); } MaxOp(const RefPtr& src, const RefPtr& dst, size_t B) : INIT(wtype(DTypeTrait::min())), src(src), dst(dst), B(B) {} }; template struct MaxOp { typedef dt_float32 wtype; const wtype INIT; RefPtr src; RefPtr dst; const size_t B; wtype read(uint32_t idx) { return src.ptr()[idx]; } void write(uint32_t idx, wtype val) { dst.ptr()[idx] = val; } static wtype apply(wtype lhs, wtype rhs) { return (std::isnan(lhs) || lhs > rhs) ? lhs : rhs; } MaxOp(const RefPtr& src, const RefPtr& dst, size_t B) : INIT(wtype(DTypeTrait::min())), src(src), dst(dst), B(B) {} }; template struct CheckNonFiniteOp { typedef wtype_ wtype; const wtype INIT; RefPtr* srcs; RefPtr srcs_total_nr_elems; RefPtr dst; const size_t B; wtype read(uint32_t idx) { size_t x = idx / B; size_t y = idx % B; if (y < srcs_total_nr_elems.ptr()[x]) { RefPtr src = srcs[x]; return !std::isfinite(src.ptr()[y]); } return 0; } void write(uint32_t idx, wtype val) { dst.ptr()[idx] = val; } static wtype apply(wtype lhs, wtype rhs) { return lhs | rhs; } CheckNonFiniteOp( RefPtr* srcs, const RefPtr& srcs_total_nr_elems, const RefPtr& dst, size_t B) : INIT(wtype(0)), srcs(srcs), srcs_total_nr_elems(srcs_total_nr_elems), dst(dst), B(B) {} }; void get_ABC(const TensorShape& shape, size_t& A, size_t& B, size_t& C, size_t axis); } // namespace reduce } // namespace megdnn // vim: syntax=cpp.doxygen