diff --git a/paddle/fluid/operators/p_norm_op.cu b/paddle/fluid/operators/p_norm_op.cu index 11a4928fe1c39cf75686cebd32451fdaec4d425b..1db6f6e51746282f89e9c160a0d370b72eecd605 100644 --- a/paddle/fluid/operators/p_norm_op.cu +++ b/paddle/fluid/operators/p_norm_op.cu @@ -22,6 +22,7 @@ namespace cub = hipcub; #endif #include "paddle/fluid/operators/amp/fp16_type_traits.h" #include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h" +#include "paddle/fluid/operators/fc_op.h" #include "paddle/fluid/operators/p_norm_op.h" #include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h" #include "paddle/fluid/operators/reduce_ops/reduce_op.h" @@ -106,9 +107,10 @@ class PnormCUDAKernel : public framework::OpKernel { auto xdim = in_x->dims(); auto ndim = out_norm->dims(); float porder = ctx.Attr("porder"); + bool asvector = ctx.Attr("asvector"); int axis = ctx.Attr("axis"); - if (axis < 0) axis = xdim.size() + axis; std::vector reduce_axis = {axis}; + reduce_axis = GetReduceDim(reduce_axis, xdim.size(), asvector); auto stream = ctx.cuda_device_context().stream(); @@ -195,7 +197,7 @@ class PnormGradCUDAKernel : public framework::OpKernel { auto xdim = in_x->dims(); float porder = ctx.Attr("porder"); int axis = ctx.Attr("axis"); - bool reduce_all = ((axis < 0) || (in_norm->numel() == 1)); + bool reduce_all = (in_norm->numel() == 1); if (axis < 0) axis = xdim.size() + axis; const std::vector dims = {axis};