未验证 提交 02cf6764 编写于 作者: L Leo Chen 提交者: GitHub

[new-exec] fit mkldnn op (#41058)

* fix bug that some op has no op_role attr

* add mkldnn support for new executor

* fit for mkldnn data_transfer

* fit for mkldnn data_transfer
上级 dc0702fe
...@@ -24,7 +24,7 @@ bool DataTranferHelper::apply(const OpKernelType& kernel_type_for_var, ...@@ -24,7 +24,7 @@ bool DataTranferHelper::apply(const OpKernelType& kernel_type_for_var,
const std::string& var_name, const std::string& var_name,
std::string* new_var_name, std::string* new_var_name,
std::vector<OpFuncNode>* op_func_nodes, std::vector<OpFuncNode>* op_func_nodes,
bool use_local_scope) { bool use_local_scope, bool is_fetch_v2) {
bool is_transferred = false; bool is_transferred = false;
auto* src_var_name = &var_name; auto* src_var_name = &var_name;
...@@ -35,8 +35,11 @@ bool DataTranferHelper::apply(const OpKernelType& kernel_type_for_var, ...@@ -35,8 +35,11 @@ bool DataTranferHelper::apply(const OpKernelType& kernel_type_for_var,
if (need_layout_transform(kernel_type_for_var, expected_kernel_key)) { if (need_layout_transform(kernel_type_for_var, expected_kernel_key)) {
auto op = TransferLayout( auto op = TransferLayout(
*src_var_name, new_var_name, kernel_type_for_var.data_layout_, *src_var_name, new_var_name, kernel_type_for_var.data_layout_,
expected_kernel_key.data_layout_, var_scope_, local_scope); expected_kernel_key.data_layout_, var_scope_, local_scope, is_fetch_v2);
RunAndConstructOpFuncNode(op, *src_var_name, *new_var_name, op_func_nodes); if (op) {
RunAndConstructOpFuncNode(op, *src_var_name, *new_var_name,
op_func_nodes);
}
// update src_var_name // update src_var_name
src_var_name = new_var_name; src_var_name = new_var_name;
is_transferred = true; is_transferred = true;
...@@ -46,7 +49,10 @@ bool DataTranferHelper::apply(const OpKernelType& kernel_type_for_var, ...@@ -46,7 +49,10 @@ bool DataTranferHelper::apply(const OpKernelType& kernel_type_for_var,
auto op = TransferDtype( auto op = TransferDtype(
*src_var_name, new_var_name, kernel_type_for_var.data_type_, *src_var_name, new_var_name, kernel_type_for_var.data_type_,
expected_kernel_key.data_type_, var_scope_, local_scope); expected_kernel_key.data_type_, var_scope_, local_scope);
RunAndConstructOpFuncNode(op, *src_var_name, *new_var_name, op_func_nodes); if (op) {
RunAndConstructOpFuncNode(op, *src_var_name, *new_var_name,
op_func_nodes);
}
// update src_var_name // update src_var_name
src_var_name = new_var_name; src_var_name = new_var_name;
is_transferred = true; is_transferred = true;
...@@ -55,9 +61,13 @@ bool DataTranferHelper::apply(const OpKernelType& kernel_type_for_var, ...@@ -55,9 +61,13 @@ bool DataTranferHelper::apply(const OpKernelType& kernel_type_for_var,
if (need_device_transform(kernel_type_for_var, expected_kernel_key)) { if (need_device_transform(kernel_type_for_var, expected_kernel_key)) {
auto src_place = kernel_type_for_var.place_; auto src_place = kernel_type_for_var.place_;
auto dst_place = expected_kernel_key.place_; auto dst_place = expected_kernel_key.place_;
auto op = TransferDevice(*src_var_name, new_var_name, src_place, dst_place, auto op = TransferDevice(*src_var_name, new_var_name, src_place, dst_place,
var_scope_, local_scope); var_scope_, local_scope);
RunAndConstructOpFuncNode(op, *src_var_name, *new_var_name, op_func_nodes); if (op) {
RunAndConstructOpFuncNode(op, *src_var_name, *new_var_name,
op_func_nodes);
}
is_transferred = true; is_transferred = true;
} }
return is_transferred; return is_transferred;
...@@ -128,17 +138,44 @@ void DataTranferHelper::RunAndConstructOpFuncNode( ...@@ -128,17 +138,44 @@ void DataTranferHelper::RunAndConstructOpFuncNode(
new_op_func_nodes->emplace_back(std::move(new_op_func_node)); new_op_func_nodes->emplace_back(std::move(new_op_func_node));
} }
std::shared_ptr<OperatorBase> TransferLayout(const std::string& var_name, // Var is initialized && var contains tensor && tensor is initialized
std::string* new_var_name, bool IsTensorOfVarInitialized(Variable* var) {
DataLayout in_layout, if (var->IsInitialized()) {
DataLayout out_layout, if (var->IsType<LoDTensor>() || var->IsType<phi::SelectedRows>()) {
VariableScope* var_scope, return GetLoDTensorOrSelectedRowsValueFromVar(*var)->IsInitialized();
framework::Scope* local_scope) { } else if (var->IsType<LoDTensorArray>()) {
return static_cast<const Tensor*>(&(var->Get<LoDTensorArray>()[0]))
->IsInitialized();
}
}
return false;
}
std::shared_ptr<OperatorBase> TransferLayout(
const std::string& var_name, std::string* new_var_name,
DataLayout in_layout, DataLayout out_layout, VariableScope* var_scope,
framework::Scope* local_scope, bool is_fetch_v2) {
#ifdef PADDLE_WITH_MKLDNN
// NOTE(zhiqiu): hot fix, follow the same logic in DataCopy() in fetch_op.cc
if (in_layout == framework::DataLayout::kMKLDNN &&
var_name == framework::GradVarName("Filter") && is_fetch_v2) {
out_layout = framework::DataLayout::kNCHW;
}
#endif
// 1. Generate new_var_name and Initialize it // 1. Generate new_var_name and Initialize it
*new_var_name = *new_var_name = var_name + "_layout_" +
var_name + "_layout_" + std::to_string(var_scope->VarSize() + 1); std::to_string(static_cast<int>(in_layout)) + "_" +
auto* ptr = local_scope->Var(*new_var_name); std::to_string(static_cast<int>(out_layout));
if (var_scope->HasVar(*new_var_name) &&
IsTensorOfVarInitialized(var_scope->Var(*new_var_name))) {
// already has same var
VLOG(4) << "Use cached variable: " << *new_var_name;
return nullptr;
}
auto* ptr = local_scope->Var(*new_var_name);
auto var_type = var_scope->Var(var_name)->Type(); auto var_type = var_scope->Var(var_name)->Type();
InitializeVariable(ptr, static_cast<proto::VarType::Type>(var_type)); InitializeVariable(ptr, static_cast<proto::VarType::Type>(var_type));
VLOG(3) << "Create Variable " << *new_var_name VLOG(3) << "Create Variable " << *new_var_name
...@@ -171,10 +208,17 @@ std::shared_ptr<OperatorBase> TransferDtype(const std::string& var_name, ...@@ -171,10 +208,17 @@ std::shared_ptr<OperatorBase> TransferDtype(const std::string& var_name,
VariableScope* var_scope, VariableScope* var_scope,
framework::Scope* local_scope) { framework::Scope* local_scope) {
// 1. Generate new_var_name and Initialize it // 1. Generate new_var_name and Initialize it
*new_var_name = *new_var_name = var_name + "_dtype_" +
var_name + "_dtype_" + std::to_string(var_scope->VarSize() + 1); std::to_string(static_cast<int>(in_dtype)) + "_" +
auto* ptr = local_scope->Var(*new_var_name); std::to_string(static_cast<int>(out_dtype));
if (var_scope->HasVar(*new_var_name) &&
IsTensorOfVarInitialized(var_scope->Var(*new_var_name))) {
// already has same var
VLOG(4) << "Use cached variable: " << *new_var_name;
return nullptr;
}
auto* ptr = local_scope->Var(*new_var_name);
auto var_type = var_scope->Var(var_name)->Type(); auto var_type = var_scope->Var(var_name)->Type();
InitializeVariable(ptr, static_cast<proto::VarType::Type>(var_type)); InitializeVariable(ptr, static_cast<proto::VarType::Type>(var_type));
...@@ -211,10 +255,17 @@ std::shared_ptr<OperatorBase> TransferDevice(const std::string& var_name, ...@@ -211,10 +255,17 @@ std::shared_ptr<OperatorBase> TransferDevice(const std::string& var_name,
VariableScope* var_scope, VariableScope* var_scope,
framework::Scope* local_scope) { framework::Scope* local_scope) {
// 1. Generate new_var_name and Initialize it // 1. Generate new_var_name and Initialize it
*new_var_name = *new_var_name = var_name + "_device_" + src_place.DebugString() + "_" +
var_name + "_device_" + std::to_string(var_scope->VarSize() + 1); dst_place.DebugString();
auto* ptr = local_scope->Var(*new_var_name);
if (var_scope->HasVar(*new_var_name) &&
IsTensorOfVarInitialized(var_scope->Var(*new_var_name))) {
// already has same var
VLOG(4) << "Use cached variable: " << *new_var_name;
return nullptr;
}
auto* ptr = local_scope->Var(*new_var_name);
auto var_type = var_scope->Var(var_name)->Type(); auto var_type = var_scope->Var(var_name)->Type();
InitializeVariable(ptr, static_cast<proto::VarType::Type>(var_type)); InitializeVariable(ptr, static_cast<proto::VarType::Type>(var_type));
VLOG(3) << "Create Variable " << *new_var_name VLOG(3) << "Create Variable " << *new_var_name
...@@ -258,12 +309,28 @@ void ApplyDataTransform(const OpKernelType& expected_kernel_key, ...@@ -258,12 +309,28 @@ void ApplyDataTransform(const OpKernelType& expected_kernel_key,
// record the no need transform variable index. // record the no need transform variable index.
std::unordered_set<int> no_data_transform_index; std::unordered_set<int> no_data_transform_index;
const std::unordered_set<std::string>* no_buffer_ins = nullptr;
auto& no_buffer_inferer = op_base->Info().NoNeedBufferVarsInferer();
if (no_buffer_inferer) {
no_buffer_ins = &(no_buffer_inferer(op_base->Inputs(), op_base->Outputs(),
op_base->Attrs()));
if (no_buffer_ins->empty()) {
no_buffer_ins = nullptr;
}
}
DataTranferHelper data_transfer_helper(place, var_scope); DataTranferHelper data_transfer_helper(place, var_scope);
for (auto& var_name_item : *ins_map_temp) { for (auto& var_name_item : *ins_map_temp) {
bool should_skip_input =
no_buffer_ins && no_buffer_ins->count(var_name_item.first) > 0;
for (size_t i = 0; i < var_name_item.second.size(); ++i) { for (size_t i = 0; i < var_name_item.second.size(); ++i) {
auto var = var_name_item.second[i]; auto var = var_name_item.second[i];
auto var_name = new_ins[var_name_item.first].at(i); auto var_name = new_ins[var_name_item.first].at(i);
const Tensor* tensor_in; const Tensor* tensor_in;
std::string new_var_name;
bool is_transferred = false;
if (var->IsType<LoDTensor>() || var->IsType<phi::SelectedRows>()) { if (var->IsType<LoDTensor>() || var->IsType<phi::SelectedRows>()) {
tensor_in = GetLoDTensorOrSelectedRowsValueFromVar(*var); tensor_in = GetLoDTensorOrSelectedRowsValueFromVar(*var);
} else if (var->IsType<LoDTensorArray>()) { } else if (var->IsType<LoDTensorArray>()) {
...@@ -272,18 +339,54 @@ void ApplyDataTransform(const OpKernelType& expected_kernel_key, ...@@ -272,18 +339,54 @@ void ApplyDataTransform(const OpKernelType& expected_kernel_key,
} else { } else {
continue; continue;
} }
// special case
if (!tensor_in->IsInitialized()) { if (!tensor_in->IsInitialized()) {
if (should_skip_input == true) {
#ifdef PADDLE_WITH_MKLDNN
// Var without buffer may be needed
// for some situation like InferShape().
// In this situation We cannot skip Var analysis, as
// MKL-DNN shape of Var may differ from kNHWC Var
// In such situation corressponding resized Var
// has to be created and registered
if ((tensor_in->layout() == DataLayout::kMKLDNN) &&
(var->IsType<LoDTensor>() == true) &&
(expected_kernel_key.data_layout_ != DataLayout::kMKLDNN) &&
(paddle::platform::MKLDNNDeviceContext::tls()
.get_cur_paddle_data_layout() == DataLayout::kNHWC)) {
VLOG(7) << "Created reshaped dummy input based on MKL-DNN Tensor , "
"but kNHWC layout"
<< var_name_item.first << " in Operator "
<< op_base->Type();
Scope* local_scope = use_local_scope
? var_scope->GetMutableLocalScope()
: var_scope->GetMutableScope();
auto op = TransferLayout(
var_name, &new_var_name, tensor_in->layout(), DataLayout::kNHWC,
var_scope, local_scope, op_base->Type() == "fetch_v2");
if (op) {
data_transfer_helper.RunAndConstructOpFuncNode(
op, var_name, new_var_name, new_op_func_nodes);
}
is_transferred = true;
} else {
VLOG(7) << "Skip scanning input " << var_name_item.first
<< " in Operator " << op_base->Type();
}
#endif
} else {
continue; continue;
} }
} else {
auto kernel_type_for_var = auto kernel_type_for_var =
static_cast<const framework::OperatorWithKernel*>(op_base) static_cast<const framework::OperatorWithKernel*>(op_base)
->GetKernelTypeForVar(var_name_item.first, *tensor_in, ->GetKernelTypeForVar(var_name_item.first, *tensor_in,
expected_kernel_key); expected_kernel_key);
// apply data transform // apply data transform
std::string new_var_name; is_transferred = data_transfer_helper.apply(
bool is_transferred = data_transfer_helper.apply(
kernel_type_for_var, expected_kernel_key, var_name, &new_var_name, kernel_type_for_var, expected_kernel_key, var_name, &new_var_name,
new_op_func_nodes, use_local_scope); new_op_func_nodes, use_local_scope, op_base->Type() == "fetch_v2");
}
if (is_transferred) { if (is_transferred) {
// update RuntimeContext.inputs and original op_func_node inputs // update RuntimeContext.inputs and original op_func_node inputs
......
...@@ -35,7 +35,8 @@ class DataTranferHelper { ...@@ -35,7 +35,8 @@ class DataTranferHelper {
bool apply(const OpKernelType& kernel_type_for_var, bool apply(const OpKernelType& kernel_type_for_var,
const OpKernelType& expected_kernel_key, const OpKernelType& expected_kernel_key,
const std::string& var_name, std::string* new_var_name, const std::string& var_name, std::string* new_var_name,
std::vector<OpFuncNode>* new_op_func_nodes, bool use_local_scope); std::vector<OpFuncNode>* new_op_func_nodes, bool use_local_scope,
bool is_fetch_v2);
void RunAndConstructShareNode(const std::string& src_var_name, void RunAndConstructShareNode(const std::string& src_var_name,
const std::string& dst_var_name, const std::string& dst_var_name,
...@@ -94,12 +95,10 @@ inline bool need_layout_transform(const OpKernelType& kernel_type_for_var, ...@@ -94,12 +95,10 @@ inline bool need_layout_transform(const OpKernelType& kernel_type_for_var,
expected_kernel_key.data_layout_); expected_kernel_key.data_layout_);
} }
std::shared_ptr<OperatorBase> TransferLayout(const std::string& var_name, std::shared_ptr<OperatorBase> TransferLayout(
std::string* new_var_name, const std::string& var_name, std::string* new_var_name,
DataLayout in_layout, DataLayout in_layout, DataLayout out_layout, VariableScope* var_scope,
DataLayout out_layout, framework::Scope* local_scope, bool is_fetch_v2);
VariableScope* var_scope,
framework::Scope* local_scope);
std::shared_ptr<OperatorBase> TransferDtype(const std::string& var_name, std::shared_ptr<OperatorBase> TransferDtype(const std::string& var_name,
std::string* new_var_name, std::string* new_var_name,
......
...@@ -22,6 +22,9 @@ ...@@ -22,6 +22,9 @@
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/os_info.h" #include "paddle/fluid/platform/os_info.h"
#include "paddle/fluid/platform/profiler/event_tracing.h" #include "paddle/fluid/platform/profiler/event_tracing.h"
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif
PADDLE_DEFINE_EXPORTED_bool(new_executor_use_inplace, true, PADDLE_DEFINE_EXPORTED_bool(new_executor_use_inplace, true,
"Use inplace in new executor"); "Use inplace in new executor");
...@@ -55,6 +58,7 @@ InterpreterCore::InterpreterCore(const platform::Place& place, ...@@ -55,6 +58,7 @@ InterpreterCore::InterpreterCore(const platform::Place& place,
block_(block), block_(block),
global_scope_(global_scope), global_scope_(global_scope),
stream_analyzer_(place) { stream_analyzer_(place) {
VLOG(4) << "InterpreterCore(): " << this << " on " << place_;
is_build_ = false; is_build_ = false;
async_work_queue_.reset(new interpreter::AsyncWorkQueue( async_work_queue_.reset(new interpreter::AsyncWorkQueue(
kHostNumThreads, kDeviceNumThreads, &main_thread_blocker_)); kHostNumThreads, kDeviceNumThreads, &main_thread_blocker_));
...@@ -92,6 +96,14 @@ InterpreterCore::~InterpreterCore() { ...@@ -92,6 +96,14 @@ InterpreterCore::~InterpreterCore() {
gc_.reset(nullptr); gc_.reset(nullptr);
async_work_queue_.reset(nullptr); async_work_queue_.reset(nullptr);
VLOG(4) << "~InterpreterCore(): " << this;
VLOG(4) << " on" << place_;
#ifdef PADDLE_WITH_MKLDNN
// Clear mkl-dnn cache,
// this is needed to have mkl-dnn unit tests working
platform::ClearMKLDNNCache(place_, this);
#endif
} }
void InterpreterCore::SetCopyProgram(std::shared_ptr<ProgramDesc> prog) { void InterpreterCore::SetCopyProgram(std::shared_ptr<ProgramDesc> prog) {
...@@ -101,6 +113,9 @@ void InterpreterCore::SetCopyProgram(std::shared_ptr<ProgramDesc> prog) { ...@@ -101,6 +113,9 @@ void InterpreterCore::SetCopyProgram(std::shared_ptr<ProgramDesc> prog) {
paddle::framework::FetchList InterpreterCore::Run( paddle::framework::FetchList InterpreterCore::Run(
const std::vector<std::string>& feed_names, const std::vector<std::string>& feed_names,
const std::vector<framework::LoDTensor>& feed_tensors) { const std::vector<framework::LoDTensor>& feed_tensors) {
#ifdef PADDLE_WITH_MKLDNN
platform::AttachPointerHashToMKLDNNKey(this, place_);
#endif
bool is_build = is_build_; bool is_build = is_build_;
global_scope_->SetLocalScope(local_scope_); global_scope_->SetLocalScope(local_scope_);
Prepare(feed_names, feed_tensors, is_build); Prepare(feed_names, feed_tensors, is_build);
...@@ -120,6 +135,9 @@ paddle::framework::FetchList InterpreterCore::Run( ...@@ -120,6 +135,9 @@ paddle::framework::FetchList InterpreterCore::Run(
paddle::framework::FetchList InterpreterCore::Run( paddle::framework::FetchList InterpreterCore::Run(
const std::vector<std::string>& feed_names) { const std::vector<std::string>& feed_names) {
#ifdef PADDLE_WITH_MKLDNN
platform::AttachPointerHashToMKLDNNKey(this, place_);
#endif
if (!is_build_) { if (!is_build_) {
if (create_local_scope_ && if (create_local_scope_ &&
global_scope_->GetMutableLocalScope() != global_scope_->GetMutableLocalScope() !=
......
...@@ -21,6 +21,10 @@ ...@@ -21,6 +21,10 @@
#include "paddle/fluid/operators/controlflow/while_op_helper.h" #include "paddle/fluid/operators/controlflow/while_op_helper.h"
#include "paddle/phi/core/kernel_factory.h" #include "paddle/phi/core/kernel_factory.h"
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif
PADDLE_DEFINE_EXPORTED_bool( PADDLE_DEFINE_EXPORTED_bool(
new_executor_sequential_run, false, new_executor_sequential_run, false,
"Enable sequential execution for standalone executor, used for debug"); "Enable sequential execution for standalone executor, used for debug");
...@@ -312,6 +316,10 @@ void build_op_func_list(const platform::Place& place, ...@@ -312,6 +316,10 @@ void build_op_func_list(const platform::Place& place,
operators::PrepareSafeEagerDeletionOnRecurrentOpAndRecurrentGradOp( operators::PrepareSafeEagerDeletionOnRecurrentOpAndRecurrentGradOp(
main_program, block.ID(), ops_unique); main_program, block.ID(), ops_unique);
#ifdef PADDLE_WITH_MKLDNN
platform::RegisterModelLayout(ops_unique, place);
#endif
// its elements will be moved to vec_func_list // its elements will be moved to vec_func_list
std::vector<std::shared_ptr<OperatorBase>> ops; std::vector<std::shared_ptr<OperatorBase>> ops;
for (auto& op_unique : ops_unique) { for (auto& op_unique : ops_unique) {
......
...@@ -112,7 +112,8 @@ std::shared_ptr<InterpreterCore> StandaloneExecutor::GetInterpreterCore( ...@@ -112,7 +112,8 @@ std::shared_ptr<InterpreterCore> StandaloneExecutor::GetInterpreterCore(
auto iter = interpretercores_.find(oss.str()); auto iter = interpretercores_.find(oss.str());
if (iter == interpretercores_.end()) { if (iter == interpretercores_.end()) {
VLOG(3) << "create interpreter_core for " << oss.str(); VLOG(3) << "create interpreter_core for " << oss.str() << " on place "
<< place_;
VLOG(3) << "add fetch op: " << add_fetch_op; VLOG(3) << "add fetch op: " << add_fetch_op;
std::shared_ptr<InterpreterCore> core = nullptr; std::shared_ptr<InterpreterCore> core = nullptr;
if (add_fetch_op) { if (add_fetch_op) {
......
...@@ -63,7 +63,7 @@ class StandaloneExecutor : public ExecutorBase { ...@@ -63,7 +63,7 @@ class StandaloneExecutor : public ExecutorBase {
const std::vector<std::string>& feed_names, const std::vector<std::string>& feed_names,
const std::vector<std::string>& fetch_names, bool add_fetch_op); const std::vector<std::string>& fetch_names, bool add_fetch_op);
const platform::Place& place_; platform::Place place_;
const ProgramDesc& startup_prog_; const ProgramDesc& startup_prog_;
const ProgramDesc& main_prog_; const ProgramDesc& main_prog_;
VariableScope global_scope_; VariableScope global_scope_;
......
...@@ -33,6 +33,7 @@ static void DataCopy(const framework::LoDTensor &src_item, ...@@ -33,6 +33,7 @@ static void DataCopy(const framework::LoDTensor &src_item,
framework::Tensor out; framework::Tensor out;
// Convert to desired Paddle layout, apart from grads of filter // Convert to desired Paddle layout, apart from grads of filter
// as params are not a subject to paddle's data_format // as params are not a subject to paddle's data_format
VLOG(4) << "innerTransDataLayoutFromMKLDNN";
framework::innerTransDataLayoutFromMKLDNN( framework::innerTransDataLayoutFromMKLDNN(
src_item.layout(), fetch_var_name == framework::GradVarName("Filter") src_item.layout(), fetch_var_name == framework::GradVarName("Filter")
? framework::DataLayout::kNCHW ? framework::DataLayout::kNCHW
......
...@@ -67,19 +67,25 @@ class TransferLayoutOp : public framework::OperatorWithKernel { ...@@ -67,19 +67,25 @@ class TransferLayoutOp : public framework::OperatorWithKernel {
// kernel's device type is decided by input tensor place // kernel's device type is decided by input tensor place
auto *in = ctx.InputVar("X"); auto *in = ctx.InputVar("X");
auto *in_tensor = framework::GetLoDTensorOrSelectedRowsValueFromVar(*in); auto *in_tensor = framework::GetLoDTensorOrSelectedRowsValueFromVar(*in);
// NOTE(zhiqiu): hot fix, allow empty tensor of kMKLDNN layout to run this
// op
if (in_tensor->layout() != DataLayout::kMKLDNN) {
PADDLE_ENFORCE_EQ(in_tensor->IsInitialized(), true, PADDLE_ENFORCE_EQ(in_tensor->IsInitialized(), true,
platform::errors::PreconditionNotMet( platform::errors::PreconditionNotMet(
"The tensor of Input(X) is not initialized.")); "The tensor of Input(X) is not initialized."));
}
auto place =
in_tensor->IsInitialized() ? in_tensor->place() : platform::CPUPlace();
// dtype is not important // dtype is not important
return framework::OpKernelType(framework::proto::VarType::FP32, return framework::OpKernelType(framework::proto::VarType::FP32, place);
in_tensor->place());
} }
framework::OpKernelType GetKernelTypeForVar( framework::OpKernelType GetKernelTypeForVar(
const std::string &var_name, const framework::Tensor &tensor, const std::string &var_name, const framework::Tensor &tensor,
const framework::OpKernelType &expected_kernel_type) const override { const framework::OpKernelType &expected_kernel_type) const override {
return framework::OpKernelType(expected_kernel_type.data_type_, return framework::OpKernelType(expected_kernel_type.data_type_,
tensor.place(), expected_kernel_type.place_,
expected_kernel_type.data_layout_); expected_kernel_type.data_layout_);
} }
}; };
...@@ -99,7 +105,9 @@ class TransferLayoutKernel { ...@@ -99,7 +105,9 @@ class TransferLayoutKernel {
auto &dev_ctx = ctx.device_context(); auto &dev_ctx = ctx.device_context();
auto src_layout = ctx.Attr<int>("src_layout"); auto src_layout = ctx.Attr<int>("src_layout");
auto dst_layout = ctx.Attr<int>("dst_layout"); auto dst_layout = ctx.Attr<int>("dst_layout");
TransferLayoutFunctor(x, out, dev_ctx, src_layout, dst_layout)(); auto input_name = ctx.InputName("X");
TransferLayoutFunctor(x, out, dev_ctx, src_layout, dst_layout,
input_name)();
} }
}; };
......
...@@ -39,12 +39,14 @@ class TransferLayoutFunctor { ...@@ -39,12 +39,14 @@ class TransferLayoutFunctor {
public: public:
TransferLayoutFunctor(const framework::Variable *in, framework::Variable *out, TransferLayoutFunctor(const framework::Variable *in, framework::Variable *out,
const platform::DeviceContext &dev_ctx, const platform::DeviceContext &dev_ctx,
const int src_layout, const int dst_layout) const int src_layout, const int dst_layout,
std::string in_name)
: in_(in), : in_(in),
out_(out), out_(out),
dev_ctx_(dev_ctx), dev_ctx_(dev_ctx),
src_layout_(src_layout), src_layout_(src_layout),
dst_layout_(dst_layout) {} dst_layout_(dst_layout),
in_name_(in_name) {}
void operator()() const { void operator()() const {
auto &in_tensor = *framework::GetLoDTensorOrSelectedRowsValueFromVar(*in_); auto &in_tensor = *framework::GetLoDTensorOrSelectedRowsValueFromVar(*in_);
...@@ -54,8 +56,18 @@ class TransferLayoutFunctor { ...@@ -54,8 +56,18 @@ class TransferLayoutFunctor {
out_tensor.set_layout(out_layout); out_tensor.set_layout(out_layout);
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
// NOTE(zhiqiu): to handle the special case in ApplyDataTransform() in
// data_transfer.cc
auto in_layout = static_cast<DataLayout>(src_layout_); auto in_layout = static_cast<DataLayout>(src_layout_);
auto *tensor_out = out_->GetMutable<framework::LoDTensor>();
VLOG(4) << in_layout << "->" << out_layout << " " << in_tensor.layout(); VLOG(4) << in_layout << "->" << out_layout << " " << in_tensor.layout();
if (!in_tensor.IsInitialized() && in_layout == DataLayout::kMKLDNN &&
out_layout == DataLayout::kNHWC) {
tensor_out->Resize(in_tensor.dims());
tensor_out->set_layout(out_layout);
platform::MatchShapeToLayout(tensor_out, in_layout, out_layout);
return;
}
if (in_layout == DataLayout::kMKLDNN || out_layout == DataLayout::kMKLDNN) { if (in_layout == DataLayout::kMKLDNN || out_layout == DataLayout::kMKLDNN) {
PADDLE_ENFORCE_NE( PADDLE_ENFORCE_NE(
in_layout, out_layout, in_layout, out_layout,
...@@ -81,13 +93,21 @@ class TransferLayoutFunctor { ...@@ -81,13 +93,21 @@ class TransferLayoutFunctor {
out_tensor.set_layout(DataLayout::kMKLDNN); out_tensor.set_layout(DataLayout::kMKLDNN);
out_tensor.set_format(out_format); out_tensor.set_format(out_format);
} else { } else {
VLOG(4) << "kNCHW"; auto target_layout = paddle::platform::MKLDNNDeviceContext::tls()
.get_cur_paddle_data_layout();
// NOTE(zhiqiu): hot fix, follow the same logic in DataCopy() in
// fetch_op.cc
if (out_layout == DataLayout::kNCHW &&
in_name_ == framework::GradVarName("Filter")) {
target_layout = out_layout;
}
VLOG(4) << "innerTransDataLayoutFromMKLDNN: " << in_layout << "->"
<< target_layout;
// Case2 - transfrom from MKLDNN OPKernel to Non-MKLDNN OPKernel // Case2 - transfrom from MKLDNN OPKernel to Non-MKLDNN OPKernel
// Do transform via MKLDNN lib // Do transform via MKLDNN lib
paddle::framework::innerTransDataLayoutFromMKLDNN( paddle::framework::innerTransDataLayoutFromMKLDNN(
in_layout, paddle::platform::MKLDNNDeviceContext::tls() in_layout, target_layout, in_tensor, &out_tensor,
.get_cur_paddle_data_layout(), dev_ctx_.GetPlace());
in_tensor, &out_tensor, dev_ctx_.GetPlace());
} }
} else { } else {
// Case3 - transfrom between Non-MKLDNN OPKernels // Case3 - transfrom between Non-MKLDNN OPKernels
...@@ -132,6 +152,7 @@ class TransferLayoutFunctor { ...@@ -132,6 +152,7 @@ class TransferLayoutFunctor {
const platform::DeviceContext &dev_ctx_; const platform::DeviceContext &dev_ctx_;
const int src_layout_; const int src_layout_;
const int dst_layout_; const int dst_layout_;
std::string in_name_;
}; };
} // namespace operators } // namespace operators
......
...@@ -742,6 +742,7 @@ dnnl::stream& MKLDNNDeviceContextThreadLocals::Body::get_stream(void) { ...@@ -742,6 +742,7 @@ dnnl::stream& MKLDNNDeviceContextThreadLocals::Body::get_stream(void) {
} }
void MKLDNNDeviceContext::ResetBlobMap(void* ptr) { void MKLDNNDeviceContext::ResetBlobMap(void* ptr) {
VLOG(4) << tls().get_curr_exec() << " " << ptr;
std::lock_guard<decltype(*p_mutex_)> lock(*p_mutex_); std::lock_guard<decltype(*p_mutex_)> lock(*p_mutex_);
if (!block_next_cache_clearing_) { if (!block_next_cache_clearing_) {
VLOG(3) << "Clearing DNNL cache."; VLOG(3) << "Clearing DNNL cache.";
......
...@@ -563,6 +563,7 @@ inline void RegisterModelLayout( ...@@ -563,6 +563,7 @@ inline void RegisterModelLayout(
std::vector<std::unique_ptr<framework::OperatorBase>>& ops, std::vector<std::unique_ptr<framework::OperatorBase>>& ops,
const platform::Place& place) { const platform::Place& place) {
if (platform::is_cpu_place(place)) { if (platform::is_cpu_place(place)) {
VLOG(4) << "RegisterModelLayout for mkldnn";
auto check_attrib = [](std::unique_ptr<framework::OperatorBase>& op, auto check_attrib = [](std::unique_ptr<framework::OperatorBase>& op,
const std::string& attrib_name) -> bool { const std::string& attrib_name) -> bool {
if (op->HasAttr(attrib_name)) { if (op->HasAttr(attrib_name)) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册