未验证 提交 2ea56c90 编写于 作者: Z zyfncg 提交者: GitHub

[cherry-pick] Optimize performance of dygraph (#42196) (#42329)

* Optimize performance of dygraph (v4)  (#42196)

* optimize performance of dygraph

* optimize performance of dygraph and elementwise_add

* optimize the trace op

* fix bug

* fix bug

* fix unittest bug

* fix code format

* fix cherry-pick problem
上级 fe4646d1
...@@ -109,8 +109,8 @@ size_t SizeOfType(proto::VarType::Type type) { ...@@ -109,8 +109,8 @@ size_t SizeOfType(proto::VarType::Type type) {
} }
// Now only supports promotion of complex type // Now only supports promotion of complex type
bool NeedPromoteTypes(const proto::VarType::Type a, inline bool NeedPromoteTypes(const proto::VarType::Type& a,
const proto::VarType::Type b) { const proto::VarType::Type& b) {
return (IsComplexType(a) || IsComplexType(b)); return (IsComplexType(a) || IsComplexType(b));
} }
......
...@@ -200,7 +200,7 @@ inline std::ostream& operator<<(std::ostream& out, ...@@ -200,7 +200,7 @@ inline std::ostream& operator<<(std::ostream& out,
return out; return out;
} }
extern inline bool IsComplexType(const proto::VarType::Type type) { extern inline bool IsComplexType(const proto::VarType::Type& type) {
return (type == proto::VarType::COMPLEX64 || return (type == proto::VarType::COMPLEX64 ||
type == proto::VarType::COMPLEX128); type == proto::VarType::COMPLEX128);
} }
......
...@@ -21,13 +21,17 @@ namespace framework { ...@@ -21,13 +21,17 @@ namespace framework {
std::unique_ptr<OperatorBase> OpRegistry::CreateOp( std::unique_ptr<OperatorBase> OpRegistry::CreateOp(
const std::string& type, const VariableNameMap& inputs, const std::string& type, const VariableNameMap& inputs,
const VariableNameMap& outputs, AttributeMap attrs, bool attr_check) { const VariableNameMap& outputs, const AttributeMap& attrs,
bool attr_check) {
auto& info = OpInfoMap::Instance().Get(type); auto& info = OpInfoMap::Instance().Get(type);
if (attr_check && info.Checker() != nullptr) { if (attr_check && info.Checker() != nullptr) {
info.Checker()->Check(&attrs); auto tmp_attrs = attrs;
info.Checker()->Check(&tmp_attrs);
return std::unique_ptr<OperatorBase>(
info.Creator()(type, inputs, outputs, tmp_attrs));
} }
auto op = info.Creator()(type, inputs, outputs, attrs); return std::unique_ptr<OperatorBase>(
return std::unique_ptr<OperatorBase>(op); info.Creator()(type, inputs, outputs, attrs));
} }
static VariableNameMap ConvertOpDescVarsToVarNameMap( static VariableNameMap ConvertOpDescVarsToVarNameMap(
......
...@@ -129,7 +129,7 @@ class OpRegistry { ...@@ -129,7 +129,7 @@ class OpRegistry {
static std::unique_ptr<OperatorBase> CreateOp(const std::string& type, static std::unique_ptr<OperatorBase> CreateOp(const std::string& type,
const VariableNameMap& inputs, const VariableNameMap& inputs,
const VariableNameMap& outputs, const VariableNameMap& outputs,
AttributeMap attrs, const AttributeMap& attrs,
bool attr_check = true); bool attr_check = true);
static std::unique_ptr<OperatorBase> CreateOp(const proto::OpDesc& op_desc); static std::unique_ptr<OperatorBase> CreateOp(const proto::OpDesc& op_desc);
......
...@@ -81,19 +81,21 @@ OpKernelType TransPhiKernelKeyToOpKernelType(const phi::KernelKey& kernel_key) { ...@@ -81,19 +81,21 @@ OpKernelType TransPhiKernelKeyToOpKernelType(const phi::KernelKey& kernel_key) {
phi::KernelKey TransOpKernelTypeToPhiKernelKey( phi::KernelKey TransOpKernelTypeToPhiKernelKey(
const OpKernelType& kernel_type) { const OpKernelType& kernel_type) {
phi::Backend backend = phi::TransToPhiBackend(kernel_type.place_); phi::Backend backend = phi::TransToPhiBackend(kernel_type.place_);
if (kernel_type.library_type_ == LibraryType::kMKLDNN) { switch (kernel_type.library_type_) {
backend = phi::Backend::MKLDNN; case LibraryType::kCUDNN:
} else if (kernel_type.library_type_ == LibraryType::kCUDNN) {
backend = phi::Backend::GPUDNN; backend = phi::Backend::GPUDNN;
} else if (kernel_type.library_type_ == LibraryType::kKP) { break;
case LibraryType::kMKLDNN:
backend = phi::Backend::MKLDNN;
break;
case LibraryType::kKP:
backend = phi::Backend::KPS; backend = phi::Backend::KPS;
} else { break;
// do nothing default:
break;
} }
paddle::experimental::DataLayout layout = kernel_type.data_layout_; return phi::KernelKey(backend, kernel_type.data_layout_,
paddle::experimental::DataType dtype = framework::TransToPhiDataType(kernel_type.data_type_));
paddle::framework::TransToPhiDataType(kernel_type.data_type_);
return phi::KernelKey(backend, layout, dtype);
} }
phi::KernelKey FallBackToCpu(const OpKernelType& expected_kernel_key, phi::KernelKey FallBackToCpu(const OpKernelType& expected_kernel_key,
......
...@@ -459,7 +459,7 @@ static void OpBaseRunImpl(const framework::OperatorBase& op, ...@@ -459,7 +459,7 @@ static void OpBaseRunImpl(const framework::OperatorBase& op,
const framework::AttributeMap& attrs, const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs, const framework::AttributeMap& default_attrs,
const platform::Place& place) { const platform::Place& place) {
auto* op_kernel = dynamic_cast<const framework::OperatorWithKernel*>(&op); auto* op_kernel = static_cast<const framework::OperatorWithKernel*>(&op);
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
op_kernel, platform::errors::PermissionDenied( op_kernel, platform::errors::PermissionDenied(
"Only support operator with kernel in Dygraph mode.")); "Only support operator with kernel in Dygraph mode."));
......
...@@ -40,6 +40,13 @@ static const phi::Kernel empty_kernel; ...@@ -40,6 +40,13 @@ static const phi::Kernel empty_kernel;
static const framework::RuntimeContext empty_ctx({}, {}); static const framework::RuntimeContext empty_ctx({}, {});
static const framework::Scope empty_scope; static const framework::Scope empty_scope;
const phi::KernelFactory& PreparedOp::phi_kernel_factory =
phi::KernelFactory::Instance();
const phi::OpUtilsMap& PreparedOp::phi_op_utils_map =
phi::OpUtilsMap::Instance();
const phi::DefaultKernelSignatureMap& PreparedOp::default_phi_kernel_sig_map =
phi::DefaultKernelSignatureMap::Instance();
const std::shared_ptr<VariableWrapper>& GetVariableWrapper( const std::shared_ptr<VariableWrapper>& GetVariableWrapper(
const std::shared_ptr<paddle::imperative::VarBase>& var) { const std::shared_ptr<paddle::imperative::VarBase>& var) {
return var->SharedVar(); return var->SharedVar();
...@@ -139,12 +146,14 @@ PreparedOp::PreparedOp(const framework::OperatorBase& op, ...@@ -139,12 +146,14 @@ PreparedOp::PreparedOp(const framework::OperatorBase& op,
phi_kernel_(phi_kernel) {} phi_kernel_(phi_kernel) {}
template <typename VarType> template <typename VarType>
PreparedOp PrepareImpl(const NameVarMap<VarType>& ins, PreparedOp PrepareImpl(
const NameVarMap<VarType>& outs, const NameVarMap<VarType>& ins, const NameVarMap<VarType>& outs,
const framework::OperatorWithKernel& op, const framework::OperatorWithKernel& op, const platform::Place& place,
const platform::Place& place,
const framework::AttributeMap& attrs, const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs) { const framework::AttributeMap& default_attrs,
const phi::KernelFactory& phi_kernel_factory,
const phi::OpUtilsMap& phi_op_utils_map,
const phi::DefaultKernelSignatureMap& default_phi_kernel_sig_map) {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto* dev_ctx = pool.Get(place); auto* dev_ctx = pool.Get(place);
...@@ -184,15 +193,15 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins, ...@@ -184,15 +193,15 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins,
bool has_phi_kernel = false; bool has_phi_kernel = false;
const auto* arg_map_fn = const auto* arg_map_fn = phi_op_utils_map.GetArgumentMappingFn(op.Type());
phi::OpUtilsMap::Instance().GetArgumentMappingFn(op.Type());
if (arg_map_fn) { if (arg_map_fn) {
has_phi_kernel = true; has_phi_kernel = true;
kernel_signature = (*arg_map_fn)( kernel_signature = (*arg_map_fn)(
framework::ExecutionArgumentMappingContext(dygraph_exe_ctx)); framework::ExecutionArgumentMappingContext(dygraph_exe_ctx));
} else { } else {
default_kernel_signature = default_kernel_signature =
phi::DefaultKernelSignatureMap::Instance().GetNullable(op.Type()); default_phi_kernel_sig_map.GetNullable(op.Type());
if (default_kernel_signature) { if (default_kernel_signature) {
has_phi_kernel = true; has_phi_kernel = true;
kernel_signature = *default_kernel_signature; kernel_signature = *default_kernel_signature;
...@@ -228,8 +237,7 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins, ...@@ -228,8 +237,7 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins,
<< ", using_kernel_key:" << expected_kernel_key; << ", using_kernel_key:" << expected_kernel_key;
phi::KernelKey try_pt_kernel_key = phi::KernelKey try_pt_kernel_key =
TransOpKernelTypeToPhiKernelKey(expected_kernel_key); TransOpKernelTypeToPhiKernelKey(expected_kernel_key);
if (!phi::KernelFactory::Instance().HasKernel(pt_kernel_name, if (!phi_kernel_factory.HasKernel(pt_kernel_name, try_pt_kernel_key)) {
try_pt_kernel_key)) {
expected_kernel_key.library_type_ = expected_kernel_key_library_type; expected_kernel_key.library_type_ = expected_kernel_key_library_type;
VLOG(3) << "modify XPU KP kernel: " << op.Type() << " is failed " VLOG(3) << "modify XPU KP kernel: " << op.Type() << " is failed "
<< expected_kernel_key; << expected_kernel_key;
...@@ -239,8 +247,8 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins, ...@@ -239,8 +247,8 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins,
#endif #endif
pt_kernel_key = TransOpKernelTypeToPhiKernelKey(expected_kernel_key); pt_kernel_key = TransOpKernelTypeToPhiKernelKey(expected_kernel_key);
auto& phi_kernel = phi::KernelFactory::Instance().SelectKernel( auto& phi_kernel =
pt_kernel_name, pt_kernel_key); phi_kernel_factory.SelectKernel(pt_kernel_name, pt_kernel_key);
if (phi_kernel.IsValid() if (phi_kernel.IsValid()
#if defined(PADDLE_WITH_XPU) && !defined(PADDLE_WITH_XPU_KP) #if defined(PADDLE_WITH_XPU) && !defined(PADDLE_WITH_XPU_KP)
...@@ -295,11 +303,11 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins, ...@@ -295,11 +303,11 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins,
|| (is_xpu_unsupport && !is_xpu_kp_support) || (is_xpu_unsupport && !is_xpu_kp_support)
#endif #endif
) { ) {
if (phi::KernelFactory::Instance().HasCompatiblePhiKernel(op.Type())) { if (has_phi_kernel) {
auto pt_cpu_kernel_key = auto pt_cpu_kernel_key =
FallBackToCpu(expected_kernel_key, pt_kernel_key, op); FallBackToCpu(expected_kernel_key, pt_kernel_key, op);
auto& pt_cpu_kernel = phi::KernelFactory::Instance().SelectKernel( auto& pt_cpu_kernel =
pt_kernel_name, pt_cpu_kernel_key); phi_kernel_factory.SelectKernel(pt_kernel_name, pt_cpu_kernel_key);
if (pt_cpu_kernel.IsValid()) { if (pt_cpu_kernel.IsValid()) {
VLOG(6) << "Dynamic mode PrepareImpl - kernel name: " << pt_kernel_name VLOG(6) << "Dynamic mode PrepareImpl - kernel name: " << pt_kernel_name
<< " | kernel key: " << pt_cpu_kernel_key << " | kernel key: " << pt_cpu_kernel_key
...@@ -408,7 +416,9 @@ PreparedOp PreparedOp::Prepare(const NameVarMap<VarBase>& ins, ...@@ -408,7 +416,9 @@ PreparedOp PreparedOp::Prepare(const NameVarMap<VarBase>& ins,
const platform::Place& place, const platform::Place& place,
const framework::AttributeMap& attrs, const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs) { const framework::AttributeMap& default_attrs) {
return PrepareImpl<VarBase>(ins, outs, op, place, attrs, default_attrs); return PrepareImpl<VarBase>(ins, outs, op, place, attrs, default_attrs,
phi_kernel_factory, phi_op_utils_map,
default_phi_kernel_sig_map);
} }
PreparedOp PreparedOp::Prepare(const NameVarMap<VariableWrapper>& ins, PreparedOp PreparedOp::Prepare(const NameVarMap<VariableWrapper>& ins,
...@@ -417,8 +427,9 @@ PreparedOp PreparedOp::Prepare(const NameVarMap<VariableWrapper>& ins, ...@@ -417,8 +427,9 @@ PreparedOp PreparedOp::Prepare(const NameVarMap<VariableWrapper>& ins,
const platform::Place& place, const platform::Place& place,
const framework::AttributeMap& attrs, const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs) { const framework::AttributeMap& default_attrs) {
return PrepareImpl<VariableWrapper>(ins, outs, op, place, attrs, return PrepareImpl<VariableWrapper>(
default_attrs); ins, outs, op, place, attrs, default_attrs, phi_kernel_factory,
phi_op_utils_map, default_phi_kernel_sig_map);
} }
PreparedOp PreparedOp::Prepare(const NameVarMap<egr::EagerVariable>& ins, PreparedOp PreparedOp::Prepare(const NameVarMap<egr::EagerVariable>& ins,
...@@ -427,8 +438,9 @@ PreparedOp PreparedOp::Prepare(const NameVarMap<egr::EagerVariable>& ins, ...@@ -427,8 +438,9 @@ PreparedOp PreparedOp::Prepare(const NameVarMap<egr::EagerVariable>& ins,
const platform::Place& place, const platform::Place& place,
const framework::AttributeMap& attrs, const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs) { const framework::AttributeMap& default_attrs) {
return PrepareImpl<egr::EagerVariable>(ins, outs, op, place, attrs, return PrepareImpl<egr::EagerVariable>(
default_attrs); ins, outs, op, place, attrs, default_attrs, phi_kernel_factory,
phi_op_utils_map, default_phi_kernel_sig_map);
} }
template <typename VarType> template <typename VarType>
static void PreparedOpRunImpl( static void PreparedOpRunImpl(
...@@ -441,7 +453,6 @@ static void PreparedOpRunImpl( ...@@ -441,7 +453,6 @@ static void PreparedOpRunImpl(
const NameVarMap<VarType>& outs, const framework::AttributeMap& attrs, const NameVarMap<VarType>& outs, const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs) { const framework::AttributeMap& default_attrs) {
// TODO(zjl): remove scope in dygraph // TODO(zjl): remove scope in dygraph
framework::Scope scope;
{ {
platform::RecordEvent record_event("infer_shape", platform::RecordEvent record_event("infer_shape",
...@@ -458,8 +469,8 @@ static void PreparedOpRunImpl( ...@@ -458,8 +469,8 @@ static void PreparedOpRunImpl(
platform::TracerEventType::OperatorInner, platform::TracerEventType::OperatorInner,
1, platform::EventRole::kInnerOp); 1, platform::EventRole::kInnerOp);
func(DygraphExecutionContext<VarType>(op, scope, *dev_ctx, ctx, ins, outs, func(DygraphExecutionContext<VarType>(op, empty_scope, *dev_ctx, ctx, ins,
attrs, default_attrs)); outs, attrs, default_attrs));
} }
if (FLAGS_check_nan_inf) { if (FLAGS_check_nan_inf) {
...@@ -503,7 +514,7 @@ static void PreparedOpRunPtImpl( ...@@ -503,7 +514,7 @@ static void PreparedOpRunPtImpl(
const NameVarMap<VarType>& outs, const framework::AttributeMap& attrs, const NameVarMap<VarType>& outs, const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs) { const framework::AttributeMap& default_attrs) {
{ {
platform::RecordEvent record_event(op.Type() + "::infer_shape", platform::RecordEvent record_event("infer_shape",
platform::TracerEventType::OperatorInner, platform::TracerEventType::OperatorInner,
1, platform::EventRole::kInnerOp); 1, platform::EventRole::kInnerOp);
DygraphInferShapeContext<VarType> infer_shape_ctx( DygraphInferShapeContext<VarType> infer_shape_ctx(
...@@ -513,7 +524,7 @@ static void PreparedOpRunPtImpl( ...@@ -513,7 +524,7 @@ static void PreparedOpRunPtImpl(
} }
{ {
platform::RecordEvent record_event(op.Type() + "::compute", platform::RecordEvent record_event("compute",
platform::TracerEventType::OperatorInner, platform::TracerEventType::OperatorInner,
1, platform::EventRole::kInnerOp); 1, platform::EventRole::kInnerOp);
......
...@@ -214,6 +214,10 @@ class PreparedOp { ...@@ -214,6 +214,10 @@ class PreparedOp {
const phi::KernelSignature* default_kernel_signature_; const phi::KernelSignature* default_kernel_signature_;
phi::KernelSignature kernel_signature_; phi::KernelSignature kernel_signature_;
const phi::Kernel& phi_kernel_; const phi::Kernel& phi_kernel_;
static const phi::KernelFactory& phi_kernel_factory;
static const phi::OpUtilsMap& phi_op_utils_map;
static const phi::DefaultKernelSignatureMap& default_phi_kernel_sig_map;
}; };
const inline framework::Attribute& GetAttr( const inline framework::Attribute& GetAttr(
......
...@@ -192,7 +192,7 @@ void Tracer::TraceOpImpl(const std::string& type, ...@@ -192,7 +192,7 @@ void Tracer::TraceOpImpl(const std::string& type,
paddle::framework::AttributeMap* passed_default_attrs_, paddle::framework::AttributeMap* passed_default_attrs_,
bool use_default_attr_map) { bool use_default_attr_map) {
platform::RecordEvent op_type_record_event( platform::RecordEvent op_type_record_event(
type + " trace_op", platform::TracerEventType::Operator, 1); "trace_op", platform::TracerEventType::Operator, 1);
platform::ScopedFlushDenormal flush; platform::ScopedFlushDenormal flush;
VLOG(1) << "Trace Op: " << type; VLOG(1) << "Trace Op: " << type;
if (FLAGS_use_mkldnn) { if (FLAGS_use_mkldnn) {
......
...@@ -28,23 +28,24 @@ namespace phi { ...@@ -28,23 +28,24 @@ namespace phi {
Backend TransToPhiBackend(const phi::Place& place) { Backend TransToPhiBackend(const phi::Place& place) {
auto allocation_type = place.GetType(); auto allocation_type = place.GetType();
if (allocation_type == phi::AllocationType::CPU) { switch (allocation_type) {
return Backend::CPU; case phi::AllocationType::GPU:
} else if (allocation_type == phi::AllocationType::GPU) {
return Backend::GPU; return Backend::GPU;
} else if (allocation_type == phi::AllocationType::GPUPINNED) { case AllocationType::CPU:
return Backend::CPU;
case AllocationType::GPUPINNED:
return Backend::GPU; return Backend::GPU;
} else if (allocation_type == phi::AllocationType::XPU) { case AllocationType::XPU:
return Backend::XPU; return Backend::XPU;
} else if (allocation_type == phi::AllocationType::NPU) { case AllocationType::NPU:
return Backend::NPU; return Backend::NPU;
} else if (allocation_type == phi::AllocationType::IPU) { case AllocationType::IPU:
return Backend::IPU; return Backend::IPU;
} else if (allocation_type == phi::AllocationType::CUSTOM) { case AllocationType::CUSTOM:
return static_cast<Backend>( return static_cast<Backend>(
static_cast<size_t>(Backend::NUM_BACKENDS) + static_cast<size_t>(Backend::NUM_BACKENDS) +
GetOrRegisterGlobalDeviceTypeId(place.GetDeviceType())); GetOrRegisterGlobalDeviceTypeId(place.GetDeviceType()));
} else { default:
PADDLE_THROW(phi::errors::InvalidArgument( PADDLE_THROW(phi::errors::InvalidArgument(
"Unsupported transform %s to phi Backend.", place)); "Unsupported transform %s to phi Backend.", place));
} }
......
...@@ -129,7 +129,6 @@ void* DenseTensor::AllocateFrom(Allocator* allocator, ...@@ -129,7 +129,6 @@ void* DenseTensor::AllocateFrom(Allocator* allocator,
template <typename T> template <typename T>
const T* DenseTensor::data() const { const T* DenseTensor::data() const {
check_memory_size();
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
dtype(), dtype(),
paddle::experimental::CppTypeToDataType<T>::Type(), paddle::experimental::CppTypeToDataType<T>::Type(),
...@@ -141,13 +140,13 @@ const T* DenseTensor::data() const { ...@@ -141,13 +140,13 @@ const T* DenseTensor::data() const {
template <typename T> template <typename T>
T* DenseTensor::data() { T* DenseTensor::data() {
check_memory_size(); T* ret = static_cast<T*>(data());
PADDLE_ENFORCE( PADDLE_ENFORCE(
(dtype() == paddle::experimental::CppTypeToDataType<T>::Type()), (dtype() == paddle::experimental::CppTypeToDataType<T>::Type()),
phi::errors::InvalidArgument( phi::errors::InvalidArgument(
"The type of data we are trying to retrieve does not match the " "The type of data we are trying to retrieve does not match the "
"type of data currently contained in the container.")); "type of data currently contained in the container."));
return static_cast<T*>(data()); return ret;
} }
void* DenseTensor::data() { void* DenseTensor::data() {
......
...@@ -75,7 +75,7 @@ namespace phi { ...@@ -75,7 +75,7 @@ namespace phi {
"Kernel's Input should appear before Attributes."); \ "Kernel's Input should appear before Attributes."); \
static_assert(out_idx == 0, \ static_assert(out_idx == 0, \
"Kernel's Input should appear before Outputs."); \ "Kernel's Input should appear before Outputs."); \
const std::pair<int, int> range = ctx->InputRangeAt(in_idx); \ const std::pair<int, int>& range = ctx->InputRangeAt(in_idx); \
const tensor_type& arg = ctx->InputAt<tensor_type>(range.first); \ const tensor_type& arg = ctx->InputAt<tensor_type>(range.first); \
KernelCallHelper<Tail...>:: \ KernelCallHelper<Tail...>:: \
template Compute<dev_ctx_idx, in_idx + 1, attr_idx, out_idx>( \ template Compute<dev_ctx_idx, in_idx + 1, attr_idx, out_idx>( \
...@@ -96,7 +96,7 @@ namespace phi { ...@@ -96,7 +96,7 @@ namespace phi {
"Kernel's Input should appear before Attributes."); \ "Kernel's Input should appear before Attributes."); \
static_assert(out_idx == 0, \ static_assert(out_idx == 0, \
"Kernel's Input should appear before Outputs."); \ "Kernel's Input should appear before Outputs."); \
const std::pair<int, int> range = ctx->InputRangeAt(in_idx); \ const std::pair<int, int>& range = ctx->InputRangeAt(in_idx); \
auto arg = ctx->OptionalInputAt<tensor_type>(range.first); \ auto arg = ctx->OptionalInputAt<tensor_type>(range.first); \
KernelCallHelper<Tail...>:: \ KernelCallHelper<Tail...>:: \
template Compute<dev_ctx_idx, in_idx + 1, attr_idx, out_idx>( \ template Compute<dev_ctx_idx, in_idx + 1, attr_idx, out_idx>( \
...@@ -117,7 +117,7 @@ namespace phi { ...@@ -117,7 +117,7 @@ namespace phi {
"Kernel's Input should appear before Attributes."); \ "Kernel's Input should appear before Attributes."); \
static_assert(out_idx == 0, \ static_assert(out_idx == 0, \
"Kernel's Input should appear before Outputs."); \ "Kernel's Input should appear before Outputs."); \
const std::pair<int, int> range = ctx->InputRangeAt(in_idx); \ const std::pair<int, int>& range = ctx->InputRangeAt(in_idx); \
std::vector<const tensor_type*> arg = std::move( \ std::vector<const tensor_type*> arg = std::move( \
ctx->InputsBetween<tensor_type>(range.first, range.second)); \ ctx->InputsBetween<tensor_type>(range.first, range.second)); \
KernelCallHelper<Tail...>:: \ KernelCallHelper<Tail...>:: \
...@@ -141,7 +141,7 @@ namespace phi { ...@@ -141,7 +141,7 @@ namespace phi {
"Kernel's Input should appear before Attributes."); \ "Kernel's Input should appear before Attributes."); \
static_assert(out_idx == 0, \ static_assert(out_idx == 0, \
"Kernel's Input should appear before Outputs."); \ "Kernel's Input should appear before Outputs."); \
const std::pair<int, int> range = ctx->InputRangeAt(in_idx); \ const std::pair<int, int>& range = ctx->InputRangeAt(in_idx); \
paddle::optional<const std::vector<const tensor_type*>> arg = \ paddle::optional<const std::vector<const tensor_type*>> arg = \
ctx->OptionalInputsBetween<tensor_type>(range.first, range.second); \ ctx->OptionalInputsBetween<tensor_type>(range.first, range.second); \
KernelCallHelper<Tail...>:: \ KernelCallHelper<Tail...>:: \
...@@ -195,7 +195,7 @@ namespace phi { ...@@ -195,7 +195,7 @@ namespace phi {
int out_idx, \ int out_idx, \
typename... PreviousArgs> \ typename... PreviousArgs> \
static void Compute(KernelContext* ctx, PreviousArgs&... pargs) { \ static void Compute(KernelContext* ctx, PreviousArgs&... pargs) { \
const std::pair<int, int> range = ctx->OutputRangeAt(out_idx); \ const std::pair<int, int>& range = ctx->OutputRangeAt(out_idx); \
tensor_type* arg = ctx->MutableOutputAt<tensor_type>(range.first); \ tensor_type* arg = ctx->MutableOutputAt<tensor_type>(range.first); \
KernelCallHelper<Tail...>:: \ KernelCallHelper<Tail...>:: \
template Compute<dev_ctx_idx, in_idx, attr_idx, out_idx + 1>( \ template Compute<dev_ctx_idx, in_idx, attr_idx, out_idx + 1>( \
...@@ -212,7 +212,7 @@ namespace phi { ...@@ -212,7 +212,7 @@ namespace phi {
int out_idx, \ int out_idx, \
typename... PreviousArgs> \ typename... PreviousArgs> \
static void Compute(KernelContext* ctx, PreviousArgs&... pargs) { \ static void Compute(KernelContext* ctx, PreviousArgs&... pargs) { \
const std::pair<int, int> range = ctx->OutputRangeAt(out_idx); \ const std::pair<int, int>& range = ctx->OutputRangeAt(out_idx); \
std::vector<tensor_type*> arg = std::move( \ std::vector<tensor_type*> arg = std::move( \
ctx->MutableOutputBetween<tensor_type>(range.first, range.second)); \ ctx->MutableOutputBetween<tensor_type>(range.first, range.second)); \
KernelCallHelper<Tail...>:: \ KernelCallHelper<Tail...>:: \
......
...@@ -554,6 +554,7 @@ void BroadcastKernel(const KPDevice &ctx, ...@@ -554,6 +554,7 @@ void BroadcastKernel(const KPDevice &ctx,
int axis, int axis,
Functor func) { Functor func) {
std::vector<int> dims_size; std::vector<int> dims_size;
dims_size.reserve(ins.size());
bool no_broadcast_flag = true; bool no_broadcast_flag = true;
for (auto *in : ins) { for (auto *in : ins) {
no_broadcast_flag &= ins[0]->dims() == in->dims(); no_broadcast_flag &= ins[0]->dims() == in->dims();
......
...@@ -28,7 +28,9 @@ namespace phi { ...@@ -28,7 +28,9 @@ namespace phi {
int axis, \ int axis, \
DenseTensor* out) { \ DenseTensor* out) { \
std::vector<const DenseTensor*> inputs; \ std::vector<const DenseTensor*> inputs; \
inputs.reserve(2); \
std::vector<DenseTensor*> outputs; \ std::vector<DenseTensor*> outputs; \
outputs.reserve(1); \
inputs.emplace_back(&x); \ inputs.emplace_back(&x); \
inputs.emplace_back(&y); \ inputs.emplace_back(&y); \
outputs.emplace_back(out); \ outputs.emplace_back(out); \
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册