未验证 提交 24b8f63e 编写于 作者: F fwenguang 提交者: GitHub

[MLU] fix TensorAdd for mlu (#39523)

上级 7d53a288
...@@ -35,6 +35,9 @@ ...@@ -35,6 +35,9 @@
#ifdef PADDLE_WITH_ASCEND_CL #ifdef PADDLE_WITH_ASCEND_CL
#include "paddle/fluid/platform/device/npu/npu_op_runner.h" #include "paddle/fluid/platform/device/npu/npu_op_runner.h"
#endif #endif
#ifdef PADDLE_WITH_MLU
#include "paddle/fluid/operators/mlu/mlu_baseop.h"
#endif
namespace paddle { namespace paddle {
namespace imperative { namespace imperative {
...@@ -362,6 +365,35 @@ void TensorAdd(const VarType& src, VarType* dst) { ...@@ -362,6 +365,35 @@ void TensorAdd(const VarType& src, VarType* dst) {
} }
#endif #endif
#ifdef PADDLE_WITH_MLU
if (platform::is_mlu_place(place)) {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
platform::DeviceContext* ctx = pool.Get(place);
auto dev_ctx = dynamic_cast<platform::MLUDeviceContext*>(ctx);
if (data_type == framework::DataTypeTrait<float>::DataType()) {
dst_tensor->mutable_data<float>(place);
} else if (data_type ==
framework::DataTypeTrait<platform::float16>::DataType()) {
dst_tensor->mutable_data<platform::float16>(place);
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Gradient accumulation of data type (%s) on place (%s) is not "
"supported in imperative mode",
framework::DataTypeToString(data_type), place));
}
static const float alpha = 1.f;
static const float beta = 1.f;
operators::MLUCnnlTensorDesc src_tensor_desc(src_tensor);
operators::MLUCnnlTensorDesc dst_tensor_desc(*dst_tensor);
PADDLE_ENFORCE_MLU_SUCCESS(cnnlAssignAdd(
dev_ctx->cnnl_handle(), static_cast<void*>(&alpha),
src_tensor_desc.get(), operators::GetBasePtr(&src_tensor), nullptr, 0,
static_cast<void*>(&beta), dst_tensor_desc.get(),
operators::GetBasePtr(dst_tensor)));
return;
}
#endif
PADDLE_TENSOR_ADD(float); PADDLE_TENSOR_ADD(float);
#ifndef PADDLE_WITH_XPU #ifndef PADDLE_WITH_XPU
......
...@@ -1676,7 +1676,7 @@ def cross_entropy(input, ...@@ -1676,7 +1676,7 @@ def cross_entropy(input,
if label_max >= input.shape[axis]: if label_max >= input.shape[axis]:
raise ValueError("label should not out of bound, but got{}". raise ValueError("label should not out of bound, but got{}".
format(label_max)) format(label_max))
if core.is_compiled_with_npu(): if core.is_compiled_with_npu() or core.is_compiled_with_mlu():
_, _, out = _C_ops.softmax_with_cross_entropy( _, _, out = _C_ops.softmax_with_cross_entropy(
input, label, 'soft_label', soft_label, 'ignore_index', input, label, 'soft_label', soft_label, 'ignore_index',
ignore_index, 'numeric_stable_mode', True, 'axis', axis, ignore_index, 'numeric_stable_mode', True, 'axis', axis,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册