提交 bc95e873 编写于 作者: M Megvii Engine Team

fix(jit): fix jit grad

a) fix shape mismatch when take grad of JITExecutor including Dimshuffle
b) avoid redundant computation in the grad of JITExecutor
c) not pass unused vars as inputs to the grad of JITExecutor to save device memory
d) traverse internal graph only once in JITExecutor ctor instead of traverse
   whole graph in each call of setup_args()
e) expand the gradient graph into the origin graph if all inputs are const

GitOrigin-RevId: ba6a2b29e975c7f63a21785efad87dbda76143d4
上级 fc1ce273
...@@ -88,15 +88,34 @@ JITExecutor::JITExecutor(const InternalGraphPtr& internal_graph, ...@@ -88,15 +88,34 @@ JITExecutor::JITExecutor(const InternalGraphPtr& internal_graph,
cg::add_workspace_output(this); cg::add_workspace_output(this);
} }
// check if output of internal_graph is depend on all placeholders
size_t nr_placeholders = internal_graph_ptr()->placeholders().size();
std::vector<bool> used(nr_placeholders, false);
// check if there is reduce or dimshuffle opr // check if there is reduce or dimshuffle opr
cg::DepOprIter{[this](cg::OperatorNodeBase* opr) { cg::DepOprIter{[this, nr_placeholders, &used](cg::OperatorNodeBase* opr) {
if (opr->same_type<opr::Reduce>()) { if (opr->same_type<opr::Reduce>()) {
m_feature_bits |= JITFeatureBits::REDUCE; m_feature_bits |= JITFeatureBits::REDUCE;
} }
if (opr->same_type<opr::Dimshuffle>()) { if (opr->same_type<opr::Dimshuffle>()) {
m_feature_bits |= JITFeatureBits::DIMSHUFFLE; m_feature_bits |= JITFeatureBits::DIMSHUFFLE;
} }
if (auto ph = opr->try_cast_final<JITPlaceholder>()) {
mgb_assert(ph->input_id() < nr_placeholders,
"bad placeholders %s in JITExecutor %s",
ph->cname(), cname());
used[ph->input_id()] = true;
}
}}.add(internal_graph->output()); }}.add(internal_graph->output());
for (size_t i = 0; i < nr_placeholders; ++ i) {
mgb_assert(used[i],
"placeholder %s is not depended on the output of %s",
internal_graph_ptr()->placeholders()[i]->cname(), cname());
}
if (has_dimshuffle()) {
prepare_dimshuffle();
}
} }
void JITExecutor::add_input_layout_constraint() { void JITExecutor::add_input_layout_constraint() {
...@@ -151,14 +170,14 @@ void JITExecutor::scn_do_execute() { ...@@ -151,14 +170,14 @@ void JITExecutor::scn_do_execute() {
//! can be ignored //! can be ignored
void JITExecutor::do_dimshuffle() { void JITExecutor::do_dimshuffle() {
auto get_dimshuffled_layout = [](const TensorLayout& ily, int32_t* pattern, static auto get_dimshuffled_layout = [](const TensorLayout& ily,
size_t pattern_len) { std::vector<int> pattern) {
TensorLayout oly{ily.dtype}; TensorLayout oly{ily.dtype};
oly.ndim = pattern_len; oly.ndim = pattern.size();
bool input_used[TensorLayout::MAX_NDIM] = {0}; bool input_used[TensorLayout::MAX_NDIM] = {0};
for (uint32_t idx = 0; idx < pattern_len; ++idx) { for (uint32_t idx = 0; idx < pattern.size(); ++idx) {
auto i = pattern[idx]; auto i = pattern[idx];
if (i < 0) { if (i < 0) {
oly.shape[idx] = 1; oly.shape[idx] = 1;
...@@ -179,53 +198,20 @@ void JITExecutor::do_dimshuffle() { ...@@ -179,53 +198,20 @@ void JITExecutor::do_dimshuffle() {
return oly; return oly;
}; };
// DFS to make sure traverse the dimshuffles in one branch for (auto&& i : m_internal_graph->placeholders()) {
std::unordered_set<VarNode*> visited; auto&& input = m_args.inputs[i->input_id()];
std::vector<OperatorNodeBase*> stack(0); auto&& iter = m_jitph2dimshuffle.find(i);
std::vector<uint8_t> idx(0); // input index if (iter == m_jitph2dimshuffle.end()) continue;
stack.push_back(m_internal_graph->output()->owner_opr()); auto&& param = iter->second;
idx.push_back(0); mgb_assert(input.layout.ndim == param.second,
"input ndim mismatch for Dimshuffle: "
while (!stack.empty()) { "expect=%u "
if (idx.back() < stack.back()->input().size() && "actual=%zu",
!visited.count(stack.back()->input(idx.back()))) { param.second, input.layout.ndim);
visited.insert(stack.back()->input(idx.back())); auto dimshuffled_layout = get_dimshuffled_layout(
stack.push_back(stack.back()->input(idx.back())->owner_opr()); input.layout, param.first);
if (stack.back()->same_type<jit::JITPlaceholder>()) { input.layout = dimshuffled_layout;
auto jitph = gopt::try_cast_as_op<JITPlaceholder>(stack.back());
size_t input_id = jitph->input_id();
auto&& input = m_args.inputs[input_id];
for (int i = stack.size() - 1; i >= 0; --i) {
if (stack[i]->same_type<opr::Dimshuffle>()) {
auto param =
stack[i]->cast_final_safe<opr::Dimshuffle>()
.param();
mgb_assert(input.layout.ndim == param.ndim,
"input ndim mismatch for Dimshuffle: "
"expect=%u "
"actual=%zu",
param.ndim, input.layout.ndim);
auto dimshuffled_layout = get_dimshuffled_layout(
input.layout, param.pattern, param.pattern_len);
input.layout = dimshuffled_layout;
}
}
stack.pop_back();
++idx.back();
} else {
idx.push_back(0);
}
} else {
stack.pop_back();
idx.pop_back();
if (!stack.empty())
++idx.back();
}
} }
} }
void JITExecutor::update_args() { void JITExecutor::update_args() {
...@@ -259,7 +245,9 @@ void JITExecutor::update_args() { ...@@ -259,7 +245,9 @@ void JITExecutor::update_args() {
} }
//! dimshuffle opr need to change the input. //! dimshuffle opr need to change the input.
do_dimshuffle(); if (has_dimshuffle()) {
do_dimshuffle();
}
if (m_compiler->property().contain_flag(CPFlag::NEED_INPUT_COLLAPSE)) { if (m_compiler->property().contain_flag(CPFlag::NEED_INPUT_COLLAPSE)) {
// collective collapse datum layout, try to reduce the output ndim // collective collapse datum layout, try to reduce the output ndim
...@@ -304,6 +292,82 @@ void JITExecutor::update_args() { ...@@ -304,6 +292,82 @@ void JITExecutor::update_args() {
m_args.need_update = false; m_args.need_update = false;
} }
void JITExecutor::prepare_dimshuffle() {
std::unordered_set<OperatorNodeBase*> visited;
std::vector<OperatorNodeBase*> stack(0);
std::vector<uint8_t> idx(0); // input index
using Param = DimshuffleParam;
std::vector<Param> dimshuffle_stack;
auto merge_dimshuffle = [&](const opr::Dimshuffle::Param& p) {
if (dimshuffle_stack.empty()) {
dimshuffle_stack.emplace_back();
auto&& param = dimshuffle_stack.back();
param.first.insert(param.first.end(), p.pattern, p.pattern + p.pattern_len);
param.second = p.ndim;
} else {
// merge(p, src) -> param and it has performing dimshuffle(dimshuffle(x, p), src)
// is equivalent to dimshuffle(x, param)
dimshuffle_stack.emplace_back();
auto&& param = dimshuffle_stack.back();
auto&& src = dimshuffle_stack[dimshuffle_stack.size() - 2];
mgb_assert(p.pattern_len == src.second);
param.first.resize(src.first.size());
for (size_t i = 0; i < src.first.size(); ++ i) {
if (src.first[i] == -1) {
param.first[i] = -1;
} else {
param.first[i] = p.pattern[src.first[i]];
}
}
param.second = p.ndim;
}
};
auto push_back = [&](cg::OperatorNodeBase* op) {
mgb_assert(!op->same_type<jit::JITPlaceholder>());
if (auto o = op->try_cast_final<opr::Dimshuffle>()) {
merge_dimshuffle(o->param());
}
stack.push_back(op);
idx.push_back(0);
};
auto pop_back = [&]() {
auto&& op = stack.back();
if (op->same_type<opr::Dimshuffle>()) {
dimshuffle_stack.pop_back();
}
stack.pop_back();
idx.pop_back();
};
push_back(m_internal_graph->output()->owner_opr());
while (!stack.empty()) {
if (idx.back() < stack.back()->input().size()) {
auto cur_opr = stack.back()->input(idx.back())->owner_opr();
if (visited.insert(cur_opr).second) {
if (auto jitph = cur_opr->try_cast_final<jit::JITPlaceholder>()) {
if (!dimshuffle_stack.empty()) {
mgb_assert(
m_jitph2dimshuffle.emplace(jitph, dimshuffle_stack.back()).second,
"already visited JITPlaceholder %s",
jitph->cname());
}
++ idx.back();
} else {
push_back(cur_opr);
}
} else {
++ idx.back();
}
} else {
pop_back();
if (!stack.empty())
++ idx.back();
}
}
}
const JITExecutor::Args& JITExecutor::args() const { const JITExecutor::Args& JITExecutor::args() const {
if (m_args.need_update) { if (m_args.need_update) {
const_cast<JITExecutor*>(this)->update_args(); const_cast<JITExecutor*>(this)->update_args();
...@@ -383,6 +447,56 @@ megdnn::TensorShape JITExecutor::broadcasted_input_shape() const { ...@@ -383,6 +447,56 @@ megdnn::TensorShape JITExecutor::broadcasted_input_shape() const {
#if MGB_ENABLE_GRAD #if MGB_ENABLE_GRAD
namespace {
class InternalGraphRewriter {
ThinHashMap<VarNode*, VarNode*> m_var_map;
VarNode* m_dest_var;
VarNodeArray m_new_inp;
VarNode* get_var(VarNode* var) {
auto&& iter = m_var_map.find(var);
if (iter != m_var_map.end()) {
return iter->second;
}
return var;
}
public:
InternalGraphRewriter(VarNode* dest_var)
:m_dest_var{dest_var}{}
void iter(thin_function<void(cg::OperatorNodeBase*)>&& cb) {
m_var_map.clear();
cg::DepOprIter{std::move(cb)}.add(m_dest_var->owner_opr());
m_dest_var = get_var(m_dest_var);
}
VarNode* dest_var() {
return m_dest_var;
}
void replace_var(VarNode* src, VarNode* dst) {
// Note: do not perform var replacing recursively
// when we extract used placeholders from internal graph, we don't
// consider placeholder replacement pair (a to b), (b to c) as a
// var replacing chain (a to b to c) but as a injective function
// from (a, b) to (b, c)
// in other cases, each var node would be passed as \p src or
// \p dst at most once
m_var_map[src] = dst;
}
void auto_replace_outputs(cg::OperatorNodeBase* opr) {
// in JIT internal graph, output size of opr is always 1
mgb_assert(opr->usable_output().size() == 1);
m_new_inp.clear();
bool need_replace = false;
for (auto&& i : opr->input()) {
auto inp = get_var(i);
m_new_inp.push_back(inp);
need_replace |= (inp != i);
}
if (need_replace) {
auto new_op = serialization::copy_opr_shallow(*opr, m_new_inp);
replace_var(opr->output(0), new_op->output(0));
}
}
};
} // anonymous namespace
MGB_IMPL_OPR_GRAD(JITExecutor) { MGB_IMPL_OPR_GRAD(JITExecutor) {
VarNodeArray grad_inputs; VarNodeArray grad_inputs;
for (auto input : opr.input()) for (auto input : opr.input())
...@@ -404,49 +518,120 @@ MGB_IMPL_OPR_GRAD(JITExecutor) { ...@@ -404,49 +518,120 @@ MGB_IMPL_OPR_GRAD(JITExecutor) {
if (gx.node()->owner_opr()->same_type<opr::InvalidGrad>()) { if (gx.node()->owner_opr()->same_type<opr::InvalidGrad>()) {
return opr::InvalidGrad::make(opr, wrt_idx); return opr::InvalidGrad::make(opr, wrt_idx);
} }
// early return if grad expression is single node
for (size_t i = 0; i < fwd_igraph_ptr->placeholders().size(); ++i) {
if (gx.node() == fwd_igraph_ptr->placeholders()[i]->output(0)) {
return grad_inputs[i];
}
}
if (gx.node() == og_ph.node()) {
return out_grad[0];
}
if (gx.node() == fwd_igraph_ptr->output()) {
return opr.output(0);
}
if (auto imm = gopt::try_cast_as_op<opr::ImmutableTensor>(gx.node()->owner_opr())) {
HostTensorND hval{grad_inputs[0]->comp_node()};
hval.copy_from(imm->value()).sync();
return opr::ImmutableTensor::make(*imm->owner_graph(), hval).node();
}
// replace output var in internal graph with output placeholder, so
// we could forward opr.output(computeed by forward JITExecutor) into
// placeholder to avoid redundant computation
InternalGraphRewriter rewriter{gx.node()};
rewriter.iter([&rewriter, &fwd_igraph_ptr,
&output_ph](cg::OperatorNodeBase* opr) {
if (opr == fwd_igraph_ptr->output()->owner_opr()) {
rewriter.replace_var(opr->output(0), output_ph.node());
return;
}
rewriter.auto_replace_outputs(opr);
});
static auto expand_into_origin_graph = [](cg::OperatorNodeBase* opr,
InternalGraphRewriter& rewriter, const VarNodeArray& grad_inputs) {
if (auto ph = gopt::try_cast_as_op<JITPlaceholder>(opr)) {
rewriter.replace_var(
opr->output(0), grad_inputs.at(ph->input_id()));
return;
}
if (auto imm = gopt::try_cast_as_op<opr::ImmutableTensor>(opr)) {
HostTensorND hval{grad_inputs[0]->comp_node()};
hval.copy_from(imm->value()).sync();
rewriter.replace_var(opr->output(0),
opr::ImmutableTensor::make(*opr->owner_graph(), hval).node());
return;
}
rewriter.auto_replace_outputs(opr);
};
if (opr.compiler()->property().feature_bits & JITFeatureBits::REDUCE) { if (opr.compiler()->property().feature_bits & JITFeatureBits::REDUCE) {
// expand the gradient graph into the original graph to handle bcast // expand the gradient graph into the original graph to handle bcast
// oprs // oprs
ThinHashMap<VarNode*, VarNode*> old2new; using namespace std::placeholders;
VarNodeArray new_inp; rewriter.iter(std::bind(expand_into_origin_graph, _1,
auto on_opr = [&old2new, &grad_inputs, std::ref(rewriter), std::cref(grad_inputs)));
&new_inp](cg::OperatorNodeBase* opr) { return rewriter.dest_var();
} else {
VarNodeArray new_grad_inputs;
PlaceholderArray placeholders;
bool all_inp_const = true;
// gx was not depend on all JITPlaceholders so we need to extract used
// placeholders and build a new internal graph
rewriter.iter([&rewriter, &grad_inputs, &new_grad_inputs,
&placeholders, &all_inp_const](cg::OperatorNodeBase* opr) {
if (auto ph = gopt::try_cast_as_op<JITPlaceholder>(opr)) { if (auto ph = gopt::try_cast_as_op<JITPlaceholder>(opr)) {
old2new[opr->output(0)] = grad_inputs.at(ph->input_id()); new_grad_inputs.push_back(grad_inputs[ph->input_id()]);
return; auto new_ph = JITPlaceholder::make(
} new_grad_inputs.back(), placeholders.size())
if (auto imm = gopt::try_cast_as_op<opr::ImmutableTensor>(opr)) { .node()->owner_opr();
HostTensorND hval{grad_inputs[0]->comp_node()}; placeholders.push_back(new_ph->try_cast_final<JITPlaceholder>());
hval.copy_from(imm->value()).sync(); mgb_assert(placeholders.back());
old2new[opr->output(0)] = rewriter.replace_var(opr->output(0), new_ph->output(0));
opr::ImmutableTensor::make(*opr->owner_graph(), hval) if (!cg::is_const_var_value(new_grad_inputs.back())) {
.node(); all_inp_const = false;
}
return; return;
} }
new_inp.clear(); rewriter.auto_replace_outputs(opr);
for (auto inp : opr->input()) { });
new_inp.push_back(old2new.at(inp)); if (all_inp_const) {
} // if all_inp_const, expand grad graph into origin graph by replace
auto new_opr = serialization::copy_opr_shallow(*opr, new_inp); // placeholders with const inputs, so it could benefit from static
old2new[opr->output(0)] = new_opr->output(0); // infer and const folding mechanism
}; using namespace std::placeholders;
cg::DepOprIter{on_opr}.add(gx.node()); rewriter.iter(std::bind(expand_into_origin_graph, _1,
return old2new.at(gx.node()); std::ref(rewriter), std::cref(new_grad_inputs)));
} else { return rewriter.dest_var();
PlaceholderArray placeholders = fwd_igraph_ptr->placeholders();
for (SymbolVar i : {output_ph, og_ph}) {
placeholders.push_back(
&i.node()->owner_opr()->cast_final_safe<JITPlaceholder>());
} }
for (size_t i = 0; i < placeholders.size(); ++i) { gx = rewriter.dest_var();
if (gx.node() == placeholders[i]->output(0)) {
return grad_inputs[i]; auto shape_infer = fwd_igraph_ptr->shape_infer();
if (opr.has_dimshuffle()) {
auto&& iter = opr.dimshuffle_params().find(
fwd_igraph_ptr->placeholders()[wrt_idx]);
if (iter != opr.dimshuffle_params().end()) {
auto&& pattern = iter->second.first;
auto&& ndim = iter->second.second;
std::vector<int> back(ndim, -1);
for (size_t i = 0; i < pattern.size(); i ++) {
// outdim[i] is indim[j]
auto j = pattern[i];
if (j >= 0) {
mgb_assert(back[j] == -1,
"taking grad for Dimshuffle with duplicated "
"input axis unsupported");
back[j] = i;
}
}
shape_infer = opr::Dimshuffle::make(shape_infer, back, pattern.size()).node();
} }
} }
auto grad_ig = std::make_shared<InternalGraph>( auto grad_ig = std::make_shared<InternalGraph>(
gx.node(), fwd_igraph_ptr->shape_infer(), nullptr, gx.node(), shape_infer, nullptr,
std::move(placeholders)); std::move(placeholders));
auto grad_jit = JITExecutor::make(grad_ig, grad_inputs); auto grad_jit = JITExecutor::make(grad_ig, new_grad_inputs);
if (opr.input_broadcastable()[wrt_idx]) { if (opr.input_broadcastable()[wrt_idx]) {
grad_jit = opr::reduce_sum( grad_jit = opr::reduce_sum(
......
...@@ -26,7 +26,6 @@ JITPlaceholder::JITPlaceholder(VarNode* src_var, size_t id, InpType inp_type) ...@@ -26,7 +26,6 @@ JITPlaceholder::JITPlaceholder(VarNode* src_var, size_t id, InpType inp_type)
{}), {}),
m_inp_type{inp_type}, m_inp_type{inp_type},
m_id{id} { m_id{id} {
add_equivalence_component<ScalarHash<size_t>>(m_id);
mgb_assert(src_var->dtype().category() == DTypeCategory::FLOAT || mgb_assert(src_var->dtype().category() == DTypeCategory::FLOAT ||
src_var->dtype().category() == DTypeCategory::INT, src_var->dtype().category() == DTypeCategory::INT,
"JIT can only be applied to float/int operators, got %s", "JIT can only be applied to float/int operators, got %s",
......
...@@ -35,6 +35,7 @@ MGB_DEFINE_OPR_CLASS(JITExecutor, cg::SingleCNOperatorNodeBase) // { ...@@ -35,6 +35,7 @@ MGB_DEFINE_OPR_CLASS(JITExecutor, cg::SingleCNOperatorNodeBase) // {
using ModeTrait = megdnn::Elemwise::ModeTrait; using ModeTrait = megdnn::Elemwise::ModeTrait;
InternalGraphPtr m_internal_graph; InternalGraphPtr m_internal_graph;
using DimshuffleParam = std::pair<std::vector<int>, uint32_t>;
public: public:
using Mode = opr::Elemwise::Mode; using Mode = opr::Elemwise::Mode;
...@@ -112,6 +113,11 @@ public: ...@@ -112,6 +113,11 @@ public:
return static_cast<bool>(m_feature_bits & JITFeatureBits::DIMSHUFFLE); return static_cast<bool>(m_feature_bits & JITFeatureBits::DIMSHUFFLE);
} }
const ThinHashMap<jit::JITPlaceholder*, DimshuffleParam>&
dimshuffle_params() const {
return m_jitph2dimshuffle;
}
//! get broadcasted shape of inputs //! get broadcasted shape of inputs
megdnn::TensorShape broadcasted_input_shape() const; megdnn::TensorShape broadcasted_input_shape() const;
...@@ -124,8 +130,14 @@ private: ...@@ -124,8 +130,14 @@ private:
Compiler* const m_compiler = nullptr; Compiler* const m_compiler = nullptr;
Executable* m_executable = nullptr; Executable* m_executable = nullptr;
std::vector<bool> m_input_broadcastable; std::vector<bool> m_input_broadcastable;
// JITPlaceHolder -> pair of (dimshuffle pattern, ndim)
// do DFS on internal graph only once in prepare_dimshuffle(), so we can
// easily get the dimshuffle param which should be applied on given
// JITPlaceholder
ThinHashMap<jit::JITPlaceholder*, DimshuffleParam> m_jitph2dimshuffle;
void update_args(); void update_args();
void do_dimshuffle(); void do_dimshuffle();
void prepare_dimshuffle();
NodeProp* do_make_node_prop() const override; NodeProp* do_make_node_prop() const override;
}; };
......
...@@ -61,8 +61,6 @@ public: ...@@ -61,8 +61,6 @@ public:
const PlaceholderArray& placeholders() const { return m_placeholders; } const PlaceholderArray& placeholders() const { return m_placeholders; }
static InternalGraphPtr expand_excutor_op(const InternalGraphPtr&);
private: private:
// For compilation cache, if the output_for_cache is same means the // For compilation cache, if the output_for_cache is same means the
// expression tree is same. // expression tree is same.
......
...@@ -1435,6 +1435,16 @@ TEST(TestJITNvrtc, DimshuffleGrad) { ...@@ -1435,6 +1435,16 @@ TEST(TestJITNvrtc, DimshuffleGrad) {
funcs.second->execute(); funcs.second->execute();
MGB_ASSERT_TENSOR_NEAR(host_y1, host_y2, 1e-3); MGB_ASSERT_TENSOR_NEAR(host_y1, host_y2, 1e-3);
} }
{
FusionChecker checker{2,
[](const SymbolVarArray& inp) -> SymbolVar {
auto var = opr::Dimshuffle::make(inp[0], {1, 2, 3, 0});
return inp[1] * var;
},
CompNode::load("gpu0")};
checker.set_jit_level(1)
.run({TensorShape{1, 2, 3, 4}, {2, 3, 4, 1}});
}
} }
#endif // MGB_JIT #endif // MGB_JIT
......
...@@ -98,7 +98,7 @@ void FusionChecker::ensure_init_graph() { ...@@ -98,7 +98,7 @@ void FusionChecker::ensure_init_graph() {
} else { } else {
ComputingGraph::Options opt; ComputingGraph::Options opt;
opt.graph_opt_level = 3; opt.graph_opt_level = 3;
opt.graph_opt.jit = 2; opt.graph_opt.jit = m_jit_level;
unpack_vector(gopt::GraphOptimizer{} unpack_vector(gopt::GraphOptimizer{}
.add_preset_passes(true, nullptr, &opt) .add_preset_passes(true, nullptr, &opt)
.apply({{m_truth_y}}) .apply({{m_truth_y}})
......
...@@ -65,6 +65,13 @@ public: ...@@ -65,6 +65,13 @@ public:
return *this; return *this;
} }
//! set jit level, default is 2, see graph_opt.jit in graph options
//! for more details
FusionChecker& set_jit_level(uint8_t jit_level) {
m_jit_level = jit_level;
return *this;
}
/*! /*!
* \brief run and check correctness * \brief run and check correctness
* *
...@@ -76,6 +83,7 @@ private: ...@@ -76,6 +83,7 @@ private:
bool m_check_opr_type = true; bool m_check_opr_type = true;
bool m_direct_build = false; bool m_direct_build = false;
const size_t m_nr_input; const size_t m_nr_input;
uint8_t m_jit_level = 2;
const CompNode m_comp_node; const CompNode m_comp_node;
HostTensorGenerator<> m_input_gen; HostTensorGenerator<> m_input_gen;
SmallVector<std::shared_ptr<HostTensorND>> m_inputs_val; SmallVector<std::shared_ptr<HostTensorND>> m_inputs_val;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册