未验证 提交 9e42fe9a 编写于 作者: N Noel 提交者: GitHub

[pnorm] fix bug in pnorm (#38215)

上级 59be8e0e
......@@ -22,6 +22,7 @@ namespace cub = hipcub;
#endif
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h"
#include "paddle/fluid/operators/fc_op.h"
#include "paddle/fluid/operators/p_norm_op.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op.h"
......@@ -106,9 +107,10 @@ class PnormCUDAKernel : public framework::OpKernel<T> {
auto xdim = in_x->dims();
auto ndim = out_norm->dims();
float porder = ctx.Attr<float>("porder");
bool asvector = ctx.Attr<bool>("asvector");
int axis = ctx.Attr<int>("axis");
if (axis < 0) axis = xdim.size() + axis;
std::vector<int> reduce_axis = {axis};
reduce_axis = GetReduceDim(reduce_axis, xdim.size(), asvector);
auto stream = ctx.cuda_device_context().stream();
......@@ -195,7 +197,7 @@ class PnormGradCUDAKernel : public framework::OpKernel<T> {
auto xdim = in_x->dims();
float porder = ctx.Attr<float>("porder");
int axis = ctx.Attr<int>("axis");
bool reduce_all = ((axis < 0) || (in_norm->numel() == 1));
bool reduce_all = (in_norm->numel() == 1);
if (axis < 0) axis = xdim.size() + axis;
const std::vector<int> dims = {axis};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册