未验证 提交 2ea56c90 编写于 作者: Z zyfncg 提交者: GitHub

[cherry-pick] Optimize performance of dygraph (#42196) (#42329)

* Optimize performance of dygraph (v4)  (#42196)

* optimize performance of dygraph

* optimize performance of dygraph and elementwise_add

* optimize the trace op

* fix bug

* fix bug

* fix unittest bug

* fix code format

* fix cherry-pick problem
上级 fe4646d1
......@@ -109,8 +109,8 @@ size_t SizeOfType(proto::VarType::Type type) {
}
// Now only supports promotion of complex type
bool NeedPromoteTypes(const proto::VarType::Type a,
const proto::VarType::Type b) {
inline bool NeedPromoteTypes(const proto::VarType::Type& a,
const proto::VarType::Type& b) {
return (IsComplexType(a) || IsComplexType(b));
}
......
......@@ -200,7 +200,7 @@ inline std::ostream& operator<<(std::ostream& out,
return out;
}
extern inline bool IsComplexType(const proto::VarType::Type type) {
extern inline bool IsComplexType(const proto::VarType::Type& type) {
return (type == proto::VarType::COMPLEX64 ||
type == proto::VarType::COMPLEX128);
}
......
......@@ -21,13 +21,17 @@ namespace framework {
std::unique_ptr<OperatorBase> OpRegistry::CreateOp(
const std::string& type, const VariableNameMap& inputs,
const VariableNameMap& outputs, AttributeMap attrs, bool attr_check) {
const VariableNameMap& outputs, const AttributeMap& attrs,
bool attr_check) {
auto& info = OpInfoMap::Instance().Get(type);
if (attr_check && info.Checker() != nullptr) {
info.Checker()->Check(&attrs);
auto tmp_attrs = attrs;
info.Checker()->Check(&tmp_attrs);
return std::unique_ptr<OperatorBase>(
info.Creator()(type, inputs, outputs, tmp_attrs));
}
auto op = info.Creator()(type, inputs, outputs, attrs);
return std::unique_ptr<OperatorBase>(op);
return std::unique_ptr<OperatorBase>(
info.Creator()(type, inputs, outputs, attrs));
}
static VariableNameMap ConvertOpDescVarsToVarNameMap(
......
......@@ -129,7 +129,7 @@ class OpRegistry {
static std::unique_ptr<OperatorBase> CreateOp(const std::string& type,
const VariableNameMap& inputs,
const VariableNameMap& outputs,
AttributeMap attrs,
const AttributeMap& attrs,
bool attr_check = true);
static std::unique_ptr<OperatorBase> CreateOp(const proto::OpDesc& op_desc);
......
......@@ -81,19 +81,21 @@ OpKernelType TransPhiKernelKeyToOpKernelType(const phi::KernelKey& kernel_key) {
phi::KernelKey TransOpKernelTypeToPhiKernelKey(
const OpKernelType& kernel_type) {
phi::Backend backend = phi::TransToPhiBackend(kernel_type.place_);
if (kernel_type.library_type_ == LibraryType::kMKLDNN) {
backend = phi::Backend::MKLDNN;
} else if (kernel_type.library_type_ == LibraryType::kCUDNN) {
switch (kernel_type.library_type_) {
case LibraryType::kCUDNN:
backend = phi::Backend::GPUDNN;
} else if (kernel_type.library_type_ == LibraryType::kKP) {
break;
case LibraryType::kMKLDNN:
backend = phi::Backend::MKLDNN;
break;
case LibraryType::kKP:
backend = phi::Backend::KPS;
} else {
// do nothing
break;
default:
break;
}
paddle::experimental::DataLayout layout = kernel_type.data_layout_;
paddle::experimental::DataType dtype =
paddle::framework::TransToPhiDataType(kernel_type.data_type_);
return phi::KernelKey(backend, layout, dtype);
return phi::KernelKey(backend, kernel_type.data_layout_,
framework::TransToPhiDataType(kernel_type.data_type_));
}
phi::KernelKey FallBackToCpu(const OpKernelType& expected_kernel_key,
......
......@@ -459,7 +459,7 @@ static void OpBaseRunImpl(const framework::OperatorBase& op,
const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs,
const platform::Place& place) {
auto* op_kernel = dynamic_cast<const framework::OperatorWithKernel*>(&op);
auto* op_kernel = static_cast<const framework::OperatorWithKernel*>(&op);
PADDLE_ENFORCE_NOT_NULL(
op_kernel, platform::errors::PermissionDenied(
"Only support operator with kernel in Dygraph mode."));
......
......@@ -40,6 +40,13 @@ static const phi::Kernel empty_kernel;
static const framework::RuntimeContext empty_ctx({}, {});
static const framework::Scope empty_scope;
const phi::KernelFactory& PreparedOp::phi_kernel_factory =
phi::KernelFactory::Instance();
const phi::OpUtilsMap& PreparedOp::phi_op_utils_map =
phi::OpUtilsMap::Instance();
const phi::DefaultKernelSignatureMap& PreparedOp::default_phi_kernel_sig_map =
phi::DefaultKernelSignatureMap::Instance();
const std::shared_ptr<VariableWrapper>& GetVariableWrapper(
const std::shared_ptr<paddle::imperative::VarBase>& var) {
return var->SharedVar();
......@@ -139,12 +146,14 @@ PreparedOp::PreparedOp(const framework::OperatorBase& op,
phi_kernel_(phi_kernel) {}
template <typename VarType>
PreparedOp PrepareImpl(const NameVarMap<VarType>& ins,
const NameVarMap<VarType>& outs,
const framework::OperatorWithKernel& op,
const platform::Place& place,
PreparedOp PrepareImpl(
const NameVarMap<VarType>& ins, const NameVarMap<VarType>& outs,
const framework::OperatorWithKernel& op, const platform::Place& place,
const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs) {
const framework::AttributeMap& default_attrs,
const phi::KernelFactory& phi_kernel_factory,
const phi::OpUtilsMap& phi_op_utils_map,
const phi::DefaultKernelSignatureMap& default_phi_kernel_sig_map) {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto* dev_ctx = pool.Get(place);
......@@ -184,15 +193,15 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins,
bool has_phi_kernel = false;
const auto* arg_map_fn =
phi::OpUtilsMap::Instance().GetArgumentMappingFn(op.Type());
const auto* arg_map_fn = phi_op_utils_map.GetArgumentMappingFn(op.Type());
if (arg_map_fn) {
has_phi_kernel = true;
kernel_signature = (*arg_map_fn)(
framework::ExecutionArgumentMappingContext(dygraph_exe_ctx));
} else {
default_kernel_signature =
phi::DefaultKernelSignatureMap::Instance().GetNullable(op.Type());
default_phi_kernel_sig_map.GetNullable(op.Type());
if (default_kernel_signature) {
has_phi_kernel = true;
kernel_signature = *default_kernel_signature;
......@@ -228,8 +237,7 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins,
<< ", using_kernel_key:" << expected_kernel_key;
phi::KernelKey try_pt_kernel_key =
TransOpKernelTypeToPhiKernelKey(expected_kernel_key);
if (!phi::KernelFactory::Instance().HasKernel(pt_kernel_name,
try_pt_kernel_key)) {
if (!phi_kernel_factory.HasKernel(pt_kernel_name, try_pt_kernel_key)) {
expected_kernel_key.library_type_ = expected_kernel_key_library_type;
VLOG(3) << "modify XPU KP kernel: " << op.Type() << " is failed "
<< expected_kernel_key;
......@@ -239,8 +247,8 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins,
#endif
pt_kernel_key = TransOpKernelTypeToPhiKernelKey(expected_kernel_key);
auto& phi_kernel = phi::KernelFactory::Instance().SelectKernel(
pt_kernel_name, pt_kernel_key);
auto& phi_kernel =
phi_kernel_factory.SelectKernel(pt_kernel_name, pt_kernel_key);
if (phi_kernel.IsValid()
#if defined(PADDLE_WITH_XPU) && !defined(PADDLE_WITH_XPU_KP)
......@@ -295,11 +303,11 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins,
|| (is_xpu_unsupport && !is_xpu_kp_support)
#endif
) {
if (phi::KernelFactory::Instance().HasCompatiblePhiKernel(op.Type())) {
if (has_phi_kernel) {
auto pt_cpu_kernel_key =
FallBackToCpu(expected_kernel_key, pt_kernel_key, op);
auto& pt_cpu_kernel = phi::KernelFactory::Instance().SelectKernel(
pt_kernel_name, pt_cpu_kernel_key);
auto& pt_cpu_kernel =
phi_kernel_factory.SelectKernel(pt_kernel_name, pt_cpu_kernel_key);
if (pt_cpu_kernel.IsValid()) {
VLOG(6) << "Dynamic mode PrepareImpl - kernel name: " << pt_kernel_name
<< " | kernel key: " << pt_cpu_kernel_key
......@@ -408,7 +416,9 @@ PreparedOp PreparedOp::Prepare(const NameVarMap<VarBase>& ins,
const platform::Place& place,
const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs) {
return PrepareImpl<VarBase>(ins, outs, op, place, attrs, default_attrs);
return PrepareImpl<VarBase>(ins, outs, op, place, attrs, default_attrs,
phi_kernel_factory, phi_op_utils_map,
default_phi_kernel_sig_map);
}
PreparedOp PreparedOp::Prepare(const NameVarMap<VariableWrapper>& ins,
......@@ -417,8 +427,9 @@ PreparedOp PreparedOp::Prepare(const NameVarMap<VariableWrapper>& ins,
const platform::Place& place,
const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs) {
return PrepareImpl<VariableWrapper>(ins, outs, op, place, attrs,
default_attrs);
return PrepareImpl<VariableWrapper>(
ins, outs, op, place, attrs, default_attrs, phi_kernel_factory,
phi_op_utils_map, default_phi_kernel_sig_map);
}
PreparedOp PreparedOp::Prepare(const NameVarMap<egr::EagerVariable>& ins,
......@@ -427,8 +438,9 @@ PreparedOp PreparedOp::Prepare(const NameVarMap<egr::EagerVariable>& ins,
const platform::Place& place,
const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs) {
return PrepareImpl<egr::EagerVariable>(ins, outs, op, place, attrs,
default_attrs);
return PrepareImpl<egr::EagerVariable>(
ins, outs, op, place, attrs, default_attrs, phi_kernel_factory,
phi_op_utils_map, default_phi_kernel_sig_map);
}
template <typename VarType>
static void PreparedOpRunImpl(
......@@ -441,7 +453,6 @@ static void PreparedOpRunImpl(
const NameVarMap<VarType>& outs, const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs) {
// TODO(zjl): remove scope in dygraph
framework::Scope scope;
{
platform::RecordEvent record_event("infer_shape",
......@@ -458,8 +469,8 @@ static void PreparedOpRunImpl(
platform::TracerEventType::OperatorInner,
1, platform::EventRole::kInnerOp);
func(DygraphExecutionContext<VarType>(op, scope, *dev_ctx, ctx, ins, outs,
attrs, default_attrs));
func(DygraphExecutionContext<VarType>(op, empty_scope, *dev_ctx, ctx, ins,
outs, attrs, default_attrs));
}
if (FLAGS_check_nan_inf) {
......@@ -503,7 +514,7 @@ static void PreparedOpRunPtImpl(
const NameVarMap<VarType>& outs, const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs) {
{
platform::RecordEvent record_event(op.Type() + "::infer_shape",
platform::RecordEvent record_event("infer_shape",
platform::TracerEventType::OperatorInner,
1, platform::EventRole::kInnerOp);
DygraphInferShapeContext<VarType> infer_shape_ctx(
......@@ -513,7 +524,7 @@ static void PreparedOpRunPtImpl(
}
{
platform::RecordEvent record_event(op.Type() + "::compute",
platform::RecordEvent record_event("compute",
platform::TracerEventType::OperatorInner,
1, platform::EventRole::kInnerOp);
......
......@@ -214,6 +214,10 @@ class PreparedOp {
const phi::KernelSignature* default_kernel_signature_;
phi::KernelSignature kernel_signature_;
const phi::Kernel& phi_kernel_;
static const phi::KernelFactory& phi_kernel_factory;
static const phi::OpUtilsMap& phi_op_utils_map;
static const phi::DefaultKernelSignatureMap& default_phi_kernel_sig_map;
};
const inline framework::Attribute& GetAttr(
......
......@@ -192,7 +192,7 @@ void Tracer::TraceOpImpl(const std::string& type,
paddle::framework::AttributeMap* passed_default_attrs_,
bool use_default_attr_map) {
platform::RecordEvent op_type_record_event(
type + " trace_op", platform::TracerEventType::Operator, 1);
"trace_op", platform::TracerEventType::Operator, 1);
platform::ScopedFlushDenormal flush;
VLOG(1) << "Trace Op: " << type;
if (FLAGS_use_mkldnn) {
......
......@@ -28,23 +28,24 @@ namespace phi {
Backend TransToPhiBackend(const phi::Place& place) {
auto allocation_type = place.GetType();
if (allocation_type == phi::AllocationType::CPU) {
return Backend::CPU;
} else if (allocation_type == phi::AllocationType::GPU) {
switch (allocation_type) {
case phi::AllocationType::GPU:
return Backend::GPU;
} else if (allocation_type == phi::AllocationType::GPUPINNED) {
case AllocationType::CPU:
return Backend::CPU;
case AllocationType::GPUPINNED:
return Backend::GPU;
} else if (allocation_type == phi::AllocationType::XPU) {
case AllocationType::XPU:
return Backend::XPU;
} else if (allocation_type == phi::AllocationType::NPU) {
case AllocationType::NPU:
return Backend::NPU;
} else if (allocation_type == phi::AllocationType::IPU) {
case AllocationType::IPU:
return Backend::IPU;
} else if (allocation_type == phi::AllocationType::CUSTOM) {
case AllocationType::CUSTOM:
return static_cast<Backend>(
static_cast<size_t>(Backend::NUM_BACKENDS) +
GetOrRegisterGlobalDeviceTypeId(place.GetDeviceType()));
} else {
default:
PADDLE_THROW(phi::errors::InvalidArgument(
"Unsupported transform %s to phi Backend.", place));
}
......
......@@ -129,7 +129,6 @@ void* DenseTensor::AllocateFrom(Allocator* allocator,
template <typename T>
const T* DenseTensor::data() const {
check_memory_size();
PADDLE_ENFORCE_EQ(
dtype(),
paddle::experimental::CppTypeToDataType<T>::Type(),
......@@ -141,13 +140,13 @@ const T* DenseTensor::data() const {
template <typename T>
T* DenseTensor::data() {
check_memory_size();
T* ret = static_cast<T*>(data());
PADDLE_ENFORCE(
(dtype() == paddle::experimental::CppTypeToDataType<T>::Type()),
phi::errors::InvalidArgument(
"The type of data we are trying to retrieve does not match the "
"type of data currently contained in the container."));
return static_cast<T*>(data());
return ret;
}
void* DenseTensor::data() {
......
......@@ -75,7 +75,7 @@ namespace phi {
"Kernel's Input should appear before Attributes."); \
static_assert(out_idx == 0, \
"Kernel's Input should appear before Outputs."); \
const std::pair<int, int> range = ctx->InputRangeAt(in_idx); \
const std::pair<int, int>& range = ctx->InputRangeAt(in_idx); \
const tensor_type& arg = ctx->InputAt<tensor_type>(range.first); \
KernelCallHelper<Tail...>:: \
template Compute<dev_ctx_idx, in_idx + 1, attr_idx, out_idx>( \
......@@ -96,7 +96,7 @@ namespace phi {
"Kernel's Input should appear before Attributes."); \
static_assert(out_idx == 0, \
"Kernel's Input should appear before Outputs."); \
const std::pair<int, int> range = ctx->InputRangeAt(in_idx); \
const std::pair<int, int>& range = ctx->InputRangeAt(in_idx); \
auto arg = ctx->OptionalInputAt<tensor_type>(range.first); \
KernelCallHelper<Tail...>:: \
template Compute<dev_ctx_idx, in_idx + 1, attr_idx, out_idx>( \
......@@ -117,7 +117,7 @@ namespace phi {
"Kernel's Input should appear before Attributes."); \
static_assert(out_idx == 0, \
"Kernel's Input should appear before Outputs."); \
const std::pair<int, int> range = ctx->InputRangeAt(in_idx); \
const std::pair<int, int>& range = ctx->InputRangeAt(in_idx); \
std::vector<const tensor_type*> arg = std::move( \
ctx->InputsBetween<tensor_type>(range.first, range.second)); \
KernelCallHelper<Tail...>:: \
......@@ -141,7 +141,7 @@ namespace phi {
"Kernel's Input should appear before Attributes."); \
static_assert(out_idx == 0, \
"Kernel's Input should appear before Outputs."); \
const std::pair<int, int> range = ctx->InputRangeAt(in_idx); \
const std::pair<int, int>& range = ctx->InputRangeAt(in_idx); \
paddle::optional<const std::vector<const tensor_type*>> arg = \
ctx->OptionalInputsBetween<tensor_type>(range.first, range.second); \
KernelCallHelper<Tail...>:: \
......@@ -195,7 +195,7 @@ namespace phi {
int out_idx, \
typename... PreviousArgs> \
static void Compute(KernelContext* ctx, PreviousArgs&... pargs) { \
const std::pair<int, int> range = ctx->OutputRangeAt(out_idx); \
const std::pair<int, int>& range = ctx->OutputRangeAt(out_idx); \
tensor_type* arg = ctx->MutableOutputAt<tensor_type>(range.first); \
KernelCallHelper<Tail...>:: \
template Compute<dev_ctx_idx, in_idx, attr_idx, out_idx + 1>( \
......@@ -212,7 +212,7 @@ namespace phi {
int out_idx, \
typename... PreviousArgs> \
static void Compute(KernelContext* ctx, PreviousArgs&... pargs) { \
const std::pair<int, int> range = ctx->OutputRangeAt(out_idx); \
const std::pair<int, int>& range = ctx->OutputRangeAt(out_idx); \
std::vector<tensor_type*> arg = std::move( \
ctx->MutableOutputBetween<tensor_type>(range.first, range.second)); \
KernelCallHelper<Tail...>:: \
......
......@@ -554,6 +554,7 @@ void BroadcastKernel(const KPDevice &ctx,
int axis,
Functor func) {
std::vector<int> dims_size;
dims_size.reserve(ins.size());
bool no_broadcast_flag = true;
for (auto *in : ins) {
no_broadcast_flag &= ins[0]->dims() == in->dims();
......
......@@ -28,7 +28,9 @@ namespace phi {
int axis, \
DenseTensor* out) { \
std::vector<const DenseTensor*> inputs; \
inputs.reserve(2); \
std::vector<DenseTensor*> outputs; \
outputs.reserve(1); \
inputs.emplace_back(&x); \
inputs.emplace_back(&y); \
outputs.emplace_back(out); \
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册