pytorch代码转paddle后batch_size维度对不上
Created by: cuicheng01
pytorch代码:
class SpatialGroupEnhance(nn.Module):
def __init__(self, groups = 64):
super(SpatialGroupEnhance, self).__init__()
self.groups = groups
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.weight = Parameter(torch.zeros(1, groups, 1, 1))
self.bias = Parameter(torch.ones(1, groups, 1, 1))
self.sig = nn.Sigmoid()
def forward(self, x): # (b, c, h, w)
b, c, h, w = x.size()
x = x.view(b * self.groups, -1, h, w)
xn = x * self.avg_pool(x)
xn = xn.sum(dim=1, keepdim=True)
t = xn.view(b * self.groups, -1)
t = t - t.mean(dim=1, keepdim=True)
std = t.std(dim=1, keepdim=True) + 1e-5
t = t / std
t = t.view(b, self.groups, h, w)
t = t * self.weight + self.bias
t = t.view(b * self.groups, 1, h, w)
x = x * self.sig(t)
x = x.view(b, c, h, w)
return x
paddle代码:
def sge_block(self, x, groups=64, name=None):
weight = fluid.layers.create_parameter(shape=[1,groups,1,1], dtype='float32',
default_initializer=fluid.initializer.Constant(value=1.0))
bias = fluid.layers.create_parameter(shape=[1,groups,1,1], dtype='float32',
default_initializer=fluid.initializer.Constant(value=0.0))
batchsize, num_channels, height, width = x.shape[0], x.shape[1], x.shape[2], x.shape[3]
channels_per_group = num_channels // groups
x = fluid.layers.reshape(x=x, shape=[-1, channels_per_group, height, width])
xn = x * fluid.layers.pool2d(input=x, pool_type='avg', global_pooling=True, use_cudnn=False)
xn = fluid.layers.reduce_sum(input=xn, dim=1, keep_dim=True)
t = fluid.layers.reshape(x=xn, shape=[-1, height * width])
t = fluid.layers.reshape(x=t, shape=[-1, groups, height, width])
t = t*weight + bias
t = fluid.layers.reshape(x=t, shape=[-1, 1, height, width])
x = x * fluid.layers.sigmoid(t)
x = fluid.layers.reshape(x=x, shape=[-1, num_channels, height, width])
return x
单卡batch_size = 32,
报错信息:
Enforce failed. Expected x_dims[i + axis] == y_dims[i], but received x_dims[i + axis]:32 != y_dims[i]:1. Broadcast dimension mismatch. at [/paddle/paddle/fluid/operators/elementwise/elementwise_op_function.h:63]