未验证 提交 afa0e82c 编写于 作者: L Leo Chen 提交者: GitHub

[new-exec] fit for mkldnn and inplace op (#40955)

* fit for mkldnn and inplace op

* fix compile

* refine ut

* register op version

* fix inplace op

* fix transfer_layout
上级 de8962bd
...@@ -149,7 +149,8 @@ std::shared_ptr<OperatorBase> TransferLayout(const std::string& var_name, ...@@ -149,7 +149,8 @@ std::shared_ptr<OperatorBase> TransferLayout(const std::string& var_name,
// 2. Construct VariableNameMap // 2. Construct VariableNameMap
VariableNameMap in_name_map = {{"X", {var_name}}}; VariableNameMap in_name_map = {{"X", {var_name}}};
VariableNameMap out_name_map = {{"Out", {*new_var_name}}}; VariableNameMap out_name_map = {{"Out", {*new_var_name}}};
AttributeMap attr_map = {{"dst_layout", static_cast<int>(out_layout)}}; AttributeMap attr_map = {{"src_layout", static_cast<int>(in_layout)},
{"dst_layout", static_cast<int>(out_layout)}};
// 3. Create transfer_layout_op // 3. Create transfer_layout_op
std::string op_type("transfer_layout"); std::string op_type("transfer_layout");
...@@ -157,8 +158,9 @@ std::shared_ptr<OperatorBase> TransferLayout(const std::string& var_name, ...@@ -157,8 +158,9 @@ std::shared_ptr<OperatorBase> TransferLayout(const std::string& var_name,
auto op = std::shared_ptr<OperatorBase>( auto op = std::shared_ptr<OperatorBase>(
op_info.Creator()(op_type, in_name_map, out_name_map, attr_map)); op_info.Creator()(op_type, in_name_map, out_name_map, attr_map));
VLOG(3) << string::Sprintf("Insert %s(%s) with %s -> %s(%s).", op_type, VLOG(3) << string::Sprintf("Insert %s for variable %s(%s) -> %s(%s).",
var_name, in_layout, *new_var_name, out_layout); op_type, var_name, in_layout, *new_var_name,
out_layout);
return op; return op;
} }
...@@ -242,6 +244,7 @@ std::shared_ptr<OperatorBase> TransferDevice(const std::string& var_name, ...@@ -242,6 +244,7 @@ std::shared_ptr<OperatorBase> TransferDevice(const std::string& var_name,
void ApplyDataTransform(const OpKernelType& expected_kernel_key, void ApplyDataTransform(const OpKernelType& expected_kernel_key,
const platform::Place& place, const platform::Place& place,
VariableValueMap* ins_map_temp, VariableValueMap* ins_map_temp,
VariableValueMap* outs_map_temp,
VariableScope* var_scope, OpFuncNode* op_func_node, VariableScope* var_scope, OpFuncNode* op_func_node,
std::vector<OpFuncNode>* new_op_func_nodes, std::vector<OpFuncNode>* new_op_func_nodes,
bool use_local_scope) { bool use_local_scope) {
...@@ -251,6 +254,7 @@ void ApplyDataTransform(const OpKernelType& expected_kernel_key, ...@@ -251,6 +254,7 @@ void ApplyDataTransform(const OpKernelType& expected_kernel_key,
"op_base in apply_data_transform.")); "op_base in apply_data_transform."));
VariableNameMap new_ins(op_base->Inputs()); VariableNameMap new_ins(op_base->Inputs());
VariableNameMap new_outs(op_base->Outputs());
// record the no need transform variable index. // record the no need transform variable index.
std::unordered_set<int> no_data_transform_index; std::unordered_set<int> no_data_transform_index;
...@@ -258,7 +262,7 @@ void ApplyDataTransform(const OpKernelType& expected_kernel_key, ...@@ -258,7 +262,7 @@ void ApplyDataTransform(const OpKernelType& expected_kernel_key,
for (auto& var_name_item : *ins_map_temp) { for (auto& var_name_item : *ins_map_temp) {
for (size_t i = 0; i < var_name_item.second.size(); ++i) { for (size_t i = 0; i < var_name_item.second.size(); ++i) {
auto var = var_name_item.second[i]; auto var = var_name_item.second[i];
auto& var_name = new_ins[var_name_item.first].at(i); auto var_name = new_ins[var_name_item.first].at(i);
const Tensor* tensor_in; const Tensor* tensor_in;
if (var->IsType<LoDTensor>() || var->IsType<phi::SelectedRows>()) { if (var->IsType<LoDTensor>() || var->IsType<phi::SelectedRows>()) {
tensor_in = GetLoDTensorOrSelectedRowsValueFromVar(*var); tensor_in = GetLoDTensorOrSelectedRowsValueFromVar(*var);
...@@ -287,6 +291,28 @@ void ApplyDataTransform(const OpKernelType& expected_kernel_key, ...@@ -287,6 +291,28 @@ void ApplyDataTransform(const OpKernelType& expected_kernel_key,
var_scope->VarId(new_var_name); var_scope->VarId(new_var_name);
var_name_item.second[i] = var_scope->Var(new_var_name); var_name_item.second[i] = var_scope->Var(new_var_name);
new_ins[var_name_item.first][i] = new_var_name; new_ins[var_name_item.first][i] = new_var_name;
for (auto& pair : new_outs) {
for (size_t j = 0; j < pair.second.size(); ++j) {
VLOG(4) << pair.second[j] << " " << var_name;
if (pair.second[j] == var_name) {
VLOG(4) << "Found inplace between input(" << var_name_item.first
<< ") and output(" << pair.first
<< "), the variable name is " << var_name;
(*outs_map_temp)[pair.first][j] = var_scope->Var(new_var_name);
new_outs[pair.first][j] = new_var_name;
op_func_node
->inplace_back_map[var_scope->GetIdByName(new_var_name)] =
var_scope->GetIdByName(var_name);
op_func_node->output_index[pair.first][j] =
var_scope->VarId(new_var_name);
// NOTE(zhiqiu): The inplace op with `transfer` also changes
// original output after that
// so add original output as well
op_func_node->output_index[pair.first].push_back(
var_scope->VarId(var_name));
}
}
}
// NOTE(Aurelius84): avoid deepcopy twice if we already insert data // NOTE(Aurelius84): avoid deepcopy twice if we already insert data
// transfer op. // transfer op.
if (op_base->Type() == "fetch_v2") { if (op_base->Type() == "fetch_v2") {
...@@ -306,7 +332,7 @@ void ApplyDataTransform(const OpKernelType& expected_kernel_key, ...@@ -306,7 +332,7 @@ void ApplyDataTransform(const OpKernelType& expected_kernel_key,
// with instruction. (hot fix, it is not good design here) // with instruction. (hot fix, it is not good design here)
op_func_node->operator_base_ = op_func_node->operator_base_ =
std::shared_ptr<OperatorBase>(framework::OpRegistry::CreateOp( std::shared_ptr<OperatorBase>(framework::OpRegistry::CreateOp(
op_base->Type(), new_ins, op_base->Outputs(), op_base->Attrs())); op_base->Type(), new_ins, new_outs, op_base->Attrs()));
op_func_node->no_data_transform_index = std::move(no_data_transform_index); op_func_node->no_data_transform_index = std::move(no_data_transform_index);
} }
......
...@@ -54,6 +54,7 @@ class DataTranferHelper { ...@@ -54,6 +54,7 @@ class DataTranferHelper {
void ApplyDataTransform(const OpKernelType& expected_kernel_key, void ApplyDataTransform(const OpKernelType& expected_kernel_key,
const platform::Place& place, const platform::Place& place,
VariableValueMap* ins_map_temp, VariableValueMap* ins_map_temp,
VariableValueMap* outs_map_temp,
VariableScope* var_scope, OpFuncNode* op_func_node, VariableScope* var_scope, OpFuncNode* op_func_node,
std::vector<OpFuncNode>* op_func_nodes, std::vector<OpFuncNode>* op_func_nodes,
bool use_local_scope = true); bool use_local_scope = true);
......
...@@ -457,6 +457,21 @@ void InterpreterCore::RunInstruction(const Instruction& instr_node) { ...@@ -457,6 +457,21 @@ void InterpreterCore::RunInstruction(const Instruction& instr_node) {
VLOG(4) << "End run " << place << " " << op->DebugStringEx(global_scope_); VLOG(4) << "End run " << place << " " << op->DebugStringEx(global_scope_);
if (!instr_node.InplaceBackMap().empty()) {
auto& m = instr_node.InplaceBackMap();
// NOTE(zhiqiu): same logic as TransferInplaceVarsBack() in operator.cc
for (auto& p : m) {
auto* transformed_tensor = GetMutableLoDTensorOrSelectedRowsValueFromVar(
global_scope_->Var(p.first));
auto* original_tensor = GetMutableLoDTensorOrSelectedRowsValueFromVar(
global_scope_->Var(p.second));
original_tensor->ShareDataWith(*transformed_tensor);
VLOG(4) << "Transfer inplace variable back form "
<< global_scope_->GetNameById(p.first) << " to "
<< global_scope_->GetNameById(p.second);
}
}
/*For profiling/benchmark only*/ /*For profiling/benchmark only*/
if (FLAGS_benchmark) { if (FLAGS_benchmark) {
instr_node.DeviceContext().Wait(); instr_node.DeviceContext().Wait();
......
...@@ -138,7 +138,9 @@ get_unused_vars(const BlockDesc& block, ...@@ -138,7 +138,9 @@ get_unused_vars(const BlockDesc& block,
size_t op_idx = name_op_idx_pair.second; size_t op_idx = name_op_idx_pair.second;
result[ops[op_idx].get()].emplace_back(name); result[ops[op_idx].get()].emplace_back(name);
VLOG(4) << ops[op_idx].get()->Type() << " " << name;
} }
VLOG(4) << "gc map size:" << result.size();
return result; return result;
} }
...@@ -311,8 +313,8 @@ void build_op_func_list(const platform::Place& place, ...@@ -311,8 +313,8 @@ void build_op_func_list(const platform::Place& place,
operators::PrepareSafeEagerDeletionOnRecurrentOpAndRecurrentGradOp( operators::PrepareSafeEagerDeletionOnRecurrentOpAndRecurrentGradOp(
main_program, block.ID(), ops_unique); main_program, block.ID(), ops_unique);
std::vector<std::shared_ptr<OperatorBase>> // its elements will be moved to vec_func_list
ops; // its elements will be moved to vec_func_list std::vector<std::shared_ptr<OperatorBase>> ops;
for (auto& op_unique : ops_unique) { for (auto& op_unique : ops_unique) {
ops.emplace_back(std::move(op_unique)); ops.emplace_back(std::move(op_unique));
} }
...@@ -348,34 +350,28 @@ void build_op_func_list(const platform::Place& place, ...@@ -348,34 +350,28 @@ void build_op_func_list(const platform::Place& place,
op_func_node.operator_base_ = ops[i]; op_func_node.operator_base_ = ops[i];
op_func_node.input_index = ins_name2id; op_func_node.input_index = ins_name2id;
op_func_node.output_index = outs_name2id; op_func_node.output_index = outs_name2id;
VLOG(4) << "Start run " << place << " " << op->DebugStringEx(local_scope);
if (dynamic_cast<const framework::OperatorWithKernel*>(op) == nullptr) { if (dynamic_cast<framework::OperatorWithKernel*>(op) == nullptr) {
// op is not a operatorwithkernel, so direcly run OperatorBase::Run() // op is not a operatorwithkernel, so direcly run OperatorBase::Run()
deal_operator_base(place, var_scope, ops[i], &op_func_node, local_scope); deal_operator_base(place, var_scope, ops[i], &op_func_node, local_scope);
VLOG(4) << "End run " << place << " "
<< op_func_node.operator_base_->DebugStringEx(local_scope);
} else { } else {
auto op_with_kernel = auto op_with_kernel = const_cast<framework::OperatorWithKernel*>(
static_cast<const framework::OperatorWithKernel*>(op); static_cast<const framework::OperatorWithKernel*>(op));
// construct RuntimeContext and analysis KernelType // construct RuntimeContext and analysis KernelType
RuntimeContext runtime_context({}, {}); RuntimeContext runtime_context({}, {});
runtime_context.inputs.swap(ins_map); runtime_context.inputs.swap(ins_map);
runtime_context.outputs.swap(outs_map); runtime_context.outputs.swap(outs_map);
// see OperatorWithKernel::RunImpl in operator.cc for why
if (!(op->HasAttr(kAllKernelsMustComputeRuntimeShape) &&
op->Attr<bool>(kAllKernelsMustComputeRuntimeShape))) {
InterpretercoreInferShapeContext infer_shape_ctx(*op, runtime_context);
// TODO(Aurelius84): In case of control flow ops, they are NOT
// inheritted
// from OperatorWithKernel.
op_with_kernel->Info().infer_shape_(&infer_shape_ctx);
}
platform::DeviceContextPool& pool = platform::DeviceContextPool& pool =
platform::DeviceContextPool::Instance(); platform::DeviceContextPool::Instance();
auto* dev_ctx = pool.Get(place); auto* dev_ctx = pool.Get(place);
Scope scope; Scope scope;
auto expected_kernel_key = op_with_kernel->GetExpectedKernelType( auto expected_kernel_key = op_with_kernel->GetExpectedKernelType(
ExecutionContext(*op, scope, *dev_ctx, runtime_context)); ExecutionContext(*op, scope, *dev_ctx, runtime_context));
op_with_kernel->ResetKernelType(new OpKernelType(expected_kernel_key));
// change device by the device_guard() // change device by the device_guard()
apply_device_guard(op, place, &expected_kernel_key); apply_device_guard(op, place, &expected_kernel_key);
...@@ -383,13 +379,16 @@ void build_op_func_list(const platform::Place& place, ...@@ -383,13 +379,16 @@ void build_op_func_list(const platform::Place& place,
// step 3. apply data transforms and insert data transfer ops // step 3. apply data transforms and insert data transfer ops
VariableValueMap& ins_map_temp = runtime_context.inputs; VariableValueMap& ins_map_temp = runtime_context.inputs;
VariableValueMap& outs_map_temp = runtime_context.outputs;
// NOTE(zhiqiu): op_func_node->operator_base_ maybe changed in // NOTE(zhiqiu): op_func_node->operator_base_ maybe changed in
// ApplyDataTransform // ApplyDataTransform
ApplyDataTransform(expected_kernel_key, place, &ins_map_temp, var_scope, ApplyDataTransform(expected_kernel_key, place, &ins_map_temp,
&op_func_node, vec_func_list, use_local_scope); &outs_map_temp, var_scope, &op_func_node,
op_with_kernel = static_cast<const framework::OperatorWithKernel*>( vec_func_list, use_local_scope);
op_func_node.operator_base_.get()); op_with_kernel = const_cast<framework::OperatorWithKernel*>(
static_cast<const framework::OperatorWithKernel*>(
op_func_node.operator_base_.get()));
// step 4. Run op kernel // step 4. Run op kernel
VLOG(3) << op_with_kernel->Type() VLOG(3) << op_with_kernel->Type()
...@@ -412,6 +411,16 @@ void build_op_func_list(const platform::Place& place, ...@@ -412,6 +411,16 @@ void build_op_func_list(const platform::Place& place,
auto exec_ctx = auto exec_ctx =
ExecutionContext(*op_with_kernel, scope, *dev_ctx, runtime_context); ExecutionContext(*op_with_kernel, scope, *dev_ctx, runtime_context);
// see OperatorWithKernel::RunImpl in operator.cc for why
if (!(op->HasAttr(kAllKernelsMustComputeRuntimeShape) &&
op->Attr<bool>(kAllKernelsMustComputeRuntimeShape))) {
InterpretercoreInferShapeContext infer_shape_ctx(*op, runtime_context);
// TODO(Aurelius84): In case of control flow ops, they are NOT
// inheritted
// from OperatorWithKernel.
op_with_kernel->Info().infer_shape_(&infer_shape_ctx);
}
auto run_phi_kernel = false; auto run_phi_kernel = false;
if (phi::KernelFactory::Instance().HasCompatiblePhiKernel( if (phi::KernelFactory::Instance().HasCompatiblePhiKernel(
op_with_kernel->Type())) { op_with_kernel->Type())) {
...@@ -476,9 +485,28 @@ void build_op_func_list(const platform::Place& place, ...@@ -476,9 +485,28 @@ void build_op_func_list(const platform::Place& place,
op_func_node, place, outputs_names, &runtime_context.outputs, op_func_node, place, outputs_names, &runtime_context.outputs,
var_scope, vec_func_list, local_scope); var_scope, vec_func_list, local_scope);
} }
if (!op_func_node.inplace_back_map.empty()) {
auto& m = op_func_node.inplace_back_map;
// NOTE(zhiqiu): same logic as TransferInplaceVarsBack() in operator.cc
for (auto& p : m) {
auto* transformed_tensor =
GetMutableLoDTensorOrSelectedRowsValueFromVar(
var_scope->Var(p.first));
auto* original_tensor = GetMutableLoDTensorOrSelectedRowsValueFromVar(
var_scope->Var(p.second));
original_tensor->ShareDataWith(*transformed_tensor);
VLOG(4) << "Transfer inplace variable back form "
<< var_scope->GetNameById(p.first) << " to "
<< var_scope->GetNameById(p.second);
}
}
} }
VLOG(4) << "End run " << place << " "
<< op_func_node.operator_base_->DebugStringEx(local_scope);
vec_func_list->emplace_back(op_func_node); vec_func_list->emplace_back(op_func_node);
// gc--------------------------------------------------------------------------- // gc---------------------------------------------------------------------------
auto iter = unused_var_map.find(op); auto iter = unused_var_map.find(op);
if (iter == unused_var_map.end()) { if (iter == unused_var_map.end()) {
...@@ -514,10 +542,7 @@ void build_op_func_list(const platform::Place& place, ...@@ -514,10 +542,7 @@ void build_op_func_list(const platform::Place& place,
framework::ToTypeName(var->Type()), var_name)); framework::ToTypeName(var->Type()), var_name));
} }
} }
delete garbages; // free mem delete garbages; // free mem
VLOG(3) << "run " << op->Type() << " done.";
} }
} }
......
...@@ -692,6 +692,10 @@ phi::Kernel* Instruction::PhiKernel() const { return op_func_node_.pt_kernel_; } ...@@ -692,6 +692,10 @@ phi::Kernel* Instruction::PhiKernel() const { return op_func_node_.pt_kernel_; }
OpFuncType Instruction::KernelType() const { return op_func_node_.type_; } OpFuncType Instruction::KernelType() const { return op_func_node_.type_; }
const std::map<int, int>& Instruction::InplaceBackMap() const {
return op_func_node_.inplace_back_map;
}
OperatorBase* Instruction::OpBase() const { OperatorBase* Instruction::OpBase() const {
auto op_base = op_func_node_.operator_base_; auto op_base = op_func_node_.operator_base_;
PADDLE_ENFORCE_NOT_NULL(op_base, platform::errors::PreconditionNotMet( PADDLE_ENFORCE_NOT_NULL(op_base, platform::errors::PreconditionNotMet(
......
...@@ -297,6 +297,8 @@ struct OpFuncNode { ...@@ -297,6 +297,8 @@ struct OpFuncNode {
std::map<std::string, std::vector<int>> output_index; std::map<std::string, std::vector<int>> output_index;
std::unordered_set<int> no_data_transform_index; std::unordered_set<int> no_data_transform_index;
std::map<int, int> inplace_back_map;
OpKernelComputeFunc kernel_func_; OpKernelComputeFunc kernel_func_;
platform::DeviceContext* dev_ctx_; // not owned platform::DeviceContext* dev_ctx_; // not owned
...@@ -325,6 +327,8 @@ class Instruction { ...@@ -325,6 +327,8 @@ class Instruction {
OpFuncType KernelType() const; OpFuncType KernelType() const;
const std::map<int, int>& InplaceBackMap() const;
OperatorBase* OpBase() const; OperatorBase* OpBase() const;
NextInstruction& NextInstructions(); NextInstruction& NextInstructions();
......
...@@ -664,6 +664,10 @@ class OperatorWithKernel : public OperatorBase { ...@@ -664,6 +664,10 @@ class OperatorWithKernel : public OperatorBase {
const OpKernelType* kernel_type() const { return kernel_type_.get(); } const OpKernelType* kernel_type() const { return kernel_type_.get(); }
void ResetKernelType(OpKernelType* kernel_type) {
kernel_type_.reset(kernel_type);
}
private: private:
void RunImpl(const Scope& scope, const platform::Place& place) const final; void RunImpl(const Scope& scope, const platform::Place& place) const final;
void RunImpl(const Scope& scope, const platform::Place& place, void RunImpl(const Scope& scope, const platform::Place& place,
......
...@@ -94,7 +94,8 @@ void BatchNormOp::InferShape(framework::InferShapeContext *ctx) const { ...@@ -94,7 +94,8 @@ void BatchNormOp::InferShape(framework::InferShapeContext *ctx) const {
"must smaller than or equal to 5. But received: the shape of input X " "must smaller than or equal to 5. But received: the shape of input X "
"= [%s], the dimension of input X = [%d]", "= [%s], the dimension of input X = [%d]",
x_dims, x_dims.size())); x_dims, x_dims.size()));
VLOG(4) << ctx->IsRunMKLDNNKernel();
VLOG(4) << data_layout;
const int64_t C = const int64_t C =
((ctx->IsRunMKLDNNKernel() == true) || (data_layout == DataLayout::kNCHW) ((ctx->IsRunMKLDNNKernel() == true) || (data_layout == DataLayout::kNCHW)
? x_dims[1] ? x_dims[1]
...@@ -136,6 +137,7 @@ void BatchNormOp::InferShape(framework::InferShapeContext *ctx) const { ...@@ -136,6 +137,7 @@ void BatchNormOp::InferShape(framework::InferShapeContext *ctx) const {
C, bias_dim[0])); C, bias_dim[0]));
} }
ctx->SetOutputDim("Y", x_dims); ctx->SetOutputDim("Y", x_dims);
VLOG(4) << x_dims;
ctx->SetOutputDim("MeanOut", {C}); ctx->SetOutputDim("MeanOut", {C});
ctx->SetOutputDim("VarianceOut", {C}); ctx->SetOutputDim("VarianceOut", {C});
ctx->SetOutputDim("SavedMean", {C}); ctx->SetOutputDim("SavedMean", {C});
......
...@@ -203,14 +203,12 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -203,14 +203,12 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto *y = ctx.Output<Tensor>("Y"); auto *y = ctx.Output<Tensor>("Y");
auto *batch_mean = ctx.Output<Tensor>("SavedMean"); auto *batch_mean = ctx.Output<Tensor>("SavedMean");
auto *batch_variance = ctx.Output<Tensor>("SavedVariance"); auto *batch_variance = ctx.Output<Tensor>("SavedVariance");
BatchNormMKLDNNHandler<T> handler(ctx, mkldnn_engine, x, global_stats, BatchNormMKLDNNHandler<T> handler(ctx, mkldnn_engine, x, global_stats,
test_mode); test_mode);
auto src_memory = handler.AcquireSrcMemory(x); auto src_memory = handler.AcquireSrcMemory(x);
auto scaleshift_memory = handler.AcquireScaleShiftMemory(scale, shift); auto scaleshift_memory = handler.AcquireScaleShiftMemory(scale, shift);
auto dst_memory = handler.AcquireDstMemory(y); auto dst_memory = handler.AcquireDstMemory(y);
auto batch_norm_p = handler.AcquireForwardPrimitive(); auto batch_norm_p = handler.AcquireForwardPrimitive();
std::shared_ptr<memory> mean_memory; std::shared_ptr<memory> mean_memory;
...@@ -300,7 +298,6 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { ...@@ -300,7 +298,6 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
auto diff_src_memory = handler.AcquireDiffSrcMemory(diff_x); auto diff_src_memory = handler.AcquireDiffSrcMemory(diff_x);
auto diff_scaleshift_memory = auto diff_scaleshift_memory =
handler.AcquireDiffScaleShiftMemory(diff_scaleshift_data.data()); handler.AcquireDiffScaleShiftMemory(diff_scaleshift_data.data());
// finally create batch_norm backward primitive // finally create batch_norm backward primitive
auto batch_norm_bwd_p = handler.AcquireBackwardPrimitive(); auto batch_norm_bwd_p = handler.AcquireBackwardPrimitive();
......
...@@ -16,6 +16,8 @@ ...@@ -16,6 +16,8 @@
#include <string> #include <string>
#include "paddle/fluid/framework/op_version_registry.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
class OpDesc; class OpDesc;
...@@ -95,8 +97,9 @@ class TransferLayoutKernel { ...@@ -95,8 +97,9 @@ class TransferLayoutKernel {
auto *x = ctx.InputVar("X"); auto *x = ctx.InputVar("X");
auto *out = ctx.OutputVar("Out"); auto *out = ctx.OutputVar("Out");
auto &dev_ctx = ctx.device_context(); auto &dev_ctx = ctx.device_context();
auto src_layout = ctx.Attr<int>("src_layout");
auto dst_layout = ctx.Attr<int>("dst_layout"); auto dst_layout = ctx.Attr<int>("dst_layout");
TransferLayoutFunctor(x, out, dev_ctx, dst_layout)(); TransferLayoutFunctor(x, out, dev_ctx, src_layout, dst_layout)();
} }
}; };
...@@ -105,6 +108,14 @@ class TransferLayoutOpProtoMaker : public framework::OpProtoAndCheckerMaker { ...@@ -105,6 +108,14 @@ class TransferLayoutOpProtoMaker : public framework::OpProtoAndCheckerMaker {
void Make() override { void Make() override {
AddInput("X", "(LoDTensor) The input Tensor"); AddInput("X", "(LoDTensor) The input Tensor");
AddOutput("Out", "(LoDTensor) The Output Tensor with desired layout"); AddOutput("Out", "(LoDTensor) The Output Tensor with desired layout");
// NOTE(zhiqiu): in most case, the src_layout is not needed, the op can use
// the layout
// of input X. However, in some mkldnn kernel, the src layout computed by
// GetKernelTypeForVar is different with the layout of tensor X.
AddAttr<int>("src_layout",
"kAnyLayout = 0, kNHWC = 1, kNCHW = 2, kMKLDNN = 3, default "
"-1 means unspecified and use the tensor's layout.")
.SetDefault(-1);
AddAttr<int>("dst_layout", AddAttr<int>("dst_layout",
"kAnyLayout = 0, kNHWC = 1, kNCHW = 2, kMKLDNN = 3"); "kAnyLayout = 0, kNHWC = 1, kNCHW = 2, kMKLDNN = 3");
AddComment(R"DOC( AddComment(R"DOC(
...@@ -126,3 +137,8 @@ REGISTER_OPERATOR( ...@@ -126,3 +137,8 @@ REGISTER_OPERATOR(
// dtype is not important // dtype is not important
REGISTER_OP_CPU_KERNEL_FUNCTOR(transfer_layout, float, REGISTER_OP_CPU_KERNEL_FUNCTOR(transfer_layout, float,
ops::TransferLayoutKernel); ops::TransferLayoutKernel);
REGISTER_OP_VERSION(transfer_layout)
.AddCheckpoint(
R"ROC(refine transfer_layout, add src_layout attribute)ROC",
paddle::framework::compatible::OpVersionDesc().NewAttr(
"src_layout", "(int, the layout of the input tensor", -1));
...@@ -39,8 +39,12 @@ class TransferLayoutFunctor { ...@@ -39,8 +39,12 @@ class TransferLayoutFunctor {
public: public:
TransferLayoutFunctor(const framework::Variable *in, framework::Variable *out, TransferLayoutFunctor(const framework::Variable *in, framework::Variable *out,
const platform::DeviceContext &dev_ctx, const platform::DeviceContext &dev_ctx,
const int dst_layout) const int src_layout, const int dst_layout)
: in_(in), out_(out), dev_ctx_(dev_ctx), dst_layout_(dst_layout) {} : in_(in),
out_(out),
dev_ctx_(dev_ctx),
src_layout_(src_layout),
dst_layout_(dst_layout) {}
void operator()() const { void operator()() const {
auto &in_tensor = *framework::GetLoDTensorOrSelectedRowsValueFromVar(*in_); auto &in_tensor = *framework::GetLoDTensorOrSelectedRowsValueFromVar(*in_);
...@@ -50,7 +54,8 @@ class TransferLayoutFunctor { ...@@ -50,7 +54,8 @@ class TransferLayoutFunctor {
out_tensor.set_layout(out_layout); out_tensor.set_layout(out_layout);
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
auto in_layout = in_tensor.layout(); auto in_layout = static_cast<DataLayout>(src_layout_);
VLOG(4) << in_layout << "->" << out_layout << " " << in_tensor.layout();
if (in_layout == DataLayout::kMKLDNN || out_layout == DataLayout::kMKLDNN) { if (in_layout == DataLayout::kMKLDNN || out_layout == DataLayout::kMKLDNN) {
PADDLE_ENFORCE_NE( PADDLE_ENFORCE_NE(
in_layout, out_layout, in_layout, out_layout,
...@@ -68,6 +73,7 @@ class TransferLayoutFunctor { ...@@ -68,6 +73,7 @@ class TransferLayoutFunctor {
// For NHWC data we need reshape of tensors as MKL-DNN // For NHWC data we need reshape of tensors as MKL-DNN
// is expecting NHWC dims description order // is expecting NHWC dims description order
if (in_layout == DataLayout::kNHWC) { if (in_layout == DataLayout::kNHWC) {
VLOG(4) << "kNHWC";
platform::MatchShapeToLayout(&out_tensor, in_layout, out_layout); platform::MatchShapeToLayout(&out_tensor, in_layout, out_layout);
paddle::platform::MKLDNNDeviceContext::tls() paddle::platform::MKLDNNDeviceContext::tls()
.set_cur_paddle_data_layout(in_layout); .set_cur_paddle_data_layout(in_layout);
...@@ -75,6 +81,7 @@ class TransferLayoutFunctor { ...@@ -75,6 +81,7 @@ class TransferLayoutFunctor {
out_tensor.set_layout(DataLayout::kMKLDNN); out_tensor.set_layout(DataLayout::kMKLDNN);
out_tensor.set_format(out_format); out_tensor.set_format(out_format);
} else { } else {
VLOG(4) << "kNCHW";
// Case2 - transfrom from MKLDNN OPKernel to Non-MKLDNN OPKernel // Case2 - transfrom from MKLDNN OPKernel to Non-MKLDNN OPKernel
// Do transform via MKLDNN lib // Do transform via MKLDNN lib
paddle::framework::innerTransDataLayoutFromMKLDNN( paddle::framework::innerTransDataLayoutFromMKLDNN(
...@@ -123,6 +130,7 @@ class TransferLayoutFunctor { ...@@ -123,6 +130,7 @@ class TransferLayoutFunctor {
const framework::Variable *in_; const framework::Variable *in_;
framework::Variable *out_; framework::Variable *out_;
const platform::DeviceContext &dev_ctx_; const platform::DeviceContext &dev_ctx_;
const int src_layout_;
const int dst_layout_; const int dst_layout_;
}; };
......
...@@ -531,6 +531,7 @@ Eigen::GpuDevice* CUDADeviceContext::eigen_device() const { ...@@ -531,6 +531,7 @@ Eigen::GpuDevice* CUDADeviceContext::eigen_device() const {
} }
void CUDADeviceContext::Wait() const { void CUDADeviceContext::Wait() const {
VLOG(4) << "CUDA context(" << this << ") Wait";
if (thread_ctx_.count(this)) { if (thread_ctx_.count(this)) {
context()->Stream()->Wait(); context()->Stream()->Wait();
return; return;
......
...@@ -352,5 +352,23 @@ class TestException(unittest.TestCase): ...@@ -352,5 +352,23 @@ class TestException(unittest.TestCase):
self.fetch_vars.name)) self.fetch_vars.name))
class TestInplaceApiWithDataTransform(unittest.TestCase):
def test_increment(self):
if paddle.fluid.core.is_compiled_with_cuda():
with paddle.fluid.device_guard("gpu:0"):
x = paddle.fluid.layers.fill_constant([1], "float32", 0)
with paddle.fluid.device_guard("cpu"):
x = paddle.increment(x)
exe = paddle.static.Executor(paddle.CUDAPlace(0))
os.environ['FLAGS_USE_STANDALONE_EXECUTOR'] = '1'
for i in range(10):
a, = exe.run(paddle.static.default_main_program(),
fetch_list=[x])
self.assertEqual(a[0], 1)
del os.environ['FLAGS_USE_STANDALONE_EXECUTOR']
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -30,6 +30,7 @@ class TestTransferLayoutOpkNCHWTokNHWC(OpTest): ...@@ -30,6 +30,7 @@ class TestTransferLayoutOpkNCHWTokNHWC(OpTest):
self.inputs = {'X': ipt.astype('float32')} self.inputs = {'X': ipt.astype('float32')}
self.outputs = {'Out': ipt.transpose([0, 2, 3, 1])} self.outputs = {'Out': ipt.transpose([0, 2, 3, 1])}
self.attrs = { self.attrs = {
'src_layout': 0,
'dst_layout': 1 # kNHWC 'dst_layout': 1 # kNHWC
} }
self.op_type = 'transfer_layout' self.op_type = 'transfer_layout'
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册