提交 38b49272 编写于 作者: M Megvii Engine Team

fix(opr): fix no update ptr in reduce operator when input change

GitOrigin-RevId: a443a79ac0a73cd984c163c56771dfc49f9cdebd
上级 4cce2480
......@@ -1648,6 +1648,9 @@ void Reduce::scn_do_execute() {
m_kern_scheduler->check_shapes(inp.shape(), out_ptr->shape());
if (m_kern_scheduler->has_actual_computing()) {
m_kern_scheduler->update_ptr(
inp, *out_ptr,
output(1)->shape()[0] ? output(1)->dev_tensor() : DeviceTensorND{});
m_kern_scheduler->execute(
static_cast<megdnn::Reduce*>(megdnn_opr()), inp, *out_ptr);
} else {
......
......@@ -416,6 +416,31 @@ TEST(TestBasicArithReduction, NonContFwd) {
}
}
TEST(TestBasicArithReduction, ResetMemory) {
HostTensorGenerator<> gen;
auto graph = ComputingGraph::make();
auto host_x = gen({3, 2});
auto host_tshp =
std::make_shared<HostTensorND>(host_x->comp_node(), dtype::Int32());
host_tshp->resize({1});
host_tshp->ptr<int>()[0] = 1;
auto tshp = opr::Host2DeviceCopy::make(*graph, host_tshp, {"tshp"});
auto x = opr::Host2DeviceCopy::make(*graph, host_x);
auto y = opr::reduce_max(x, tshp);
HostTensorND host_y;
auto func = graph->compile({make_callback_copy(y, host_y)});
func->execute();
func->wait();
//! only reset the host x memory, make sure the case can run normal
auto host_x_tmp = gen({3, 2});
host_x->reset(host_x_tmp->storage(), host_x_tmp->layout());
func->execute();
func->wait();
}
TEST(TestBasicArithReduction, NonContPerform) {
DeviceTensorND x{CompNode::default_cpu(), dtype::Float32()},
y{x.comp_node(), x.dtype()}, workspace;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册