提交 6fe6df28 编写于 作者: M Megvii Engine Team

fix(opr/basic_arith): make Elemwise really support empty IO

... and add tests, should've done this the first time...

GitOrigin-RevId: 5d73fc4c7bf1fefe79c2cae653068938c0c8cb40
上级 38a95744
......@@ -174,11 +174,12 @@ void ElemwiseForward::deduce_shape(const TensorShapeArray& src,
if (cur_idx >= 0 && dst_idx >= 0) {
size_t v0 = dst.shape[dst_idx], v1 = cur.shape[cur_idx];
if (v0 != v1) {
if (v0 != 1 && v1 != 1)
if (v0 > 1 && v1 > 1)
err();
}
int final_idx = std::max(cur_idx, dst_idx);
dst.shape[final_idx] = std::max(v0, v1);
dst.shape[final_idx] =
(v0 != 0 && v1 != 0) ? std::max(v0, v1) : 0;
} else {
if (dst_idx < 0) {
dst.shape[cur_idx] = cur.shape[cur_idx];
......
......@@ -132,6 +132,7 @@ Elemwise::Elemwise(
Super{inputs.at(0)->owner_graph(), config, mode_trait.name, inputs}
{
init_megdnn_opr(*this, param);
output(0)->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE);
if (mode_trait.commutable) {
mgb_assert(inputs.size() == 2);
add_input({inputs[0], inputs[1]}, AddInputSortType::CUR_ADDED);
......@@ -371,8 +372,14 @@ void Elemwise::scn_do_execute() {
mgb_assert(megdnn_inp.capacity() >= inp.size(),
"heap allocation in elemwise exec");
megdnn_inp.resize(inp.size());
for (size_t i = 0; i < inp.size(); ++ i)
for (size_t i = 0; i < inp.size(); ++ i) {
if (inp[i]->dev_tensor().empty()) {
mgb_assert(output(0)->dev_tensor().empty());
return;
}
megdnn_inp[i] = (inp[i]->dev_tensor().as_megdnn());
}
mgb_assert(!output(0)->dev_tensor().empty());
megdnn_opr()->param() = param();
call_megdnn_opr_exec(
......@@ -747,6 +754,15 @@ void Elemwise::record_execute_deps(ExecDependencyArray& deps) {
record_megdnn_opr(deps);
}
Elemwise::NodeProp* Elemwise::do_make_node_prop() const {
auto ret = Super::do_make_node_prop();
for (auto& inp : input()) {
ret->add_dep_type_existing_var(inp,
NodeProp::DepType::VALUE_ALLOW_EMPTY);
}
return ret;
}
/* =========================== TypeCvt =========================== */
MGB_DYN_TYPE_OBJ_FINAL_IMPL(TypeCvt);
......
......@@ -163,6 +163,7 @@ MGB_DEFINE_OPR_CLASS(Elemwise, intl::ElemwiseBase) // {
void record_execute_deps(ExecDependencyArray& deps) override;
void add_input_layout_constraint() override;
NodeProp* do_make_node_prop() const override;
};
namespace intl {
......
......@@ -649,6 +649,17 @@ namespace {
> TernaryTraitTypes;
TYPED_TEST_CASE(TestOprBasicArithTernaryElemwise, TernaryTraitTypes);
::testing::AssertionResult assert_shape_equal(const TensorShape& v0,
const TensorShape& v1) {
if (v0.eq_shape(v1))
return ::testing::AssertionSuccess()
<< v0.to_string() << " == " << v1.to_string();
else
return ::testing::AssertionFailure()
<< v0.to_string() << " != " << v1.to_string();
}
#define ASSERT_SHAPE_EQ(v0, v1) ASSERT_TRUE(assert_shape_equal(v0, v1))
} // anonymous namespace
template<typename Trait, typename dtype>
......@@ -950,4 +961,58 @@ TEST(TestLayoutUtil, CollectiveCollapse) {
check(cc_res5, std_res5);
}
TEST(TestOprBasicArithElemwise, EmptyInputOutputUnary) {
HostTensorGenerator<> gen;
auto graph = ComputingGraph::make();
auto host_x = gen({3, 0, 1, 3});
auto x = opr::Host2DeviceCopy::make(*graph, host_x),
y = opr::Elemwise::make(
{x}, opr::Elemwise::Param(opr::Elemwise::Param::Mode::RELU));
HostTensorND host_y;
auto func = graph->compile({make_callback_copy(y, host_y)});
ASSERT_NO_THROW(func->execute().wait());
ASSERT_TRUE(host_y.empty());
ASSERT_TRUE(host_y.shape().is_empty());
ASSERT_SHAPE_EQ(host_y.shape(), TensorShape({3, 0, 1, 3}));
}
TEST(TestOprBasicArithElemwise, EmptyInputOutputBinary) {
HostTensorGenerator<> gen;
auto graph = ComputingGraph::make();
auto host_x = gen({0, 8, 1, 7}), host_y = gen({0, 8, 1, 7});
auto x = opr::Host2DeviceCopy::make(*graph, host_x),
y = opr::Host2DeviceCopy::make(*graph, host_y),
z = x + y;
HostTensorND host_z;
auto func = graph->compile({make_callback_copy(z, host_z)});
// Invalid broadcast
host_y->resize({0, 9, 1, 7});
ASSERT_ANY_THROW(func->execute().wait());
// Broadcast to 0
host_y->resize({1, 8, 0, 7});
ASSERT_NO_THROW(func->execute().wait());
ASSERT_TRUE(host_z.empty());
ASSERT_TRUE(host_z.shape().is_empty());
ASSERT_SHAPE_EQ(host_z.shape(), TensorShape({0, 8, 0, 7}));
// Broadcast to 0 (2)
host_y->resize({2, 8, 1, 7});
ASSERT_NO_THROW(func->execute().wait());
ASSERT_TRUE(host_z.empty());
ASSERT_TRUE(host_z.shape().is_empty());
ASSERT_SHAPE_EQ(host_z.shape(), TensorShape({0, 8, 1, 7}));
// Scalar broadcast
z = x + x.make_scalar(1.f);
func = graph->compile({make_callback_copy(z, host_z)});
ASSERT_NO_THROW(func->execute().wait());
ASSERT_TRUE(host_z.empty());
ASSERT_TRUE(host_z.shape().is_empty());
ASSERT_SHAPE_EQ(host_z.shape(), TensorShape({0, 8, 1, 7}));
}
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册