diff --git a/imperative/src/impl/ops/batch_norm.cpp b/imperative/src/impl/ops/batch_norm.cpp index d8f605dcd0adb3f98cb0dd005664f5d543397305..dc9cdf708f1242ec6ef7e680e98888aedaed1a9d 100644 --- a/imperative/src/impl/ops/batch_norm.cpp +++ b/imperative/src/impl/ops/batch_norm.cpp @@ -10,6 +10,8 @@ */ #include "megbrain/opr/dnn/batch_norm.h" +#include "../blob_manager_impl.h" +#include "../dnn_op_helper.h" #include "../op_trait.h" #include "megbrain/imperative/graph_builder.h" #include "megbrain/imperative/ops/autogen.h" @@ -138,7 +140,7 @@ std::tuple, bool> infer_output_attrs_fallible( SmallVector out_shapes(nr_out); auto&& i0 = inputs[0]; auto&& i1 = inputs[1]; - // [running_mean, running_var,] save_mean, save_var + // [running_mean, running_var,] save_mean, save_variance for (size_t i = 0; i < nr_out - 2; ++i) { out_shapes[i] = {i1.layout, i1.comp_node}; } @@ -148,10 +150,122 @@ std::tuple, bool> infer_output_attrs_fallible( return {out_shapes, out_shapes[nr_out - 1].layout.ndim != 0}; } +SmallVector apply_on_physical_tensor( + const OpDef& def, const SmallVector& inputs, + SmallVector& output_descs, const bool& validated) { + auto&& op_def = def.cast_final_safe(); + auto&& comp_node = inputs[0]->comp_node(); + + using TensorND = megdnn::TensorND; + + SmallVector inp_tensornds(inputs.size()); + for (size_t i = 0; i < inputs.size(); ++i) { + inp_tensornds[i] = inputs[i]->dnn_tensor(); + } + + DnnOprCaller dnn_opr(comp_node); + dnn_opr.op->param() = op_def.param(); + + TensorLayout src_layout = inputs[0]->layout(); + TensorLayout scale_layout = inputs[1]->layout(); + bool empty_input = src_layout.is_empty(); + size_t nr_inp = inputs.size(); + + DeviceTensorND ws, reserve; + size_t sz = 0, rsz = 0; + + TensorLayout w_layout({sz}, dtype::Byte()); + TensorLayout r_layout({rsz}, dtype::Byte()); + + if (!empty_input) { + sz = dnn_opr.op->get_workspace_in_bytes( + src_layout, src_layout, src_layout, src_layout, src_layout, src_layout, + src_layout, src_layout, src_layout); + rsz = dnn_opr.op->get_reserve_in_bytes(src_layout); + + w_layout = TensorLayout({sz}, dtype::Byte()); + r_layout = TensorLayout({rsz}, dtype::Byte()); + } + auto wk = Blob::make(comp_node, sz); + auto ptr = wk->storage().get(); + megdnn::Workspace dnn_wk(ptr, sz); + reserve = BlobManager::inst()->alloc_workspace_with_defrag(comp_node, r_layout); + + // alloc memory + DeviceTensorND y = + BlobManager::inst()->alloc_workspace_with_defrag(comp_node, src_layout); + + DeviceTensorND save_mean = + BlobManager::inst()->alloc_workspace_with_defrag(comp_node, scale_layout); + DeviceTensorND save_variance = + BlobManager::inst()->alloc_workspace_with_defrag(comp_node, scale_layout); + + if (op_def.fwd_mode == ::megdnn::param::BN::FwdMode::INFERENCE) { + if (!empty_input) + dnn_opr.op->exec( + inp_tensornds[0], inp_tensornds[1], inp_tensornds[2], + inp_tensornds[3], inp_tensornds[4], save_mean.as_megdnn(), + save_variance.as_megdnn(), reserve.as_megdnn(), y.as_megdnn(), + dnn_wk); + return {inputs[3], inputs[4], Tensor::make(reserve), Tensor::make(y)}; + } else { + DeviceTensorND mean, variance; + if (nr_inp == 5) { + mean = BlobManager::inst()->alloc_workspace_with_defrag( + comp_node, scale_layout); + variance = BlobManager::inst()->alloc_workspace_with_defrag( + comp_node, scale_layout); + + megdnn::RefPtr src_ptr1( + inp_tensornds[3].get_ref_ptr().get_ptr(), inputs[3]->offset()); + megdnn::RefPtr dst_ptr1( + mean.storage().get_ref_ptr(), mean.storage().offset(), false); + comp_node.peer_copy_to_ref( + comp_node, dst_ptr1, src_ptr1, scale_layout.span().high_byte); + + megdnn::RefPtr src_ptr2( + inp_tensornds[4].get_ref_ptr().get_ptr(), inputs[4]->offset()); + megdnn::RefPtr dst_ptr2( + variance.storage().get_ref_ptr(), variance.storage().offset(), + false); + comp_node.peer_copy_to_ref( + comp_node, dst_ptr2, src_ptr2, scale_layout.span().high_byte); + + if (!empty_input) + dnn_opr.op->exec( + inp_tensornds[0], inp_tensornds[1], inp_tensornds[2], + mean.as_megdnn(), variance.as_megdnn(), save_mean.as_megdnn(), + save_variance.as_megdnn(), reserve.as_megdnn(), y.as_megdnn(), + dnn_wk); + + return {Tensor::make(mean), Tensor::make(variance), + Tensor::make(save_mean), Tensor::make(save_variance), + Tensor::make(reserve), Tensor::make(y)}; + } + + TensorLayout m_layout({0}, scale_layout.dtype); + mean = BlobManager::inst()->alloc_workspace_with_defrag(comp_node, m_layout); + variance = + BlobManager::inst()->alloc_workspace_with_defrag(comp_node, m_layout); + + if (!empty_input) { + dnn_opr.op->exec( + inp_tensornds[0], inp_tensornds[1], inp_tensornds[2], + mean.as_megdnn(), variance.as_megdnn(), save_mean.as_megdnn(), + save_variance.as_megdnn(), reserve.as_megdnn(), y.as_megdnn(), + dnn_wk); + } + + return {Tensor::make(save_mean), Tensor::make(save_variance), + Tensor::make(reserve), Tensor::make(y)}; + } +} + OP_TRAIT_REG(BatchNorm, BatchNorm, opr::BatchNorm) .make_from_op_node(make_from_op_node) .apply_on_var_node(apply_on_var_node) .infer_output_attrs_fallible(infer_output_attrs_fallible) + .apply_on_physical_tensor(apply_on_physical_tensor) .fallback(); } // namespace bn