reduce_helper.h 6.0 KB
Newer Older
1 2 3 4
/**
 * \file dnn/src/common/reduce_helper.h
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
5
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
6
 *
7 8 9
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10 11 12 13 14 15 16 17 18 19 20 21 22 23 24
 */
#pragma once
#include "megdnn/dtype.h"

#include "megdnn/basic_types.h"

namespace megdnn {
namespace reduce {

template <typename src_ctype, typename dst_ctype, typename wtype_>
struct SumOp {
    typedef wtype_ wtype;

    const wtype INIT;

25 26
    RefPtr src;
    RefPtr dst;
27 28
    const size_t B;

29 30 31 32
    wtype read(uint32_t idx) { return src.ptr<src_ctype>()[idx]; }
    void write(uint32_t idx, wtype val) { dst.ptr<dst_ctype>()[idx] = val; }
    static wtype apply(wtype lhs, wtype rhs) { return lhs + rhs; }
    SumOp(const RefPtr& src, const RefPtr& dst, size_t B)
33 34 35 36 37 38 39 40 41
            : INIT(wtype(0)), src(src), dst(dst), B(B) {}
};

template <typename src_ctype, typename dst_ctype, typename wtype_>
struct MeanOp {
    typedef wtype_ wtype;

    const wtype INIT;

42 43
    RefPtr src;
    RefPtr dst;
44
    const size_t B;
M
Megvii Engine Team 已提交
45

46 47 48
    wtype read(uint32_t idx) { return src.ptr<src_ctype>()[idx]; }
    void write(uint32_t idx, wtype val) {
        dst.ptr<dst_ctype>()[idx] = val / static_cast<wtype>(B);
49
    }
50 51
    static wtype apply(wtype lhs, wtype rhs) { return lhs + rhs; }
    MeanOp(const RefPtr& src, const RefPtr& dst, size_t B)
52 53 54 55 56 57 58 59 60
            : INIT(wtype(0)), src(src), dst(dst), B(B) {}
};

template <typename src_ctype, typename dst_ctype, typename wtype_>
struct SumSqrOp {
    typedef wtype_ wtype;

    const wtype INIT;

61 62
    RefPtr src;
    RefPtr dst;
63 64
    const size_t B;

65 66 67
    wtype read(uint32_t idx) {
        return static_cast<wtype>(src.ptr<src_ctype>()[idx]) *
               static_cast<wtype>(src.ptr<src_ctype>()[idx]);
68
    }
69 70 71
    void write(uint32_t idx, wtype val) { dst.ptr<dst_ctype>()[idx] = val; }
    static wtype apply(wtype lhs, wtype rhs) { return lhs + rhs; }
    SumSqrOp(const RefPtr& src, const RefPtr& dst, size_t B)
72 73 74 75 76 77 78 79
            : INIT(wtype(0)), src(src), dst(dst), B(B) {}
};

template <typename src_ctype, typename dst_ctype, typename wtype_>
struct ProdOp {
    typedef wtype_ wtype;
    const wtype INIT;

80 81
    RefPtr src;
    RefPtr dst;
82 83
    const size_t B;

84 85 86 87
    wtype read(uint32_t idx) { return src.ptr<src_ctype>()[idx]; }
    void write(uint32_t idx, wtype val) { dst.ptr<dst_ctype>()[idx] = val; }
    static wtype apply(wtype lhs, wtype rhs) { return lhs * rhs; }
    ProdOp(const RefPtr& src, const RefPtr& dst, size_t B)
88 89 90 91 92 93 94 95
            : INIT(wtype(1)), src(src), dst(dst), B(B) {}
};

template <typename src_ctype, typename dst_ctype, typename wtype_>
struct MinOp {
    typedef wtype_ wtype;
    const wtype INIT;

96 97
    RefPtr src;
    RefPtr dst;
98 99
    const size_t B;

100 101 102 103
    wtype read(uint32_t idx) { return src.ptr<src_ctype>()[idx]; }
    void write(uint32_t idx, wtype val) { dst.ptr<dst_ctype>()[idx] = val; }
    static wtype apply(wtype lhs, wtype rhs) { return std::min(lhs, rhs); }
    MinOp(const RefPtr& src, const RefPtr& dst, size_t B)
104 105 106
            : INIT(wtype(DTypeTrait<wtype>::max())), src(src), dst(dst), B(B) {}
};

107 108 109 110 111
template <typename src_ctype, typename dst_ctype>
struct MinOp<src_ctype, dst_ctype, dt_float32> {
    typedef dt_float32 wtype;
    const wtype INIT;

112 113
    RefPtr src;
    RefPtr dst;
114 115
    const size_t B;

116 117 118
    wtype read(uint32_t idx) { return src.ptr<src_ctype>()[idx]; }
    void write(uint32_t idx, wtype val) { dst.ptr<dst_ctype>()[idx] = val; }
    static wtype apply(wtype lhs, wtype rhs) {
119 120
        return (std::isnan(lhs) || lhs < rhs) ? lhs : rhs;
    }
121
    MinOp(const RefPtr& src, const RefPtr& dst, size_t B)
122 123 124
            : INIT(wtype(DTypeTrait<wtype>::max())), src(src), dst(dst), B(B) {}
};

125 126 127 128 129
template <typename src_ctype, typename dst_ctype, typename wtype_>
struct MaxOp {
    typedef wtype_ wtype;
    const wtype INIT;

130 131
    RefPtr src;
    RefPtr dst;
132 133
    const size_t B;

134 135 136 137
    wtype read(uint32_t idx) { return src.ptr<src_ctype>()[idx]; }
    void write(uint32_t idx, wtype val) { dst.ptr<dst_ctype>()[idx] = val; }
    static wtype apply(wtype lhs, wtype rhs) { return std::max(lhs, rhs); }
    MaxOp(const RefPtr& src, const RefPtr& dst, size_t B)
138 139 140 141 142 143 144 145
            : INIT(wtype(DTypeTrait<wtype>::min())), src(src), dst(dst), B(B) {}
};

template <typename src_ctype, typename dst_ctype>
struct MaxOp<src_ctype, dst_ctype, dt_float32> {
    typedef dt_float32 wtype;
    const wtype INIT;

146 147
    RefPtr src;
    RefPtr dst;
148 149
    const size_t B;

150 151 152
    wtype read(uint32_t idx) { return src.ptr<src_ctype>()[idx]; }
    void write(uint32_t idx, wtype val) { dst.ptr<dst_ctype>()[idx] = val; }
    static wtype apply(wtype lhs, wtype rhs) {
153
        return (std::isnan(lhs) || lhs > rhs) ? lhs : rhs;
154
    }
155
    MaxOp(const RefPtr& src, const RefPtr& dst, size_t B)
156 157 158
            : INIT(wtype(DTypeTrait<wtype>::min())), src(src), dst(dst), B(B) {}
};

159
template <typename src_ctype, typename index_ctype, typename dst_ctype, typename wtype_>
160
struct CheckNonFiniteOp {
161 162 163
    typedef wtype_ wtype;
    const wtype INIT;

164 165
    RefPtr* srcs;
    RefPtr srcs_total_nr_elems;
166
    RefPtr dst;
167 168
    const size_t B;

169 170 171 172 173 174 175 176 177
    wtype read(uint32_t idx) {
        size_t x = idx / B;
        size_t y = idx % B;
        if (y < srcs_total_nr_elems.ptr<index_ctype>()[x]) {
            RefPtr src = srcs[x];
            return !std::isfinite(src.ptr<src_ctype>()[y]);
        }
        return 0;
    }
178 179
    void write(uint32_t idx, wtype val) { dst.ptr<dst_ctype>()[idx] = val; }
    static wtype apply(wtype lhs, wtype rhs) { return lhs | rhs; }
180 181 182 183 184 185 186 187
    CheckNonFiniteOp(
            RefPtr* srcs, const RefPtr& srcs_total_nr_elems, const RefPtr& dst,
            size_t B)
            : INIT(wtype(0)),
              srcs(srcs),
              srcs_total_nr_elems(srcs_total_nr_elems),
              dst(dst),
              B(B) {}
188 189
};

M
Megvii Engine Team 已提交
190
void get_ABC(const TensorShape& shape, size_t& A, size_t& B, size_t& C, size_t axis);
191 192 193 194

}  // namespace reduce
}  // namespace megdnn
// vim: syntax=cpp.doxygen