提交 672d4ad0 编写于 作者: M Megvii Engine Team

fix(jit): more testcases on the grad of JITExecutor

GitOrigin-RevId: c3bb40597934b65375e1b90f051fabaf01309815
上级 bc95e873
...@@ -549,8 +549,8 @@ MGB_IMPL_OPR_GRAD(JITExecutor) { ...@@ -549,8 +549,8 @@ MGB_IMPL_OPR_GRAD(JITExecutor) {
rewriter.auto_replace_outputs(opr); rewriter.auto_replace_outputs(opr);
}); });
static auto expand_into_origin_graph = [](cg::OperatorNodeBase* opr, auto expand_into_origin_graph = [&rewriter](
InternalGraphRewriter& rewriter, const VarNodeArray& grad_inputs) { cg::OperatorNodeBase* opr, const VarNodeArray& grad_inputs) {
if (auto ph = gopt::try_cast_as_op<JITPlaceholder>(opr)) { if (auto ph = gopt::try_cast_as_op<JITPlaceholder>(opr)) {
rewriter.replace_var( rewriter.replace_var(
opr->output(0), grad_inputs.at(ph->input_id())); opr->output(0), grad_inputs.at(ph->input_id()));
...@@ -571,7 +571,7 @@ MGB_IMPL_OPR_GRAD(JITExecutor) { ...@@ -571,7 +571,7 @@ MGB_IMPL_OPR_GRAD(JITExecutor) {
// oprs // oprs
using namespace std::placeholders; using namespace std::placeholders;
rewriter.iter(std::bind(expand_into_origin_graph, _1, rewriter.iter(std::bind(expand_into_origin_graph, _1,
std::ref(rewriter), std::cref(grad_inputs))); std::cref(grad_inputs)));
return rewriter.dest_var(); return rewriter.dest_var();
} else { } else {
VarNodeArray new_grad_inputs; VarNodeArray new_grad_inputs;
...@@ -602,7 +602,7 @@ MGB_IMPL_OPR_GRAD(JITExecutor) { ...@@ -602,7 +602,7 @@ MGB_IMPL_OPR_GRAD(JITExecutor) {
// infer and const folding mechanism // infer and const folding mechanism
using namespace std::placeholders; using namespace std::placeholders;
rewriter.iter(std::bind(expand_into_origin_graph, _1, rewriter.iter(std::bind(expand_into_origin_graph, _1,
std::ref(rewriter), std::cref(new_grad_inputs))); std::cref(new_grad_inputs)));
return rewriter.dest_var(); return rewriter.dest_var();
} }
gx = rewriter.dest_var(); gx = rewriter.dest_var();
......
...@@ -1443,7 +1443,111 @@ TEST(TestJITNvrtc, DimshuffleGrad) { ...@@ -1443,7 +1443,111 @@ TEST(TestJITNvrtc, DimshuffleGrad) {
}, },
CompNode::load("gpu0")}; CompNode::load("gpu0")};
checker.set_jit_level(1) checker.set_jit_level(1)
.run({TensorShape{1, 2, 3, 4}, {2, 3, 4, 1}}); .run({TensorShape{1, 2, 3, 4}, {2, 3, 4, 1}})
.run({TensorShape{3, 4, 1, 2}, {4, 1, 2, 3}})
.run({TensorShape{4, 6, 3, 5}, {6, 3, 5, 4}});
}
}
TEST(TestJITExecutor, GradBehavior) {
REQUIRE_GPU(1);
auto cn = CompNode::load("gpu0");
HostTensorGenerator<> gen;
{
set_backend(Backend::NVRTC);
auto graph = ComputingGraph::make();
auto host_a = gen({2, 3, 4}, cn);
auto a = opr::Host2DeviceCopy::make(*graph, host_a),
x = opr::exp(a + 1);
gopt::GraphOptimizer gopt;
gopt.add_pass<gopt::JITFusionPass>();
VarNodeArray dest_vars{x.node()};
gopt.apply_inplace(dest_vars);
x = opr::reduce_sum(dest_vars[0], a.make_scalar_dt(1));
SmallVector<jit::JITExecutor*> jits;
auto on_opr = [&jits](cg::OperatorNodeBase* op) {
if (auto jit = op->try_cast_final<jit::JITExecutor>()) {
jits.push_back(jit);
}
};
auto grad_a = cg::grad(x, a);
cg::DepOprIter{on_opr}.add(grad_a);
ASSERT_EQ(jits.size(), 2);
// input of forward jit executor: host_a
ASSERT_EQ(jits[0]->input().size(), 1);
// input of grad jit executor:
// output of forward jit executor, output grad
ASSERT_EQ(jits[1]->input().size(), 2);
// internal graph is (input: og, out | output: og * out)
size_t nr_ph = 0, nr_mul = 0;
cg::DepOprIter{
[&nr_ph, &nr_mul](cg::OperatorNodeBase* op) {
if (op->same_type<jit::JITPlaceholder>()) {
++ nr_ph;
return;
}
if(auto mul = op->try_cast_final<opr::Elemwise>()) {
using Mode = opr::Elemwise::Mode;
if (mul->param().mode == Mode::MUL) {
++ nr_mul;
return;
}
}
mgb_throw(MegBrainError, "unexpected op %s", op->cname());
}}
.add(jits[1]->internal_graph_ptr()->output());
ASSERT_EQ(nr_ph, 2);
ASSERT_EQ(nr_mul, 1);
}
{
set_backend(Backend::HALIDE);
auto graph = ComputingGraph::make();
auto host_a = gen({2, 3, 4}, cn);
auto a = opr::Host2DeviceCopy::make(*graph, host_a),
x = opr::exp(a + 1);
gopt::GraphOptimizer gopt;
gopt.add_pass<gopt::JITFusionPass>();
VarNodeArray dest_vars{x.node()};
gopt.apply_inplace(dest_vars);
x = opr::reduce_sum(dest_vars[0], a.make_scalar_dt(1));
size_t nr_ops = 0, nr_jits = 0;
auto on_opr = [&nr_jits, &nr_ops](cg::OperatorNodeBase* op) {
if (op->same_type<jit::JITExecutor>()) {
++ nr_jits;
}
++ nr_ops;
};
auto grad_a = cg::grad(x, a);
cg::DepOprIter{on_opr}.add(grad_a);
// in Halide backend, grad internal graph would be expanded into
// original graph, so there was only one JITExecutor
ASSERT_EQ(nr_jits, 1);
// the grad of a is broadcast(JITExecutor.output(0), a.shape()),
// so the oprs depended by grad_a are H2D(a), JITExecutor,
// GetVarShape(a) and broadcast
ASSERT_EQ(nr_ops, 4);
}
{
set_backend(Backend::NVRTC);
auto graph = ComputingGraph::make();
auto host_a = gen({2, 3, 4}, cn);
auto a = opr::SharedDeviceTensor::make(*graph, *host_a),
x = a * 2 + 1;
gopt::GraphOptimizer gopt;
gopt.add_pass<gopt::JITFusionPass>();
VarNodeArray dest_vars{x.node()};
gopt.apply_inplace(dest_vars);
x = opr::reduce_sum(dest_vars[0], a.make_scalar_dt(1));
auto grad_a = cg::grad(x, a);
// all inputs of grad jit executor are const, its internal graph
// would be expanded into original graph for more optimizations,
// so no JITExecutor can be found
cg::DepOprIter{[](cg::OperatorNodeBase* op) {
ASSERT_FALSE(op->same_type<jit::JITExecutor>());}
}.add(grad_a);
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册