diff --git a/dnn/include/megdnn/oprs/general.h b/dnn/include/megdnn/oprs/general.h index cfdcc3da2dffa0bd660b8b5f0c7274270a4e63d4..0fbd15f764b0a307eec3f4206607d76709b24134 100644 --- a/dnn/include/megdnn/oprs/general.h +++ b/dnn/include/megdnn/oprs/general.h @@ -180,7 +180,7 @@ public: virtual void exec( _megdnn_tensor_in src, _megdnn_tensor_out dst, _megdnn_workspace workspace) = 0; - void deduce_layout(const TensorLayout& src, TensorLayout& dst); + MGE_WIN_DECLSPEC_FUC void deduce_layout(const TensorLayout& src, TensorLayout& dst); virtual size_t get_workspace_in_bytes( const TensorLayout& src, const TensorLayout& dst) = 0; diff --git a/imperative/python/megengine/core/tensor/array_method.py b/imperative/python/megengine/core/tensor/array_method.py index 653337f5ff12f476c4d83c89b44518f8a8b8659c..b37a085e4261136e59db58934761f6f8172e2782 100644 --- a/imperative/python/megengine/core/tensor/array_method.py +++ b/imperative/python/megengine/core/tensor/array_method.py @@ -384,22 +384,18 @@ def _reduce(mode): data = self if axis is None: assert not keepdims, "can not set axis=None and keepdims=True" - result = _reduce_to_scalar(builtin.Reduce(mode=mode), data) + (result,) = apply(builtin.Reduce(mode=mode), data) elif isinstance(axis, collections.abc.Iterable): axis = _normalize_axis(self.ndim, axis, reverse=True) for ai in axis: - op = builtin.Reduce(mode=mode, axis=ai) + op = builtin.Reduce(mode=mode, axis=ai, keepdim=keepdims) (data,) = apply(op, data) - if not keepdims: - data = squeeze_cpp(data, ai) result = data else: # builtin.Reduce already accept negtive axis - op = builtin.Reduce(mode=mode, axis=axis) + op = builtin.Reduce(mode=mode, axis=axis, keepdim=keepdims) (result,) = apply(op, data) - if not keepdims: - result = squeeze_cpp(result, axis) return result return f diff --git a/imperative/python/src/grad_override.cpp b/imperative/python/src/grad_override.cpp index 45a7b03f9a0c96185549fb3bf94efbc01cce5099..d59437d489b33a181c9ee7a1e51e9385b2f3e3b6 100644 --- a/imperative/python/src/grad_override.cpp +++ b/imperative/python/src/grad_override.cpp @@ -271,24 +271,34 @@ std::optional reduce_grad_rule( if (reduce.mode != Reduce::Mode::SUM) { return {}; } - if (inputs.size() != 1) { + auto axis = reduce.axis; + if (inputs.size() != 1 || axis == INT_MAX) { return {}; } std::array input_shapes; if (inputs_require_grad[0]) { input_shapes[0] = get_shape(inputs[0]); } + if (axis < 0) { + axis = (*inputs[0].shape()).ndim + axis; + } auto maker = CustomGradMaker(backward, inputs.size()); + auto keepdim = reduce.keepdim || axis == INT_MAX; maker.output_size(1).output_captured(0, false); - maker.backward([shapes = std::move(input_shapes)](Span grads) { - mgb_assert(grads.size() == 1); - ValueRef grad = grads[0]; - SmallVector ret(1); - if (grad && shapes[0]) { - ret[0] = broadcast_to(grad, shapes[0]); - } - return ret; - }); + maker.backward( + [shapes = std::move(input_shapes), axis, keepdim](Span grads) { + mgb_assert(grads.size() == 1); + ValueRef grad = grads[0]; + if (!keepdim) { + auto&& grad_op = AddAxis::make(std::vector({axis})); + grad = imperative::apply(*grad_op, grad)[0]; + } + SmallVector ret(1); + if (grad && shapes[0]) { + ret[0] = broadcast_to(grad, shapes[0]); + } + return ret; + }); maker.finalize(); return imperative::apply(ApplyOp(op), inputs); } diff --git a/imperative/src/impl/ops/batch_norm.cpp b/imperative/src/impl/ops/batch_norm.cpp index f28a163ec8d58cd1a963681f4e2b86a83dab8bbc..d8f605dcd0adb3f98cb0dd005664f5d543397305 100644 --- a/imperative/src/impl/ops/batch_norm.cpp +++ b/imperative/src/impl/ops/batch_norm.cpp @@ -36,8 +36,8 @@ EncodedSubgraph generate_batchnorm_backward_graph(DType dtype, CompNode device) op, Subgraph::vars_t({(Subgraph::var_t)args...}), 1)[0]; }; - auto prod = Reduce::make(megdnn::param::Reduce(Reduce::Mode::PRODUCT, 0)); - auto sum = Reduce::make(megdnn::param::Reduce(Reduce::Mode::SUM)); + auto prod = Reduce::make(megdnn::param::Reduce(Reduce::Mode::PRODUCT, 0), true); + auto sum = Reduce::make(megdnn::param::Reduce(Reduce::Mode::SUM), true); auto sub = Elemwise::make(Elemwise::Mode::SUB); auto mul = Elemwise::make(Elemwise::Mode::MUL); auto div = Elemwise::make(Elemwise::Mode::TRUE_DIV); diff --git a/imperative/src/impl/ops/reduce.cpp b/imperative/src/impl/ops/reduce.cpp index 9be904175ae4090272fb016f94b728b1784fd270..5000fea6c8979cfd1b7f0bf34db670fac98cc071 100644 --- a/imperative/src/impl/ops/reduce.cpp +++ b/imperative/src/impl/ops/reduce.cpp @@ -9,10 +9,16 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +#include "megbrain/graph/symbol_var.h" #include "megbrain/imperative/ops/autogen.h" #include "megbrain/imperative/proxy_graph_detail.h" #include "megbrain/opr/basic_arith.h" +#include "megbrain/opr/internal/megdnn_opr_wrapper.h" +#include "megbrain/opr/io.h" +#include "megbrain/opr/tensor_manip.h" +#include "megdnn/dtype.h" +#include "../blob_manager_impl.h" #include "../dnn_op_helper.h" #include "../op_trait.h" @@ -22,18 +28,41 @@ namespace { namespace reduce { auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { auto&& reduce = static_cast(def); - OperatorNodeConfig config{reduce.make_name()}; + auto comp_node = inputs[0]->comp_node(); + OperatorNodeConfig config{reduce.make_name(), comp_node, inputs[0]->dtype()}; + if (inputs.size() > 1) { return opr::Reduce::make(inputs[0], reduce.param(), inputs[1], config); - } else { - return opr::Reduce::make( - inputs[0], reduce.param(), (cg::VarNode*)nullptr, config); } + + using Param = megdnn::param::Reduce; + auto param = reduce.param(); + if (param.axis < 0) { + param.axis = inputs[0]->shape().ndim + param.axis; + } + + SymbolVar target_shape = (cg::VarNode*)nullptr; + if (param.axis == INT_MAX) { + DTypeScalar vi{1}; + // auto graph = ComputingGraph::make(); + auto graph = inputs[0]->owner_graph(); + target_shape = opr::ImmutableTensor::make(*graph, vi, config); + } + auto res = opr::Reduce::make(inputs[0], param, target_shape, config); + if (!reduce.keepdim && param.axis != INT_MAX) { + using Desc = opr::AxisAddRemove::AxisDesc; + std::vector remove_param; + remove_param.push_back(Desc::make_remove(param.axis)); + OperatorNodeConfig remove_config{ + def.make_name(), comp_node, inputs[0]->dtype()}; + return opr::AxisAddRemove::make(res, remove_param, remove_config); + } + return res; } std::shared_ptr make_from_op_node(cg::OperatorNodeBase* node_) { auto* node = &node_->cast_final_safe(); - return Reduce::make(node->param()); + return Reduce::make(node->param(), true); } // TODO: using this for apply_on_physical_tensor @@ -57,21 +86,159 @@ SmallVector apply_on_physical_tensor( return {Tensor::make( inputs[0]->blob(), inputs[0]->offset(), inputs[0]->layout())}; } - return proxy_graph_detail::apply_on_physical_tensor( - def, inputs, output_descs, validated); + + auto size = inputs.size(); + if (size > 1) { + return proxy_graph_detail::apply_on_physical_tensor( + def, inputs, output_descs, validated); + } + + auto comp_node = inputs[0]->comp_node(); + using TensorND = megdnn::TensorND; + auto&& op_def = def.cast_final_safe(); + SmallVector inp_tensornds; + inp_tensornds.reserve(inputs.size()); + auto src = inputs[0]->layout(); + + DnnOprCaller dnn_op(comp_node); + dnn_op.op->param() = op_def.param(); + auto axis = op_def.param().axis; + auto keepdim = op_def.keepdim; + + if (axis < 0) { + axis = inputs[0]->layout().ndim + axis; + } + + dnn_op.op->param().axis = axis == INT_MAX ? 0 : axis; + + if (axis == INT_MAX) { + src.shape[0] = src.total_nr_elems(); + src.ndim = 1; + src.init_contiguous_stride(); + } + TensorLayout layout{src.dtype}; + dnn_op.op->deduce_layout(src, layout); + + if (inputs[0]->layout().is_empty()) { + inputs[0]->dev_tensor().reset(inputs[0]->dev_tensor().storage(), src); + + auto mode = op_def.param().mode; + DnnOprCaller fill_op(comp_node); + + if (!keepdim && src.ndim > 1) { + layout.remove_axis_inplace(axis); + layout.init_contiguous_stride(); + } + DeviceTensorND out = + BlobManager::inst()->alloc_workspace_with_defrag(comp_node, layout); + std::string err_msg; + switch (mode) { + case Reduce::Mode::SUM: + if (!out.empty()) { + fill_op.op->param() = 0; + fill_op.op->exec(out.as_megdnn(), {}); + } + break; + case Reduce::Mode::PRODUCT: + if (!out.empty()) { + fill_op.op->param() = 1; + fill_op.op->exec(out.as_megdnn(), {}); + } + break; + case Reduce::Mode::MEAN: + err_msg = "mean"; + break; + case Reduce::Mode::MIN: + err_msg = "min"; + break; + case Reduce::Mode::MAX: + err_msg = "max"; + break; + case Reduce::Mode::SUM_SQR: + err_msg = "sum_sqr"; + break; + default: + mgb_throw(MegBrainError, "bad reduce mode"); + } + if (!err_msg.empty()) { + mgb_throw( + MegBrainError, "empty input is not allowed for reduce mode: %s", + err_msg.c_str()); + } + return {Tensor::make(out)}; + } + + auto dnn_ten = inputs[0]->dnn_tensor(); + dnn_ten.layout = src; + inp_tensornds.push_back(dnn_ten); + + megdnn::Workspace dnn_wk; + + auto wk_size = dnn_op.op->get_workspace_in_bytes(src, layout); + if (wk_size != 0) { + auto wk = Blob::make(comp_node, wk_size); + dnn_wk.raw_ptr = wk->storage().get(); + dnn_wk.size = wk_size; + } + + DeviceTensorND out = + BlobManager::inst()->alloc_workspace_with_defrag(comp_node, layout); + + dnn_op.op->exec(inp_tensornds[0], out.as_megdnn(), dnn_wk); + + if (!keepdim && src.ndim > 1) { + auto out_layout = out.layout(); + out_layout.remove_axis_inplace(axis); + out_layout.init_contiguous_stride(); + out.resize(out_layout); + } + + return {Tensor::make(out)}; } std::tuple, bool> infer_output_attrs_fallible( const OpDef& def, const SmallVector& inputs) { - auto [output_descs, validated] = - proxy_graph_detail::infer_output_attrs_fallible(def, inputs); - if (inputs.size() == 2 && !output_descs[0].layout.ndim) { + auto&& op_def = def.cast_final_safe(); + auto axis = op_def.param().axis; + auto keepdim = op_def.keepdim; + + size_t size = inputs.size(); + SmallVector dests(size); + + if (size > 1) { + auto [output_descs, validated] = + proxy_graph_detail::infer_output_attrs_fallible(def, inputs); if (!inputs[1].value.empty()) { cg::copy_tensor_value_to_shape(output_descs[0].layout, inputs[1].value); output_descs[0].layout.init_contiguous_stride(); } + return {output_descs, validated}; + } + + if (axis < 0) { + axis = inputs[0].layout.ndim + axis; + } + + if (axis == INT_MAX || inputs[0].layout.ndim == 1) { + TensorLayout layout{inputs[0].layout.dtype}; + layout.shape[0] = 1; + layout.ndim = 1; + dests[0].layout = layout; + dests[0].comp_node = inputs[0].comp_node; + } else { + for (size_t i = 0; i < size; ++i) { + dests[i].comp_node = inputs[i].comp_node; + dests[i].layout = inputs[i].layout; + if (not keepdim && dests[i].layout.ndim > 1) { + dests[i].layout.remove_axis_inplace(axis); + } else { + dests[i].layout.shape[axis] = 1; + } + dests[i].layout.init_contiguous_stride(); + } } - return {output_descs, validated}; + + return {dests, true}; } SmallVector get_input_layout_constraint( diff --git a/imperative/src/impl/transformations/scalar.cpp b/imperative/src/impl/transformations/scalar.cpp index 544ee51290ad84221898fc0be3d11522bbb19070..89ee1f46cd3a92b4b40ad98e16f99370887c9e8e 100644 --- a/imperative/src/impl/transformations/scalar.cpp +++ b/imperative/src/impl/transformations/scalar.cpp @@ -92,7 +92,13 @@ ValueRefList remove_axis_rule( ValueRefList reduce_rule( const Reduce& reduce, Span inputs, Span inputs_mask, const Type& scalar_type) { + bool keepdim = reduce.keepdim; + auto axis = reduce.axis; if (inputs.size() == 1) { + if (axis == INT_MAX || (inputs[0].shape()->ndim == 1 && keepdim == false)) { + // CompNode device = *inputs[0].device(); + return {scalar_type.make(imperative::apply(reduce, inputs)[0])}; + } return imperative::apply(reduce, inputs); } mgb_assert(inputs.size() == 2); diff --git a/imperative/src/include/megbrain/imperative/subgraph_detail.h b/imperative/src/include/megbrain/imperative/subgraph_detail.h index ccebce1fddf43457fd9fd1a322828c4773e66cae..5c3224c53e0b9e70eabb8bdbe6921d566428d04f 100644 --- a/imperative/src/include/megbrain/imperative/subgraph_detail.h +++ b/imperative/src/include/megbrain/imperative/subgraph_detail.h @@ -43,6 +43,7 @@ EncodedSubgraph make_backward_graph_from_forward( const EncodedSubgraph& forward, const SmallVector& inputs, const SmallVector& input_requires_grad, const SmallVector& output_has_grad); + EncodedSubgraph make_from_computing_graph( const VarNodeArray& inputs, const VarNodeArray& outputs); diff --git a/src/core/include/megbrain/ir/ops.td b/src/core/include/megbrain/ir/ops.td index c1cccb38becfb15219f731477046e039b35219f1..23311367dbd1ca6ab3526482298d6685a929ac24 100644 --- a/src/core/include/megbrain/ir/ops.td +++ b/src/core/include/megbrain/ir/ops.td @@ -26,7 +26,11 @@ def Elemwise : MgbHashableOp<"Elemwise", [ElemwiseParam], [NoSideEffect]> { }]; } -def Reduce: MgbHashableOp<"Reduce", [ReduceParam]>; +def Reduce: MgbHashableOp<"Reduce", [ReduceParam]>{ + let extraArguments = (ins + MgbBoolAttr:$keepdim + ); +} def TypeCvt: MgbHashableOp<"TypeCvt", [], [NoSideEffect]> { let inputs = (ins AnyType:$inputs);