From 4bfd0445021e8d899fb7b06e3f8ed4f703e8c329 Mon Sep 17 00:00:00 2001 From: Yuang Liu Date: Wed, 25 Aug 2021 10:23:26 +0800 Subject: [PATCH] [hybrid performance] optim npu coalesce set constant (#35105) --- paddle/fluid/operators/coalesce_tensor_op.cc | 27 +++++++++++++++++--- 1 file changed, 23 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/operators/coalesce_tensor_op.cc b/paddle/fluid/operators/coalesce_tensor_op.cc index 4c5f3a2a47..c9cc01b8b1 100644 --- a/paddle/fluid/operators/coalesce_tensor_op.cc +++ b/paddle/fluid/operators/coalesce_tensor_op.cc @@ -30,8 +30,14 @@ namespace operators { template struct FillConstantVisitor { FillConstantVisitor(const DeviceContext &dev_ctx, - framework::LoDTensor *tensor, const float value) - : dev_ctx_(dev_ctx), tensor_(tensor), value_(value) {} + framework::LoDTensor *tensor, const float value, + framework::proto::VarType::Type dtype, + const framework::ExecutionContext &context) + : dev_ctx_(dev_ctx), + tensor_(tensor), + value_(value), + dtype_(dtype), + context_(context) {} template void apply(typename std::enable_if::value || @@ -47,7 +53,17 @@ struct FillConstantVisitor { * = nullptr) const { #ifdef PADDLE_WITH_ASCEND_CL if (platform::is_npu_place(dev_ctx_.GetPlace())) { - FillNpuTensorWithConstant(tensor_, static_cast(value_)); + Tensor tensor_tmp(dtype_); + tensor_tmp.mutable_data({1}, context_.GetPlace()); + FillNpuTensorWithConstant(&tensor_tmp, static_cast(value_)); + + const auto &runner = + NpuOpRunner("FillD", {tensor_tmp}, {*tensor_}, + {{"dims", framework::vectorize(tensor_->dims())}}); + auto stream = + context_.template device_context() + .stream(); + runner.Run(stream); } else { math::SetConstant set_constant; set_constant(dev_ctx_, tensor_, static_cast(value_)); @@ -61,6 +77,8 @@ struct FillConstantVisitor { const DeviceContext &dev_ctx_; framework::LoDTensor *tensor_; float value_; + framework::proto::VarType::Type dtype_; + const framework::ExecutionContext &context_; }; template @@ -165,7 +183,8 @@ class CoalesceTensorOpKernel : public framework::OpKernel { } else if (context.Attr("set_constant")) { framework::VisitDataType( dtype, FillConstantVisitor( - dev_ctx, fused_tensor, context.Attr("constant"))); + dev_ctx, fused_tensor, context.Attr("constant"), + dtype, context)); } else if (context.Attr("persist_output")) { for (size_t i = 0; i < out_var_names.size(); ++i) { size_t len = static_cast(out_tensors[i]->numel()); -- GitLab