/** * \file src/jit/impl/executor_opr.cpp * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ #include "megbrain/jit/executor_opr.h" #include "megbrain/common.h" #include "megbrain/comp_node_env.h" #include "megbrain/gopt/framework.h" #include "megbrain/graph/grad_impl.h" #include "megbrain/graph/helper.h" #include "megbrain/jit/compiler.h" #include "megbrain/jit/param_elem_visitor.h" #include "megbrain/jit/placeholder_opr.h" #include "megbrain/opr/basic_arith_wrapper.h" #include "megbrain/opr/tensor_manip.h" #include "megbrain/opr/utility.h" #include "megbrain/utils/hash.h" #include "megbrain/serialization/opr_shallow_copy.h" #if MGB_JIT using namespace mgb; using namespace jit; using CPFlag = Compiler::Property::Flag; /* =================== Fusion ==================== */ MGB_DYN_TYPE_OBJ_FINAL_IMPL(JITExecutor); JITExecutor::JITExecutor(const InternalGraphPtr& internal_graph, const VarNodeArray& inputs, const OperatorNodeConfig& config) : Super(internal_graph->output()->owner_graph(), config, ssprintf("JIT-Fusion{%zu}", internal_graph->placeholders().size()), inputs), m_internal_graph{internal_graph}, m_compiler{Compiler::get(*inputs[0]->owner_graph(), inputs[0]->comp_node())} { for (auto inp : inputs) { add_input({inp}); } m_input_broadcastable.resize(inputs.size()); auto&& placeholders = m_internal_graph->placeholders(); mgb_assert(placeholders.size() == inputs.size()); for (size_t i = 0; i < inputs.size(); ++i) { mgb_assert(placeholders[i]->output(0) != internal_graph->output()); if (placeholders[i]->is_host_value_shape_input() || input()[i] ->owner_opr() ->same_type()) { m_input_broadcastable[i] = false; } else { m_input_broadcastable[i] = true; } } if (inputs.size() == 1) { m_input_broadcastable[0] = false; } else { Maybe non_scalar; for (size_t i = 0; i < input().size(); ++i) { if (placeholders[i]->is_host_value_shape_input()) continue; if (!(cg::is_const_var_shape(input(i)) && input(i)->shape().is_scalar())) { if (non_scalar.valid()) { non_scalar.invalidate(); break; } non_scalar = i; } } if (non_scalar.valid()) { // exactly one input is non-scalar m_input_broadcastable[non_scalar.val()] = false; } } add_output(None)->dtype(m_internal_graph->output()->dtype()); add_equivalence_component>(internal_graph->output()); for (size_t i = 0, it = m_compiler->get_nr_workspace_outputs(this); i < it; ++i) { cg::add_workspace_output(this); } // check if output of internal_graph is depend on all placeholders size_t nr_placeholders = internal_graph_ptr()->placeholders().size(); std::vector used(nr_placeholders, false); // check if there is reduce or dimshuffle opr cg::DepOprIter{[this, nr_placeholders, &used](cg::OperatorNodeBase* opr) { if (opr->same_type()) { m_feature_bits |= JITFeatureBits::REDUCE; } if (opr->same_type()) { m_feature_bits |= JITFeatureBits::DIMSHUFFLE; } if (auto ph = opr->try_cast_final()) { mgb_assert(ph->input_id() < nr_placeholders, "bad placeholders %s in JITExecutor %s", ph->cname(), cname()); used[ph->input_id()] = true; } }}.add(internal_graph->output()); for (size_t i = 0; i < nr_placeholders; ++ i) { mgb_assert(used[i], "placeholder %s is not depended on the output of %s", internal_graph_ptr()->placeholders()[i]->cname(), cname()); } if (has_dimshuffle()) { prepare_dimshuffle(); } } void JITExecutor::add_input_layout_constraint() { if (m_compiler->property().contain_flag(CPFlag::NEED_INPUT_CONTIG)) { for (auto i : input()) { i->add_layout_constraint_contiguous(); } } else { for (auto i : input()) { i->add_layout_constraint_monotone(); } } } void JITExecutor::init_output_mem_plan(bool dynamic) { Super::init_output_mem_plan(dynamic); m_args.need_update = true; } void JITExecutor::mem_plan_fwd_in2out_writable() { //! currently mem fwd only support elemwise fusion if (m_feature_bits != JITFeatureBits::NONE) return; mixin_mem_plan_fwd_in2out_writable(*this); } SymbolVar JITExecutor::make(const InternalGraphPtr& internal_graph, const VarNodeArray& inputs, const OperatorNodeConfig& config) { return internal_graph->output() ->owner_graph() ->insert_opr(std::make_unique(internal_graph, inputs, config)) ->output(0); } void JITExecutor::init_output_static_infer_desc() { using namespace cg::static_infer; auto&& mgr = owner_graph()->static_infer_manager(); mgr.register_shape_infer( output(0), ShapeInferDesc::make_identity(m_internal_graph->shape_infer())); m_compiler->init_workspace_size_infer(this); if (m_internal_graph->value_infer()) { mgr.register_value_infer( output(0), ValueInferDesc::make_identity(m_internal_graph->value_infer())); } } void JITExecutor::scn_do_execute() { if (m_executable == nullptr || m_args.need_update) { m_executable = m_compiler->compile(this); } m_executable->execute(this); } //! change the inputs which depend on dimshuffle opr, make sure dimshuffles //! can be ignored void JITExecutor::do_dimshuffle() { static auto get_dimshuffled_layout = [](const TensorLayout& ily, std::vector pattern) { TensorLayout oly{ily.dtype}; oly.ndim = pattern.size(); bool input_used[TensorLayout::MAX_NDIM] = {0}; for (uint32_t idx = 0; idx < pattern.size(); ++idx) { auto i = pattern[idx]; if (i < 0) { oly.shape[idx] = 1; oly.stride[idx] = 1; } else { input_used[i] = true; oly.shape[idx] = ily.shape[i]; oly.stride[idx] = ily.stride[i]; } } for (size_t i = 0; i < ily.ndim; ++i) { mgb_assert(input_used[i] || ily.shape[i] == 1, "non-1 dim discarded in Dimshuffle: ishp=%s dim=%zd", static_cast(ily).to_string().c_str(), i); } return oly; }; for (auto&& i : m_internal_graph->placeholders()) { auto&& input = m_args.inputs[i->input_id()]; auto&& iter = m_jitph2dimshuffle.find(i); if (iter == m_jitph2dimshuffle.end()) continue; auto&& param = iter->second; mgb_assert(input.layout.ndim == param.second, "input ndim mismatch for Dimshuffle: " "expect=%u " "actual=%zu", param.second, input.layout.ndim); auto dimshuffled_layout = get_dimshuffled_layout( input.layout, param.first); input.layout = dimshuffled_layout; } } void JITExecutor::update_args() { m_args.outputs.clear(); for (auto out : output()) { m_args.outputs.push_back({out, out->layout(), -1}); } m_args.inputs.resize(input().size()); auto is_host_value_shape_input = [this](size_t idx) { return m_internal_graph->placeholders() .at(idx) ->is_host_value_shape_input(); }; for (size_t i = 0; i < input().size(); i++) { auto&& dst_data = m_args.inputs[i]; dst_data.from = input(i); dst_data.idx = i; if (is_host_value_shape_input(i)) { auto&& mgr = owner_graph()->static_infer_manager(); auto&& shpval_inp_val = &mgr.infer_value(input(i)); cg::copy_tensor_value_to_shape(dst_data.layout, *shpval_inp_val); dst_data.layout.dtype = {}; for (size_t i = 0; i < dst_data.layout.ndim; ++i) { dst_data.layout.stride[i] = 0; } } else { dst_data.layout = input(i)->layout(); } } //! dimshuffle opr need to change the input. if (has_dimshuffle()) { do_dimshuffle(); } if (m_compiler->property().contain_flag(CPFlag::NEED_INPUT_COLLAPSE)) { // collective collapse datum layout, try to reduce the output ndim opr::Elemwise::TensorLayoutPtrArray inp_layouts; inp_layouts.reserve(m_args.inputs.size()); for (size_t i = 0; i < m_args.inputs.size(); i++) { if (!is_host_value_shape_input(i)) { inp_layouts.push_back(&m_args.inputs[i].layout); } } opr::Elemwise::broadcast_collective_collapse(inp_layouts, &m_args.outputs[0].layout); } // compute and update hash XXHash hstate; // update layout info auto prop = m_compiler->property(); if (prop.contain_flag(CPFlag::BIND_NDIM | CPFlag::BIND_SHAPE)) { mgb_assert(prop.contain_flag(CPFlag::BIND_NDIM), "BIND_NDIM must be set if bind_shape is set"); std::vector buf; buf.reserve(1024); buf.push_back(m_args.inputs.size()); for (auto&& i : m_args.inputs) { buf.push_back(i.layout.ndim); if (prop.contain_flag(CPFlag::BIND_SHAPE)) { for (size_t j = 0; j < i.layout.ndim; ++j) { buf.push_back(i.layout[j]); } } } hstate.update(buf.data(), sizeof(buf[0]) * buf.size()); } m_args.hash = hstate.digest(); // update version number static std::atomic_uint_fast64_t global_version; m_args.version = global_version.fetch_add(1); m_args.need_update = false; } void JITExecutor::prepare_dimshuffle() { std::unordered_set visited; std::vector stack(0); std::vector idx(0); // input index using Param = DimshuffleParam; std::vector dimshuffle_stack; auto merge_dimshuffle = [&](const opr::Dimshuffle::Param& p) { if (dimshuffle_stack.empty()) { dimshuffle_stack.emplace_back(); auto&& param = dimshuffle_stack.back(); param.first.insert(param.first.end(), p.pattern, p.pattern + p.pattern_len); param.second = p.ndim; } else { // merge(p, src) -> param and it has performing dimshuffle(dimshuffle(x, p), src) // is equivalent to dimshuffle(x, param) dimshuffle_stack.emplace_back(); auto&& param = dimshuffle_stack.back(); auto&& src = dimshuffle_stack[dimshuffle_stack.size() - 2]; mgb_assert(p.pattern_len == src.second); param.first.resize(src.first.size()); for (size_t i = 0; i < src.first.size(); ++ i) { if (src.first[i] == -1) { param.first[i] = -1; } else { param.first[i] = p.pattern[src.first[i]]; } } param.second = p.ndim; } }; auto push_back = [&](cg::OperatorNodeBase* op) { mgb_assert(!op->same_type()); if (auto o = op->try_cast_final()) { merge_dimshuffle(o->param()); } stack.push_back(op); idx.push_back(0); }; auto pop_back = [&]() { auto&& op = stack.back(); if (op->same_type()) { dimshuffle_stack.pop_back(); } stack.pop_back(); idx.pop_back(); }; push_back(m_internal_graph->output()->owner_opr()); while (!stack.empty()) { if (idx.back() < stack.back()->input().size()) { auto cur_opr = stack.back()->input(idx.back())->owner_opr(); if (visited.insert(cur_opr).second) { if (auto jitph = cur_opr->try_cast_final()) { if (!dimshuffle_stack.empty()) { mgb_assert( m_jitph2dimshuffle.emplace(jitph, dimshuffle_stack.back()).second, "already visited JITPlaceholder %s", jitph->cname()); } ++ idx.back(); } else { push_back(cur_opr); } } else { ++ idx.back(); } } else { pop_back(); if (!stack.empty()) ++ idx.back(); } } } const JITExecutor::Args& JITExecutor::args() const { if (m_args.need_update) { const_cast(this)->update_args(); } return m_args; } bool JITExecutor::Args::operator==(const Args& rhs) const { auto&& lhs = *this; mgb_assert(!lhs.need_update && !rhs.need_update); if (lhs.hash != rhs.hash) { return false; } if (lhs.version == rhs.version) { return true; } if (lhs.outputs.size() != rhs.outputs.size()) return false; if (lhs.inputs.size() != rhs.inputs.size()) return false; auto prop = owner->m_compiler->property(); if (prop.contain_flag(CPFlag::BIND_NDIM | CPFlag::BIND_SHAPE)) { bool (*chk_layout)(const TensorLayout&, const TensorLayout&); if (prop.contain_flag(CPFlag::BIND_SHAPE)) { chk_layout = [](const TensorLayout& lhs, const TensorLayout& rhs) { return lhs.eq_shape(rhs); }; } else { chk_layout = [](const TensorLayout& lhs, const TensorLayout& rhs) { return lhs.ndim == rhs.ndim; }; } for (size_t i = 0; i < lhs.inputs.size(); i++) { if (!chk_layout(lhs.inputs[i].layout, rhs.inputs[i].layout)) return false; } for (size_t i = 0; i < lhs.outputs.size(); i++) { if (!chk_layout(lhs.outputs[i].layout, rhs.outputs[i].layout)) return false; } } // elect a common version so next check can be fast lhs.version = rhs.version = std::min(lhs.version, rhs.version); return true; } JITExecutor::NodeProp* JITExecutor::do_make_node_prop() const { auto ret = Super::do_make_node_prop(); using DepType = NodeProp::DepType; SmallVector dt(input().size()); auto&& placeholders = internal_graph().placeholders(); for (size_t i = 0; i < dt.size(); ++i) { dt[i] = placeholders[i]->is_host_value_shape_input() ? DepType::HOST_VALUE : DepType::DEV_VALUE; } ret->reset_dep_type(input(), dt); return ret; } megdnn::TensorShape JITExecutor::broadcasted_input_shape() const { megdnn::TensorShapeArray inp_shps; megdnn::TensorShape brdcast_shp; auto placeholders = m_internal_graph->placeholders(); for (auto ph : placeholders) { if (!ph->is_host_value_shape_input()) { inp_shps.push_back(input(ph->input_id())->shape()); } } megdnn::Elemwise::deduce_shape(inp_shps, brdcast_shp); return brdcast_shp; } #if MGB_ENABLE_GRAD namespace { class InternalGraphRewriter { ThinHashMap m_var_map; VarNode* m_dest_var; VarNodeArray m_new_inp; VarNode* get_var(VarNode* var) { auto&& iter = m_var_map.find(var); if (iter != m_var_map.end()) { return iter->second; } return var; } public: InternalGraphRewriter(VarNode* dest_var) :m_dest_var{dest_var}{} void iter(thin_function&& cb) { m_var_map.clear(); cg::DepOprIter{std::move(cb)}.add(m_dest_var->owner_opr()); m_dest_var = get_var(m_dest_var); } VarNode* dest_var() { return m_dest_var; } void replace_var(VarNode* src, VarNode* dst) { // Note: do not perform var replacing recursively // when we extract used placeholders from internal graph, we don't // consider placeholder replacement pair (a to b), (b to c) as a // var replacing chain (a to b to c) but as a injective function // from (a, b) to (b, c) // in other cases, each var node would be passed as \p src or // \p dst at most once m_var_map[src] = dst; } void auto_replace_outputs(cg::OperatorNodeBase* opr) { // in JIT internal graph, output size of opr is always 1 mgb_assert(opr->usable_output().size() == 1); m_new_inp.clear(); bool need_replace = false; for (auto&& i : opr->input()) { auto inp = get_var(i); m_new_inp.push_back(inp); need_replace |= (inp != i); } if (need_replace) { auto new_op = serialization::copy_opr_shallow(*opr, m_new_inp); replace_var(opr->output(0), new_op->output(0)); } } }; } // anonymous namespace MGB_IMPL_OPR_GRAD(JITExecutor) { VarNodeArray grad_inputs; for (auto input : opr.input()) grad_inputs.push_back(input); mgb_assert(out_grad[0]); grad_inputs.push_back(opr.output(0)); grad_inputs.push_back(out_grad[0]); auto fwd_igraph_ptr = opr.internal_graph_ptr(); auto output_ph = JITPlaceholder::make( fwd_igraph_ptr->output(), fwd_igraph_ptr->placeholders().size()); auto og_ph = JITPlaceholder::make( out_grad[0], fwd_igraph_ptr->placeholders().size() + 1); auto loss = opr::VirtualLoss::make({fwd_igraph_ptr->output()}, {og_ph}); auto gx = cg::grad(loss, fwd_igraph_ptr->placeholders()[wrt_idx]->output(0), false, false); if (!gx.node()) { return nullptr; } if (gx.node()->owner_opr()->same_type()) { return opr::InvalidGrad::make(opr, wrt_idx); } // early return if grad expression is single node for (size_t i = 0; i < fwd_igraph_ptr->placeholders().size(); ++i) { if (gx.node() == fwd_igraph_ptr->placeholders()[i]->output(0)) { return grad_inputs[i]; } } if (gx.node() == og_ph.node()) { return out_grad[0]; } if (gx.node() == fwd_igraph_ptr->output()) { return opr.output(0); } if (auto imm = gopt::try_cast_as_op(gx.node()->owner_opr())) { HostTensorND hval{grad_inputs[0]->comp_node()}; hval.copy_from(imm->value()).sync(); return opr::ImmutableTensor::make(*imm->owner_graph(), hval).node(); } // replace output var in internal graph with output placeholder, so // we could forward opr.output(computeed by forward JITExecutor) into // placeholder to avoid redundant computation InternalGraphRewriter rewriter{gx.node()}; rewriter.iter([&rewriter, &fwd_igraph_ptr, &output_ph](cg::OperatorNodeBase* opr) { if (opr == fwd_igraph_ptr->output()->owner_opr()) { rewriter.replace_var(opr->output(0), output_ph.node()); return; } rewriter.auto_replace_outputs(opr); }); auto expand_into_origin_graph = [&rewriter]( cg::OperatorNodeBase* opr, const VarNodeArray& grad_inputs) { if (auto ph = gopt::try_cast_as_op(opr)) { rewriter.replace_var( opr->output(0), grad_inputs.at(ph->input_id())); return; } if (auto imm = gopt::try_cast_as_op(opr)) { HostTensorND hval{grad_inputs[0]->comp_node()}; hval.copy_from(imm->value()).sync(); rewriter.replace_var(opr->output(0), opr::ImmutableTensor::make(*opr->owner_graph(), hval).node()); return; } rewriter.auto_replace_outputs(opr); }; if (opr.compiler()->property().feature_bits & JITFeatureBits::REDUCE) { // expand the gradient graph into the original graph to handle bcast // oprs using namespace std::placeholders; rewriter.iter(std::bind(expand_into_origin_graph, _1, std::cref(grad_inputs))); return rewriter.dest_var(); } else { VarNodeArray new_grad_inputs; PlaceholderArray placeholders; bool all_inp_const = true; // gx was not depend on all JITPlaceholders so we need to extract used // placeholders and build a new internal graph rewriter.iter([&rewriter, &grad_inputs, &new_grad_inputs, &placeholders, &all_inp_const](cg::OperatorNodeBase* opr) { if (auto ph = gopt::try_cast_as_op(opr)) { new_grad_inputs.push_back(grad_inputs[ph->input_id()]); auto new_ph = JITPlaceholder::make( new_grad_inputs.back(), placeholders.size()) .node()->owner_opr(); placeholders.push_back(new_ph->try_cast_final()); mgb_assert(placeholders.back()); rewriter.replace_var(opr->output(0), new_ph->output(0)); if (!cg::is_const_var_value(new_grad_inputs.back())) { all_inp_const = false; } return; } rewriter.auto_replace_outputs(opr); }); if (all_inp_const) { // if all_inp_const, expand grad graph into origin graph by replace // placeholders with const inputs, so it could benefit from static // infer and const folding mechanism using namespace std::placeholders; rewriter.iter(std::bind(expand_into_origin_graph, _1, std::cref(new_grad_inputs))); return rewriter.dest_var(); } gx = rewriter.dest_var(); auto shape_infer = fwd_igraph_ptr->shape_infer(); if (opr.has_dimshuffle()) { auto&& iter = opr.dimshuffle_params().find( fwd_igraph_ptr->placeholders()[wrt_idx]); if (iter != opr.dimshuffle_params().end()) { auto&& pattern = iter->second.first; auto&& ndim = iter->second.second; std::vector back(ndim, -1); for (size_t i = 0; i < pattern.size(); i ++) { // outdim[i] is indim[j] auto j = pattern[i]; if (j >= 0) { mgb_assert(back[j] == -1, "taking grad for Dimshuffle with duplicated " "input axis unsupported"); back[j] = i; } } shape_infer = opr::Dimshuffle::make(shape_infer, back, pattern.size()).node(); } } auto grad_ig = std::make_shared( gx.node(), shape_infer, nullptr, std::move(placeholders)); auto grad_jit = JITExecutor::make(grad_ig, new_grad_inputs); if (opr.input_broadcastable()[wrt_idx]) { grad_jit = opr::reduce_sum( grad_jit, opr::GetVarShape::make(opr.input(wrt_idx))); } return grad_jit.node(); } } #endif // MGB_ENABLE_GRAD #endif // MGB_JIT // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}