提交 73d25779 编写于 作者: M Megvii Engine Team

fix(ops): check index layout for IndexingOneHot

GitOrigin-RevId: 80b42a8de923847560d7d338eca1da824dddff8d
上级 10198650
/**
* \file imperative/src/impl/ops/indexing.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "megbrain/imperative/ops/autogen.h"
#include "../op_trait.h"
#include "megbrain/opr/indexing.h"
namespace mgb {
namespace imperative {
namespace {
namespace indexing_one_hot {
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
const OpDef& def,
const SmallVector<LogicalTensorDesc>& input_descs) {
auto& op = def.cast_final_safe<IndexingOneHot>();
mgb_assert(input_descs.size() == 2, "IndexingOneHot expects two inputs");
auto comp_node = input_descs[0].comp_node;
TensorLayout src = input_descs[0].layout,
index = input_descs[1].layout;
mgb_assert(index.dtype == dtype::Int32(), "index dtype must be int32");
if (!src.ndim) {
return {{{{{}, src.dtype}, comp_node}}, false};
}
mgb_assert(src.ndim >= 2, "src ndim must be at least 2");
mgb_assert(src.is_contiguous(), "src should be contiguous");
mgb_assert(op.axis >= 0 && op.axis < src.ndim, "axis %d not exists in src", op.axis);
TensorLayout dst = src;
dst.shape[op.axis] = 1;
dst.init_contiguous_stride();
if (!index.ndim) {
return {{{dst, comp_node}}, false};
}
mgb_assert(index.is_contiguous(), "index should be all contiguous");
mgb_assert(index.eq_shape(src.remove_axis(op.axis)), "index shape doesn't match src");
return {{{dst, comp_node}}, true};
}
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
auto&& op = static_cast<const IndexingOneHot&>(def);
mgb_assert(inputs.size() == 2);
OperatorNodeConfig config{op.make_name()};
return opr::IndexingOneHot::make(inputs[0], inputs[1], op.param(), config);
}
OP_TRAIT_REG(IndexingOneHot, IndexingOneHot)
.infer_output_attrs_fallible(infer_output_attrs_fallible)
.apply_on_var_node(apply_on_var_node)
.fallback();
} // namespace indexing_one_hot
} // anonymous namespace
} // namespace imperative
} // namespace mgb
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
......@@ -302,20 +302,6 @@ OP_TRAIT_REG(GroupLocal, GroupLocal)
} // namespace group_local
} // namespace
namespace {
namespace indexing_one_hot {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
auto&& op = static_cast<const IndexingOneHot&>(def);
mgb_assert(inputs.size() == 2);
OperatorNodeConfig config{op.make_name()};
return opr::IndexingOneHot::make(inputs[0], inputs[1], op.param(), config);
}
OP_TRAIT_REG(IndexingOneHot, IndexingOneHot)
.apply_on_var_node(apply_on_var_node)
.fallback();
} // namespace indexing_one_hot
} // namespace
namespace {
namespace indexing_set_one_hot {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册