未验证 提交 22f06e64 编写于 作者: Z Zhong Hui 提交者: GitHub

fix compile problem on windows and some invalid argument check, cherry pick...

fix compile problem on windows and some invalid argument check, cherry pick from #23831, test=release/2.0 (#23870)

Fix the compile problem on windows, cherry-pick from develop branch 
上级 35b50b7f
......@@ -64,18 +64,19 @@ class PnormOp : public framework::OperatorWithKernel {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "p_norm");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "p_norm");
auto porder = ctx->Attrs().Get<float>("porder");
PADDLE_ENFORCE_NE(porder, 0,
platform::errors::InvalidArgument(
"The input porder of p_norm is not support for "
"porder == 0, INFINITY, -INFINITY now."));
PADDLE_ENFORCE_NE(porder, INFINITY,
platform::errors::InvalidArgument(
platform::errors::Unimplemented(
"The input porder of p_norm is not support for "
"porder == 0, INFINITY, -INFINITY now."));
PADDLE_ENFORCE_NE(porder, -INFINITY,
platform::errors::InvalidArgument(
platform::errors::Unimplemented(
"The input porder of p_norm is not support for "
"porder == 0, INFINITY, -INFINITY now."));
PADDLE_ENFORCE_GT(porder, 0.0f,
platform::errors::InvalidArgument(
"The input porder of p_norm is not support for "
"porder <= 0, But received porder=%f.",
porder));
auto xdim = ctx->GetInputDim("X");
int axis = ctx->Attrs().Get<int>("axis");
bool keepdim = ctx->Attrs().Get<bool>("keepdim");
......
......@@ -44,6 +44,9 @@ __global__ void Pnorm(const T* x, const int pre,
typedef cub::BlockReduce<T, BlockDim> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
int num = pre * post;
auto porder_t = static_cast<T>(porder);
auto porder_inv = static_cast<T>(1.0 / porder);
for (int i = blockIdx.x; i < num; i += gridDim.x) {
int base = (i / post) * post * axis_n + (i % post);
......@@ -51,12 +54,12 @@ __global__ void Pnorm(const T* x, const int pre,
__shared__ T norm;
for (int j = threadIdx.x; j < axis_n; j += blockDim.x) {
const T x_ij = x[base + j * post];
sum += inline_pow(inline_abs(x_ij), porder);
sum += inline_pow(inline_abs(x_ij), porder_t);
}
T reduce_result = BlockReduce(temp_storage).Sum(sum);
if (threadIdx.x == 0) {
norm = inline_pow(reduce_result, 1.0f / porder);
norm = inline_pow(reduce_result, porder_inv);
out_norm[i] = norm;
}
__syncthreads();
......@@ -100,6 +103,7 @@ __global__ void PnormGradient(const T* x, const T* x_norm, const T* y_grad,
__shared__ typename BlockReduce::TempStorage temp_storage_sum;
// dx = (x/pnorm_broadcast).pow(p-1) * norm_dy.broadcast * sign(x)
int num = pre * post;
auto porder_grad = static_cast<T>(porder - 1.0f);
for (int i = blockIdx.x; i < num; i += gridDim.x) {
T sum = 0.0;
__shared__ T row_sum;
......@@ -128,8 +132,8 @@ __global__ void PnormGradient(const T* x, const T* x_norm, const T* y_grad,
int index = base + j * post;
const T x_ij = inline_abs(x[index]);
const T dy_ij = y_grad[index];
x_grad[index] = inline_pow(x_ij, porder - 1.0f) /
(inline_pow(pnorm_i, porder - 1.0f) + eps) * yout_i *
x_grad[index] = inline_pow(x_ij, porder_grad) /
(inline_pow(pnorm_i, porder_grad) + eps) * yout_i *
inline_sign(x[index]);
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册