未验证 提交 6cbeafb6 编写于 作者: Z Zhong Hui 提交者: GitHub

add zero norm, inf norm support for p_norm op (#26364)

* add zero norm, inf norm support for p_norm op

* fix the invalid argument check, fix the dtype problem in test case.
上级 6cd67a81
......@@ -25,34 +25,49 @@ class PnormOpMaker : public framework::OpProtoAndCheckerMaker {
void Make() override {
AddInput("X", "(Tensor) A tensor of rank >= axis.");
AddAttr<float>("porder",
"The porder is the p order vector norm to calculate.")
"(float, default 2) The porder is the p order vector norm "
"to calculate. Available for porder=0, inf, -inf and any "
"real number.")
.SetDefault(2.0f);
AddAttr<int>("axis",
"The axis on which to apply normalization. If axis < 0, "
"The axis on which to apply norm operation. If axis < 0, "
"the dimension to pnorm is rank(X) + axis. -1 is "
"the last dimension.")
.SetDefault(-1);
AddAttr<float>("epsilon",
"(float, default 1e-10) The epsilon value is used "
"(float, default 1e-12) The epsilon value is used "
"to avoid division by zero.")
.SetDefault(1.0e-12f);
AddAttr<bool>(
"keepdim",
"(bool, default false) Whether to keep the dimensions as the input")
"(bool, default false) Whether to keep the dimensions as the input.")
.SetDefault(false);
AddOutput(
"Out",
"(Tensor) Output tensor for the `(sum(x.pow(p)) + epsion).pow(1/p)`");
AddOutput("Out", "(Tensor) Output result tensor of p-norm");
AddComment(R"DOC(
Pnorm Operator.
Given a tensor X, compute Lp-norm of X.
Given a tensor, apply 2-normalization along the provided axis.
When p = 0, defining $0^0 = 0$, the zero-norm of X is simply the number of non-zero elements of X.
$$
||X||_{0} = \lim_{p \rightarrow 0} \sum_i |x_i|^p
$$
When p = inf, the inf-norm of X is the maximum element of X.
$$
||X||_\infty = \max_i |x_i|
$$
When p = -inf, the negative-inf-norm of X is the minimum element of X.
$$
||X||_{-\infty} = \min_i |x_i|
$$
Otherwise, the p-norm of X follows the formula,
$$
pnorm = \(\sum_i {abs\(x_i\)^p} \)^{1/p}
||X||_{p} = (\sum_i |x_i|^p)^{1/p}
$$
where, $\sum_i $ is calculated along the `axis` dimension.
where, $\sum_i{x_i^p}$ is calculated along the `axis` dimension.
)DOC");
}
};
......@@ -63,31 +78,33 @@ class PnormOp : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "p_norm");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "p_norm");
auto porder = ctx->Attrs().Get<float>("porder");
PADDLE_ENFORCE_NE(porder, INFINITY,
platform::errors::Unimplemented(
"The input porder of p_norm is not support for "
"porder == 0, INFINITY, -INFINITY now."));
PADDLE_ENFORCE_NE(porder, -INFINITY,
platform::errors::Unimplemented(
"The input porder of p_norm is not support for "
"porder == 0, INFINITY, -INFINITY now."));
PADDLE_ENFORCE_GT(porder, 0.0f,
platform::errors::InvalidArgument(
"The input porder of p_norm is not support for "
"porder <= 0, But received porder=%f.",
porder));
auto xdim = ctx->GetInputDim("X");
auto x_dim = ctx->GetInputDim("X");
auto x_rank = x_dim.size();
int axis = ctx->Attrs().Get<int>("axis");
bool keepdim = ctx->Attrs().Get<bool>("keepdim");
if (axis < 0) axis = xdim.size() + axis;
PADDLE_ENFORCE_GE(axis, -x_rank,
platform::errors::InvalidArgument(
"Attr(axis) value should be in range [-R, R-1], R is "
"the rank of Input(X). But received axis: %d, R: %d. "
"Current Input(X)'s shape is=[%s].",
axis, x_rank, x_dim));
PADDLE_ENFORCE_LT(axis, x_rank,
platform::errors::InvalidArgument(
"Attr(axis) value should be in range [-R, R-1], R is "
"the rank of Input(X). But received axis: %d, R: %d. "
"Current Input(X)'s shape is=[%s].",
axis, x_rank, x_dim));
if (axis < 0) axis = x_dim.size() + axis;
std::vector<int> reduce_dims;
for (int i = 0; i < xdim.size(); ++i) {
if (i != axis) reduce_dims.emplace_back(xdim[i]);
for (int i = 0; i < x_dim.size(); ++i) {
if (i != axis) reduce_dims.emplace_back(x_dim[i]);
}
xdim[axis] = 1;
x_dim[axis] = 1;
if (keepdim) {
ctx->SetOutputDim("Out", xdim);
ctx->SetOutputDim("Out", x_dim);
} else {
ctx->SetOutputDim("Out", framework::make_ddim(reduce_dims));
}
......
......@@ -49,20 +49,70 @@ __global__ void Pnorm(const T* x, const int pre,
for (int i = blockIdx.x; i < num; i += gridDim.x) {
int base = (i / post) * post * axis_n + (i % post);
T sum = 0.0;
__shared__ T norm;
for (int j = threadIdx.x; j < axis_n; j += blockDim.x) {
const T x_ij = x[base + j * post];
sum += inline_pow(inline_abs(x_ij), porder_t);
}
T reduce_result = BlockReduce(temp_storage).Sum(sum);
if (threadIdx.x == 0) out_norm[i] = inline_pow(reduce_result, porder_inv);
}
}
if (threadIdx.x == 0) {
norm = inline_pow(reduce_result, porder_inv);
out_norm[i] = norm;
template <typename T, int BlockDim>
__global__ void ZeorNorm(const T* x, const int pre,
const int axis_n, // dim in axis
const int post, T* out_norm) {
typedef cub::BlockReduce<T, BlockDim> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
int num = pre * post;
for (int i = blockIdx.x; i < num; i += gridDim.x) {
int base = (i / post) * post * axis_n + (i % post);
T sum = 0.0;
for (int j = threadIdx.x; j < axis_n; j += blockDim.x) {
const T x_ij = x[base + j * post];
sum += static_cast<T>(x_ij != 0);
}
__syncthreads();
T reduce_result = BlockReduce(temp_storage).Sum(sum);
if (threadIdx.x == 0) out_norm[i] = reduce_result;
}
}
template <typename T, int BlockDim>
__global__ void InfNorm(const T* x, const int pre,
const int axis_n, // dim in axis
const int post, T* out_norm) {
typedef cub::BlockReduce<T, BlockDim> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
int num = pre * post;
for (int i = blockIdx.x; i < num; i += gridDim.x) {
int base = (i / post) * post * axis_n + (i % post);
T cur_max = inline_abs(x[base]);
for (int j = threadIdx.x; j < axis_n; j += blockDim.x) {
T x_ij_abs = inline_abs(x[base + j * post]);
if (cur_max < x_ij_abs) cur_max = x_ij_abs;
}
T reduce_result = BlockReduce(temp_storage).Reduce(cur_max, cub::Max());
if (threadIdx.x == 0) out_norm[i] = reduce_result;
}
}
template <typename T, int BlockDim>
__global__ void NegInfNorm(const T* x, const int pre,
const int axis_n, // dim in axis
const int post, T* out_norm) {
typedef cub::BlockReduce<T, BlockDim> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
int num = pre * post;
for (int i = blockIdx.x; i < num; i += gridDim.x) {
int base = (i / post) * post * axis_n + (i % post);
T cur_min = inline_abs(x[base]);
for (int j = threadIdx.x; j < axis_n; j += blockDim.x) {
T x_ij_abs = inline_abs(x[base + j * post]);
if (cur_min > x_ij_abs) cur_min = x_ij_abs;
}
T reduce_result = BlockReduce(temp_storage).Reduce(cur_min, cub::Min());
if (threadIdx.x == 0) out_norm[i] = reduce_result;
}
}
......@@ -89,8 +139,19 @@ class PnormCUDAKernel : public framework::OpKernel<T> {
int max_threads = dev_ctx.GetMaxPhysicalThreadCount();
const int max_blocks = std::max(max_threads / block, 1);
int grid = std::min(max_blocks, pre * post);
Pnorm<T, block><<<grid, block, 0, dev_ctx.stream()>>>(x, pre, n, post,
porder, norm);
if (porder == 0) {
ZeorNorm<T, block><<<grid, block, 0, dev_ctx.stream()>>>(x, pre, n, post,
norm);
} else if (porder == INFINITY) {
InfNorm<T, block><<<grid, block, 0, dev_ctx.stream()>>>(x, pre, n, post,
norm);
} else if (porder == -INFINITY) {
NegInfNorm<T, block><<<grid, block, 0, dev_ctx.stream()>>>(x, pre, n,
post, norm);
} else {
Pnorm<T, block><<<grid, block, 0, dev_ctx.stream()>>>(x, pre, n, post,
porder, norm);
}
}
};
......@@ -112,7 +173,6 @@ __global__ void PnormGradient(const T* x, const T* x_norm, const T* y_grad,
pnorm_i = x_norm[i];
yout_i = y_grad[i];
}
__syncthreads();
for (int j = threadIdx.x; j < axis_n; j += blockDim.x) {
......@@ -125,6 +185,33 @@ __global__ void PnormGradient(const T* x, const T* x_norm, const T* y_grad,
}
}
template <typename T, int BlockDim>
__global__ void InfNormGradient(const T* x, const T* x_norm, const T* y_grad,
const int pre, const int axis_n, const int post,
T* x_grad) {
int num = pre * post;
for (int i = blockIdx.x; i < num; i += gridDim.x) {
__shared__ T pnorm_i;
__shared__ T yout_i;
auto base = (i / post) * post * axis_n + (i % post);
if (threadIdx.x == 0) {
pnorm_i = x_norm[i];
yout_i = y_grad[i];
}
__syncthreads();
for (int j = threadIdx.x; j < axis_n; j += blockDim.x) {
int index = base + j * post;
const T x_ij = inline_abs(x[index]);
if (x_ij == pnorm_i) {
x_grad[index] = inline_sign(x[index]) * yout_i;
} else {
x_grad[index] = static_cast<T>(0);
}
}
}
}
template <typename DeviceContext, typename T, typename AttrType = T>
class PnormGradCUDAKernel : public framework::OpKernel<T> {
public:
......@@ -153,8 +240,17 @@ class PnormGradCUDAKernel : public framework::OpKernel<T> {
int max_threads = dev_ctx.GetMaxPhysicalThreadCount();
const int max_blocks = std::max(max_threads / block, 1);
int grid = std::min(max_blocks, pre * post);
PnormGradient<T, block><<<grid, block, 0, dev_ctx.stream()>>>(
x, x_norm, norm_dy, porder, pre, n, post, eps, dx);
if (porder == 0) {
math::SetConstant<DeviceContext, T> set_zero;
auto& dev_ctx = ctx.template device_context<DeviceContext>();
set_zero(dev_ctx, out_dx, static_cast<T>(0));
} else if (porder == INFINITY || porder == -INFINITY) {
InfNormGradient<T, block><<<grid, block, 0, dev_ctx.stream()>>>(
x, x_norm, norm_dy, pre, n, post, dx);
} else {
PnormGradient<T, block><<<grid, block, 0, dev_ctx.stream()>>>(
x, x_norm, norm_dy, porder, pre, n, post, eps, dx);
}
}
};
......
......@@ -58,10 +58,20 @@ class PnormKernel : public framework::OpKernel<T> {
auto x = x_e.reshape(shape);
auto norm = norm_e.reshape(norm_shape);
// p=0 means number of non-zero elements of (x)
// p=inf means the maximum of |x|
// p=-inf means the minimum of |x|
// otherwise, Lp-norm = pow(sum(pow(|x|, p)), 1/p)
Eigen::DSizes<int, 1> rdim(1);
auto xp = (x.abs()).pow(porder);
auto sum = xp.sum(rdim);
norm.device(*place) = sum.pow(1.0f / porder);
if (porder == 0) {
norm.device(*place) = (x != x.constant(0)).template cast<T>().sum(rdim);
} else if (porder == INFINITY) {
norm.device(*place) = x.abs().maximum(rdim);
} else if (porder == -INFINITY) {
norm.device(*place) = x.abs().minimum(rdim);
} else {
norm.device(*place) = x.abs().pow(porder).sum(rdim).pow(1.0f / porder);
}
}
};
......@@ -102,10 +112,20 @@ class PnormGradKernel : public framework::OpKernel<T> {
Eigen::DSizes<int, 1> rdim(1);
Eigen::DSizes<int, 3> bcast(1, n, 1);
dx.device(*place) = (x.abs()).pow(porder - 1.0f);
dx.device(*place) =
dx / ((norm.broadcast(bcast)).pow(porder - 1.0f) + x.constant(eps));
dx.device(*place) = dx * norm_dy.broadcast(bcast) * x.sign();
if (porder == 0) {
math::SetConstant<DeviceContext, T> set_zero;
auto& dev_ctx = ctx.template device_context<DeviceContext>();
set_zero(dev_ctx, out_dx, static_cast<T>(0));
} else if (porder == INFINITY || porder == -INFINITY) {
dx.device(*place) =
(x.abs() == norm.broadcast(bcast)).template cast<T>() * x.sign() *
norm_dy.broadcast(bcast);
} else {
dx.device(*place) =
(x.abs()).pow(porder - 1.0f) /
((norm.broadcast(bcast)).pow(porder - 1.0f) + x.constant(eps));
dx.device(*place) = dx * norm_dy.broadcast(bcast) * x.sign();
}
}
};
} // namespace operators
......
......@@ -23,16 +23,16 @@ import paddle.fluid as fluid
def p_norm(x, axis, porder, keepdims=False):
if axis is None: axis = -1
xp = np.power(np.abs(x), porder)
s = np.sum(xp, axis=axis, keepdims=keepdims)
r = np.power(s, 1.0 / porder)
r = np.linalg.norm(
x, ord=porder, axis=axis, keepdims=keepdims).astype(x.dtype)
return r
def frobenius_norm(x, axis=None, keepdims=False):
if isinstance(axis, list): axis = tuple(axis)
if axis is None: axis = (-2, -1)
r = np.linalg.norm(x, ord='fro', axis=axis, keepdims=keepdims)
r = np.linalg.norm(
x, ord='fro', axis=axis, keepdims=keepdims).astype(x.dtype)
return r
......@@ -89,6 +89,7 @@ class TestPnormOp(OpTest):
'porder': float(self.porder)
}
self.outputs = {'Out': norm}
self.gradient = self.calc_gradient()
def test_check_output(self):
self.check_output()
......@@ -104,6 +105,34 @@ class TestPnormOp(OpTest):
self.keepdim = False
self.dtype = "float64"
def calc_gradient(self):
self.attrs = {
'epsilon': self.epsilon,
'axis': self.axis,
'keepdim': self.keepdim,
'porder': float(self.porder)
}
x = self.inputs["X"]
porder = self.attrs["porder"]
axis = self.attrs["axis"]
if porder == 0:
grad = np.zeros(x.shape).astype(x.dtype)
elif porder in [float("inf"), float("-inf")]:
norm = p_norm(x, axis=axis, porder=porder, keepdims=True)
x_abs = np.abs(x)
grad = np.sign(x)
grad[x_abs != norm] = 0.0
else:
norm = p_norm(x, axis=axis, porder=porder, keepdims=True)
grad = np.power(norm, 1 - porder) * np.power(
np.abs(x), porder - 1) * np.sign(x)
numel = 1
for s in x.shape:
numel *= s
numel /= x.shape[axis]
return [grad.astype(x.dtype) * 1 / numel]
class TestPnormOp2(TestPnormOp):
def init_test_case(self):
......@@ -118,6 +147,45 @@ class TestPnormOp2(TestPnormOp):
self.check_grad(['X'], 'Out')
class TestPnormOp3(TestPnormOp):
def init_test_case(self):
self.shape = [3, 20, 3]
self.axis = 2
self.epsilon = 1e-12
self.porder = np.inf
self.keepdim = True
self.dtype = "float32"
def test_check_grad(self):
self.check_grad(['X'], 'Out', user_defined_grads=self.gradient)
class TestPnormOp4(TestPnormOp):
def init_test_case(self):
self.shape = [3, 20, 3]
self.axis = 2
self.epsilon = 1e-12
self.porder = -np.inf
self.keepdim = True
self.dtype = "float32"
def test_check_grad(self):
self.check_grad(['X'], 'Out', user_defined_grads=self.gradient)
class TestPnormOp5(TestPnormOp):
def init_test_case(self):
self.shape = [3, 20, 3]
self.axis = 2
self.epsilon = 1e-12
self.porder = 0
self.keepdim = True
self.dtype = "float32"
def test_check_grad(self):
self.check_grad(['X'], 'Out', user_defined_grads=self.gradient)
def run_out(self, p, axis, shape_x, shape_y, dtype):
with fluid.program_guard(fluid.Program()):
data1 = fluid.data(name="X", shape=shape_x, dtype=dtype)
......@@ -170,6 +238,9 @@ class API_NormTest(unittest.TestCase):
run_fro(self, p='fro', axis=[0, 1], shape_x=[3, 3, 4], dtype="float64")
run_pnorm(self, p=2, axis=None, shape_x=[3, 4], dtype="float32")
run_pnorm(self, p=2, axis=1, shape_x=[3, 4], dtype="float64")
run_pnorm(self, p=np.inf, axis=1, shape_x=[3, 4], dtype="float32")
run_pnorm(self, p=-np.inf, axis=1, shape_x=[3, 4], dtype="float64")
run_pnorm(self, p=0, axis=1, shape_x=[3, 4], dtype="float64")
def test_name(self):
with fluid.program_guard(fluid.Program()):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册