未验证 提交 35d2b71a 编写于 作者: C Chen Weihang 提交者: GitHub

[PTen] Remove cached kernel context (#38953)

* remove cached kernel context

* revert dataloader format change
上级 1053b1d5
......@@ -418,15 +418,16 @@ void InterpreterCore::RunInstruction(const Instruction& instr_node) {
VLOG(4) << "Run pten kernel: " << op->Type();
VLOG(4) << instr_node.InnerRuntimeContext().get() << " "
<< &instr_node.DeviceContext();
pten::KernelContext pt_kernel_context;
op_with_kernel->BuildPtenKernelContext(
*instr_node.InnerRuntimeContext().get(),
const_cast<platform::DeviceContext*>(&instr_node.DeviceContext()));
const_cast<platform::DeviceContext*>(&instr_node.DeviceContext()),
&pt_kernel_context);
(*instr_node.PtenKernel())(instr_node.PtenKernelContext());
(*instr_node.PtenKernel())(&pt_kernel_context);
op_with_kernel->WriteBackToOutputs(
instr_node.InnerRuntimeContext().get());
instr_node.PtenKernelContext()->ClearData();
instr_node.InnerRuntimeContext().get(), &pt_kernel_context);
} else {
instr_node.KernelFunc()(*instr_node.InnerExecutionContext().get());
}
......
......@@ -425,13 +425,14 @@ void build_op_func_list(const platform::Place& place,
}
if (run_pten_kernel) {
op_with_kernel->BuildPtenKernelContext(runtime_context, dev_ctx);
pten::KernelContext pt_kernel_context;
op_with_kernel->BuildPtenKernelContext(runtime_context, dev_ctx,
&pt_kernel_context);
op_func_node.pt_kernel_ = op_with_kernel->PtenKernel();
op_func_node.pt_kernel_context_ = op_with_kernel->PtenKernelContext();
(*op_func_node.pt_kernel_)(op_func_node.pt_kernel_context_);
op_with_kernel->WriteBackToOutputs(&runtime_context);
op_func_node.pt_kernel_context_->ClearData();
(*op_func_node.pt_kernel_)(&pt_kernel_context);
op_with_kernel->WriteBackToOutputs(&runtime_context,
&pt_kernel_context);
} else {
op_func_node.kernel_func_ = OpKernelComputeFunc(kernel_iter->second);
op_func_node.kernel_func_(exec_ctx);
......
......@@ -688,10 +688,6 @@ pten::Kernel* Instruction::PtenKernel() const {
return op_func_node_.pt_kernel_;
}
pten::KernelContext* Instruction::PtenKernelContext() const {
return op_func_node_.pt_kernel_context_;
}
OpFuncType Instruction::KernelType() const { return op_func_node_.type_; }
OperatorBase* Instruction::OpBase() const {
......
......@@ -299,8 +299,7 @@ struct OpFuncNode {
platform::DeviceContext* dev_ctx_; // not owned
// fit for pten kernel
pten::Kernel* pt_kernel_{nullptr}; // not owned
pten::KernelContext* pt_kernel_context_{nullptr}; // not onwed
pten::Kernel* pt_kernel_{nullptr}; // not owned
OpFuncType type_;
};
......@@ -322,8 +321,6 @@ class Instruction {
pten::Kernel* PtenKernel() const;
pten::KernelContext* PtenKernelContext() const;
OpFuncType KernelType() const;
OperatorBase* OpBase() const;
......
......@@ -1192,13 +1192,10 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
platform::RecordEvent record_event("compute",
platform::EventRole::kInnerOp);
if (run_pten_kernel_) {
if (pt_kernel_context_ == nullptr) {
pt_kernel_context_.reset(new pten::KernelContext());
}
BuildPtenKernelContext(*runtime_ctx, dev_ctx);
(*pt_kernel_)(pt_kernel_context_.get());
WriteBackToOutputs(runtime_ctx);
pt_kernel_context_->ClearData();
pten::KernelContext pt_kernel_context;
BuildPtenKernelContext(*runtime_ctx, dev_ctx, &pt_kernel_context);
(*pt_kernel_)(&pt_kernel_context);
WriteBackToOutputs(runtime_ctx, &pt_kernel_context);
} else {
(*kernel_func_)(
ExecutionContext(*this, exec_scope, *dev_ctx, *runtime_ctx));
......@@ -1791,18 +1788,9 @@ KernelSignature OperatorWithKernel::GetExpectedPtenKernelArgs(
}
void OperatorWithKernel::BuildPtenKernelContext(
const RuntimeContext& ctx, platform::DeviceContext* dev_ctx) const {
if (pt_kernel_context_ == nullptr) {
pt_kernel_context_.reset(new pten::KernelContext());
}
// TODO(chenweihang): now only work for very simple case,
// many cases need to be deal with later:
// 1. the input and output are not tensor
// 2. the dispensbale, duplicable input and output
// 3. needless attributes remove
// 4. use pt Tensor directly
// 5. kernel input is not DenseTensor
pt_kernel_context_->SetDeviceContext(dev_ctx);
const RuntimeContext& ctx, platform::DeviceContext* dev_ctx,
pten::KernelContext* pt_kernel_context) const {
pt_kernel_context->SetDeviceContext(dev_ctx);
auto& input_names = std::get<0>(pt_kernel_signature_->args);
auto& attr_names = std::get<1>(pt_kernel_signature_->args);
......@@ -1836,33 +1824,14 @@ void OperatorWithKernel::BuildPtenKernelContext(
// calcute the start and end index of the input tensors
size_t start_idx =
(i == 0 ? 0 : pt_kernel_context_->InputRangeAt(i - 1).second);
(i == 0 ? 0 : pt_kernel_context->InputRangeAt(i - 1).second);
size_t end_idx = start_idx + ins_vector.size();
auto current_vector_size = pt_kernel_context_->InputsSize();
// If the memory needed is less than the current memory allocated, we will
// reuse the current memory by using ReMakePtenDenseTensorFromVar.
// Otherwise,we will create new storage.
for (size_t offset = 0; offset < ins_vector.size(); ++offset) {
if (current_vector_size > start_idx + offset) {
auto& input_ptr =
pt_kernel_context_->MutableInputPtrAt(start_idx + offset);
if (input_ptr == nullptr) {
input_ptr = experimental::MakePtenTensorBaseFromVar(
*ins_vector[offset], in_def);
} else {
experimental::ReMakePtenDenseTensorFromVar(
*ins_vector[offset], in_def,
pt_kernel_context_->MutableInputAt<pten::DenseTensor>(start_idx +
offset));
}
} else {
pt_kernel_context_->EmplaceBackInputWithoutSetRange(
experimental::MakePtenTensorBaseFromVar(*ins_vector[offset],
in_def));
}
pt_kernel_context->EmplaceBackInputWithoutSetRange(
experimental::MakePtenTensorBaseFromVar(*ins_vector[offset], in_def));
}
pt_kernel_context_->AssignInputRange(std::make_pair(start_idx, end_idx), i);
pt_kernel_context->AssignInputRange(std::make_pair(start_idx, end_idx), i);
}
for (size_t i = 0; i < output_names.size(); ++i) {
......@@ -1870,43 +1839,24 @@ void OperatorWithKernel::BuildPtenKernelContext(
auto& outs_vector = ctx.outputs.at(output_names[i]);
size_t start_idx =
(i == 0 ? 0 : pt_kernel_context_->OutputRangeAt(i - 1).second);
(i == 0 ? 0 : pt_kernel_context->OutputRangeAt(i - 1).second);
size_t end_idx = start_idx + outs_vector.size();
auto current_vector_size = pt_kernel_context_->OutputsSize();
// If the memory needed is less than the current memory allocated, we will
// reuse the current memory by using ReMakePtenDenseTensorFromVar.
// Otherwise,we will create new storage.
for (size_t offset = 0; offset < outs_vector.size(); ++offset) {
if (current_vector_size > start_idx + offset) {
auto* buffer_tensor =
pt_kernel_context_->MutableOutputAt<pten::DenseTensor>(start_idx +
offset);
if (buffer_tensor) {
experimental::ReMakePtenDenseTensorFromVar(outs_vector[offset],
out_def, buffer_tensor);
}
} else {
pt_kernel_context_->EmplaceBackOutputWithoutSetRange(
experimental::MakePtenTensorBaseFromVar(outs_vector[offset],
out_def));
}
pt_kernel_context->EmplaceBackOutputWithoutSetRange(
experimental::MakePtenTensorBaseFromVar(outs_vector[offset],
out_def));
}
// Deal with the case that some outputs are NULL when run the kernel.
// For example : the outputs of matmul_grad are dx and dy,
// sometimes dx or dy may be NULL.
if (outs_vector.empty()) {
if (current_vector_size > start_idx) {
pt_kernel_context_->SetOutputWithoutSetRange(start_idx, {nullptr});
} else {
pt_kernel_context_->EmplaceBackOutputWithoutSetRange({nullptr});
}
pt_kernel_context->EmplaceBackOutputWithoutSetRange({nullptr});
end_idx = start_idx + 1;
}
pt_kernel_context_->AssignOutputRange(std::make_pair(start_idx, end_idx),
i);
pt_kernel_context->AssignOutputRange(std::make_pair(start_idx, end_idx), i);
}
for (size_t i = 0; i < attr_names.size(); ++i) {
......@@ -1915,11 +1865,11 @@ void OperatorWithKernel::BuildPtenKernelContext(
if (attr_iter != Attrs().end()) { // shape is in the attribute
if (std::type_index(attr_iter->second.type()) ==
std::type_index(typeid(std::vector<int64_t>))) {
pt_kernel_context_->EmplaceBackAttr(std::move(pten::ScalarArray(
pt_kernel_context->EmplaceBackAttr(std::move(pten::ScalarArray(
BOOST_GET_CONST(std::vector<int64_t>, attr_iter->second))));
} else if (std::type_index(attr_iter->second.type()) ==
std::type_index(typeid(std::vector<int32_t>))) {
pt_kernel_context_->EmplaceBackAttr(std::move(pten::ScalarArray(
pt_kernel_context->EmplaceBackAttr(std::move(pten::ScalarArray(
BOOST_GET_CONST(std::vector<int32_t>, attr_iter->second))));
} else {
PADDLE_THROW(platform::errors::Unimplemented(
......@@ -1930,10 +1880,10 @@ void OperatorWithKernel::BuildPtenKernelContext(
} else { // shape is in the input
auto& ins_vector = ctx.inputs.at(attr_names[i]);
if (ins_vector.size() == 1) { // ShapeTensor
pt_kernel_context_->EmplaceBackAttr(std::move(
pt_kernel_context->EmplaceBackAttr(std::move(
experimental::MakePtenScalarArrayFromVar(*ins_vector.front())));
} else { // ShapeTensorList
pt_kernel_context_->EmplaceBackAttr(std::move(
pt_kernel_context->EmplaceBackAttr(std::move(
experimental::MakePtenScalarArrayFromVarList(ins_vector)));
}
}
......@@ -1946,11 +1896,11 @@ void OperatorWithKernel::BuildPtenKernelContext(
if (attr_iter != Attrs().end()) { // scalar is in the attribute
auto& attr = Attrs().at(attr_names[i]);
if (std::type_index(attr.type()) == std::type_index(typeid(float))) {
pt_kernel_context_->EmplaceBackAttr(
pt_kernel_context->EmplaceBackAttr(
std::move(pten::Scalar(BOOST_GET_CONST(float, attr))));
} else if (std::type_index(attr.type()) ==
std::type_index(typeid(std::string))) {
pt_kernel_context_->EmplaceBackAttr(
pt_kernel_context->EmplaceBackAttr(
std::move(pten::Scalar(BOOST_GET_CONST(std::string, attr))));
} else {
PADDLE_THROW(platform::errors::Unimplemented(
......@@ -1960,7 +1910,7 @@ void OperatorWithKernel::BuildPtenKernelContext(
}
} else {
auto& ins_vector = ctx.inputs.at(attr_names[i]);
pt_kernel_context_->EmplaceBackAttr(std::move(
pt_kernel_context->EmplaceBackAttr(std::move(
experimental::MakePtenScalarFromVar(*ins_vector.front())));
}
......@@ -1968,17 +1918,17 @@ void OperatorWithKernel::BuildPtenKernelContext(
// TODO(chenweihang): support other attrs later
auto& attr = Attrs().at(attr_names[i]);
if (attr_defs[i].type_index == std::type_index(typeid(int))) {
pt_kernel_context_->EmplaceBackAttr(BOOST_GET_CONST(int, attr));
pt_kernel_context->EmplaceBackAttr(BOOST_GET_CONST(int, attr));
} else if (attr_defs[i].type_index == std::type_index(typeid(float))) {
pt_kernel_context_->EmplaceBackAttr(BOOST_GET_CONST(float, attr));
pt_kernel_context->EmplaceBackAttr(BOOST_GET_CONST(float, attr));
} else if (attr_defs[i].type_index == std::type_index(typeid(bool))) {
pt_kernel_context_->EmplaceBackAttr(BOOST_GET_CONST(bool, attr));
pt_kernel_context->EmplaceBackAttr(BOOST_GET_CONST(bool, attr));
} else if (attr_defs[i].type_index ==
std::type_index(typeid(pten::DataType))) {
auto data_type = pten::TransToPtenDataType(
static_cast<framework::proto::VarType::Type>(
BOOST_GET_CONST(int, attr)));
pt_kernel_context_->EmplaceBackAttr(data_type);
pt_kernel_context->EmplaceBackAttr(data_type);
} else if (attr_defs[i].type_index ==
std::type_index(typeid(std::vector<int64_t>))) {
if (std::type_index(attr.type()) ==
......@@ -1987,7 +1937,7 @@ void OperatorWithKernel::BuildPtenKernelContext(
const auto& vector_int_attr = BOOST_GET_CONST(std::vector<int>, attr);
const std::vector<int64_t> vector_int64_attr(vector_int_attr.begin(),
vector_int_attr.end());
pt_kernel_context_->EmplaceBackAttr(vector_int64_attr);
pt_kernel_context->EmplaceBackAttr(vector_int64_attr);
}
// TODO(YuanRisheng) Need support vector<int64_t> attr
......@@ -2001,20 +1951,16 @@ void OperatorWithKernel::BuildPtenKernelContext(
}
}
void OperatorWithKernel::WriteBackToOutputs(RuntimeContext* ctx) const {
// auto& input_names = std::get<0>(pt_kernel_signature_->args);
// auto& attr_names = std::get<1>(pt_kernel_signature_->args);
void OperatorWithKernel::WriteBackToOutputs(
RuntimeContext* ctx, pten::KernelContext* pt_kernel_context) const {
auto& output_names = std::get<2>(pt_kernel_signature_->args);
// pt_kernel_context_
for (size_t i = 0; i < output_names.size(); ++i) {
auto& outs_vector = ctx->outputs.at(output_names[i]);
auto& range_pair = pt_kernel_context_->OutputRangeAt(i);
auto pten_outs =
pt_kernel_context_->MutableOutputBetween<pten::DenseTensor>(
range_pair.first, range_pair.second);
auto& range_pair = pt_kernel_context->OutputRangeAt(i);
auto pten_outs = pt_kernel_context->MutableOutputBetween<pten::DenseTensor>(
range_pair.first, range_pair.second);
for (size_t j = 0; j < pten_outs.size(); ++j) {
if (pten_outs[j]) {
......
......@@ -589,16 +589,14 @@ class OperatorWithKernel : public OperatorBase {
void ChoosePtenKernel(const ExecutionContext& ctx) const;
void BuildPtenKernelContext(const RuntimeContext& ctx,
platform::DeviceContext* dev_ctx) const;
platform::DeviceContext* dev_ctx,
pten::KernelContext* pt_kernel_context) const;
void WriteBackToOutputs(RuntimeContext* ctx) const;
void WriteBackToOutputs(RuntimeContext* ctx,
pten::KernelContext* pt_kernel_context) const;
pten::Kernel* PtenKernel() const { return pt_kernel_.get(); }
pten::KernelContext* PtenKernelContext() const {
return pt_kernel_context_.get();
}
const OpKernelType* kernel_type() const { return kernel_type_.get(); }
private:
......@@ -657,9 +655,6 @@ class OperatorWithKernel : public OperatorBase {
mutable bool run_pten_kernel_ = false;
mutable std::unique_ptr<KernelSignature> pt_kernel_signature_;
mutable std::unique_ptr<pten::Kernel> pt_kernel_;
// In order to reduce the compatibility phase
// performance overhead, temporarily cache KernelContext
mutable std::unique_ptr<pten::KernelContext> pt_kernel_context_;
};
extern bool OpSupportGPU(const std::string& op_type);
......
......@@ -409,8 +409,6 @@ void VarBase::_CopyGradientFrom(const VarBase& src) {
}
}
pten::KernelContext OpBase::pt_kernel_context_;
void OpBase::SetType(const std::string& type) {
op_ = framework::OpRegistry::CreateOp(type, {}, {}, {}, false);
}
......@@ -426,8 +424,7 @@ static void OpBaseRunImpl(const framework::OperatorBase& op,
const NameVarMap<VarType>& outs,
const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs,
const platform::Place& place,
pten::KernelContext* pt_kernel_context) {
const platform::Place& place) {
auto* op_kernel = dynamic_cast<const framework::OperatorWithKernel*>(&op);
PADDLE_ENFORCE_NOT_NULL(
op_kernel, platform::errors::PermissionDenied(
......@@ -468,8 +465,8 @@ static void OpBaseRunImpl(const framework::OperatorBase& op,
* after the execution of op, but the original input is directly
* overwritten in the previous dynamic graph implemention.
*/
auto prepared_op = PreparedOp::Prepare(ins, outs, *op_kernel, place, attrs,
default_attrs, pt_kernel_context);
auto prepared_op =
PreparedOp::Prepare(ins, outs, *op_kernel, place, attrs, default_attrs);
auto tmp_ins_ptr =
PrepareData<VarType>(*op_kernel, ins, prepared_op.kernel_type());
if (tmp_ins_ptr == nullptr) {
......@@ -497,8 +494,7 @@ void OpBase::Run(const framework::OperatorBase& op,
const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs,
const platform::Place& place) {
OpBaseRunImpl<VarBase>(op, ins, outs, attrs, default_attrs, place,
&pt_kernel_context_);
OpBaseRunImpl<VarBase>(op, ins, outs, attrs, default_attrs, place);
}
void OpBase::Run(const framework::OperatorBase& op,
......@@ -507,8 +503,7 @@ void OpBase::Run(const framework::OperatorBase& op,
const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs,
const platform::Place& place) {
OpBaseRunImpl<VariableWrapper>(op, ins, outs, attrs, default_attrs, place,
&pt_kernel_context_);
OpBaseRunImpl<VariableWrapper>(op, ins, outs, attrs, default_attrs, place);
}
void ClearNoNeedBufferInputs(OpBase* op) {
......
......@@ -183,8 +183,6 @@ class OpBase {
const framework::AttributeMap& default_attrs,
const platform::Place& place);
static pten::KernelContext* GetKernelContext() { return &pt_kernel_context_; }
bool HasVoidFunctionPostHook() const {
return !void_function_post_hooks_.empty();
}
......@@ -212,9 +210,6 @@ class OpBase {
std::unique_ptr<framework::OperatorBase> op_;
platform::Place place_;
size_t id_{-1UL};
// In order to reduce the compatibility phase
// performance overhead, temporarily cache KernelContext
static pten::KernelContext pt_kernel_context_;
std::vector<std::shared_ptr<std::function<void()>>> void_function_post_hooks_;
};
......
......@@ -117,7 +117,6 @@ PreparedOp::PreparedOp(const framework::OperatorBase& op,
const framework::OpKernelType& kernel_type,
const framework::KernelSignature& kernel_signature,
const pten::Kernel& pt_kernel,
pten::KernelContext* pt_kernel_context,
platform::DeviceContext* dev_ctx)
: op_(op),
ctx_(ctx),
......@@ -126,8 +125,7 @@ PreparedOp::PreparedOp(const framework::OperatorBase& op,
dev_ctx_(dev_ctx),
run_pten_kernel_(true),
pt_kernel_signature_(kernel_signature),
pt_kernel_(pt_kernel),
pt_kernel_context_(pt_kernel_context) {}
pt_kernel_(pt_kernel) {}
template <typename VarType>
PreparedOp PrepareImpl(const NameVarMap<VarType>& ins,
......@@ -135,8 +133,7 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins,
const framework::OperatorWithKernel& op,
const platform::Place& place,
const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs,
pten::KernelContext* pt_kernel_context) {
const framework::AttributeMap& default_attrs) {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto* dev_ctx = pool.Get(place);
......@@ -178,7 +175,7 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins,
// TODO(chenweihang): using CPUKernel when miss device kernel case
return PreparedOp(op, ctx, expected_kernel_key, pt_kernel_signature,
pt_kernel, pt_kernel_context, dev_ctx);
pt_kernel, dev_ctx);
} else {
VLOG(6) << "Dynamic mode ChoosePtenKernel - kernel `" << pt_kernel_name
<< "` not found.";
......@@ -247,10 +244,8 @@ PreparedOp PreparedOp::Prepare(const NameVarMap<VarBase>& ins,
const framework::OperatorWithKernel& op,
const platform::Place& place,
const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs,
pten::KernelContext* pt_kernel_context) {
return PrepareImpl<VarBase>(ins, outs, op, place, attrs, default_attrs,
pt_kernel_context);
const framework::AttributeMap& default_attrs) {
return PrepareImpl<VarBase>(ins, outs, op, place, attrs, default_attrs);
}
PreparedOp PreparedOp::Prepare(const NameVarMap<VariableWrapper>& ins,
......@@ -258,10 +253,9 @@ PreparedOp PreparedOp::Prepare(const NameVarMap<VariableWrapper>& ins,
const framework::OperatorWithKernel& op,
const platform::Place& place,
const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs,
pten::KernelContext* pt_kernel_context) {
const framework::AttributeMap& default_attrs) {
return PrepareImpl<VariableWrapper>(ins, outs, op, place, attrs,
default_attrs, pt_kernel_context);
default_attrs);
}
template <typename VarType>
......@@ -271,13 +265,6 @@ static void BuildDygraphPtenKernelContext(
const NameVarMap<VarType>& outs, const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs,
platform::DeviceContext* dev_ctx, pten::KernelContext* kernel_ctx) {
// TODO(chenweihang): now only work for very simple case,
// many cases need to be deal with later:
// 1. the input and output are not tensor
// 2. the dispensbale, duplicable input and output
// 3. needless attributes remove
// 4. use pt Tensor directly
// 5. kernel input is not DenseTensor
kernel_ctx->SetDeviceContext(dev_ctx);
auto& input_names = std::get<0>(pt_kernel_signature.args);
......@@ -312,26 +299,11 @@ static void BuildDygraphPtenKernelContext(
size_t start_idx = (i == 0 ? 0 : kernel_ctx->InputRangeAt(i - 1).second);
size_t end_idx = start_idx + ins_vector.size();
auto current_vector_size = kernel_ctx->InputsSize();
// If the memory needed is less than the current memory allocated, we will
// reuse the current memory by using ReMakePtenDenseTensorFromVar.
// Otherwise,we will create new storage.
for (size_t offset = 0; offset < ins_vector.size(); ++offset) {
const auto& variable = ins_vector[offset]->Var();
if (current_vector_size > start_idx + offset) {
auto& input_ptr = kernel_ctx->MutableInputPtrAt(start_idx + offset);
if (input_ptr == nullptr) {
input_ptr = experimental::MakePtenTensorBaseFromVar(variable, in_def);
} else {
experimental::ReMakePtenDenseTensorFromVar(
variable, in_def, kernel_ctx->MutableInputAt<pten::DenseTensor>(
start_idx + offset));
}
} else {
kernel_ctx->EmplaceBackInputWithoutSetRange(
experimental::MakePtenTensorBaseFromVar(variable, in_def));
}
kernel_ctx->EmplaceBackInputWithoutSetRange(
paddle::experimental::MakePtenTensorBaseFromVar(variable, in_def));
}
kernel_ctx->AssignInputRange(std::make_pair(start_idx, end_idx), i);
}
......@@ -340,15 +312,10 @@ static void BuildDygraphPtenKernelContext(
auto& out_def = output_defs.at(i);
size_t start_idx = (i == 0 ? 0 : kernel_ctx->OutputRangeAt(i - 1).second);
auto current_vector_size = kernel_ctx->OutputsSize();
auto iter = outs.find(output_names[i]);
if (iter == outs.end()) {
if (current_vector_size > start_idx) {
kernel_ctx->SetOutputWithoutSetRange(start_idx, {nullptr});
} else {
kernel_ctx->EmplaceBackOutputWithoutSetRange({nullptr});
}
kernel_ctx->EmplaceBackOutputWithoutSetRange({nullptr});
kernel_ctx->AssignOutputRange(std::make_pair(start_idx, start_idx + 1),
i);
continue;
......@@ -357,27 +324,10 @@ static void BuildDygraphPtenKernelContext(
auto& outs_vector = iter->second;
size_t end_idx = start_idx + outs_vector.size();
// If the memory needed is less than the current memory allocated, we will
// reuse the current memory by using ReMakePtenDenseTensorFromVar.
// Otherwise,we will create new storage.
for (size_t offset = 0; offset < outs_vector.size(); ++offset) {
if (current_vector_size > start_idx + offset) {
auto* buffer_tensor =
kernel_ctx->MutableOutputAt<pten::DenseTensor>(start_idx + offset);
if (buffer_tensor) {
experimental::ReMakePtenDenseTensorFromVar(
outs_vector[offset]->MutableVar(), out_def, buffer_tensor);
} else {
kernel_ctx->SetOutputWithoutSetRange(
start_idx + offset,
experimental::MakePtenTensorBaseFromVar(
outs_vector[offset]->MutableVar(), out_def));
}
} else {
kernel_ctx->EmplaceBackOutputWithoutSetRange(
experimental::MakePtenTensorBaseFromVar(
outs_vector[offset]->MutableVar(), out_def));
}
kernel_ctx->EmplaceBackOutputWithoutSetRange(
paddle::experimental::MakePtenTensorBaseFromVar(
outs_vector[offset]->MutableVar(), out_def));
}
kernel_ctx->AssignOutputRange(std::make_pair(start_idx, end_idx), i);
}
......@@ -556,19 +506,20 @@ static void PreparedOpRunPtImpl(
const framework::OperatorBase& op,
const framework::OpKernelType& kernel_type,
const framework::KernelSignature& pt_kernel_signature,
const pten::Kernel& pt_kernel, pten::KernelContext* pt_kernel_context,
platform::DeviceContext* dev_ctx, const NameVarMap<VarType>& ins,
const NameVarMap<VarType>& outs, const framework::AttributeMap& attrs,
const pten::Kernel& pt_kernel, platform::DeviceContext* dev_ctx,
const NameVarMap<VarType>& ins, const NameVarMap<VarType>& outs,
const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs) {
DygraphInferShapeContext<VarType> infer_shape_ctx(
&ins, &outs, &attrs, &default_attrs, op.Type(), &kernel_type);
op.Info().infer_shape_(&infer_shape_ctx);
pten::KernelContext pt_kernel_context;
BuildDygraphPtenKernelContext<VarType>(pt_kernel_signature, pt_kernel, ins,
outs, attrs, default_attrs, dev_ctx,
pt_kernel_context);
&pt_kernel_context);
pt_kernel(pt_kernel_context);
pt_kernel(&pt_kernel_context);
if (FLAGS_benchmark) {
dev_ctx->Wait();
......@@ -578,10 +529,7 @@ static void PreparedOpRunPtImpl(
#endif
}
WriteBackToOutputs<VarType>(pt_kernel_signature, outs, pt_kernel_context);
// Ensure that it does not affect the VarBase life cycle management
pt_kernel_context->ClearData();
WriteBackToOutputs<VarType>(pt_kernel_signature, outs, &pt_kernel_context);
// TODO(chenweihang): add debug flags later
if (framework::IsComplexType(kernel_type.data_type_)) {
......@@ -595,8 +543,8 @@ void PreparedOp::Run(const NameVarMap<VarBase>& ins,
const framework::AttributeMap& default_attrs) {
if (run_pten_kernel_) {
PreparedOpRunPtImpl<VarBase>(op_, kernel_type_, pt_kernel_signature_,
pt_kernel_, pt_kernel_context_, dev_ctx_, ins,
outs, attrs, default_attrs);
pt_kernel_, dev_ctx_, ins, outs, attrs,
default_attrs);
} else {
PreparedOpRunImpl<VarBase>(op_, ctx_, kernel_type_, func_, dev_ctx_, ins,
outs, attrs, default_attrs);
......@@ -609,8 +557,8 @@ void PreparedOp::Run(const NameVarMap<VariableWrapper>& ins,
const framework::AttributeMap& default_attrs) {
if (run_pten_kernel_) {
PreparedOpRunPtImpl<VariableWrapper>(
op_, kernel_type_, pt_kernel_signature_, pt_kernel_, pt_kernel_context_,
dev_ctx_, ins, outs, attrs, default_attrs);
op_, kernel_type_, pt_kernel_signature_, pt_kernel_, dev_ctx_, ins,
outs, attrs, default_attrs);
} else {
PreparedOpRunImpl<VariableWrapper>(op_, ctx_, kernel_type_, func_, dev_ctx_,
ins, outs, attrs, default_attrs);
......
......@@ -153,25 +153,21 @@ class PreparedOp {
const framework::RuntimeContext& ctx,
const framework::OpKernelType& kernel_type,
const framework::KernelSignature& kernel_signature,
const pten::Kernel& pt_kernel,
pten::KernelContext* pt_kernel_context,
platform::DeviceContext* dev_ctx);
const pten::Kernel& pt_kernel, platform::DeviceContext* dev_ctx);
static PreparedOp Prepare(const NameVarMap<VarBase>& ins,
const NameVarMap<VarBase>& outs,
const framework::OperatorWithKernel& op,
const platform::Place& place,
const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs,
pten::KernelContext* pt_kernel_context = nullptr);
const framework::AttributeMap& default_attrs);
static PreparedOp Prepare(const NameVarMap<VariableWrapper>& ins,
const NameVarMap<VariableWrapper>& outs,
const framework::OperatorWithKernel& op,
const platform::Place& place,
const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs,
pten::KernelContext* pt_kernel_context = nullptr);
const framework::AttributeMap& default_attrs);
void Run(const NameVarMap<VarBase>& in, const NameVarMap<VarBase>& out,
const framework::AttributeMap& attrs,
......@@ -196,9 +192,6 @@ class PreparedOp {
bool run_pten_kernel_{false};
framework::KernelSignature pt_kernel_signature_;
pten::Kernel pt_kernel_;
// In order to reduce the compatibility phase
// performance overhead, temporarily cache KernelContext
pten::KernelContext* pt_kernel_context_;
};
} // namespace imperative
......
......@@ -231,8 +231,6 @@ void Tracer::TraceOp(const std::string& type, const NameVarBaseMap& ins,
OpBase::Run(*op, new_ins, outs, attrs, default_attrs, place);
} catch (platform::EnforceNotMet& exception) {
framework::AppendErrorOpHint(type, &exception);
// Compatible impl: clear pten kernel context data when throw error
OpBase::GetKernelContext()->ClearData();
throw std::move(exception);
} catch (std::exception& ex) {
PADDLE_THROW(platform::errors::Fatal(
......
......@@ -202,22 +202,6 @@ class _DataLoaderIterSingleProcess(_DataLoaderIterBase):
# APIs in this thread.
_set_expected_place(legacy_expected_place)
# NOTE(chenweihang): [ Why need to set not to execute pten kernel here? ]
# Now, in order to ensure that the execution performance of the dynamic
# graph mode in pten compatible state does not decline significantly,
# we have adopted the approach of caching a KernelContext globally for
# the dynamic graph tracer to reduce the construction and deconstruction
# overhead of data interfaces such as the compatible state DenseTensor.
# The static graph is each op caches a KernelContext, but the op of
# the dynamic graph will be constructed and destroyed every round of
# execution, so it is impossible to cache KernelContext for each op.
# However, it is not thread-safe if using only one global kernel context in
# dynamic graph. If the pten op of paddle is used in the DataLoader thread,
# it may cause access errors. We temporarily do not execute pten kernel
# in this scenario and will find a better solution later and remove
# this setting.
set_flags({'FLAGS_run_pten_kernel': False})
while not self._thread_done_event.is_set():
try:
indices = next(self._sampler_iter)
......@@ -519,9 +503,6 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase):
# APIs in this thread.
_set_expected_place(legacy_expected_place)
# NOTE(chenweihang): See Note [ Why need to set not to execute pten kernel here? ]
set_flags({'FLAGS_run_pten_kernel': False})
while not self._thread_done_event.is_set():
batch = self._get_data()
if not self._thread_done_event.is_set():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册