提交 46cad4d3 编写于 作者: M Megvii Engine Team

feat(functional/ops): add _assert_equal

GitOrigin-RevId: b7ce4158b7087886e7a9aef5c89b682cae26c646
上级 585aa561
......@@ -7,7 +7,7 @@
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# pylint: disable=redefined-builtin
from . import metric, vision
from . import metric, utils, vision
from .elemwise import *
from .math import *
from .nn import *
......
......@@ -11,6 +11,7 @@ from typing import Iterable, Union
import numpy as np
from ..tensor import Tensor
from .elemwise import abs, maximum, minimum
from .math import topk as _topk
from .tensor import broadcast_to, transpose
......
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from ..core._imperative_rt.core2 import apply
from ..core._imperative_rt.core2 import sync as _sync
from ..core.ops.builtin import AssertEqual
from ..tensor import Tensor
from .elemwise import abs, maximum, minimum
def _assert_equal(
expect: Tensor, actual: Tensor, *, maxerr: float = 0.0001, verbose: bool = False
):
r"""
Asserts two tensors equal and returns expected value (first input).
It is a variant of python assert which is symbolically traceable (similar to ``numpy.testing.assert_equal``).
If we want to verify the correctness of model, just ``assert`` its states and outputs.
While sometimes we need to verify the correctness at different backends for *dumped* model
(or in :class:`~jit.trace` context), and no python code could be executed in that case.
Thus we have to use :func:`~functional.utils._assert_equal` instead.
:param expect: expected tensor value
:param actual: tensor to check value
:param maxerr: max allowed error; error is defined as the minimal of absolute and relative error
:param verbose: whether to print maxerr to stdout during opr exec
:return: expected tensor
Examples:
.. testcode::
import numpy as np
from megengine import tensor
import megengine.functional as F
x = tensor([1, 2, 3], np.float32)
y = tensor([1, 2, 3], np.float32)
print(F.utils._assert_equal(x, y, maxerr=0).numpy())
Outputs:
.. testoutput::
[1. 2. 3.]
"""
err = (
abs(expect - actual)
/ maximum(minimum(abs(expect), abs(actual)), Tensor(1.0, dtype="float32"))
).max()
result = apply(AssertEqual(maxerr=maxerr, verbose=verbose), expect, actual, err)[0]
_sync() # sync interpreter to get exception
return result
......@@ -28,7 +28,12 @@ from ..core._imperative_rt.core2 import (
unset_compiled,
unset_tracing,
)
from ..core._imperative_rt.ops import CollectiveComm, RemoteRecv, RemoteSend
from ..core._imperative_rt.ops import (
AssertEqual,
CollectiveComm,
RemoteRecv,
RemoteSend,
)
from ..core._trace_option import set_symbolic_shape
from ..core._wrap import device as as_device
from ..core.ops.builtin import BackwardGraph, OpDef
......@@ -110,7 +115,7 @@ class TensorInfo:
self.data_reader = None
_io_op_types = {CollectiveComm, RemoteSend, RemoteRecv}
_io_op_types = {AssertEqual, CollectiveComm, RemoteSend, RemoteRecv}
class trace:
......
......@@ -21,6 +21,7 @@ from megengine.core._trace_option import use_symbolic_shape
from megengine.core.autodiff.grad import Grad
from megengine.core.tensor.utils import make_shape_tuple
from megengine.distributed.helper import get_device_count_by_fork
from megengine.jit import trace
def test_where():
......@@ -746,3 +747,18 @@ def test_ones(val):
shp = tensor(val)
np_shp = np.array(val)
np.testing.assert_equal(F.ones(shp), np.ones(np_shp))
def test_assert_equal():
shape = (2, 3, 4, 5)
x = F.ones(shape, dtype=np.float32)
y = F.zeros(shape, dtype=np.float32) + 1.00001
z = F.utils._assert_equal(x, y)
def test_assert_not_equal():
shape = (2, 3, 4, 5)
x = F.ones(shape, dtype=np.float32)
y = F.zeros(shape, dtype=np.float32) + 1.1
with pytest.raises(RuntimeError):
z = F.utils._assert_equal(x, y)
......@@ -451,20 +451,22 @@ OP_TRAIT_REG(Identity, Identity)
namespace { namespace assert_equal {
auto apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
auto&& op = static_cast<const AssertEqual&>(def);
mgb_assert(inputs.size() == 2);
OperatorNodeConfig config{op.make_name()};
return opr::AssertEqual::make(inputs[0], inputs[1], op.param(), config);
const OpDef& def,
const VarNodeArray& inputs) {
auto&& op = def.cast_final<AssertEqual>();
if (inputs.size() == 2) {
return opr::AssertEqual::make(inputs[0], inputs[1], op.param());
} else {
// workaround for MiniGraph, which only allow one opr in the graph
mgb_assert(inputs.size() == 3);
return opr::AssertEqual::make(inputs[0], inputs[1], inputs[2], op.param(), {});
}
}
OP_TRAIT_REG(AssertEqual, AssertEqual)
.apply_on_var_node(apply_on_var_node)
.fallback();
}}
}} // assert_equal
namespace { namespace uniform_rng {
auto apply_on_var_node(
......
......@@ -445,6 +445,12 @@ public:
size_t nr_oprs_in_graph() const override {return m_opr_refkeeper.size();}
void record_async_error(std::unique_ptr<MegBrainError> async_exc) override {
if (!ProxyGraph::tm_async_error) {
std::swap(async_exc, tm_async_error);
}
}
std::unique_ptr<cg::AsyncExecutable> compile(const OutputSpec &out_spec) override {mgb_assert(0);}
SmallVector<std::unique_ptr<cg::AsyncExecutable>> compile_multi_part(
const SmallVector<OutputSpec>& out_specs) override {mgb_assert(0);}
......@@ -457,7 +463,6 @@ public:
size_t get_device_memory_size(CompNode cn) override {mgb_assert(0);}
size_t clear_device_memory() override {mgb_assert(0);}
void set_as_subgraph(ComputingGraph &par_graph) override {mgb_assert(0);}
void record_async_error(std::unique_ptr<MegBrainError> async_exc) override {mgb_assert(0);}
};
std::atomic<size_t> ProxyGraph::ProxyGraphImpl::m_node_id = 0;
......@@ -861,6 +866,8 @@ TensorPtr ProxyGraph::as_tensor(cg::OperatorNodeBase* opr, bool share) {
}
}
thread_local std::unique_ptr<MegBrainError> ProxyGraph::tm_async_error;
} // namespace imperative
} // namespace mgb
......
......@@ -24,6 +24,9 @@ namespace imperative {
class ProxyGraph : public NonCopyableObj {
public:
static ProxyGraph* get_default_graph();
static std::unique_ptr<MegBrainError> get_async_error() {
return std::move(tm_async_error);
}
/********************** Physical Tensor API **********************/
......@@ -98,6 +101,8 @@ private:
std::unique_ptr<ExecEnv> m_env;
std::unique_ptr<StaticInferManager> m_static_infer_manager;
std::unique_ptr<SeqCompNodeOptimizer> m_seq_comp_node_optimizer;
static thread_local std::unique_ptr<MegBrainError> tm_async_error;
};
} // namespace imperative
......
......@@ -101,6 +101,10 @@ apply_on_physical_tensor(const OpDef& def,
}
}
exec(def, inputs, outputs);
auto async_error = ProxyGraph::get_async_error();
if (async_error) {
throw *async_error;
}
return outputs;
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册