proxy_graph_detail.h 1.3 KB
Newer Older
1
/**
M
Megvii Engine Team 已提交
2 3
 * \file imperative/src/impl/proxy_graph_detail.h
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
4
 *
M
Megvii Engine Team 已提交
5
 * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
6
 *
M
Megvii Engine Team 已提交
7 8 9
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41
 */

#pragma once

#include "megbrain/imperative/op_def.h"

namespace mgb {
namespace imperative {
namespace proxy_graph_detail {

void exec(const OpDef& def,
        const SmallVector<TensorPtr>& inputs_,
        const SmallVector<TensorPtr>& outputs_);

SmallVector<LogicalTensorDesc> infer_output_attrs(const OpDef& def,
        const SmallVector<TensorPtr>& inputs);

SmallVector<LogicalTensorDesc>
infer_output_attrs_fallible(const OpDef& def,
        const SmallVector<LogicalTensorDesc>& inputs);

BackwardGraphResult
make_backward_graph(const OpDef& def,
        const SmallVector<LogicalTensorDesc>& inputs,
        const SmallVector<bool>& input_requires_grad,
        const SmallVector<bool>& output_has_grad);

} // namespace proxy_graph_detail
} // namespace imperative
} // namespace mgb

// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}