diff --git a/paddle/phi/kernels/cpu/accuracy_kernel.cc b/paddle/phi/kernels/cpu/accuracy_kernel.cc index c57ec69b73a230df48411f4074935e2bb4bce461..6ff8a1f7558973965f51f42bdd0984757f285b47 100644 --- a/paddle/phi/kernels/cpu/accuracy_kernel.cc +++ b/paddle/phi/kernels/cpu/accuracy_kernel.cc @@ -69,4 +69,7 @@ void AccuracyRawKernel(const Context& dev_ctx, // TODO(add supported dtype.) PD_REGISTER_KERNEL( - accuracy, CPU, ALL_LAYOUT, phi::AccuracyRawKernel, float, double) {} + accuracy, CPU, ALL_LAYOUT, phi::AccuracyRawKernel, float, double) { + kernel->InputAt(1).SetDataType(phi::DataType::INT64); + kernel->InputAt(2).SetDataType(phi::DataType::INT64); +} diff --git a/paddle/phi/kernels/gpu/accuracy_kernel.cu b/paddle/phi/kernels/gpu/accuracy_kernel.cu index f08fb74e54d8c86f7b54d21c762e30cebedfe967..5eecfce09324857b53cfa462d8a65b60c27efb7d 100644 --- a/paddle/phi/kernels/gpu/accuracy_kernel.cu +++ b/paddle/phi/kernels/gpu/accuracy_kernel.cu @@ -114,4 +114,7 @@ PD_REGISTER_KERNEL(accuracy, phi::AccuracyRawKernel, phi::dtype::float16, float, - double) {} + double) { + kernel->InputAt(1).SetDataType(phi::DataType::INT64); + kernel->InputAt(2).SetDataType(phi::DataType::INT64); +} diff --git a/paddle/phi/kernels/impl/trace_grad_kernel_impl.h b/paddle/phi/kernels/impl/trace_grad_kernel_impl.h index b0878d779462a9c351caa038af2ac017bbf4a14f..90a2327ef3e204fe0cda3cc281407926e0a61ba3 100644 --- a/paddle/phi/kernels/impl/trace_grad_kernel_impl.h +++ b/paddle/phi/kernels/impl/trace_grad_kernel_impl.h @@ -82,8 +82,8 @@ struct TraceGradFunctor { template void TraceGradKernel(const Context& ctx, - const DenseTensor& out_grad, const DenseTensor& x, + const DenseTensor& out_grad, int offset, int axis1, int axis2, diff --git a/paddle/phi/kernels/norm_grad_kernel.h b/paddle/phi/kernels/norm_grad_kernel.h index 55714b8a4a091f6d64cbb9a03eb9043d4c2dbf22..a67e757ba510f03f211cf383cc68b38e3099ae3c 100644 --- a/paddle/phi/kernels/norm_grad_kernel.h +++ b/paddle/phi/kernels/norm_grad_kernel.h @@ -21,7 +21,7 @@ namespace phi { template void NormGradKernel(const Context& ctx, const DenseTensor& x, - const DenseTensor& out, + const DenseTensor& norm, const DenseTensor& out_grad, int axis, float epsilon, diff --git a/paddle/phi/kernels/trace_grad_kernel.h b/paddle/phi/kernels/trace_grad_kernel.h index ef17986e75593cbff21b11d2371488d77bd56205..4884e53b4efe50e1cb805ea616ebb332c976e0a7 100644 --- a/paddle/phi/kernels/trace_grad_kernel.h +++ b/paddle/phi/kernels/trace_grad_kernel.h @@ -20,8 +20,8 @@ namespace phi { template void TraceGradKernel(const Context& ctx, - const DenseTensor& out_grad, const DenseTensor& x, + const DenseTensor& out_grad, int offset, int axis1, int axis2, diff --git a/paddle/phi/ops/compat/trace_sig.cc b/paddle/phi/ops/compat/trace_sig.cc index 44fd53db98a3cf12098a676d1a2abf0bc629bb70..c3f5d6d287551e0c8732f3c6a7fca9cfcf3276bb 100644 --- a/paddle/phi/ops/compat/trace_sig.cc +++ b/paddle/phi/ops/compat/trace_sig.cc @@ -23,7 +23,7 @@ KernelSignature TraceOpArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature TraceGradOpArgumentMapping(const ArgumentMappingContext& ctx) { return KernelSignature("trace_grad", - {GradVarName("Out"), "Input"}, + {"Input", GradVarName("Out")}, {"offset", "axis1", "axis2"}, {GradVarName("Input")}); }