未验证 提交 35d2b71a 编写于 作者: C Chen Weihang 提交者: GitHub

[PTen] Remove cached kernel context (#38953)

* remove cached kernel context

* revert dataloader format change
上级 1053b1d5
...@@ -418,15 +418,16 @@ void InterpreterCore::RunInstruction(const Instruction& instr_node) { ...@@ -418,15 +418,16 @@ void InterpreterCore::RunInstruction(const Instruction& instr_node) {
VLOG(4) << "Run pten kernel: " << op->Type(); VLOG(4) << "Run pten kernel: " << op->Type();
VLOG(4) << instr_node.InnerRuntimeContext().get() << " " VLOG(4) << instr_node.InnerRuntimeContext().get() << " "
<< &instr_node.DeviceContext(); << &instr_node.DeviceContext();
pten::KernelContext pt_kernel_context;
op_with_kernel->BuildPtenKernelContext( op_with_kernel->BuildPtenKernelContext(
*instr_node.InnerRuntimeContext().get(), *instr_node.InnerRuntimeContext().get(),
const_cast<platform::DeviceContext*>(&instr_node.DeviceContext())); const_cast<platform::DeviceContext*>(&instr_node.DeviceContext()),
&pt_kernel_context);
(*instr_node.PtenKernel())(instr_node.PtenKernelContext()); (*instr_node.PtenKernel())(&pt_kernel_context);
op_with_kernel->WriteBackToOutputs( op_with_kernel->WriteBackToOutputs(
instr_node.InnerRuntimeContext().get()); instr_node.InnerRuntimeContext().get(), &pt_kernel_context);
instr_node.PtenKernelContext()->ClearData();
} else { } else {
instr_node.KernelFunc()(*instr_node.InnerExecutionContext().get()); instr_node.KernelFunc()(*instr_node.InnerExecutionContext().get());
} }
......
...@@ -425,13 +425,14 @@ void build_op_func_list(const platform::Place& place, ...@@ -425,13 +425,14 @@ void build_op_func_list(const platform::Place& place,
} }
if (run_pten_kernel) { if (run_pten_kernel) {
op_with_kernel->BuildPtenKernelContext(runtime_context, dev_ctx); pten::KernelContext pt_kernel_context;
op_with_kernel->BuildPtenKernelContext(runtime_context, dev_ctx,
&pt_kernel_context);
op_func_node.pt_kernel_ = op_with_kernel->PtenKernel(); op_func_node.pt_kernel_ = op_with_kernel->PtenKernel();
op_func_node.pt_kernel_context_ = op_with_kernel->PtenKernelContext();
(*op_func_node.pt_kernel_)(op_func_node.pt_kernel_context_); (*op_func_node.pt_kernel_)(&pt_kernel_context);
op_with_kernel->WriteBackToOutputs(&runtime_context); op_with_kernel->WriteBackToOutputs(&runtime_context,
op_func_node.pt_kernel_context_->ClearData(); &pt_kernel_context);
} else { } else {
op_func_node.kernel_func_ = OpKernelComputeFunc(kernel_iter->second); op_func_node.kernel_func_ = OpKernelComputeFunc(kernel_iter->second);
op_func_node.kernel_func_(exec_ctx); op_func_node.kernel_func_(exec_ctx);
......
...@@ -688,10 +688,6 @@ pten::Kernel* Instruction::PtenKernel() const { ...@@ -688,10 +688,6 @@ pten::Kernel* Instruction::PtenKernel() const {
return op_func_node_.pt_kernel_; return op_func_node_.pt_kernel_;
} }
pten::KernelContext* Instruction::PtenKernelContext() const {
return op_func_node_.pt_kernel_context_;
}
OpFuncType Instruction::KernelType() const { return op_func_node_.type_; } OpFuncType Instruction::KernelType() const { return op_func_node_.type_; }
OperatorBase* Instruction::OpBase() const { OperatorBase* Instruction::OpBase() const {
......
...@@ -299,8 +299,7 @@ struct OpFuncNode { ...@@ -299,8 +299,7 @@ struct OpFuncNode {
platform::DeviceContext* dev_ctx_; // not owned platform::DeviceContext* dev_ctx_; // not owned
// fit for pten kernel // fit for pten kernel
pten::Kernel* pt_kernel_{nullptr}; // not owned pten::Kernel* pt_kernel_{nullptr}; // not owned
pten::KernelContext* pt_kernel_context_{nullptr}; // not onwed
OpFuncType type_; OpFuncType type_;
}; };
...@@ -322,8 +321,6 @@ class Instruction { ...@@ -322,8 +321,6 @@ class Instruction {
pten::Kernel* PtenKernel() const; pten::Kernel* PtenKernel() const;
pten::KernelContext* PtenKernelContext() const;
OpFuncType KernelType() const; OpFuncType KernelType() const;
OperatorBase* OpBase() const; OperatorBase* OpBase() const;
......
...@@ -1192,13 +1192,10 @@ void OperatorWithKernel::RunImpl(const Scope& scope, ...@@ -1192,13 +1192,10 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
platform::RecordEvent record_event("compute", platform::RecordEvent record_event("compute",
platform::EventRole::kInnerOp); platform::EventRole::kInnerOp);
if (run_pten_kernel_) { if (run_pten_kernel_) {
if (pt_kernel_context_ == nullptr) { pten::KernelContext pt_kernel_context;
pt_kernel_context_.reset(new pten::KernelContext()); BuildPtenKernelContext(*runtime_ctx, dev_ctx, &pt_kernel_context);
} (*pt_kernel_)(&pt_kernel_context);
BuildPtenKernelContext(*runtime_ctx, dev_ctx); WriteBackToOutputs(runtime_ctx, &pt_kernel_context);
(*pt_kernel_)(pt_kernel_context_.get());
WriteBackToOutputs(runtime_ctx);
pt_kernel_context_->ClearData();
} else { } else {
(*kernel_func_)( (*kernel_func_)(
ExecutionContext(*this, exec_scope, *dev_ctx, *runtime_ctx)); ExecutionContext(*this, exec_scope, *dev_ctx, *runtime_ctx));
...@@ -1791,18 +1788,9 @@ KernelSignature OperatorWithKernel::GetExpectedPtenKernelArgs( ...@@ -1791,18 +1788,9 @@ KernelSignature OperatorWithKernel::GetExpectedPtenKernelArgs(
} }
void OperatorWithKernel::BuildPtenKernelContext( void OperatorWithKernel::BuildPtenKernelContext(
const RuntimeContext& ctx, platform::DeviceContext* dev_ctx) const { const RuntimeContext& ctx, platform::DeviceContext* dev_ctx,
if (pt_kernel_context_ == nullptr) { pten::KernelContext* pt_kernel_context) const {
pt_kernel_context_.reset(new pten::KernelContext()); pt_kernel_context->SetDeviceContext(dev_ctx);
}
// TODO(chenweihang): now only work for very simple case,
// many cases need to be deal with later:
// 1. the input and output are not tensor
// 2. the dispensbale, duplicable input and output
// 3. needless attributes remove
// 4. use pt Tensor directly
// 5. kernel input is not DenseTensor
pt_kernel_context_->SetDeviceContext(dev_ctx);
auto& input_names = std::get<0>(pt_kernel_signature_->args); auto& input_names = std::get<0>(pt_kernel_signature_->args);
auto& attr_names = std::get<1>(pt_kernel_signature_->args); auto& attr_names = std::get<1>(pt_kernel_signature_->args);
...@@ -1836,33 +1824,14 @@ void OperatorWithKernel::BuildPtenKernelContext( ...@@ -1836,33 +1824,14 @@ void OperatorWithKernel::BuildPtenKernelContext(
// calcute the start and end index of the input tensors // calcute the start and end index of the input tensors
size_t start_idx = size_t start_idx =
(i == 0 ? 0 : pt_kernel_context_->InputRangeAt(i - 1).second); (i == 0 ? 0 : pt_kernel_context->InputRangeAt(i - 1).second);
size_t end_idx = start_idx + ins_vector.size(); size_t end_idx = start_idx + ins_vector.size();
auto current_vector_size = pt_kernel_context_->InputsSize();
// If the memory needed is less than the current memory allocated, we will
// reuse the current memory by using ReMakePtenDenseTensorFromVar.
// Otherwise,we will create new storage.
for (size_t offset = 0; offset < ins_vector.size(); ++offset) { for (size_t offset = 0; offset < ins_vector.size(); ++offset) {
if (current_vector_size > start_idx + offset) { pt_kernel_context->EmplaceBackInputWithoutSetRange(
auto& input_ptr = experimental::MakePtenTensorBaseFromVar(*ins_vector[offset], in_def));
pt_kernel_context_->MutableInputPtrAt(start_idx + offset);
if (input_ptr == nullptr) {
input_ptr = experimental::MakePtenTensorBaseFromVar(
*ins_vector[offset], in_def);
} else {
experimental::ReMakePtenDenseTensorFromVar(
*ins_vector[offset], in_def,
pt_kernel_context_->MutableInputAt<pten::DenseTensor>(start_idx +
offset));
}
} else {
pt_kernel_context_->EmplaceBackInputWithoutSetRange(
experimental::MakePtenTensorBaseFromVar(*ins_vector[offset],
in_def));
}
} }
pt_kernel_context_->AssignInputRange(std::make_pair(start_idx, end_idx), i); pt_kernel_context->AssignInputRange(std::make_pair(start_idx, end_idx), i);
} }
for (size_t i = 0; i < output_names.size(); ++i) { for (size_t i = 0; i < output_names.size(); ++i) {
...@@ -1870,43 +1839,24 @@ void OperatorWithKernel::BuildPtenKernelContext( ...@@ -1870,43 +1839,24 @@ void OperatorWithKernel::BuildPtenKernelContext(
auto& outs_vector = ctx.outputs.at(output_names[i]); auto& outs_vector = ctx.outputs.at(output_names[i]);
size_t start_idx = size_t start_idx =
(i == 0 ? 0 : pt_kernel_context_->OutputRangeAt(i - 1).second); (i == 0 ? 0 : pt_kernel_context->OutputRangeAt(i - 1).second);
size_t end_idx = start_idx + outs_vector.size(); size_t end_idx = start_idx + outs_vector.size();
auto current_vector_size = pt_kernel_context_->OutputsSize();
// If the memory needed is less than the current memory allocated, we will
// reuse the current memory by using ReMakePtenDenseTensorFromVar.
// Otherwise,we will create new storage.
for (size_t offset = 0; offset < outs_vector.size(); ++offset) { for (size_t offset = 0; offset < outs_vector.size(); ++offset) {
if (current_vector_size > start_idx + offset) { pt_kernel_context->EmplaceBackOutputWithoutSetRange(
auto* buffer_tensor = experimental::MakePtenTensorBaseFromVar(outs_vector[offset],
pt_kernel_context_->MutableOutputAt<pten::DenseTensor>(start_idx + out_def));
offset);
if (buffer_tensor) {
experimental::ReMakePtenDenseTensorFromVar(outs_vector[offset],
out_def, buffer_tensor);
}
} else {
pt_kernel_context_->EmplaceBackOutputWithoutSetRange(
experimental::MakePtenTensorBaseFromVar(outs_vector[offset],
out_def));
}
} }
// Deal with the case that some outputs are NULL when run the kernel. // Deal with the case that some outputs are NULL when run the kernel.
// For example : the outputs of matmul_grad are dx and dy, // For example : the outputs of matmul_grad are dx and dy,
// sometimes dx or dy may be NULL. // sometimes dx or dy may be NULL.
if (outs_vector.empty()) { if (outs_vector.empty()) {
if (current_vector_size > start_idx) { pt_kernel_context->EmplaceBackOutputWithoutSetRange({nullptr});
pt_kernel_context_->SetOutputWithoutSetRange(start_idx, {nullptr});
} else {
pt_kernel_context_->EmplaceBackOutputWithoutSetRange({nullptr});
}
end_idx = start_idx + 1; end_idx = start_idx + 1;
} }
pt_kernel_context_->AssignOutputRange(std::make_pair(start_idx, end_idx), pt_kernel_context->AssignOutputRange(std::make_pair(start_idx, end_idx), i);
i);
} }
for (size_t i = 0; i < attr_names.size(); ++i) { for (size_t i = 0; i < attr_names.size(); ++i) {
...@@ -1915,11 +1865,11 @@ void OperatorWithKernel::BuildPtenKernelContext( ...@@ -1915,11 +1865,11 @@ void OperatorWithKernel::BuildPtenKernelContext(
if (attr_iter != Attrs().end()) { // shape is in the attribute if (attr_iter != Attrs().end()) { // shape is in the attribute
if (std::type_index(attr_iter->second.type()) == if (std::type_index(attr_iter->second.type()) ==
std::type_index(typeid(std::vector<int64_t>))) { std::type_index(typeid(std::vector<int64_t>))) {
pt_kernel_context_->EmplaceBackAttr(std::move(pten::ScalarArray( pt_kernel_context->EmplaceBackAttr(std::move(pten::ScalarArray(
BOOST_GET_CONST(std::vector<int64_t>, attr_iter->second)))); BOOST_GET_CONST(std::vector<int64_t>, attr_iter->second))));
} else if (std::type_index(attr_iter->second.type()) == } else if (std::type_index(attr_iter->second.type()) ==
std::type_index(typeid(std::vector<int32_t>))) { std::type_index(typeid(std::vector<int32_t>))) {
pt_kernel_context_->EmplaceBackAttr(std::move(pten::ScalarArray( pt_kernel_context->EmplaceBackAttr(std::move(pten::ScalarArray(
BOOST_GET_CONST(std::vector<int32_t>, attr_iter->second)))); BOOST_GET_CONST(std::vector<int32_t>, attr_iter->second))));
} else { } else {
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(platform::errors::Unimplemented(
...@@ -1930,10 +1880,10 @@ void OperatorWithKernel::BuildPtenKernelContext( ...@@ -1930,10 +1880,10 @@ void OperatorWithKernel::BuildPtenKernelContext(
} else { // shape is in the input } else { // shape is in the input
auto& ins_vector = ctx.inputs.at(attr_names[i]); auto& ins_vector = ctx.inputs.at(attr_names[i]);
if (ins_vector.size() == 1) { // ShapeTensor if (ins_vector.size() == 1) { // ShapeTensor
pt_kernel_context_->EmplaceBackAttr(std::move( pt_kernel_context->EmplaceBackAttr(std::move(
experimental::MakePtenScalarArrayFromVar(*ins_vector.front()))); experimental::MakePtenScalarArrayFromVar(*ins_vector.front())));
} else { // ShapeTensorList } else { // ShapeTensorList
pt_kernel_context_->EmplaceBackAttr(std::move( pt_kernel_context->EmplaceBackAttr(std::move(
experimental::MakePtenScalarArrayFromVarList(ins_vector))); experimental::MakePtenScalarArrayFromVarList(ins_vector)));
} }
} }
...@@ -1946,11 +1896,11 @@ void OperatorWithKernel::BuildPtenKernelContext( ...@@ -1946,11 +1896,11 @@ void OperatorWithKernel::BuildPtenKernelContext(
if (attr_iter != Attrs().end()) { // scalar is in the attribute if (attr_iter != Attrs().end()) { // scalar is in the attribute
auto& attr = Attrs().at(attr_names[i]); auto& attr = Attrs().at(attr_names[i]);
if (std::type_index(attr.type()) == std::type_index(typeid(float))) { if (std::type_index(attr.type()) == std::type_index(typeid(float))) {
pt_kernel_context_->EmplaceBackAttr( pt_kernel_context->EmplaceBackAttr(
std::move(pten::Scalar(BOOST_GET_CONST(float, attr)))); std::move(pten::Scalar(BOOST_GET_CONST(float, attr))));
} else if (std::type_index(attr.type()) == } else if (std::type_index(attr.type()) ==
std::type_index(typeid(std::string))) { std::type_index(typeid(std::string))) {
pt_kernel_context_->EmplaceBackAttr( pt_kernel_context->EmplaceBackAttr(
std::move(pten::Scalar(BOOST_GET_CONST(std::string, attr)))); std::move(pten::Scalar(BOOST_GET_CONST(std::string, attr))));
} else { } else {
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(platform::errors::Unimplemented(
...@@ -1960,7 +1910,7 @@ void OperatorWithKernel::BuildPtenKernelContext( ...@@ -1960,7 +1910,7 @@ void OperatorWithKernel::BuildPtenKernelContext(
} }
} else { } else {
auto& ins_vector = ctx.inputs.at(attr_names[i]); auto& ins_vector = ctx.inputs.at(attr_names[i]);
pt_kernel_context_->EmplaceBackAttr(std::move( pt_kernel_context->EmplaceBackAttr(std::move(
experimental::MakePtenScalarFromVar(*ins_vector.front()))); experimental::MakePtenScalarFromVar(*ins_vector.front())));
} }
...@@ -1968,17 +1918,17 @@ void OperatorWithKernel::BuildPtenKernelContext( ...@@ -1968,17 +1918,17 @@ void OperatorWithKernel::BuildPtenKernelContext(
// TODO(chenweihang): support other attrs later // TODO(chenweihang): support other attrs later
auto& attr = Attrs().at(attr_names[i]); auto& attr = Attrs().at(attr_names[i]);
if (attr_defs[i].type_index == std::type_index(typeid(int))) { if (attr_defs[i].type_index == std::type_index(typeid(int))) {
pt_kernel_context_->EmplaceBackAttr(BOOST_GET_CONST(int, attr)); pt_kernel_context->EmplaceBackAttr(BOOST_GET_CONST(int, attr));
} else if (attr_defs[i].type_index == std::type_index(typeid(float))) { } else if (attr_defs[i].type_index == std::type_index(typeid(float))) {
pt_kernel_context_->EmplaceBackAttr(BOOST_GET_CONST(float, attr)); pt_kernel_context->EmplaceBackAttr(BOOST_GET_CONST(float, attr));
} else if (attr_defs[i].type_index == std::type_index(typeid(bool))) { } else if (attr_defs[i].type_index == std::type_index(typeid(bool))) {
pt_kernel_context_->EmplaceBackAttr(BOOST_GET_CONST(bool, attr)); pt_kernel_context->EmplaceBackAttr(BOOST_GET_CONST(bool, attr));
} else if (attr_defs[i].type_index == } else if (attr_defs[i].type_index ==
std::type_index(typeid(pten::DataType))) { std::type_index(typeid(pten::DataType))) {
auto data_type = pten::TransToPtenDataType( auto data_type = pten::TransToPtenDataType(
static_cast<framework::proto::VarType::Type>( static_cast<framework::proto::VarType::Type>(
BOOST_GET_CONST(int, attr))); BOOST_GET_CONST(int, attr)));
pt_kernel_context_->EmplaceBackAttr(data_type); pt_kernel_context->EmplaceBackAttr(data_type);
} else if (attr_defs[i].type_index == } else if (attr_defs[i].type_index ==
std::type_index(typeid(std::vector<int64_t>))) { std::type_index(typeid(std::vector<int64_t>))) {
if (std::type_index(attr.type()) == if (std::type_index(attr.type()) ==
...@@ -1987,7 +1937,7 @@ void OperatorWithKernel::BuildPtenKernelContext( ...@@ -1987,7 +1937,7 @@ void OperatorWithKernel::BuildPtenKernelContext(
const auto& vector_int_attr = BOOST_GET_CONST(std::vector<int>, attr); const auto& vector_int_attr = BOOST_GET_CONST(std::vector<int>, attr);
const std::vector<int64_t> vector_int64_attr(vector_int_attr.begin(), const std::vector<int64_t> vector_int64_attr(vector_int_attr.begin(),
vector_int_attr.end()); vector_int_attr.end());
pt_kernel_context_->EmplaceBackAttr(vector_int64_attr); pt_kernel_context->EmplaceBackAttr(vector_int64_attr);
} }
// TODO(YuanRisheng) Need support vector<int64_t> attr // TODO(YuanRisheng) Need support vector<int64_t> attr
...@@ -2001,20 +1951,16 @@ void OperatorWithKernel::BuildPtenKernelContext( ...@@ -2001,20 +1951,16 @@ void OperatorWithKernel::BuildPtenKernelContext(
} }
} }
void OperatorWithKernel::WriteBackToOutputs(RuntimeContext* ctx) const { void OperatorWithKernel::WriteBackToOutputs(
// auto& input_names = std::get<0>(pt_kernel_signature_->args); RuntimeContext* ctx, pten::KernelContext* pt_kernel_context) const {
// auto& attr_names = std::get<1>(pt_kernel_signature_->args);
auto& output_names = std::get<2>(pt_kernel_signature_->args); auto& output_names = std::get<2>(pt_kernel_signature_->args);
// pt_kernel_context_
for (size_t i = 0; i < output_names.size(); ++i) { for (size_t i = 0; i < output_names.size(); ++i) {
auto& outs_vector = ctx->outputs.at(output_names[i]); auto& outs_vector = ctx->outputs.at(output_names[i]);
auto& range_pair = pt_kernel_context_->OutputRangeAt(i); auto& range_pair = pt_kernel_context->OutputRangeAt(i);
auto pten_outs = auto pten_outs = pt_kernel_context->MutableOutputBetween<pten::DenseTensor>(
pt_kernel_context_->MutableOutputBetween<pten::DenseTensor>( range_pair.first, range_pair.second);
range_pair.first, range_pair.second);
for (size_t j = 0; j < pten_outs.size(); ++j) { for (size_t j = 0; j < pten_outs.size(); ++j) {
if (pten_outs[j]) { if (pten_outs[j]) {
......
...@@ -589,16 +589,14 @@ class OperatorWithKernel : public OperatorBase { ...@@ -589,16 +589,14 @@ class OperatorWithKernel : public OperatorBase {
void ChoosePtenKernel(const ExecutionContext& ctx) const; void ChoosePtenKernel(const ExecutionContext& ctx) const;
void BuildPtenKernelContext(const RuntimeContext& ctx, void BuildPtenKernelContext(const RuntimeContext& ctx,
platform::DeviceContext* dev_ctx) const; platform::DeviceContext* dev_ctx,
pten::KernelContext* pt_kernel_context) const;
void WriteBackToOutputs(RuntimeContext* ctx) const; void WriteBackToOutputs(RuntimeContext* ctx,
pten::KernelContext* pt_kernel_context) const;
pten::Kernel* PtenKernel() const { return pt_kernel_.get(); } pten::Kernel* PtenKernel() const { return pt_kernel_.get(); }
pten::KernelContext* PtenKernelContext() const {
return pt_kernel_context_.get();
}
const OpKernelType* kernel_type() const { return kernel_type_.get(); } const OpKernelType* kernel_type() const { return kernel_type_.get(); }
private: private:
...@@ -657,9 +655,6 @@ class OperatorWithKernel : public OperatorBase { ...@@ -657,9 +655,6 @@ class OperatorWithKernel : public OperatorBase {
mutable bool run_pten_kernel_ = false; mutable bool run_pten_kernel_ = false;
mutable std::unique_ptr<KernelSignature> pt_kernel_signature_; mutable std::unique_ptr<KernelSignature> pt_kernel_signature_;
mutable std::unique_ptr<pten::Kernel> pt_kernel_; mutable std::unique_ptr<pten::Kernel> pt_kernel_;
// In order to reduce the compatibility phase
// performance overhead, temporarily cache KernelContext
mutable std::unique_ptr<pten::KernelContext> pt_kernel_context_;
}; };
extern bool OpSupportGPU(const std::string& op_type); extern bool OpSupportGPU(const std::string& op_type);
......
...@@ -409,8 +409,6 @@ void VarBase::_CopyGradientFrom(const VarBase& src) { ...@@ -409,8 +409,6 @@ void VarBase::_CopyGradientFrom(const VarBase& src) {
} }
} }
pten::KernelContext OpBase::pt_kernel_context_;
void OpBase::SetType(const std::string& type) { void OpBase::SetType(const std::string& type) {
op_ = framework::OpRegistry::CreateOp(type, {}, {}, {}, false); op_ = framework::OpRegistry::CreateOp(type, {}, {}, {}, false);
} }
...@@ -426,8 +424,7 @@ static void OpBaseRunImpl(const framework::OperatorBase& op, ...@@ -426,8 +424,7 @@ static void OpBaseRunImpl(const framework::OperatorBase& op,
const NameVarMap<VarType>& outs, const NameVarMap<VarType>& outs,
const framework::AttributeMap& attrs, const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs, const framework::AttributeMap& default_attrs,
const platform::Place& place, const platform::Place& place) {
pten::KernelContext* pt_kernel_context) {
auto* op_kernel = dynamic_cast<const framework::OperatorWithKernel*>(&op); auto* op_kernel = dynamic_cast<const framework::OperatorWithKernel*>(&op);
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
op_kernel, platform::errors::PermissionDenied( op_kernel, platform::errors::PermissionDenied(
...@@ -468,8 +465,8 @@ static void OpBaseRunImpl(const framework::OperatorBase& op, ...@@ -468,8 +465,8 @@ static void OpBaseRunImpl(const framework::OperatorBase& op,
* after the execution of op, but the original input is directly * after the execution of op, but the original input is directly
* overwritten in the previous dynamic graph implemention. * overwritten in the previous dynamic graph implemention.
*/ */
auto prepared_op = PreparedOp::Prepare(ins, outs, *op_kernel, place, attrs, auto prepared_op =
default_attrs, pt_kernel_context); PreparedOp::Prepare(ins, outs, *op_kernel, place, attrs, default_attrs);
auto tmp_ins_ptr = auto tmp_ins_ptr =
PrepareData<VarType>(*op_kernel, ins, prepared_op.kernel_type()); PrepareData<VarType>(*op_kernel, ins, prepared_op.kernel_type());
if (tmp_ins_ptr == nullptr) { if (tmp_ins_ptr == nullptr) {
...@@ -497,8 +494,7 @@ void OpBase::Run(const framework::OperatorBase& op, ...@@ -497,8 +494,7 @@ void OpBase::Run(const framework::OperatorBase& op,
const framework::AttributeMap& attrs, const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs, const framework::AttributeMap& default_attrs,
const platform::Place& place) { const platform::Place& place) {
OpBaseRunImpl<VarBase>(op, ins, outs, attrs, default_attrs, place, OpBaseRunImpl<VarBase>(op, ins, outs, attrs, default_attrs, place);
&pt_kernel_context_);
} }
void OpBase::Run(const framework::OperatorBase& op, void OpBase::Run(const framework::OperatorBase& op,
...@@ -507,8 +503,7 @@ void OpBase::Run(const framework::OperatorBase& op, ...@@ -507,8 +503,7 @@ void OpBase::Run(const framework::OperatorBase& op,
const framework::AttributeMap& attrs, const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs, const framework::AttributeMap& default_attrs,
const platform::Place& place) { const platform::Place& place) {
OpBaseRunImpl<VariableWrapper>(op, ins, outs, attrs, default_attrs, place, OpBaseRunImpl<VariableWrapper>(op, ins, outs, attrs, default_attrs, place);
&pt_kernel_context_);
} }
void ClearNoNeedBufferInputs(OpBase* op) { void ClearNoNeedBufferInputs(OpBase* op) {
......
...@@ -183,8 +183,6 @@ class OpBase { ...@@ -183,8 +183,6 @@ class OpBase {
const framework::AttributeMap& default_attrs, const framework::AttributeMap& default_attrs,
const platform::Place& place); const platform::Place& place);
static pten::KernelContext* GetKernelContext() { return &pt_kernel_context_; }
bool HasVoidFunctionPostHook() const { bool HasVoidFunctionPostHook() const {
return !void_function_post_hooks_.empty(); return !void_function_post_hooks_.empty();
} }
...@@ -212,9 +210,6 @@ class OpBase { ...@@ -212,9 +210,6 @@ class OpBase {
std::unique_ptr<framework::OperatorBase> op_; std::unique_ptr<framework::OperatorBase> op_;
platform::Place place_; platform::Place place_;
size_t id_{-1UL}; size_t id_{-1UL};
// In order to reduce the compatibility phase
// performance overhead, temporarily cache KernelContext
static pten::KernelContext pt_kernel_context_;
std::vector<std::shared_ptr<std::function<void()>>> void_function_post_hooks_; std::vector<std::shared_ptr<std::function<void()>>> void_function_post_hooks_;
}; };
......
...@@ -117,7 +117,6 @@ PreparedOp::PreparedOp(const framework::OperatorBase& op, ...@@ -117,7 +117,6 @@ PreparedOp::PreparedOp(const framework::OperatorBase& op,
const framework::OpKernelType& kernel_type, const framework::OpKernelType& kernel_type,
const framework::KernelSignature& kernel_signature, const framework::KernelSignature& kernel_signature,
const pten::Kernel& pt_kernel, const pten::Kernel& pt_kernel,
pten::KernelContext* pt_kernel_context,
platform::DeviceContext* dev_ctx) platform::DeviceContext* dev_ctx)
: op_(op), : op_(op),
ctx_(ctx), ctx_(ctx),
...@@ -126,8 +125,7 @@ PreparedOp::PreparedOp(const framework::OperatorBase& op, ...@@ -126,8 +125,7 @@ PreparedOp::PreparedOp(const framework::OperatorBase& op,
dev_ctx_(dev_ctx), dev_ctx_(dev_ctx),
run_pten_kernel_(true), run_pten_kernel_(true),
pt_kernel_signature_(kernel_signature), pt_kernel_signature_(kernel_signature),
pt_kernel_(pt_kernel), pt_kernel_(pt_kernel) {}
pt_kernel_context_(pt_kernel_context) {}
template <typename VarType> template <typename VarType>
PreparedOp PrepareImpl(const NameVarMap<VarType>& ins, PreparedOp PrepareImpl(const NameVarMap<VarType>& ins,
...@@ -135,8 +133,7 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins, ...@@ -135,8 +133,7 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins,
const framework::OperatorWithKernel& op, const framework::OperatorWithKernel& op,
const platform::Place& place, const platform::Place& place,
const framework::AttributeMap& attrs, const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs, const framework::AttributeMap& default_attrs) {
pten::KernelContext* pt_kernel_context) {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto* dev_ctx = pool.Get(place); auto* dev_ctx = pool.Get(place);
...@@ -178,7 +175,7 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins, ...@@ -178,7 +175,7 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins,
// TODO(chenweihang): using CPUKernel when miss device kernel case // TODO(chenweihang): using CPUKernel when miss device kernel case
return PreparedOp(op, ctx, expected_kernel_key, pt_kernel_signature, return PreparedOp(op, ctx, expected_kernel_key, pt_kernel_signature,
pt_kernel, pt_kernel_context, dev_ctx); pt_kernel, dev_ctx);
} else { } else {
VLOG(6) << "Dynamic mode ChoosePtenKernel - kernel `" << pt_kernel_name VLOG(6) << "Dynamic mode ChoosePtenKernel - kernel `" << pt_kernel_name
<< "` not found."; << "` not found.";
...@@ -247,10 +244,8 @@ PreparedOp PreparedOp::Prepare(const NameVarMap<VarBase>& ins, ...@@ -247,10 +244,8 @@ PreparedOp PreparedOp::Prepare(const NameVarMap<VarBase>& ins,
const framework::OperatorWithKernel& op, const framework::OperatorWithKernel& op,
const platform::Place& place, const platform::Place& place,
const framework::AttributeMap& attrs, const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs, const framework::AttributeMap& default_attrs) {
pten::KernelContext* pt_kernel_context) { return PrepareImpl<VarBase>(ins, outs, op, place, attrs, default_attrs);
return PrepareImpl<VarBase>(ins, outs, op, place, attrs, default_attrs,
pt_kernel_context);
} }
PreparedOp PreparedOp::Prepare(const NameVarMap<VariableWrapper>& ins, PreparedOp PreparedOp::Prepare(const NameVarMap<VariableWrapper>& ins,
...@@ -258,10 +253,9 @@ PreparedOp PreparedOp::Prepare(const NameVarMap<VariableWrapper>& ins, ...@@ -258,10 +253,9 @@ PreparedOp PreparedOp::Prepare(const NameVarMap<VariableWrapper>& ins,
const framework::OperatorWithKernel& op, const framework::OperatorWithKernel& op,
const platform::Place& place, const platform::Place& place,
const framework::AttributeMap& attrs, const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs, const framework::AttributeMap& default_attrs) {
pten::KernelContext* pt_kernel_context) {
return PrepareImpl<VariableWrapper>(ins, outs, op, place, attrs, return PrepareImpl<VariableWrapper>(ins, outs, op, place, attrs,
default_attrs, pt_kernel_context); default_attrs);
} }
template <typename VarType> template <typename VarType>
...@@ -271,13 +265,6 @@ static void BuildDygraphPtenKernelContext( ...@@ -271,13 +265,6 @@ static void BuildDygraphPtenKernelContext(
const NameVarMap<VarType>& outs, const framework::AttributeMap& attrs, const NameVarMap<VarType>& outs, const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs, const framework::AttributeMap& default_attrs,
platform::DeviceContext* dev_ctx, pten::KernelContext* kernel_ctx) { platform::DeviceContext* dev_ctx, pten::KernelContext* kernel_ctx) {
// TODO(chenweihang): now only work for very simple case,
// many cases need to be deal with later:
// 1. the input and output are not tensor
// 2. the dispensbale, duplicable input and output
// 3. needless attributes remove
// 4. use pt Tensor directly
// 5. kernel input is not DenseTensor
kernel_ctx->SetDeviceContext(dev_ctx); kernel_ctx->SetDeviceContext(dev_ctx);
auto& input_names = std::get<0>(pt_kernel_signature.args); auto& input_names = std::get<0>(pt_kernel_signature.args);
...@@ -312,26 +299,11 @@ static void BuildDygraphPtenKernelContext( ...@@ -312,26 +299,11 @@ static void BuildDygraphPtenKernelContext(
size_t start_idx = (i == 0 ? 0 : kernel_ctx->InputRangeAt(i - 1).second); size_t start_idx = (i == 0 ? 0 : kernel_ctx->InputRangeAt(i - 1).second);
size_t end_idx = start_idx + ins_vector.size(); size_t end_idx = start_idx + ins_vector.size();
auto current_vector_size = kernel_ctx->InputsSize();
// If the memory needed is less than the current memory allocated, we will
// reuse the current memory by using ReMakePtenDenseTensorFromVar.
// Otherwise,we will create new storage.
for (size_t offset = 0; offset < ins_vector.size(); ++offset) { for (size_t offset = 0; offset < ins_vector.size(); ++offset) {
const auto& variable = ins_vector[offset]->Var(); const auto& variable = ins_vector[offset]->Var();
if (current_vector_size > start_idx + offset) { kernel_ctx->EmplaceBackInputWithoutSetRange(
auto& input_ptr = kernel_ctx->MutableInputPtrAt(start_idx + offset); paddle::experimental::MakePtenTensorBaseFromVar(variable, in_def));
if (input_ptr == nullptr) {
input_ptr = experimental::MakePtenTensorBaseFromVar(variable, in_def);
} else {
experimental::ReMakePtenDenseTensorFromVar(
variable, in_def, kernel_ctx->MutableInputAt<pten::DenseTensor>(
start_idx + offset));
}
} else {
kernel_ctx->EmplaceBackInputWithoutSetRange(
experimental::MakePtenTensorBaseFromVar(variable, in_def));
}
} }
kernel_ctx->AssignInputRange(std::make_pair(start_idx, end_idx), i); kernel_ctx->AssignInputRange(std::make_pair(start_idx, end_idx), i);
} }
...@@ -340,15 +312,10 @@ static void BuildDygraphPtenKernelContext( ...@@ -340,15 +312,10 @@ static void BuildDygraphPtenKernelContext(
auto& out_def = output_defs.at(i); auto& out_def = output_defs.at(i);
size_t start_idx = (i == 0 ? 0 : kernel_ctx->OutputRangeAt(i - 1).second); size_t start_idx = (i == 0 ? 0 : kernel_ctx->OutputRangeAt(i - 1).second);
auto current_vector_size = kernel_ctx->OutputsSize();
auto iter = outs.find(output_names[i]); auto iter = outs.find(output_names[i]);
if (iter == outs.end()) { if (iter == outs.end()) {
if (current_vector_size > start_idx) { kernel_ctx->EmplaceBackOutputWithoutSetRange({nullptr});
kernel_ctx->SetOutputWithoutSetRange(start_idx, {nullptr});
} else {
kernel_ctx->EmplaceBackOutputWithoutSetRange({nullptr});
}
kernel_ctx->AssignOutputRange(std::make_pair(start_idx, start_idx + 1), kernel_ctx->AssignOutputRange(std::make_pair(start_idx, start_idx + 1),
i); i);
continue; continue;
...@@ -357,27 +324,10 @@ static void BuildDygraphPtenKernelContext( ...@@ -357,27 +324,10 @@ static void BuildDygraphPtenKernelContext(
auto& outs_vector = iter->second; auto& outs_vector = iter->second;
size_t end_idx = start_idx + outs_vector.size(); size_t end_idx = start_idx + outs_vector.size();
// If the memory needed is less than the current memory allocated, we will
// reuse the current memory by using ReMakePtenDenseTensorFromVar.
// Otherwise,we will create new storage.
for (size_t offset = 0; offset < outs_vector.size(); ++offset) { for (size_t offset = 0; offset < outs_vector.size(); ++offset) {
if (current_vector_size > start_idx + offset) { kernel_ctx->EmplaceBackOutputWithoutSetRange(
auto* buffer_tensor = paddle::experimental::MakePtenTensorBaseFromVar(
kernel_ctx->MutableOutputAt<pten::DenseTensor>(start_idx + offset); outs_vector[offset]->MutableVar(), out_def));
if (buffer_tensor) {
experimental::ReMakePtenDenseTensorFromVar(
outs_vector[offset]->MutableVar(), out_def, buffer_tensor);
} else {
kernel_ctx->SetOutputWithoutSetRange(
start_idx + offset,
experimental::MakePtenTensorBaseFromVar(
outs_vector[offset]->MutableVar(), out_def));
}
} else {
kernel_ctx->EmplaceBackOutputWithoutSetRange(
experimental::MakePtenTensorBaseFromVar(
outs_vector[offset]->MutableVar(), out_def));
}
} }
kernel_ctx->AssignOutputRange(std::make_pair(start_idx, end_idx), i); kernel_ctx->AssignOutputRange(std::make_pair(start_idx, end_idx), i);
} }
...@@ -556,19 +506,20 @@ static void PreparedOpRunPtImpl( ...@@ -556,19 +506,20 @@ static void PreparedOpRunPtImpl(
const framework::OperatorBase& op, const framework::OperatorBase& op,
const framework::OpKernelType& kernel_type, const framework::OpKernelType& kernel_type,
const framework::KernelSignature& pt_kernel_signature, const framework::KernelSignature& pt_kernel_signature,
const pten::Kernel& pt_kernel, pten::KernelContext* pt_kernel_context, const pten::Kernel& pt_kernel, platform::DeviceContext* dev_ctx,
platform::DeviceContext* dev_ctx, const NameVarMap<VarType>& ins, const NameVarMap<VarType>& ins, const NameVarMap<VarType>& outs,
const NameVarMap<VarType>& outs, const framework::AttributeMap& attrs, const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs) { const framework::AttributeMap& default_attrs) {
DygraphInferShapeContext<VarType> infer_shape_ctx( DygraphInferShapeContext<VarType> infer_shape_ctx(
&ins, &outs, &attrs, &default_attrs, op.Type(), &kernel_type); &ins, &outs, &attrs, &default_attrs, op.Type(), &kernel_type);
op.Info().infer_shape_(&infer_shape_ctx); op.Info().infer_shape_(&infer_shape_ctx);
pten::KernelContext pt_kernel_context;
BuildDygraphPtenKernelContext<VarType>(pt_kernel_signature, pt_kernel, ins, BuildDygraphPtenKernelContext<VarType>(pt_kernel_signature, pt_kernel, ins,
outs, attrs, default_attrs, dev_ctx, outs, attrs, default_attrs, dev_ctx,
pt_kernel_context); &pt_kernel_context);
pt_kernel(pt_kernel_context); pt_kernel(&pt_kernel_context);
if (FLAGS_benchmark) { if (FLAGS_benchmark) {
dev_ctx->Wait(); dev_ctx->Wait();
...@@ -578,10 +529,7 @@ static void PreparedOpRunPtImpl( ...@@ -578,10 +529,7 @@ static void PreparedOpRunPtImpl(
#endif #endif
} }
WriteBackToOutputs<VarType>(pt_kernel_signature, outs, pt_kernel_context); WriteBackToOutputs<VarType>(pt_kernel_signature, outs, &pt_kernel_context);
// Ensure that it does not affect the VarBase life cycle management
pt_kernel_context->ClearData();
// TODO(chenweihang): add debug flags later // TODO(chenweihang): add debug flags later
if (framework::IsComplexType(kernel_type.data_type_)) { if (framework::IsComplexType(kernel_type.data_type_)) {
...@@ -595,8 +543,8 @@ void PreparedOp::Run(const NameVarMap<VarBase>& ins, ...@@ -595,8 +543,8 @@ void PreparedOp::Run(const NameVarMap<VarBase>& ins,
const framework::AttributeMap& default_attrs) { const framework::AttributeMap& default_attrs) {
if (run_pten_kernel_) { if (run_pten_kernel_) {
PreparedOpRunPtImpl<VarBase>(op_, kernel_type_, pt_kernel_signature_, PreparedOpRunPtImpl<VarBase>(op_, kernel_type_, pt_kernel_signature_,
pt_kernel_, pt_kernel_context_, dev_ctx_, ins, pt_kernel_, dev_ctx_, ins, outs, attrs,
outs, attrs, default_attrs); default_attrs);
} else { } else {
PreparedOpRunImpl<VarBase>(op_, ctx_, kernel_type_, func_, dev_ctx_, ins, PreparedOpRunImpl<VarBase>(op_, ctx_, kernel_type_, func_, dev_ctx_, ins,
outs, attrs, default_attrs); outs, attrs, default_attrs);
...@@ -609,8 +557,8 @@ void PreparedOp::Run(const NameVarMap<VariableWrapper>& ins, ...@@ -609,8 +557,8 @@ void PreparedOp::Run(const NameVarMap<VariableWrapper>& ins,
const framework::AttributeMap& default_attrs) { const framework::AttributeMap& default_attrs) {
if (run_pten_kernel_) { if (run_pten_kernel_) {
PreparedOpRunPtImpl<VariableWrapper>( PreparedOpRunPtImpl<VariableWrapper>(
op_, kernel_type_, pt_kernel_signature_, pt_kernel_, pt_kernel_context_, op_, kernel_type_, pt_kernel_signature_, pt_kernel_, dev_ctx_, ins,
dev_ctx_, ins, outs, attrs, default_attrs); outs, attrs, default_attrs);
} else { } else {
PreparedOpRunImpl<VariableWrapper>(op_, ctx_, kernel_type_, func_, dev_ctx_, PreparedOpRunImpl<VariableWrapper>(op_, ctx_, kernel_type_, func_, dev_ctx_,
ins, outs, attrs, default_attrs); ins, outs, attrs, default_attrs);
......
...@@ -153,25 +153,21 @@ class PreparedOp { ...@@ -153,25 +153,21 @@ class PreparedOp {
const framework::RuntimeContext& ctx, const framework::RuntimeContext& ctx,
const framework::OpKernelType& kernel_type, const framework::OpKernelType& kernel_type,
const framework::KernelSignature& kernel_signature, const framework::KernelSignature& kernel_signature,
const pten::Kernel& pt_kernel, const pten::Kernel& pt_kernel, platform::DeviceContext* dev_ctx);
pten::KernelContext* pt_kernel_context,
platform::DeviceContext* dev_ctx);
static PreparedOp Prepare(const NameVarMap<VarBase>& ins, static PreparedOp Prepare(const NameVarMap<VarBase>& ins,
const NameVarMap<VarBase>& outs, const NameVarMap<VarBase>& outs,
const framework::OperatorWithKernel& op, const framework::OperatorWithKernel& op,
const platform::Place& place, const platform::Place& place,
const framework::AttributeMap& attrs, const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs, const framework::AttributeMap& default_attrs);
pten::KernelContext* pt_kernel_context = nullptr);
static PreparedOp Prepare(const NameVarMap<VariableWrapper>& ins, static PreparedOp Prepare(const NameVarMap<VariableWrapper>& ins,
const NameVarMap<VariableWrapper>& outs, const NameVarMap<VariableWrapper>& outs,
const framework::OperatorWithKernel& op, const framework::OperatorWithKernel& op,
const platform::Place& place, const platform::Place& place,
const framework::AttributeMap& attrs, const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs, const framework::AttributeMap& default_attrs);
pten::KernelContext* pt_kernel_context = nullptr);
void Run(const NameVarMap<VarBase>& in, const NameVarMap<VarBase>& out, void Run(const NameVarMap<VarBase>& in, const NameVarMap<VarBase>& out,
const framework::AttributeMap& attrs, const framework::AttributeMap& attrs,
...@@ -196,9 +192,6 @@ class PreparedOp { ...@@ -196,9 +192,6 @@ class PreparedOp {
bool run_pten_kernel_{false}; bool run_pten_kernel_{false};
framework::KernelSignature pt_kernel_signature_; framework::KernelSignature pt_kernel_signature_;
pten::Kernel pt_kernel_; pten::Kernel pt_kernel_;
// In order to reduce the compatibility phase
// performance overhead, temporarily cache KernelContext
pten::KernelContext* pt_kernel_context_;
}; };
} // namespace imperative } // namespace imperative
......
...@@ -231,8 +231,6 @@ void Tracer::TraceOp(const std::string& type, const NameVarBaseMap& ins, ...@@ -231,8 +231,6 @@ void Tracer::TraceOp(const std::string& type, const NameVarBaseMap& ins,
OpBase::Run(*op, new_ins, outs, attrs, default_attrs, place); OpBase::Run(*op, new_ins, outs, attrs, default_attrs, place);
} catch (platform::EnforceNotMet& exception) { } catch (platform::EnforceNotMet& exception) {
framework::AppendErrorOpHint(type, &exception); framework::AppendErrorOpHint(type, &exception);
// Compatible impl: clear pten kernel context data when throw error
OpBase::GetKernelContext()->ClearData();
throw std::move(exception); throw std::move(exception);
} catch (std::exception& ex) { } catch (std::exception& ex) {
PADDLE_THROW(platform::errors::Fatal( PADDLE_THROW(platform::errors::Fatal(
......
...@@ -202,22 +202,6 @@ class _DataLoaderIterSingleProcess(_DataLoaderIterBase): ...@@ -202,22 +202,6 @@ class _DataLoaderIterSingleProcess(_DataLoaderIterBase):
# APIs in this thread. # APIs in this thread.
_set_expected_place(legacy_expected_place) _set_expected_place(legacy_expected_place)
# NOTE(chenweihang): [ Why need to set not to execute pten kernel here? ]
# Now, in order to ensure that the execution performance of the dynamic
# graph mode in pten compatible state does not decline significantly,
# we have adopted the approach of caching a KernelContext globally for
# the dynamic graph tracer to reduce the construction and deconstruction
# overhead of data interfaces such as the compatible state DenseTensor.
# The static graph is each op caches a KernelContext, but the op of
# the dynamic graph will be constructed and destroyed every round of
# execution, so it is impossible to cache KernelContext for each op.
# However, it is not thread-safe if using only one global kernel context in
# dynamic graph. If the pten op of paddle is used in the DataLoader thread,
# it may cause access errors. We temporarily do not execute pten kernel
# in this scenario and will find a better solution later and remove
# this setting.
set_flags({'FLAGS_run_pten_kernel': False})
while not self._thread_done_event.is_set(): while not self._thread_done_event.is_set():
try: try:
indices = next(self._sampler_iter) indices = next(self._sampler_iter)
...@@ -519,9 +503,6 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase): ...@@ -519,9 +503,6 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase):
# APIs in this thread. # APIs in this thread.
_set_expected_place(legacy_expected_place) _set_expected_place(legacy_expected_place)
# NOTE(chenweihang): See Note [ Why need to set not to execute pten kernel here? ]
set_flags({'FLAGS_run_pten_kernel': False})
while not self._thread_done_event.is_set(): while not self._thread_done_event.is_set():
batch = self._get_data() batch = self._get_data()
if not self._thread_done_event.is_set(): if not self._thread_done_event.is_set():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册