From f4788442e8ced1d04cfbdc682589d6bc920b17da Mon Sep 17 00:00:00 2001 From: Aurelius84 Date: Sat, 29 Oct 2022 12:52:01 +0800 Subject: [PATCH] [JITLayer]Enable OneDNN on CPU and Fix zero shape (#47428) (#47436) * [JITLayer]Enable OneDNN on CPU and Fix zero shape --- .../fluid/inference/api/analysis_predictor.cc | 5 +- paddle/fluid/jit/engine/predictor_engine.cc | 68 ++++++++++--------- 2 files changed, 40 insertions(+), 33 deletions(-) diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index 469baa92706..8ec16dad3c1 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -171,8 +171,9 @@ bool PaddleTensorToLoDTensor(const PaddleTensor &pt, // NOTE(Aurelius84): Some kernels support zero shape input // without memory holder, we should skip enforce logic. bool has_zero_dim = (phi::product(ddim) == 0); - if (has_zero_dim) { - VLOG(3) << "Found zero dim from input with ddim: " << ddim; + VLOG(3) << "Found zero dim: " << has_zero_dim + << " from input with ddim: " << ddim; + if (!has_zero_dim) { PADDLE_ENFORCE_NOT_NULL( input_ptr, paddle::platform::errors::Fatal( diff --git a/paddle/fluid/jit/engine/predictor_engine.cc b/paddle/fluid/jit/engine/predictor_engine.cc index d6bdf42b041..6a44c192c16 100644 --- a/paddle/fluid/jit/engine/predictor_engine.cc +++ b/paddle/fluid/jit/engine/predictor_engine.cc @@ -34,12 +34,16 @@ PredictorEngine::PredictorEngine(const std::shared_ptr &info, utils::ShareParamsIntoScope(info_->ParamNames(), params_dict, scope_.get()); VLOG(6) << framework::GenScopeTreeDebugInfo(scope_.get()); + // TODO(Aurelius84): Expose AnalysisConfig to user. AnalysisConfig config; config.SetProgFile(info->ProgramFilePath()); if (platform::is_gpu_place(place_)) { config.EnableUseGpu(100, place_.GetDeviceId()); } else if (platform::is_cpu_place(place_)) { config.DisableGpu(); + config.EnableMKLDNN(); + config.EnableMkldnnInt8(); + config.SetMkldnnCacheCapacity(0); } config.SetSkipLoadParams(true); config.SetApplyOptim(true); @@ -59,10 +63,6 @@ std::vector PredictorEngine::operator()( std::vector PredictorEngine::operator()( const std::vector &inputs) { - for (auto t : inputs) { - VLOG(1) << "inputs is init: " << t.initialized(); - } - std::vector pt_inputs; std::vector pt_outputs; for (auto &t : inputs) { @@ -84,22 +84,23 @@ std::vector PredictorEngine::operator()( static PaddleTensor DenseTensorToPaddleTensor(DenseTensor *t) { PaddleTensor pt; - - if (framework::TransToProtoVarType(t->dtype()) == - framework::proto::VarType::INT32) { - pt.data.Reset(t->data(), t->numel() * sizeof(int32_t)); - pt.dtype = PaddleDType::INT32; - } else if (framework::TransToProtoVarType(t->dtype()) == - framework::proto::VarType::INT64) { - pt.data.Reset(t->data(), t->numel() * sizeof(int64_t)); - pt.dtype = PaddleDType::INT64; - } else if (framework::TransToProtoVarType(t->dtype()) == - framework::proto::VarType::FP32) { - pt.data.Reset(t->data(), t->numel() * sizeof(float)); - pt.dtype = PaddleDType::FLOAT32; - } else { - PADDLE_THROW(platform::errors::Unimplemented( - "Unsupported tensor date type. Now only supports INT64, FP32, INT32.")); + switch (framework::TransToProtoVarType(t->dtype())) { + case framework::proto::VarType::INT32: { + pt.data.Reset(t->data(), t->numel() * sizeof(int32_t)); + pt.dtype = PaddleDType::INT32; + } break; + case framework::proto::VarType::INT64: { + pt.data.Reset(t->data(), t->numel() * sizeof(int64_t)); + pt.dtype = PaddleDType::INT64; + } break; + case framework::proto::VarType::FP32: { + pt.data.Reset(t->data(), t->numel() * sizeof(float)); + pt.dtype = PaddleDType::FLOAT32; + } break; + default: + PADDLE_THROW( + platform::errors::Unimplemented("Unsupported tensor date type. Now " + "only supports INT64, FP32, INT32.")); } pt.shape = phi::vectorize(t->dims()); return pt; @@ -110,17 +111,22 @@ static bool PaddleTensorToDenseTensor(const PaddleTensor &pt, const platform::Place &place) { framework::DDim ddim = phi::make_ddim(pt.shape); void *input_ptr; - if (pt.dtype == PaddleDType::INT64) { - input_ptr = t->mutable_data(ddim, place); - } else if (pt.dtype == PaddleDType::FLOAT32) { - input_ptr = t->mutable_data(ddim, place); - } else if (pt.dtype == PaddleDType::INT32) { - input_ptr = t->mutable_data(ddim, place); - } else if (pt.dtype == PaddleDType::FLOAT16) { - input_ptr = t->mutable_data(ddim, place); - } else { - LOG(ERROR) << "unsupported feed type " << pt.dtype; - return false; + switch (pt.dtype) { + case PaddleDType::INT64: + input_ptr = t->mutable_data(ddim, place); + break; + case PaddleDType::FLOAT32: + input_ptr = t->mutable_data(ddim, place); + break; + case PaddleDType::INT32: + input_ptr = t->mutable_data(ddim, place); + break; + case PaddleDType::FLOAT16: + input_ptr = t->mutable_data(ddim, place); + break; + default: + LOG(ERROR) << "unsupported feed type " << pt.dtype; + return false; } PADDLE_ENFORCE_NOT_NULL( -- GitLab