提交 b7c9361f 编写于 作者: M Megvii Engine Team

perf(mge/functional): add infer_output_attrs_fallible for some ops

GitOrigin-RevId: 33ae4b18e9038469170cde4ecc63427247d40f4b
上级 a4327c4d
......@@ -61,6 +61,44 @@ auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
return opr::Dimshuffle::make(inputs[0], ds.pattern, 0UL, config);
}
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) {
auto&& ds = static_cast<const Dimshuffle&>(def);
mgb_assert(
ds.pattern.size() <= TensorShape::MAX_NDIM,
"Dimshuffle pattern exceeds max length of %zd", TensorShape::MAX_NDIM);
size_t nr_inp = inputs.size();
mgb_assert(nr_inp == 1, "Dimshuffle expects 1 inputs; got %lu actually", nr_inp);
auto&& src = inputs[0];
TensorShape out_shape;
if (src.layout.ndim == 0) {
return {{{TensorLayout(out_shape, src.layout.dtype), src.comp_node}}, false};
}
size_t pattern_ndim = *std::max_element(ds.pattern.begin(), ds.pattern.end()) + 1;
mgb_assert(
src.layout.ndim == pattern_ndim,
"input ndim mismatch for Dimshuffle: expect=%zd actual=%zd", pattern_ndim,
src.layout.ndim);
size_t idx = 0;
bool input_used[TensorLayout::MAX_NDIM] = {0};
for (auto i : ds.pattern) {
if (i < 0) {
out_shape[idx] = 1;
} else {
input_used[i] = true;
out_shape[idx] = src.layout.shape[i];
}
++idx;
}
for (size_t i = 0; i < pattern_ndim; ++i) {
mgb_assert(
input_used[i] || src.layout.shape[i] == 1,
"non-1 dim discarded in Dimshuffle: ishp=%s dim=%zd",
src.layout.megdnn::TensorShape::to_string().c_str(), i);
}
return {{{TensorLayout(out_shape, src.layout.dtype), src.comp_node}}, true};
}
SmallVector<TensorPtr> apply_on_physical_tensor(
const OpDef& def, const SmallVector<TensorPtr>& inputs,
SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) {
......@@ -110,6 +148,7 @@ OP_TRAIT_REG(Dimshuffle, Dimshuffle, opr::Dimshuffle)
.make_from_op_node(make_from_op_node)
.apply_on_var_node(apply_on_var_node)
.apply_on_physical_tensor(apply_on_physical_tensor)
.infer_output_attrs_fallible(infer_output_attrs_fallible)
.fallback();
} // namespace dimshuffle
} // namespace
......@@ -127,6 +166,22 @@ auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
return opr::AxisAddRemove::make(inputs[0], param, config);
}
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) {
auto&& op_def = def.cast_final_safe<AddAxis>();
size_t nr_inp = inputs.size();
mgb_assert(nr_inp == 1, "AddAxis expects 1 inputs; got %lu actually", nr_inp);
auto&& src = inputs[0];
auto olayout = src.layout;
if (src.layout.ndim == 0) {
return {{{TensorLayout(src.layout.dtype), src.comp_node}}, false};
}
for (auto&& i : op_def.axis) {
olayout.add_axis_cont_inplace(i);
}
return {{{olayout, src.comp_node}}, true};
}
SmallVector<TensorPtr> apply_on_physical_tensor(
const OpDef& def, const SmallVector<TensorPtr>& inputs,
SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) {
......@@ -145,6 +200,7 @@ SmallVector<TensorPtr> apply_on_physical_tensor(
OP_TRAIT_REG(AddAxis, AddAxis)
.apply_on_var_node(apply_on_var_node)
.apply_on_physical_tensor(apply_on_physical_tensor)
.infer_output_attrs_fallible(infer_output_attrs_fallible)
.fallback();
} // namespace add_axis
} // namespace
......@@ -188,9 +244,37 @@ SmallVector<TensorPtr> apply_on_physical_tensor(
return {Tensor::make(src->blob(), src->offset(), tlayout)};
}
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) {
auto&& op_def = def.cast_final_safe<RemoveAxis>();
size_t nr_inp = inputs.size();
mgb_assert(nr_inp == 1, "RemoveAxis expects 1 inputs; got %lu actually", nr_inp);
auto&& src = inputs[0];
auto olayout = src.layout;
if (src.layout.ndim == 0) {
return {{{TensorLayout(src.layout.dtype), src.comp_node}}, false};
}
for (auto&& i : op_def.axis) {
if (olayout.ndim == 1) {
mgb_assert(
olayout.shape[0] == 1 && i == 0,
"can not remove axis %u from tensor of shape=%s", i,
olayout.megdnn::TensorShape::to_string().c_str());
} else {
mgb_assert(
i < olayout.ndim && olayout.shape[i] == 1,
"can not remove axis %u from tensor of shape=%s", i,
olayout.megdnn::TensorShape::to_string().c_str());
olayout.remove_axis_inplace(i);
}
}
return {{{olayout, src.comp_node}}, true};
}
OP_TRAIT_REG(RemoveAxis, RemoveAxis)
.apply_on_var_node(apply_on_var_node)
.apply_on_physical_tensor(apply_on_physical_tensor)
.infer_output_attrs_fallible(infer_output_attrs_fallible)
.fallback();
} // namespace remove_axis
} // namespace
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册