未验证 提交 648ec795 编写于 作者: P PuQing 提交者: GitHub

[PHI]fix momentum dtype infer (#51353)

* fix momentum dtype infer

* fix momentum datatype

* fix on cpu

* add momentum
上级 e18f5339
...@@ -2172,10 +2172,15 @@ void MomentumInferMeta(const MetaTensor& param, ...@@ -2172,10 +2172,15 @@ void MomentumInferMeta(const MetaTensor& param,
auto param_dim = param.dims(); auto param_dim = param.dims();
param_out->set_dims(param_dim); param_out->set_dims(param_dim);
auto MPType = (param.dtype() == phi::DataType::FLOAT16 ||
param.dtype() == phi::DataType::BFLOAT16)
? phi::DataType::FLOAT32
: param.dtype();
velocity_out->set_dims(param_dim); velocity_out->set_dims(param_dim);
velocity_out->set_dtype(MPType);
if (master_param_out) { if (master_param_out) {
master_param_out->set_dims(param_dim); master_param_out->set_dims(param_dim);
master_param_out->set_dtype(MPType);
} }
} }
......
...@@ -24,7 +24,10 @@ PD_REGISTER_KERNEL(momentum, ...@@ -24,7 +24,10 @@ PD_REGISTER_KERNEL(momentum,
phi::MomentumDenseKernel, phi::MomentumDenseKernel,
float, float,
double, double,
phi::dtype::float16) {} phi::dtype::float16) {
kernel->OutputAt(1).SetDataType(phi::DataType::UNDEFINED);
kernel->OutputAt(2).SetDataType(phi::DataType::UNDEFINED);
}
PD_REGISTER_KERNEL(momentum_dense_param_sparse_grad, PD_REGISTER_KERNEL(momentum_dense_param_sparse_grad,
GPU, GPU,
...@@ -32,4 +35,7 @@ PD_REGISTER_KERNEL(momentum_dense_param_sparse_grad, ...@@ -32,4 +35,7 @@ PD_REGISTER_KERNEL(momentum_dense_param_sparse_grad,
phi::MomentumSparseKernel, phi::MomentumSparseKernel,
float, float,
double, double,
phi::dtype::float16) {} phi::dtype::float16) {
kernel->OutputAt(1).SetDataType(phi::DataType::UNDEFINED);
kernel->OutputAt(2).SetDataType(phi::DataType::UNDEFINED);
}
...@@ -69,4 +69,7 @@ PD_REGISTER_KERNEL(momentum, ...@@ -69,4 +69,7 @@ PD_REGISTER_KERNEL(momentum,
ALL_LAYOUT, ALL_LAYOUT,
phi::MomentumDenseKernel, phi::MomentumDenseKernel,
float, float,
phi::dtype::float16) {} phi::dtype::float16) {
kernel->OutputAt(1).SetDataType(phi::DataType::UNDEFINED);
kernel->OutputAt(2).SetDataType(phi::DataType::UNDEFINED);
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册