#include "megbrain/opr/custom_opnode.h" #if MGB_CUSTOM_OP namespace mgb { namespace opr { MGB_DYN_TYPE_OBJ_FINAL_IMPL(CustomOpNode); void CustomOpNode::infer_output_comp_node(void) { SmallVector input_comp_nodes(input_num()); for (size_t i = 0; i < input_num(); ++i) { input_comp_nodes[i] = input(i)->comp_node(); } SmallVector output_comp_nodes = custom::to_builtin(m_op->infer_output_device( custom::to_custom(input_comp_nodes), m_param)); for (size_t i = 0; i < output_num(); ++i) { mgb_assert( output_comp_nodes[i] == output_comp_nodes[0], "only single comp node operator is supported"); output(i)->comp_node(output_comp_nodes[i]); } m_comp_node = output_comp_nodes[0]; } void CustomOpNode::infer_output_dtype(void) { SmallVector input_dtypes(input_num()); for (size_t i = 0; i < input_num(); ++i) { input_dtypes[i] = input(i)->dtype(); } SmallVector output_dtypes = custom::to_builtin(m_op->infer_output_dtype( custom::to_custom(input_dtypes), m_param)); for (size_t i = 0; i < output_num(); ++i) { output(i)->dtype(output_dtypes[i]); } } void CustomOpNode::infer_output_format(void) { SmallVector input_formats(input_num()); for (size_t i = 0; i < input_num(); ++i) { input_formats[i] = input(i)->format(); } SmallVector output_formats = custom::to_builtin(m_op->infer_output_format( custom::to_custom(input_formats), m_param)); for (size_t i = 0; i < output_num(); ++i) { output(i)->format(output_formats[i]); } } void CustomOpNode::infer_output_shape(void) { SmallVector input_shapes(input_num()); for (size_t i = 0; i < input_num(); ++i) { input_shapes[i] = input(i)->shape(); } SmallVector output_shapes = custom::to_builtin(m_op->infer_output_shape( custom::to_custom(input_shapes), m_param)); for (size_t i = 0; i < output_num(); ++i) { output(i)->shape(output_shapes[i]); } } void CustomOpNode::infer_output_shape( const TensorShapeArray& input_shapes, TensorShapeArray& output_shapes) { output_shapes = custom::to_builtin(m_op->infer_output_shape( custom::to_custom(input_shapes), m_param)); } // called by computing_graph for each output varnode bool CustomOpNode::infer_desc( size_t out_idx, TensorShape& output_shape, const StaticInferInpVal& input_vals) { TensorShapeArray input_shapes(input_vals.val.size()); TensorShapeArray output_shapes(output_num()); for (size_t i = 0; i < input_shapes.size(); ++i) { input_shapes[i] = input_vals.val[i].shape(); } infer_output_shape(input_shapes, output_shapes); output_shape = output_shapes.at(out_idx); return true; } void CustomOpNode::init_output_dtype() { infer_output_dtype(); } void CustomOpNode::init_output_format() { infer_output_format(); } void CustomOpNode::init_output_comp_node() { infer_output_comp_node(); } void CustomOpNode::do_execute(ExecEnv& env) { auto runner = [this]() { std::shared_ptr> inputs = std::make_shared>(); std::shared_ptr> outputs = std::make_shared>(); for (size_t i = 0; i < input_num(); i++) { inputs->emplace_back(input(i)->dev_tensor()); } for (size_t i = 0; i < output_num(); i++) { outputs->emplace_back(output(i)->dev_tensor()); } this->owner_graph()->event().signal_inplace( this, m_comp_node); m_comp_node.activate(); custom::dispatch_custom_op(m_op, m_param, inputs, outputs); this->owner_graph()->event().signal_inplace( this, m_comp_node); }; env.dispatch_on_comp_node(m_comp_node, runner); } void CustomOpNode::init_output_static_infer_desc() { using namespace std::placeholders; using namespace cg::static_infer; m_out_shape.resize(output_num()); auto&& mgr = owner_graph()->static_infer_manager(); DepVal dep; // [TODO] need design a interface to allow user to decide it if (true) { for (auto input_var : input()) dep.push_back({input_var, DepType::SHAPE}); } else { for (auto input_var : input()) dep.push_back({input_var, DepType::VALUE}); } for (size_t i = 0; i < output_num(); ++i) { mgr.register_shape_infer( output(i), {dep.empty() ? SourceType::CONSTANT : SourceType::DEP, dep, std::bind(&CustomOpNode::infer_desc, this, i, _1, _2)}); } } void CustomOpNode::init_output_mem_plan(bool dynamic) { for (auto output_var : output()) { if (cg::is_static_var_storage(output_var) == !dynamic && !output_var->contain_flag(VarNode::Flag::NO_SYS_MEM_ALLOC)) output_var->init_mem_plan(); } } void CustomOpNode::init_rt_force_dynamic_mem_alloc_imply_chain() {} void CustomOpNode::add_input_layout_constraint() { for (auto&& input_var : input()) { input_var->add_layout_constraint_contiguous(); } } void CustomOpNode::mem_plan_fwd_in2out_readonly() {} void CustomOpNode::mem_plan_fwd_in2out_writable() {} cg::OperatorNodeBase::OprEventCallback CustomOpNode::get_opr_event_callback() { return {}; } void CustomOpNode::on_output_comp_node_stream_changed() { for (auto output_var : output()) { if (output_var->comp_node() != m_comp_node) { mgb_assert(output_var->contain_flag(VarNode::Flag::VOLATILE_CONTENT)); output_var->comp_node(m_comp_node); } } } cg::OperatorNodeBase::NodeProp* CustomOpNode::do_make_node_prop() const { return OperatorNodeBase::do_make_node_prop(); } bool CustomOpNode::update_priority() const { if (output_num() == 1 && output()[0]->contain_flag(VarNode::Flag::PERSISTENT_DEVICE_VALUE)) { node_prop().attribute().priority = std::numeric_limits::min(); return true; } return false; } CustomOpNode::CustomOpNode( const std::shared_ptr& op, VarNodeArray inputs, const custom::Param& param, const OperatorNodeConfig& config) : OperatorNodeBase(inputs[0]->owner_graph(), config, op->op_type(), inputs), m_op(op), m_param(param) { mgb_assert(input_num() == inputs.size(), "wrong input tensors list length"); for (size_t i = 0; i < input_num(); ++i) add_input({inputs[i]}); for (size_t i = 0; i < output_num(); ++i) add_output(output_info(i).name()); if (!std::is_empty::value) { using step = unsigned long; size_t STEP_SIZE = sizeof(step); std::string hash_str = std::to_string(op->runtime_id()); for (auto&& val : param.raw()) { hash_str += val.first; hash_str += val.second.str(); } if (hash_str.size() % STEP_SIZE != 0) hash_str += std::string(STEP_SIZE - (hash_str.size() % STEP_SIZE), ' '); for (size_t pos = 0; pos < hash_str.size(); pos += STEP_SIZE) add_equivalence_component>( reinterpret_cast(hash_str.c_str() + pos)); } } VarNodeArray CustomOpNode::make( const std::shared_ptr& op, VarNodeArray inputs, const custom::Param& param, const OperatorNodeConfig& config) { auto&& outputs = inputs[0] ->owner_graph() ->insert_opr(std::make_unique( op, inputs, param, config)) ->output(); return outputs; } SymbolVarArray CustomOpNode::make( const std::shared_ptr& op, SymbolVarArray inputs, const custom::Param& param, const OperatorNodeConfig& config) { VarNodeArray input_vars(inputs.size()); for (size_t i = 0; i < input_vars.size(); ++i) input_vars[i] = inputs[i].node(); auto&& outputs = inputs[0] .node() ->owner_graph() ->insert_opr(std::make_unique( op, input_vars, param, config)) ->output(); SymbolVarArray ret(outputs.size()); for (size_t i = 0; i < ret.size(); ++i) ret[i] = outputs[i]; return ret; } custom::RunTimeId CustomOpNode::runtime_id() const { return m_op->runtime_id(); } uint32_t CustomOpNode::param_tag(void) const { return m_op->param_info().tag(); } custom::Param& CustomOpNode::param(void) { return m_param; } custom::Param CustomOpNode::param(void) const { return m_param; } // a series of functions with the same names as CustomOpImpl std::string CustomOpNode::op_type(void) const { return m_op->op_type(); } std::string CustomOpNode::op_desc(void) const { return m_op->op_desc(); } size_t CustomOpNode::input_num(void) const { return m_op->input_num(); } size_t CustomOpNode::output_num(void) const { return m_op->output_num(); } custom::ArgInfo CustomOpNode::input_info(size_t idx) const { return m_op->input_info(idx); } custom::ArgInfo CustomOpNode::output_info(size_t idx) const { return m_op->output_info(idx); } } // namespace opr } // namespace mgb #endif