#include "../dnn_op_helper.h" #include "megbrain/imperative/ops/autogen.h" #include "../op_trait.h" #include "megbrain/opr/indexing.h" #include "megdnn/oprs/general.h" namespace mgb { namespace imperative { namespace { namespace indexing_one_hot { std::tuple, bool> infer_output_attrs_fallible( const OpDef& def, const SmallVector& input_descs) { auto&& op = def.cast_final_safe(); mgb_assert(input_descs.size() == 2, "IndexingOneHot expects two inputs"); auto comp_node = input_descs[0].comp_node; TensorLayout src = input_descs[0].layout, index = input_descs[1].layout; mgb_assert(index.dtype == dtype::Int32(), "index dtype must be int32"); if (!src.ndim) { return {{{{{}, src.dtype}, comp_node}}, false}; } mgb_assert(src.ndim >= 2, "src ndim must be at least 2"); mgb_assert(src.is_contiguous(), "src should be contiguous"); mgb_assert( -static_cast(src.ndim) <= op.axis && op.axis < static_cast(src.ndim), "axis %d not exists in src", op.axis); int real_axis = static_cast(op.axis); if (real_axis < 0) { real_axis += static_cast(src.ndim); } TensorLayout dst = src; dst.shape[real_axis] = 1; dst.init_contiguous_stride(); if (!index.ndim) { return {{{dst, comp_node}}, false}; } mgb_assert(index.is_contiguous(), "index should be all contiguous"); mgb_assert( index.eq_shape(src.remove_axis(real_axis)), "index shape doesn't match src"); return {{{dst, comp_node}}, true}; } auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { auto&& op = def.cast_final_safe(); mgb_assert(inputs.size() == 2); int real_axis = static_cast(op.axis); if (real_axis < 0) { real_axis += static_cast(op.ndim); } OperatorNodeConfig config{op.make_name()}; return opr::IndexingOneHot::make(inputs[0], inputs[1], real_axis, config); } SmallVector apply_on_physical_tensor( const OpDef& def, SmallVector inputs, SmallVector& output_descs, const bool& validated) { auto&& op = def.cast_final_safe(); auto&& inp = inputs[0]; auto&& index = inputs[1]; auto&& layout = inp->layout(); auto&& index_layout = index->layout(); int real_axis = static_cast(op.axis); if (real_axis < 0) { real_axis += static_cast(layout.ndim); } mgb_assert( 0 <= real_axis && real_axis < static_cast(layout.ndim), "Dimension out of range (expected to be in range of [%d, %d], but got %d)", 0, static_cast(layout.ndim) - 1, op.axis); DnnOprCaller dnn_op(inp->comp_node(), real_axis); auto tlayout = dnn_op.deduce_layout(layout, index_layout); auto out = Tensor::make(tlayout, inp->comp_node()); dnn_op.exec_with_ws(inp, index, out); return {out}; } OP_TRAIT_REG(IndexingOneHot, IndexingOneHot) .infer_output_attrs_fallible(infer_output_attrs_fallible) .apply_on_var_node(apply_on_var_node) .apply_on_physical_tensor(apply_on_physical_tensor) .fallback(); } // namespace indexing_one_hot namespace indexing_set_one_hot { std::tuple, bool> infer_output_attrs_fallible( const OpDef& def, const SmallVector& input_descs) { mgb_assert(input_descs.size() == 3, "IndexingSetOneHot expects three inputs"); auto comp_node = input_descs[0].comp_node; auto&& src = input_descs[0].layout; auto&& index = input_descs[1].layout; mgb_assert(index.dtype == dtype::Int32(), "index dtype must be int32"); if (!src.ndim) { return {{{{{}, src.dtype}, comp_node}}, false}; } return {{{{src, src.dtype}, comp_node}}, true}; } auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { auto&& op = static_cast(def); mgb_assert(inputs.size() == 3); int real_axis = static_cast(op.axis); if (real_axis < 0) { real_axis += static_cast(op.ndim); } OperatorNodeConfig config{op.make_name()}; return opr::IndexingSetOneHot::make( inputs[0], inputs[1], inputs[2], real_axis, config); } SmallVector apply_on_physical_tensor( const OpDef& def, SmallVector inputs, SmallVector& output_descs, const bool& validated) { auto&& op = def.cast_final_safe(); auto&& inp = inputs[0]; auto&& index = inputs[1]; auto&& sub = inputs[2]; TensorLayout layout = inp->layout(); mgb_assert(layout.is_contiguous()); int real_axis = static_cast(op.axis); if (real_axis < 0) { real_axis += static_cast(layout.ndim); } DnnOprCaller dnn_op(inp->comp_node(), real_axis); TensorPtr out = Tensor::make(layout, inp->comp_node()); out->dev_tensor().copy_from_fixlayout(inp->dev_tensor()); dnn_op.exec_with_ws(out, index, sub); return {out}; } SmallVector get_input_layout_constraint( const OpDef& def, const SmallVector& inputs) { SmallVector layout_checker(inputs.size()); layout_checker[0] = layout_checker[1] = layout_checker[2] = [](const TensorLayout& layout) { return layout.is_contiguous(); }; return layout_checker; } OP_TRAIT_REG(IndexingSetOneHot, IndexingSetOneHot) .infer_output_attrs_fallible(infer_output_attrs_fallible) .apply_on_var_node(apply_on_var_node) .get_input_layout_constraint(get_input_layout_constraint) .apply_on_physical_tensor(apply_on_physical_tensor) .fallback(); } // namespace indexing_set_one_hot } // anonymous namespace } // namespace imperative } // namespace mgb // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}