提交 63d322f0 编写于 作者: D dengkaipeng 提交者: ceci3

fix attr dim calc. test=develop

上级 ca1502c7
...@@ -33,19 +33,34 @@ class SpectralNormOp : public framework::OperatorWithKernel { ...@@ -33,19 +33,34 @@ class SpectralNormOp : public framework::OperatorWithKernel {
"Output(Out) of SpectralNormOp should not be null."); "Output(Out) of SpectralNormOp should not be null.");
auto dim_weight = ctx->GetInputDim("Weight"); auto dim_weight = ctx->GetInputDim("Weight");
auto weight_dimsize = dim_weight.size(); auto rank_weight = dim_weight.size();
PADDLE_ENFORCE(weight_dimsize >= 2 && weight_dimsize <= 5, PADDLE_ENFORCE(rank_weight >= 2 && rank_weight <= 5,
"The size of dims of Input(Weights) can only be 2, 3," "The rank of Input(Weights) can only be 2, 3,"
"4, 5 for fc, conv1d, conv2d, conv3d layers."); "4, 5 for fc, conv1d, conv2d, conv3d layers.");
int dim = ctx->Attrs().Get<int>("dim"); int dim = ctx->Attrs().Get<int>("dim");
int power_iters = ctx->Attrs().Get<int>("power_iters"); int power_iters = ctx->Attrs().Get<int>("power_iters");
PADDLE_ENFORCE(dim >= 0 && dim < weight_dimsize - 1, PADDLE_ENFORCE(dim == 0 || dim == 1, "Attr(dim) can only be 0 or 1");
"Attr(dim) should be larger equal 0 and less then the"
"size of dims of Input(Weights) - 1,");
PADDLE_ENFORCE(power_iters >= 0, PADDLE_ENFORCE(power_iters >= 0,
"Attr(power_iters) should be larger equal then 0"); "Attr(power_iters) should be larger equal then 0");
int h = dim_weight[dim];
int w = 1;
for (int i = 0; i < rank_weight; i++) {
if (i != dim) {
w *= dim_weight[i];
}
}
auto dim_u = ctx->GetInputDim("U");
auto dim_v = ctx->GetInputDim("V");
PADDLE_ENFORCE_EQ(dim_u[0], h,
"Input(U) dims[0] should be equal to "
"Input(Weight) dims[Attr(dim)]");
PADDLE_ENFORCE_EQ(
dim_v[0], w,
"Input(V) dims[0] should be equal to "
"the product of Input(Weight) dims except dims[Attr(dim)]");
ctx->SetOutputDim("Out", dim_weight); ctx->SetOutputDim("Out", dim_weight);
ctx->ShareLoD("Weight", /*->*/ "Out"); ctx->ShareLoD("Weight", /*->*/ "Out");
} }
......
...@@ -10,6 +10,7 @@ ...@@ -10,6 +10,7 @@
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#include <vector>
#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/blas.h"
...@@ -27,17 +28,33 @@ using Array1 = Eigen::DSizes<int64_t, 1>; ...@@ -27,17 +28,33 @@ using Array1 = Eigen::DSizes<int64_t, 1>;
using Array2 = Eigen::DSizes<int64_t, 2>; using Array2 = Eigen::DSizes<int64_t, 2>;
using IndexPair = Eigen::IndexPair<int>; using IndexPair = Eigen::IndexPair<int>;
static inline void CalcMatrixShape(const Tensor& weight, const int dim, int* h, template <typename DeviceContext, typename T>
int* w) { static inline void TransCompute(const int rank, const Tensor& in, Tensor* out,
auto weight_dims = weight.dims(); const std::vector<int>& perm,
*h = 1; const DeviceContext& dev_ctx) {
*w = 1; if (rank <= 1 || rank > 5) {
for (int i = 0; i < weight_dims.size(); i++) { PADDLE_THROW("Invalid weight rank.");
if (i <= dim) {
*h *= weight_dims[i];
} else {
*w *= weight_dims[i];
} }
switch (rank) {
case 2:
math::Transpose<DeviceContext, T, 2> trans2;
trans2(dev_ctx, in, out, perm);
break;
case 3:
math::Transpose<DeviceContext, T, 3> trans3;
trans3(dev_ctx, in, out, perm);
break;
case 4:
math::Transpose<DeviceContext, T, 4> trans4;
trans4(dev_ctx, in, out, perm);
break;
case 5:
math::Transpose<DeviceContext, T, 5> trans5;
trans5(dev_ctx, in, out, perm);
break;
default:
break;
} }
} }
...@@ -83,6 +100,7 @@ template <typename DeviceContext, typename T> ...@@ -83,6 +100,7 @@ template <typename DeviceContext, typename T>
class SpectralNormKernel : public framework::OpKernel<T> { class SpectralNormKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
auto& dev_ctx = ctx.template device_context<DeviceContext>();
auto weight = ctx.Input<Tensor>("Weight"); auto weight = ctx.Input<Tensor>("Weight");
auto u = ctx.Input<Tensor>("U"); auto u = ctx.Input<Tensor>("U");
auto v = ctx.Input<Tensor>("V"); auto v = ctx.Input<Tensor>("V");
...@@ -92,10 +110,32 @@ class SpectralNormKernel : public framework::OpKernel<T> { ...@@ -92,10 +110,32 @@ class SpectralNormKernel : public framework::OpKernel<T> {
int power_iters = ctx.Attr<int>("power_iters"); int power_iters = ctx.Attr<int>("power_iters");
float eps = ctx.Attr<float>("eps"); float eps = ctx.Attr<float>("eps");
const int h = u->dims()[0];
const int w = v->dims()[0];
Tensor weight_mat; Tensor weight_mat;
int h, w; auto dims = weight->dims();
CalcMatrixShape(*weight, dim, &h, &w); const int rank = dims.size();
std::vector<int> real_dims;
if (dim != 0) {
std::vector<int> perm;
perm.push_back(dim);
real_dims.push_back(dims[dim]);
for (int i = 0; i < rank; i++) {
if (i != dim) {
perm.push_back(i);
real_dims.push_back(dims[i]);
}
}
weight_mat.mutable_data<T>(framework::make_ddim(real_dims),
ctx.GetPlace());
TransCompute<DeviceContext, T>(rank, *weight, &weight_mat, perm, dev_ctx);
} else {
for (int i = 0; i < rank; i++) {
real_dims.push_back(i);
}
TensorCopySync(*weight, ctx.GetPlace(), &weight_mat); TensorCopySync(*weight, ctx.GetPlace(), &weight_mat);
}
weight_mat = weight_mat.Resize({h, w}); weight_mat = weight_mat.Resize({h, w});
Tensor sigma; Tensor sigma;
...@@ -106,7 +146,25 @@ class SpectralNormKernel : public framework::OpKernel<T> { ...@@ -106,7 +146,25 @@ class SpectralNormKernel : public framework::OpKernel<T> {
CalcMatrixSigmaAndNormWeight<DeviceContext, T>( CalcMatrixSigmaAndNormWeight<DeviceContext, T>(
&sigma, &(uu.Resize({h, 1})), &(vv.Resize({w, 1})), &weight_mat, &sigma, &(uu.Resize({h, 1})), &(vv.Resize({w, 1})), &weight_mat,
power_iters, eps, ctx); power_iters, eps, ctx);
TensorCopySync(weight_mat.Resize(out->dims()), ctx.GetPlace(), out);
if (dim != 0) {
std::vector<int> perm;
for (int i = 0; i < rank; i++) {
if (i < dim) {
perm.push_back(i + 1);
} else if (i == dim) {
perm.push_back(0);
} else {
perm.push_back(i);
}
}
out->mutable_data<T>(dims, ctx.GetPlace());
TransCompute<DeviceContext, T>(
rank, weight_mat.Resize(framework::make_ddim(real_dims)), out, perm,
dev_ctx);
} else {
TensorCopySync(weight_mat.Resize(dims), ctx.GetPlace(), out);
}
} }
}; };
...@@ -115,6 +173,7 @@ class SpectralNormGradKernel : public framework::OpKernel<T> { ...@@ -115,6 +173,7 @@ class SpectralNormGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
auto& place = *ctx.template device_context<DeviceContext>().eigen_device(); auto& place = *ctx.template device_context<DeviceContext>().eigen_device();
auto& dev_ctx = ctx.template device_context<DeviceContext>();
auto blas = math::GetBlas<DeviceContext, T>(ctx); auto blas = math::GetBlas<DeviceContext, T>(ctx);
auto weight = ctx.Input<Tensor>("Weight"); auto weight = ctx.Input<Tensor>("Weight");
auto u = ctx.Input<Tensor>("U"); auto u = ctx.Input<Tensor>("U");
...@@ -126,11 +185,37 @@ class SpectralNormGradKernel : public framework::OpKernel<T> { ...@@ -126,11 +185,37 @@ class SpectralNormGradKernel : public framework::OpKernel<T> {
int power_iters = ctx.Attr<int>("power_iters"); int power_iters = ctx.Attr<int>("power_iters");
float eps = ctx.Attr<float>("eps"); float eps = ctx.Attr<float>("eps");
const int h = u->dims()[0];
const int w = v->dims()[0];
Tensor weight_mat, out_grad_mat; Tensor weight_mat, out_grad_mat;
int h, w; auto dims = weight->dims();
CalcMatrixShape(*weight, dim, &h, &w); const int rank = dims.size();
std::vector<int> real_dims;
if (dim != 0) {
std::vector<int> perm;
perm.push_back(dim);
real_dims.push_back(dims[dim]);
for (int i = 0; i < rank; i++) {
if (i != dim) {
perm.push_back(i);
real_dims.push_back(dims[i]);
}
}
weight_mat.mutable_data<T>(framework::make_ddim(real_dims),
ctx.GetPlace());
out_grad_mat.mutable_data<T>(framework::make_ddim(real_dims),
ctx.GetPlace());
TransCompute<DeviceContext, T>(rank, *weight, &weight_mat, perm, dev_ctx);
TransCompute<DeviceContext, T>(rank, *out_grad, &out_grad_mat, perm,
dev_ctx);
} else {
for (int i = 0; i < rank; i++) {
real_dims.push_back(i);
}
TensorCopySync(*weight, ctx.GetPlace(), &weight_mat); TensorCopySync(*weight, ctx.GetPlace(), &weight_mat);
TensorCopySync(*out_grad, ctx.GetPlace(), &out_grad_mat); TensorCopySync(*out_grad, ctx.GetPlace(), &out_grad_mat);
}
weight_mat = weight_mat.Resize({h, w}); weight_mat = weight_mat.Resize({h, w});
out_grad_mat = out_grad_mat.Resize({h, w}); out_grad_mat = out_grad_mat.Resize({h, w});
...@@ -148,21 +233,37 @@ class SpectralNormGradKernel : public framework::OpKernel<T> { ...@@ -148,21 +233,37 @@ class SpectralNormGradKernel : public framework::OpKernel<T> {
blas.MatMul(uu.Resize({h, 1}), false, vv.Resize({w, 1}), false, T(1), &uv, blas.MatMul(uu.Resize({h, 1}), false, vv.Resize({w, 1}), false, T(1), &uv,
T(0)); T(0));
Tensor weight_grad_mat, ones; Tensor weight_grad_mat;
weight_grad_mat.mutable_data<T>({h, w}, ctx.GetPlace()); weight_grad_mat.mutable_data<T>({h, w}, ctx.GetPlace());
ones.mutable_data<T>({h, w}, ctx.GetPlace());
auto weight_grad_mat_t = EigenTensor<T, 2>::From(weight_grad_mat); auto weight_grad_mat_t = EigenTensor<T, 2>::From(weight_grad_mat);
auto weight_mat_t = EigenTensor<T, 2>::From(weight_mat); auto weight_mat_t = EigenTensor<T, 2>::From(weight_mat);
auto out_grad_mat_t = EigenTensor<T, 2>::From(out_grad_mat); auto out_grad_mat_t = EigenTensor<T, 2>::From(out_grad_mat);
auto sigma_t = EigenTensor<T, 2>::From(sigma); auto sigma_t = EigenTensor<T, 2>::From(sigma);
auto uv_t = EigenTensor<T, 2>::From(uv); auto uv_t = EigenTensor<T, 2>::From(uv);
auto ones_t = EigenTensor<T, 2>::From(ones).setConstant((T)1);
weight_mat_t.device(place) = weight_mat_t.device(place) =
weight_mat_t.sum().eval().reshape(Array2(1, 1)).broadcast(Array2(h, w)); weight_mat_t.sum().eval().reshape(Array2(1, 1)).broadcast(Array2(h, w));
weight_grad_mat_t.device(place) = weight_grad_mat_t.device(place) =
out_grad_mat_t * (ones_t - uv_t * weight_mat_t) / sigma_t; out_grad_mat_t * (out_grad_mat_t.constant(1.0) - uv_t * weight_mat_t) /
TensorCopySync(weight_grad_mat.Resize(weight_grad->dims()), ctx.GetPlace(), sigma_t;
weight_grad);
if (dim != 0) {
std::vector<int> perm;
for (int i = 0; i < rank; i++) {
if (i < dim) {
perm.push_back(i + 1);
} else if (i == dim) {
perm.push_back(0);
} else {
perm.push_back(i);
}
}
weight_grad->mutable_data<T>(dims, ctx.GetPlace());
TransCompute<DeviceContext, T>(
rank, weight_grad_mat.Resize(framework::make_ddim(real_dims)),
weight_grad, perm, dev_ctx);
} else {
TensorCopySync(weight_grad_mat.Resize(dims), ctx.GetPlace(), weight_grad);
}
} }
}; };
......
...@@ -94,6 +94,7 @@ __all__ = [ ...@@ -94,6 +94,7 @@ __all__ = [
'multiplex', 'multiplex',
'layer_norm', 'layer_norm',
'group_norm', 'group_norm',
'spectral_norm',
'softmax_with_cross_entropy', 'softmax_with_cross_entropy',
'smooth_l1', 'smooth_l1',
'one_hot', 'one_hot',
...@@ -3347,6 +3348,80 @@ def group_norm(input, ...@@ -3347,6 +3348,80 @@ def group_norm(input,
return helper.append_activation(group_norm_out) return helper.append_activation(group_norm_out)
@templatedoc()
def spectral_norm(weight,
dim=0,
power_iters=1,
eps=1e-12,
u_attr=None,
v_attr=None,
name=None):
"""
**Spectral Normalization Layer**
Refer to `Spectral Normalization <https://arxiv.org/abs/1802.05957>`_ .
Args:
weight(${weight_type}): ${weight_comment}
dim(${dim_type}): ${dim_comment}
eps(${eps_type}): ${eps_comment}
u_attr(ParamAttr|None): The parameter attribute for vector u in
spectral calculatings, set None to use default attribute, which
generates random values in normal distribution N(0, 1). Default: None.
v_attr(ParamAttr|None): The parameter attribute for vector v in
spectral calculatings, set None to use default attribute, which
generates random values in normal distribution N(0, 1). Default: None.
name (str): The name of this layer. It is optional.
Returns:
Variable: A tensor variable of weight after spetral normalization.
Examples:
>>> weight = fluid.layers.data(name='weight', shape=[8, 32, 32],
>>> dtype='float32')
>>> x = fluid.layers.spectral_norm(weight=data, dim=1, power_iters=2)
"""
helper = LayerHelper('spectral_norm', **locals())
dtype = helper.input_dtype()
# create intput and parameters
inputs = {'Weight': weight}
input_shape = input.shape
if data_layout != 'NCHW':
raise ValueError("unsupported data layout:" + data_layout)
param_shape = [input_shape[1]]
if param_attr:
scale = helper.create_parameter(
attr=helper.param_attr,
shape=param_shape,
dtype=dtype,
default_initializer=Constant(1.0))
inputs['Scale'] = scale
if bias_attr:
bias = helper.create_parameter(
attr=helper.bias_attr, shape=param_shape, dtype=dtype, is_bias=True)
inputs['Bias'] = bias
# create output
mean_out = helper.create_variable(dtype=dtype, stop_gradient=True)
variance_out = helper.create_variable(dtype=dtype, stop_gradient=True)
group_norm_out = helper.create_variable(dtype=dtype)
helper.append_op(
type="group_norm",
inputs=inputs,
outputs={
"Y": group_norm_out,
"Mean": mean_out,
"Variance": variance_out,
},
attrs={"epsilon": epsilon,
"groups": groups})
return helper.append_activation(group_norm_out)
def conv2d_transpose(input, def conv2d_transpose(input,
num_filters, num_filters,
output_size=None, output_size=None,
......
...@@ -22,13 +22,17 @@ from paddle.fluid import core ...@@ -22,13 +22,17 @@ from paddle.fluid import core
def spectral_norm(weight, u, v, dim, power_iters, eps): def spectral_norm(weight, u, v, dim, power_iters, eps):
h = w = 1 shape = weight.shape
for i, d in enumerate(weight.shape): weight_mat = weight.copy()
if i <= dim: h = shape[dim]
h *= d w = np.prod(shape) // h
if dim != 0:
perm = [dim] + [d for d in range(len(shape)) if d != dim]
weight_mat = weight_mat.transpose(perm)
real_shape = weight_mat.shape
else: else:
w *= d real_shape = shape
weight_mat = weight.reshape((h, w)) weight_mat = weight_mat.reshape((h, w))
u = u.reshape((h, 1)) u = u.reshape((h, 1))
v = v.reshape((w, 1)) v = v.reshape((w, 1))
...@@ -41,7 +45,7 @@ def spectral_norm(weight, u, v, dim, power_iters, eps): ...@@ -41,7 +45,7 @@ def spectral_norm(weight, u, v, dim, power_iters, eps):
u = u / (u_norm + eps) u = u / (u_norm + eps)
sigma = (u * np.matmul(weight_mat, v)).sum() sigma = (u * np.matmul(weight_mat, v)).sum()
return (weight_mat / sigma).reshape(weight.shape) return weight / sigma
class TestSpectralNormOpNoGrad(OpTest): class TestSpectralNormOpNoGrad(OpTest):
...@@ -83,8 +87,8 @@ class TestSpectralNormOpNoGrad(OpTest): ...@@ -83,8 +87,8 @@ class TestSpectralNormOpNoGrad(OpTest):
class TestSpectralNormOpNoGrad2(TestSpectralNormOpNoGrad): class TestSpectralNormOpNoGrad2(TestSpectralNormOpNoGrad):
def initTestCase(self): def initTestCase(self):
self.weight_shape = (2, 3, 3, 3) self.weight_shape = (2, 3, 3, 3)
self.u_shape = (6, ) self.u_shape = (3, )
self.v_shape = (9, ) self.v_shape = (18, )
self.dim = 1 self.dim = 1
self.power_iters = 10 self.power_iters = 10
self.eps = 1e-12 self.eps = 1e-12
...@@ -110,8 +114,8 @@ class TestSpectralNormOp(TestSpectralNormOpNoGrad): ...@@ -110,8 +114,8 @@ class TestSpectralNormOp(TestSpectralNormOpNoGrad):
class TestSpectralNormOp2(TestSpectralNormOp): class TestSpectralNormOp2(TestSpectralNormOp):
def initTestCase(self): def initTestCase(self):
self.weight_shape = (2, 3, 3, 3) self.weight_shape = (2, 3, 3, 3)
self.u_shape = (6, ) self.u_shape = (3, )
self.v_shape = (9, ) self.v_shape = (18, )
self.dim = 1 self.dim = 1
self.power_iters = 0 self.power_iters = 0
self.eps = 1e-12 self.eps = 1e-12
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册