未验证 提交 3ef2922b 编写于 作者: Z zyfncg 提交者: GitHub

【Pten】Remove WriteBackOutput in tensor_utils (#39291)

* remove remake densetensor

* fix eager test error

* fix bug in eager

* implement AllocateFrom

* remove WriteBackOutput

* fix problem of eager
Co-authored-by: Nzkh2016 <zhangkaihuo@baidu.com>
上级 90f44c6f
......@@ -245,8 +245,7 @@ class EagerTensor final {
auto tensor_dense =
std::dynamic_pointer_cast<pten::DenseTensor>(tensor_->impl());
if (tensor_dense && tensor_dense.get()) {
paddle::experimental::SharesStorage(tensor_dense.get(),
framework_tensor);
*framework_tensor = *tensor_dense;
} else {
PADDLE_THROW(paddle::platform::errors::Fatal(
"Unrecognized egr::EagerTensor type, only "
......
......@@ -207,17 +207,13 @@ static void RunKernelFunc(const framework::ExecutionContext& ctx,
"Tensors.",
vec_true_outs.size(), outs.size()));
for (size_t j = 0; j < vec_true_outs.size(); ++j) {
experimental::SharesStorage(
std::dynamic_pointer_cast<pten::DenseTensor>(outs.at(j).impl())
.get(),
vec_true_outs.at(j));
*vec_true_outs.at(j) =
*std::dynamic_pointer_cast<pten::DenseTensor>(outs.at(j).impl());
}
} else {
auto* true_out = ctx.Output<Tensor>(out_name);
experimental::SharesStorage(
std::dynamic_pointer_cast<pten::DenseTensor>(outs.at(i).impl())
.get(),
true_out);
*true_out =
*std::dynamic_pointer_cast<pten::DenseTensor>(outs.at(i).impl());
}
}
} catch (platform::EnforceNotMet& exception) {
......
......@@ -2105,24 +2105,5 @@ void OperatorWithKernel::BuildPtenKernelContext(
}
}
void OperatorWithKernel::WriteBackToOutputs(
RuntimeContext* ctx, pten::KernelContext* pt_kernel_context) const {
auto& output_names = std::get<2>(pt_kernel_signature_->args);
for (size_t i = 0; i < output_names.size(); ++i) {
auto& outs_vector = ctx->outputs.at(output_names[i]);
auto& range_pair = pt_kernel_context->OutputRangeAt(i);
auto pten_outs = pt_kernel_context->MutableOutputBetween<pten::DenseTensor>(
range_pair.first, range_pair.second);
for (size_t j = 0; j < pten_outs.size(); ++j) {
if (pten_outs[j]) {
experimental::MakeVariableFromPtenTensor(pten_outs[j], outs_vector[j]);
}
}
}
}
} // namespace framework
} // namespace paddle
......@@ -616,9 +616,6 @@ class OperatorWithKernel : public OperatorBase {
platform::DeviceContext* dev_ctx,
pten::KernelContext* pt_kernel_context) const;
void WriteBackToOutputs(RuntimeContext* ctx,
pten::KernelContext* pt_kernel_context) const;
pten::KernelSignature* PtenKernelSignature() const {
return pt_kernel_signature_.get();
}
......
......@@ -198,68 +198,6 @@ pten::ScalarArray MakePtenScalarArrayFromVarList(
return {vector_data};
}
void SharesStorageBase(pten::DenseTensor* src, paddle::framework::Tensor* dst) {
PADDLE_ENFORCE_NOT_NULL(
src,
platform::errors::InvalidArgument(
"The source DenseTensor is nullptr when move allocation."));
PADDLE_ENFORCE_NOT_NULL(
dst,
platform::errors::InvalidArgument(
"The destination Tensor is nullptr when move allocation."));
dst->Resize(src->dims());
dst->ResetHolderWithType(src->Holder(),
pten::TransToProtoVarType(src->dtype()));
dst->set_offset(src->meta().offset);
}
void SharesStorage(pten::DenseTensor* src, paddle::framework::Tensor* dst) {
SharesStorageBase(src, static_cast<paddle::framework::Tensor*>(dst));
SetLoD(dst->mutable_lod(), src->lod());
}
static bool IsSameAllocation(const std::shared_ptr<memory::Allocation>& a,
const std::shared_ptr<memory::Allocation>& b) {
return a->ptr() == b->ptr() && a->size() == b->size() &&
platform::is_same_place(a->place(), b->place());
}
void MakeVariableFromPtenTensor(pten::DenseTensor* src,
framework::Variable* variable) {
if (variable->IsType<framework::LoDTensor>()) {
auto* tensor = variable->GetMutable<framework::LoDTensor>();
auto dtype = pten::TransToProtoVarType(src->dtype());
tensor->Resize(src->dims());
SetLoD(tensor->mutable_lod(), src->lod());
if (!tensor->IsInitialized() ||
(tensor->IsInitialized() &&
!IsSameAllocation(tensor->Holder(), src->Holder()))) {
tensor->ResetHolderWithType(std::move(src->Holder()), dtype);
} else {
// Even the pten tensor and Variable have the same Alloctation (both have
// the same pointer address, same size and same place)
// but there is possible that they do not have the same data_type.
// so, here we set the variable's type with the pten tensor dtype.
tensor->set_type(dtype);
}
} else if (variable->IsType<pten::SelectedRows>()) {
auto* tensor = variable->GetMutable<pten::SelectedRows>();
auto dtype = pten::TransToProtoVarType(src->dtype());
if (!tensor->value().IsInitialized()) {
tensor->mutable_value()->ResetHolderWithType(std::move(src->Holder()),
dtype);
}
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported shared input `%s` type now when call pt kernel.",
framework::ToTypeName(variable->Type())));
}
}
void ResetTensorByArgDef(pten::DenseTensor* dst,
const pten::TensorArgDef& arg_def) {
VLOG(5) << "ResetTensor by TensorArgDef.";
......
......@@ -45,11 +45,6 @@ pten::ScalarArray MakePtenScalarArrayFromVar(
pten::ScalarArray MakePtenScalarArrayFromVarList(
const std::vector<framework::Variable*>& variable_list);
void SharesStorage(pten::DenseTensor* src, paddle::framework::Tensor* dst);
void MakeVariableFromPtenTensor(pten::DenseTensor* src,
framework::Variable* variable);
void ResetTensorByArgDef(pten::DenseTensor* dst,
const pten::TensorArgDef& arg_def);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册