未验证 提交 79095918 编写于 作者: T taixiurong 提交者: GitHub

fix build in xpu (#37699)

上级 feda7c1d
...@@ -193,13 +193,14 @@ void TensorAdd(const egr::EagerTensor& src, egr::EagerTensor* dst) { ...@@ -193,13 +193,14 @@ void TensorAdd(const egr::EagerTensor& src, egr::EagerTensor* dst) {
// TODO(jiabin): Support NPU here // TODO(jiabin): Support NPU here
PADDLE_TENSOR_ADD(float); PADDLE_TENSOR_ADD(float);
// NOTE(phlrain): xpu only support float // NOTE(phlrain): xpu only support float
#ifndef PADDLE_WITH_XPU
PADDLE_TENSOR_ADD(double); PADDLE_TENSOR_ADD(double);
// NOTE(chenweihang): only support complex grad tensor accumulated, // NOTE(chenweihang): only support complex grad tensor accumulated,
// support selected rows if needed in the future // support selected rows if needed in the future
PADDLE_TENSOR_ADD(paddle::platform::complex<float>); PADDLE_TENSOR_ADD(paddle::platform::complex<float>);
PADDLE_TENSOR_ADD(paddle::platform::complex<double>); PADDLE_TENSOR_ADD(paddle::platform::complex<double>);
#endif
#undef PADDLE_TENSOR_ADD #undef PADDLE_TENSOR_ADD
if (data_type == paddle::framework::proto::VarType::FP16) { if (data_type == paddle::framework::proto::VarType::FP16) {
...@@ -268,13 +269,14 @@ void VariableAdd(const egr::EagerTensor& src, egr::EagerTensor* dst) { ...@@ -268,13 +269,14 @@ void VariableAdd(const egr::EagerTensor& src, egr::EagerTensor* dst) {
// TODO(jiabin): Support NPU here // TODO(jiabin): Support NPU here
PADDLE_TENSOR_ADD(float); PADDLE_TENSOR_ADD(float);
// NOTE(phlrain): xpu only support float // NOTE(phlrain): xpu only support float
#ifndef PADDLE_WITH_XPU
PADDLE_TENSOR_ADD(double); PADDLE_TENSOR_ADD(double);
// NOTE(chenweihang): only support complex grad tensor accumulated, // NOTE(chenweihang): only support complex grad tensor accumulated,
// support selected rows if needed in the future // support selected rows if needed in the future
PADDLE_TENSOR_ADD(paddle::platform::complex<float>); PADDLE_TENSOR_ADD(paddle::platform::complex<float>);
PADDLE_TENSOR_ADD(paddle::platform::complex<double>); PADDLE_TENSOR_ADD(paddle::platform::complex<double>);
#endif
#undef PADDLE_TENSOR_ADD #undef PADDLE_TENSOR_ADD
if (data_type == paddle::framework::proto::VarType::FP16) { if (data_type == paddle::framework::proto::VarType::FP16) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册