/** * \file imperative/src/impl/op_def.cpp * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ #include "megbrain/imperative/op_def.h" #include #include "megbrain/imperative/ops/opr_attr.h" #include "./op_trait.h" namespace mgb { namespace imperative { std::shared_ptr OpDef::make_from_op_node(cg::OperatorNodeBase* node) { OpTrait* trait; trait = OpTrait::find_by_typeinfo(node->dyn_typeinfo()); if (!trait) { // TODO: register `make_from_op_node` for each OperatorNode // instead of forwarding to OprAttr trait = OpTrait::find_by_typeinfo(OprAttr::typeinfo()); } mgb_assert(trait); return trait->make_from_op_node(node); } DispatchMode OpDef::decide_dispatch_mode( const OpDef& def, const SmallVector& inputs) { return def.trait()->decide_dispatch_mode(def, inputs); } SmallVector OpDef::apply_on_physical_tensor( const OpDef& def, SmallVector inputs) { return def.trait()->apply_on_physical_tensor(def, std::move(inputs)); } std::tuple, SmallVector> OpDef:: infer_output_mem_desc( const OpDef& def, const SmallVector& inputs_tensors, const SmallVector& inputs_mems) { return def.trait()->infer_output_mem_desc(def, inputs_tensors, inputs_mems); } void OpDef::execute( const OpDef& def, SmallVector inputs, SmallVector outputs, SmallVector workspace) { def.trait()->execute(def, std::move(inputs), outputs, std::move(workspace)); } void OpDef::apply_on_device_tensornd( const OpDef& def, const SmallVector& inputs, SmallVector* outputs) { def.trait()->apply_on_device_tensornd(def, inputs, outputs); return; } VarNodeArray OpDef::apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { return def.trait()->apply_on_var_node(def, inputs); } std::tuple, bool> OpDef::infer_output_attrs_fallible( const OpDef& def, const SmallVector& inputs) { return def.trait()->infer_output_attrs_fallible(def, inputs); } EncodedSubgraph OpDef::make_backward_graph( const OpDef& def, const SmallVector& inputs, const SmallVector& input_requires_grad, const SmallVector& output_has_grad) { using BackwardGraphCache = OpMethResultCache, SmallVector>; thread_local BackwardGraphCache cache; decltype(cache)::key_t cache_key{ const_cast(def).shared_from_this(), inputs, {input_requires_grad, output_has_grad}}; auto iter = cache.find(cache_key); if (iter == cache.end()) { iter = cache.insert({cache_key, def.trait()->make_backward_graph( def, inputs, input_requires_grad, output_has_grad)}) .first; } return iter->second; } std::vector> OpDef::props(const OpDef& def) { return def.trait()->props(def); } EncodedSubgraph OpDef::make_forward_graph( const OpDef& def, const SmallVector& inputs) { using ForwardGraphCache = OpMethResultCache, SmallVector>; thread_local ForwardGraphCache cache; decltype(cache)::key_t cache_key{ const_cast(def).shared_from_this(), inputs}; auto iter = cache.find(cache_key); if (iter == cache.end()) { iter = cache.insert({cache_key, def.trait()->make_forward_graph(def, inputs)}) .first; } return iter->second; } std::string OpDef::to_string() const { std::string builder = trait()->make_name(*this) + "{"; for (auto&& [name, value] : props(*this)) { builder += name; builder += ": "; builder += value; builder += ","; } return builder + "}"; } size_t OpDef::hash() const { return trait()->hash(*this); } bool OpDef::is_same_st(const Hashable& rhs) const { return trait()->is_same_st(*this, static_cast(rhs)); } const OpTrait* OpDef::trait() const { if (!m_trait) { m_trait = OpTrait::find_by_typeinfo(dyn_typeinfo()); mgb_throw_if( !m_trait, MegBrainError, "can not find op_trait by %s", dyn_typeinfo()->name); } return m_trait; } const std::string OpDef::scope() const { return m_scope; } void OpDef::set_scope(const std::string& scope) { m_scope = scope; } const std::string OpDef::make_name() const { if (m_scope.empty()) return trait()->make_name(*this); return m_scope + "." + trait()->make_name(*this); } static thread_local OpDef::allocator_t local_allocator; void OpDef::set_allocator(allocator_t allocator) { mgb_assert(!local_allocator, "allocator has been set before"); local_allocator = allocator; } DeviceTensorStorage::RawStorage OpDef::allocate(CompNode device, size_t size) const { return local_allocator(device, size); } std::string Subgraph::repr() const { std::ostringstream buf; buf << "("; for (size_t i = 0; i < inputs.size(); ++i) { if (i > 0) buf << ", "; buf << "%" << inputs[i]; } buf << ") => {\n"; auto fmt_const = [](size_t i, const TensorPtr& t) { if (t->shape().ndim == 1 && t->shape()[0] == 1) { auto&& v = t->get_value(); if (v.dtype() == dtype::Float32{}) { return std::to_string(*v.ptr()); } else if (v.dtype() == dtype::Int32{}) { return std::to_string(*v.ptr()); } } return std::string("%c") + std::to_string(i); }; std::unordered_map const_reps; for (auto&& [i, t] : constants) { const_reps.emplace(i, fmt_const(i, t)); } for (auto& [op, ins, outs] : exprs) { buf << " "; if (outs.size()) { for (size_t i = 0; i < outs.size(); ++i) { if (i > 0) buf << ", "; buf << "%" << outs[i]; } buf << " = "; } if (auto* p = op->try_cast_final()) { buf << p->type; } else { buf << op->make_name(); } for (size_t i : ins) { buf << " "; auto&& it = const_reps.find(i); if (it != const_reps.end()) { buf << it->second; } else { buf << "%" << i; } } buf << "\n"; } buf << " "; if (outputs.size()) { for (size_t i = 0; i < outputs.size(); ++i) { if (i > 0) buf << ", "; buf << "%" << outputs[i]; } } else { buf << "()"; } buf << "\n}\n"; return buf.str(); } bool Subgraph::is_single() const { if (exprs.size() != 1) { return false; } auto& expr = exprs.at(0); return expr.inputs == inputs && expr.outputs == outputs; } std::shared_ptr Subgraph::as_single() const { if (is_single()) { return exprs.at(0).op; } else { return nullptr; } } bool Subgraph::operator==(const Subgraph& rhs) const { mgb_assert(false, "Not Implemented"); } } // namespace imperative } // namespace mgb // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}