/** * \file imperative/src/impl/ops/custom_opdef.cpp * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ #include "megbrain/imperative/ops/custom_opdef.h" #include "megbrain/opr/custom_opnode.h" #include "megbrain/custom/data_adaptor.h" #include "../op_trait.h" namespace mgb { namespace imperative { MGB_DYN_TYPE_OBJ_FINAL_IMPL(CustomOpDef); CustomOpDef::CustomOpDef(const std::shared_ptr &op) : m_op(op), m_param(op->param_info()) {} CustomOpDef::CustomOpDef(const std::shared_ptr &op, const custom::Param ¶m) : m_op(op), m_param(param) {} void CustomOpDef::param(const custom::Param &rhs) { m_param = rhs; } custom::Param &CustomOpDef::param(void) { return m_param; } custom::Param CustomOpDef::param(void) const { return m_param; } size_t CustomOpDef::input_num(void) const { return m_op->input_num(); } size_t CustomOpDef::output_num(void) const { return m_op->output_num(); } std::string CustomOpDef::name(void) const { return m_op->op_type(); } custom::RunTimeId CustomOpDef::runtime_id(void) const { return m_op->runtime_id(); } const std::shared_ptr &CustomOpDef::impl(void) const { return m_op; } void CustomOpDef::compute(const SmallVector &inputs, SmallVector *outputs) const { std::vector custom_inputs = custom::to_custom(inputs); std::vector custom_outputs = custom::to_custom(*outputs); m_op->compute(custom_inputs, this->m_param, custom_outputs); } std::tuple, bool> CustomOpDef::infer_output_attrs( const SmallVector &inputs) const { SmallVector input_descs(inputs.size()); for (int i=0; icomp_node(); input_descs[i].layout = inputs[i]->layout(); } return std::move(this->infer_output_attrs(input_descs)); } std::tuple, bool> CustomOpDef::infer_output_attrs( const SmallVector &inputs) const { SmallVector i_devices(inputs.size()); SmallVector i_shapes(inputs.size()); SmallVector i_dtypes(inputs.size()); SmallVector i_formats(inputs.size()); for (int i=0; i o_devices; SmallVector o_dtypes; SmallVector o_formats; SmallVector o_shapes; o_devices = custom::to_builtin( m_op->infer_output_device( custom::to_custom(i_devices), this->m_param ) ); o_dtypes = custom::to_builtin( m_op->infer_output_dtype( custom::to_custom(i_dtypes), this->m_param ) ); o_formats = custom::to_builtin( m_op->infer_output_format( custom::to_custom(i_formats), this->m_param ) ); if (success) { o_shapes = custom::to_builtin( m_op->infer_output_shape( custom::to_custom(i_shapes), this->m_param ) ); } else { o_shapes = SmallVector(this->output_num()); } SmallVector outputs(this->output_num()); for (int i=0; ioutput_num(); i++) { outputs[i].comp_node = std::move(o_devices[i]); outputs[i].layout = std::move( TensorLayout(o_shapes[i], o_dtypes[i], o_formats[i]) ); } return std::tuple, bool>(outputs, success); } CustomOpDefFactory *CustomOpDefFactory::inst(void) { static CustomOpDefFactory factory; return &factory; } bool CustomOpDefFactory::is_custom_op(const OpDef &op) { return op.dyn_typeinfo() == CustomOpDef::typeinfo(); } CustomOpDefFactory::CustomOpDefFactory() { ops = custom::CustomOpManager::inst(); } std::vector CustomOpDefFactory::op_list(void) const { return ops->op_name_list(); } std::shared_ptr CustomOpDefFactory::create_opdef(const std::string &op_type) const { auto op = ops->find(op_type); return std::make_shared(op); } std::shared_ptr CustomOpDefFactory::create_opdef(const custom::RunTimeId &op_id) const { auto op = ops->find(op_id); return std::make_shared(op); } std::shared_ptr CustomOpDefFactory::create_opdef(const std::string &op_type, const custom::Param ¶m) const { auto op = ops->find(op_type); return std::make_shared(op, param); } std::shared_ptr CustomOpDefFactory::create_opdef(const custom::RunTimeId &op_id, const custom::Param ¶m) const { auto op = ops->find(op_id); return std::make_shared(op, param); } namespace custom_opdef { // avoid name conflict void apply_on_device_tensornd(const OpDef& def, const SmallVector& inputs, SmallVector* outputs) { for (auto &&output: (*outputs)) { auto cn = output.comp_node(); cn.activate(); } CompNode::sync_all(); auto&& op = static_cast(def); op.compute(inputs, outputs); // for (auto &&output: (*outputs)) { // auto cn = output.comp_node(); // cn.sync(); // cannot sync ?????????? // } CompNode::sync_all(); } SmallVector apply_on_physical_tensor( const OpDef& def, const SmallVector &inputs) { auto&& op = static_cast(def); auto [output_descs, success] = op.infer_output_attrs(inputs); mgb_assert(success == true, "infer output attributes fall\n"); SmallVector outputs(output_descs.size()); for (size_t i=0; i inp_tensornds(inputs.size()); SmallVector oup_tensornds(outputs.size()); for (size_t i = 0; i < inputs.size(); ++i) inp_tensornds[i] = inputs[i]->dev_tensor(); for (size_t i = 0; i < outputs.size(); ++i) oup_tensornds[i] = outputs[i]->dev_tensor(); apply_on_device_tensornd(def, inp_tensornds, &oup_tensornds); return outputs; } VarNodeArray apply_on_var_node(const OpDef &def, const cg::VarNodeArray &inputs) { SymbolVarArray input_syms; for (auto &input_var: inputs) input_syms.emplace_back(input_var); auto&& op = static_cast(def); OperatorNodeConfig config; SymbolVarArray output_syms = opr::CustomOpNode::make( op.impl(), input_syms, op.param(), config ); VarNodeArray outputs; for (auto &output_sym: output_syms) outputs.push_back(output_sym.node()); return outputs; } std::tuple, bool> infer_output_attrs_fallible( const OpDef& def, const SmallVector& inputs) { auto&& op = static_cast(def); return op.infer_output_attrs(inputs); } std::tuple, SmallVector> infer_output_mem_desc( const OpDef& def, const SmallVector& inputs_tensors, const SmallVector& inputs_mems) { return {{}, {}}; } size_t hash(const OpDef& def) { auto&& op = static_cast(def); const custom::Param ¶m = op.param(); size_t val = mgb::hash(op.runtime_id()); std::string hash_str = ""; for (auto &&val: param.raw()) { hash_str += val.first; hash_str += val.second.str(); } val = mgb::hash_pair_combine(val, mgb::hash(hash_str)); return val; } bool is_same_st(const OpDef& lhs, const OpDef& rhs) { auto &&a = static_cast(lhs), &&b = static_cast(rhs); return a.param() == b.param() && a.runtime_id() == b.runtime_id(); } std::vector> props(const OpDef& def) { mgb_assert(false, "Custom OpDef Props Function is not IMPLEMENTED now"); // can be implement with param schema // auto&& custom_opdef = def.cast_final_safe(); std::vector> props_; return props_; } std::string make_name(const OpDef& def) { auto&& op = static_cast(def); return op.name(); } } // custom_opdef OP_TRAIT_REG(CustomOpDef, CustomOpDef) .apply_on_physical_tensor(imperative::custom_opdef::apply_on_physical_tensor) .apply_on_var_node(imperative::custom_opdef::apply_on_var_node) .apply_on_device_tensornd(imperative::custom_opdef::apply_on_device_tensornd) .infer_output_attrs_fallible(imperative::custom_opdef::infer_output_attrs_fallible) .infer_output_mem_desc(imperative::custom_opdef::infer_output_mem_desc) .hash(imperative::custom_opdef::hash) .is_same_st(imperative::custom_opdef::is_same_st) .props(imperative::custom_opdef::props) .make_name(imperative::custom_opdef::make_name) .fallback(); } // imperative } // mgb