未验证 提交 f4788442 编写于 作者: A Aurelius84 提交者: GitHub

[JITLayer]Enable OneDNN on CPU and Fix zero shape (#47428) (#47436)

* [JITLayer]Enable OneDNN on CPU and Fix zero shape
上级 7618cbdc
......@@ -171,8 +171,9 @@ bool PaddleTensorToLoDTensor(const PaddleTensor &pt,
// NOTE(Aurelius84): Some kernels support zero shape input
// without memory holder, we should skip enforce logic.
bool has_zero_dim = (phi::product(ddim) == 0);
if (has_zero_dim) {
VLOG(3) << "Found zero dim from input with ddim: " << ddim;
VLOG(3) << "Found zero dim: " << has_zero_dim
<< " from input with ddim: " << ddim;
if (!has_zero_dim) {
PADDLE_ENFORCE_NOT_NULL(
input_ptr,
paddle::platform::errors::Fatal(
......
......@@ -34,12 +34,16 @@ PredictorEngine::PredictorEngine(const std::shared_ptr<FunctionInfo> &info,
utils::ShareParamsIntoScope(info_->ParamNames(), params_dict, scope_.get());
VLOG(6) << framework::GenScopeTreeDebugInfo(scope_.get());
// TODO(Aurelius84): Expose AnalysisConfig to user.
AnalysisConfig config;
config.SetProgFile(info->ProgramFilePath());
if (platform::is_gpu_place(place_)) {
config.EnableUseGpu(100, place_.GetDeviceId());
} else if (platform::is_cpu_place(place_)) {
config.DisableGpu();
config.EnableMKLDNN();
config.EnableMkldnnInt8();
config.SetMkldnnCacheCapacity(0);
}
config.SetSkipLoadParams(true);
config.SetApplyOptim(true);
......@@ -59,10 +63,6 @@ std::vector<Tensor> PredictorEngine::operator()(
std::vector<DenseTensor> PredictorEngine::operator()(
const std::vector<DenseTensor> &inputs) {
for (auto t : inputs) {
VLOG(1) << "inputs is init: " << t.initialized();
}
std::vector<PaddleTensor> pt_inputs;
std::vector<PaddleTensor> pt_outputs;
for (auto &t : inputs) {
......@@ -84,22 +84,23 @@ std::vector<DenseTensor> PredictorEngine::operator()(
static PaddleTensor DenseTensorToPaddleTensor(DenseTensor *t) {
PaddleTensor pt;
if (framework::TransToProtoVarType(t->dtype()) ==
framework::proto::VarType::INT32) {
pt.data.Reset(t->data(), t->numel() * sizeof(int32_t));
pt.dtype = PaddleDType::INT32;
} else if (framework::TransToProtoVarType(t->dtype()) ==
framework::proto::VarType::INT64) {
pt.data.Reset(t->data(), t->numel() * sizeof(int64_t));
pt.dtype = PaddleDType::INT64;
} else if (framework::TransToProtoVarType(t->dtype()) ==
framework::proto::VarType::FP32) {
pt.data.Reset(t->data(), t->numel() * sizeof(float));
pt.dtype = PaddleDType::FLOAT32;
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported tensor date type. Now only supports INT64, FP32, INT32."));
switch (framework::TransToProtoVarType(t->dtype())) {
case framework::proto::VarType::INT32: {
pt.data.Reset(t->data(), t->numel() * sizeof(int32_t));
pt.dtype = PaddleDType::INT32;
} break;
case framework::proto::VarType::INT64: {
pt.data.Reset(t->data(), t->numel() * sizeof(int64_t));
pt.dtype = PaddleDType::INT64;
} break;
case framework::proto::VarType::FP32: {
pt.data.Reset(t->data(), t->numel() * sizeof(float));
pt.dtype = PaddleDType::FLOAT32;
} break;
default:
PADDLE_THROW(
platform::errors::Unimplemented("Unsupported tensor date type. Now "
"only supports INT64, FP32, INT32."));
}
pt.shape = phi::vectorize<int>(t->dims());
return pt;
......@@ -110,17 +111,22 @@ static bool PaddleTensorToDenseTensor(const PaddleTensor &pt,
const platform::Place &place) {
framework::DDim ddim = phi::make_ddim(pt.shape);
void *input_ptr;
if (pt.dtype == PaddleDType::INT64) {
input_ptr = t->mutable_data<int64_t>(ddim, place);
} else if (pt.dtype == PaddleDType::FLOAT32) {
input_ptr = t->mutable_data<float>(ddim, place);
} else if (pt.dtype == PaddleDType::INT32) {
input_ptr = t->mutable_data<int32_t>(ddim, place);
} else if (pt.dtype == PaddleDType::FLOAT16) {
input_ptr = t->mutable_data<float16>(ddim, place);
} else {
LOG(ERROR) << "unsupported feed type " << pt.dtype;
return false;
switch (pt.dtype) {
case PaddleDType::INT64:
input_ptr = t->mutable_data<int64_t>(ddim, place);
break;
case PaddleDType::FLOAT32:
input_ptr = t->mutable_data<float>(ddim, place);
break;
case PaddleDType::INT32:
input_ptr = t->mutable_data<int32_t>(ddim, place);
break;
case PaddleDType::FLOAT16:
input_ptr = t->mutable_data<float16>(ddim, place);
break;
default:
LOG(ERROR) << "unsupported feed type " << pt.dtype;
return false;
}
PADDLE_ENFORCE_NOT_NULL(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册