diff --git a/dnn/src/common/basic_types.cpp b/dnn/src/common/basic_types.cpp index 7462441418fbb1465e3bdd5bb85f9fdb2fda518f..dce7b820c35f78fc3dbf2c8b95cf491a8b509449 100644 --- a/dnn/src/common/basic_types.cpp +++ b/dnn/src/common/basic_types.cpp @@ -413,7 +413,7 @@ TensorLayout TensorLayout::broadcast(const TensorShape& tshape) const { megdnn_throw_if( cur_shape != 1 && cur_stride != 0, tensor_reshape_error, megdnn_mangle(ssprintf( - "brodcast on dim with shape not equal to 1: " + "broadcast on dim with shape not equal to 1: " "src_shape=%s dst_shape=%s", to_string().c_str(), tshape.to_string().c_str()))); result.shape[target_idx] = tshape.shape[target_idx]; diff --git a/imperative/python/megengine/core/autodiff/builtin_op_utils.py b/imperative/python/megengine/core/autodiff/builtin_op_utils.py index db9605ca62ed0ddb130b1c148e20a745cce1e8eb..356e21f72b21748b7e20b2975b023ac0aafc21c1 100644 --- a/imperative/python/megengine/core/autodiff/builtin_op_utils.py +++ b/imperative/python/megengine/core/autodiff/builtin_op_utils.py @@ -47,7 +47,9 @@ def _(op: OpDef, inputs, outputs, input_requires_grad): grad_fn = reduce_sum_grad_fn else: grad_fn = default_grad_fn - elif isinstance(op, Elemwise) and op.mode == Elemwise.Mode.ADD: + elif isinstance(op, Broadcast) or ( + isinstance(op, Elemwise) and op.mode == Elemwise.Mode.ADD + ): grad_fn = elemwise_add_grad_fn else: grad_fn = default_grad_fn @@ -212,5 +214,4 @@ _oprAttr_grad_fn = { Reshape.name: reshape_grad_fn, Subtensor.name: subtensor_grad_fn, IndexingMultiAxisVec.name: indexingMultiAxisVec_grad_fn, - Broadcast.name: elemwise_add_grad_fn, } diff --git a/imperative/python/megengine/core/tensor/tensor_wrapper.py b/imperative/python/megengine/core/tensor/tensor_wrapper.py index 6b41c4c8e494d5db2116ab429729d29ee0e63f61..411358062e51c4e447a9c2fcdc3a06935a65fd70 100644 --- a/imperative/python/megengine/core/tensor/tensor_wrapper.py +++ b/imperative/python/megengine/core/tensor/tensor_wrapper.py @@ -59,29 +59,7 @@ def _transpose(data, axes): def _broadcast(inp, shape): - def valid_broadcast(src, tar): - def failed(): - raise ValueError( - "the input shape {} can not be broadcasted to target shape {}".format( - src, tar - ) - ) - - if isinstance(src, (TensorBase, TensorWrapperBase)): - src = src.numpy() - - if isinstance(tar, (TensorBase, TensorWrapperBase)): - tar = tar.numpy() - - if len(src) > len(tar): - failed() - - for i in range(min(len(src), len(tar))): - if src[-i - 1] != 1 and src[-i - 1] != tar[-i - 1]: - failed() - shape = utils.astensor1d(shape, inp, dtype="int32", device=inp.device) - valid_broadcast(inp.shape, shape) (result,) = apply(builtin.Broadcast(), inp, shape) return result diff --git a/imperative/python/src/ops.cpp b/imperative/python/src/ops.cpp index e6afd85c8949d13e1c07b513a89d8fe28b9b404a..f48c330fd0c34a1b2a72dd1893ce911ff45b597f 100644 --- a/imperative/python/src/ops.cpp +++ b/imperative/python/src/ops.cpp @@ -21,6 +21,7 @@ #include "megbrain/imperative/ops/nms.h" #include "megbrain/imperative/ops/elemwise.h" #include "megbrain/imperative/ops/batch_norm.h" +#include "megbrain/imperative/ops/broadcast.h" namespace py = pybind11; @@ -206,4 +207,7 @@ void init_ops(py::module m) { V(INFERENCE); #undef V + py::class_, OpDef>(m, "Broadcast") + .def(py::init<>()); + } diff --git a/imperative/python/test/unit/functional/test_tensor.py b/imperative/python/test/unit/functional/test_tensor.py index 1e8fbc55840450f19a15a74e7a145e24cb00cee1..da7b8d19a3036266db89261bcfeb33b5a56df65a 100644 --- a/imperative/python/test/unit/functional/test_tensor.py +++ b/imperative/python/test/unit/functional/test_tensor.py @@ -262,13 +262,13 @@ def test_broadcast(): opr_test(cases, F.broadcast_to, compare_fn=compare_fn) x = F.ones((2, 1, 3)) - with pytest.raises(ValueError): + with pytest.raises(RuntimeError): F.broadcast_to(x, (2, 3, 4)) - with pytest.raises(ValueError): + with pytest.raises(RuntimeError): F.broadcast_to(x, (4, 1, 3)) - with pytest.raises(ValueError): + with pytest.raises(RuntimeError): F.broadcast_to(x, (1, 3)) diff --git a/imperative/src/impl/ops/broadcast.cpp b/imperative/src/impl/ops/broadcast.cpp new file mode 100644 index 0000000000000000000000000000000000000000..23e0137a0466c414ba7ab99de76a18db99fb397b --- /dev/null +++ b/imperative/src/impl/ops/broadcast.cpp @@ -0,0 +1,95 @@ +/** + * \file imperative/src/impl/ops/broadcast.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "megbrain/imperative/ops/broadcast.h" +#include "../op_trait.h" + +namespace mgb { +namespace imperative { + +namespace { + +std::shared_ptr make_from_op_node(cg::OperatorNodeBase* node_) { + node_->cast_final_safe(); + return Broadcast::make(); +} + +cg::OperatorNodeBase* apply_on_var_node( + const OpDef& def, + const VarNodeArray& inputs) { + def.cast_final_safe(); + size_t nr_inp = inputs.size(); + mgb_assert(nr_inp == 2, "Broadcast expects 2 inputs; got %lu actually", nr_inp); + return opr::Broadcast::make(inputs[0], inputs[1]).node()->owner_opr(); +} + +bool valid_broadcast(const TensorShape& src_shape, + const TensorShape& tar_shape) { + size_t src_ndim = src_shape.ndim, tar_ndim = tar_shape.ndim; + if (src_ndim > tar_ndim) { + return false; + } + size_t min_ndim = src_ndim < tar_ndim ? src_ndim : tar_ndim; + for (size_t i = 0; i < min_ndim; ++i) { + if (src_shape[src_ndim - i - 1] != 1 && + src_shape[src_ndim - i - 1] != tar_shape[tar_ndim - i - 1]) { + return false; + } + } + return true; +} + +SmallVector infer_output_attrs_fallible( + const OpDef& def, + const SmallVector& inputs) { + def.cast_final_safe(); + size_t nr_inp = inputs.size(); + mgb_assert(nr_inp == 2, "Broadcast expects 2 inputs; got %lu actually", nr_inp); + auto&& src = inputs[0]; + auto&& tshp = inputs[1]; + + TensorLayout out_layout = src.layout; + if (tshp.layout.ndim == 0 || tshp.value.empty()) { + out_layout.ndim = 0; + return {{out_layout, src.comp_node}}; + } + mgb_assert( + tshp.layout.ndim == 1, + "target shape of Broadcast expects ndim=1; got ndim=%lu actually", + tshp.layout.ndim); + + size_t target_ndim = tshp.layout.shape[0]; + out_layout.ndim = target_ndim; + auto* ptr = tshp.value.ptr(); + for(size_t i=0; i { + MGB_DYN_TYPE_OBJ_FINAL_DECL; +public: + Broadcast() = default; + + size_t hash() const override { + return reinterpret_cast(dyn_typeinfo()); + } + + bool is_same_st(const Hashable& rhs) const override { + return true; + } + +}; + +} // namespace mgb::imperative diff --git a/imperative/src/include/megbrain/imperative/ops/nms.h b/imperative/src/include/megbrain/imperative/ops/nms.h index ad1bd96bec474373fd0bae92d5ac6aea2e0e214a..ed66cd8e62275cc71f9bc0c1ae25eab7b814153e 100644 --- a/imperative/src/include/megbrain/imperative/ops/nms.h +++ b/imperative/src/include/megbrain/imperative/ops/nms.h @@ -32,8 +32,7 @@ public: bool is_same_st(const Hashable& rhs_) const override { auto&& rhs = static_cast(rhs_); - return rhs.dyn_typeinfo() == dyn_typeinfo() - && rhs.iou_thresh == iou_thresh + return rhs.iou_thresh == iou_thresh && rhs.max_output == max_output; }