From 99029dc9349baf7e358e5fcf7006b857e2e70cb7 Mon Sep 17 00:00:00 2001 From: hong <43953930+phlrain@users.noreply.github.com> Date: Fri, 1 Apr 2022 21:06:01 +0800 Subject: [PATCH] update (#41245) --- paddle/phi/kernels/cpu/accuracy_kernel.cc | 5 ++++- paddle/phi/kernels/gpu/accuracy_kernel.cu | 5 ++++- paddle/phi/kernels/impl/trace_grad_kernel_impl.h | 2 +- paddle/phi/kernels/norm_grad_kernel.h | 2 +- paddle/phi/kernels/trace_grad_kernel.h | 2 +- paddle/phi/ops/compat/trace_sig.cc | 2 +- 6 files changed, 12 insertions(+), 6 deletions(-) diff --git a/paddle/phi/kernels/cpu/accuracy_kernel.cc b/paddle/phi/kernels/cpu/accuracy_kernel.cc index c57ec69b73a..6ff8a1f7558 100644 --- a/paddle/phi/kernels/cpu/accuracy_kernel.cc +++ b/paddle/phi/kernels/cpu/accuracy_kernel.cc @@ -69,4 +69,7 @@ void AccuracyRawKernel(const Context& dev_ctx, // TODO(add supported dtype.) PD_REGISTER_KERNEL( - accuracy, CPU, ALL_LAYOUT, phi::AccuracyRawKernel, float, double) {} + accuracy, CPU, ALL_LAYOUT, phi::AccuracyRawKernel, float, double) { + kernel->InputAt(1).SetDataType(phi::DataType::INT64); + kernel->InputAt(2).SetDataType(phi::DataType::INT64); +} diff --git a/paddle/phi/kernels/gpu/accuracy_kernel.cu b/paddle/phi/kernels/gpu/accuracy_kernel.cu index f08fb74e54d..5eecfce0932 100644 --- a/paddle/phi/kernels/gpu/accuracy_kernel.cu +++ b/paddle/phi/kernels/gpu/accuracy_kernel.cu @@ -114,4 +114,7 @@ PD_REGISTER_KERNEL(accuracy, phi::AccuracyRawKernel, phi::dtype::float16, float, - double) {} + double) { + kernel->InputAt(1).SetDataType(phi::DataType::INT64); + kernel->InputAt(2).SetDataType(phi::DataType::INT64); +} diff --git a/paddle/phi/kernels/impl/trace_grad_kernel_impl.h b/paddle/phi/kernels/impl/trace_grad_kernel_impl.h index b0878d77946..90a2327ef3e 100644 --- a/paddle/phi/kernels/impl/trace_grad_kernel_impl.h +++ b/paddle/phi/kernels/impl/trace_grad_kernel_impl.h @@ -82,8 +82,8 @@ struct TraceGradFunctor { template void TraceGradKernel(const Context& ctx, - const DenseTensor& out_grad, const DenseTensor& x, + const DenseTensor& out_grad, int offset, int axis1, int axis2, diff --git a/paddle/phi/kernels/norm_grad_kernel.h b/paddle/phi/kernels/norm_grad_kernel.h index 55714b8a4a0..a67e757ba51 100644 --- a/paddle/phi/kernels/norm_grad_kernel.h +++ b/paddle/phi/kernels/norm_grad_kernel.h @@ -21,7 +21,7 @@ namespace phi { template void NormGradKernel(const Context& ctx, const DenseTensor& x, - const DenseTensor& out, + const DenseTensor& norm, const DenseTensor& out_grad, int axis, float epsilon, diff --git a/paddle/phi/kernels/trace_grad_kernel.h b/paddle/phi/kernels/trace_grad_kernel.h index ef17986e755..4884e53b4ef 100644 --- a/paddle/phi/kernels/trace_grad_kernel.h +++ b/paddle/phi/kernels/trace_grad_kernel.h @@ -20,8 +20,8 @@ namespace phi { template void TraceGradKernel(const Context& ctx, - const DenseTensor& out_grad, const DenseTensor& x, + const DenseTensor& out_grad, int offset, int axis1, int axis2, diff --git a/paddle/phi/ops/compat/trace_sig.cc b/paddle/phi/ops/compat/trace_sig.cc index 44fd53db98a..c3f5d6d2875 100644 --- a/paddle/phi/ops/compat/trace_sig.cc +++ b/paddle/phi/ops/compat/trace_sig.cc @@ -23,7 +23,7 @@ KernelSignature TraceOpArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature TraceGradOpArgumentMapping(const ArgumentMappingContext& ctx) { return KernelSignature("trace_grad", - {GradVarName("Out"), "Input"}, + {"Input", GradVarName("Out")}, {"offset", "axis1", "axis2"}, {GradVarName("Input")}); } -- GitLab