diff --git a/paddle/pten/core/kernel_context.cc b/paddle/pten/core/kernel_context.cc index 443990c07247dc8796f6c845b55ef5699f33de4f..b2c84807951a52cb71adb0ef2f98a158a541ac84 100644 --- a/paddle/pten/core/kernel_context.cc +++ b/paddle/pten/core/kernel_context.cc @@ -14,4 +14,114 @@ #include "paddle/pten/core/kernel_context.h" -namespace pten {} // namespace pten +namespace pten { + +void KernelContext::EmplaceBackInput(std::shared_ptr input) { + int index = inputs_.size(); + inputs_.emplace_back(std::move(input)); + // Record the start and end index of the input + input_range_.emplace_back(std::pair(index, index + 1)); +} + +void KernelContext::EmplaceBackInputWithoutSetRange( + std::shared_ptr input) { + inputs_.emplace_back(std::move(input)); +} + +void KernelContext::EmplaceBackInputs( + paddle::SmallVector> inputs) { + int index = inputs_.size(); + // Record the start and end index of the input + input_range_.emplace_back(std::pair(index, index + inputs.size())); + inputs_.insert(inputs_.end(), + std::make_move_iterator(inputs.begin()), + std::make_move_iterator(inputs.end())); +} + +void KernelContext::EmplaceBackOutput(std::shared_ptr output) { + int index = outputs_.size(); + outputs_.emplace_back(std::move(output)); + // Record the start and end index of the input + output_range_.emplace_back(std::pair(index, index + 1)); +} + +void KernelContext::EmplaceBackOutputWithoutSetRange( + std::shared_ptr output) { + outputs_.emplace_back(std::move(output)); +} + +void KernelContext::EmplaceBackOutputs( + paddle::SmallVector> outputs) { + int index = outputs_.size(); + // Record the start and end index of the input + output_range_.emplace_back( + std::pair(index, index + outputs.size())); + outputs_.insert(outputs_.end(), + std::make_move_iterator(outputs.begin()), + std::make_move_iterator(outputs.end())); +} + +void KernelContext::EmplaceBackAttr(paddle::any attr) { + attrs_.emplace_back(std::move(attr)); +} + +void KernelContext::AssignInputRange(std::pair&& range, size_t idx) { + if (idx < input_range_.size()) { + input_range_[idx] = range; + } else if (idx == input_range_.size()) { + input_range_.emplace_back(range); + } else { + PADDLE_THROW(paddle::platform::errors::PreconditionNotMet( + "Invalid idx when trying to set InputRange, " + "index is `%d`, it is greater than the size(%d) of InputRange.", + idx, + input_range_.size())); + } +} + +void KernelContext::AssignOutputRange(std::pair&& range, size_t idx) { + if (idx < output_range_.size()) { + output_range_[idx] = range; + } else if (idx == output_range_.size()) { + output_range_.emplace_back(range); + } else { + PADDLE_THROW(paddle::platform::errors::PreconditionNotMet( + "Invalid idx when trying to set InputRange, " + "index is `%d`, it is greater than the size(%d) of InputRange.", + idx, + output_range_.size())); + } +} + +const std::pair& KernelContext::InputRangeAt(size_t idx) const { + return input_range_.at(idx); +} + +const std::pair& KernelContext::OutputRangeAt(size_t idx) const { + return output_range_.at(idx); +} + +std::pair& KernelContext::MutableInputRangeAt(size_t idx) { + return input_range_[idx]; +} + +std::pair& KernelContext::MutableOutputRangeAt(size_t idx) { + return output_range_[idx]; +} + +// Temporary method: For compatible with fluid Tensor and improve performance +// Only deal with DenseTensor now +void KernelContext::ClearData() { + for (auto& in : inputs_) { + if (in) { + CompatibleDenseTensorUtils::ClearStorage( + static_cast(in.get())); + } + } + for (auto& out : outputs_) { + CompatibleDenseTensorUtils::ClearStorage( + static_cast(out.get())); + } + attrs_.clear(); +} +} // namespace pten diff --git a/paddle/pten/core/kernel_context.h b/paddle/pten/core/kernel_context.h index 8a87a5b735e99eb13114e9bd60777aff7e18ac7f..6c695987096cb7635aac2bdc46ddf229fedce19c 100644 --- a/paddle/pten/core/kernel_context.h +++ b/paddle/pten/core/kernel_context.h @@ -51,53 +51,29 @@ class KernelContext { return static_cast(*dev_ctx_); } - void EmplaceBackInput(std::shared_ptr input) { - int index = inputs_.size(); - inputs_.emplace_back(std::move(input)); - // Record the start and end index of the input - input_range_.emplace_back(std::pair(index, index + 1)); - } + void EmplaceBackInput(std::shared_ptr input); - void EmplaceBackInputWithoutSetRange(std::shared_ptr input) { - inputs_.emplace_back(std::move(input)); - } + void EmplaceBackInputWithoutSetRange(std::shared_ptr input); void EmplaceBackInputs( - paddle::SmallVector> inputs) { - int index = inputs_.size(); - // Record the start and end index of the input - input_range_.emplace_back( - std::pair(index, index + inputs.size())); - inputs_.insert(inputs_.end(), - std::make_move_iterator(inputs.begin()), - std::make_move_iterator(inputs.end())); - } + paddle::SmallVector> inputs); - void EmplaceBackOutput(std::shared_ptr output) { - int index = outputs_.size(); - outputs_.emplace_back(std::move(output)); - // Record the start and end index of the input - output_range_.emplace_back(std::pair(index, index + 1)); - } + void EmplaceBackOutput(std::shared_ptr output); - void EmplaceBackOutputWithoutSetRange(std::shared_ptr output) { - outputs_.emplace_back(std::move(output)); - } + void EmplaceBackOutputWithoutSetRange(std::shared_ptr output); void EmplaceBackOutputs( - paddle::SmallVector> outputs) { - int index = outputs_.size(); - // Record the start and end index of the input - output_range_.emplace_back( - std::pair(index, index + outputs.size())); - outputs_.insert(outputs_.end(), - std::make_move_iterator(outputs.begin()), - std::make_move_iterator(outputs.end())); - } + paddle::SmallVector> outputs); - void EmplaceBackAttr(paddle::any attr) { - attrs_.emplace_back(std::move(attr)); - } + void EmplaceBackAttr(paddle::any attr); + + const std::pair& InputRangeAt(size_t idx) const; + + const std::pair& OutputRangeAt(size_t idx) const; + + std::pair& MutableInputRangeAt(size_t idx); + + std::pair& MutableOutputRangeAt(size_t idx); template const TensorType& InputAt(size_t idx) const { @@ -119,41 +95,9 @@ class KernelContext { return v; } - const std::pair& InputRangeAt(size_t idx) const { - return input_range_.at(idx); - } - - const std::pair& OutputRangeAt(size_t idx) const { - return output_range_.at(idx); - } + void AssignInputRange(std::pair&& range, size_t idx); - void AssignInputRange(std::pair&& range, size_t idx) { - if (idx < input_range_.size()) { - input_range_[idx] = range; - } else if (idx == input_range_.size()) { - input_range_.emplace_back(range); - } else { - PADDLE_THROW(paddle::platform::errors::PreconditionNotMet( - "Invalid idx when trying to set InputRange, " - "index is `%d`, it is greater than the size(%d) of InputRange.", - idx, - input_range_.size())); - } - } - - void AssignOutputRange(std::pair&& range, size_t idx) { - if (idx < output_range_.size()) { - output_range_[idx] = range; - } else if (idx == output_range_.size()) { - output_range_.emplace_back(range); - } else { - PADDLE_THROW(paddle::platform::errors::PreconditionNotMet( - "Invalid idx when trying to set InputRange, " - "index is `%d`, it is greater than the size(%d) of InputRange.", - idx, - output_range_.size())); - } - } + void AssignOutputRange(std::pair&& range, size_t idx); template TensorType* MutableInputAt(size_t idx) { @@ -187,19 +131,7 @@ class KernelContext { // Temporary method: For compatible with fluid Tensor and improve performance // Only deal with DenseTensor now - void ClearData() { - for (auto& in : inputs_) { - if (in) { - CompatibleDenseTensorUtils::ClearStorage( - static_cast(in.get())); - } - } - for (auto& out : outputs_) { - CompatibleDenseTensorUtils::ClearStorage( - static_cast(out.get())); - } - attrs_.clear(); - } + void ClearData(); size_t InputsSize() const { return inputs_.size(); } size_t OutputsSize() const { return outputs_.size(); } diff --git a/paddle/pten/kernels/cpu/manipulation.h b/paddle/pten/kernels/cpu/manipulation.h index 3dce249c54532c4364c47ccc73eb63f65cefbd32..36f9aaa85aa5e3b9b8cc62f843d046d0ee3824e8 100644 --- a/paddle/pten/kernels/cpu/manipulation.h +++ b/paddle/pten/kernels/cpu/manipulation.h @@ -15,6 +15,8 @@ limitations under the License. */ #pragma once #include "paddle/pten/core/dense_tensor.h" +#include "paddle/pten/core/kernel_registry.h" + // See Note [ Why still include the fluid headers? ] #include "paddle/fluid/platform/device_context.h" diff --git a/paddle/pten/kernels/cuda/manipulation.h b/paddle/pten/kernels/cuda/manipulation.h index bb724beb2e34b95c4e62832489d4b0e387a2a842..c0f2d8a11414e6665ae7e1e74a78d1e106603e94 100644 --- a/paddle/pten/kernels/cuda/manipulation.h +++ b/paddle/pten/kernels/cuda/manipulation.h @@ -18,6 +18,7 @@ #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #include "paddle/pten/core/dense_tensor.h" +#include "paddle/pten/core/kernel_registry.h" // See Note [ Why still include the fluid headers? ] #include "paddle/fluid/platform/device_context.h"