提交 7ba641fe 编写于 作者: M Megvii Engine Team

fix(mgb/core): fix cond op with empty shape

GitOrigin-RevId: 1953d4cd2150a6ae537501d346f4ff710fcfcef9
上级 6fe6df28
......@@ -699,7 +699,9 @@ CondExecMark::CondExecMark(VarNode* ppv, const VarNodeArrayView& inputs,
for (size_t i = 0; i < inputs.size(); ++i) {
add_input({inputs[i]});
add_output(ssprintf("fwd%zu", i))->dtype(inputs[i]->dtype());
add_output(ssprintf("fwd%zu", i))
->dtype(inputs[i]->dtype())
.add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE);
}
add_input({ppv});
add_equivalence_component<PODHash<Param>>(&m_param);
......@@ -789,6 +791,10 @@ void CondExecMark::add_input_layout_constraint() {
CondExecMark::NodeProp* CondExecMark::do_make_node_prop() const {
auto ret = Super::do_make_node_prop();
ret->dep_map().at(input().back()) = NodeProp::DepType::DEV_COMP_ORDER;
for (size_t i = 0; i < input().size() - 1; ++ i) {
ret->add_dep_type_existing_var(input(i),
NodeProp::DepType::VALUE_ALLOW_EMPTY);
}
return ret;
}
......@@ -859,7 +865,8 @@ CondExecMerge::CondExecMerge(const VarNodeArrayView& inputs,
// 2. dynamic allocator would wait for all inputs to become ready (see
// VarNodeMemManager::DynamicAllocOprInfo::host_wait_input_ready),
// which would cause infinite waiting for unselected inputs.
ovar->add_flag(VarNode::Flag::NO_SYS_MEM_ALLOC);
ovar->add_flag(VarNode::Flag::NO_SYS_MEM_ALLOC)
.add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE);
}
MGB_MARK_USED_VAR(mask2str);
......@@ -1056,7 +1063,9 @@ void CondExecMerge::init_output_static_infer_desc() {
desc.infer_func = [this](DeviceTensorND& dest, const InpVal& inp) {
auto nr_branch = m_branch_masks.size();
bool found = false, first = true;
for (size_t i = 0; i < nr_branch; ++i) {
auto&& shape = inp.val.at(nr_branch).shape();
for (size_t i = 0; i < nr_branch && !shape.is_empty(); ++i) {
if (!inp.val[i].value().ptr<int>()[0])
continue;
auto&& cur = inp.val.at(nr_branch + i).value();
......@@ -1083,7 +1092,6 @@ void CondExecMerge::init_output_static_infer_desc() {
}
}
if (!found) {
auto&& shape = inp.val.at(nr_branch).shape();
if (dest.storage().raw_storage().use_count() > 1) {
// likely to be assigned from some input in previous
// runs; we create a new tensor to avoid modifying input
......@@ -1115,6 +1123,7 @@ void CondExecMerge::scn_do_execute() {
bool first = true;
auto&& forwarded = m_mem_forwarded;
std::vector<bool> is_shape_empty(nr_out, false);
for (size_t br = 0; br < m_branch_masks.size(); ++br) {
if (!m_branch_masks[br]->enabled()) {
continue;
......@@ -1125,6 +1134,10 @@ void CondExecMerge::scn_do_execute() {
for (size_t oidx = 0; oidx < nr_out; ++oidx) {
bool succ = output(oidx)->reset_dev_tensor_from_other_var(
inp(br, oidx));
if (inp(br, oidx)->shape().is_empty()) {
is_shape_empty[oidx] = true;
continue;
}
if (!is_exact_one()) {
if (forwarded.empty()) {
forwarded.resize(nr_out);
......@@ -1144,6 +1157,11 @@ void CondExecMerge::scn_do_execute() {
auto ovar = output(oidx);
auto&& src = inp(br, oidx)->dev_tensor().as_megdnn();
auto&& dest = ovar->dev_tensor().as_megdnn();
mgb_assert(src.layout.eq_shape(dest.layout),
"shape mismatch: %s vs %s in CondExecMerge",
src.layout.to_string().c_str(),
dest.layout.to_string().c_str());
if (is_shape_empty[oidx]) continue;
if (forwarded[oidx]) {
ovar->shape_alloc(ovar->shape());
auto&& own_dest = ovar->dev_tensor().as_megdnn();
......@@ -1200,6 +1218,10 @@ CondExecMerge::NodeProp* CondExecMerge::do_make_node_prop() const {
// directly
ret->dep_map().at(input().back()) = NodeProp::DepType::DEV_COMP_ORDER;
}
for (size_t i = 0; i < m_param.nr_output * m_branch_masks.size(); ++ i) {
ret->add_dep_type_existing_var(input(i),
NodeProp::DepType::VALUE_ALLOW_EMPTY);
}
return ret;
}
......
......@@ -14,6 +14,7 @@
#include "megbrain/opr/basic_arith_wrapper.h"
#include "megbrain/opr/cond.h"
#include "megbrain/opr/io.h"
#include "megbrain/opr/misc.h"
#include "megbrain/opr/tensor_manip.h"
#include "megbrain/opr/utility.h"
#include "megbrain/utils/timer.h"
......@@ -1285,6 +1286,80 @@ TEST(TestCondExec, MultiShape) {
check(host_d2);
}
TEST(TestCondExec, EmptyShape) {
HostTensorGenerator<> gen;
auto host_pred = gen({1});
host_pred->ptr<float>()[0] = 0;
static auto empty_in_empty_out = [](SymbolVar x) {
return x;
};
static auto empty_in_scalar_out = [](SymbolVar x) {
return opr::Concat::make({x, x.make_scalar(1.f)}, 0);
};
static auto scalar_in_empty_out = [](SymbolVar x) {
return opr::CondTake::make(x, x, {})[0]; // whether eq 0
};
{ // EXACT_ONE
auto graph = ComputingGraph::make();
auto pred = opr::Host2DeviceCopy::make(*graph, host_pred),
empty = opr::ImmutableTensor::make(*graph, *gen({0})),
scalar = pred.make_scalar(1.f),
y0 = empty_in_empty_out(make_one_cond(pred + 1, empty)),
y1 = empty_in_scalar_out(make_one_cond(pred, empty)),
y2 = scalar_in_empty_out(make_one_cond(pred - 1, scalar)),
z = merge_one_out({y0, y1, y2}, MergeMode::EXACT_ONE);
HostTensorND host_z;
auto func = graph->compile({make_callback_copy(z, host_z)});
func->execute();
ASSERT_TRUE(host_z.layout().is_empty());
host_pred->ptr<float>()[0] = 1;
func->execute();
ASSERT_EQ(1.f, host_z.ptr<float>()[0]);
host_pred->ptr<float>()[0] = 2;
func->execute();
ASSERT_TRUE(host_z.layout().is_empty());
}
{ // SUM
auto graph = ComputingGraph::make();
host_pred->ptr<float>()[0] = 1;
auto pred = opr::Host2DeviceCopy::make(*graph, host_pred),
empty = opr::ImmutableTensor::make(*graph, *gen({0})),
scalar = pred.make_scalar(1.f),
y0 = empty_in_empty_out(make_one_cond(pred, empty)),
y1 = scalar_in_empty_out(make_one_cond(pred, scalar)),
z = merge_one_out({y0, y1}, MergeMode::SUM);
HostTensorND host_z;
auto func = graph->compile({make_callback_copy(z, host_z)});
func->execute();
ASSERT_TRUE(host_z.layout().is_empty());
}
{ // TAKE GRAD
auto graph = ComputingGraph::make();
host_pred->ptr<float>()[0] = 0;
auto pred = opr::Host2DeviceCopy::make(*graph, host_pred),
x = pred.make_scalar(1.2f),
y0 = opr::CondTake::make(make_one_cond(pred + 1, x), pred, {})[0],
y1 = make_one_cond(pred, x.make_scalar(3.4f)),
z = merge_one_out({y0, y1}, MergeMode::EXACT_ONE),
g = cg::grad(z, x);
HostTensorND host_z, host_g;
auto func = graph->compile({
make_callback_copy(z, host_z), make_callback_copy(g, host_g)});
func->execute();
ASSERT_EQ(1.2f, host_z.ptr<float>()[0]);
ASSERT_EQ(1.f, host_g.ptr<float>()[0]);
host_pred->ptr<float>()[0] = 1;
func->execute();
ASSERT_EQ(3.4f, host_z.ptr<float>()[0]);
ASSERT_EQ(0.f, host_g.ptr<float>()[0]);
}
}
#endif // MGB_ENABLE_COND_EXEC
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册