/** * \file imperative/src/impl/op_def.cpp * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ #include "megbrain/imperative/op_def.h" #include #include "megbrain/imperative/ops/opr_attr.h" #include "./op_trait.h" namespace mgb { namespace imperative { std::shared_ptr OpDef::make_from_op_node( cg::OperatorNodeBase* node) { OpTrait* trait; trait = OpTrait::find_by_typeinfo(node->dyn_typeinfo()); if (!trait) { // TODO: register `make_from_op_node` for each OperatorNode // instead of forwarding to OprAttr trait = OpTrait::find_by_typeinfo(OprAttr::typeinfo()); } mgb_assert(trait); return trait->make_from_op_node(node); } DispatchMode OpDef::decide_dispatch_mode( const OpDef& def, const SmallVector& inputs) { return def.trait()->decide_dispatch_mode(def, inputs); } SmallVector OpDef::apply_on_physical_tensor( const OpDef& def, SmallVector inputs) { return def.trait()->apply_on_physical_tensor(def, std::move(inputs)); } std::tuple, SmallVector> OpDef::infer_output_mem_desc( const OpDef& def, const SmallVector& inputs_tensors, const SmallVector& inputs_mems) { return def.trait()->infer_output_mem_desc(def, inputs_tensors, inputs_mems); } void OpDef::execute( const OpDef& def, SmallVector inputs, SmallVector outputs, SmallVector workspace) { def.trait()->execute(def, std::move(inputs), outputs, std::move(workspace)); } void OpDef::apply_on_device_tensornd( const OpDef& def, const SmallVector& inputs, SmallVector* outputs) { def.trait()->apply_on_device_tensornd(def, inputs, outputs); return; } VarNodeArray OpDef::apply_on_var_node( const OpDef& def, const VarNodeArray& inputs) { return def.trait()->apply_on_var_node(def, inputs); } std::tuple, bool> OpDef::infer_output_attrs_fallible( const OpDef& def, const SmallVector& inputs) { return def.trait()->infer_output_attrs_fallible(def, inputs); } EncodedSubraph OpDef::make_backward_graph( const OpDef& def, const SmallVector& inputs, const SmallVector& input_requires_grad, const SmallVector& output_has_grad) { return def.trait()->make_backward_graph(def, inputs, input_requires_grad, output_has_grad); } std::vector> OpDef::props( const OpDef& def) { return def.trait()->props(def); } std::string OpDef::to_string() const { std::string builder = "{"; for (auto&& [name, value]: props(*this)) { builder += name; builder += ": "; builder += value; builder += ","; } return builder + "}"; } size_t OpDef::hash() const { return trait()->hash(*this); } bool OpDef::is_same_st(const Hashable& rhs) const { return trait()->is_same_st(*this, static_cast(rhs)); } const OpTrait* OpDef::trait() const { if (!m_trait) { m_trait = OpTrait::find_by_typeinfo(dyn_typeinfo()); mgb_throw_if(!m_trait, MegBrainError, "can not find op_trait by %s", dyn_typeinfo()->name); } return m_trait; } const std::string OpDef::scope() const { return m_scope; } void OpDef::set_scope(const std::string& scope) { m_scope = scope; } const std::string OpDef::make_name() const { if (m_scope.empty()) return trait()->make_name(*this); return m_scope + "." + trait()->make_name(*this); } std::string Subgraph::repr() const { std::ostringstream buf; buf << "("; for (size_t i = 0; i < inputs.size(); ++i) { if (i > 0) buf << ", "; buf << "%" << inputs[i]; } buf << ") => {\n"; auto fmt_const = [](size_t i, const TensorPtr& t) { if (t->shape().ndim == 1 && t->shape()[0] == 1) { auto&& v = t->get_value(); if (v.dtype() == dtype::Float32{}) { return std::to_string(*v.ptr()); } else if (v.dtype() == dtype::Int32{}) { return std::to_string(*v.ptr()); } } return std::string("%c") + std::to_string(i); }; std::unordered_map const_reps; for (auto&& [i, t] : constants) { const_reps.emplace(i, fmt_const(i, t)); } for (auto& [op, ins, outs] : exprs) { buf << " "; if (outs.size()) { for (size_t i = 0; i < outs.size(); ++i) { if (i > 0) buf << ", "; buf << "%" << outs[i]; } buf << " = "; } if (auto* p = op->try_cast_final()) { buf << p->type; } else { buf << op->dyn_typeinfo()->name; } for (size_t i : ins) { buf << " "; auto&& it = const_reps.find(i); if (it != const_reps.end()) { buf << it->second; } else { buf << "%" << i; } } buf << "\n"; } buf << " "; if (outputs.size()) { for (size_t i = 0; i < outputs.size(); ++i) { if (i > 0) buf << ", "; buf << "%" << outputs[i]; } } else { buf << "()"; } buf << "\n}\n"; return buf.str(); } } // namespace imperative } // namespace mgb // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}