提交 9de1ea6a 编写于 作者: M Megvii Engine Team

perf(imperative): add apply_on_physical_tensor for Elemwise

GitOrigin-RevId: 27087d90e431d0fbdb0439827f8bf6088781f6a5
上级 469d0808
......@@ -65,10 +65,30 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
return {{{TensorLayout(out_shape, out_dt, inputs[0].layout.format), out_cn}}, true};
}
SmallVector<TensorPtr> apply_on_physical_tensor(
const OpDef& def,
const SmallVector<TensorPtr>& inputs) {
auto&& op_def = def.cast_final_safe<Elemwise>();
auto trait = megdnn::Elemwise::ModeTrait::from_mode(op_def.mode);
mgb_assert(inputs.size() == trait.arity,
"%s expects %u inputs; got %zu actually", trait.name,
trait.arity, inputs.size());
DeviceTensorND out;
SmallVector<DeviceTensorND> dt_inputs(inputs.size());
for (unsigned i = 0; i < inputs.size(); ++i){
dt_inputs[i] = inputs[i]->dev_tensor();
}
auto&& dnn_opr = opr::intl::create_megdnn_opr<megdnn::Elemwise>(inputs[0]->comp_node());
opr::Elemwise::perform(op_def.mode, out, dt_inputs, dnn_opr);
return {Tensor::make(out)};
}
OP_TRAIT_REG(Elemwise, Elemwise, opr::Elemwise)
.make_from_op_node(make_from_op_node)
.apply_on_var_node(apply_on_var_node)
.infer_output_attrs_fallible(infer_output_attrs_fallible)
.apply_on_physical_tensor(apply_on_physical_tensor)
.fallback();
} // anonymous namespace
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册