未验证 提交 c0f84b8f 编写于 作者: iSerendipity's avatar iSerendipity 提交者: GitHub

Add output defs for sgd kernel (#51332)

* Add output defs for sgd kernel

* add datatype infer for sgd

* add infer logic
上级 4474e085
...@@ -96,7 +96,6 @@ static std::set<std::string> OpsNeedSetOutputDtypeWhenRegisterPhiKernel = { ...@@ -96,7 +96,6 @@ static std::set<std::string> OpsNeedSetOutputDtypeWhenRegisterPhiKernel = {
"select", "select",
"send_recv", "send_recv",
"send_ue_recv", "send_ue_recv",
"sgd",
"svd", "svd",
"sync_batch_norm_grad", "sync_batch_norm_grad",
"unique", "unique",
......
...@@ -2484,6 +2484,15 @@ void SgdInferMeta(const MetaTensor& param, ...@@ -2484,6 +2484,15 @@ void SgdInferMeta(const MetaTensor& param,
param_out->set_dims(param.dims()); param_out->set_dims(param.dims());
param_out->set_dtype(param.dtype()); param_out->set_dtype(param.dtype());
if (multi_precision) {
master_param_out->set_dims(master_param.dims());
if (DataType::FLOAT16 == master_param.dtype() ||
DataType::BFLOAT16 == master_param.dtype()) {
master_param_out->set_dtype(DataType::FLOAT32);
} else {
master_param_out->set_dtype(master_param.dtype());
}
}
} }
void SendUERecvInferMeta(const MetaTensor& x, void SendUERecvInferMeta(const MetaTensor& x,
......
...@@ -187,7 +187,9 @@ PD_REGISTER_KERNEL(sgd, ...@@ -187,7 +187,9 @@ PD_REGISTER_KERNEL(sgd,
phi::SGDDenseKernel, phi::SGDDenseKernel,
phi::dtype::float16, phi::dtype::float16,
float, float,
double) {} double) {
kernel->OutputAt(1).SetDataType(phi::DataType::UNDEFINED);
}
PD_REGISTER_KERNEL(sgd_dense_param_sparse_grad, PD_REGISTER_KERNEL(sgd_dense_param_sparse_grad,
GPU, GPU,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册