提交 8ef12bdf 编写于 作者: M Megvii Engine Team

feat(mgb/gopt): add user inferface for global layout transform

GitOrigin-RevId: b71d6c60ef322de44cd5b3189420130213db6fd3
上级 a3cd3fc7
......@@ -247,8 +247,20 @@ R"__usage__(
Execute operators with kernels implemented in MegDNN with NCHW64 tensor format. Can only be used
on Nvidia GPUs, which natively support fast int4 tensorcore inference.
)__usage__"
R"__usage__(
--layout-transform [cuda|x86|arm|opencl|unspec]
Enable global layout transform optimization for computing graph. User should specify the device target for the optimization, and a series of passes will be applied on the computing graph. The passes will benchmark the elapsed time of operators on different tensor layouts, and select fastest implementation for the operators. The optimization process will take some time. The default target is unspec, which all the available for operators will be profiled. So the optimize time will be longer.
--layout-transform-dump <dump_path>
The computing graph after global layout transform will be dumped to the given file path.
--layout-transform-verify
After applying the layout transform optimization, the results of the computing graph before and after layout transform passes will be compared to verify the correctness of the passes.
)__usage__"
R"__usage__(
)__usage__"
;
struct DataParser {
struct Brace {
std::weak_ptr<Brace> parent;
......@@ -566,6 +578,11 @@ struct Args {
serialization::GraphLoader::LoadConfig load_config;
thin_function<void(size_t)> affinity_cb;
bool layout_transform = false;
gopt::GraphTuningOptions::Target layout_transform_target =
gopt::GraphTuningOptions::Target::UNSPEC;
std::string layout_transform_dump_path;
static Args from_argv(int argc, char **argv);
};
......@@ -712,50 +729,9 @@ void run_test_st(Args &env) {
printf("load model: %.3fms\n", timer.get_msecs_reset());
// compile function to compute all outputs
ComputingGraph::OutputSpec out_spec;
std::string output_names;
if (env.display_model_info) {
format_and_print("Original Model Info", env);
}
OutputDumper output_dumper(env);
for (auto&& i : env.load_ret.output_var_list) {
if (&i != env.load_ret.output_var_list.data()) {
output_names += " ";
}
output_names.append(i.node()->name() + i.shape().to_string());
ComputingGraph::Callback cb;
if (!env.bin_out_dump.empty()) {
cb = output_dumper.bind();
} else if (env.copy_to_host) {
HostTensorND val;
cb = [val](const DeviceTensorND& dv) mutable {
val.copy_from(dv);
};
}
out_spec.emplace_back(i, std::move(cb));
}
if (env.disable_assert_throw) {
auto on_opr = [](cg::OperatorNodeBase* opr) {
if (opr->same_type<opr::AssertEqual>()) {
opr->cast_final<opr::AssertEqual>().disable_throw_on_error();
}
};
cg::DepOprIter iter{on_opr};
for (auto&& i : out_spec) {
iter.add(i.first.node()->owner_opr());
}
}
SymbolVarArray vars;
for (auto i : out_spec) {
vars.push_back(i.first);
}
mgb::gopt::set_opr_algo_workspace_limit_inplace(vars, env.workspace_limit);
auto& output_var_list = env.load_ret.output_var_list;
mgb::gopt::set_opr_algo_workspace_limit_inplace(output_var_list,
env.workspace_limit);
using S = opr::mixin::AlgoChooserHelper::ExecutionPolicy::Strategy;
S strategy = static_cast<S>(0);
if (env.reproducible) {
......@@ -772,7 +748,7 @@ void run_test_st(Args &env) {
#else
strategy = S::HEURISTIC | strategy;
#endif
mgb::gopt::modify_opr_algo_strategy_inplace(vars, strategy);
mgb::gopt::modify_opr_algo_strategy_inplace(output_var_list, strategy);
if (!env.fast_run_cache_path.empty()) {
#if MGB_ENABLE_FASTRUN
if (!access(env.fast_run_cache_path.c_str(), F_OK)) {
......@@ -799,7 +775,86 @@ void run_test_st(Args &env) {
}
if (!env.use_full_run && !env.use_fast_run)
#endif
mgb::gopt::enable_opr_use_profiling_cache_inplace(vars);
mgb::gopt::enable_opr_use_profiling_cache_inplace(output_var_list);
}
// load testcase
decltype(env.load_ret) testcase;
if (nr_test) {
loader = serialization::GraphLoader::make(loader->reset_file(),
loader->format());
testcase = loader->load(env.load_config, false);
}
if (env.layout_transform) {
env.load_ret.output_var_list = gopt::layout_transform(
env.load_ret.output_var_list, env.layout_transform_target);
if (!env.layout_transform_dump_path.empty()) {
auto out_file = serialization::OutputFile::make_fs(
env.layout_transform_dump_path.c_str(), 'w');
if (nr_test) {
const char* magic = "mgbtest0";
constexpr size_t len = sizeof(magic);
out_file->write(magic, len);
uint32_t nr_inp_tensors = testcase.output_var_list.size();
out_file->write(&nr_inp_tensors, sizeof(nr_inp_tensors));
}
auto dumper = serialization::GraphDumper::make(std::move(out_file),
format.val());
using DumpConfig = serialization::GraphDumper::DumpConfig;
DumpConfig config{1, false, false};
dumper->dump(env.load_ret.output_var_list, config);
if (nr_test) {
out_file = serialization::OutputFile::make_fs(
env.layout_transform_dump_path.c_str(), 'a');
auto testdumper = serialization::GraphDumper::make(
std::move(out_file), format.val());
testdumper->dump(testcase.output_var_list, config);
}
}
}
// compile function to compute all outputs
ComputingGraph::OutputSpec out_spec;
std::string output_names;
if (env.display_model_info) {
format_and_print("Original Model Info", env);
}
OutputDumper output_dumper(env);
for (auto&& i : env.load_ret.output_var_list) {
if (&i != env.load_ret.output_var_list.data()) {
output_names += " ";
}
output_names.append(i.node()->name() + i.shape().to_string());
ComputingGraph::Callback cb;
if (!env.bin_out_dump.empty()) {
cb = output_dumper.bind();
} else if (env.copy_to_host) {
HostTensorND val;
cb = [val](const DeviceTensorND& dv) mutable {
val.copy_from(dv);
};
}
out_spec.emplace_back(i, std::move(cb));
}
if (env.disable_assert_throw) {
auto on_opr = [](cg::OperatorNodeBase* opr) {
if (opr->same_type<opr::AssertEqual>()) {
opr->cast_final<opr::AssertEqual>().disable_throw_on_error();
}
};
cg::DepOprIter iter{on_opr};
for (auto&& i : out_spec) {
iter.add(i.first.node()->owner_opr());
}
}
SymbolVarArray vars;
for (auto i : out_spec) {
vars.push_back(i.first);
}
auto func = env.load_ret.graph_compile(out_spec);
......@@ -924,9 +979,6 @@ void run_test_st(Args &env) {
env.c_opr_args.copr_param_device_ptr_malloc(c_opr_param.get());
}
loader = serialization::GraphLoader::make(
loader->reset_file(), loader->format());
auto testcase = loader->load(env.load_config, false);
mgb_assert(testcase.output_var_list.size() == inp_tensors.size());
for (size_t i = 0; i < inp_tensors.size(); ++ i) {
auto &&opr = testcase.output_var_list[i].node()->owner_opr()->
......@@ -1522,7 +1574,45 @@ Args Args::from_argv(int argc, char **argv) {
graph_opt.graph_opt.enable_weight_preprocess();
continue;
}
if (!strcmp(argv[i], "--layout-transform")) {
ret.layout_transform = true;
++i;
if (i >= argc) {
--i;
continue;
}
using Target = gopt::GraphTuningOptions::Target;
if (!strcmp(argv[i], "cuda")) {
ret.layout_transform_target = Target::CUDA;
} else if (!strcmp(argv[i], "x86")) {
ret.layout_transform_target = Target::X86;
} else if (!strcmp(argv[i], "arm")) {
ret.layout_transform_target = Target::ARM;
} else if (!strcmp(argv[i], "opencl")) {
ret.layout_transform_target = Target::OPENCL;
} else if (!strncmp(argv[i], "--", 2)) {
--i;
} else {
mgb_assert(false,
"unsupported target(got:%s) for global layout "
"transform",
argv[i]);
}
continue;
}
if (!strcmp(argv[i], "--layout-transform-dump")) {
++i;
mgb_assert(i < argc,
"dump path not given for --layout-transform-dump");
mgb_assert(strncmp(argv[i], "--", 2),
"dump path not given for --layout-transform-dump");
ret.layout_transform_dump_path = argv[i];
continue;
}
fprintf(stderr, "invalid arg: %s\n", argv[i]);
ret.args_parse_ret = -1;
return ret;
......
......@@ -30,6 +30,8 @@
#include "megbrain/tensorrt/opr_replace.h"
#endif
#include "megbrain/gopt/global_layout_transform.h"
using namespace mgb;
using namespace gopt;
......@@ -812,6 +814,40 @@ const GraphOptimizer& GraphOptimizer::add_passes_for_optimize_options(
return *this;
}
const GraphOptimizer& GraphOptimizer::add_passes_for_graph_tuning_options(
const GraphTuningOptions& options) {
bool need_param_fuse = false;
#define cb(_options, _passes) \
if (options.has_set_##_options()) { \
_passes need_param_fuse = true; \
}
cb(layout_transform, {
add_pass<FuseConvBiasNonlinPass>();
add_pass<FuseConvBiasZPass>();
auto profiler = ProfilerBase::make_profiler();
std::unique_ptr<SolverBase> solver{
new DynamicProgrammingSolver(std::move(profiler))};
auto ctx = LayoutTransformContext::make(options.target);
add_pass<LayoutTransformPass>(std::move(ctx), std::move(solver));
add_pass<ShuffleShuffleRemovePass>();
add_pass(FuseNCHW4Int8Preprocess::make());
add_pass(FuseNCHW4Int8Preprocess::make());
add_pass<FuseWarpPerspectiveDimshufflePass>();
#if CUDA_VERSION >= 10020
add_pass<FoldingConvBiasDimshufflePass>();
#endif
});
#undef cb
if (need_param_fuse) {
add_pass<ParamFusePass>();
add_pass<ParamMergePass>();
}
return *this;
}
/* ================ ConstVarPropogateBase ================ */
ConstVarPropogate::AddOprResult ConstVarPropogate::add_opr(
......
......@@ -118,6 +118,17 @@ SymbolVarArray gopt::optimize_for_inference(
.endpoint_vars();
}
SymbolVarArray gopt::layout_transform(const SymbolVarArray& dest_vars,
GraphTuningOptions::Target target) {
GraphTuningOptions options;
options.target = target;
options.enable_layout_transform();
return gopt::GraphOptimizer{}
.add_passes_for_graph_tuning_options(options)
.apply({dest_vars})
.endpoint_vars();
}
namespace {
void modify_conv_strategy(
opr::mixin::AlgoChooserHelper& conv,
......
......@@ -12,10 +12,74 @@
#include "./utils.h"
#include "megbrain/gopt/global_layout_transform.h"
#include "megbrain/opr/dnn/pooling.h"
#include "megbrain/opr/imgproc.h"
#include "megbrain/opr/nn_int.h"
using namespace mgb;
using namespace gopt;
namespace {
using OprFormat = LayoutTransformContext::OprFormat;
using OprList = LayoutTransformContext::OprList;
using Attribute = LayoutTransformContext::Attribute;
using Target = LayoutTransformContext::Target;
const char* target_to_string(Target target) {
#define cb(_target) \
case Target::_target: \
return #_target
switch (target) {
cb(CUDA);
cb(X86);
cb(ARM);
cb(UNSPEC);
default:
mgb_assert(false, "unsupported target (got:%u)",
static_cast<uint32_t>(target));
}
#undef cb
}
std::unique_ptr<LayoutTransformContext> make_cuda_ctx(
OprFormat base_opr_format, TensorFormats base_tensor_format) {
OprList opr_list = {
opr::ConvBiasForward::typeinfo(),
opr::ConvolutionForward::typeinfo(),
opr::ConvolutionBackwardData::typeinfo(),
opr::ElemwiseMultiType::typeinfo(),
opr::Elemwise::typeinfo(),
opr::TypeCvt::typeinfo(),
opr::PoolingForward::typeinfo(),
opr::WarpPerspectiveForward::typeinfo(),
};
SmallVector<TensorFormats> available_tensor_formats = {
TensorFormats::NCHW, TensorFormats::NHWC,
TensorFormats::NCHWc4, TensorFormats::NCHWc32,
TensorFormats::NCHWc64, TensorFormats::CHWNc4};
Attribute attribute = {base_opr_format, base_tensor_format, Target::CUDA};
auto ctx = std::make_unique<LayoutTransformContext>(
std::move(opr_list), std::move(available_tensor_formats),
attribute);
ctx->add_opr_config(
opr::ConvBiasForward::typeinfo(),
{OprFormat::NCHW, OprFormat::NHWC, OprFormat::NCHW4,
OprFormat::NCHW32, OprFormat::NCHW64, OprFormat::CHWN4})
.add_opr_config(opr::ConvolutionForward::typeinfo(),
{OprFormat::NCHW, OprFormat::NCHW4})
.add_opr_config(opr::ConvolutionBackwardData::typeinfo(),
{OprFormat::NCHW, OprFormat::NCHW4})
.add_opr_config(
opr::PoolingForward::typeinfo(),
{OprFormat::NCHW4, OprFormat::NCHW32, OprFormat::NHWC,
OprFormat::NCHW64, OprFormat::CHWN4})
.add_opr_config(
opr::WarpPerspectiveForward::typeinfo(),
{OprFormat::NHWC, OprFormat::NCHW4, OprFormat::NCHW64});
return ctx;
}
} // namespace
/* ================= LayoutTransformContext ==================*/
LayoutTransformContext& LayoutTransformContext::add_opr_config(
Typeinfo* opr, OprFormat opr_format) {
......@@ -37,4 +101,16 @@ LayoutTransformContext& LayoutTransformContext::add_opr_config(
return *this;
}
std::unique_ptr<LayoutTransformContext> LayoutTransformContext::make(
Target target, OprFormat base_opr_format,
TensorFormats base_tensor_format) {
switch (target) {
case Target::CUDA:
return make_cuda_ctx(base_opr_format, base_tensor_format);
default:
mgb_assert(false, "unsupported target %s\n",
target_to_string(target));
}
}
// vim: syntax=cpp.doxygen
......@@ -46,7 +46,8 @@ void LayoutTransformPass::apply(OptState& opt) const {
auto&& opr_configs = m_ctx->opr_configs();
auto&& base_fmt = m_ctx->attribute().base_tensor_formats;
auto&& reformat_attribute = m_ctx->attribute().reformat_attribute;
auto&& reformat_attribute =
ReformatManager::ReformatKey::Attribute::DEFAULT;
ThinHashMap<VarNode*, TensorFormats> var2fmts;
static ThinHashSet<Typeinfo*> format_aware_oprs = {
#define cb(_Opr) opr::_Opr::typeinfo(),
......
......@@ -404,6 +404,11 @@ ReformatManager::ReformatImpl ReformatManager::auto_aligned_reformat_featrue(
NamedTensorShape orig_shape =
tensor_formats_to_named_tensor_shape(orig_format);
size_t orig_channel = 0;
mgb_assert(orig_var->shape().ndim == orig_shape.ndim,
"incompatible NamedTensorShape for "
"feature(var:%s;shape:%s)",
cg::dump_var_info({const_cast<VarNode*>(orig_var)}).c_str(),
orig_shape.to_string().c_str());
for (size_t i = 0; i < orig_shape.ndim; ++i) {
if (orig_shape[i].name() == Dimension::Name::C &&
orig_shape[i].extent() == Dimension::UNDETERMINED_EXTENT) {
......@@ -412,7 +417,9 @@ ReformatManager::ReformatImpl ReformatManager::auto_aligned_reformat_featrue(
}
}
mgb_assert(orig_channel > 0,
"incompatible NamedTensorShape for feature(got:%s)",
"incompatible NamedTensorShape for "
"feature(var:%s;shape:%s)",
cg::dump_var_info({const_cast<VarNode*>(orig_var)}).c_str(),
orig_shape.to_string().c_str());
size_t aligned_in_channel =
divup(orig_channel, input_alignment) * input_alignment;
......
......@@ -24,6 +24,9 @@ namespace gopt {
//! forward declaration for structs in inference.h
struct OptimizeForInferenceOptions;
//! forward declaration for GraphTuningOptions
struct GraphTuningOptions;
/*!
* \brief represent a computing graph to be optimized by specifying its
* endpoints
......@@ -479,6 +482,14 @@ namespace gopt {
const GraphOptimizer& add_passes_for_optimize_options(
const cg::GraphCommonOptimizeOptions& options);
/**
* \brief add pass indicated by graph tuning options
*
* \param options graph tuning options
*/
const GraphOptimizer& add_passes_for_graph_tuning_options(
const GraphTuningOptions& options);
};
/*!
......
......@@ -16,6 +16,7 @@
#include "megbrain/gopt/subgraph_extractor.h"
#include "megbrain/opr/dnn/convolution.h"
#include "megbrain/plugin/opr_footprint.h"
#include "megbrain/gopt/inference.h"
namespace mgb {
namespace gopt {
......@@ -52,6 +53,7 @@ public:
using OprConfigTrait =
ThinHashMap<Typeinfo*,
ThinHashMap<OprFormat, OprTensorFormatsDispatcher*>>;
using Target = GraphTuningOptions::Target;
using ReformatAttribute = ReformatManager::ReformatKey::Attribute;
struct Attribute {
OprFormat base_opr_format; /// the base opr format indicates that the
......@@ -66,11 +68,13 @@ public:
/// (like elemwise, elemwise multi type,
/// typecvt etc.) are built in the base
/// tensor format.
ReformatAttribute
reformat_attribute; /// additional reformat attribute, which
/// indicates whether to pad nhwc layout
/// automatically or to enable nhwcd4 format
/// on opencl platform to use image object
Target target; /// target which indicates the device type
ReformatAttribute reformat_attribute =
ReformatAttribute::DEFAULT; /// additional reformat attribute,
/// which indicates whether to pad
/// nhwc layout automatically or to
/// enable nhwcd4 format on opencl
/// platform to use image object
};
LayoutTransformContext() = delete;
LayoutTransformContext(OprList opr_list,
......@@ -108,6 +112,10 @@ public:
*/
LayoutTransformContext& add_opr_config(Typeinfo* opr,
SmallVector<OprFormat> opr_formats);
static std::unique_ptr<LayoutTransformContext> make(
Target target = Target::UNSPEC,
OprFormat base_opr_format = OprFormat::NCHW,
TensorFormats base_tensor_format = TensorFormats::NCHW);
private:
OprList m_opr_list; /// supported operator list
......
......@@ -353,6 +353,39 @@ namespace gopt {
}
};
/**
* \brief graph level tuning options.
* The GraphTuningOptions is corresponding to graph level optimizations.
* Unlike the GraphCommonOptimizeOptions, these optimization options are
* usually target-dependent and profiling based, and the optimize usually should take place
* during runtime. The GraphTuningOptions includes layout optimization etc,
* more optimize options will be introduced in the future.
*/
struct GraphTuningOptions {
enum class Target : uint32_t {
UNSPEC = 0, ///< unspecific device target
CUDA = 1, ///< CUDA device, usually refer to GPU devices of Nvidia
X86 = 2, ///< x86 cpu
ARM = 3, ///< arm cpu
OPENCL = 4, ///< opencl, usually run on mobile devices
};
Target target;
bool layout_transform = false; ///< whether to enable graph level
///< tuning for layouts of tensors
#define SET(n) \
GraphTuningOptions& enable_##n() { \
n = true; \
return *this; \
} \
GraphTuningOptions& disable_##n() { \
n = false; \
return *this; \
} \
bool has_set_##n() const { return n == true; }
SET(layout_transform);
#undef SET
};
/*!
* \brief optimize a computing graph for inference
*
......@@ -363,6 +396,16 @@ namespace gopt {
const SymbolVarArray& dest_vars,
const OptimizeForInferenceOptions& opt = {});
/*!
* \brief optimize the layout selection for a computing graph
*
* The layout selection optimizers are target-dependent. And this function
* applies a set of predefined optimizer passes designed for specific
* device. */
SymbolVarArray layout_transform(const SymbolVarArray& dest_vars,
GraphTuningOptions::Target target =
GraphTuningOptions::Target::UNSPEC);
/*!
* \brief modify execution strategy for oprs with multiple
* algorithms
......
......@@ -45,7 +45,6 @@ size_t find_opr_num(SymbolVar endpoint) {
size_t opr_num = 0;
auto cb = [&opr_num](cg::OperatorNodeBase* opr) {
if (opr->same_type<T>()) {
printf("%s, %s\n", opr->cname(), opr->dyn_typeinfo()->name);
opr_num++;
}
};
......@@ -78,6 +77,7 @@ TEST(TestLayoutTransform, Resnet18_QS8) {
using OprFormat = LayoutTransformContext::OprFormat;
using OprList = LayoutTransformContext::OprList;
using Target = LayoutTransformContext::Target;
using ReformatAttribute = LayoutTransformContext::ReformatAttribute;
using Attribute = LayoutTransformContext::Attribute;
OprList opr_list = {
......@@ -91,7 +91,7 @@ TEST(TestLayoutTransform, Resnet18_QS8) {
SmallVector<TensorFormats> available_tensor_formats = {
TensorFormats::NCHW, TensorFormats::NHWC, TensorFormats::NCHWc4,
TensorFormats::NCHWc32, TensorFormats::CHWNc4};
Attribute attribute = {OprFormat::NCHW, TensorFormats::NCHW,
Attribute attribute = {OprFormat::NCHW, TensorFormats::NCHW, Target::UNSPEC,
ReformatAttribute::AUTO_PADDING_NHWC};
auto ctx = std::make_unique<LayoutTransformContext>(
std::move(opr_list), std::move(available_tensor_formats),
......@@ -167,8 +167,9 @@ TEST(TestLayoutTransform, Resnet18_QS4) {
using OprFormat = LayoutTransformContext::OprFormat;
using OprList = LayoutTransformContext::OprList;
using ReformatAttribute = LayoutTransformContext::ReformatAttribute;
using Attribute = LayoutTransformContext::Attribute;
using Target = LayoutTransformContext::Target;
using ReformatAttribute = LayoutTransformContext::ReformatAttribute;
OprList opr_list = {
opr::ConvBiasForward::typeinfo(),
opr::ElemwiseMultiType::typeinfo(),
......@@ -181,7 +182,7 @@ TEST(TestLayoutTransform, Resnet18_QS4) {
TensorFormats::NCHW, TensorFormats::NHWC,
TensorFormats::NCHWc4, TensorFormats::NCHWc32,
TensorFormats::NCHWc64, TensorFormats::CHWNc4};
Attribute attribute = {OprFormat::NCHW, TensorFormats::NCHW,
Attribute attribute = {OprFormat::NCHW, TensorFormats::NCHW, Target::UNSPEC,
ReformatAttribute::AUTO_PADDING_NHWC};
auto ctx = std::make_unique<LayoutTransformContext>(
std::move(opr_list), std::move(available_tensor_formats),
......@@ -288,8 +289,9 @@ TEST(TestLayoutTransform, Detection_QS8) {
using OprFormat = LayoutTransformContext::OprFormat;
using OprList = LayoutTransformContext::OprList;
using ReformatAttribute = LayoutTransformContext::ReformatAttribute;
using Attribute = LayoutTransformContext::Attribute;
using Target = LayoutTransformContext::Target;
using ReformatAttribute = LayoutTransformContext::ReformatAttribute;
OprList opr_list = {
opr::ConvBiasForward::typeinfo(),
opr::ElemwiseMultiType::typeinfo(),
......@@ -302,7 +304,7 @@ TEST(TestLayoutTransform, Detection_QS8) {
TensorFormats::NCHW, TensorFormats::NHWC,
TensorFormats::NCHWc4, TensorFormats::NCHWc32,
TensorFormats::NCHWc64, TensorFormats::CHWNc4};
Attribute attribute = {OprFormat::NCHW, TensorFormats::NCHW,
Attribute attribute = {OprFormat::NCHW, TensorFormats::NCHW, Target::UNSPEC,
ReformatAttribute::AUTO_PADDING_NHWC};
auto ctx = std::make_unique<LayoutTransformContext>(
std::move(opr_list), std::move(available_tensor_formats),
......@@ -362,6 +364,7 @@ TEST(TestLayoutTransform, Detection_QS4) {
using OprList = LayoutTransformContext::OprList;
using ReformatAttribute = LayoutTransformContext::ReformatAttribute;
using Attribute = LayoutTransformContext::Attribute;
using Target = LayoutTransformContext::Target;
OprList opr_list = {
opr::ConvBiasForward::typeinfo(),
opr::ElemwiseMultiType::typeinfo(),
......@@ -374,7 +377,7 @@ TEST(TestLayoutTransform, Detection_QS4) {
TensorFormats::NCHW, TensorFormats::NHWC,
TensorFormats::NCHWc4, TensorFormats::NCHWc32,
TensorFormats::NCHWc64, TensorFormats::CHWNc4};
Attribute attribute = {OprFormat::NCHW, TensorFormats::NCHW,
Attribute attribute = {OprFormat::NCHW, TensorFormats::NCHW, Target::UNSPEC,
ReformatAttribute::AUTO_PADDING_NHWC};
auto ctx = std::make_unique<LayoutTransformContext>(
std::move(opr_list), std::move(available_tensor_formats),
......@@ -443,13 +446,14 @@ TEST(TestLayoutTransform, Wide) {
using OprList = LayoutTransformContext::OprList;
using ReformatAttribute = LayoutTransformContext::ReformatAttribute;
using Attribute = LayoutTransformContext::Attribute;
using Target = LayoutTransformContext::Target;
OprList opr_list = {
opr::ConvBiasForward::typeinfo(),
opr::Elemwise::typeinfo(),
};
SmallVector<TensorFormats> available_tensor_formats = {TensorFormats::NCHW,
TensorFormats::NHWC};
Attribute attribute = {OprFormat::NCHW, TensorFormats::NCHW,
Attribute attribute = {OprFormat::NCHW, TensorFormats::NCHW, Target::UNSPEC,
ReformatAttribute::DEFAULT};
auto ctx = std::make_unique<LayoutTransformContext>(
std::move(opr_list), std::move(available_tensor_formats),
......@@ -571,8 +575,8 @@ TEST(TestLayoutTransform, DetectionHead) {
using OprFormat = LayoutTransformContext::OprFormat;
using OprList = LayoutTransformContext::OprList;
using ReformatAttribute = LayoutTransformContext::ReformatAttribute;
using Attribute = LayoutTransformContext::Attribute;
using Target = LayoutTransformContext::Target;
OprList opr_list = {
opr::ConvBiasForward::typeinfo(),
opr::ConvolutionForward::typeinfo(),
......@@ -588,7 +592,7 @@ TEST(TestLayoutTransform, DetectionHead) {
TensorFormats::NCHWc4, TensorFormats::NCHWc32,
TensorFormats::NCHWc64, TensorFormats::CHWNc4};
Attribute attribute = {OprFormat::NCHW, TensorFormats::NCHW,
ReformatAttribute::DEFAULT};
Target::UNSPEC};
auto ctx = std::make_unique<LayoutTransformContext>(
std::move(opr_list), std::move(available_tensor_formats),
attribute);
......
......@@ -28,8 +28,8 @@ namespace {
std::unique_ptr<LayoutTransformContext> make_ctx() {
using OprFormat = LayoutTransformContext::OprFormat;
using OprList = LayoutTransformContext::OprList;
using ReformatAttribute = LayoutTransformContext::ReformatAttribute;
using Attribute = LayoutTransformContext::Attribute;
using Target = LayoutTransformContext::Target;
OprList opr_list = {
opr::ConvBiasForward::typeinfo(),
opr::ConvolutionForward::typeinfo(),
......@@ -45,8 +45,7 @@ std::unique_ptr<LayoutTransformContext> make_ctx() {
TensorFormats::NCHW, TensorFormats::NHWC,
TensorFormats::NCHWc4, TensorFormats::NCHWc32,
TensorFormats::NCHWc64, TensorFormats::CHWNc4};
Attribute attribute = {OprFormat::NCHW, TensorFormats::NCHW,
ReformatAttribute::DEFAULT};
Attribute attribute = {OprFormat::NCHW, TensorFormats::NCHW, Target::CUDA};
auto ctx = std::make_unique<LayoutTransformContext>(
std::move(opr_list), std::move(available_tensor_formats),
attribute);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册