提交 12178011 编写于 作者: M Megvii Engine Team

perf(mge): add opdef for broadcast

GitOrigin-RevId: 92f0af29eb000b3e37f059e83fda52d26f21b383
上级 fccb2510
......@@ -413,7 +413,7 @@ TensorLayout TensorLayout::broadcast(const TensorShape& tshape) const {
megdnn_throw_if(
cur_shape != 1 && cur_stride != 0, tensor_reshape_error,
megdnn_mangle(ssprintf(
"brodcast on dim with shape not equal to 1: "
"broadcast on dim with shape not equal to 1: "
"src_shape=%s dst_shape=%s",
to_string().c_str(), tshape.to_string().c_str())));
result.shape[target_idx] = tshape.shape[target_idx];
......
......@@ -47,7 +47,9 @@ def _(op: OpDef, inputs, outputs, input_requires_grad):
grad_fn = reduce_sum_grad_fn
else:
grad_fn = default_grad_fn
elif isinstance(op, Elemwise) and op.mode == Elemwise.Mode.ADD:
elif isinstance(op, Broadcast) or (
isinstance(op, Elemwise) and op.mode == Elemwise.Mode.ADD
):
grad_fn = elemwise_add_grad_fn
else:
grad_fn = default_grad_fn
......@@ -212,5 +214,4 @@ _oprAttr_grad_fn = {
Reshape.name: reshape_grad_fn,
Subtensor.name: subtensor_grad_fn,
IndexingMultiAxisVec.name: indexingMultiAxisVec_grad_fn,
Broadcast.name: elemwise_add_grad_fn,
}
......@@ -59,29 +59,7 @@ def _transpose(data, axes):
def _broadcast(inp, shape):
def valid_broadcast(src, tar):
def failed():
raise ValueError(
"the input shape {} can not be broadcasted to target shape {}".format(
src, tar
)
)
if isinstance(src, (TensorBase, TensorWrapperBase)):
src = src.numpy()
if isinstance(tar, (TensorBase, TensorWrapperBase)):
tar = tar.numpy()
if len(src) > len(tar):
failed()
for i in range(min(len(src), len(tar))):
if src[-i - 1] != 1 and src[-i - 1] != tar[-i - 1]:
failed()
shape = utils.astensor1d(shape, inp, dtype="int32", device=inp.device)
valid_broadcast(inp.shape, shape)
(result,) = apply(builtin.Broadcast(), inp, shape)
return result
......
......@@ -21,6 +21,7 @@
#include "megbrain/imperative/ops/nms.h"
#include "megbrain/imperative/ops/elemwise.h"
#include "megbrain/imperative/ops/batch_norm.h"
#include "megbrain/imperative/ops/broadcast.h"
namespace py = pybind11;
......@@ -206,4 +207,7 @@ void init_ops(py::module m) {
V(INFERENCE);
#undef V
py::class_<Broadcast, std::shared_ptr<Broadcast>, OpDef>(m, "Broadcast")
.def(py::init<>());
}
......@@ -262,13 +262,13 @@ def test_broadcast():
opr_test(cases, F.broadcast_to, compare_fn=compare_fn)
x = F.ones((2, 1, 3))
with pytest.raises(ValueError):
with pytest.raises(RuntimeError):
F.broadcast_to(x, (2, 3, 4))
with pytest.raises(ValueError):
with pytest.raises(RuntimeError):
F.broadcast_to(x, (4, 1, 3))
with pytest.raises(ValueError):
with pytest.raises(RuntimeError):
F.broadcast_to(x, (1, 3))
......
/**
* \file imperative/src/impl/ops/broadcast.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "megbrain/imperative/ops/broadcast.h"
#include "../op_trait.h"
namespace mgb {
namespace imperative {
namespace {
std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) {
node_->cast_final_safe<opr::Broadcast>();
return Broadcast::make();
}
cg::OperatorNodeBase* apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
def.cast_final_safe<Broadcast>();
size_t nr_inp = inputs.size();
mgb_assert(nr_inp == 2, "Broadcast expects 2 inputs; got %lu actually", nr_inp);
return opr::Broadcast::make(inputs[0], inputs[1]).node()->owner_opr();
}
bool valid_broadcast(const TensorShape& src_shape,
const TensorShape& tar_shape) {
size_t src_ndim = src_shape.ndim, tar_ndim = tar_shape.ndim;
if (src_ndim > tar_ndim) {
return false;
}
size_t min_ndim = src_ndim < tar_ndim ? src_ndim : tar_ndim;
for (size_t i = 0; i < min_ndim; ++i) {
if (src_shape[src_ndim - i - 1] != 1 &&
src_shape[src_ndim - i - 1] != tar_shape[tar_ndim - i - 1]) {
return false;
}
}
return true;
}
SmallVector<LogicalTensorDesc> infer_output_attrs_fallible(
const OpDef& def,
const SmallVector<LogicalTensorDesc>& inputs) {
def.cast_final_safe<Broadcast>();
size_t nr_inp = inputs.size();
mgb_assert(nr_inp == 2, "Broadcast expects 2 inputs; got %lu actually", nr_inp);
auto&& src = inputs[0];
auto&& tshp = inputs[1];
TensorLayout out_layout = src.layout;
if (tshp.layout.ndim == 0 || tshp.value.empty()) {
out_layout.ndim = 0;
return {{out_layout, src.comp_node}};
}
mgb_assert(
tshp.layout.ndim == 1,
"target shape of Broadcast expects ndim=1; got ndim=%lu actually",
tshp.layout.ndim);
size_t target_ndim = tshp.layout.shape[0];
out_layout.ndim = target_ndim;
auto* ptr = tshp.value.ptr<dt_int32>();
for(size_t i=0; i<target_ndim; ++i) {
out_layout.shape[i] = ptr[i];
}
mgb_assert(valid_broadcast(src.layout, out_layout),
"the input shape %s can not be broadcasted to target shape %s",
src.layout.TensorShape::to_string().c_str(),
out_layout.TensorShape::to_string().c_str());
return {{out_layout, src.comp_node}};
}
OP_TRAIT_REG(Broadcast, Broadcast, opr::Broadcast)
.make_from_op_node(make_from_op_node)
.apply_on_var_node(apply_on_var_node)
.infer_output_attrs_fallible(infer_output_attrs_fallible)
.fallback();
} // anonymous namespace
MGB_DYN_TYPE_OBJ_FINAL_IMPL(Broadcast);
} // namespace imperative
} // namespace mgb
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
/**
* \file imperative/src/include/megbrain/imperative/ops/broadcast.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include "megbrain/opr/tensor_manip.h"
#include "megbrain/imperative/ops/opr_attr.h"
#include "megbrain/imperative/op_def.h"
namespace mgb::imperative {
class Broadcast : public OpDefImplBase<Broadcast> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
public:
Broadcast() = default;
size_t hash() const override {
return reinterpret_cast<std::uintptr_t>(dyn_typeinfo());
}
bool is_same_st(const Hashable& rhs) const override {
return true;
}
};
} // namespace mgb::imperative
......@@ -32,8 +32,7 @@ public:
bool is_same_st(const Hashable& rhs_) const override {
auto&& rhs = static_cast<const NMSKeep&>(rhs_);
return rhs.dyn_typeinfo() == dyn_typeinfo()
&& rhs.iou_thresh == iou_thresh
return rhs.iou_thresh == iou_thresh
&& rhs.max_output == max_output;
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册