未验证 提交 9c1eb98a 编写于 作者: H houj04 提交者: GitHub

[XPU] c_sync_calc_stream support more types (#53389)

上级 18968e7e
...@@ -17,5 +17,12 @@ limitations under the License. */ ...@@ -17,5 +17,12 @@ limitations under the License. */
namespace ops = paddle::operators; namespace ops = paddle::operators;
namespace plat = paddle::platform; namespace plat = paddle::platform;
PD_REGISTER_STRUCT_KERNEL( PD_REGISTER_STRUCT_KERNEL(c_sync_calc_stream,
c_sync_calc_stream, XPU, ALL_LAYOUT, ops::CSyncCalcStreamKernel, float) {} XPU,
ALL_LAYOUT,
ops::CSyncCalcStreamKernel,
float,
double,
int,
int64_t,
plat::float16) {}
...@@ -115,7 +115,12 @@ XPUOpMap& get_kl2_ops() { ...@@ -115,7 +115,12 @@ XPUOpMap& get_kl2_ops() {
XPUKernelSet({phi::DataType::FLOAT16, XPUKernelSet({phi::DataType::FLOAT16,
phi::DataType::FLOAT32, phi::DataType::FLOAT32,
phi::DataType::INT32})}, phi::DataType::INT32})},
{"c_sync_calc_stream", XPUKernelSet({phi::DataType::FLOAT32})}, {"c_sync_calc_stream",
XPUKernelSet({phi::DataType::FLOAT16,
phi::DataType::FLOAT32,
phi::DataType::FLOAT64,
phi::DataType::INT32,
phi::DataType::INT64})},
{"c_sync_comm_stream", XPUKernelSet({phi::DataType::FLOAT32})}, {"c_sync_comm_stream", XPUKernelSet({phi::DataType::FLOAT32})},
{"cast", {"cast",
XPUKernelSet({phi::DataType::FLOAT32, XPUKernelSet({phi::DataType::FLOAT32,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册