diff --git a/imperative/python/megengine/core/tensor/array_method.py b/imperative/python/megengine/core/tensor/array_method.py index 78b2668f0204296d40ff95f5f02279bfecd485cd..cd8bf458374cb9e933297f441046755308b92429 100644 --- a/imperative/python/megengine/core/tensor/array_method.py +++ b/imperative/python/megengine/core/tensor/array_method.py @@ -37,7 +37,6 @@ _ElwMod = builtin.Elemwise.Mode def _elemwise_multi_type(*args, mode, **kwargs): op = builtin.ElemwiseMultiType(mode=mode, **kwargs) - args = convert_inputs(*args) (result,) = apply(op, *args) return result @@ -249,22 +248,22 @@ class ArrayMethodMixin(abc.ABC): __hash__ = None # due to __eq__ diviates from python convention __lt__ = lambda self, value: _elemwise_multi_type( - self, value, mode="lt", dtype="Bool" + self, value, mode="lt", dtype="bool" ) __le__ = lambda self, value: _elemwise_multi_type( - self, value, mode="leq", dtype="Bool" + self, value, mode="leq", dtype="bool" ) __gt__ = lambda self, value: _elemwise_multi_type( - value, self, mode="lt", dtype="Bool" + value, self, mode="lt", dtype="bool" ) __ge__ = lambda self, value: _elemwise_multi_type( - value, self, mode="leq", dtype="Bool" + value, self, mode="leq", dtype="bool" ) __eq__ = lambda self, value: _elemwise_multi_type( - self, value, mode="eq", dtype="Bool" + self, value, mode="eq", dtype="bool" ) __ne__ = lambda self, value: _elemwise_multi_type( - self, value, mode="neq", dtype="Bool" + self, value, mode="neq", dtype="bool" ) __neg__ = _unary_elwise(_ElwMod.NEGATE) diff --git a/imperative/python/megengine/functional/math.py b/imperative/python/megengine/functional/math.py index 0fa09958f73190c69e10e24f28af81b0575340b0..978221ac48c180736cfcf87ea5c362e1aea7b5d2 100644 --- a/imperative/python/megengine/functional/math.py +++ b/imperative/python/megengine/functional/math.py @@ -52,7 +52,7 @@ def isnan(inp: Tensor) -> Tensor: >>> F.isnan(x).numpy() array([False, True, False]) """ - return _elemwise_multi_type(inp, mode="isnan", dtype="Bool") + return _elemwise_multi_type(inp, mode="isnan", dtype="bool") def isinf(inp: Tensor) -> Tensor: @@ -69,7 +69,7 @@ def isinf(inp: Tensor) -> Tensor: >>> F.isinf(x).numpy() array([False, True, False]) """ - return _elemwise_multi_type(inp, mode="isinf", dtype="Bool") + return _elemwise_multi_type(inp, mode="isinf", dtype="bool") def sign(inp: Tensor): diff --git a/imperative/python/src/tensor.cpp b/imperative/python/src/tensor.cpp index ef67ad30a5cb7fe446eca37e415f993005489e05..3409b7d7d65eedb46b22404bd9aeba14dd6aff7f 100644 --- a/imperative/python/src/tensor.cpp +++ b/imperative/python/src/tensor.cpp @@ -118,7 +118,7 @@ PyObject* py_apply( tensors[i] = tw->m_tensor->data(); } else if ( DTypePromoteCfg::convert_input_enabled && - op->same_type()) { + (op->same_type() || op->same_type())) { tensors[i] = convert_pyinput_to_tensor(i); } else { PyErr_SetString(PyExc_TypeError, "py_apply expects tensor as inputs"); diff --git a/imperative/src/impl/transformations/dtype_promote.cpp b/imperative/src/impl/transformations/dtype_promote.cpp index c4f4547a7d3f697b8df239266fa36ea648a75130..58de880a7364ff5785c6c6b87f3e026d2fe9f6d0 100644 --- a/imperative/src/impl/transformations/dtype_promote.cpp +++ b/imperative/src/impl/transformations/dtype_promote.cpp @@ -53,6 +53,41 @@ mgb::DType get_promoted_dtype(const SmallVector& dtypes) { return ret; } +ValueRefList elemwise_multi_type_rule(const OpDef& op, Span inputs) { + auto&& elem_op = op.cast_final_safe(); + static std::unordered_set cast_case = { + ElemwiseMultiType::Mode::EQ, + ElemwiseMultiType::Mode::NEQ, + ElemwiseMultiType::Mode::LT, + ElemwiseMultiType::Mode::LEQ, + }; + + if (cast_case.find(elem_op.mode) == cast_case.end()) { + return imperative::apply(op, inputs); + } + + SmallVector dtypes(inputs.size()); + for (size_t i = 0; i < inputs.size(); ++i) { + dtypes[i] = *(inputs[i].dtype()); + } + + ValueRefList converted(inputs.size()); + mgb::DType target_dtype = get_promoted_dtype(dtypes); + + for (size_t i = 0; i < inputs.size(); ++i) { + if (!is_quantized_dtype(dtypes[i]) && dtypes[i] != target_dtype && + DTypePromoteCfg::convert_input_enabled) { + converted[i] = imperative::apply( + ApplyOp(*TypeCvt::make(target_dtype)), inputs[i])[0]; + dtypes[i] = target_dtype; + } else { + converted[i] = inputs[i]; + } + } + + return imperative::apply(op, converted); +} + ValueRefList elemwise_rule(const OpDef& op, Span inputs) { auto&& elem_op = op.cast_final_safe(); @@ -349,6 +384,7 @@ ValueRefList naive_promote_rule(const OpDef& op, Span inputs) { struct DTypePromoteRuleRegistry { DTypePromoteRuleRegistry() { register_dtype_promote_rule(elemwise_rule); + register_dtype_promote_rule(elemwise_multi_type_rule); register_dtype_promote_rule(naive_promote_rule); register_dtype_promote_rule(naive_promote_rule); register_dtype_promote_rule(reduce_rule); diff --git a/src/opr/impl/basic_arith.cpp b/src/opr/impl/basic_arith.cpp index 93309027b7fa64236efa1c7a3436b390e87a00fd..8111b4e0585a049ce5e3f46822709e4f052c4e67 100644 --- a/src/opr/impl/basic_arith.cpp +++ b/src/opr/impl/basic_arith.cpp @@ -16,52 +16,6 @@ using namespace mgb; using namespace opr; -namespace { - -//! global operator instance for static inference -template -class StaticInferOpr { - intl::UniqPtrWithCN m_opr; - MGB_MUTEX m_mtx; - -public: - class Lock { - friend class StaticInferOpr; - StaticInferOpr* m_owner; - - explicit Lock(StaticInferOpr* owner) : m_owner{owner} { -#if !__DEPLOY_ON_XP_SP2__ - m_owner->m_mtx.lock(); -#endif - } - - public: - Lock(Lock&& rhs) : m_owner{rhs.m_owner} { rhs.m_owner = nullptr; } - - ~Lock() { -#if !__DEPLOY_ON_XP_SP2__ - if (m_owner) - m_owner->m_mtx.unlock(); -#endif - } - - Lock& operator=(const Lock&) = delete; - Lock& operator=(Lock&&) = delete; - - intl::UniqPtrWithCN& operator()() { return m_owner->m_opr; } - }; - - //! lock and acquire the operator - Lock lock() { - Lock ret{this}; - if (!m_opr) { - m_opr = intl::create_megdnn_opr(CompNode::default_cpu()); - } - return ret; - } -}; -} // anonymous namespace - /* ========================= BatchedDTypePromotion ========================= */ intl::BatchedDTypePromotion::BatchedDTypePromotion(const VarNodeArrayView& vars) : m_orig_vars{vars} { diff --git a/src/opr/impl/nn_int.cpp b/src/opr/impl/nn_int.cpp index e04aa49ddb01ae488c3991cf1c8005bc82a7ceb3..cbf84fe9aeac10e83043b1a24a6f51a419cf7ea1 100644 --- a/src/opr/impl/nn_int.cpp +++ b/src/opr/impl/nn_int.cpp @@ -1,6 +1,6 @@ #include "megbrain/opr/nn_int.h" #include "./internal/megdnn_opr_wrapper.inl" - +#include "megbrain/opr/utility.h" #include "megdnn/oprs/general.h" using namespace mgb; @@ -18,6 +18,7 @@ ElemwiseMultiType::ElemwiseMultiType( for (auto i : inputs) { add_input({i}); } + output(0)->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE); } SymbolVar ElemwiseMultiType::make( @@ -52,8 +53,13 @@ void ElemwiseMultiType::init_output_dtype() { void ElemwiseMultiType::scn_do_execute() { megdnn::TensorNDArray inp_arr(input().size()); for (size_t i = 0; i < input().size(); ++i) { + if (input()[i]->dev_tensor().empty()) { + mgb_assert(output(0)->dev_tensor().empty()); + return; + } inp_arr[i] = input()[i]->dev_tensor().as_megdnn(); } + mgb_assert(!output(0)->dev_tensor().empty()); megdnn_opr()->exec(inp_arr, output(0)->dev_tensor().as_megdnn()); } @@ -75,4 +81,120 @@ void ElemwiseMultiType::add_input_layout_constraint() { #endif } +ElemwiseMultiType::NodeProp* ElemwiseMultiType::do_make_node_prop() const { + auto ret = Super::do_make_node_prop(); + for (auto& inp : input()) { + ret->add_dep_type_existing_var(inp, NodeProp::DepType::VALUE_ALLOW_EMPTY); + } + return ret; +} + +void ElemwiseMultiType::init_output_static_infer_desc() { + Super::init_output_static_infer_desc(); + static StaticInferOpr static_infer_opr; + + using namespace cg::static_infer; + + auto infer_value = [this](DeviceTensorND& dest, const InpVal& inp) { + SmallVector inp_vals(inp.val.size()); + for (size_t i = 0; i < inp_vals.size(); ++i) + inp_vals[i] = inp.val[i].value(); + + DType out_dt; + auto trait = ModeTrait::from_mode(param().mode); + if (trait.need_specify_out_dtype) { + auto dtype = config().output_dtype(); + mgb_assert(dtype.valid()); + out_dt = dtype; + } else { + DType dtype; + trait.check_out(dtype, false); + out_dt = dtype; + } + auto sopr = static_infer_opr.lock(); + perform(param().mode, out_dt, dest, inp_vals, sopr()); + return true; + }; + DepVal deps(input().size()); + for (size_t i = 0; i < input().size(); ++i) + deps[i] = {input(i), DepType::VALUE}; + owner_graph()->static_infer_manager().register_value_infer( + output(0), {SourceType::DEP, deps, infer_value}); +} + +TensorShape ElemwiseMultiType::get_output_var_shape( + Mode mode, const TensorShapeArray& input_shapes) { + mgb_assert(input_shapes.size() == ModeTrait::from_mode(mode).arity); + TensorShape ret; + megdnn::Elemwise::deduce_shape(input_shapes, ret); + return ret; +} + +void ElemwiseMultiType::call_megdnn_opr_exec( + CompNode comp_node, megdnn::TensorNDArray& inp, const megdnn::TensorND& out, + megdnn::ElemwiseMultiType* opr, ElemwiseMultiType* caller) { + // All Elemwise operations on QuantizedS32/QuantizedS8 are not related to + // scale. MegDNN does not support computing Elemwise for + // QuantizedS32/QuantizedS8, we translate the data type to Int32/Int8 before + // passing to MegDNN. + if (inp.size() && inp[0].layout.dtype.category() == DTypeCategory::QUANTIZED) { + auto inp_dtype = inp[0].layout.dtype; + DType compute_dtype; + if (inp_dtype.enumv() == DTypeEnum::QuantizedS32) { + compute_dtype = dtype::Int32(); + } else if (inp_dtype.enumv() == DTypeEnum::QuantizedS8) { + compute_dtype = dtype::Int8(); + } else { + mgb_throw( + MegBrainError, "Unsupported Quantized Elemwise Mode %s: %d on %s", + inp[0].layout.dtype.name(), int(opr->param().mode), + comp_node.to_string().c_str()); + } + + megdnn::TensorNDArray run_inp(inp); + for (size_t i = 0; i < inp.size(); i++) { + run_inp[i].layout.dtype = compute_dtype; + } + megdnn::TensorND run_out = out; + run_out.layout.dtype = compute_dtype; + opr->exec(run_inp, run_out); + return; + } + + opr->exec(inp, out); +} + +void ElemwiseMultiType::perform( + Mode mode, DType out_dt, DeviceTensorND& dest, + const SmallVector& inputs, + intl::UniqPtrWithCN& opr) { + megdnn::TensorNDArray dnn_inputs(inputs.size()); + TensorShapeArray inp_shapes(inputs.size()); + CompNode out_cn; + for (size_t i = 0; i < inputs.size(); ++i) { + auto&& t = inputs[i]; + if (!i) { + out_cn = t.comp_node(); + } else { + mgb_assert(t.comp_node() == out_cn); + } + if (t.shape().is_empty()) { + mgb_assert(dest.empty()); + return; + } + inp_shapes[i] = t.shape(); + } + if (!opr) { + opr = intl::create_megdnn_opr(out_cn); + } else { + mgb_assert(out_cn == opr.comp_node()); + } + out_cn.activate(); + for (size_t i = 0; i < inputs.size(); ++i) + dnn_inputs[i] = inputs[i].as_megdnn(); + dest.comp_node(out_cn).dtype(out_dt).resize(get_output_var_shape(mode, inp_shapes)); + opr->param() = {mode}; + call_megdnn_opr_exec(out_cn, dnn_inputs, dest.as_megdnn(), opr.get(), nullptr); +} + // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/src/opr/include/megbrain/opr/nn_int.h b/src/opr/include/megbrain/opr/nn_int.h index 305e323fb155f9d95968c160831700aa18905aac..0bcbeb10da0e2c58645ce88023940654cdb7251b 100644 --- a/src/opr/include/megbrain/opr/nn_int.h +++ b/src/opr/include/megbrain/opr/nn_int.h @@ -26,6 +26,14 @@ public: const VarNodeArrayView& inputs, Param param, const OperatorNodeConfig& config = {}); + MGE_WIN_DECLSPEC_FUC static TensorShape get_output_var_shape( + Mode mode, const TensorShapeArray& input_shapes); + + MGE_WIN_DECLSPEC_FUC static void perform( + Mode mode, DType out_dt, DeviceTensorND& dest, + const SmallVector& inputs, + intl::UniqPtrWithCN& opr); + private: using ModeTrait = megdnn::ElemwiseMultiType::ModeTrait; @@ -40,6 +48,14 @@ private: void record_execute_deps(ExecDependencyArray& deps) override; void add_input_layout_constraint() override; + + NodeProp* do_make_node_prop() const override; + + void init_output_static_infer_desc() override; + + static void call_megdnn_opr_exec( + CompNode comp_node, megdnn::TensorNDArray& inp, const megdnn::TensorND& out, + megdnn::ElemwiseMultiType* opr, ElemwiseMultiType* caller); }; //! deprecated; TODO: remove in megbrain 8 diff --git a/src/opr/include/megbrain/opr/utility.h b/src/opr/include/megbrain/opr/utility.h index 36c322dcf4fefddf66f07fd142d5a10b51e0136f..37e896ebd8bea9272b93ace92aa736850a326157 100644 --- a/src/opr/include/megbrain/opr/utility.h +++ b/src/opr/include/megbrain/opr/utility.h @@ -509,6 +509,49 @@ public: bool is_const() const { return m_is_const; } }; +//! global operator instance for static inference +template +class StaticInferOpr { + intl::UniqPtrWithCN m_opr; + MGB_MUTEX m_mtx; + +public: + class Lock { + friend class StaticInferOpr; + StaticInferOpr* m_owner; + + explicit Lock(StaticInferOpr* owner) : m_owner{owner} { +#if !__DEPLOY_ON_XP_SP2__ + m_owner->m_mtx.lock(); +#endif + } + + public: + Lock(Lock&& rhs) : m_owner{rhs.m_owner} { rhs.m_owner = nullptr; } + + ~Lock() { +#if !__DEPLOY_ON_XP_SP2__ + if (m_owner) + m_owner->m_mtx.unlock(); +#endif + } + + Lock& operator=(const Lock&) = delete; + Lock& operator=(Lock&&) = delete; + + intl::UniqPtrWithCN& operator()() { return m_owner->m_opr; } + }; + + //! lock and acquire the operator + Lock lock() { + Lock ret{this}; + if (!m_opr) { + m_opr = intl::create_megdnn_opr(CompNode::default_cpu()); + } + return ret; + } +}; + } // namespace opr } // namespace mgb