未验证 提交 12260bdb 编写于 作者: Y Yuang Liu 提交者: GitHub

[cherry-pick][hybrid performance] optim npu coalesce set constant (#35105) (#35302)

上级 e69cc215
......@@ -30,8 +30,14 @@ namespace operators {
template <typename DeviceContext>
struct FillConstantVisitor {
FillConstantVisitor(const DeviceContext &dev_ctx,
framework::LoDTensor *tensor, const float value)
: dev_ctx_(dev_ctx), tensor_(tensor), value_(value) {}
framework::LoDTensor *tensor, const float value,
framework::proto::VarType::Type dtype,
const framework::ExecutionContext &context)
: dev_ctx_(dev_ctx),
tensor_(tensor),
value_(value),
dtype_(dtype),
context_(context) {}
template <typename T>
void apply(typename std::enable_if<std::is_same<T, int8_t>::value ||
......@@ -47,7 +53,17 @@ struct FillConstantVisitor {
* = nullptr) const {
#ifdef PADDLE_WITH_ASCEND_CL
if (platform::is_npu_place(dev_ctx_.GetPlace())) {
FillNpuTensorWithConstant<T>(tensor_, static_cast<T>(value_));
Tensor tensor_tmp(dtype_);
tensor_tmp.mutable_data<T>({1}, context_.GetPlace());
FillNpuTensorWithConstant<T>(&tensor_tmp, static_cast<T>(value_));
const auto &runner =
NpuOpRunner("FillD", {tensor_tmp}, {*tensor_},
{{"dims", framework::vectorize(tensor_->dims())}});
auto stream =
context_.template device_context<paddle::platform::NPUDeviceContext>()
.stream();
runner.Run(stream);
} else {
math::SetConstant<DeviceContext, T> set_constant;
set_constant(dev_ctx_, tensor_, static_cast<T>(value_));
......@@ -61,6 +77,8 @@ struct FillConstantVisitor {
const DeviceContext &dev_ctx_;
framework::LoDTensor *tensor_;
float value_;
framework::proto::VarType::Type dtype_;
const framework::ExecutionContext &context_;
};
template <typename DeviceContext, typename T>
......@@ -165,7 +183,8 @@ class CoalesceTensorOpKernel : public framework::OpKernel<T> {
} else if (context.Attr<bool>("set_constant")) {
framework::VisitDataType(
dtype, FillConstantVisitor<DeviceContext>(
dev_ctx, fused_tensor, context.Attr<float>("constant")));
dev_ctx, fused_tensor, context.Attr<float>("constant"),
dtype, context));
} else if (context.Attr<bool>("persist_output")) {
for (size_t i = 0; i < out_var_names.size(); ++i) {
size_t len = static_cast<size_t>(out_tensors[i]->numel());
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册