未验证 提交 64e2f10c 编写于 作者: Z zhangkaihuo 提交者: GitHub

Conv3 support bias (#43458)

上级 81abaaf5
...@@ -117,3 +117,44 @@ class TestSparseConv(unittest.TestCase): ...@@ -117,3 +117,44 @@ class TestSparseConv(unittest.TestCase):
#Currently, only support data_format='NDHWC' #Currently, only support data_format='NDHWC'
conv3d = paddle.incubate.sparse.nn.SubmConv3D( conv3d = paddle.incubate.sparse.nn.SubmConv3D(
1, 1, (1, 3, 3), data_format='NCDHW') 1, 1, (1, 3, 3), data_format='NCDHW')
def test_Conv3D_bias(self):
with _test_eager_guard():
paddle.seed(0)
shape = [1, 4, 4, 4, 3]
x = paddle.randn(shape)
sp_x = x.to_sparse_coo(4)
conv3d = paddle.nn.Conv3D(3, 2, 3, data_format='NDHWC')
sp_conv3d = paddle.incubate.sparse.nn.Conv3D(3,
2,
3,
data_format='NDHWC')
sp_conv3d.weight.set_value(
paddle.to_tensor(conv3d.weight.numpy().transpose(2, 3, 4, 1,
0)))
sp_conv3d.bias.set_value(paddle.to_tensor(conv3d.bias.numpy()))
x.stop_gradient = False
out = conv3d(x)
loss = out.mean()
loss.backward()
sp_x.stop_gradient = False
sp_out = sp_conv3d(sp_x)
dense_out = sp_out.to_dense()
sp_loss = dense_out.mean()
sp_loss.backward()
assert np.allclose(out.numpy(),
dense_out.numpy(),
atol=1e-3,
rtol=1e-3)
assert np.allclose(conv3d.weight.grad.numpy().transpose(
2, 3, 4, 1, 0),
sp_conv3d.weight.grad.numpy(),
atol=1e-3,
rtol=1e-3)
assert np.allclose(conv3d.bias.grad.numpy(),
sp_conv3d.bias.grad.numpy(),
atol=1e-5,
rtol=1e-5)
...@@ -82,9 +82,9 @@ class _Conv3D(Layer): ...@@ -82,9 +82,9 @@ class _Conv3D(Layer):
shape=filter_shape, shape=filter_shape,
attr=self._param_attr, attr=self._param_attr,
default_initializer=_get_default_param_initializer()) default_initializer=_get_default_param_initializer())
#self.bias = self.create_parameter( self.bias = self.create_parameter(attr=self._bias_attr,
# attr=self._bias_attr, shape=[self._out_channels], is_bias=True) shape=[self._out_channels],
self.bias = None is_bias=True)
def forward(self, x): def forward(self, x):
out = F.conv._conv3d(x, out = F.conv._conv3d(x,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册