未验证 提交 99029dc9 编写于 作者: H hong 提交者: GitHub

update (#41245)

上级 6c285c37
...@@ -69,4 +69,7 @@ void AccuracyRawKernel(const Context& dev_ctx, ...@@ -69,4 +69,7 @@ void AccuracyRawKernel(const Context& dev_ctx,
// TODO(add supported dtype.) // TODO(add supported dtype.)
PD_REGISTER_KERNEL( PD_REGISTER_KERNEL(
accuracy, CPU, ALL_LAYOUT, phi::AccuracyRawKernel, float, double) {} accuracy, CPU, ALL_LAYOUT, phi::AccuracyRawKernel, float, double) {
kernel->InputAt(1).SetDataType(phi::DataType::INT64);
kernel->InputAt(2).SetDataType(phi::DataType::INT64);
}
...@@ -114,4 +114,7 @@ PD_REGISTER_KERNEL(accuracy, ...@@ -114,4 +114,7 @@ PD_REGISTER_KERNEL(accuracy,
phi::AccuracyRawKernel, phi::AccuracyRawKernel,
phi::dtype::float16, phi::dtype::float16,
float, float,
double) {} double) {
kernel->InputAt(1).SetDataType(phi::DataType::INT64);
kernel->InputAt(2).SetDataType(phi::DataType::INT64);
}
...@@ -82,8 +82,8 @@ struct TraceGradFunctor { ...@@ -82,8 +82,8 @@ struct TraceGradFunctor {
template <typename T, typename Context> template <typename T, typename Context>
void TraceGradKernel(const Context& ctx, void TraceGradKernel(const Context& ctx,
const DenseTensor& out_grad,
const DenseTensor& x, const DenseTensor& x,
const DenseTensor& out_grad,
int offset, int offset,
int axis1, int axis1,
int axis2, int axis2,
......
...@@ -21,7 +21,7 @@ namespace phi { ...@@ -21,7 +21,7 @@ namespace phi {
template <typename T, typename Context> template <typename T, typename Context>
void NormGradKernel(const Context& ctx, void NormGradKernel(const Context& ctx,
const DenseTensor& x, const DenseTensor& x,
const DenseTensor& out, const DenseTensor& norm,
const DenseTensor& out_grad, const DenseTensor& out_grad,
int axis, int axis,
float epsilon, float epsilon,
......
...@@ -20,8 +20,8 @@ namespace phi { ...@@ -20,8 +20,8 @@ namespace phi {
template <typename T, typename Context> template <typename T, typename Context>
void TraceGradKernel(const Context& ctx, void TraceGradKernel(const Context& ctx,
const DenseTensor& out_grad,
const DenseTensor& x, const DenseTensor& x,
const DenseTensor& out_grad,
int offset, int offset,
int axis1, int axis1,
int axis2, int axis2,
......
...@@ -23,7 +23,7 @@ KernelSignature TraceOpArgumentMapping(const ArgumentMappingContext& ctx) { ...@@ -23,7 +23,7 @@ KernelSignature TraceOpArgumentMapping(const ArgumentMappingContext& ctx) {
KernelSignature TraceGradOpArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature TraceGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("trace_grad", return KernelSignature("trace_grad",
{GradVarName("Out"), "Input"}, {"Input", GradVarName("Out")},
{"offset", "axis1", "axis2"}, {"offset", "axis1", "axis2"},
{GradVarName("Input")}); {GradVarName("Input")});
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册