From 24b8f63efb61f0ddb0d40fea8d1f474d600b9ce1 Mon Sep 17 00:00:00 2001 From: fwenguang <95677191+fwenguang@users.noreply.github.com> Date: Wed, 16 Feb 2022 17:41:03 +0800 Subject: [PATCH] [MLU] fix TensorAdd for mlu (#39523) --- .../fluid/imperative/gradient_accumulator.cc | 32 +++++++++++++++++++ python/paddle/nn/functional/loss.py | 2 +- 2 files changed, 33 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/imperative/gradient_accumulator.cc b/paddle/fluid/imperative/gradient_accumulator.cc index dc8b3982ba..17ab1f1f7c 100644 --- a/paddle/fluid/imperative/gradient_accumulator.cc +++ b/paddle/fluid/imperative/gradient_accumulator.cc @@ -35,6 +35,9 @@ #ifdef PADDLE_WITH_ASCEND_CL #include "paddle/fluid/platform/device/npu/npu_op_runner.h" #endif +#ifdef PADDLE_WITH_MLU +#include "paddle/fluid/operators/mlu/mlu_baseop.h" +#endif namespace paddle { namespace imperative { @@ -362,6 +365,35 @@ void TensorAdd(const VarType& src, VarType* dst) { } #endif +#ifdef PADDLE_WITH_MLU + if (platform::is_mlu_place(place)) { + platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); + platform::DeviceContext* ctx = pool.Get(place); + auto dev_ctx = dynamic_cast(ctx); + if (data_type == framework::DataTypeTrait::DataType()) { + dst_tensor->mutable_data(place); + } else if (data_type == + framework::DataTypeTrait::DataType()) { + dst_tensor->mutable_data(place); + } else { + PADDLE_THROW(platform::errors::Unimplemented( + "Gradient accumulation of data type (%s) on place (%s) is not " + "supported in imperative mode", + framework::DataTypeToString(data_type), place)); + } + static const float alpha = 1.f; + static const float beta = 1.f; + operators::MLUCnnlTensorDesc src_tensor_desc(src_tensor); + operators::MLUCnnlTensorDesc dst_tensor_desc(*dst_tensor); + PADDLE_ENFORCE_MLU_SUCCESS(cnnlAssignAdd( + dev_ctx->cnnl_handle(), static_cast(&alpha), + src_tensor_desc.get(), operators::GetBasePtr(&src_tensor), nullptr, 0, + static_cast(&beta), dst_tensor_desc.get(), + operators::GetBasePtr(dst_tensor))); + return; + } +#endif + PADDLE_TENSOR_ADD(float); #ifndef PADDLE_WITH_XPU diff --git a/python/paddle/nn/functional/loss.py b/python/paddle/nn/functional/loss.py index 711fd1e94c..8dc0403259 100755 --- a/python/paddle/nn/functional/loss.py +++ b/python/paddle/nn/functional/loss.py @@ -1676,7 +1676,7 @@ def cross_entropy(input, if label_max >= input.shape[axis]: raise ValueError("label should not out of bound, but got{}". format(label_max)) - if core.is_compiled_with_npu(): + if core.is_compiled_with_npu() or core.is_compiled_with_mlu(): _, _, out = _C_ops.softmax_with_cross_entropy( input, label, 'soft_label', soft_label, 'ignore_index', ignore_index, 'numeric_stable_mode', True, 'axis', axis, -- GitLab