提交 95eb6ae3 编写于 作者: M Megvii Engine Team

feat(mgb/opr): let more ops support empty IO

GitOrigin-RevId: 84dddb4b23638b29950e438bba2af8b5fd5166fa
上级 296a2885
......@@ -392,8 +392,6 @@ TensorLayout TensorLayout::broadcast(const TensorShape& tshape) const {
TensorLayout result{dtype, format};
result.ndim = tshape.ndim;
for (size_t i = 0; i < tshape.ndim; i++) {
megdnn_throw_if(!tshape.shape[i], tensor_reshape_error,
megdnn_mangle("target shape is 0"));
result.shape[i] = tshape.shape[i];
result.stride[i] = (tshape.shape[i] == 1);
}
......@@ -409,8 +407,6 @@ TensorLayout TensorLayout::broadcast(const TensorShape& tshape) const {
for (size_t i = 0; i < tshape.ndim; ++i) {
int target_idx = tshape.ndim - i - 1;
int cur_idx = ndim - i - 1;
megdnn_throw_if(!tshape.shape[target_idx], tensor_reshape_error,
megdnn_mangle("target shape is 0"));
size_t cur_shape = (cur_idx >= 0 ? shape[cur_idx] : 1),
cur_stride = (cur_idx >= 0 ? stride[cur_idx] : 0);
if (tshape.shape[target_idx] != cur_shape) {
......@@ -434,10 +430,16 @@ TensorLayout TensorLayout::broadcast(const TensorShape& tshape) const {
bool TensorLayout::try_reshape(TensorLayout& result,
const TensorShape& tshp) const {
megdnn_assert(tshp.ndim);
bool is_empty_shape = false;
for (size_t i = 0; i < tshp.ndim; ++i) {
megdnn_throw_if(!tshp.shape[i], tensor_reshape_error,
megdnn_mangle(ssprintf("bad target tshp: %s",
tshp.to_string().c_str())));
if (!tshp.shape[i]) {
megdnn_throw_if(!format.is_default(), tensor_reshape_error,
megdnn_mangle(ssprintf("bad target tshp: %s",
tshp.to_string().c_str())));
is_empty_shape = true;
break;
}
}
megdnn_throw_if(
......@@ -454,6 +456,11 @@ bool TensorLayout::try_reshape(TensorLayout& result,
result.format = this->format;
result.TensorShape::operator=(tshp);
if (is_empty_shape) {
result.init_contiguous_stride();
return true;
}
size_t sdim = 0, prod = 1, cont_sdim = 0;
for (size_t i = 0; i < tshp.ndim; ++i) {
megdnn_assert(cont_sdim < cont.ndim);
......
......@@ -237,7 +237,8 @@ void GetVarShape::record_execute_deps(ExecDependencyArray& deps) {
void ReshapeBrdcastHelper::reshapebrdcast_init(VarNode *inp, VarNode *tshp) {
add_input({inp, tshp});
add_output(None)->dtype(inp->dtype());
add_output(None)->dtype(inp->dtype())
.add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE);
if (reshapebrdcast_output_shape_need_input_shape())
outshape_by_symvar_enable(1, 1);
else
......@@ -340,6 +341,14 @@ void ReshapeBrdcastHelper::init_output_static_infer_desc() {
infer_value});
}
ReshapeBrdcastHelper::NodeProp*
ReshapeBrdcastHelper::do_make_node_prop() const {
auto ret = Super::do_make_node_prop();
ret->add_dep_type_existing_var(input(0),
NodeProp::DepType::VALUE_ALLOW_EMPTY);
return ret;
}
// f}}}
/* f{{{ ======================= Reshape ======================= */
......@@ -394,7 +403,7 @@ Maybe<TensorLayout> Reshape::reshapebrdcast_get_dest_layout(
}
auto tot_nr_elem = src.total_nr_elems();
actual_tshape.shape[unspec] = 0;
mgb_throw_if(tot_nr_elem % rem_nr_elem, TensorReshapeError,
mgb_throw_if(!rem_nr_elem || tot_nr_elem % rem_nr_elem, TensorReshapeError,
"could not reshape: src=%s tshape=%s unspec_axis=%zd",
static_cast<const TensorShape&>(src).to_string().c_str(),
actual_tshape.to_string().c_str(),
......@@ -484,6 +493,17 @@ void AxisManipOprBase::init_output_static_infer_desc() {
{SourceType::DEP, {{input(0), DepType::VALUE}}, infer_value});
}
AxisManipOprBase::NodeProp* AxisManipOprBase::do_make_node_prop() const {
auto ret = Super::do_make_node_prop();
ret->add_dep_type_existing_var(input(0),
NodeProp::DepType::VALUE_ALLOW_EMPTY);
return ret;
}
void AxisManipOprBase::axis_manip_init(VarNode* inp) {
add_input({inp});
add_output(None)->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE);
}
// f}}}
......@@ -504,8 +524,7 @@ Dimshuffle::Dimshuffle(VarNode *inp, const std::vector<int> &pattern,
mgb_throw_if(i < -1 || i >= int(ndim), GraphError,
"bad Dimshuffle pattern");
}
add_input({inp});
add_output(None);
axis_manip_init(inp);
add_equivalence_component<PODHash<int>>(m_pattern.data(), m_pattern.size());
}
......@@ -587,8 +606,7 @@ AxisAddRemove::AxisAddRemove(
{
mgb_throw_if(desc.empty(), GraphError,
"desc for AxisAddRemove could not be empty");
add_input({inp});
add_output(None)->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE);
axis_manip_init(inp);
add_equivalence_component<PODHash<AxisDesc>>(m_desc.data(), m_desc.size());
}
......@@ -631,13 +649,6 @@ TensorLayout AxisAddRemove::axis_manip_get_output_layout(
return layout;
}
AxisAddRemove::NodeProp* AxisAddRemove::do_make_node_prop() const {
auto ret = Super::do_make_node_prop();
ret->add_dep_type_existing_var(input(0),
NodeProp::DepType::VALUE_ALLOW_EMPTY);
return ret;
}
#ifdef MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(AxisAddRemove) {
MGB_MARK_USED_VAR(wrt_idx);
......
......@@ -92,6 +92,7 @@ MGB_DEFINE_CLS_WITH_SUPER(ReshapeBrdcastHelper,
void scn_do_execute() override final;
void add_input_layout_constraint() override final;
void init_output_static_infer_desc() override;
NodeProp* do_make_node_prop() const override;
protected:
using Super::Super;
......@@ -199,11 +200,14 @@ MGB_DEFINE_CLS_WITH_SUPER(AxisManipOprBase,
void mem_plan_fwd_in2out_readonly() override final;
void scn_do_execute() override final;
void init_output_static_infer_desc() override final;
NodeProp* do_make_node_prop() const override;
protected:
using Super::Super;
virtual TensorLayout axis_manip_get_output_layout(
const TensorLayout &inp_layout) const = 0;
void axis_manip_init(VarNode* inp);
};
}
......@@ -319,8 +323,6 @@ MGB_DEFINE_OPR_CLASS(AxisAddRemove, intl::AxisManipOprBase) // {
TensorLayout axis_manip_get_output_layout(
const TensorLayout &inp_layout) const override;
NodeProp* do_make_node_prop() const override;
};
namespace intl {
......
......@@ -17,6 +17,7 @@
#include "megbrain/opr/io.h"
#include "megbrain/opr/blas.h"
#include "megbrain/opr/utility.h"
#include "megbrain/opr/misc.h"
#include "megbrain/utils/arith_helper.h"
using namespace mgb;
......@@ -138,7 +139,7 @@ TEST(TestTensorManip, Reshape) {
auto &&dep_map = opr0_reshp.node()->owner_opr()->node_prop().dep_map();
using DT = cg::OperatorNodeBase::NodeProp::DepType;
ASSERT_EQ(2u, dep_map.size());
ASSERT_EQ(DT::DEV_VALUE, dep_map.at(op->input(0)));
ASSERT_EQ(DT::DEV_VALUE | DT::VALUE_ALLOW_EMPTY, dep_map.at(op->input(0)));
ASSERT_EQ(DT::HOST_VALUE, dep_map.at(op->input(1)));
}
......@@ -318,6 +319,39 @@ TEST(TestTensorManip, ReshapeInferShapeForDynamicInput) {
run({23, 12, 5});
}
TEST(TestTensorManip, ReshapeEmptyShape) {
HostTensorGenerator<> gen;
constexpr size_t x_length = 233;
auto host_x = gen({x_length}),
host_v = gen({2, 3, 3, 3});
for (size_t i = 0; i < x_length; ++ i) {
host_x->ptr<float>()[i] = 1.f;
}
constexpr auto INVALID_AXIS = opr::Reshape::Param::INVALID_AXIS;
for (auto unspec_axis: {INVALID_AXIS, 0, 1, 3}) {
auto graph = ComputingGraph::make();
graph->options().graph_opt_level = 0;
TensorShape tshape{2, 3, 3, 3};
auto zero_axis = unspec_axis;
if (unspec_axis == INVALID_AXIS) {
tshape[zero_axis = 2] = 0;
}
using CondTakeMode = opr::CondTake::Param::Mode;
auto x = opr::Host2DeviceCopy::make(*graph, host_x),
x_empty = opr::CondTake::make(x, x, {CondTakeMode::EQ, 0.f})[0],
v = opr::Host2DeviceCopy::make(*graph, host_v),
x_reshape = opr::Reshape::make(x_empty, tshape, {unspec_axis}),
y = opr::Concat::make({x_reshape, v}, zero_axis);
HostTensorND host_empty, host_y;
auto func = graph->compile({
make_callback_copy(x_reshape, host_empty),
make_callback_copy(y, host_y)});
func->execute().wait();
ASSERT_TRUE(host_empty.layout().is_empty());
MGB_ASSERT_TENSOR_EQ(*host_v, host_y);
}
}
TEST(TestTensorManip, ReshapeWithNegativeUnspec) {
HostTensorGenerator<> gen;
auto host_x = gen({4, 8});
......@@ -365,6 +399,26 @@ TEST(TestTensorManip, Broadcast) {
}
}
TEST(TestTensorManip, BroadcastEmptyShape) {
HostTensorGenerator<> gen;
for (auto&& arg:
{std::make_pair(TensorShape{1}, TensorShape{0}),
{{1, 2, 3}, {0, 2, 3}},
{{2, 3}, {1, 0, 2, 3}},
{{1, 0, 2, 3}, {4, 0, 2, 3}},
{{0, 1, 2, 3}, {3, 0, 4, 2, 3}}}) {
auto host_x = gen(arg.first);
auto graph = ComputingGraph::make();
graph->options().graph_opt_level = 0;
auto x = opr::Host2DeviceCopy::make(*graph, host_x),
y = opr::Broadcast::make(x, arg.second);
HostTensorND host_y;
auto func = graph->compile({make_callback_copy(y, host_y)});
func->execute();
ASSERT_TRUE(host_y.shape().eq_shape(arg.second));
}
}
TEST(TestTensorManip, Dimshuffle) {
HostTensorGenerator<> gen;
constexpr size_t S0 = 8, S1 = 3;
......@@ -395,6 +449,34 @@ TEST(TestTensorManip, Dimshuffle) {
}
}
TEST(TestTensorManip, DimshuffleEmptyShape) {
HostTensorGenerator<> gen;
for (auto&& arg:
{std::make_pair(
TensorShape{3, 0},
std::vector<int>{1, -1, 0, -1}),
{{3, 1, 0, 4}, {-1, 3, -1, 0, 2}},
{{2, 0, 3, 0}, {1, 0, 2, 3}}}) {
auto host_x = gen(arg.first);
auto graph = ComputingGraph::make();
graph->options().graph_opt_level = 0;
auto x = opr::Host2DeviceCopy::make(*graph, host_x),
y = opr::Dimshuffle::make(x, arg.second);
HostTensorND host_y;
auto func = graph->compile({make_callback_copy(y, host_y)});
func->execute();
auto&& y_shape = host_y.shape();
for(size_t idx = 0; idx < arg.second.size(); ++ idx) {
auto elem = arg.second[idx];
if (elem == -1) {
ASSERT_EQ(y_shape[idx], 1u);
} else {
ASSERT_EQ(arg.first[elem], y_shape[idx]);
}
}
}
}
TEST(TestTensorManip, DimshuffleCombined) {
using Checker = AutoOprChecker<1, 1>;
constexpr int RED0 = 2, RED1 = 3;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册