未验证 提交 f4788442 编写于 作者: A Aurelius84 提交者: GitHub

[JITLayer]Enable OneDNN on CPU and Fix zero shape (#47428) (#47436)

* [JITLayer]Enable OneDNN on CPU and Fix zero shape
上级 7618cbdc
...@@ -171,8 +171,9 @@ bool PaddleTensorToLoDTensor(const PaddleTensor &pt, ...@@ -171,8 +171,9 @@ bool PaddleTensorToLoDTensor(const PaddleTensor &pt,
// NOTE(Aurelius84): Some kernels support zero shape input // NOTE(Aurelius84): Some kernels support zero shape input
// without memory holder, we should skip enforce logic. // without memory holder, we should skip enforce logic.
bool has_zero_dim = (phi::product(ddim) == 0); bool has_zero_dim = (phi::product(ddim) == 0);
if (has_zero_dim) { VLOG(3) << "Found zero dim: " << has_zero_dim
VLOG(3) << "Found zero dim from input with ddim: " << ddim; << " from input with ddim: " << ddim;
if (!has_zero_dim) {
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
input_ptr, input_ptr,
paddle::platform::errors::Fatal( paddle::platform::errors::Fatal(
......
...@@ -34,12 +34,16 @@ PredictorEngine::PredictorEngine(const std::shared_ptr<FunctionInfo> &info, ...@@ -34,12 +34,16 @@ PredictorEngine::PredictorEngine(const std::shared_ptr<FunctionInfo> &info,
utils::ShareParamsIntoScope(info_->ParamNames(), params_dict, scope_.get()); utils::ShareParamsIntoScope(info_->ParamNames(), params_dict, scope_.get());
VLOG(6) << framework::GenScopeTreeDebugInfo(scope_.get()); VLOG(6) << framework::GenScopeTreeDebugInfo(scope_.get());
// TODO(Aurelius84): Expose AnalysisConfig to user.
AnalysisConfig config; AnalysisConfig config;
config.SetProgFile(info->ProgramFilePath()); config.SetProgFile(info->ProgramFilePath());
if (platform::is_gpu_place(place_)) { if (platform::is_gpu_place(place_)) {
config.EnableUseGpu(100, place_.GetDeviceId()); config.EnableUseGpu(100, place_.GetDeviceId());
} else if (platform::is_cpu_place(place_)) { } else if (platform::is_cpu_place(place_)) {
config.DisableGpu(); config.DisableGpu();
config.EnableMKLDNN();
config.EnableMkldnnInt8();
config.SetMkldnnCacheCapacity(0);
} }
config.SetSkipLoadParams(true); config.SetSkipLoadParams(true);
config.SetApplyOptim(true); config.SetApplyOptim(true);
...@@ -59,10 +63,6 @@ std::vector<Tensor> PredictorEngine::operator()( ...@@ -59,10 +63,6 @@ std::vector<Tensor> PredictorEngine::operator()(
std::vector<DenseTensor> PredictorEngine::operator()( std::vector<DenseTensor> PredictorEngine::operator()(
const std::vector<DenseTensor> &inputs) { const std::vector<DenseTensor> &inputs) {
for (auto t : inputs) {
VLOG(1) << "inputs is init: " << t.initialized();
}
std::vector<PaddleTensor> pt_inputs; std::vector<PaddleTensor> pt_inputs;
std::vector<PaddleTensor> pt_outputs; std::vector<PaddleTensor> pt_outputs;
for (auto &t : inputs) { for (auto &t : inputs) {
...@@ -84,22 +84,23 @@ std::vector<DenseTensor> PredictorEngine::operator()( ...@@ -84,22 +84,23 @@ std::vector<DenseTensor> PredictorEngine::operator()(
static PaddleTensor DenseTensorToPaddleTensor(DenseTensor *t) { static PaddleTensor DenseTensorToPaddleTensor(DenseTensor *t) {
PaddleTensor pt; PaddleTensor pt;
switch (framework::TransToProtoVarType(t->dtype())) {
if (framework::TransToProtoVarType(t->dtype()) == case framework::proto::VarType::INT32: {
framework::proto::VarType::INT32) {
pt.data.Reset(t->data(), t->numel() * sizeof(int32_t)); pt.data.Reset(t->data(), t->numel() * sizeof(int32_t));
pt.dtype = PaddleDType::INT32; pt.dtype = PaddleDType::INT32;
} else if (framework::TransToProtoVarType(t->dtype()) == } break;
framework::proto::VarType::INT64) { case framework::proto::VarType::INT64: {
pt.data.Reset(t->data(), t->numel() * sizeof(int64_t)); pt.data.Reset(t->data(), t->numel() * sizeof(int64_t));
pt.dtype = PaddleDType::INT64; pt.dtype = PaddleDType::INT64;
} else if (framework::TransToProtoVarType(t->dtype()) == } break;
framework::proto::VarType::FP32) { case framework::proto::VarType::FP32: {
pt.data.Reset(t->data(), t->numel() * sizeof(float)); pt.data.Reset(t->data(), t->numel() * sizeof(float));
pt.dtype = PaddleDType::FLOAT32; pt.dtype = PaddleDType::FLOAT32;
} else { } break;
PADDLE_THROW(platform::errors::Unimplemented( default:
"Unsupported tensor date type. Now only supports INT64, FP32, INT32.")); PADDLE_THROW(
platform::errors::Unimplemented("Unsupported tensor date type. Now "
"only supports INT64, FP32, INT32."));
} }
pt.shape = phi::vectorize<int>(t->dims()); pt.shape = phi::vectorize<int>(t->dims());
return pt; return pt;
...@@ -110,15 +111,20 @@ static bool PaddleTensorToDenseTensor(const PaddleTensor &pt, ...@@ -110,15 +111,20 @@ static bool PaddleTensorToDenseTensor(const PaddleTensor &pt,
const platform::Place &place) { const platform::Place &place) {
framework::DDim ddim = phi::make_ddim(pt.shape); framework::DDim ddim = phi::make_ddim(pt.shape);
void *input_ptr; void *input_ptr;
if (pt.dtype == PaddleDType::INT64) { switch (pt.dtype) {
case PaddleDType::INT64:
input_ptr = t->mutable_data<int64_t>(ddim, place); input_ptr = t->mutable_data<int64_t>(ddim, place);
} else if (pt.dtype == PaddleDType::FLOAT32) { break;
case PaddleDType::FLOAT32:
input_ptr = t->mutable_data<float>(ddim, place); input_ptr = t->mutable_data<float>(ddim, place);
} else if (pt.dtype == PaddleDType::INT32) { break;
case PaddleDType::INT32:
input_ptr = t->mutable_data<int32_t>(ddim, place); input_ptr = t->mutable_data<int32_t>(ddim, place);
} else if (pt.dtype == PaddleDType::FLOAT16) { break;
case PaddleDType::FLOAT16:
input_ptr = t->mutable_data<float16>(ddim, place); input_ptr = t->mutable_data<float16>(ddim, place);
} else { break;
default:
LOG(ERROR) << "unsupported feed type " << pt.dtype; LOG(ERROR) << "unsupported feed type " << pt.dtype;
return false; return false;
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册