opr_impl.cpp 4.8 KB
Newer Older
1
/**
2
 * \file dnn/src/cuda/check_non_finite/opr_impl.cpp
3 4 5 6 7 8 9 10 11
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */

12
#include "src/cuda/check_non_finite/opr_impl.h"
13 14 15 16 17
#include "src/cuda/reduce_helper.cuh"

#include "src/cuda/handle.h"
#include "src/cuda/utils.h"

18
#include "src/common/reduce_helper_device.h"
19 20 21 22

namespace megdnn {
namespace cuda {

23
using device_reduce::CheckNonFiniteOp;
24
#define total_nr_elems_max 2048
25
template <typename T>
26 27
size_t CheckNonFiniteImpl::_get_workspace_in_bytes() {
    // Call the _get_workspace_in_bytes to reduce the loop fetch workspace bytes
28
    typedef CheckNonFiniteOp<T, size_t, dt_int32, dt_int32> Op;
29 30 31
    megdnn_assert(m_size > 0);
    WorkspaceBundle bundle(
            nullptr, {
32
                             sizeof(T*) * m_size,
33 34 35 36 37
                             sizeof(size_t) * m_size,
                     });
    return get_reduce_workspace_in_bytes<Op>(1, m_size * total_nr_elems_max, 1) +
           bundle.total_size_in_bytes();
}
38

M
Megvii Engine Team 已提交
39
size_t CheckNonFiniteImpl::get_workspace_in_bytes(
40 41 42 43 44
        const TensorNDArray& srcs, const TensorLayout&) {
    m_size = 0;
    for (const auto& src : srcs) {
        m_size += DIVUP(src.layout.total_nr_elems(), total_nr_elems_max);
    }
45 46 47 48 49 50 51 52
    if (srcs.begin()->layout.dtype == dtype::Float32()) {
        return _get_workspace_in_bytes<dt_float32>();
    } else if (srcs.begin()->layout.dtype == dtype::Float16()) {
        return _get_workspace_in_bytes<dt_float16>();
    } else {
        megdnn_log_warn("only support fp16 and fp32, fallback to fp32");
        return _get_workspace_in_bytes<dt_float32>();
    }
53 54
}

M
Megvii Engine Team 已提交
55
void CheckNonFiniteImpl::exec(
56 57
        _megdnn_in const TensorNDArray& srcs, _megdnn_tensor_out dst,
        _megdnn_workspace workspace) {
58 59 60 61 62 63 64 65 66 67 68 69 70 71
    if (srcs.begin()->layout.dtype == dtype::Float32()) {
        _exec<dt_float32>(srcs, dst, workspace);
    }
#ifdef DNN_INC_FLOAT16
    else if (srcs.begin()->layout.dtype == dtype::Float16()) {
        _exec<dt_float16>(srcs, dst, workspace);
    }
#endif
}

template <typename T>
void CheckNonFiniteImpl::_exec(
        _megdnn_in const TensorNDArray& srcs, _megdnn_tensor_out dst,
        _megdnn_workspace workspace) {
72
    check_exec(srcs, dst, workspace.size);
73
    typedef CheckNonFiniteOp<T, size_t, dt_int32, dt_int32> Op;
74
    auto stream = cuda_stream(this->handle());
75
    SmallVector<size_t> workspace_sizes{
76
            sizeof(T*) * m_size,
77 78 79 80 81 82 83 84 85 86 87
            sizeof(size_t) * m_size,
    };
    WorkspaceBundle workspace_cpu(nullptr, workspace_sizes),
            workspace_gpu(nullptr, workspace_sizes);
    auto total_workspace_size = workspace_cpu.total_size_in_bytes();
    void* workspace_cpu_raw = malloc(total_workspace_size);
    megdnn_assert_internal(workspace_cpu_raw);
    void* workspace_gpu_raw = workspace.raw_ptr;
    workspace_cpu = WorkspaceBundle(workspace_cpu_raw, workspace_sizes);
    workspace_gpu = WorkspaceBundle(workspace_gpu_raw, workspace_sizes);

88 89
    auto srcs_cpu = static_cast<T**>(workspace_cpu.get(0));
    auto srcs_gpu = static_cast<T**>(workspace_gpu.get(0));
90 91 92 93 94 95 96 97 98 99
    auto srcs_total_nr_elems_cpu = static_cast<size_t*>(workspace_cpu.get(1));
    auto srcs_total_nr_elems_gpu = static_cast<size_t*>(workspace_gpu.get(1));

    // srcs
    // cut the tensor to a fixed length of total_nr_elems_max
    size_t i = 0;
    for (const auto& src : srcs) {
        size_t src_nr_elems = src.layout.total_nr_elems();
        size_t nr_elems = DIVUP(src_nr_elems, total_nr_elems_max);
        for (size_t j = 0; j < nr_elems; ++j, ++i) {
100
            srcs_cpu[i] = src.ptr<T>() + j * total_nr_elems_max;
101 102 103 104 105 106 107 108 109 110 111 112 113 114 115
            if (j + 1 == nr_elems && src_nr_elems % total_nr_elems_max) {
                srcs_total_nr_elems_cpu[i] = src_nr_elems % total_nr_elems_max;
            } else {
                srcs_total_nr_elems_cpu[i] = total_nr_elems_max;
            }
        }
    }
    for (size_t i = 0; i < workspace_cpu.nr_workspace(); ++i) {
        cuda_check(cudaMemcpyAsync(
                workspace_gpu.get(i), workspace_cpu.get(i), workspace_cpu.get_size(i),
                cudaMemcpyHostToDevice, stream));
    }
    cuda_check(cudaStreamAddCallback(
            stream, callback_free, static_cast<void*>(workspace_cpu_raw), 0));

116
    return run_reduce<Op, false>(
117 118 119 120 121
            static_cast<dt_int32*>(
                    (void*)((char*)workspace_gpu_raw +
                            workspace_gpu.total_size_in_bytes())),
            1, m_size * total_nr_elems_max, 1, stream,
            Op(srcs_gpu, srcs_total_nr_elems_gpu, dst.ptr<dt_int32>(),
122
               total_nr_elems_max, static_cast<T>(param().scale)));
123 124 125 126 127 128
}

}  // namespace cuda
}  // namespace megdnn

// vim: syntax=cpp.doxygen