未验证 提交 8645591d 编写于 作者: G Guanghua Yu 提交者: GitHub

support fp64 in huber_loss cuda kernel (#26583)

上级 90e6819c
......@@ -16,7 +16,9 @@ limitations under the License. */
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
huber_loss,
ops::HuberLossKernel<paddle::platform::CUDADeviceContext, float>);
ops::HuberLossKernel<paddle::platform::CUDADeviceContext, float>,
ops::HuberLossKernel<paddle::platform::CUDADeviceContext, double>);
REGISTER_OP_CUDA_KERNEL(
huber_loss_grad,
ops::HuberLossGradKernel<paddle::platform::CUDADeviceContext, float>);
ops::HuberLossGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::HuberLossGradKernel<paddle::platform::CUDADeviceContext, double>);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册