未验证 提交 2b5771c4 编写于 作者: H huzhiqiang 提交者: GitHub

op cache supports un-persistable attributes (#43221)

上级 9b7126d0
......@@ -1529,8 +1529,20 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
// Do data transform before building KernelContext
// TODO(zhiqiu): support TransferInplaceVarsBack
PreparePhiData(exec_scope, *pt_kernel_, *kernel_signature_, runtime_ctx);
if (enable_cache_runtime_context_ && !need_prepare_phi_data_ &&
!need_prepare_data_) {
impl_ =
new CacheImpl(new phi::KernelContext(),
new RuntimeInferShapeContext(*this, *runtime_ctx));
BuildPhiKernelContext(*runtime_ctx, dev_ctx, impl_->getKernelContext());
(*pt_kernel_)(impl_->getKernelContext());
} else {
phi::KernelContext pt_kernel_context;
// Do data transform before building KernelContext
// TODO(zhiqiu): support TransferInplaceVarsBack
BuildPhiKernelContext(*runtime_ctx, dev_ctx, &pt_kernel_context);
(*pt_kernel_)(&pt_kernel_context);
}
} else {
(*kernel_func_)(
ExecutionContext(*this, exec_scope, *dev_ctx, *runtime_ctx));
......@@ -2386,7 +2398,6 @@ void OperatorWithKernel::BuildPhiKernelContext(
// calcute the start and end index of the input tensors
size_t start_idx =
(i == 0 ? 0 : pt_kernel_context->InputRangeAt(i - 1).second);
// deal with optional here
if ((it == ctx.inputs.end() || it->second.size() == 0) &&
(input_defs[i].type_index ==
......@@ -2400,6 +2411,7 @@ void OperatorWithKernel::BuildPhiKernelContext(
auto end_idx = start_idx + 1;
pt_kernel_context->AssignInputRange(std::make_pair(start_idx, end_idx),
i);
continue;
}
auto ins_vector = it->second;
......@@ -2414,6 +2426,7 @@ void OperatorWithKernel::BuildPhiKernelContext(
tensor_in = &(var->Get<phi::SelectedRows>());
pt_kernel_context->EmplaceBackInputWithoutSetRange(tensor_in);
} else if (var->IsType<framework::LoDTensorArray>()) {
need_prepare_phi_data_ = true;
paddle::small_vector<const phi::TensorBase*> tensor_vector;
auto& tensor_array = var->Get<framework::LoDTensorArray>();
for (auto& t : tensor_array) {
......@@ -2514,6 +2527,7 @@ void OperatorWithKernel::BuildPhiKernelContext(
attr_names[i]));
}
} else { // scalar is in the input
need_prepare_phi_data_ = true;
auto& ins_vector = ctx.inputs.at(attr_names[i]);
pt_kernel_context->EmplaceBackAttr(std::move(
experimental::MakePhiScalarFromVar(*ins_vector.front())));
......@@ -2545,6 +2559,7 @@ void OperatorWithKernel::BuildPhiKernelContext(
attr_names[i]));
}
} else { // shape is in the input
need_prepare_phi_data_ = true;
auto& ins_vector = ctx.inputs.at(attr_names[i]);
if (ins_vector.size() == 1) { // ShapeTensor
pt_kernel_context->EmplaceBackAttr(std::move(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册