opr_impl.cpp 1.7 KB
Newer Older
1
/**
2
 * \file dnn/src/naive/check_non_finite/opr_impl.cpp
3 4 5 6 7 8 9 10 11
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */

12
#include "src/naive/check_non_finite/opr_impl.h"
13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29

#include "src/common/utils.h"
#include "src/naive/handle.h"

namespace {
using namespace megdnn;

#define src_ctype dt_float32
#define wtype dt_int32

void reduce_fwd(const src_ctype* sptr, wtype* dptr, size_t size) {
    std::function<wtype(size_t, size_t)> func;
    func = [&](size_t l, size_t r) -> wtype {
        if (l + 1 < r) {
            size_t mid = l + (r - l) / 2;
            return func(l, mid) | func(mid, r);
        } else {
30
            return static_cast<wtype>(!std::isfinite(sptr[l]));
31 32 33 34 35 36 37 38 39 40 41
        }
    };

    dptr[0] = func(0, size);
}

}  // namespace

namespace megdnn {
namespace naive {

42
size_t CheckNonFiniteImpl::get_workspace_in_bytes(const TensorLayout&,
43 44 45 46
                                               const TensorLayout&) {
    return 0;
}

47
void CheckNonFiniteImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
48 49 50 51 52 53 54 55 56 57 58 59
                           _megdnn_workspace workspace) {
    check_exec(src.layout, dst.layout, workspace.size);

    auto handle = static_cast<HandleImpl*>(this->handle());
    MEGDNN_DISPATCH_CPU_KERN(
            handle, reduce_fwd(src.ptr<dt_float32>(), dst.ptr<dt_int32>(),
                               src.layout.total_nr_elems()));
}
}  // namespace naive
}  // namespace megdnn

// vim: syntax=cpp.doxygen