提交 d2278f02 编写于 作者: M Megvii Engine Team

perf(imperative): speed up conv_transpose3d

GitOrigin-RevId: e741305446e926086c36affcb54d77f739133bbe
上级 3a5347ed
...@@ -784,6 +784,10 @@ public: ...@@ -784,6 +784,10 @@ public:
protected: protected:
void deduce_layout_fwd(const TensorLayout& src, TensorLayout& dst); void deduce_layout_fwd(const TensorLayout& src, TensorLayout& dst);
void check_layout_fwd(const TensorLayout& src, const TensorLayout& dst); void check_layout_fwd(const TensorLayout& src, const TensorLayout& dst);
public:
MGE_WIN_DECLSPEC_FUC static void deduce_layout_impl(
const TensorLayout& src, const Param& param, TensorLayout& dst);
}; };
class PoolingForward : public PoolingBase, class PoolingForward : public PoolingBase,
...@@ -1241,6 +1245,8 @@ protected: ...@@ -1241,6 +1245,8 @@ protected:
const TensorLayout& src, const TensorLayout& filter, const TensorLayout& src, const TensorLayout& filter,
const TensorLayout& dst) const; const TensorLayout& dst) const;
static CanonizedFilterMeta make_canonized_filter_meta_impl(
size_t src_ndim, const TensorLayout& filter, const Param& param);
CanonizedFilterMeta make_canonized_filter_meta( CanonizedFilterMeta make_canonized_filter_meta(
size_t src_ndim, const TensorLayout& filter) const; size_t src_ndim, const TensorLayout& filter) const;
}; };
...@@ -1286,6 +1292,10 @@ public: ...@@ -1286,6 +1292,10 @@ public:
* \param[in] diff (n, oc, od, oh, ow) * \param[in] diff (n, oc, od, oh, ow)
* \param[out] grad (n, ic, id, ih, iw) * \param[out] grad (n, ic, id, ih, iw)
*/ */
MGE_WIN_DECLSPEC_FUC static void deduce_layout_impl(
const TensorLayout& filter, const TensorLayout& diff, const Param& param,
TensorLayout& grad);
virtual void exec( virtual void exec(
_megdnn_tensor_in filter, _megdnn_tensor_in diff, _megdnn_tensor_out grad, _megdnn_tensor_in filter, _megdnn_tensor_in diff, _megdnn_tensor_out grad,
_megdnn_workspace workspace) = 0; _megdnn_workspace workspace) = 0;
......
...@@ -38,17 +38,18 @@ std::string get_errmsg( ...@@ -38,17 +38,18 @@ std::string get_errmsg(
} }
} // namespace } // namespace
Convolution3DBase::CanonizedFilterMeta Convolution3DBase::make_canonized_filter_meta( Convolution3DBase::CanonizedFilterMeta Convolution3DBase::
size_t src_ndim, const TensorLayout& filter) const { make_canonized_filter_meta_impl(
size_t src_ndim, const TensorLayout& filter, const Param& param) {
megdnn_assert_contiguous(filter); megdnn_assert_contiguous(filter);
auto img_ndim = src_ndim - 2; auto img_ndim = src_ndim - 2;
CanonizedFilterMeta ret; CanonizedFilterMeta ret;
ret.dtype_enum = filter.dtype.enumv(); ret.dtype_enum = filter.dtype.enumv();
ret.format = param().format; ret.format = param.format;
if (param().mode == Mode::CONVOLUTION) { if (param.mode == Mode::CONVOLUTION) {
ret.should_flip = true; ret.should_flip = true;
} else { } else {
megdnn_assert(param().mode == Mode::CROSS_CORRELATION, "invalid conv mode"); megdnn_assert(param.mode == Mode::CROSS_CORRELATION, "invalid conv mode");
ret.should_flip = false; ret.should_flip = false;
} }
size_t flt_start, flt_spatial_start, ocpg_pos, icpg_pos; size_t flt_start, flt_spatial_start, ocpg_pos, icpg_pos;
...@@ -56,7 +57,7 @@ Convolution3DBase::CanonizedFilterMeta Convolution3DBase::make_canonized_filter_ ...@@ -56,7 +57,7 @@ Convolution3DBase::CanonizedFilterMeta Convolution3DBase::make_canonized_filter_
MEGDNN_MARK_USED_VAR(ocpg_pos); MEGDNN_MARK_USED_VAR(ocpg_pos);
MEGDNN_MARK_USED_VAR(icpg_pos); MEGDNN_MARK_USED_VAR(icpg_pos);
if (param().sparse == Param::Sparse::DENSE) { if (param.sparse == Param::Sparse::DENSE) {
megdnn_assert( megdnn_assert(
filter.ndim == img_ndim + 2, filter.ndim == img_ndim + 2,
"bad filter ndim for dense convolution: " "bad filter ndim for dense convolution: "
...@@ -66,7 +67,7 @@ Convolution3DBase::CanonizedFilterMeta Convolution3DBase::make_canonized_filter_ ...@@ -66,7 +67,7 @@ Convolution3DBase::CanonizedFilterMeta Convolution3DBase::make_canonized_filter_
flt_start = 0; flt_start = 0;
} else { } else {
megdnn_assert( megdnn_assert(
param().sparse == Param::Sparse::GROUP, param.sparse == Param::Sparse::GROUP,
"invalid convolution sparse type"); "invalid convolution sparse type");
megdnn_assert( megdnn_assert(
filter.ndim == img_ndim + 3, filter.ndim == img_ndim + 3,
...@@ -77,14 +78,14 @@ Convolution3DBase::CanonizedFilterMeta Convolution3DBase::make_canonized_filter_ ...@@ -77,14 +78,14 @@ Convolution3DBase::CanonizedFilterMeta Convolution3DBase::make_canonized_filter_
flt_start = 1; flt_start = 1;
} }
if (param().format == Param::Format::NCDHW) { if (param.format == Param::Format::NCDHW) {
// filter should be (oc, ic, fd, fh, fw) // filter should be (oc, ic, fd, fh, fw)
flt_spatial_start = 2; flt_spatial_start = 2;
ocpg_pos = 0; ocpg_pos = 0;
icpg_pos = 1; icpg_pos = 1;
} else { } else {
megdnn_assert( megdnn_assert(
param().format == Param::Format::NDHWC, "invalid conv tensor format"); param.format == Param::Format::NDHWC, "invalid conv tensor format");
// filter should be (oc, fd, fh, fw, ic) // filter should be (oc, fd, fh, fw, ic)
flt_spatial_start = 1; flt_spatial_start = 1;
ocpg_pos = 0; ocpg_pos = 0;
...@@ -96,15 +97,15 @@ Convolution3DBase::CanonizedFilterMeta Convolution3DBase::make_canonized_filter_ ...@@ -96,15 +97,15 @@ Convolution3DBase::CanonizedFilterMeta Convolution3DBase::make_canonized_filter_
"only 3D convolution is supported, and input should be 5-dim; " "only 3D convolution is supported, and input should be 5-dim; "
"got input dim = %zu", "got input dim = %zu",
src_ndim); src_ndim);
ret.stride[0] = this->param().stride_d; ret.stride[0] = param.stride_d;
ret.stride[1] = this->param().stride_h; ret.stride[1] = param.stride_h;
ret.stride[2] = this->param().stride_w; ret.stride[2] = param.stride_w;
ret.padding[0] = this->param().pad_d; ret.padding[0] = param.pad_d;
ret.padding[1] = this->param().pad_h; ret.padding[1] = param.pad_h;
ret.padding[2] = this->param().pad_w; ret.padding[2] = param.pad_w;
ret.dilation[0] = param().dilate_d; ret.dilation[0] = param.dilate_d;
ret.dilation[1] = param().dilate_h; ret.dilation[1] = param.dilate_h;
ret.dilation[2] = param().dilate_w; ret.dilation[2] = param.dilate_w;
ret.ocpg = filter[flt_start + ocpg_pos]; ret.ocpg = filter[flt_start + ocpg_pos];
ret.icpg = filter[flt_start + icpg_pos]; ret.icpg = filter[flt_start + icpg_pos];
for (size_t i = 0; i < ret.spatial_ndim; ++i) { for (size_t i = 0; i < ret.spatial_ndim; ++i) {
...@@ -117,6 +118,11 @@ Convolution3DBase::CanonizedFilterMeta Convolution3DBase::make_canonized_filter_ ...@@ -117,6 +118,11 @@ Convolution3DBase::CanonizedFilterMeta Convolution3DBase::make_canonized_filter_
return ret; return ret;
} }
Convolution3DBase::CanonizedFilterMeta Convolution3DBase::make_canonized_filter_meta(
size_t src_ndim, const TensorLayout& filter) const {
return make_canonized_filter_meta_impl(src_ndim, filter, param());
}
Convolution3DBase::CanonizedFilterMeta Convolution3DBase::deduce_layout_fwd( Convolution3DBase::CanonizedFilterMeta Convolution3DBase::deduce_layout_fwd(
const TensorLayout& src, const TensorLayout& filter, TensorLayout& dst) const { const TensorLayout& src, const TensorLayout& filter, TensorLayout& dst) const {
auto errmsg = [&]() { return get_errmsg(src, filter, dst, param()); }; auto errmsg = [&]() { return get_errmsg(src, filter, dst, param()); };
...@@ -213,12 +219,13 @@ Convolution3DBase::CanonizedFilterMeta Convolution3DBackwardData::check_exec( ...@@ -213,12 +219,13 @@ Convolution3DBase::CanonizedFilterMeta Convolution3DBackwardData::check_exec(
return ret; return ret;
} }
void Convolution3DBackwardData::deduce_layout( void Convolution3DBackwardData::deduce_layout_impl(
const TensorLayout& filter, const TensorLayout& diff, TensorLayout& grad) { const TensorLayout& filter, const TensorLayout& diff, const Param& param,
TensorLayout& grad) {
megdnn_assert( megdnn_assert(
param().data_type == Param::DataType::FLOAT, param.data_type == Param::DataType::FLOAT,
"only float type is supported for conv backward"); "only float type is supported for conv backward");
auto errmsg = [&]() { return get_errmsg(filter, diff, grad, param()); }; auto errmsg = [&]() { return get_errmsg(filter, diff, grad, param); };
MEGDNN_MARK_USED_VAR(errmsg); MEGDNN_MARK_USED_VAR(errmsg);
megdnn_assert_contiguous(filter); megdnn_assert_contiguous(filter);
megdnn_assert_contiguous(diff); megdnn_assert_contiguous(diff);
...@@ -226,7 +233,7 @@ void Convolution3DBackwardData::deduce_layout( ...@@ -226,7 +233,7 @@ void Convolution3DBackwardData::deduce_layout(
megdnn_assert(diff.ndim == 5_z, "%s", errmsg().c_str()); megdnn_assert(diff.ndim == 5_z, "%s", errmsg().c_str());
megdnn_assert(filter.dtype == diff.dtype, "%s", errmsg().c_str()); megdnn_assert(filter.dtype == diff.dtype, "%s", errmsg().c_str());
auto cflt = make_canonized_filter_meta(diff.ndim, filter); auto cflt = make_canonized_filter_meta_impl(diff.ndim, filter, param);
megdnn_assert(cflt.ocpg * cflt.group == diff[1], "%s", errmsg().c_str()); megdnn_assert(cflt.ocpg * cflt.group == diff[1], "%s", errmsg().c_str());
auto deduce = [&errmsg](size_t out, size_t filter, size_t stride, size_t pad) { auto deduce = [&errmsg](size_t out, size_t filter, size_t stride, size_t pad) {
...@@ -247,6 +254,11 @@ void Convolution3DBackwardData::deduce_layout( ...@@ -247,6 +254,11 @@ void Convolution3DBackwardData::deduce_layout(
grad.init_contiguous_stride(); grad.init_contiguous_stride();
} }
void Convolution3DBackwardData::deduce_layout(
const TensorLayout& filter, const TensorLayout& diff, TensorLayout& grad) {
deduce_layout_impl(filter, diff, param(), grad);
}
Convolution3DBase::CanonizedFilterMeta Convolution3DBackwardFilter::check_exec( Convolution3DBase::CanonizedFilterMeta Convolution3DBackwardFilter::check_exec(
const TensorLayout& src, const TensorLayout& diff, const TensorLayout& grad, const TensorLayout& src, const TensorLayout& diff, const TensorLayout& grad,
size_t workspace_in_bytes) { size_t workspace_in_bytes) {
......
...@@ -15,22 +15,22 @@ ...@@ -15,22 +15,22 @@
namespace megdnn { namespace megdnn {
void PoolingBase::deduce_layout_fwd(const TensorLayout& src, TensorLayout& dst) { void PoolingBase::deduce_layout_impl(
auto& p = param(); const TensorLayout& src, const Param& param, TensorLayout& dst) {
auto pformat = p.format; auto pformat = param.format;
// the overhead of generating error message is about 18x of the other part of this // the overhead of generating error message is about 18x of the other part of this
// function so we use a function to wrap the error message and get it only when need. // function so we use a function to wrap the error message and get it only when need.
auto get_errmsg = [&](void) -> std::string { auto get_errmsg = [&](void) -> std::string {
std::string errmsg = std::string errmsg =
megdnn_layout_msg(src) + ", " + megdnn_layout_msg(dst) + ", " + megdnn_layout_msg(src) + ", " + megdnn_layout_msg(dst) + ", " +
"pad_h=" + std::to_string(param().pad_h) + ", " + "pad_h=" + std::to_string(param.pad_h) + ", " +
"pad_w=" + std::to_string(param().pad_w) + ", " + "pad_w=" + std::to_string(param.pad_w) + ", " +
"stride_h=" + std::to_string(param().stride_h) + ", " + "stride_h=" + std::to_string(param.stride_h) + ", " +
"stride_w=" + std::to_string(param().stride_w) + ", " + "stride_w=" + std::to_string(param.stride_w) + ", " +
"window_h=" + std::to_string(param().window_h) + ", " + "window_h=" + std::to_string(param.window_h) + ", " +
"window_w=" + std::to_string(param().window_w) + ", " + "window_w=" + std::to_string(param.window_w) + ", " +
"is_max=" + std::to_string(param().mode == Mode::MAX) + ", " + "is_max=" + std::to_string(param.mode == Mode::MAX) + ", " +
"is_nhwc=" + std::to_string(pformat == Param::Format::NHWC) + ", " + "is_nhwc=" + std::to_string(pformat == Param::Format::NHWC) + ", " +
"is_nhwcd4=" + std::to_string(pformat == Param::Format::NHWCD4); "is_nhwcd4=" + std::to_string(pformat == Param::Format::NHWCD4);
return errmsg; return errmsg;
...@@ -90,12 +90,12 @@ void PoolingBase::deduce_layout_fwd(const TensorLayout& src, TensorLayout& dst) ...@@ -90,12 +90,12 @@ void PoolingBase::deduce_layout_fwd(const TensorLayout& src, TensorLayout& dst)
c *= 64; c *= 64;
} }
size_t oh, ow; size_t oh, ow;
size_t fh = p.window_h; size_t fh = param.window_h;
size_t fw = p.window_w; size_t fw = param.window_w;
size_t sh = p.stride_h; size_t sh = param.stride_h;
size_t sw = p.stride_w; size_t sw = param.stride_w;
size_t ph = p.pad_h; size_t ph = param.pad_h;
size_t pw = p.pad_w; size_t pw = param.pad_w;
// moving some python assert to here // moving some python assert to here
// megdnn_assert() // megdnn_assert()
...@@ -128,12 +128,15 @@ void PoolingBase::deduce_layout_fwd(const TensorLayout& src, TensorLayout& dst) ...@@ -128,12 +128,15 @@ void PoolingBase::deduce_layout_fwd(const TensorLayout& src, TensorLayout& dst)
} }
} }
void PoolingBase::deduce_layout_fwd(const TensorLayout& src, TensorLayout& dst) {
deduce_layout_impl(src, param(), dst);
}
void PoolingBase::check_layout_fwd(const TensorLayout& src, const TensorLayout& dst) { void PoolingBase::check_layout_fwd(const TensorLayout& src, const TensorLayout& dst) {
TensorLayout dst_expected; TensorLayout dst_expected;
megdnn_assert_eq_dtype(src, dst); megdnn_assert_eq_dtype(src, dst);
deduce_layout_fwd(src, dst_expected); deduce_layout_fwd(src, dst_expected);
megdnn_assert_eq_layout(dst_expected, dst); megdnn_assert_eq_layout(dst_expected, dst);
megdnn_assert(src.dtype == dst.dtype);
megdnn_assert( megdnn_assert(
src.dtype.category() == DTypeCategory::FLOAT || src.dtype.category() == DTypeCategory::FLOAT ||
src.dtype == dtype::Int8() || src.dtype == dtype::Int8() ||
......
...@@ -93,12 +93,17 @@ __all__ = [ ...@@ -93,12 +93,17 @@ __all__ = [
def expand_hw(x): def expand_hw(x):
# judge int is 5 times faster than judge Sequence
if isinstance(x, int):
return x, x
if isinstance(x, Sequence): if isinstance(x, Sequence):
return int(x[0]), int(x[1]) return int(x[0]), int(x[1])
return int(x), int(x) return int(x), int(x)
def expand_dhw(x): def expand_dhw(x):
if isinstance(x, int):
return x, x, x
if isinstance(x, Sequence): if isinstance(x, Sequence):
return int(x[0]), int(x[1]), int(x[2]) return int(x[0]), int(x[1]), int(x[2])
return int(x), int(x), int(x) return int(x), int(x), int(x)
......
#pragma once
#include "megbrain/rdnn/algo_chooser.h" #include "megbrain/rdnn/algo_chooser.h"
#include "megdnn/heuristic_cache.h" #include "megdnn/heuristic_cache.h"
......
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/ */
#pragma once
#include "megbrain/comp_node.h" #include "megbrain/comp_node.h"
#include "megbrain/comp_node_env.h" #include "megbrain/comp_node_env.h"
......
...@@ -579,6 +579,63 @@ OP_TRAIT_REG(Convolution3D, Convolution3D, opr::Convolution3D) ...@@ -579,6 +579,63 @@ OP_TRAIT_REG(Convolution3D, Convolution3D, opr::Convolution3D)
namespace { namespace {
namespace convolution3d_backward_data { namespace convolution3d_backward_data {
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) {
mgb_assert(
inputs.size() == 2,
"inputs num of conv_transpose3d should be 2 but you give %zu",
inputs.size());
auto&& op_def = def.cast_final_safe<Convolution3DBackwardData>();
auto&& weight = inputs[0];
auto&& diff = inputs[1];
auto& cn = weight.comp_node;
if (weight.layout.ndim == 0) {
return {{{TensorLayout{weight.layout.dtype}, cn, {}}}, false};
}
TensorLayout oup_layout;
megdnn::Convolution3DBackwardData::deduce_layout_impl(
weight.layout, diff.layout, op_def.param(), oup_layout);
return {{{oup_layout, cn, {}}}, true};
}
SmallVector<TensorPtr> apply_on_physical_tensor(
const OpDef& def, const SmallVector<TensorPtr>& inputs,
SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) {
auto&& op_def = def.cast_final_safe<Convolution3DBackwardData>();
auto cn = inputs[0]->comp_node();
megdnn::TensorND weight = inputs[0]->dnn_tensor();
megdnn::TensorND diff = inputs[1]->dnn_tensor();
DnnOprCaller<megdnn::Convolution3DBackwardData> caller(cn);
auto&& dnn_opr = caller.op;
dnn_opr->param() = op_def.param();
TensorLayout& oup_layout = output_descs[0].layout;
if (!validated) {
megdnn::Convolution3DBackwardData::deduce_layout_impl(
weight.layout, diff.layout, op_def.param(), oup_layout);
}
DeviceTensorND oup =
BlobManager::inst()->alloc_workspace_with_defrag(cn, oup_layout);
size_t wk_size = setup_algo<megdnn::Convolution3DBackwardData>(
{weight.layout, diff.layout, oup_layout}, dnn_opr.get(), 0, false, false,
cn, op_def.policy(), false);
megdnn::Workspace dnn_wk;
if (wk_size != 0) {
auto wk = Blob::make(cn, wk_size);
dnn_wk.raw_ptr = wk->storage().get();
dnn_wk.size = wk_size;
}
dnn_opr->exec(weight, diff, oup.as_megdnn(), dnn_wk);
return {Tensor::make(oup)};
}
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
auto&& conv = static_cast<const Convolution3DBackwardData&>(def); auto&& conv = static_cast<const Convolution3DBackwardData&>(def);
OperatorNodeConfig config{conv.make_name()}; OperatorNodeConfig config{conv.make_name()};
...@@ -589,6 +646,8 @@ auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { ...@@ -589,6 +646,8 @@ auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
OP_TRAIT_REG(Convolution3DBackwardData, Convolution3DBackwardData) OP_TRAIT_REG(Convolution3DBackwardData, Convolution3DBackwardData)
.apply_on_var_node(apply_on_var_node) .apply_on_var_node(apply_on_var_node)
.infer_output_attrs_fallible(infer_output_attrs_fallible)
.apply_on_physical_tensor(apply_on_physical_tensor)
.fallback(); .fallback();
} // namespace convolution3d_backward_data } // namespace convolution3d_backward_data
} // namespace } // namespace
......
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
#include "megbrain/opr/dnn/pooling.h" #include "megbrain/opr/dnn/pooling.h"
#include "megbrain/imperative/ops/autogen.h" #include "megbrain/imperative/ops/autogen.h"
#include "megbrain/imperative/utils/stats.h"
#include "megbrain/opr/utility.h" #include "megbrain/opr/utility.h"
#include "megbrain/opr/internal/megdnn_opr_wrapper.h" #include "megbrain/opr/internal/megdnn_opr_wrapper.h"
...@@ -25,9 +26,6 @@ namespace mgb::imperative { ...@@ -25,9 +26,6 @@ namespace mgb::imperative {
namespace { namespace {
namespace pooling { namespace pooling {
// using OprHandle = opr::intl::UniqPtrWithCN<megdnn::Pooling>;
// static ThinHashMap<CompNode, OprHandle> dnn_oprs;
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
auto&& pool = static_cast<const Pooling&>(def); auto&& pool = static_cast<const Pooling&>(def);
OperatorNodeConfig config{pool.make_name()}; OperatorNodeConfig config{pool.make_name()};
...@@ -48,11 +46,9 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( ...@@ -48,11 +46,9 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
return {{{TensorLayout{inp.layout.dtype}, inp_cn, {}}}, false}; return {{{TensorLayout{inp.layout.dtype}, inp_cn, {}}}, false};
} }
DnnOprCaller<megdnn::Pooling> caller(inp_cn);
auto&& dnn_opr = caller.op;
dnn_opr->param() = op_def.param();
TensorLayout oup_layout; TensorLayout oup_layout;
dnn_opr->deduce_layout(inp.layout, oup_layout); megdnn::Pooling::deduce_layout_impl(inp.layout, op_def.param(), oup_layout);
return {{{oup_layout, inp_cn, {}}}, true}; return {{{oup_layout, inp_cn, {}}}, true};
} }
...@@ -73,7 +69,8 @@ SmallVector<TensorPtr> apply_on_physical_tensor( ...@@ -73,7 +69,8 @@ SmallVector<TensorPtr> apply_on_physical_tensor(
TensorLayout& oup_layout = output_descs[0].layout; TensorLayout& oup_layout = output_descs[0].layout;
if (!validated) { if (!validated) {
dnn_opr->deduce_layout(inp_tensornd.layout, oup_layout); megdnn::Pooling::deduce_layout_impl(
inp_tensornd.layout, op_def.param(), oup_layout);
} }
DeviceTensorND out_devtensor = DeviceTensorND out_devtensor =
BlobManager::inst()->alloc_workspace_with_defrag(cn, oup_layout); BlobManager::inst()->alloc_workspace_with_defrag(cn, oup_layout);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册