未验证 提交 9e42fe9a 编写于 作者: N Noel 提交者: GitHub

[pnorm] fix bug in pnorm (#38215)

上级 59be8e0e
...@@ -22,6 +22,7 @@ namespace cub = hipcub; ...@@ -22,6 +22,7 @@ namespace cub = hipcub;
#endif #endif
#include "paddle/fluid/operators/amp/fp16_type_traits.h" #include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h" #include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h"
#include "paddle/fluid/operators/fc_op.h"
#include "paddle/fluid/operators/p_norm_op.h" #include "paddle/fluid/operators/p_norm_op.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h" #include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op.h" #include "paddle/fluid/operators/reduce_ops/reduce_op.h"
...@@ -106,9 +107,10 @@ class PnormCUDAKernel : public framework::OpKernel<T> { ...@@ -106,9 +107,10 @@ class PnormCUDAKernel : public framework::OpKernel<T> {
auto xdim = in_x->dims(); auto xdim = in_x->dims();
auto ndim = out_norm->dims(); auto ndim = out_norm->dims();
float porder = ctx.Attr<float>("porder"); float porder = ctx.Attr<float>("porder");
bool asvector = ctx.Attr<bool>("asvector");
int axis = ctx.Attr<int>("axis"); int axis = ctx.Attr<int>("axis");
if (axis < 0) axis = xdim.size() + axis;
std::vector<int> reduce_axis = {axis}; std::vector<int> reduce_axis = {axis};
reduce_axis = GetReduceDim(reduce_axis, xdim.size(), asvector);
auto stream = ctx.cuda_device_context().stream(); auto stream = ctx.cuda_device_context().stream();
...@@ -195,7 +197,7 @@ class PnormGradCUDAKernel : public framework::OpKernel<T> { ...@@ -195,7 +197,7 @@ class PnormGradCUDAKernel : public framework::OpKernel<T> {
auto xdim = in_x->dims(); auto xdim = in_x->dims();
float porder = ctx.Attr<float>("porder"); float porder = ctx.Attr<float>("porder");
int axis = ctx.Attr<int>("axis"); int axis = ctx.Attr<int>("axis");
bool reduce_all = ((axis < 0) || (in_norm->numel() == 1)); bool reduce_all = (in_norm->numel() == 1);
if (axis < 0) axis = xdim.size() + axis; if (axis < 0) axis = xdim.size() + axis;
const std::vector<int> dims = {axis}; const std::vector<int> dims = {axis};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册