提交 da620ca1 编写于 作者: M Megvii Engine Team

perf(imperative): specialize batchnorm implementation

GitOrigin-RevId: 83a82590441b9ea4078e5df3117f788652e96745
上级 5ebc9d50
......@@ -10,6 +10,8 @@
*/
#include "megbrain/opr/dnn/batch_norm.h"
#include "../blob_manager_impl.h"
#include "../dnn_op_helper.h"
#include "../op_trait.h"
#include "megbrain/imperative/graph_builder.h"
#include "megbrain/imperative/ops/autogen.h"
......@@ -138,7 +140,7 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
SmallVector<LogicalTensorDesc> out_shapes(nr_out);
auto&& i0 = inputs[0];
auto&& i1 = inputs[1];
// [running_mean, running_var,] save_mean, save_var
// [running_mean, running_var,] save_mean, save_variance
for (size_t i = 0; i < nr_out - 2; ++i) {
out_shapes[i] = {i1.layout, i1.comp_node};
}
......@@ -148,10 +150,122 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
return {out_shapes, out_shapes[nr_out - 1].layout.ndim != 0};
}
SmallVector<TensorPtr> apply_on_physical_tensor(
const OpDef& def, const SmallVector<TensorPtr>& inputs,
SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) {
auto&& op_def = def.cast_final_safe<BatchNorm>();
auto&& comp_node = inputs[0]->comp_node();
using TensorND = megdnn::TensorND;
SmallVector<TensorND> inp_tensornds(inputs.size());
for (size_t i = 0; i < inputs.size(); ++i) {
inp_tensornds[i] = inputs[i]->dnn_tensor();
}
DnnOprCaller<megdnn::BN> dnn_opr(comp_node);
dnn_opr.op->param() = op_def.param();
TensorLayout src_layout = inputs[0]->layout();
TensorLayout scale_layout = inputs[1]->layout();
bool empty_input = src_layout.is_empty();
size_t nr_inp = inputs.size();
DeviceTensorND ws, reserve;
size_t sz = 0, rsz = 0;
TensorLayout w_layout({sz}, dtype::Byte());
TensorLayout r_layout({rsz}, dtype::Byte());
if (!empty_input) {
sz = dnn_opr.op->get_workspace_in_bytes(
src_layout, src_layout, src_layout, src_layout, src_layout, src_layout,
src_layout, src_layout, src_layout);
rsz = dnn_opr.op->get_reserve_in_bytes(src_layout);
w_layout = TensorLayout({sz}, dtype::Byte());
r_layout = TensorLayout({rsz}, dtype::Byte());
}
auto wk = Blob::make(comp_node, sz);
auto ptr = wk->storage().get();
megdnn::Workspace dnn_wk(ptr, sz);
reserve = BlobManager::inst()->alloc_workspace_with_defrag(comp_node, r_layout);
// alloc memory
DeviceTensorND y =
BlobManager::inst()->alloc_workspace_with_defrag(comp_node, src_layout);
DeviceTensorND save_mean =
BlobManager::inst()->alloc_workspace_with_defrag(comp_node, scale_layout);
DeviceTensorND save_variance =
BlobManager::inst()->alloc_workspace_with_defrag(comp_node, scale_layout);
if (op_def.fwd_mode == ::megdnn::param::BN::FwdMode::INFERENCE) {
if (!empty_input)
dnn_opr.op->exec(
inp_tensornds[0], inp_tensornds[1], inp_tensornds[2],
inp_tensornds[3], inp_tensornds[4], save_mean.as_megdnn(),
save_variance.as_megdnn(), reserve.as_megdnn(), y.as_megdnn(),
dnn_wk);
return {inputs[3], inputs[4], Tensor::make(reserve), Tensor::make(y)};
} else {
DeviceTensorND mean, variance;
if (nr_inp == 5) {
mean = BlobManager::inst()->alloc_workspace_with_defrag(
comp_node, scale_layout);
variance = BlobManager::inst()->alloc_workspace_with_defrag(
comp_node, scale_layout);
megdnn::RefPtr src_ptr1(
inp_tensornds[3].get_ref_ptr().get_ptr(), inputs[3]->offset());
megdnn::RefPtr dst_ptr1(
mean.storage().get_ref_ptr(), mean.storage().offset(), false);
comp_node.peer_copy_to_ref(
comp_node, dst_ptr1, src_ptr1, scale_layout.span().high_byte);
megdnn::RefPtr src_ptr2(
inp_tensornds[4].get_ref_ptr().get_ptr(), inputs[4]->offset());
megdnn::RefPtr dst_ptr2(
variance.storage().get_ref_ptr(), variance.storage().offset(),
false);
comp_node.peer_copy_to_ref(
comp_node, dst_ptr2, src_ptr2, scale_layout.span().high_byte);
if (!empty_input)
dnn_opr.op->exec(
inp_tensornds[0], inp_tensornds[1], inp_tensornds[2],
mean.as_megdnn(), variance.as_megdnn(), save_mean.as_megdnn(),
save_variance.as_megdnn(), reserve.as_megdnn(), y.as_megdnn(),
dnn_wk);
return {Tensor::make(mean), Tensor::make(variance),
Tensor::make(save_mean), Tensor::make(save_variance),
Tensor::make(reserve), Tensor::make(y)};
}
TensorLayout m_layout({0}, scale_layout.dtype);
mean = BlobManager::inst()->alloc_workspace_with_defrag(comp_node, m_layout);
variance =
BlobManager::inst()->alloc_workspace_with_defrag(comp_node, m_layout);
if (!empty_input) {
dnn_opr.op->exec(
inp_tensornds[0], inp_tensornds[1], inp_tensornds[2],
mean.as_megdnn(), variance.as_megdnn(), save_mean.as_megdnn(),
save_variance.as_megdnn(), reserve.as_megdnn(), y.as_megdnn(),
dnn_wk);
}
return {Tensor::make(save_mean), Tensor::make(save_variance),
Tensor::make(reserve), Tensor::make(y)};
}
}
OP_TRAIT_REG(BatchNorm, BatchNorm, opr::BatchNorm)
.make_from_op_node(make_from_op_node)
.apply_on_var_node(apply_on_var_node)
.infer_output_attrs_fallible(infer_output_attrs_fallible)
.apply_on_physical_tensor(apply_on_physical_tensor)
.fallback();
} // namespace bn
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册