未验证 提交 99029dc9 编写于 作者: H hong 提交者: GitHub

update (#41245)

上级 6c285c37
......@@ -69,4 +69,7 @@ void AccuracyRawKernel(const Context& dev_ctx,
// TODO(add supported dtype.)
PD_REGISTER_KERNEL(
accuracy, CPU, ALL_LAYOUT, phi::AccuracyRawKernel, float, double) {}
accuracy, CPU, ALL_LAYOUT, phi::AccuracyRawKernel, float, double) {
kernel->InputAt(1).SetDataType(phi::DataType::INT64);
kernel->InputAt(2).SetDataType(phi::DataType::INT64);
}
......@@ -114,4 +114,7 @@ PD_REGISTER_KERNEL(accuracy,
phi::AccuracyRawKernel,
phi::dtype::float16,
float,
double) {}
double) {
kernel->InputAt(1).SetDataType(phi::DataType::INT64);
kernel->InputAt(2).SetDataType(phi::DataType::INT64);
}
......@@ -82,8 +82,8 @@ struct TraceGradFunctor {
template <typename T, typename Context>
void TraceGradKernel(const Context& ctx,
const DenseTensor& out_grad,
const DenseTensor& x,
const DenseTensor& out_grad,
int offset,
int axis1,
int axis2,
......
......@@ -21,7 +21,7 @@ namespace phi {
template <typename T, typename Context>
void NormGradKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& out,
const DenseTensor& norm,
const DenseTensor& out_grad,
int axis,
float epsilon,
......
......@@ -20,8 +20,8 @@ namespace phi {
template <typename T, typename Context>
void TraceGradKernel(const Context& ctx,
const DenseTensor& out_grad,
const DenseTensor& x,
const DenseTensor& out_grad,
int offset,
int axis1,
int axis2,
......
......@@ -23,7 +23,7 @@ KernelSignature TraceOpArgumentMapping(const ArgumentMappingContext& ctx) {
KernelSignature TraceGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("trace_grad",
{GradVarName("Out"), "Input"},
{"Input", GradVarName("Out")},
{"offset", "axis1", "axis2"},
{GradVarName("Input")});
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册