dnn_op_helper.h 1.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54
/**
 * \file src/core/include/megbrain/imperative.h
 *
 * This file is part of MegBrain, a deep learning framework developed by Megvii.
 *
 * \copyright Copyright (c) 2014-2019 Megvii Inc. All rights reserved.
 *
 */

#include "megbrain/comp_node_env.h"
#include "megbrain/comp_node.h"

using namespace megdnn;

namespace mgb {
namespace imperative {

/*!
 * \brief A struct for safely calling DNN oprs
 * In some cases, op may be released before the complete of the execution
 * This destructor will prevent this
 */
template<typename Opr>
struct DnnOprCaller {
    CompNode cn;
    DeviceTensorND dev_tensor;
    Workspace workspace;
    std::unique_ptr<Opr> op;

    DnnOprCaller(CompNode cn): cn(cn) {
        auto&& handle = MegDNNHandle::get(
                                CompNodeEnv::from_comp_node(cn)).handle();
        op = handle->create_operator<Opr>();
    }

    megdnn::Workspace create_workspace(TensorLayout layout) {
        dev_tensor = Tensor::make(layout, cn)->dev_tensor();
        workspace = megdnn::Workspace(dev_tensor.raw_ptr(), 
                                      dev_tensor.storage().size());
        return workspace;
    }
    
    ~DnnOprCaller() {
        using DT = CompNode::DeviceType;
        if (cn.device_type() == DT::CPU && cn != CompNode::default_cpu()) {
            CompNodeEnv::from_comp_node(cn).cpu_env().dispatch(
                [p = op.release()] { delete p; }
            );
        }
    }
};

} // namespace imperative
} // namespace mgb