提交 3e4e4c46 编写于 作者: M Megvii Engine Team 提交者: huangxinda

feat(mgb/jit): add graph_opt_config and jit_config interfaces

GitOrigin-RevId: 170d9eeab295d3f1553b9991e507f877af478251
上级 1c7d0802
...@@ -11,6 +11,7 @@ from ..core._imperative_rt.core2 import ( ...@@ -11,6 +11,7 @@ from ..core._imperative_rt.core2 import (
set_cpp_apply_with_tracing, set_cpp_apply_with_tracing,
) )
from .dtr_config import DTRConfig from .dtr_config import DTRConfig
from .graph_opt_config import GraphOptimizationConfig
from .sublinear_memory_config import SublinearMemoryConfig from .sublinear_memory_config import SublinearMemoryConfig
from .tracing import ( from .tracing import (
apply_const_with_tracing, apply_const_with_tracing,
......
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
class GraphOptimizationConfig:
r"""
Configuration for graph optimization: False for OFF, True for ON. The default value
None means that opt_level will decide whther this optimization will be applied or not.
:param jit_fuse_dimshuffle: whether to fuse dimshuffle in JIT optimization
:param jit_fuse_reduce: whether to fuse reduce in JIT optimization
"""
def __init__(self):
self.jit_fuse_dimshuffle = None
self.jit_fuse_reduce = None
def __repr__(self):
val2str = {None: "UNSET", False: "OFF", True: "ON"}
return (
"GraphOptimizationConfig {"
+ " jit_fuse_dimshuffle = "
+ val2str[self.jit_fuse_dimshuffle]
+ ", jit_fuse_reduce = "
+ val2str[self.jit_fuse_reduce]
+ " }"
)
...@@ -38,6 +38,7 @@ from ..core.tensor import megbrain_graph as G ...@@ -38,6 +38,7 @@ from ..core.tensor import megbrain_graph as G
from ..core.tensor.utils import setscalar from ..core.tensor.utils import setscalar
from ..utils.naming import AutoNaming from ..utils.naming import AutoNaming
from .dtr_config import DTRConfig from .dtr_config import DTRConfig
from .graph_opt_config import GraphOptimizationConfig
from .sublinear_memory_config import SublinearMemoryConfig from .sublinear_memory_config import SublinearMemoryConfig
...@@ -129,6 +130,7 @@ class trace: ...@@ -129,6 +130,7 @@ class trace:
If not None, it enables sublinear memory optimization with given setting. If not None, it enables sublinear memory optimization with given setting.
:param profiling: whether to profile compiled trace. Default: False :param profiling: whether to profile compiled trace. Default: False
:param opt_level: optimization level for compiling trace. Default: 2 :param opt_level: optimization level for compiling trace. Default: 2
:param graph_opt_config: configuration for graph optimization. Default: None
:param symbolic_shape: whether to use symbolic shape for tracing. Default: True :param symbolic_shape: whether to use symbolic shape for tracing. Default: True
""" """
...@@ -146,6 +148,7 @@ class trace: ...@@ -146,6 +148,7 @@ class trace:
dtr_config: DTRConfig = None, dtr_config: DTRConfig = None,
profiling: bool = False, profiling: bool = False,
opt_level: int = 2, opt_level: int = 2,
graph_opt_config: GraphOptimizationConfig = None,
symbolic_shape: bool = True, symbolic_shape: bool = True,
): ):
self.__wrapped__ = function self.__wrapped__ = function
...@@ -156,6 +159,7 @@ class trace: ...@@ -156,6 +159,7 @@ class trace:
self._profiling = profiling self._profiling = profiling
self._profiler = None self._profiler = None
self._graph_opt_level = opt_level self._graph_opt_level = opt_level
self._graph_opt_config = graph_opt_config
self._symbolic_shape = symbolic_shape self._symbolic_shape = symbolic_shape
self._output_handles = set() self._output_handles = set()
...@@ -502,7 +506,14 @@ class trace: ...@@ -502,7 +506,14 @@ class trace:
graph.options.dtr_config.evictee_minimum_size = ( graph.options.dtr_config.evictee_minimum_size = (
self._dtr_config.evictee_minimum_size self._dtr_config.evictee_minimum_size
) )
# graph optimization
if self._graph_opt_config is not None:
mapping = {None: 0, False: 1, True: 2}
jit_config = graph.options.graph_opt.jit_config
jit_config.fuse_dimshuffle = mapping[
self._graph_opt_config.jit_fuse_dimshuffle
]
jit_config.fuse_reduce = mapping[self._graph_opt_config.jit_fuse_reduce]
# sublinear # sublinear
if self._sublinear_memory_config is not None: if self._sublinear_memory_config is not None:
graph.options.enable_sublinear_memory_opt = True graph.options.enable_sublinear_memory_opt = True
......
...@@ -421,12 +421,20 @@ void init_graph_rt(py::module m) { ...@@ -421,12 +421,20 @@ void init_graph_rt(py::module m) {
#undef CURRENT_CLASS #undef CURRENT_CLASS
#define CURRENT_CLASS cg::ComputingGraph::Options::GraphOpt #define CURRENT_CLASS cg::ComputingGraph::Options::GraphOpt
py::class_<cg::ComputingGraph::Options::GraphOpt>(PyComputingGraphOptions, "GraphOpt") auto PyGraphOpt = py::class_<cg::ComputingGraph::Options::GraphOpt>(
PyComputingGraphOptions, "GraphOpt")
DEF_READWRITE(jit) DEF_READWRITE(jit)
DEF_READWRITE(jit_config)
DEF_READWRITE(tensorrt); DEF_READWRITE(tensorrt);
#undef CURRENT_CLASS #undef CURRENT_CLASS
#define CURRENT_CLASS cg::ComputingGraph::Options::GraphOpt::JITConfig
py::class_<cg::ComputingGraph::Options::GraphOpt::JITConfig>(PyGraphOpt, "JITConfig")
DEF_READWRITE(fuse_dimshuffle)
DEF_READWRITE(fuse_reduce);
#undef CURRENT_CLASS
#define CURRENT_CLASS cg::ComputingGraph::Options::SublinearMemConfig #define CURRENT_CLASS cg::ComputingGraph::Options::SublinearMemConfig
py::class_<cg::ComputingGraph::Options::SublinearMemConfig>(PyComputingGraphOptions, "SublinearMemConfig") py::class_<cg::ComputingGraph::Options::SublinearMemConfig>(PyComputingGraphOptions, "SublinearMemConfig")
......
...@@ -25,7 +25,7 @@ from megengine.core.ops import builtin as ops ...@@ -25,7 +25,7 @@ from megengine.core.ops import builtin as ops
from megengine.core.ops.builtin import Elemwise from megengine.core.ops.builtin import Elemwise
from megengine.core.tensor.utils import isscalar from megengine.core.tensor.utils import isscalar
from megengine.functional import exp, log from megengine.functional import exp, log
from megengine.jit import exclude_from_trace, trace from megengine.jit import GraphOptimizationConfig, exclude_from_trace, trace
from megengine.module import Module from megengine.module import Module
from megengine.random import normal, uniform from megengine.random import normal, uniform
from megengine.utils.naming import AutoNaming from megengine.utils.naming import AutoNaming
...@@ -605,3 +605,30 @@ def test_trace_advance_indexing(shape_mode): ...@@ -605,3 +605,30 @@ def test_trace_advance_indexing(shape_mode):
for _ in range(3): for _ in range(3):
result_trace = f_traced(**params) result_trace = f_traced(**params)
np.testing.assert_equal(expected, result_trace.numpy()) np.testing.assert_equal(expected, result_trace.numpy())
@pytest.mark.require_ngpu(1) # nvrtc backend
def test_trace_jit_config():
def run(fuse_dimshuffle, fuse_reduce):
config = GraphOptimizationConfig()
config.jit_fuse_dimshuffle = fuse_dimshuffle
config.jit_fuse_reduce = fuse_reduce
# set opt_level = 1 to avoid fusing dimshuffle and reduce at the same time
@trace(opt_level=1, graph_opt_config=config)
def func(x):
return x + 1
x = tensor(2)
y = func(x)
func._compile()
options = func._graph.options
mapping = {None: 0, False: 1, True: 2}
assert options.graph_opt.jit == 0
assert options.graph_opt.jit_config.fuse_dimshuffle == mapping[fuse_dimshuffle]
assert options.graph_opt.jit_config.fuse_reduce == mapping[fuse_reduce]
for fuse_dimshuffle in [None, False, True]:
for fuse_reduce in [None, False, True]:
run(fuse_dimshuffle, fuse_reduce)
...@@ -145,6 +145,24 @@ size_t ComputingGraph::prealloc_static_storage(size_t size) { ...@@ -145,6 +145,24 @@ size_t ComputingGraph::prealloc_static_storage(size_t size) {
} }
#endif #endif
/* ========================== JITConfig ========================== */
bool ComputingGraph::Options::GraphOpt::JITConfig::enabled() const {
if (fuse_dimshuffle != UNSET) return true;
if (fuse_reduce != UNSET) return true;
return false;
}
void ComputingGraph::Options::GraphOpt::JITConfig::update(
const JITConfig& modifier) {
if (modifier.fuse_dimshuffle != UNSET) {
this->fuse_dimshuffle = modifier.fuse_dimshuffle;
}
if (modifier.fuse_reduce != UNSET) {
this->fuse_reduce = modifier.fuse_reduce;
}
}
/* ========================== CallbackCaller ========================== */ /* ========================== CallbackCaller ========================== */
MGB_DEFINE_OPR_CLASS(ComputingGraphImpl::CallbackCaller, MGB_DEFINE_OPR_CLASS(ComputingGraphImpl::CallbackCaller,
SingleCNOperatorNodeBase) // { SingleCNOperatorNodeBase) // {
...@@ -538,12 +556,18 @@ ComputingGraphImpl::CompileState ComputingGraphImpl::compile_prepare( ...@@ -538,12 +556,18 @@ ComputingGraphImpl::CompileState ComputingGraphImpl::compile_prepare(
#if MGB_JIT #if MGB_JIT
if (std::abs(options().graph_opt_level) == 0 && options().graph_opt.jit) { if (std::abs(options().graph_opt_level) == 0 &&
setenv("MGB_JIT_BACKEND","NVRTC",1); (options().graph_opt.jit || options().graph_opt.jit_config.enabled())) {
// Deprecated usage added previously. It allows NVRTC JIT optimization
// when graph_opt_level is 0. This usage is not recommanded any more.
mgb_log_warn(
"It is not recommanded to enable JIT optimization when "
"graph_opt_level is 0.");
setenv("MGB_JIT_BACKEND", "NVRTC", 1);
gopt::GraphOptimizer optimizer; gopt::GraphOptimizer optimizer;
optimizer.add_pass<gopt::JITFusionPass>( optimizer.add_pass<gopt::JITFusionPass>(sopr_stat.has_virtual_grad,
sopr_stat.has_virtual_grad, options().graph_opt.jit,
std::max<uint8_t>(options().graph_opt.jit, 1)); options().graph_opt.jit_config);
optimizer.apply_inplace(dest_vars); optimizer.apply_inplace(dest_vars);
} }
#endif #endif
......
...@@ -338,6 +338,20 @@ class ComputingGraph : public std::enable_shared_from_this<ComputingGraph>, ...@@ -338,6 +338,20 @@ class ComputingGraph : public std::enable_shared_from_this<ComputingGraph>,
//! this value indicates JIT level: 1 for basic elemwise opr; 2 //! this value indicates JIT level: 1 for basic elemwise opr; 2
//! for including reduce oprs //! for including reduce oprs
uint8_t jit = 0; uint8_t jit = 0;
//! jit configurations
struct JITConfig {
static const int UNSET = 0;
static const int OFF = 1;
static const int ON = 2;
int fuse_dimshuffle = UNSET;
int fuse_reduce = UNSET;
bool enabled() const;
void update(const JITConfig& modifier);
} jit_config;
//! whether to enable fine-grained TensorRT opr replace //! whether to enable fine-grained TensorRT opr replace
bool tensorrt = false; bool tensorrt = false;
} graph_opt; } graph_opt;
......
...@@ -645,11 +645,21 @@ GraphOptimizer& GraphOptimizer::add_preset_passes( ...@@ -645,11 +645,21 @@ GraphOptimizer& GraphOptimizer::add_preset_passes(
add_pass<RemoveRedundantCopyPass>(); add_pass<RemoveRedundantCopyPass>();
#if MGB_JIT #if MGB_JIT
bool need_jit = false; using JITConfig = cg::ComputingGraph::Options::GraphOpt::JITConfig;
if (comp_graph_opt && (std::abs(comp_graph_opt->graph_opt_level) >= 3 || int jit_opt_level = 0;
comp_graph_opt->graph_opt.jit)) { JITConfig jit_config;
need_jit = true;
// for more detail on what is happening here, see comments on the
// constuctor of class JITFusionPass in fusion_pass.h
if (comp_graph_opt) {
jit_opt_level = comp_graph_opt->graph_opt.jit;
if (comp_graph_opt->graph_opt_level >= 3) {
jit_opt_level = std::max(jit_opt_level, 1);
}
jit_config = comp_graph_opt->graph_opt.jit_config;
} }
bool need_jit = (jit_opt_level > 0) || jit_config.enabled();
if (need_jit && after_grad) { if (need_jit && after_grad) {
add_pass<gopt::RecompTypeCvtPass>(); add_pass<gopt::RecompTypeCvtPass>();
} }
...@@ -662,9 +672,7 @@ GraphOptimizer& GraphOptimizer::add_preset_passes( ...@@ -662,9 +672,7 @@ GraphOptimizer& GraphOptimizer::add_preset_passes(
#if MGB_JIT #if MGB_JIT
if (need_jit) { if (need_jit) {
add_pass<gopt::JITFusionPass>( add_pass<gopt::JITFusionPass>(after_grad, jit_opt_level, jit_config);
after_grad,
std::max<uint8_t>(comp_graph_opt->graph_opt.jit, 1));
} }
#endif #endif
......
...@@ -428,14 +428,33 @@ bool JITFusionPass::Impl::can_be_fused(cg::OperatorNodeBase* opr) const { ...@@ -428,14 +428,33 @@ bool JITFusionPass::Impl::can_be_fused(cg::OperatorNodeBase* opr) const {
return false; return false;
} }
JITFusionPass::JITFusionPass(bool after_grad, int8_t jit_opt_level) JITFusionPass::JITFusionPass(bool after_grad, int jit_opt_level,
const JITConfig& jit_config)
: m_after_grad{after_grad}, m_feature_bits{JITFeatureBits::NONE} { : m_after_grad{after_grad}, m_feature_bits{JITFeatureBits::NONE} {
// TODO reduce and dimshuffle can not coexsit now. // get default config from jit_opt_level
if (jit_opt_level >= 2) { JITConfig config;
m_feature_bits |= JITFeatureBits::REDUCE; if (jit_opt_level == 1) {
} else { config.fuse_dimshuffle = JITConfig::ON;
config.fuse_reduce = JITConfig::OFF;
} else if (jit_opt_level >= 2) {
config.fuse_dimshuffle = JITConfig::OFF;
config.fuse_reduce = JITConfig::ON;
}
// overwrite default config with custom settings
config.update(jit_config);
bool fuse_dimshuffle = config.fuse_dimshuffle == JITConfig::ON;
bool fuse_reduce = config.fuse_reduce == JITConfig::ON;
if (fuse_dimshuffle && fuse_reduce) {
mgb_assert(false, "reduce and dimshuffle can not coexist now");
}
if (fuse_dimshuffle) {
m_feature_bits |= JITFeatureBits::DIMSHUFFLE; m_feature_bits |= JITFeatureBits::DIMSHUFFLE;
} }
if (fuse_reduce) {
m_feature_bits |= JITFeatureBits::REDUCE;
}
} }
const char* JITFusionPass::name() const { const char* JITFusionPass::name() const {
......
...@@ -39,7 +39,40 @@ class JITFusionPass final : public Pass { ...@@ -39,7 +39,40 @@ class JITFusionPass final : public Pass {
JITFeatureBits m_feature_bits; JITFeatureBits m_feature_bits;
public: public:
JITFusionPass(bool after_grad = true, int8_t jit_opt_level = 1); using JITConfig = cg::ComputingGraph::Options::GraphOpt::JITConfig;
/*
* Explanation of how graph_opt_level, jit_opt_level and jit_config
* control the behavior of JIT optimization:
*
* The design of this API is restricted by the historical burden of
* jit_opt_level and we have to support the old interface jit_opt_level and
* the new interface jit_config at the same time.
*
* How JITFusionPass decides its behavior:
* (1) When graph_opt_level is 3, it sets jit_opt_level to 1
* (2) When the user-defined jit_opt_level is greater than 1, it overwrites
* the previous value of jit_opt_level
* (3) We get a default jit_config from jit_opt_level:
* jit_opt_level = 0: JIT optimization OFF
* jit_opt_level = 1: dimshuffle ON, reduce OFF
* jit_opt_level = 2: dimshuffle OFF, reduce ON
* (4) The user-defined jit_config provides more precise control and
* overwrites the default settings defined by jit_opt_level
*
* Situations in which JIT optimization is ON:
* (1) graph_opt_level = 3
* (2) graph_opt_level = 2, jit_opt_level > 0
* (3) graph_opt_level = 2, jit_opt_level = 0, jit_config is set
* (4) graph_opt_level = 0, jit_opt_level > 0 (deprecated usage)
*
* Situations in which JIT optimization is OFF:
* (1) graph_opt_level = 2, jit_opt_level = 0, jit_config is unset
* (2) graph_opt_level = 1
* (3) graph_opt_level = 0, jit_opt_level = 0
*/
JITFusionPass(bool after_grad = true, int jit_opt_level = 0,
const JITConfig& jit_config = {});
const char* name() const override; const char* name() const override;
void apply(OptState& opt) const override; void apply(OptState& opt) const override;
}; };
......
...@@ -27,6 +27,8 @@ ...@@ -27,6 +27,8 @@
#include "megbrain/test/helper.h" #include "megbrain/test/helper.h"
#include "megbrain/opr/dnn/convolution.h" #include "megbrain/opr/dnn/convolution.h"
#include "../../core/impl/graph/cg_impl_seq.h"
#if MGB_JIT #if MGB_JIT
using namespace mgb; using namespace mgb;
...@@ -1455,6 +1457,122 @@ TEST(TestJITNvrtc, DimshuffleGrad) { ...@@ -1455,6 +1457,122 @@ TEST(TestJITNvrtc, DimshuffleGrad) {
} }
} }
TEST(TestJITNvrtc, JITConfig) {
using JITConfig = cg::ComputingGraph::Options::GraphOpt::JITConfig;
using CompSeq = cg::ComputingGraphImpl::ComputingSequence;
using ReduceMode = opr::Reduce::Param::Mode;
static const int UNSET = JITConfig::UNSET;
static const int OFF = JITConfig::OFF;
static const int ON = JITConfig::ON;
REQUIRE_GPU(1);
set_backend(Backend::NVRTC);
auto cn = CompNode::load("gpu0");
HostTensorGenerator<> gen;
auto run = [&](int graph_opt_level, int jit_opt_level,
const JITConfig& jit_config, bool expect_dimshuffle_fused,
bool expect_reduce_fused, bool expect_jit_enabled) {
auto cg = ComputingGraph::make();
cg->options().graph_opt_level = graph_opt_level;
cg->options().graph_opt.jit = jit_opt_level;
cg->options().graph_opt.jit_config = jit_config;
auto host_x = gen({2, 3, 4, 5}, cn);
auto x = opr::SharedDeviceTensor::make(*cg, *host_x);
// three types of operations to be fused by JIT
x = (2 * x + 3) * (3 * x - 1); // Elemwise
x = opr::Dimshuffle::make(x, {1, 2, 3, 0}); // Dimshuffle
x = opr::Reduce::make(x + 2, {ReduceMode::SUM, 2}); // Reduce
auto func = cg->compile({make_callback_copy(x + 1, *host_x)});
auto comp_seq = dynamic_cast<CompSeq*>(func.get());
ASSERT_TRUE(comp_seq != nullptr);
bool dimshuffle_found = false, reduce_found = false,
jit_executor_found = false;
auto on_opr = [&](cg::OperatorNodeBase* opr) {
if (opr->same_type<opr::Dimshuffle>()) {
dimshuffle_found = true;
} else if (opr->same_type<opr::Reduce>()) {
reduce_found = true;
} else if (opr->same_type<JITExecutor>()) {
jit_executor_found = true;
}
return true;
};
comp_seq->iter_opr_seq(on_opr);
ASSERT_EQ(expect_dimshuffle_fused, !dimshuffle_found);
ASSERT_EQ(expect_reduce_fused, !reduce_found);
ASSERT_EQ(expect_jit_enabled, jit_executor_found);
};
// graph_opt_level = 1, always OFF
for (int jit_opt_level : {0, 1, 2}) {
for (int fuse_dimshuffle : {UNSET, OFF, ON}) {
for (int fuse_reduce : {UNSET, OFF, ON}) {
run(1, jit_opt_level, JITConfig{fuse_dimshuffle, fuse_reduce},
false, false, false);
}
}
}
// some test cases are commented because dimshuffle and reduce can not be
// fused at the same time
for (int graph_opt_level : {0, 2}) {
// jit_opt_level = 0, default = {OFF, OFF}
run(graph_opt_level, 0, JITConfig{UNSET, UNSET}, false, false, false);
run(graph_opt_level, 0, JITConfig{UNSET, OFF}, false, false, true);
run(graph_opt_level, 0, JITConfig{UNSET, ON}, false, true, true);
run(graph_opt_level, 0, JITConfig{OFF, UNSET}, false, false, true);
run(graph_opt_level, 0, JITConfig{OFF, OFF}, false, false, true);
run(graph_opt_level, 0, JITConfig{OFF, ON}, false, true, true);
run(graph_opt_level, 0, JITConfig{ON, UNSET}, true, false, true);
run(graph_opt_level, 0, JITConfig{ON, OFF}, true, false, true);
// run(graph_opt_level, 0, JITConfig{ON, ON}, true, true, true);
}
{
// graph_opt_level = 3, jit_opt_level = 0, default = {ON, OFF}
run(3, 0, JITConfig{UNSET, UNSET}, true, false, true);
run(3, 0, JITConfig{UNSET, OFF}, true, false, true);
// run(3, 0, JITConfig{UNSET, ON}, true, true, true);
run(3, 0, JITConfig{OFF, UNSET}, false, false, true);
run(3, 0, JITConfig{OFF, OFF}, false, false, true);
run(3, 0, JITConfig{OFF, ON}, false, true, true);
run(3, 0, JITConfig{ON, UNSET}, true, false, true);
run(3, 0, JITConfig{ON, OFF}, true, false, true);
// run(3, 0, JITConfig{ON, ON}, true, true, true);
}
for (int graph_opt_level : {0, 2, 3}) {
// jit_opt_level = 1, default = {ON, OFF}
run(graph_opt_level, 1, JITConfig{UNSET, UNSET}, true, false, true);
run(graph_opt_level, 1, JITConfig{UNSET, OFF}, true, false, true);
// run(graph_opt_level, 1, JITConfig{UNSET, ON}, true, true, true);
run(graph_opt_level, 1, JITConfig{OFF, UNSET}, false, false, true);
run(graph_opt_level, 1, JITConfig{OFF, OFF}, false, false, true);
run(graph_opt_level, 1, JITConfig{OFF, ON}, false, true, true);
run(graph_opt_level, 1, JITConfig{ON, UNSET}, true, false, true);
run(graph_opt_level, 1, JITConfig{ON, OFF}, true, false, true);
// run(graph_opt_level, 1, JITConfig{ON, ON}, true, true, true);
// jit_opt_level = 2, default = {OFF, ON}
run(graph_opt_level, 2, JITConfig{UNSET, UNSET}, false, true, true);
run(graph_opt_level, 2, JITConfig{UNSET, OFF}, false, false, true);
run(graph_opt_level, 2, JITConfig{UNSET, ON}, false, true, true);
run(graph_opt_level, 2, JITConfig{OFF, UNSET}, false, true, true);
run(graph_opt_level, 2, JITConfig{OFF, OFF}, false, false, true);
run(graph_opt_level, 2, JITConfig{OFF, ON}, false, true, true);
// run(graph_opt_level, 2, JITConfig{ON, UNSET}, true, true, true);
run(graph_opt_level, 2, JITConfig{ON, OFF}, true, false, true);
// run(graph_opt_level, 2, JITConfig{ON, ON}, true, true, true);
}
}
TEST(TestJITExecutor, GradBehavior) { TEST(TestJITExecutor, GradBehavior) {
REQUIRE_GPU(1); REQUIRE_GPU(1);
auto cn = CompNode::load("gpu0"); auto cn = CompNode::load("gpu0");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册