提交 eb3c6473 编写于 作者: M Megvii Engine Team

fix(jit): fix the JIT shallow copy error

GitOrigin-RevId: 513b7acf52e95e100583cb91a6fe04be09642a2f
上级 99040dbe
......@@ -479,6 +479,8 @@ void CompiledTransformation::compile() {
}
}
m_executable = m_graph->compile(output_specs);
mgb_assert(m_executable != nullptr, "The compiled executable is nullptr.");
m_var_accessors = var_accessors;
m_output_spec = output_specs;
}
......
......@@ -35,8 +35,12 @@ cg::OperatorNodeBase* opr_shallow_copy_jit_executor_opr(
};
cg::DepOprIter iter{on_opr};
for (size_t i = 0; i < inputs.size(); ++i) {
var_replace_map[opr.input(i)] = inputs[i];
auto input_opr = opr.input(i)->owner_opr();
for (size_t j = 0; j < input_opr->output().size(); j++) {
var_replace_map[input_opr->output(j)] = input_opr->output(j);
}
iter.set_visited(opr.input(i)->owner_opr());
var_replace_map[opr.input(i)] = inputs[i];
}
if (shape_infer) {
iter.add(shape_infer);
......
......@@ -6,7 +6,9 @@
#include "megbrain/jit/executor_opr.h"
#include "megbrain/opr/basic_arith.h"
#include "megbrain/opr/basic_arith_wrapper.h"
#include "megbrain/opr/dnn/layer_norm.h"
#include "megbrain/opr/tensor_manip.h"
#include "megbrain/serialization/opr_shallow_copy.h"
#include "megbrain/test/helper.h"
#include "megdnn/dtype.h"
......@@ -554,6 +556,61 @@ TEST(TestJITMlirDimshuffle, BasicGPU) {
#endif // MGB_JIT_MLIR
TEST(TestJITExecutor, TestJITExecutorShallowCopy) {
REQUIRE_GPU(1);
set_backend(Backend::NVRTC);
auto cn = CompNode::load("gpu0");
auto graph = ComputingGraph::make();
HostTensorGenerator<> gen;
auto host_x0 = gen({23, 42}, cn), host_x1 = gen({1, 42}, cn);
auto a = opr::Host2DeviceCopy::make(*graph, host_x0);
using Param = opr::LayerNormForward::Param;
Param param;
param.eps = 1e-5;
param.affine = false;
param.normalized_dim = 1;
param.normalized_size = 42;
auto out_array = opr::LayerNormForward::make(a, param);
a = out_array[1];
auto shape = out_array[2];
a = opr::TypeCvt::make(a, dtype::Float16{});
auto y = a + 2;
y = opr::TypeCvt::make(y, dtype::Float16{});
y = opr::TypeCvt::make((y + y.make_scalar_dt(1.f)), dtype::Float32{});
auto ig_gen = std::make_unique<InternalGraphGenerator>(y.node()->owner_opr());
for (auto i : get_rev_topo_order(y)) {
if (!(i->same_type<opr::Host2DeviceCopy>() ||
i->same_type<opr::LayerNormForward>() ||
i->same_type<opr::SharedDeviceTensor>())) {
ig_gen->add_opr(i);
}
}
auto igraph_0 = ig_gen->generate();
auto igraph_1 = std::make_shared<InternalGraph>(
igraph_0->output(), shape.node(), igraph_0->value_infer(),
igraph_0->placeholders());
auto y_jit = JITExecutor::make(igraph_1, ig_gen->orig_inps());
auto opr_ori = y_jit.node()->owner_opr();
auto opr_copy = serialization::copy_opr_shallow(*opr_ori, opr_ori->input());
auto out_var = opr_copy->output(0);
HostTensorND host_y, host_y_jit;
auto func = graph->compile(
{make_callback_copy(y, host_y), make_callback_copy(out_var, host_y_jit)});
func->execute().wait();
MGB_ASSERT_TENSOR_NEAR(host_y, host_y_jit, 5e-3);
}
#endif // MGB_JIT
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册