未验证 提交 4bfd0445 编写于 作者: Y Yuang Liu 提交者: GitHub

[hybrid performance] optim npu coalesce set constant (#35105)

上级 d710c3a0
...@@ -30,8 +30,14 @@ namespace operators { ...@@ -30,8 +30,14 @@ namespace operators {
template <typename DeviceContext> template <typename DeviceContext>
struct FillConstantVisitor { struct FillConstantVisitor {
FillConstantVisitor(const DeviceContext &dev_ctx, FillConstantVisitor(const DeviceContext &dev_ctx,
framework::LoDTensor *tensor, const float value) framework::LoDTensor *tensor, const float value,
: dev_ctx_(dev_ctx), tensor_(tensor), value_(value) {} framework::proto::VarType::Type dtype,
const framework::ExecutionContext &context)
: dev_ctx_(dev_ctx),
tensor_(tensor),
value_(value),
dtype_(dtype),
context_(context) {}
template <typename T> template <typename T>
void apply(typename std::enable_if<std::is_same<T, int8_t>::value || void apply(typename std::enable_if<std::is_same<T, int8_t>::value ||
...@@ -47,7 +53,17 @@ struct FillConstantVisitor { ...@@ -47,7 +53,17 @@ struct FillConstantVisitor {
* = nullptr) const { * = nullptr) const {
#ifdef PADDLE_WITH_ASCEND_CL #ifdef PADDLE_WITH_ASCEND_CL
if (platform::is_npu_place(dev_ctx_.GetPlace())) { if (platform::is_npu_place(dev_ctx_.GetPlace())) {
FillNpuTensorWithConstant<T>(tensor_, static_cast<T>(value_)); Tensor tensor_tmp(dtype_);
tensor_tmp.mutable_data<T>({1}, context_.GetPlace());
FillNpuTensorWithConstant<T>(&tensor_tmp, static_cast<T>(value_));
const auto &runner =
NpuOpRunner("FillD", {tensor_tmp}, {*tensor_},
{{"dims", framework::vectorize(tensor_->dims())}});
auto stream =
context_.template device_context<paddle::platform::NPUDeviceContext>()
.stream();
runner.Run(stream);
} else { } else {
math::SetConstant<DeviceContext, T> set_constant; math::SetConstant<DeviceContext, T> set_constant;
set_constant(dev_ctx_, tensor_, static_cast<T>(value_)); set_constant(dev_ctx_, tensor_, static_cast<T>(value_));
...@@ -61,6 +77,8 @@ struct FillConstantVisitor { ...@@ -61,6 +77,8 @@ struct FillConstantVisitor {
const DeviceContext &dev_ctx_; const DeviceContext &dev_ctx_;
framework::LoDTensor *tensor_; framework::LoDTensor *tensor_;
float value_; float value_;
framework::proto::VarType::Type dtype_;
const framework::ExecutionContext &context_;
}; };
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
...@@ -165,7 +183,8 @@ class CoalesceTensorOpKernel : public framework::OpKernel<T> { ...@@ -165,7 +183,8 @@ class CoalesceTensorOpKernel : public framework::OpKernel<T> {
} else if (context.Attr<bool>("set_constant")) { } else if (context.Attr<bool>("set_constant")) {
framework::VisitDataType( framework::VisitDataType(
dtype, FillConstantVisitor<DeviceContext>( dtype, FillConstantVisitor<DeviceContext>(
dev_ctx, fused_tensor, context.Attr<float>("constant"))); dev_ctx, fused_tensor, context.Attr<float>("constant"),
dtype, context));
} else if (context.Attr<bool>("persist_output")) { } else if (context.Attr<bool>("persist_output")) {
for (size_t i = 0; i < out_var_names.size(); ++i) { for (size_t i = 0; i < out_var_names.size(); ++i) {
size_t len = static_cast<size_t>(out_tensors[i]->numel()); size_t len = static_cast<size_t>(out_tensors[i]->numel());
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册