未验证 提交 bfa0d7f3 编写于 作者: Y YuanRisheng 提交者: GitHub

[Pten]Move func from kernel_context.h into kernel_context.cc (#37804)

* add inplace op adaptation

* optimize inplace logic and fix bugs when run kernel that has args of vector<DenseTensor>

* move func in kernel_context.h into kernel_context.cc

* refactor logic that transform variable to densetensor

* fix bugs when compile

* update func name

* fix bugs when run windows-ci
上级 b3185296
......@@ -14,4 +14,114 @@
#include "paddle/pten/core/kernel_context.h"
namespace pten {} // namespace pten
namespace pten {
void KernelContext::EmplaceBackInput(std::shared_ptr<TensorBase> input) {
int index = inputs_.size();
inputs_.emplace_back(std::move(input));
// Record the start and end index of the input
input_range_.emplace_back(std::pair<int, int>(index, index + 1));
}
void KernelContext::EmplaceBackInputWithoutSetRange(
std::shared_ptr<TensorBase> input) {
inputs_.emplace_back(std::move(input));
}
void KernelContext::EmplaceBackInputs(
paddle::SmallVector<std::shared_ptr<TensorBase>> inputs) {
int index = inputs_.size();
// Record the start and end index of the input
input_range_.emplace_back(std::pair<int, int>(index, index + inputs.size()));
inputs_.insert(inputs_.end(),
std::make_move_iterator(inputs.begin()),
std::make_move_iterator(inputs.end()));
}
void KernelContext::EmplaceBackOutput(std::shared_ptr<TensorBase> output) {
int index = outputs_.size();
outputs_.emplace_back(std::move(output));
// Record the start and end index of the input
output_range_.emplace_back(std::pair<int, int>(index, index + 1));
}
void KernelContext::EmplaceBackOutputWithoutSetRange(
std::shared_ptr<TensorBase> output) {
outputs_.emplace_back(std::move(output));
}
void KernelContext::EmplaceBackOutputs(
paddle::SmallVector<std::shared_ptr<TensorBase>> outputs) {
int index = outputs_.size();
// Record the start and end index of the input
output_range_.emplace_back(
std::pair<int, int>(index, index + outputs.size()));
outputs_.insert(outputs_.end(),
std::make_move_iterator(outputs.begin()),
std::make_move_iterator(outputs.end()));
}
void KernelContext::EmplaceBackAttr(paddle::any attr) {
attrs_.emplace_back(std::move(attr));
}
void KernelContext::AssignInputRange(std::pair<int, int>&& range, size_t idx) {
if (idx < input_range_.size()) {
input_range_[idx] = range;
} else if (idx == input_range_.size()) {
input_range_.emplace_back(range);
} else {
PADDLE_THROW(paddle::platform::errors::PreconditionNotMet(
"Invalid idx when trying to set InputRange, "
"index is `%d`, it is greater than the size(%d) of InputRange.",
idx,
input_range_.size()));
}
}
void KernelContext::AssignOutputRange(std::pair<int, int>&& range, size_t idx) {
if (idx < output_range_.size()) {
output_range_[idx] = range;
} else if (idx == output_range_.size()) {
output_range_.emplace_back(range);
} else {
PADDLE_THROW(paddle::platform::errors::PreconditionNotMet(
"Invalid idx when trying to set InputRange, "
"index is `%d`, it is greater than the size(%d) of InputRange.",
idx,
output_range_.size()));
}
}
const std::pair<int, int>& KernelContext::InputRangeAt(size_t idx) const {
return input_range_.at(idx);
}
const std::pair<int, int>& KernelContext::OutputRangeAt(size_t idx) const {
return output_range_.at(idx);
}
std::pair<int, int>& KernelContext::MutableInputRangeAt(size_t idx) {
return input_range_[idx];
}
std::pair<int, int>& KernelContext::MutableOutputRangeAt(size_t idx) {
return output_range_[idx];
}
// Temporary method: For compatible with fluid Tensor and improve performance
// Only deal with DenseTensor now
void KernelContext::ClearData() {
for (auto& in : inputs_) {
if (in) {
CompatibleDenseTensorUtils::ClearStorage(
static_cast<DenseTensor*>(in.get()));
}
}
for (auto& out : outputs_) {
CompatibleDenseTensorUtils::ClearStorage(
static_cast<DenseTensor*>(out.get()));
}
attrs_.clear();
}
} // namespace pten
......@@ -51,53 +51,29 @@ class KernelContext {
return static_cast<const CtxType&>(*dev_ctx_);
}
void EmplaceBackInput(std::shared_ptr<TensorBase> input) {
int index = inputs_.size();
inputs_.emplace_back(std::move(input));
// Record the start and end index of the input
input_range_.emplace_back(std::pair<int, int>(index, index + 1));
}
void EmplaceBackInput(std::shared_ptr<TensorBase> input);
void EmplaceBackInputWithoutSetRange(std::shared_ptr<TensorBase> input) {
inputs_.emplace_back(std::move(input));
}
void EmplaceBackInputWithoutSetRange(std::shared_ptr<TensorBase> input);
void EmplaceBackInputs(
paddle::SmallVector<std::shared_ptr<TensorBase>> inputs) {
int index = inputs_.size();
// Record the start and end index of the input
input_range_.emplace_back(
std::pair<int, int>(index, index + inputs.size()));
inputs_.insert(inputs_.end(),
std::make_move_iterator(inputs.begin()),
std::make_move_iterator(inputs.end()));
}
paddle::SmallVector<std::shared_ptr<TensorBase>> inputs);
void EmplaceBackOutput(std::shared_ptr<TensorBase> output) {
int index = outputs_.size();
outputs_.emplace_back(std::move(output));
// Record the start and end index of the input
output_range_.emplace_back(std::pair<int, int>(index, index + 1));
}
void EmplaceBackOutput(std::shared_ptr<TensorBase> output);
void EmplaceBackOutputWithoutSetRange(std::shared_ptr<TensorBase> output) {
outputs_.emplace_back(std::move(output));
}
void EmplaceBackOutputWithoutSetRange(std::shared_ptr<TensorBase> output);
void EmplaceBackOutputs(
paddle::SmallVector<std::shared_ptr<TensorBase>> outputs) {
int index = outputs_.size();
// Record the start and end index of the input
output_range_.emplace_back(
std::pair<int, int>(index, index + outputs.size()));
outputs_.insert(outputs_.end(),
std::make_move_iterator(outputs.begin()),
std::make_move_iterator(outputs.end()));
}
paddle::SmallVector<std::shared_ptr<TensorBase>> outputs);
void EmplaceBackAttr(paddle::any attr) {
attrs_.emplace_back(std::move(attr));
}
void EmplaceBackAttr(paddle::any attr);
const std::pair<int, int>& InputRangeAt(size_t idx) const;
const std::pair<int, int>& OutputRangeAt(size_t idx) const;
std::pair<int, int>& MutableInputRangeAt(size_t idx);
std::pair<int, int>& MutableOutputRangeAt(size_t idx);
template <typename TensorType>
const TensorType& InputAt(size_t idx) const {
......@@ -119,41 +95,9 @@ class KernelContext {
return v;
}
const std::pair<int, int>& InputRangeAt(size_t idx) const {
return input_range_.at(idx);
}
const std::pair<int, int>& OutputRangeAt(size_t idx) const {
return output_range_.at(idx);
}
void AssignInputRange(std::pair<int, int>&& range, size_t idx);
void AssignInputRange(std::pair<int, int>&& range, size_t idx) {
if (idx < input_range_.size()) {
input_range_[idx] = range;
} else if (idx == input_range_.size()) {
input_range_.emplace_back(range);
} else {
PADDLE_THROW(paddle::platform::errors::PreconditionNotMet(
"Invalid idx when trying to set InputRange, "
"index is `%d`, it is greater than the size(%d) of InputRange.",
idx,
input_range_.size()));
}
}
void AssignOutputRange(std::pair<int, int>&& range, size_t idx) {
if (idx < output_range_.size()) {
output_range_[idx] = range;
} else if (idx == output_range_.size()) {
output_range_.emplace_back(range);
} else {
PADDLE_THROW(paddle::platform::errors::PreconditionNotMet(
"Invalid idx when trying to set InputRange, "
"index is `%d`, it is greater than the size(%d) of InputRange.",
idx,
output_range_.size()));
}
}
void AssignOutputRange(std::pair<int, int>&& range, size_t idx);
template <typename TensorType>
TensorType* MutableInputAt(size_t idx) {
......@@ -187,19 +131,7 @@ class KernelContext {
// Temporary method: For compatible with fluid Tensor and improve performance
// Only deal with DenseTensor now
void ClearData() {
for (auto& in : inputs_) {
if (in) {
CompatibleDenseTensorUtils::ClearStorage(
static_cast<DenseTensor*>(in.get()));
}
}
for (auto& out : outputs_) {
CompatibleDenseTensorUtils::ClearStorage(
static_cast<DenseTensor*>(out.get()));
}
attrs_.clear();
}
void ClearData();
size_t InputsSize() const { return inputs_.size(); }
size_t OutputsSize() const { return outputs_.size(); }
......
......@@ -15,6 +15,8 @@ limitations under the License. */
#pragma once
#include "paddle/pten/core/dense_tensor.h"
#include "paddle/pten/core/kernel_registry.h"
// See Note [ Why still include the fluid headers? ]
#include "paddle/fluid/platform/device_context.h"
......
......@@ -18,6 +18,7 @@
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
#include "paddle/pten/core/dense_tensor.h"
#include "paddle/pten/core/kernel_registry.h"
// See Note [ Why still include the fluid headers? ]
#include "paddle/fluid/platform/device_context.h"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册