未验证 提交 3a0d7bf0 编写于 作者: C Chen Weihang 提交者: GitHub

Optimize dygraph GetExpectedKernelType perf (#42154)

* opt dygraph scheduling

* revert part impl
上级 13190707
...@@ -940,7 +940,7 @@ class RuntimeInferShapeContext : public InferShapeContext { ...@@ -940,7 +940,7 @@ class RuntimeInferShapeContext : public InferShapeContext {
return ((op_with_kernel.kernel_type()) && return ((op_with_kernel.kernel_type()) &&
(op_with_kernel.kernel_type()->data_layout_ == (op_with_kernel.kernel_type()->data_layout_ ==
framework::DataLayout::kMKLDNN)); framework::DataLayout::kMKLDNN));
} catch (std::bad_cast exp) { } catch (const std::bad_cast& exp) {
return false; return false;
} }
} }
...@@ -1965,6 +1965,36 @@ Scope* OperatorWithKernel::PrepareData( ...@@ -1965,6 +1965,36 @@ Scope* OperatorWithKernel::PrepareData(
} }
void OperatorWithKernel::ParseInputDataType( void OperatorWithKernel::ParseInputDataType(
const Variable* var, const std::string& name,
proto::VarType::Type* data_type) const {
if (var != nullptr) {
const Tensor* t = nullptr;
if (var->IsType<Tensor>()) {
t = &var->Get<Tensor>();
} else if (var->IsType<LoDTensor>()) {
t = &var->Get<LoDTensor>();
} else if (var->IsType<phi::SelectedRows>()) {
t = &(var->Get<phi::SelectedRows>().value());
} else if (var->IsType<LoDTensorArray>()) {
auto t_arr = &var->Get<LoDTensorArray>();
for (size_t j = 0; j < t_arr->size(); j++) {
if (t_arr->at(j).IsInitialized()) {
t = &(t_arr->at(j));
}
}
}
if (t != nullptr) {
PADDLE_ENFORCE_EQ(
t->IsInitialized(), true,
platform::errors::InvalidArgument("The %s Op's Input Variable `%s` "
"contains uninitialized Tensor.",
Type(), name));
*data_type = paddle::framework::TransToProtoVarType(t->dtype());
}
}
}
void OperatorWithKernel::ParseMultiInputDataType(
const std::vector<Variable*>& vars, const std::string& name, const std::vector<Variable*>& vars, const std::string& name,
proto::VarType::Type* data_type) const { proto::VarType::Type* data_type) const {
proto::VarType::Type default_data_type = proto::VarType::Type default_data_type =
...@@ -2015,9 +2045,12 @@ proto::VarType::Type OperatorWithKernel::IndicateDataType( ...@@ -2015,9 +2045,12 @@ proto::VarType::Type OperatorWithKernel::IndicateDataType(
proto::VarType::Type dafault_data_type = proto::VarType::Type dafault_data_type =
static_cast<proto::VarType::Type>(-1); static_cast<proto::VarType::Type>(-1);
proto::VarType::Type data_type = dafault_data_type; proto::VarType::Type data_type = dafault_data_type;
for (auto& input : ctx.InNameList()) { for (auto* name : ctx.InNameList()) {
const std::vector<Variable*> vars = ctx.MultiInputVar(input); if (ctx.InputSize(*name) == 1UL) {
ParseInputDataType(vars, input, &data_type); ParseInputDataType(ctx.InputVar(*name), *name, &data_type);
} else {
ParseMultiInputDataType(ctx.MultiInputVar(*name), *name, &data_type);
}
} }
PADDLE_ENFORCE_NE( PADDLE_ENFORCE_NE(
data_type, dafault_data_type, data_type, dafault_data_type,
...@@ -2031,7 +2064,11 @@ proto::VarType::Type OperatorWithKernel::IndicateVarDataType( ...@@ -2031,7 +2064,11 @@ proto::VarType::Type OperatorWithKernel::IndicateVarDataType(
proto::VarType::Type dafault_data_type = proto::VarType::Type dafault_data_type =
static_cast<proto::VarType::Type>(-1); static_cast<proto::VarType::Type>(-1);
proto::VarType::Type data_type = dafault_data_type; proto::VarType::Type data_type = dafault_data_type;
ParseInputDataType(ctx.MultiInputVar(name), name, &data_type); if (ctx.InputSize(name) == 1UL) {
ParseInputDataType(ctx.InputVar(name), name, &data_type);
} else {
ParseMultiInputDataType(ctx.MultiInputVar(name), name, &data_type);
}
PADDLE_ENFORCE_NE( PADDLE_ENFORCE_NE(
data_type, dafault_data_type, data_type, dafault_data_type,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
......
...@@ -333,12 +333,12 @@ class ExecutionContext { ...@@ -333,12 +333,12 @@ class ExecutionContext {
return it->second; return it->second;
} }
virtual std::vector<std::string> InNameList() const { virtual paddle::SmallVector<const std::string*> InNameList() const {
std::vector<std::string> vec_temp; paddle::SmallVector<const std::string*> vec_temp;
vec_temp.reserve(ctx_.inputs.size()); vec_temp.reserve(ctx_.inputs.size());
for (auto& input : ctx_.inputs) { for (auto& input : ctx_.inputs) {
vec_temp.push_back(input.first); vec_temp.push_back(&input.first);
} }
return vec_temp; return vec_temp;
...@@ -680,9 +680,11 @@ class OperatorWithKernel : public OperatorBase { ...@@ -680,9 +680,11 @@ class OperatorWithKernel : public OperatorBase {
// By default all input data must be same. // By default all input data must be same.
proto::VarType::Type IndicateDataType(const ExecutionContext& ctx) const; proto::VarType::Type IndicateDataType(const ExecutionContext& ctx) const;
// used for IndicateDataType // used for IndicateDataType
void ParseInputDataType(const std::vector<Variable*>& vars, void ParseInputDataType(const Variable* vars, const std::string& name,
const std::string& name,
proto::VarType::Type* data_type) const; proto::VarType::Type* data_type) const;
void ParseMultiInputDataType(const std::vector<Variable*>& vars,
const std::string& name,
proto::VarType::Type* data_type) const;
// used for IndicateOrPromoteVarDataTypes // used for IndicateOrPromoteVarDataTypes
Tensor* GetTensorFormInputSafely(const ExecutionContext& ctx, Tensor* GetTensorFormInputSafely(const ExecutionContext& ctx,
const std::string& name) const; const std::string& name) const;
......
...@@ -117,12 +117,12 @@ class DygraphExecutionContext : public framework::ExecutionContext { ...@@ -117,12 +117,12 @@ class DygraphExecutionContext : public framework::ExecutionContext {
return it->second; return it->second;
} }
std::vector<std::string> InNameList() const override { paddle::SmallVector<const std::string*> InNameList() const override {
std::vector<std::string> vec_temp; paddle::SmallVector<const std::string*> vec_temp;
vec_temp.reserve(var_map_in_.size()); vec_temp.reserve(var_map_in_.size());
for (auto& v : var_map_in_) { for (auto& v : var_map_in_) {
vec_temp.push_back(v.first); vec_temp.push_back(&v.first);
} }
return vec_temp; return vec_temp;
...@@ -144,11 +144,19 @@ class DygraphExecutionContext : public framework::ExecutionContext { ...@@ -144,11 +144,19 @@ class DygraphExecutionContext : public framework::ExecutionContext {
} }
size_t InputSize(const std::string& name) const override { size_t InputSize(const std::string& name) const override {
return InputNames(name).size(); auto it = var_map_in_.find(name);
PADDLE_ENFORCE_NE(
it, var_map_in_.end(),
platform::errors::NotFound("Can not find [%s] in Input", name));
return it->second.size();
} }
size_t OutputSize(const std::string& name) const override { size_t OutputSize(const std::string& name) const override {
return OutputNames(name).size(); auto it = var_map_out_.find(name);
PADDLE_ENFORCE_NE(
it, var_map_out_.end(),
platform::errors::NotFound("Can not find [%s] in Output", name));
return it->second.size();
} }
const Variable* InputVar(const std::string& name) const override { const Variable* InputVar(const std::string& name) const override {
......
...@@ -90,7 +90,7 @@ class TransposeOp : public framework::OperatorWithKernel { ...@@ -90,7 +90,7 @@ class TransposeOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
framework::LibraryType library_{framework::LibraryType::kPlain}; framework::LibraryType library_{framework::LibraryType::kPlain};
std::string data_format = ctx.Attr<std::string>("data_format"); auto &data_format = ctx.Attr<std::string>("data_format");
framework::DataLayout layout_ = framework::StringToDataLayout(data_format); framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
......
...@@ -22,6 +22,7 @@ ...@@ -22,6 +22,7 @@
#include "paddle/phi/core/enforce.h" #include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/tensor_base.h" #include "paddle/phi/core/tensor_base.h"
#include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/core/type_defs.h"
#include "paddle/utils/optional.h" #include "paddle/utils/optional.h"
#include "paddle/utils/small_vector.h" #include "paddle/utils/small_vector.h"
...@@ -139,10 +140,11 @@ class KernelContext { ...@@ -139,10 +140,11 @@ class KernelContext {
paddle::SmallVector<const TensorBase*> inputs_; paddle::SmallVector<const TensorBase*> inputs_;
paddle::SmallVector<TensorBase*> outputs_; paddle::SmallVector<TensorBase*> outputs_;
paddle::SmallVector<Attribute> attrs_; paddle::SmallVector<Attribute, kAttrSmallVectorSize> attrs_;
paddle::SmallVector<std::pair<int, int>> input_range_; paddle::SmallVector<std::pair<int, int>, kInputSmallVectorSize> input_range_;
paddle::SmallVector<std::pair<int, int>> output_range_; paddle::SmallVector<std::pair<int, int>, kOutputSmallVectorSize>
output_range_;
}; };
} // namespace phi } // namespace phi
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册