#include "megbrain/opr/dnn/instance_norm.h" #include "megbrain/imperative/ops/autogen.h" #include "megbrain/opr/internal/megdnn_opr_wrapper.h" #include "../blob_manager_impl.h" #include "../dnn_op_helper.h" #include "../op_trait.h" namespace mgb::imperative { namespace instance_norm { cg::OperatorNodeBase* apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { auto&& op = static_cast(def); size_t nr_inp = inputs.size(); auto p = op.param(); mgb_assert((nr_inp == 3 && p.affine) || (nr_inp == 1 && !p.affine)); OperatorNodeConfig config{op.make_name()}; if (nr_inp == 3) { return opr::InstanceNorm::make( inputs[0], inputs[1], inputs[2], op.param(), config)[0] .node() ->owner_opr(); } else { return opr::InstanceNorm::make(inputs[0], op.param(), config)[0] .node() ->owner_opr(); } } std::tuple, bool> infer_output_attrs_fallible( const OpDef& def, const SmallVector& inputs) { auto&& instance_norm = def.cast_final_safe(); size_t nr_inp = inputs.size(); auto affine = instance_norm.affine; mgb_assert( (nr_inp == 3 && affine) || (nr_inp == 1 && !affine), "num of inputs of pooling should be 1 or 3 but you give %zu", inputs.size()); auto&& inp = inputs[0]; auto& inp_cn = inp.comp_node; if (inp.layout.ndim == 0) { return {{{TensorLayout{inp.layout.dtype}, inp_cn, {}}, {TensorLayout{dtype::Float32()}, inp_cn, {}}, {TensorLayout{dtype::Float32()}, inp_cn, {}}}, false}; } size_t C = inputs[0].layout.shape[1]; auto p = instance_norm.param(); p.group = C; DnnOprHelper dnn_opr(p); auto&& [oup_layout, mean_layout, rstd_layout] = dnn_opr.deduce_layouts<3>(inp.layout, TensorLayout{}, TensorLayout{}); return {{{oup_layout, inp_cn, {}}, {mean_layout, inp_cn, {}}, {rstd_layout, inp_cn, {}}}, true}; } SmallVector apply_on_physical_tensor( const OpDef& def, const SmallVector& inputs, SmallVector& output_descs, const bool& validated) { auto&& op_def = def.cast_final_safe(); size_t nr_inp = inputs.size(); auto p = op_def.param(); mgb_assert( (nr_inp == 3 && p.affine) || (nr_inp == 1 && !p.affine), "num of inputs of instancenorm should be 1 or 3 but you give %zu", inputs.size()); auto cn = inputs[0]->comp_node(); using Format = megdnn::param::GroupNorm::Format; mgb_assert(p.format == Format::NCHW, "only support inputs in shape NCHW."); size_t C = inputs[0]->shape()[1]; p.group = C; DnnOprCaller caller(cn, p); auto&& [oup_layout, mean_layout, rstd_layout] = caller.deduce_layouts<3>( inputs[0]->layout(), TensorLayout{}, TensorLayout{}); auto out = Tensor::make(oup_layout, cn); auto mean = Tensor::make(mean_layout, cn); auto rstd = Tensor::make(rstd_layout, cn); if (p.affine) { caller.exec_with_ws(inputs[0], inputs[1], inputs[2], out, mean, rstd); } else { megdnn::TensorND empty_dnn; caller.exec_with_ws(inputs[0], empty_dnn, empty_dnn, out, mean, rstd); } return {out, mean, rstd}; } OP_TRAIT_REG(InstanceNorm, InstanceNorm) .apply_on_var_node(apply_on_var_node) .infer_output_attrs_fallible(infer_output_attrs_fallible) .apply_on_physical_tensor(apply_on_physical_tensor) .fallback(); } // namespace instance_norm } // namespace mgb::imperative