未验证 提交 b73708d2 编写于 作者: C chengduo 提交者: GitHub

add int and int64 dtype for gather_op (#14175)

test=develop
上级 62a0fe08
......@@ -102,7 +102,9 @@ REGISTER_OPERATOR(gather, ops::GatherOp, ops::GatherOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(gather_grad, ops::GatherGradOp);
REGISTER_OP_CPU_KERNEL(gather, ops::GatherOpKernel<float>,
ops::GatherOpKernel<int>, ops::GatherOpKernel<double>);
ops::GatherOpKernel<double>, ops::GatherOpKernel<int>,
ops::GatherOpKernel<int64_t>);
REGISTER_OP_CPU_KERNEL(gather_grad, ops::GatherGradientOpKernel<float>,
ops::GatherGradientOpKernel<double>,
ops::GatherGradientOpKernel<int>,
ops::GatherGradientOpKernel<double>);
ops::GatherGradientOpKernel<int64_t>);
......@@ -61,5 +61,11 @@ class GatherGradOpCUDAKernel : public framework::OpKernel<T> {
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(gather, ops::GatherOpCUDAKernel<float>);
REGISTER_OP_CUDA_KERNEL(gather_grad, ops::GatherGradOpCUDAKernel<float>);
REGISTER_OP_CUDA_KERNEL(gather, ops::GatherOpCUDAKernel<float>,
ops::GatherOpCUDAKernel<double>,
ops::GatherOpCUDAKernel<int64_t>,
ops::GatherOpCUDAKernel<int>);
REGISTER_OP_CUDA_KERNEL(gather_grad, ops::GatherGradOpCUDAKernel<float>,
ops::GatherGradOpCUDAKernel<double>,
ops::GatherGradOpCUDAKernel<int64_t>,
ops::GatherGradOpCUDAKernel<int>);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册