diff --git a/paddle/pten/core/kernel_context.h b/paddle/pten/core/kernel_context.h index b6459d9b7069567461dbc5f40f45645039a9fd30..ac1ed668f7bf5abbd3f0a9724a2921bb8a96bb41 100644 --- a/paddle/pten/core/kernel_context.h +++ b/paddle/pten/core/kernel_context.h @@ -52,37 +52,37 @@ class KernelContext { } void EmplaceBackInput(std::shared_ptr input) { + int index = inputs_.size(); inputs_.emplace_back(std::move(input)); // Record the start and end index of the input - int index = inputs_.size(); input_range_.emplace_back(std::pair(index, index + 1)); } void EmplaceBackInputs( - paddle::SmallVector> inputs) { + const paddle::SmallVector>& inputs) { + int index = inputs_.size(); for (auto in : inputs) { - inputs_.emplace_back(in); + inputs_.emplace_back(std::move(in)); } // Record the start and end index of the input - int index = inputs_.size(); input_range_.emplace_back( std::pair(index, index + inputs.size())); } void EmplaceBackOutput(std::shared_ptr output) { + int index = outputs_.size(); outputs_.emplace_back(std::move(output)); // Record the start and end index of the input - int index = outputs_.size(); output_range_.emplace_back(std::pair(index, index + 1)); } void EmplaceBackOutputs( - paddle::SmallVector> outputs) { + const paddle::SmallVector>& outputs) { + int index = outputs_.size(); for (auto out : outputs) { - outputs_.emplace_back(out); + outputs_.emplace_back(std::move(out)); } // Record the start and end index of the input - int index = outputs_.size(); output_range_.emplace_back( std::pair(index, index + outputs.size())); } @@ -96,11 +96,40 @@ class KernelContext { return static_cast(*(inputs_.at(idx))); } + template + std::vector InputBetween(size_t start, size_t end) const { + std::vector v; + for (size_t i = start; i < end; ++i) { + auto t = std::dynamic_pointer_cast(inputs_.at(i)); + v.emplace_back(std::move(*t.get())); + } + + return v; + } + + const std::pair& InputRangeAt(size_t idx) const { + return input_range_.at(idx); + } + + const std::pair& OutputRangeAt(size_t idx) const { + return output_range_.at(idx); + } + template TensorType* MutableOutputAt(size_t idx) { return static_cast(outputs_.at(idx).get()); } + template + std::vector MutableOutputBetween(size_t start, size_t end) { + std::vector v; + for (size_t i = start; i < end; ++i) { + v.emplace_back(static_cast(outputs_.at(i).get())); + } + + return v; + } + template AttrType AttrAt(size_t idx) const { try { diff --git a/paddle/pten/core/kernel_registry.h b/paddle/pten/core/kernel_registry.h index d3422d173a3db257b05b7b101786f7c5394dd7f0..c2b97148aa5fb1941d5bac0a9e366c70bb6f1149 100644 --- a/paddle/pten/core/kernel_registry.h +++ b/paddle/pten/core/kernel_registry.h @@ -62,9 +62,17 @@ struct KernelArgsParseFunctor { } else if (arg_type == std::type_index(typeid(const DenseTensor&))) { args_def->AppendInput( default_key.backend(), default_tensor_layout, default_key.dtype()); + } else if (arg_type == + std::type_index(typeid(const std::vector&))) { + args_def->AppendInput( + default_key.backend(), default_tensor_layout, default_key.dtype()); } else if (arg_type == std::type_index(typeid(DenseTensor*))) { args_def->AppendOutput( default_key.backend(), default_tensor_layout, default_key.dtype()); + } else if (arg_type == + std::type_index(typeid(std::vector))) { + args_def->AppendOutput( + default_key.backend(), default_tensor_layout, default_key.dtype()); } else { // Attribute deal with // TODO(chenweihang): now here allow any types of attribute, maybe diff --git a/paddle/pten/core/kernel_utils.h b/paddle/pten/core/kernel_utils.h index c45a81206323e96cab2d04e2df5f639681a0ab96..450202607648dbe8dd59846c3a9abc40ff38ce03 100644 --- a/paddle/pten/core/kernel_utils.h +++ b/paddle/pten/core/kernel_utils.h @@ -79,7 +79,30 @@ using XPUContext = paddle::platform::XPUDeviceContext; "Kernel's Input should appear before Attributes."); \ static_assert(out_idx == 0, \ "Kernel's Input should appear before Outputs."); \ - const tensor_type& arg = ctx->InputAt(in_idx); \ + const std::pair range = ctx->InputRangeAt(in_idx); \ + const tensor_type& arg = ctx->InputAt(range.first); \ + KernelCallHelper:: \ + template Compute( \ + ctx, pargs..., arg); \ + } \ + } + +#define PT_SPECIALIZE_KernelCallHelper_FOR_MULTI_INPUT(tensor_type) \ + template \ + struct KernelCallHelper&, Tail...> { \ + template \ + static void Compute(KernelContext* ctx, PreviousArgs&... pargs) { \ + static_assert(attr_idx == 0, \ + "Kernel's Input should appear before Attributes."); \ + static_assert(out_idx == 0, \ + "Kernel's Input should appear before Outputs."); \ + const std::pair range = ctx->InputRangeAt(in_idx); \ + std::vector arg = std::move( \ + ctx->InputBetween(range.first, range.second)); \ KernelCallHelper:: \ template Compute( \ ctx, pargs..., arg); \ @@ -104,20 +127,39 @@ using XPUContext = paddle::platform::XPUDeviceContext; } \ } -#define PT_SPECIALIZE_KernelCallHelper_FOR_OUTPUT(tensor_type) \ - template \ - struct KernelCallHelper { \ - template \ - static void Compute(KernelContext* ctx, PreviousArgs&... pargs) { \ - tensor_type* arg = ctx->MutableOutputAt(out_idx); \ - KernelCallHelper:: \ - template Compute( \ - ctx, pargs..., arg); \ - } \ +#define PT_SPECIALIZE_KernelCallHelper_FOR_OUTPUT(tensor_type) \ + template \ + struct KernelCallHelper { \ + template \ + static void Compute(KernelContext* ctx, PreviousArgs&... pargs) { \ + const std::pair range = ctx->OutputRangeAt(out_idx); \ + tensor_type* arg = ctx->MutableOutputAt(range.first); \ + KernelCallHelper:: \ + template Compute( \ + ctx, pargs..., arg); \ + } \ + } + +#define PT_SPECIALIZE_KernelCallHelper_FOR_MULTI_OUTPUT(tensor_type) \ + template \ + struct KernelCallHelper, Tail...> { \ + template \ + static void Compute(KernelContext* ctx, PreviousArgs&... pargs) { \ + const std::pair range = ctx->OutputRangeAt(out_idx); \ + std::vector arg = std::move( \ + ctx->MutableOutputBetween(range.first, range.second)); \ + KernelCallHelper:: \ + template Compute( \ + ctx, pargs..., arg); \ + } \ } template @@ -152,6 +194,7 @@ struct KernelImpl { /* Input Helpers */ PT_SPECIALIZE_KernelCallHelper_FOR_INPUT(DenseTensor); + PT_SPECIALIZE_KernelCallHelper_FOR_MULTI_INPUT(DenseTensor); // TODO(chenweihang): adapt SelectedRows // PT_SPECIALIZE_KernelCallHelper_FOR_INPUT(SelectedRowsTensor); @@ -168,6 +211,7 @@ struct KernelImpl { /* Output Helpers */ PT_SPECIALIZE_KernelCallHelper_FOR_OUTPUT(DenseTensor); + PT_SPECIALIZE_KernelCallHelper_FOR_MULTI_OUTPUT(DenseTensor); // TODO(chenweihang): adapt SelectedRows // PT_SPECIALIZE_KernelCallHelper_FOR_OUTPUT(SelectedRowsTensor); diff --git a/paddle/pten/hapi/lib/kernel_dispatch.h b/paddle/pten/hapi/lib/kernel_dispatch.h index d7190076bf3f68bd2bc3bd7ab4b9ec3b4762b934..f61f3297d6d6c872b244a27319c6de7b16cabff4 100644 --- a/paddle/pten/hapi/lib/kernel_dispatch.h +++ b/paddle/pten/hapi/lib/kernel_dispatch.h @@ -122,6 +122,14 @@ struct KernelKeyParser : ArgsIterator { key_set.dtype = x.type(); } + void operator()(const std::vector& x) { + key_set.backend_set = + key_set.backend_set | detail::GetTensorBackendSet(x[0]); + // TODO(chenweihang): selecte multi layout and dtype + key_set.layout = x[0].layout(); + key_set.dtype = x[0].type(); + } + // skip other type args, these args don't used in kernel selection template void operator()(const T& x) {