未验证 提交 c56fffb4 编写于 作者: Z zyfncg 提交者: GitHub

optimize performance of dygraph (#42137)

上级 79ac8870
...@@ -402,12 +402,11 @@ std::vector<phi::MetaTensor*> CompatInferMetaContext::MutableOutputBetween( ...@@ -402,12 +402,11 @@ std::vector<phi::MetaTensor*> CompatInferMetaContext::MutableOutputBetween(
CompatInferMetaContext BuildInferMetaContext(InferShapeContext* ctx, CompatInferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
const std::string& op_type) { const std::string& op_type) {
// 1. get kernel args // 1. get kernel args
auto arg_map_fn = phi::OpUtilsMap::Instance().GetArgumentMappingFn(op_type); auto* arg_map_fn = phi::OpUtilsMap::Instance().GetArgumentMappingFn(op_type);
PADDLE_ENFORCE_NOT_NULL(
arg_map_fn, platform::errors::NotFound(
"The ArgumentMappingFn of %s op is not found.", op_type));
InferShapeArgumentMappingContext arg_map_context(*ctx); InferShapeArgumentMappingContext arg_map_context(*ctx);
auto signature = arg_map_fn(arg_map_context); KernelSignature signature =
arg_map_fn ? (*arg_map_fn)(arg_map_context)
: phi::DefaultKernelSignatureMap::Instance().Get(op_type);
VLOG(3) << "BuildInferMetaContext: op kernel signature - " << signature; VLOG(3) << "BuildInferMetaContext: op kernel signature - " << signature;
// 2. build infermeta context // 2. build infermeta context
......
...@@ -2117,8 +2117,16 @@ KernelSignature OperatorWithKernel::GetExpectedPhiKernelArgs( ...@@ -2117,8 +2117,16 @@ KernelSignature OperatorWithKernel::GetExpectedPhiKernelArgs(
const ExecutionContext& ctx) const { const ExecutionContext& ctx) const {
ExecutionArgumentMappingContext arg_mapping_ctx(ctx); ExecutionArgumentMappingContext arg_mapping_ctx(ctx);
if (arg_map_fn_ == nullptr) { if (arg_map_fn_ == nullptr) {
arg_map_fn_.reset(new phi::ArgumentMappingFn( auto* arg_map_fn = phi::OpUtilsMap::Instance().GetArgumentMappingFn(type_);
phi::OpUtilsMap::Instance().GetArgumentMappingFn(Type()))); if (arg_map_fn) {
arg_map_fn_.reset(new phi::ArgumentMappingFn(*arg_map_fn));
} else {
auto func =
[this](const phi::ArgumentMappingContext& ctx) -> KernelSignature {
return phi::DefaultKernelSignatureMap::Instance().Get(type_);
};
arg_map_fn_.reset(new phi::ArgumentMappingFn(func));
}
} }
return (*arg_map_fn_)(arg_mapping_ctx); return (*arg_map_fn_)(arg_mapping_ctx);
} }
......
...@@ -37,6 +37,8 @@ namespace paddle { ...@@ -37,6 +37,8 @@ namespace paddle {
namespace imperative { namespace imperative {
static const phi::Kernel empty_kernel; static const phi::Kernel empty_kernel;
static const framework::RuntimeContext empty_ctx({}, {});
static const framework::Scope empty_scope;
const std::shared_ptr<VariableWrapper>& GetVariableWrapper( const std::shared_ptr<VariableWrapper>& GetVariableWrapper(
const std::shared_ptr<paddle::imperative::VarBase>& var) { const std::shared_ptr<paddle::imperative::VarBase>& var) {
...@@ -138,8 +140,6 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins, ...@@ -138,8 +140,6 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins,
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto* dev_ctx = pool.Get(place); auto* dev_ctx = pool.Get(place);
framework::RuntimeContext ctx({}, {});
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
// MKLDNN variant of code reads attributes in some of GetKernelTypeForVar and // MKLDNN variant of code reads attributes in some of GetKernelTypeForVar and
// GetKernelType functions, so we need to copy the attributes there. // GetKernelType functions, so we need to copy the attributes there.
...@@ -158,7 +158,7 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins, ...@@ -158,7 +158,7 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins,
// 1. get expected kernel key // 1. get expected kernel key
auto dygraph_exe_ctx = DygraphExecutionContext<VarType>( auto dygraph_exe_ctx = DygraphExecutionContext<VarType>(
op, framework::Scope(), *dev_ctx, ctx, ins, outs, attrs, default_attrs); op, empty_scope, *dev_ctx, empty_ctx, ins, outs, attrs, default_attrs);
auto expected_kernel_key = op.GetExpectedKernelType(dygraph_exe_ctx); auto expected_kernel_key = op.GetExpectedKernelType(dygraph_exe_ctx);
framework::KernelSignature pt_kernel_signature; framework::KernelSignature pt_kernel_signature;
...@@ -172,11 +172,26 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins, ...@@ -172,11 +172,26 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins,
paddle::platform::is_in_xpu_black_list(op.Type()); paddle::platform::is_in_xpu_black_list(op.Type());
#endif #endif
if (phi::KernelFactory::Instance().HasCompatiblePhiKernel(op.Type())) {
pt_kernel_signature =
std::move(op.GetExpectedPhiKernelArgs(dygraph_exe_ctx));
VLOG(6) << pt_kernel_signature;
bool has_phi_kernel = false;
const auto* arg_map_fn =
phi::OpUtilsMap::Instance().GetArgumentMappingFn(op.Type());
if (arg_map_fn) {
has_phi_kernel = true;
pt_kernel_signature = (*arg_map_fn)(
framework::ExecutionArgumentMappingContext(dygraph_exe_ctx));
} else {
const auto* kernel_sig =
phi::DefaultKernelSignatureMap::Instance().GetNullable(op.Type());
if (kernel_sig) {
has_phi_kernel = true;
pt_kernel_signature = *kernel_sig;
}
}
if (has_phi_kernel) {
VLOG(6) << pt_kernel_signature;
pt_kernel_name = pt_kernel_signature.name; pt_kernel_name = pt_kernel_signature.name;
// NOTE(Liu-xiandong): The register kernel used KP have library_type[KP], // NOTE(Liu-xiandong): The register kernel used KP have library_type[KP],
// But the default library_type is Plain, so we need to modify the // But the default library_type is Plain, so we need to modify the
...@@ -231,7 +246,7 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins, ...@@ -231,7 +246,7 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins,
dev_ctx = pool.Get(expected_kernel_key.place_); dev_ctx = pool.Get(expected_kernel_key.place_);
} }
return PreparedOp(op, ctx, expected_kernel_key, return PreparedOp(op, empty_ctx, expected_kernel_key,
std::move(pt_kernel_signature), pt_kernel, dev_ctx); std::move(pt_kernel_signature), pt_kernel, dev_ctx);
} else { } else {
VLOG(6) << "Dynamic mode ChoosePhiKernel - kernel `" << pt_kernel_name VLOG(6) << "Dynamic mode ChoosePhiKernel - kernel `" << pt_kernel_name
...@@ -280,7 +295,7 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins, ...@@ -280,7 +295,7 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins,
<< " | kernel key: " << pt_cpu_kernel_key << " | kernel key: " << pt_cpu_kernel_key
<< " | kernel: " << pt_cpu_kernel; << " | kernel: " << pt_cpu_kernel;
auto* cpu_ctx = pool.Get(paddle::platform::CPUPlace()); auto* cpu_ctx = pool.Get(paddle::platform::CPUPlace());
return PreparedOp(op, ctx, expected_kernel_key, return PreparedOp(op, empty_ctx, expected_kernel_key,
std::move(pt_kernel_signature), pt_cpu_kernel, std::move(pt_kernel_signature), pt_cpu_kernel,
cpu_ctx); cpu_ctx);
} }
...@@ -373,7 +388,8 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins, ...@@ -373,7 +388,8 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins,
dev_ctx = pool.Get(expected_kernel_key.place_); dev_ctx = pool.Get(expected_kernel_key.place_);
} }
return PreparedOp(op, ctx, expected_kernel_key, kernel_iter->second, dev_ctx); return PreparedOp(op, empty_ctx, expected_kernel_key, kernel_iter->second,
dev_ctx);
} }
PreparedOp PreparedOp::Prepare(const NameVarMap<VarBase>& ins, PreparedOp PreparedOp::Prepare(const NameVarMap<VarBase>& ins,
......
...@@ -193,7 +193,7 @@ void PhiOpConvertPass::convertStage() { ...@@ -193,7 +193,7 @@ void PhiOpConvertPass::convertStage() {
op->replaceAllUsesWith(kernel_op.getResults()); op->replaceAllUsesWith(kernel_op.getResults());
} else { } else {
::phi::KernelSignature kernel_sign = ::phi::KernelSignature kernel_sign =
::phi::OpUtilsMap::Instance().GetArgumentMappingFn(op_name)( (*::phi::OpUtilsMap::Instance().GetArgumentMappingFn(op_name))(
infrt::ProtoArgumentMappingContext(op)); infrt::ProtoArgumentMappingContext(op));
VLOG(3) << "IncompatiblePhiKernel: op(" << op_name << "), kernel(" VLOG(3) << "IncompatiblePhiKernel: op(" << op_name << "), kernel("
<< kernel_sign.name << ")"; << kernel_sign.name << ")";
......
...@@ -86,6 +86,14 @@ class DefaultKernelSignatureMap { ...@@ -86,6 +86,14 @@ class DefaultKernelSignatureMap {
return it->second; return it->second;
} }
const KernelSignature* GetNullable(const std::string& op_type) const {
auto it = map_.find(op_type);
if (it != map_.end()) {
return &it->second;
}
return nullptr;
}
void Insert(std::string op_type, KernelSignature signature) { void Insert(std::string op_type, KernelSignature signature) {
PADDLE_ENFORCE_NE( PADDLE_ENFORCE_NE(
Has(op_type), Has(op_type),
...@@ -148,16 +156,13 @@ class OpUtilsMap { ...@@ -148,16 +156,13 @@ class OpUtilsMap {
} }
} }
ArgumentMappingFn GetArgumentMappingFn(const std::string& op_type) const { const ArgumentMappingFn* GetArgumentMappingFn(
const std::string& op_type) const {
auto it = arg_mapping_fn_map_.find(op_type); auto it = arg_mapping_fn_map_.find(op_type);
if (it == arg_mapping_fn_map_.end()) { if (it == arg_mapping_fn_map_.end()) {
auto func = return nullptr;
[&op_type](const ArgumentMappingContext& ctx) -> KernelSignature {
return DefaultKernelSignatureMap::Instance().Get(op_type);
};
return func;
} else { } else {
return it->second; return &it->second;
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册