未验证 提交 3a98dabd 编写于 作者: Z zhangshijin 提交者: GitHub

Merge pull request #41 from Cambricon/develop-set-input-layout

Develop set input layout
...@@ -36,16 +36,12 @@ void CxxPaddleApiImpl::Init(const lite_api::CxxConfig &config) { ...@@ -36,16 +36,12 @@ void CxxPaddleApiImpl::Init(const lite_api::CxxConfig &config) {
#endif #endif
#ifdef LITE_WITH_MLU #ifdef LITE_WITH_MLU
Env<TARGET(kMLU)>::Init(); Env<TARGET(kMLU)>::Init();
mlu_core_version_ = config.mlu_core_version(); lite::DeviceInfo::Global().SetMLURunMode(config.mlu_core_version(),
mlu_core_number_ = config.mlu_core_number(); config.mlu_core_number(),
use_first_conv_ = config.use_first_conv(); config.mlu_use_first_conv(),
mean_vec_ = config.mean(); config.mlu_first_conv_mean(),
std_vec_ = config.std(); config.mlu_first_conv_std(),
lite::DeviceInfo::Global().SetMLURunMode(mlu_core_version_, config.mlu_input_layout());
mlu_core_number_,
use_first_conv_,
mean_vec_,
std_vec_);
#endif // LITE_WITH_MLU #endif // LITE_WITH_MLU
auto places = config.valid_places(); auto places = config.valid_places();
std::vector<std::string> passes{}; std::vector<std::string> passes{};
......
...@@ -203,6 +203,37 @@ void ConfigBase::set_threads(int threads) { ...@@ -203,6 +203,37 @@ void ConfigBase::set_threads(int threads) {
#endif #endif
} }
void CxxConfig::set_mlu_core_version(lite_api::MLUCoreVersion core_version) {
mlu_core_version_ = core_version;
}
void CxxConfig::set_mlu_core_number(int core_number) {
mlu_core_number_ = core_number;
}
void CxxConfig::set_mlu_input_layout(DataLayoutType layout) {
mlu_input_layout_ = layout;
}
void CxxConfig::set_mlu_use_first_conv(bool use_first_conv) {
mlu_use_first_conv_ = use_first_conv;
}
void CxxConfig::set_mlu_first_conv_mean(const std::vector<float> &mean) {
mlu_first_conv_mean_ = mean;
}
void CxxConfig::set_mlu_first_conv_std(const std::vector<float> &std) {
mlu_first_conv_std_ = std;
}
lite_api::MLUCoreVersion CxxConfig::mlu_core_version() const {
return mlu_core_version_;
}
int CxxConfig::mlu_core_number() const { return mlu_core_number_; }
DataLayoutType CxxConfig::mlu_input_layout() const { return mlu_input_layout_; }
bool CxxConfig::mlu_use_first_conv() const { return mlu_use_first_conv_; }
std::vector<float> CxxConfig::mlu_first_conv_mean() const {
return mlu_first_conv_mean_;
}
std::vector<float> CxxConfig::mlu_first_conv_std() const {
return mlu_first_conv_std_;
}
// set model data in combined format, `set_model_from_file` refers to loading // set model data in combined format, `set_model_from_file` refers to loading
// model from file, set_model_from_buffer refers to loading model from memory // model from file, set_model_from_buffer refers to loading model from memory
// buffer // buffer
......
...@@ -106,11 +106,6 @@ class LITE_API PaddlePredictor { ...@@ -106,11 +106,6 @@ class LITE_API PaddlePredictor {
protected: protected:
int threads_{1}; int threads_{1};
lite_api::PowerMode mode_{lite_api::LITE_POWER_NO_BIND}; lite_api::PowerMode mode_{lite_api::LITE_POWER_NO_BIND};
lite_api::MLUCoreVersion mlu_core_version_{lite_api::MLU_270};
int mlu_core_number_{1};
bool use_first_conv_{false};
std::vector<float> mean_vec_;
std::vector<float> std_vec_;
}; };
/// Base class for all the configs. /// Base class for all the configs.
...@@ -141,11 +136,12 @@ class LITE_API CxxConfig : public ConfigBase { ...@@ -141,11 +136,12 @@ class LITE_API CxxConfig : public ConfigBase {
#ifdef LITE_WITH_X86 #ifdef LITE_WITH_X86
int x86_math_library_math_threads_ = 1; int x86_math_library_math_threads_ = 1;
#endif #endif
bool use_firstconv_{false};
std::vector<float> mean_ = {0.0f};
std::vector<float> std_ = {1.0f};
lite_api::MLUCoreVersion mlu_core_version_{lite_api::MLUCoreVersion::MLU_270}; lite_api::MLUCoreVersion mlu_core_version_{lite_api::MLUCoreVersion::MLU_270};
int mlu_core_number_{1}; int mlu_core_number_{1};
DataLayoutType mlu_input_layout_{DATALAYOUT(kNCHW)};
bool mlu_use_first_conv_{false};
std::vector<float> mlu_first_conv_mean_;
std::vector<float> mlu_first_conv_std_;
public: public:
void set_valid_places(const std::vector<Place>& x) { valid_places_ = x; } void set_valid_places(const std::vector<Place>& x) { valid_places_ = x; }
...@@ -173,20 +169,20 @@ class LITE_API CxxConfig : public ConfigBase { ...@@ -173,20 +169,20 @@ class LITE_API CxxConfig : public ConfigBase {
return x86_math_library_math_threads_; return x86_math_library_math_threads_;
} }
#endif #endif
void set_use_firstconv(const bool firstconv) { use_firstconv_ = firstconv; }
void set_mean(const std::vector<float> mean) { mean_ = mean; } void set_mlu_core_version(lite_api::MLUCoreVersion core_version);
void set_std(const std::vector<float> std) { std_ = std; } void set_mlu_core_number(int core_number);
void set_mlu_core_version(lite_api::MLUCoreVersion core_version) { void set_mlu_input_layout(DataLayoutType layout);
mlu_core_version_ = core_version; void set_mlu_use_first_conv(bool use_first_conv);
} void set_mlu_first_conv_mean(const std::vector<float>& mean);
void set_mlu_core_number(int core_number) { mlu_core_number_ = core_number; } void set_mlu_first_conv_std(const std::vector<float>& std);
bool use_first_conv() const { return use_firstconv_; }
std::vector<float> mean() const { return mean_; } lite_api::MLUCoreVersion mlu_core_version() const;
std::vector<float> std() const { return std_; } int mlu_core_number() const;
lite_api::MLUCoreVersion mlu_core_version() const { DataLayoutType mlu_input_layout() const;
return mlu_core_version_; bool mlu_use_first_conv() const;
} std::vector<float> mlu_first_conv_mean() const;
int mlu_core_number() const { return mlu_core_number_; } std::vector<float> mlu_first_conv_std() const;
}; };
/// MobileConfig is the config for the light weight predictor, it will skip /// MobileConfig is the config for the light weight predictor, it will skip
......
...@@ -128,11 +128,12 @@ void BindLiteCxxConfig(py::module *m) { ...@@ -128,11 +128,12 @@ void BindLiteCxxConfig(py::module *m) {
.def("power_mode", &CxxConfig::power_mode); .def("power_mode", &CxxConfig::power_mode);
#endif #endif
#ifdef LITE_WITH_MLU #ifdef LITE_WITH_MLU
cxx_config.def("set_use_firstconv", &CxxConfig::set_use_firstconv) cxx_config.def("set_mlu_core_version", &CxxConfig::set_mlu_core_version)
.def("set_mean", &CxxConfig::set_mean) .def("set_mlu_core_number", &CxxConfig::set_mlu_core_number)
.def("set_std", &CxxConfig::set_std) .def("set_mlu_input_layout", &CxxConfig::set_mlu_input_layout)
.def("set_mlu_core_version", &CxxConfig::set_mlu_core_version) .def("set_mlu_use_first_conv", &CxxConfig::set_mlu_use_first_conv)
.def("set_mlu_core_number", &CxxConfig::set_mlu_core_number); .def("set_mlu_first_conv_mean", &CxxConfig::set_mlu_first_conv_mean)
.def("set_mlu_first_conv_std", &CxxConfig::set_mlu_first_conv_std);
#endif #endif
} }
......
...@@ -72,6 +72,7 @@ thread_local int DeviceInfo::mlu_core_number_{1}; ...@@ -72,6 +72,7 @@ thread_local int DeviceInfo::mlu_core_number_{1};
thread_local bool DeviceInfo::use_first_conv_{false}; thread_local bool DeviceInfo::use_first_conv_{false};
thread_local std::vector<float> DeviceInfo::mean_vec_; thread_local std::vector<float> DeviceInfo::mean_vec_;
thread_local std::vector<float> DeviceInfo::std_vec_; thread_local std::vector<float> DeviceInfo::std_vec_;
thread_local DataLayoutType DeviceInfo::input_layout_{DATALAYOUT(kNCHW)};
#endif #endif
#ifdef TARGET_IOS #ifdef TARGET_IOS
...@@ -1093,7 +1094,8 @@ void DeviceInfo::SetMLURunMode(lite_api::MLUCoreVersion core_version, ...@@ -1093,7 +1094,8 @@ void DeviceInfo::SetMLURunMode(lite_api::MLUCoreVersion core_version,
int core_number, int core_number,
bool use_first_conv, bool use_first_conv,
const std::vector<float>& mean_vec, const std::vector<float>& mean_vec,
const std::vector<float>& std_vec) { const std::vector<float>& std_vec,
DataLayoutType input_layout) {
switch (core_version) { switch (core_version) {
case (lite_api::MLUCoreVersion::MLU_220): case (lite_api::MLUCoreVersion::MLU_220):
mlu_core_version_ = CNML_MLU220; mlu_core_version_ = CNML_MLU220;
...@@ -1109,6 +1111,7 @@ void DeviceInfo::SetMLURunMode(lite_api::MLUCoreVersion core_version, ...@@ -1109,6 +1111,7 @@ void DeviceInfo::SetMLURunMode(lite_api::MLUCoreVersion core_version,
use_first_conv_ = use_first_conv; use_first_conv_ = use_first_conv;
mean_vec_ = mean_vec; mean_vec_ = mean_vec;
std_vec_ = std_vec; std_vec_ = std_vec;
input_layout_ = input_layout;
} }
cnmlCoreVersion_t DeviceInfo::MLUCoreVersion() { return mlu_core_version_; } cnmlCoreVersion_t DeviceInfo::MLUCoreVersion() { return mlu_core_version_; }
...@@ -1121,6 +1124,8 @@ const std::vector<float>& DeviceInfo::MeanVec() const { return mean_vec_; } ...@@ -1121,6 +1124,8 @@ const std::vector<float>& DeviceInfo::MeanVec() const { return mean_vec_; }
const std::vector<float>& DeviceInfo::StdVec() const { return std_vec_; } const std::vector<float>& DeviceInfo::StdVec() const { return std_vec_; }
DataLayoutType DeviceInfo::InputLayout() const { return input_layout_; }
#endif // LITE_WITH_MLU #endif // LITE_WITH_MLU
void DeviceInfo::SetRunMode(lite_api::PowerMode mode, int thread_num) { void DeviceInfo::SetRunMode(lite_api::PowerMode mode, int thread_num) {
......
...@@ -60,12 +60,14 @@ class DeviceInfo { ...@@ -60,12 +60,14 @@ class DeviceInfo {
int core_number, int core_number,
bool use_first_conv, bool use_first_conv,
const std::vector<float>& mean_vec, const std::vector<float>& mean_vec,
const std::vector<float>& std_vec); const std::vector<float>& std_vec,
DataLayoutType input_layout);
cnmlCoreVersion_t MLUCoreVersion(); cnmlCoreVersion_t MLUCoreVersion();
int MLUCoreNumber(); int MLUCoreNumber();
bool UseFirstConv(); bool UseFirstConv();
const std::vector<float>& MeanVec() const; const std::vector<float>& MeanVec() const;
const std::vector<float>& StdVec() const; const std::vector<float>& StdVec() const;
DataLayoutType InputLayout() const;
#endif #endif
void SetCache(int l1size, int l2size, int l3size); void SetCache(int l1size, int l2size, int l3size);
void SetArch(ARMArch arch) { arch_ = arch; } void SetArch(ARMArch arch) { arch_ = arch; }
...@@ -124,6 +126,7 @@ class DeviceInfo { ...@@ -124,6 +126,7 @@ class DeviceInfo {
static thread_local bool use_first_conv_; static thread_local bool use_first_conv_;
static thread_local std::vector<float> mean_vec_; static thread_local std::vector<float> mean_vec_;
static thread_local std::vector<float> std_vec_; static thread_local std::vector<float> std_vec_;
static thread_local DataLayoutType input_layout_;
#endif #endif
void SetDotInfo(int argc, ...); void SetDotInfo(int argc, ...);
......
...@@ -74,7 +74,9 @@ Node* MLUPostprocessPass::InsertCastBefore(const std::string& op_type, ...@@ -74,7 +74,9 @@ Node* MLUPostprocessPass::InsertCastBefore(const std::string& op_type,
const Type* in_arg_ty = kernel->GetInputDeclType("Input"); const Type* in_arg_ty = kernel->GetInputDeclType("Input");
const Type* out_arg_ty = kernel->GetOutputDeclType("Out"); const Type* out_arg_ty = kernel->GetOutputDeclType("Out");
if (DataLayoutCompatible(*in_arg_ty, *cur_node->AsArg().type) && if (DataLayoutCompatible(*in_arg_ty, *cur_node->AsArg().type) &&
DataLayoutCompatible(*out_arg_ty, *cast_type)) { DataLayoutCompatible(*out_arg_ty, *cast_type) &&
// for first conv
PrecisionCompatibleTo(*in_arg_ty, *cur_node->AsArg().type)) {
is_found = true; is_found = true;
} }
} else if (op_type == "io_copy") { } else if (op_type == "io_copy") {
...@@ -121,7 +123,7 @@ Node* MLUPostprocessPass::InsertCastAfter(const std::string& op_type, ...@@ -121,7 +123,7 @@ Node* MLUPostprocessPass::InsertCastAfter(const std::string& op_type,
cast_arg->AsArg().type = cast_type; cast_arg->AsArg().type = cast_type;
auto* var = inst_node->AsStmt().op()->scope()->Var(cast_arg_name); auto* var = inst_node->AsStmt().op()->scope()->Var(cast_arg_name);
// for CastAfter manully set the tensor's type // for CastAfter manully set the tensor's type
var->GetMutable<::paddle::lite::Tensor>(); var->GetMutable<paddle::lite::Tensor>();
// create the stmt node // create the stmt node
auto* cast_inst = graph->NewInstructNode(); auto* cast_inst = graph->NewInstructNode();
...@@ -215,23 +217,23 @@ void MLUPostprocessPass::InsertBefore(SSAGraph* graph, ...@@ -215,23 +217,23 @@ void MLUPostprocessPass::InsertBefore(SSAGraph* graph,
first_conv_nodes_.end(), first_conv_nodes_.end(),
head_node->AsArg().name) != first_conv_nodes_.end(); head_node->AsArg().name) != first_conv_nodes_.end();
// layout cast node // precision cast node
if (head_type->layout() != inst_type->layout()) { if (head_type->precision() != inst_type->precision() && !is_first_conv_head) {
cur_node = InsertCastBefore( cur_node = InsertCastBefore(
"layout", "cast",
name_prefix + "layout", name_prefix + "cast",
graph, graph,
cur_node, cur_node,
inst_node, inst_node,
LiteType::GetTensorTy( LiteType::GetTensorTy(
head_type->target(), head_type->precision(), inst_type->layout())); head_type->target(), inst_type->precision(), head_type->layout()));
} }
// precision cast node // layout cast node
if (head_type->precision() != inst_type->precision() && !is_first_conv_head) { if (head_type->layout() != inst_type->layout()) {
cur_node = InsertCastBefore( cur_node = InsertCastBefore(
"cast", "layout",
name_prefix + "cast", name_prefix + "layout",
graph, graph,
cur_node, cur_node,
inst_node, inst_node,
...@@ -281,7 +283,7 @@ void MLUPostprocessPass::GetSubgraphOpArgType(Node* inst_node, ...@@ -281,7 +283,7 @@ void MLUPostprocessPass::GetSubgraphOpArgType(Node* inst_node,
// get subgraph's valid precision // get subgraph's valid precision
const auto& places = graph->valid_places(); const auto& places = graph->valid_places();
std::set<::paddle::lite_api::PrecisionType> prec_set; std::set<paddle::lite_api::PrecisionType> prec_set;
for (const auto& place : places) { for (const auto& place : places) {
if (place.target == TARGET(kMLU)) { if (place.target == TARGET(kMLU)) {
prec_set.insert(place.precision); prec_set.insert(place.precision);
...@@ -364,23 +366,23 @@ void MLUPostprocessPass::InsertAfter(SSAGraph* graph, ...@@ -364,23 +366,23 @@ void MLUPostprocessPass::InsertAfter(SSAGraph* graph,
const auto name_prefix = const auto name_prefix =
tail_node->AsArg().name + string_format("_%p", inst_node) + "/trans_"; tail_node->AsArg().name + string_format("_%p", inst_node) + "/trans_";
// layout cast node // precision cast node
if (tail_type->layout() != inst_type->layout()) { if (tail_type->precision() != inst_type->precision()) {
cur_node = InsertCastAfter( cur_node = InsertCastAfter(
"layout", "cast",
name_prefix + "layout", name_prefix + "cast",
graph, graph,
cur_node, cur_node,
inst_node, inst_node,
LiteType::GetTensorTy( LiteType::GetTensorTy(
tail_type->target(), tail_type->precision(), inst_type->layout())); tail_type->target(), inst_type->precision(), tail_type->layout()));
} }
// precision cast node // layout cast node
if (tail_type->precision() != inst_type->precision()) { if (tail_type->layout() != inst_type->layout()) {
cur_node = InsertCastAfter( cur_node = InsertCastAfter(
"cast", "layout",
name_prefix + "cast", name_prefix + "layout",
graph, graph,
cur_node, cur_node,
inst_node, inst_node,
...@@ -474,13 +476,20 @@ bool MLUPostprocessPass::IsFirstConvNode(Node* arg_node) { ...@@ -474,13 +476,20 @@ bool MLUPostprocessPass::IsFirstConvNode(Node* arg_node) {
return false; return false;
} }
void MLUPostprocessPass::GatherFirstConvNodes(SSAGraph* graph) { void MLUPostprocessPass::GatherAndModifyFirstConvNodes(SSAGraph* graph) {
for (auto& node : graph->mutable_nodes()) { for (auto& node : graph->mutable_nodes()) {
if (!node.IsStmt()) continue; if (!node.IsStmt()) continue;
if (node.AsStmt().op_type() == "feed") { if (node.AsStmt().op_type() == "feed") {
for (auto& out : node.outlinks) { for (auto& out : node.outlinks) {
if (IsFirstConvNode(out)) { if (IsFirstConvNode(out)) {
first_conv_nodes_.insert(out->AsArg().name); first_conv_nodes_.insert(out->AsArg().name);
// modify first conv nodes' type
const auto* old_type = out->AsArg().type;
out->AsArg().type =
LiteType::GetTensorTy(old_type->target(),
paddle::lite_api::PrecisionType::kInt8,
old_type->layout(),
old_type->device());
} }
} }
} }
...@@ -504,7 +513,7 @@ void MLUPostprocessPass::ModifyLayout(SSAGraph* graph) { ...@@ -504,7 +513,7 @@ void MLUPostprocessPass::ModifyLayout(SSAGraph* graph) {
out->AsArg().type = out->AsArg().type =
LiteType::GetTensorTy(old_type->target(), LiteType::GetTensorTy(old_type->target(),
old_type->precision(), old_type->precision(),
::paddle::lite_api::DataLayoutType::kNHWC, paddle::lite_api::DataLayoutType::kNHWC,
old_type->device()); old_type->device());
} }
} }
...@@ -523,7 +532,7 @@ void MLUPostprocessPass::ModifyLayout(SSAGraph* graph) { ...@@ -523,7 +532,7 @@ void MLUPostprocessPass::ModifyLayout(SSAGraph* graph) {
inp->AsArg().type = inp->AsArg().type =
LiteType::GetTensorTy(old_type->target(), LiteType::GetTensorTy(old_type->target(),
old_type->precision(), old_type->precision(),
::paddle::lite_api::DataLayoutType::kNHWC, paddle::lite_api::DataLayoutType::kNHWC,
old_type->device()); old_type->device());
} }
} }
...@@ -539,10 +548,12 @@ void MLUPostprocessPass::Apply(const std::unique_ptr<SSAGraph>& graph) { ...@@ -539,10 +548,12 @@ void MLUPostprocessPass::Apply(const std::unique_ptr<SSAGraph>& graph) {
// 1: feed->arg_in->subgraph->... 2: ...->subgraph->arg_out->fetch; // 1: feed->arg_in->subgraph->... 2: ...->subgraph->arg_out->fetch;
// arg_in and arg_out are assumed to be NHWC which user should be aware of. // arg_in and arg_out are assumed to be NHWC which user should be aware of.
// Thus here we change these args' layout to NHWC // Thus here we change these args' layout to NHWC
ModifyLayout(graph.get()); if (lite::DeviceInfo::Global().InputLayout() == DATALAYOUT(kNHWC)) {
ModifyLayout(graph.get());
}
if (lite::DeviceInfo::Global().UseFirstConv()) { if (lite::DeviceInfo::Global().UseFirstConv()) {
GatherFirstConvNodes(graph.get()); GatherAndModifyFirstConvNodes(graph.get());
} }
// insert io_copy, layout and precision cast of subgraph's inputs and outputs // insert io_copy, layout and precision cast of subgraph's inputs and outputs
......
...@@ -109,7 +109,7 @@ class MLUPostprocessPass : public ProgramPass { ...@@ -109,7 +109,7 @@ class MLUPostprocessPass : public ProgramPass {
void RecreateOp(Node* inst_node, SSAGraph* graph); void RecreateOp(Node* inst_node, SSAGraph* graph);
void GatherFirstConvNodes(SSAGraph* graph); void GatherAndModifyFirstConvNodes(SSAGraph* graph);
bool IsFirstConvNode(Node* arg_node); bool IsFirstConvNode(Node* arg_node);
......
...@@ -84,7 +84,7 @@ struct FPTypeTraits<paddle::lite_api::PrecisionType::kFloat> { ...@@ -84,7 +84,7 @@ struct FPTypeTraits<paddle::lite_api::PrecisionType::kFloat> {
template <> template <>
struct FPTypeTraits<paddle::lite_api::PrecisionType::kFP16> { struct FPTypeTraits<paddle::lite_api::PrecisionType::kFP16> {
typedef ::paddle::lite::fluid::float16 T; typedef paddle::lite::fluid::float16 T;
}; };
} // namespace mlu } // namespace mlu
......
...@@ -48,11 +48,11 @@ REGISTER_LITE_KERNEL( ...@@ -48,11 +48,11 @@ REGISTER_LITE_KERNEL(
def_layout_nhwc2nchw_fp16) def_layout_nhwc2nchw_fp16)
.BindInput("Input", .BindInput("Input",
{LiteType::GetTensorTy(TARGET(kHost), {LiteType::GetTensorTy(TARGET(kHost),
PRECISION(kFloat), PRECISION(kFP16),
DATALAYOUT(kNHWC))}) DATALAYOUT(kNHWC))})
.BindOutput("Out", .BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kHost), {LiteType::GetTensorTy(TARGET(kHost),
PRECISION(kFloat), PRECISION(kFP16),
DATALAYOUT(kNCHW))}) DATALAYOUT(kNCHW))})
.Finalize(); .Finalize();
...@@ -82,10 +82,27 @@ REGISTER_LITE_KERNEL( ...@@ -82,10 +82,27 @@ REGISTER_LITE_KERNEL(
def_layout_nchw2nhwc_fp16) def_layout_nchw2nhwc_fp16)
.BindInput("Input", .BindInput("Input",
{LiteType::GetTensorTy(TARGET(kHost), {LiteType::GetTensorTy(TARGET(kHost),
PRECISION(kFloat), PRECISION(kFP16),
DATALAYOUT(kNCHW))}) DATALAYOUT(kNCHW))})
.BindOutput("Out", .BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kHost), {LiteType::GetTensorTy(TARGET(kHost),
PRECISION(kFloat), PRECISION(kFP16),
DATALAYOUT(kNHWC))})
.Finalize();
REGISTER_LITE_KERNEL(
layout,
kMLU,
kInt8,
kNHWC,
paddle::lite::kernels::mlu::LayoutNchwToNhwcCompute<PRECISION(kInt8)>,
def_layout_nchw2nhwc_fp32_int8)
.BindInput("Input",
{LiteType::GetTensorTy(TARGET(kHost),
PRECISION(kInt8),
DATALAYOUT(kNCHW))})
.BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kHost),
PRECISION(kInt8),
DATALAYOUT(kNHWC))}) DATALAYOUT(kNHWC))})
.Finalize(); .Finalize();
...@@ -29,6 +29,24 @@ namespace lite { ...@@ -29,6 +29,24 @@ namespace lite {
namespace kernels { namespace kernels {
namespace mlu { namespace mlu {
template <paddle::lite_api::PrecisionType>
struct FPTypeTraits {};
template <>
struct FPTypeTraits<paddle::lite_api::PrecisionType::kFloat> {
typedef float T;
};
template <>
struct FPTypeTraits<paddle::lite_api::PrecisionType::kFP16> {
typedef paddle::lite::fluid::float16 T;
};
template <>
struct FPTypeTraits<paddle::lite_api::PrecisionType::kInt8> {
typedef int8_t T;
};
template <lite::TargetType Target, typename T> template <lite::TargetType Target, typename T>
inline void LayoutTransCompute(const int dim, inline void LayoutTransCompute(const int dim,
const lite::Context<Target>& context, const lite::Context<Target>& context,
...@@ -63,7 +81,7 @@ class LayoutNchwToNhwcCompute ...@@ -63,7 +81,7 @@ class LayoutNchwToNhwcCompute
auto& param = this->template Param<param_t>(); auto& param = this->template Param<param_t>();
auto* x = param.x; auto* x = param.x;
auto* out = param.y; auto* out = param.y;
out->template mutable_data<float>(); out->template mutable_data<typename FPTypeTraits<Precision>::T>();
auto x_dims = param.x->dims().size(); auto x_dims = param.x->dims().size();
auto& context = this->ctx_->template As<X86Context>(); auto& context = this->ctx_->template As<X86Context>();
...@@ -88,7 +106,8 @@ class LayoutNchwToNhwcCompute ...@@ -88,7 +106,8 @@ class LayoutNchwToNhwcCompute
CHECK(0) << "Unsupport dim in mlu layout nchw to nhwc"; CHECK(0) << "Unsupport dim in mlu layout nchw to nhwc";
} }
LayoutTransCompute<lite::TargetType::kX86, float>( LayoutTransCompute<lite::TargetType::kX86,
typename FPTypeTraits<Precision>::T>(
x_dims, context, *x, out, axis); x_dims, context, *x, out, axis);
if (x_dims > 2) { if (x_dims > 2) {
...@@ -111,7 +130,7 @@ class LayoutNhwcToNchwCompute ...@@ -111,7 +130,7 @@ class LayoutNhwcToNchwCompute
auto& param = this->template Param<param_t>(); auto& param = this->template Param<param_t>();
auto* x = param.x; auto* x = param.x;
auto* out = param.y; auto* out = param.y;
out->template mutable_data<float>(); out->template mutable_data<typename FPTypeTraits<Precision>::T>();
auto x_dims = param.x->dims().size(); auto x_dims = param.x->dims().size();
auto& context = this->ctx_->template As<X86Context>(); auto& context = this->ctx_->template As<X86Context>();
...@@ -136,7 +155,8 @@ class LayoutNhwcToNchwCompute ...@@ -136,7 +155,8 @@ class LayoutNhwcToNchwCompute
CHECK(0) << "Unsupport dim in mlu layout nhwc to nchw"; CHECK(0) << "Unsupport dim in mlu layout nhwc to nchw";
} }
LayoutTransCompute<lite::TargetType::kX86, float>( LayoutTransCompute<lite::TargetType::kX86,
typename FPTypeTraits<Precision>::T>(
x_dims, context, *x, out, axis); x_dims, context, *x, out, axis);
if (x_dims > 2) { if (x_dims > 2) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册