提交 b74afde8 编写于 作者: M Megvii Engine Team

feat(mgb/opr): let reduce support empty IO

GitOrigin-RevId: 88b37123a8fa7f7dafbb1b0c506fb79f1e5a24c4
上级 1af350c6
...@@ -13,7 +13,7 @@ import pytest ...@@ -13,7 +13,7 @@ import pytest
from utils import opr_test from utils import opr_test
import megengine.functional as F import megengine.functional as F
from megengine import tensor from megengine import jit, tensor
def common_test_reduce(opr, ref_opr): def common_test_reduce(opr, ref_opr):
...@@ -204,3 +204,28 @@ def test_topk(descending, sorted, inp1d, kth_only): ...@@ -204,3 +204,28 @@ def test_topk(descending, sorted, inp1d, kth_only):
if not sorted: if not sorted:
values = np_sort(values) values = np_sort(values)
np.testing.assert_equal(values, np_sort(data)[..., :k]) np.testing.assert_equal(values, np_sort(data)[..., :k])
@pytest.mark.parametrize("is_trace", [True, False])
def test_reduce_on_empty_tensor(is_trace):
dtypes = [np.float32, np.int32, np.bool]
inputs = [
(np.random.random((0,)), None),
(np.random.random((3, 0, 2)), 1),
(np.random.random((10, 10, 0, 10)), 0),
]
def run_test(fn, ref_fn, input, dtype, axis=None, symbolic=False):
if is_trace:
fn = jit.trace(symbolic=symbolic)(fn)
for i in range(3):
out = fn(tensor(input, dtype=dtype), axis=axis).numpy()
out_ref = ref_fn(input.astype(dtype), axis=axis)
np.testing.assert_equal(out, out_ref)
for dtype in dtypes:
for inp, axis in inputs:
run_test(F.sum, np.sum, inp, dtype, axis, True)
run_test(F.sum, np.sum, inp, dtype, axis, False)
run_test(F.prod, np.prod, inp, dtype, axis, True)
run_test(F.prod, np.prod, inp, dtype, axis, False)
...@@ -84,6 +84,11 @@ public: ...@@ -84,6 +84,11 @@ public:
auto&& dev_tensor = tensor.dev_tensor(); auto&& dev_tensor = tensor.dev_tensor();
var->m_comp_node = dev_tensor.comp_node(); var->m_comp_node = dev_tensor.comp_node();
var->m_shape = dev_tensor.shape(); var->m_shape = dev_tensor.shape();
if (dev_tensor.empty()) {
auto layout = dev_tensor.layout();
layout.init_contiguous_stride();
dev_tensor.reset(dev_tensor.storage(), layout);
}
var->m_dev_tensor = dev_tensor; var->m_dev_tensor = dev_tensor;
var->m_mem_plan.reset_from_owner_var().chunk() var->m_mem_plan.reset_from_owner_var().chunk()
.mem_alloc_status.set_from_owner_var(); .mem_alloc_status.set_from_owner_var();
......
...@@ -1364,7 +1364,7 @@ TEST(TestGraph, EmptyShapeCheck) { ...@@ -1364,7 +1364,7 @@ TEST(TestGraph, EmptyShapeCheck) {
using Param = opr::CondTake::Param; using Param = opr::CondTake::Param;
auto x = opr::Host2DeviceCopy::make(*graph, host_x), auto x = opr::Host2DeviceCopy::make(*graph, host_x),
y = opr::CondTake::make(x, x, {Param::Mode::GT})[0], y = opr::CondTake::make(x, x, {Param::Mode::GT})[0],
z = opr::reduce_sum(y, y.make_scalar(1)); z = opr::reduce_max(y, y.make_scalar(1));
HostTensorND host_z; HostTensorND host_z;
auto func = graph->compile({make_callback_copy(z, host_z)}); auto func = graph->compile({make_callback_copy(z, host_z)});
func->execute(); func->execute();
...@@ -1377,7 +1377,7 @@ TEST(TestGraph, EmptyShapeCheck) { ...@@ -1377,7 +1377,7 @@ TEST(TestGraph, EmptyShapeCheck) {
func->execute(); func->execute();
} catch (const MegBrainError& exc) { } catch (const MegBrainError& exc) {
std::string msg{exc.what()}; std::string msg{exc.what()};
ASSERT_TRUE(msg.find("empty output var") != ASSERT_TRUE(msg.find("empty input is not allowed") !=
std::string::npos) std::string::npos)
<< "bad message " << msg; << "bad message " << msg;
throw; throw;
...@@ -2413,8 +2413,6 @@ TEST(TestMemReuse, ResetEmptyDevTensor) { ...@@ -2413,8 +2413,6 @@ TEST(TestMemReuse, ResetEmptyDevTensor) {
y = opr::Reduce::make(x, {opr::Reduce::Mode::MAX, 0}); y = opr::Reduce::make(x, {opr::Reduce::Mode::MAX, 0});
HostTensorND host_y; HostTensorND host_y;
auto func = g->compile({make_callback_copy(y, host_y)}); auto func = g->compile({make_callback_copy(y, host_y)});
auto &&recv = x.node()->owner_graph()->var_receiver_in_current_comp_seq(x.node());
ASSERT_TRUE(!recv.is_empty_allowed());
if (inp_shp.is_empty()) { if (inp_shp.is_empty()) {
ASSERT_ANY_THROW(func->execute().wait()); ASSERT_ANY_THROW(func->execute().wait());
} else { } else {
......
...@@ -1072,6 +1072,7 @@ class Reduce::KernScheduler { ...@@ -1072,6 +1072,7 @@ class Reduce::KernScheduler {
m_apply_side_effect; m_apply_side_effect;
std::unique_ptr<megdnn::Elemwise> m_elemwise_trans_opr; std::unique_ptr<megdnn::Elemwise> m_elemwise_trans_opr;
std::unique_ptr<megdnn::TypeCvt> m_typecvt_opr; std::unique_ptr<megdnn::TypeCvt> m_typecvt_opr;
std::unique_ptr<megdnn::Fill> m_fill_opr;
DeviceTensorND m_side_affect_wkspc; DeviceTensorND m_side_affect_wkspc;
}; };
...@@ -1338,6 +1339,47 @@ void Reduce::KernScheduler::execute( ...@@ -1338,6 +1339,47 @@ void Reduce::KernScheduler::execute(
} }
mgb_assert(!m_kern_param.empty()); mgb_assert(!m_kern_param.empty());
// empty input
if (input.shape_valid() && input.empty()) {
auto mode = m_kern_param[0].kparam.mode;
if (!m_fill_opr) {
m_fill_opr = intl::get_megdnn_handle(dest.comp_node())->
create_operator<megdnn::Fill>();
}
std::string err_msg;
switch (mode) {
case Reduce::Mode::SUM:
if (!dest.empty()) {
m_fill_opr->param() = 0;
m_fill_opr->exec(dest.as_megdnn(), {});
}
break;
case Reduce::Mode::PRODUCT:
if (!dest.empty()) {
m_fill_opr->param() = 1;
m_fill_opr->exec(dest.as_megdnn(), {});
}
break;
case Reduce::Mode::MEAN:
err_msg = "mean"; break;
case Reduce::Mode::MIN:
err_msg = "min"; break;
case Reduce::Mode::MAX:
err_msg = "max"; break;
case Reduce::Mode::SUM_SQR:
err_msg = "sum_sqr"; break;
default:
mgb_throw(MegBrainError, "bad reduce mode");
}
if (!err_msg.empty()) {
mgb_throw(
MegBrainError,
"empty input is not allowed for reduce mode: %s",
err_msg.c_str());
}
return;
}
mgb_assert(input.layout().is_contiguous() && mgb_assert(input.layout().is_contiguous() &&
input.raw_ptr() == m_kern_param[0].input.raw_ptr && input.raw_ptr() == m_kern_param[0].input.raw_ptr &&
dest.raw_ptr() == m_kern_param.back().output.raw_ptr); dest.raw_ptr() == m_kern_param.back().output.raw_ptr);
...@@ -1425,7 +1467,9 @@ Reduce::Reduce(VarNode *inp, VarNode *target_shape, const Param &param, ...@@ -1425,7 +1467,9 @@ Reduce::Reduce(VarNode *inp, VarNode *target_shape, const Param &param,
mgb_throw(GraphError, "invalid param data_type: %d", mgb_throw(GraphError, "invalid param data_type: %d",
int(param.data_type)); int(param.data_type));
} }
add_output(None)->dtype(out_dtype); add_output(None)
->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE)
.dtype(out_dtype);
cg::add_workspace_output(this); cg::add_workspace_output(this);
add_equivalence_component<PODHash<Param>>(&m_param); add_equivalence_component<PODHash<Param>>(&m_param);
...@@ -1703,6 +1747,13 @@ void Reduce::perform( ...@@ -1703,6 +1747,13 @@ void Reduce::perform(
ksched.execute(opr.get(), *input_contig, dest); ksched.execute(opr.get(), *input_contig, dest);
} }
Reduce::NodeProp* Reduce::do_make_node_prop() const {
auto ret = Super::do_make_node_prop();
ret->add_dep_type_existing_var(input(0),
NodeProp::DepType::VALUE_ALLOW_EMPTY);
return ret;
}
void Reduce::create_megdnn_opr() { void Reduce::create_megdnn_opr() {
set_megdnn_opr(intl::get_megdnn_handle(comp_node())-> set_megdnn_opr(intl::get_megdnn_handle(comp_node())->
create_operator<megdnn::Reduce>()); create_operator<megdnn::Reduce>());
......
...@@ -335,6 +335,7 @@ MGB_DEFINE_OPR_CLASS(Reduce, intl::DynamicOutputIfInputDynamic< ...@@ -335,6 +335,7 @@ MGB_DEFINE_OPR_CLASS(Reduce, intl::DynamicOutputIfInputDynamic<
void add_input_layout_constraint() override final; void add_input_layout_constraint() override final;
void scn_do_execute() override final; void scn_do_execute() override final;
void init_output_static_infer_desc() override final; void init_output_static_infer_desc() override final;
NodeProp* do_make_node_prop() const override;
void create_megdnn_opr() override; void create_megdnn_opr() override;
void record_execute_deps(ExecDependencyArray& deps) override; void record_execute_deps(ExecDependencyArray& deps) override;
......
...@@ -900,4 +900,61 @@ TEST(TestBasicArithReduction, StaticInferValueDType) { ...@@ -900,4 +900,61 @@ TEST(TestBasicArithReduction, StaticInferValueDType) {
run_test(F16, F16, ParamType::FLOAT_O16xC32); run_test(F16, F16, ParamType::FLOAT_O16xC32);
} }
TEST(TestBasicArithReduction, EmptyInput) {
using Param = opr::Reduce::Param;
using Mode = opr::Reduce::Mode;
auto check_allow_empty = [](const Param& param, const TensorShape& inpshp, double target_val) {
HostTensorGenerator<> gen;
auto graph = ComputingGraph::make();
auto host_x = gen(inpshp);
auto x = opr::Host2DeviceCopy::make(*graph, host_x),
y = opr::Reduce::make(x, param, {});
HostTensorND host_y;
auto func = graph->compile({make_callback_copy(y, host_y)});
func->execute().wait();
if (!host_y.shape().is_empty()) {
size_t size = host_y.layout().total_nr_elems();
#define cb(DType) \
if (host_y.layout().dtype == DType()) { \
using ctype = typename DTypeTrait<DType>::ctype; \
auto ptr = host_y.ptr<ctype>(); \
ctype target = static_cast<ctype>(target_val); \
for (size_t i = 0; i < size; ++i) { \
ASSERT_TRUE(ptr[i] == target); \
} \
}
MEGDNN_FOREACH_COMPUTING_DTYPE(cb)
#undef cb
} else {
ASSERT_TRUE(host_y.empty());
}
};
auto check_forbid_empty = [](const Param& param, const TensorShape& inpshp) {
HostTensorGenerator<> gen;
auto graph = ComputingGraph::make();
auto host_x = gen(inpshp);
auto x = opr::Host2DeviceCopy::make(*graph, host_x),
y = opr::Reduce::make(x, param, {});
HostTensorND host_y;
auto func = graph->compile({make_callback_copy(y, host_y)});
ASSERT_ANY_THROW(func->execute().wait());
};
check_allow_empty({Mode::SUM, 0, {}}, {0}, 0);
check_allow_empty({Mode::SUM, -1, {}}, {2, 0, 3}, 0);
check_allow_empty({Mode::SUM, 1, {}}, {2, 0, 3}, 0);
check_allow_empty({Mode::PRODUCT, 0, {}}, {0, 1, 2}, 1);
check_allow_empty({Mode::PRODUCT, 1, {}}, {0, 0, 0}, 1);
check_allow_empty({Mode::PRODUCT, 2, {}}, {0, 0, 0}, 1);
check_forbid_empty({Mode::MAX, 0, {}}, {0});
check_forbid_empty({Mode::MIN, -1, {}}, {0, 1, 2});
check_forbid_empty({Mode::MEAN, 0, {}}, {0, 0});
check_forbid_empty({Mode::SUM_SQR, 1, {}}, {2, 1, 0});
}
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册