提交 51397bfc 编写于 作者: M Megvii Engine Team

feat(mgb): supports value infer and empty input tensor in ElemwiseMultiType

GitOrigin-RevId: 05577a8bc8e214dcd7d7fc138ef952fc881c7a88
上级 247e2f59
......@@ -37,7 +37,6 @@ _ElwMod = builtin.Elemwise.Mode
def _elemwise_multi_type(*args, mode, **kwargs):
op = builtin.ElemwiseMultiType(mode=mode, **kwargs)
args = convert_inputs(*args)
(result,) = apply(op, *args)
return result
......@@ -249,22 +248,22 @@ class ArrayMethodMixin(abc.ABC):
__hash__ = None # due to __eq__ diviates from python convention
__lt__ = lambda self, value: _elemwise_multi_type(
self, value, mode="lt", dtype="Bool"
self, value, mode="lt", dtype="bool"
)
__le__ = lambda self, value: _elemwise_multi_type(
self, value, mode="leq", dtype="Bool"
self, value, mode="leq", dtype="bool"
)
__gt__ = lambda self, value: _elemwise_multi_type(
value, self, mode="lt", dtype="Bool"
value, self, mode="lt", dtype="bool"
)
__ge__ = lambda self, value: _elemwise_multi_type(
value, self, mode="leq", dtype="Bool"
value, self, mode="leq", dtype="bool"
)
__eq__ = lambda self, value: _elemwise_multi_type(
self, value, mode="eq", dtype="Bool"
self, value, mode="eq", dtype="bool"
)
__ne__ = lambda self, value: _elemwise_multi_type(
self, value, mode="neq", dtype="Bool"
self, value, mode="neq", dtype="bool"
)
__neg__ = _unary_elwise(_ElwMod.NEGATE)
......
......@@ -52,7 +52,7 @@ def isnan(inp: Tensor) -> Tensor:
>>> F.isnan(x).numpy()
array([False, True, False])
"""
return _elemwise_multi_type(inp, mode="isnan", dtype="Bool")
return _elemwise_multi_type(inp, mode="isnan", dtype="bool")
def isinf(inp: Tensor) -> Tensor:
......@@ -69,7 +69,7 @@ def isinf(inp: Tensor) -> Tensor:
>>> F.isinf(x).numpy()
array([False, True, False])
"""
return _elemwise_multi_type(inp, mode="isinf", dtype="Bool")
return _elemwise_multi_type(inp, mode="isinf", dtype="bool")
def sign(inp: Tensor):
......
......@@ -118,7 +118,7 @@ PyObject* py_apply(
tensors[i] = tw->m_tensor->data();
} else if (
DTypePromoteCfg::convert_input_enabled &&
op->same_type<Elemwise>()) {
(op->same_type<Elemwise>() || op->same_type<ElemwiseMultiType>())) {
tensors[i] = convert_pyinput_to_tensor(i);
} else {
PyErr_SetString(PyExc_TypeError, "py_apply expects tensor as inputs");
......
......@@ -53,6 +53,41 @@ mgb::DType get_promoted_dtype(const SmallVector<DType>& dtypes) {
return ret;
}
ValueRefList elemwise_multi_type_rule(const OpDef& op, Span<ValueRef> inputs) {
auto&& elem_op = op.cast_final_safe<ElemwiseMultiType>();
static std::unordered_set<ElemwiseMultiType::Mode> cast_case = {
ElemwiseMultiType::Mode::EQ,
ElemwiseMultiType::Mode::NEQ,
ElemwiseMultiType::Mode::LT,
ElemwiseMultiType::Mode::LEQ,
};
if (cast_case.find(elem_op.mode) == cast_case.end()) {
return imperative::apply(op, inputs);
}
SmallVector<DType> dtypes(inputs.size());
for (size_t i = 0; i < inputs.size(); ++i) {
dtypes[i] = *(inputs[i].dtype());
}
ValueRefList converted(inputs.size());
mgb::DType target_dtype = get_promoted_dtype(dtypes);
for (size_t i = 0; i < inputs.size(); ++i) {
if (!is_quantized_dtype(dtypes[i]) && dtypes[i] != target_dtype &&
DTypePromoteCfg::convert_input_enabled) {
converted[i] = imperative::apply(
ApplyOp(*TypeCvt::make(target_dtype)), inputs[i])[0];
dtypes[i] = target_dtype;
} else {
converted[i] = inputs[i];
}
}
return imperative::apply(op, converted);
}
ValueRefList elemwise_rule(const OpDef& op, Span<ValueRef> inputs) {
auto&& elem_op = op.cast_final_safe<Elemwise>();
......@@ -349,6 +384,7 @@ ValueRefList naive_promote_rule(const OpDef& op, Span<ValueRef> inputs) {
struct DTypePromoteRuleRegistry {
DTypePromoteRuleRegistry() {
register_dtype_promote_rule<Elemwise>(elemwise_rule);
register_dtype_promote_rule<ElemwiseMultiType>(elemwise_multi_type_rule);
register_dtype_promote_rule<Concat>(naive_promote_rule);
register_dtype_promote_rule<GroupLocal>(naive_promote_rule);
register_dtype_promote_rule<Reduce>(reduce_rule);
......
......@@ -16,52 +16,6 @@
using namespace mgb;
using namespace opr;
namespace {
//! global operator instance for static inference
template <class Opr>
class StaticInferOpr {
intl::UniqPtrWithCN<Opr> m_opr;
MGB_MUTEX m_mtx;
public:
class Lock {
friend class StaticInferOpr;
StaticInferOpr* m_owner;
explicit Lock(StaticInferOpr* owner) : m_owner{owner} {
#if !__DEPLOY_ON_XP_SP2__
m_owner->m_mtx.lock();
#endif
}
public:
Lock(Lock&& rhs) : m_owner{rhs.m_owner} { rhs.m_owner = nullptr; }
~Lock() {
#if !__DEPLOY_ON_XP_SP2__
if (m_owner)
m_owner->m_mtx.unlock();
#endif
}
Lock& operator=(const Lock&) = delete;
Lock& operator=(Lock&&) = delete;
intl::UniqPtrWithCN<Opr>& operator()() { return m_owner->m_opr; }
};
//! lock and acquire the operator
Lock lock() {
Lock ret{this};
if (!m_opr) {
m_opr = intl::create_megdnn_opr<Opr>(CompNode::default_cpu());
}
return ret;
}
};
} // anonymous namespace
/* ========================= BatchedDTypePromotion ========================= */
intl::BatchedDTypePromotion::BatchedDTypePromotion(const VarNodeArrayView& vars)
: m_orig_vars{vars} {
......
#include "megbrain/opr/nn_int.h"
#include "./internal/megdnn_opr_wrapper.inl"
#include "megbrain/opr/utility.h"
#include "megdnn/oprs/general.h"
using namespace mgb;
......@@ -18,6 +18,7 @@ ElemwiseMultiType::ElemwiseMultiType(
for (auto i : inputs) {
add_input({i});
}
output(0)->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE);
}
SymbolVar ElemwiseMultiType::make(
......@@ -52,8 +53,13 @@ void ElemwiseMultiType::init_output_dtype() {
void ElemwiseMultiType::scn_do_execute() {
megdnn::TensorNDArray inp_arr(input().size());
for (size_t i = 0; i < input().size(); ++i) {
if (input()[i]->dev_tensor().empty()) {
mgb_assert(output(0)->dev_tensor().empty());
return;
}
inp_arr[i] = input()[i]->dev_tensor().as_megdnn();
}
mgb_assert(!output(0)->dev_tensor().empty());
megdnn_opr()->exec(inp_arr, output(0)->dev_tensor().as_megdnn());
}
......@@ -75,4 +81,120 @@ void ElemwiseMultiType::add_input_layout_constraint() {
#endif
}
ElemwiseMultiType::NodeProp* ElemwiseMultiType::do_make_node_prop() const {
auto ret = Super::do_make_node_prop();
for (auto& inp : input()) {
ret->add_dep_type_existing_var(inp, NodeProp::DepType::VALUE_ALLOW_EMPTY);
}
return ret;
}
void ElemwiseMultiType::init_output_static_infer_desc() {
Super::init_output_static_infer_desc();
static StaticInferOpr<megdnn::ElemwiseMultiType> static_infer_opr;
using namespace cg::static_infer;
auto infer_value = [this](DeviceTensorND& dest, const InpVal& inp) {
SmallVector<DeviceTensorND> inp_vals(inp.val.size());
for (size_t i = 0; i < inp_vals.size(); ++i)
inp_vals[i] = inp.val[i].value();
DType out_dt;
auto trait = ModeTrait::from_mode(param().mode);
if (trait.need_specify_out_dtype) {
auto dtype = config().output_dtype();
mgb_assert(dtype.valid());
out_dt = dtype;
} else {
DType dtype;
trait.check_out(dtype, false);
out_dt = dtype;
}
auto sopr = static_infer_opr.lock();
perform(param().mode, out_dt, dest, inp_vals, sopr());
return true;
};
DepVal deps(input().size());
for (size_t i = 0; i < input().size(); ++i)
deps[i] = {input(i), DepType::VALUE};
owner_graph()->static_infer_manager().register_value_infer(
output(0), {SourceType::DEP, deps, infer_value});
}
TensorShape ElemwiseMultiType::get_output_var_shape(
Mode mode, const TensorShapeArray& input_shapes) {
mgb_assert(input_shapes.size() == ModeTrait::from_mode(mode).arity);
TensorShape ret;
megdnn::Elemwise::deduce_shape(input_shapes, ret);
return ret;
}
void ElemwiseMultiType::call_megdnn_opr_exec(
CompNode comp_node, megdnn::TensorNDArray& inp, const megdnn::TensorND& out,
megdnn::ElemwiseMultiType* opr, ElemwiseMultiType* caller) {
// All Elemwise operations on QuantizedS32/QuantizedS8 are not related to
// scale. MegDNN does not support computing Elemwise for
// QuantizedS32/QuantizedS8, we translate the data type to Int32/Int8 before
// passing to MegDNN.
if (inp.size() && inp[0].layout.dtype.category() == DTypeCategory::QUANTIZED) {
auto inp_dtype = inp[0].layout.dtype;
DType compute_dtype;
if (inp_dtype.enumv() == DTypeEnum::QuantizedS32) {
compute_dtype = dtype::Int32();
} else if (inp_dtype.enumv() == DTypeEnum::QuantizedS8) {
compute_dtype = dtype::Int8();
} else {
mgb_throw(
MegBrainError, "Unsupported Quantized Elemwise Mode %s: %d on %s",
inp[0].layout.dtype.name(), int(opr->param().mode),
comp_node.to_string().c_str());
}
megdnn::TensorNDArray run_inp(inp);
for (size_t i = 0; i < inp.size(); i++) {
run_inp[i].layout.dtype = compute_dtype;
}
megdnn::TensorND run_out = out;
run_out.layout.dtype = compute_dtype;
opr->exec(run_inp, run_out);
return;
}
opr->exec(inp, out);
}
void ElemwiseMultiType::perform(
Mode mode, DType out_dt, DeviceTensorND& dest,
const SmallVector<DeviceTensorND>& inputs,
intl::UniqPtrWithCN<megdnn::ElemwiseMultiType>& opr) {
megdnn::TensorNDArray dnn_inputs(inputs.size());
TensorShapeArray inp_shapes(inputs.size());
CompNode out_cn;
for (size_t i = 0; i < inputs.size(); ++i) {
auto&& t = inputs[i];
if (!i) {
out_cn = t.comp_node();
} else {
mgb_assert(t.comp_node() == out_cn);
}
if (t.shape().is_empty()) {
mgb_assert(dest.empty());
return;
}
inp_shapes[i] = t.shape();
}
if (!opr) {
opr = intl::create_megdnn_opr<megdnn::ElemwiseMultiType>(out_cn);
} else {
mgb_assert(out_cn == opr.comp_node());
}
out_cn.activate();
for (size_t i = 0; i < inputs.size(); ++i)
dnn_inputs[i] = inputs[i].as_megdnn();
dest.comp_node(out_cn).dtype(out_dt).resize(get_output_var_shape(mode, inp_shapes));
opr->param() = {mode};
call_megdnn_opr_exec(out_cn, dnn_inputs, dest.as_megdnn(), opr.get(), nullptr);
}
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
......@@ -26,6 +26,14 @@ public:
const VarNodeArrayView& inputs, Param param,
const OperatorNodeConfig& config = {});
MGE_WIN_DECLSPEC_FUC static TensorShape get_output_var_shape(
Mode mode, const TensorShapeArray& input_shapes);
MGE_WIN_DECLSPEC_FUC static void perform(
Mode mode, DType out_dt, DeviceTensorND& dest,
const SmallVector<DeviceTensorND>& inputs,
intl::UniqPtrWithCN<megdnn::ElemwiseMultiType>& opr);
private:
using ModeTrait = megdnn::ElemwiseMultiType::ModeTrait;
......@@ -40,6 +48,14 @@ private:
void record_execute_deps(ExecDependencyArray& deps) override;
void add_input_layout_constraint() override;
NodeProp* do_make_node_prop() const override;
void init_output_static_infer_desc() override;
static void call_megdnn_opr_exec(
CompNode comp_node, megdnn::TensorNDArray& inp, const megdnn::TensorND& out,
megdnn::ElemwiseMultiType* opr, ElemwiseMultiType* caller);
};
//! deprecated; TODO: remove in megbrain 8
......
......@@ -509,6 +509,49 @@ public:
bool is_const() const { return m_is_const; }
};
//! global operator instance for static inference
template <class Opr>
class StaticInferOpr {
intl::UniqPtrWithCN<Opr> m_opr;
MGB_MUTEX m_mtx;
public:
class Lock {
friend class StaticInferOpr;
StaticInferOpr* m_owner;
explicit Lock(StaticInferOpr* owner) : m_owner{owner} {
#if !__DEPLOY_ON_XP_SP2__
m_owner->m_mtx.lock();
#endif
}
public:
Lock(Lock&& rhs) : m_owner{rhs.m_owner} { rhs.m_owner = nullptr; }
~Lock() {
#if !__DEPLOY_ON_XP_SP2__
if (m_owner)
m_owner->m_mtx.unlock();
#endif
}
Lock& operator=(const Lock&) = delete;
Lock& operator=(Lock&&) = delete;
intl::UniqPtrWithCN<Opr>& operator()() { return m_owner->m_opr; }
};
//! lock and acquire the operator
Lock lock() {
Lock ret{this};
if (!m_opr) {
m_opr = intl::create_megdnn_opr<Opr>(CompNode::default_cpu());
}
return ret;
}
};
} // namespace opr
} // namespace mgb
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册