提交 b1ab3646 编写于 作者: M Megvii Engine Team

feat(imperative): add tensor sanity check

GitOrigin-RevId: 27243978e38416e4bea79ee875534e44ca83217a
上级 5aa19f3d
from ..core._imperative_rt import TensorSanityCheckImpl
from ..core._imperative_rt.imperative import sync
class TensorSanityCheck:
r"""An object that checks whether the input tensors of each operator have changed before and after the operation.
Examples:
.. testcode::
from megengine import tensor
from megengine.utils.tensor_sanity_check import TensorSanityCheck
with TensorSanityCheck() as checker:
a = tensor([1, 2])
b = tensor([3, 4])
c = a + b
print(c.numpy())
.. testoutput::
[4 6]
"""
def __init__(self):
self.impl = TensorSanityCheckImpl()
def __enter__(self):
sync()
self.impl.enable()
return self
def __exit__(self, val, type, trace):
sync()
self.impl.disable()
......@@ -23,6 +23,7 @@
#include "megbrain/comp_node.h"
#include "megbrain/imperative/blob_manager.h"
#include "megbrain/imperative/profiler.h"
#include "megbrain/imperative/tensor_sanity_check.h"
#include "megbrain/serialization/helper.h"
#if MGB_ENABLE_OPR_MM
......@@ -225,6 +226,19 @@ void init_utils(py::module m) {
},
py::arg("path") = std::optional<std::string>());
using mgb::imperative::TensorSanityCheck;
py::class_<TensorSanityCheck>(m, "TensorSanityCheckImpl")
.def(py::init<>())
.def("enable",
[](TensorSanityCheck& checker) -> TensorSanityCheck& {
checker.enable();
return checker;
})
.def("disable",
[](TensorSanityCheck& checker) {
checker.disable();
});
#if MGB_ENABLE_OPR_MM
m.def("create_mm_server", &create_zmqrpc_server, py::arg("addr"),
py::arg("port") = 0);
......
......@@ -110,9 +110,20 @@ SmallVector<TensorPtr> apply_on_physical_tensor(
return out;
}
SmallVector<LogicalTensorDesc> infer_output_attrs(
const OpDef& def,
const SmallVector<TensorPtr>& inputs) {
SmallVector<LogicalTensorDesc> out;
for (size_t i = 0; i < 2; ++ i) {
out.push_back({TensorLayout(), inputs[0]->comp_node()});
}
return out;
}
OP_TRAIT_REG(CondTake, CondTake, opr::CondTake)
.apply_on_var_node(apply_on_var_node)
.apply_on_physical_tensor(apply_on_physical_tensor)
.infer_output_attrs(infer_output_attrs)
.fallback();
} // namespace
......
/**
* \file src/core/impl/imperative/tensor_sanity_check.cpp
*
* This file is part of MegBrain, a deep learning framework developed by Megvii.
*
* \copyright Copyright (c) 2014-2019 Megvii Inc. All rights reserved.
*
*/
#include "megbrain/imperative/tensor_sanity_check.h"
#include "./op_trait.h"
namespace mgb {
namespace imperative {
TensorChecksumCalc::ChecksumResult TensorChecksumCalc::calc(TensorPtr ptr) {
auto&& dt = ptr->dev_tensor();
if (!dt.layout().total_nr_elems()) {
static ChecksumResult empty_checksum;
return empty_checksum;
}
auto span = dt.layout().span();
megdnn::TensorND tensor;
tensor.raw_ptr = dt.raw_ptr() + span.low_byte;
tensor.layout.init_contiguous_stride({span.dist_byte()});
tensor.layout.dtype = dtype::Byte();
DeviceTensorStorage* workspace;
{
MGB_LOCK_GUARD(m_workspace_mtx);
workspace = &m_workspace[std::this_thread::get_id()]
.storage[ptr->comp_node()];
}
auto comp_node = ptr->comp_node();
comp_node.activate();
auto opr = opr::intl::get_megdnn_global_opr<megdnn::Checksum>(comp_node);
auto workspace_reqsize = opr->get_workspace_in_bytes(tensor.layout);
workspace->comp_node(ptr->comp_node()).ensure_size(workspace_reqsize);
megdnn::Workspace mwk;
if (workspace_reqsize)
mwk = {workspace->ptr(), workspace_reqsize};
return opr->exec(tensor, mwk);
}
class TensorSanityCheckImpl {
public:
std::vector<std::tuple<OpTrait*, std::unique_ptr<ApplyOnPhysicalTensor>>>
hook_list;
std::unordered_map<TensorPtr, TensorChecksumCalc::ChecksumResult>
tensor2chksum; // TODO: may increase device memory overhead
TensorSanityCheckImpl() {
m_calc = std::make_unique<TensorChecksumCalc>();
}
bool check(TensorPtr p);
private:
std::unique_ptr<TensorChecksumCalc> m_calc;
};
bool TensorSanityCheckImpl::check(TensorPtr p) {
auto&& it = tensor2chksum.find(p);
auto&& chksum = m_calc->calc(p);
if (it == tensor2chksum.end()) {
tensor2chksum[p] = chksum;
return true;
}
return it->second == chksum;
}
void TensorSanityCheck::enable() {
CompNode::sync_all();
OpTrait::for_each_trait([this](OpTrait& trait) {
auto backup = std::make_unique<ApplyOnPhysicalTensor>(
std::move(trait.apply_on_physical_tensor));
trait.apply_on_physical_tensor = [this, backup = backup.get()] (
const OpDef& def, const SmallVector<TensorPtr>& inputs) {
for (auto&& i: inputs) {
if (!m_checker->check(i)) {
mgb_throw(TensorChecksumCalc::Error,
"tensor modified before exec %s", print_op(def).c_str());
}
}
auto output = (*backup)(def, inputs);
for (auto&& i: output) {
mgb_assert(m_checker->check(i));
}
for (auto&& i: inputs) {
if (!m_checker->check(i)) {
mgb_throw(TensorChecksumCalc::Error,
"tensor modified after exec %s", print_op(def).c_str());
}
}
return output;
};
m_checker->hook_list.push_back({&trait, std::move(backup)});
});
}
void TensorSanityCheck::disable() {
for (auto&& hook : m_checker->hook_list) {
std::get<0>(hook)->apply_on_physical_tensor =
std::move(*std::get<1>(hook));
}
m_checker->tensor2chksum.clear();
m_checker->hook_list.clear();
}
TensorSanityCheck::TensorSanityCheck() {
m_checker = std::make_unique<TensorSanityCheckImpl>();
}
TensorSanityCheck::~TensorSanityCheck () {
}
std::string TensorSanityCheck::print_op(const OpDef& def){
auto* opr_attr = def.try_cast_final<const OprAttr>();
if(opr_attr){
return std::string("OprAttr:") + opr_attr->type;
}
return def.dyn_typeinfo()->name;
}
} // namespace imperative
} // namespace mgb
\ No newline at end of file
/**
* \file src/core/include/megbrain/tensor_sanity_check.h
*
* This file is part of MegBrain, a deep learning framework developed by Megvii.
*
* \copyright Copyright (c) 2014-2019 Megvii Inc. All rights reserved.
*
*/
#include "megbrain/comp_node_env.h"
#include "megbrain/imperative/ops/opr_attr.h"
#include "megbrain/imperative/op_def.h"
#include "megbrain/plugin/var_sanity_check.h"
#include "megbrain/opr/internal/megdnn_opr_wrapper.h"
#include "megdnn/oprs/general.h"
namespace mgb {
namespace imperative {
class TensorChecksumCalc {
public:
using ChecksumResult = megdnn::opr_result::Checksum;
using Error = VarSanityCheckError;
struct WorkspaceCache {
//! var comp node to workspace
CompNode::UnorderedMap<DeviceTensorStorage> storage;
};
ThinHashMap<std::thread::id, WorkspaceCache> m_workspace;
std::mutex m_workspace_mtx;
ChecksumResult calc(TensorPtr ptr);
TensorChecksumCalc() {}
};
class TensorSanityCheckImpl;
class TensorSanityCheck {
public:
TensorSanityCheck();
~TensorSanityCheck();
void enable();
void disable();
std::string print_op(const OpDef& def);
private:
std::unique_ptr<TensorSanityCheckImpl> m_checker;
};
} // namespace imperative
} // namespace mgb
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册