/** * \file src/core/impl/imperative/tensor_sanity_check.cpp * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ #include "megbrain/imperative/tensor_sanity_check.h" #include "./op_trait.h" namespace mgb { namespace imperative { TensorChecksumCalc::ChecksumResult TensorChecksumCalc::calc(TensorPtr ptr) { auto&& dt = ptr->dev_tensor(); if (!dt.layout().total_nr_elems()) { static ChecksumResult empty_checksum; return empty_checksum; } auto span = dt.layout().span(); megdnn::TensorND tensor; tensor.raw_ptr = dt.raw_ptr() + span.low_byte; tensor.layout.init_contiguous_stride({span.dist_byte()}); tensor.layout.dtype = dtype::Byte(); DeviceTensorStorage* workspace; { MGB_LOCK_GUARD(m_workspace_mtx); workspace = &m_workspace[std::this_thread::get_id()] .storage[ptr->comp_node()]; } auto comp_node = ptr->comp_node(); comp_node.activate(); auto opr = opr::intl::get_megdnn_global_opr(comp_node); auto workspace_reqsize = opr->get_workspace_in_bytes(tensor.layout); workspace->comp_node(ptr->comp_node()).ensure_size(workspace_reqsize); megdnn::Workspace mwk; if (workspace_reqsize) mwk = {workspace->ptr(), workspace_reqsize}; return opr->exec(tensor, mwk); } class TensorSanityCheckImpl { public: std::vector>> hook_list; std::unordered_map tensor2chksum; // TODO: may increase device memory overhead TensorSanityCheckImpl() { m_calc = std::make_unique(); } bool check(TensorPtr p); private: std::unique_ptr m_calc; }; bool TensorSanityCheckImpl::check(TensorPtr p) { auto&& it = tensor2chksum.find(p); auto&& chksum = m_calc->calc(p); if (it == tensor2chksum.end()) { tensor2chksum[p] = chksum; return true; } return it->second == chksum; } void TensorSanityCheck::enable() { CompNode::sync_all(); OpTrait::for_each_trait([this](OpTrait& trait) { auto backup = std::make_unique( std::move(trait.apply_on_physical_tensor)); trait.apply_on_physical_tensor = [this, backup = backup.get()] ( const OpDef& def, const SmallVector& inputs) { for (auto&& i: inputs) { if (!m_checker->check(i)) { mgb_throw(TensorChecksumCalc::Error, "tensor modified before exec %s", print_op(def).c_str()); } } auto output = (*backup)(def, inputs); for (auto&& i: output) { mgb_assert(m_checker->check(i)); } for (auto&& i: inputs) { if (!m_checker->check(i)) { mgb_throw(TensorChecksumCalc::Error, "tensor modified after exec %s", print_op(def).c_str()); } } return output; }; m_checker->hook_list.push_back({&trait, std::move(backup)}); }); } void TensorSanityCheck::disable() { for (auto&& hook : m_checker->hook_list) { std::get<0>(hook)->apply_on_physical_tensor = std::move(*std::get<1>(hook)); } m_checker->tensor2chksum.clear(); m_checker->hook_list.clear(); } TensorSanityCheck::TensorSanityCheck() { m_checker = std::make_unique(); } TensorSanityCheck::~TensorSanityCheck () { } std::string TensorSanityCheck::print_op(const OpDef& def){ auto* opr_attr = def.try_cast_final(); if(opr_attr){ return std::string("OprAttr:") + opr_attr->type; } return def.dyn_typeinfo()->name; } } // namespace imperative } // namespace mgb