From 672d4ad0e06e8f12fa30dd29c8ab2d57d8de8258 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Mon, 15 Jun 2020 13:58:46 +0800 Subject: [PATCH] fix(jit): more testcases on the grad of JITExecutor GitOrigin-RevId: c3bb40597934b65375e1b90f051fabaf01309815 --- src/jit/impl/executor_opr.cpp | 8 +-- src/jit/test/fusion.cpp | 106 +++++++++++++++++++++++++++++++++- 2 files changed, 109 insertions(+), 5 deletions(-) diff --git a/src/jit/impl/executor_opr.cpp b/src/jit/impl/executor_opr.cpp index 744af62f9..f76506d1b 100644 --- a/src/jit/impl/executor_opr.cpp +++ b/src/jit/impl/executor_opr.cpp @@ -549,8 +549,8 @@ MGB_IMPL_OPR_GRAD(JITExecutor) { rewriter.auto_replace_outputs(opr); }); - static auto expand_into_origin_graph = [](cg::OperatorNodeBase* opr, - InternalGraphRewriter& rewriter, const VarNodeArray& grad_inputs) { + auto expand_into_origin_graph = [&rewriter]( + cg::OperatorNodeBase* opr, const VarNodeArray& grad_inputs) { if (auto ph = gopt::try_cast_as_op(opr)) { rewriter.replace_var( opr->output(0), grad_inputs.at(ph->input_id())); @@ -571,7 +571,7 @@ MGB_IMPL_OPR_GRAD(JITExecutor) { // oprs using namespace std::placeholders; rewriter.iter(std::bind(expand_into_origin_graph, _1, - std::ref(rewriter), std::cref(grad_inputs))); + std::cref(grad_inputs))); return rewriter.dest_var(); } else { VarNodeArray new_grad_inputs; @@ -602,7 +602,7 @@ MGB_IMPL_OPR_GRAD(JITExecutor) { // infer and const folding mechanism using namespace std::placeholders; rewriter.iter(std::bind(expand_into_origin_graph, _1, - std::ref(rewriter), std::cref(new_grad_inputs))); + std::cref(new_grad_inputs))); return rewriter.dest_var(); } gx = rewriter.dest_var(); diff --git a/src/jit/test/fusion.cpp b/src/jit/test/fusion.cpp index 82606b5b5..ec7f71d4b 100644 --- a/src/jit/test/fusion.cpp +++ b/src/jit/test/fusion.cpp @@ -1443,7 +1443,111 @@ TEST(TestJITNvrtc, DimshuffleGrad) { }, CompNode::load("gpu0")}; checker.set_jit_level(1) - .run({TensorShape{1, 2, 3, 4}, {2, 3, 4, 1}}); + .run({TensorShape{1, 2, 3, 4}, {2, 3, 4, 1}}) + .run({TensorShape{3, 4, 1, 2}, {4, 1, 2, 3}}) + .run({TensorShape{4, 6, 3, 5}, {6, 3, 5, 4}}); + } +} + +TEST(TestJITExecutor, GradBehavior) { + REQUIRE_GPU(1); + auto cn = CompNode::load("gpu0"); + HostTensorGenerator<> gen; + { + set_backend(Backend::NVRTC); + auto graph = ComputingGraph::make(); + auto host_a = gen({2, 3, 4}, cn); + auto a = opr::Host2DeviceCopy::make(*graph, host_a), + x = opr::exp(a + 1); + + gopt::GraphOptimizer gopt; + gopt.add_pass(); + VarNodeArray dest_vars{x.node()}; + gopt.apply_inplace(dest_vars); + x = opr::reduce_sum(dest_vars[0], a.make_scalar_dt(1)); + SmallVector jits; + auto on_opr = [&jits](cg::OperatorNodeBase* op) { + if (auto jit = op->try_cast_final()) { + jits.push_back(jit); + } + }; + auto grad_a = cg::grad(x, a); + cg::DepOprIter{on_opr}.add(grad_a); + ASSERT_EQ(jits.size(), 2); + // input of forward jit executor: host_a + ASSERT_EQ(jits[0]->input().size(), 1); + // input of grad jit executor: + // output of forward jit executor, output grad + ASSERT_EQ(jits[1]->input().size(), 2); + // internal graph is (input: og, out | output: og * out) + size_t nr_ph = 0, nr_mul = 0; + cg::DepOprIter{ + [&nr_ph, &nr_mul](cg::OperatorNodeBase* op) { + if (op->same_type()) { + ++ nr_ph; + return; + } + if(auto mul = op->try_cast_final()) { + using Mode = opr::Elemwise::Mode; + if (mul->param().mode == Mode::MUL) { + ++ nr_mul; + return; + } + } + mgb_throw(MegBrainError, "unexpected op %s", op->cname()); + }} + .add(jits[1]->internal_graph_ptr()->output()); + ASSERT_EQ(nr_ph, 2); + ASSERT_EQ(nr_mul, 1); + } + { + set_backend(Backend::HALIDE); + auto graph = ComputingGraph::make(); + auto host_a = gen({2, 3, 4}, cn); + auto a = opr::Host2DeviceCopy::make(*graph, host_a), + x = opr::exp(a + 1); + + gopt::GraphOptimizer gopt; + gopt.add_pass(); + VarNodeArray dest_vars{x.node()}; + gopt.apply_inplace(dest_vars); + x = opr::reduce_sum(dest_vars[0], a.make_scalar_dt(1)); + size_t nr_ops = 0, nr_jits = 0; + auto on_opr = [&nr_jits, &nr_ops](cg::OperatorNodeBase* op) { + if (op->same_type()) { + ++ nr_jits; + } + ++ nr_ops; + }; + auto grad_a = cg::grad(x, a); + cg::DepOprIter{on_opr}.add(grad_a); + // in Halide backend, grad internal graph would be expanded into + // original graph, so there was only one JITExecutor + ASSERT_EQ(nr_jits, 1); + // the grad of a is broadcast(JITExecutor.output(0), a.shape()), + // so the oprs depended by grad_a are H2D(a), JITExecutor, + // GetVarShape(a) and broadcast + ASSERT_EQ(nr_ops, 4); + } + { + set_backend(Backend::NVRTC); + auto graph = ComputingGraph::make(); + auto host_a = gen({2, 3, 4}, cn); + auto a = opr::SharedDeviceTensor::make(*graph, *host_a), + x = a * 2 + 1; + + gopt::GraphOptimizer gopt; + gopt.add_pass(); + VarNodeArray dest_vars{x.node()}; + gopt.apply_inplace(dest_vars); + x = opr::reduce_sum(dest_vars[0], a.make_scalar_dt(1)); + auto grad_a = cg::grad(x, a); + // all inputs of grad jit executor are const, its internal graph + // would be expanded into original graph for more optimizations, + // so no JITExecutor can be found + cg::DepOprIter{[](cg::OperatorNodeBase* op) { + ASSERT_FALSE(op->same_type());} + }.add(grad_a); } } -- GitLab