From 3a0d7bf0d9612b8e69f71f5c352d03e50bd95065 Mon Sep 17 00:00:00 2001 From: Chen Weihang Date: Mon, 25 Apr 2022 07:48:03 +0800 Subject: [PATCH] Optimize dygraph GetExpectedKernelType perf (#42154) * opt dygraph scheduling * revert part impl --- paddle/fluid/framework/operator.cc | 47 ++++++++++++++++++--- paddle/fluid/framework/operator.h | 12 +++--- paddle/fluid/imperative/execution_context.h | 18 +++++--- paddle/fluid/operators/transpose_op.cc | 2 +- paddle/phi/core/kernel_context.h | 8 ++-- 5 files changed, 68 insertions(+), 19 deletions(-) diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index da082f5d26..945b8a8984 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -940,7 +940,7 @@ class RuntimeInferShapeContext : public InferShapeContext { return ((op_with_kernel.kernel_type()) && (op_with_kernel.kernel_type()->data_layout_ == framework::DataLayout::kMKLDNN)); - } catch (std::bad_cast exp) { + } catch (const std::bad_cast& exp) { return false; } } @@ -1965,6 +1965,36 @@ Scope* OperatorWithKernel::PrepareData( } void OperatorWithKernel::ParseInputDataType( + const Variable* var, const std::string& name, + proto::VarType::Type* data_type) const { + if (var != nullptr) { + const Tensor* t = nullptr; + if (var->IsType()) { + t = &var->Get(); + } else if (var->IsType()) { + t = &var->Get(); + } else if (var->IsType()) { + t = &(var->Get().value()); + } else if (var->IsType()) { + auto t_arr = &var->Get(); + for (size_t j = 0; j < t_arr->size(); j++) { + if (t_arr->at(j).IsInitialized()) { + t = &(t_arr->at(j)); + } + } + } + if (t != nullptr) { + PADDLE_ENFORCE_EQ( + t->IsInitialized(), true, + platform::errors::InvalidArgument("The %s Op's Input Variable `%s` " + "contains uninitialized Tensor.", + Type(), name)); + *data_type = paddle::framework::TransToProtoVarType(t->dtype()); + } + } +} + +void OperatorWithKernel::ParseMultiInputDataType( const std::vector& vars, const std::string& name, proto::VarType::Type* data_type) const { proto::VarType::Type default_data_type = @@ -2015,9 +2045,12 @@ proto::VarType::Type OperatorWithKernel::IndicateDataType( proto::VarType::Type dafault_data_type = static_cast(-1); proto::VarType::Type data_type = dafault_data_type; - for (auto& input : ctx.InNameList()) { - const std::vector vars = ctx.MultiInputVar(input); - ParseInputDataType(vars, input, &data_type); + for (auto* name : ctx.InNameList()) { + if (ctx.InputSize(*name) == 1UL) { + ParseInputDataType(ctx.InputVar(*name), *name, &data_type); + } else { + ParseMultiInputDataType(ctx.MultiInputVar(*name), *name, &data_type); + } } PADDLE_ENFORCE_NE( data_type, dafault_data_type, @@ -2031,7 +2064,11 @@ proto::VarType::Type OperatorWithKernel::IndicateVarDataType( proto::VarType::Type dafault_data_type = static_cast(-1); proto::VarType::Type data_type = dafault_data_type; - ParseInputDataType(ctx.MultiInputVar(name), name, &data_type); + if (ctx.InputSize(name) == 1UL) { + ParseInputDataType(ctx.InputVar(name), name, &data_type); + } else { + ParseMultiInputDataType(ctx.MultiInputVar(name), name, &data_type); + } PADDLE_ENFORCE_NE( data_type, dafault_data_type, platform::errors::InvalidArgument( diff --git a/paddle/fluid/framework/operator.h b/paddle/fluid/framework/operator.h index d85e812505..dd21be12f4 100644 --- a/paddle/fluid/framework/operator.h +++ b/paddle/fluid/framework/operator.h @@ -333,12 +333,12 @@ class ExecutionContext { return it->second; } - virtual std::vector InNameList() const { - std::vector vec_temp; + virtual paddle::SmallVector InNameList() const { + paddle::SmallVector vec_temp; vec_temp.reserve(ctx_.inputs.size()); for (auto& input : ctx_.inputs) { - vec_temp.push_back(input.first); + vec_temp.push_back(&input.first); } return vec_temp; @@ -680,9 +680,11 @@ class OperatorWithKernel : public OperatorBase { // By default all input data must be same. proto::VarType::Type IndicateDataType(const ExecutionContext& ctx) const; // used for IndicateDataType - void ParseInputDataType(const std::vector& vars, - const std::string& name, + void ParseInputDataType(const Variable* vars, const std::string& name, proto::VarType::Type* data_type) const; + void ParseMultiInputDataType(const std::vector& vars, + const std::string& name, + proto::VarType::Type* data_type) const; // used for IndicateOrPromoteVarDataTypes Tensor* GetTensorFormInputSafely(const ExecutionContext& ctx, const std::string& name) const; diff --git a/paddle/fluid/imperative/execution_context.h b/paddle/fluid/imperative/execution_context.h index fbc47f81fd..330a5a0cfa 100644 --- a/paddle/fluid/imperative/execution_context.h +++ b/paddle/fluid/imperative/execution_context.h @@ -117,12 +117,12 @@ class DygraphExecutionContext : public framework::ExecutionContext { return it->second; } - std::vector InNameList() const override { - std::vector vec_temp; + paddle::SmallVector InNameList() const override { + paddle::SmallVector vec_temp; vec_temp.reserve(var_map_in_.size()); for (auto& v : var_map_in_) { - vec_temp.push_back(v.first); + vec_temp.push_back(&v.first); } return vec_temp; @@ -144,11 +144,19 @@ class DygraphExecutionContext : public framework::ExecutionContext { } size_t InputSize(const std::string& name) const override { - return InputNames(name).size(); + auto it = var_map_in_.find(name); + PADDLE_ENFORCE_NE( + it, var_map_in_.end(), + platform::errors::NotFound("Can not find [%s] in Input", name)); + return it->second.size(); } size_t OutputSize(const std::string& name) const override { - return OutputNames(name).size(); + auto it = var_map_out_.find(name); + PADDLE_ENFORCE_NE( + it, var_map_out_.end(), + platform::errors::NotFound("Can not find [%s] in Output", name)); + return it->second.size(); } const Variable* InputVar(const std::string& name) const override { diff --git a/paddle/fluid/operators/transpose_op.cc b/paddle/fluid/operators/transpose_op.cc index 1a297e7238..a45d32b34b 100644 --- a/paddle/fluid/operators/transpose_op.cc +++ b/paddle/fluid/operators/transpose_op.cc @@ -90,7 +90,7 @@ class TransposeOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { framework::LibraryType library_{framework::LibraryType::kPlain}; - std::string data_format = ctx.Attr("data_format"); + auto &data_format = ctx.Attr("data_format"); framework::DataLayout layout_ = framework::StringToDataLayout(data_format); auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); #ifdef PADDLE_WITH_MKLDNN diff --git a/paddle/phi/core/kernel_context.h b/paddle/phi/core/kernel_context.h index 9e5660d9dc..a06efb573a 100644 --- a/paddle/phi/core/kernel_context.h +++ b/paddle/phi/core/kernel_context.h @@ -22,6 +22,7 @@ #include "paddle/phi/core/enforce.h" #include "paddle/phi/core/tensor_base.h" #include "paddle/phi/core/tensor_utils.h" +#include "paddle/phi/core/type_defs.h" #include "paddle/utils/optional.h" #include "paddle/utils/small_vector.h" @@ -139,10 +140,11 @@ class KernelContext { paddle::SmallVector inputs_; paddle::SmallVector outputs_; - paddle::SmallVector attrs_; + paddle::SmallVector attrs_; - paddle::SmallVector> input_range_; - paddle::SmallVector> output_range_; + paddle::SmallVector, kInputSmallVectorSize> input_range_; + paddle::SmallVector, kOutputSmallVectorSize> + output_range_; }; } // namespace phi -- GitLab