未验证 提交 1a30fe54 编写于 作者: U umiswing 提交者: GitHub

[Sparse] Add Spconv2d static mode support. (#54371)

上级 4ebb4764
......@@ -194,21 +194,38 @@ def _conv2d(
subm,
key if key is not None else "",
)
x = reshape(x, [n, h, w, -1])
weight = paddle.reshape(
weight, [h_filter, w_filter, c_filter, m_filter]
else:
inputs = {'x': x, 'kernel': weight}
attrs = {
'paddings': padding,
'dilations': dilation,
'strides': stride,
'groups': groups,
'subm': subm,
'key': key,
}
op_type = 'sparse_conv3d'
helper = LayerHelper(op_type, **locals())
rulebook = helper.create_variable_for_type_inference(
dtype='int32', stop_gradient=True
)
n_out = pre_bias.shape[0]
h_out = pre_bias.shape[2]
w_out = pre_bias.shape[3]
channels_out = pre_bias.shape[4]
pre_bias = reshape(pre_bias, [n_out, h_out, w_out, channels_out])
if bias is not None:
return add(pre_bias, bias)
else:
return pre_bias
counter = helper.create_variable_for_type_inference(
dtype='int32', stop_gradient=True
)
pre_bias = helper.create_sparse_variable_for_type_inference(x.dtype)
outputs = {"out": pre_bias, "rulebook": rulebook, "counter": counter}
helper.append_op(
type=op_type, inputs=inputs, outputs=outputs, attrs=attrs
)
n_out = pre_bias.shape[0]
h_out = pre_bias.shape[2]
w_out = pre_bias.shape[3]
channels_out = pre_bias.shape[4]
pre_bias = reshape(pre_bias, [n_out, h_out, w_out, channels_out])
if bias is not None:
return add(pre_bias, bias)
else:
raise ValueError("Only support dynamic_mode now.")
return pre_bias
def conv3d(
......
......@@ -734,7 +734,6 @@ def expm1(x, name=None):
return _C_ops.sparse_expm1(x)
@dygraph_only
def reshape(x, shape, name=None):
"""
Changes the shape of ``x`` without changing its value, requiring x to be a SparseCooTensor or SparseCsrTensor.
......@@ -788,7 +787,38 @@ def reshape(x, shape, name=None):
# the shape of sp_out is [1, 2, 2, 3, 3]
"""
return _C_ops.sparse_reshape(x, shape)
if in_dynamic_mode():
return _C_ops.sparse_reshape(x, shape)
else:
check_variable_and_dtype(
x,
'x',
[
'float16',
'float32',
'float64',
'int16',
'int32',
'int64',
'bool',
'uint16',
],
'reshape',
)
check_type(shape, 'shape', (list, tuple), 'reshape')
inputs = {"x": x}
attrs = {"shape": shape}
helper = LayerHelper('sparse_reshape')
out = helper.create_sparse_variable_for_type_inference(x.dtype)
helper.append_op(
type='sparse_reshape',
inputs=inputs,
outputs={'out': out},
attrs=attrs,
)
return out
def isnan(x, name=None):
......
......@@ -312,63 +312,133 @@ class TestSparseConv(unittest.TestCase):
class TestStatic(unittest.TestCase):
def test(self):
paddle.enable_static()
indices = paddle.static.data(
name='indices', shape=[4, 4], dtype='int32'
)
values = paddle.static.data(
name='values', shape=[4, 1], dtype='float32'
)
dense_shape = [1, 1, 3, 4, 1]
sp_x = sparse.sparse_coo_tensor(indices, values, dense_shape)
main = paddle.static.Program()
with paddle.static.program_guard(main):
indices = paddle.static.data(
name='indices', shape=[4, 4], dtype='int32'
)
values = paddle.static.data(
name='values', shape=[4, 1], dtype='float32'
)
dense_shape = [1, 1, 3, 4, 1]
sp_x = sparse.sparse_coo_tensor(indices, values, dense_shape)
weight_shape = [1, 3, 3, 1, 1]
weight = paddle.static.data(
name='weight', shape=weight_shape, dtype='float32'
)
bias_shape = [1]
bias = paddle.static.data(
name='bias', shape=bias_shape, dtype='float32'
)
out = sparse.nn.functional.conv3d(
sp_x,
weight,
bias,
stride=1,
padding=0,
dilation=1,
groups=1,
data_format="NDHWC",
)
sp_out = sparse.nn.functional.relu(out)
out_indices = sp_out.indices()
out_values = sp_out.values()
out = sp_out.to_dense()
exe = paddle.static.Executor()
indices_data = [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 1, 2], [1, 3, 2, 3]]
values_data = [[1.0], [2.0], [3.0], [4.0]]
weight_data = np.array(
[[[[[1], [1], [1]], [[1], [1], [1]], [[1], [1], [1]]]]]
).astype('float32')
weight_data = weight_data.reshape(weight_shape)
bias_data = np.array([1]).astype('float32')
fetch = exe.run(
feed={
'indices': indices_data,
'values': values_data,
'weight': weight_data,
'bias': bias_data,
},
fetch_list=[out, out_indices, out_values],
return_numpy=True,
)
correct_out = np.array([[[[[5.0], [11.0]]]]]).astype('float64')
correct_out_values = [[5.0], [11.0]]
assert np.array_equal(correct_out, fetch[0])
assert np.array_equal(correct_out_values, fetch[2])
assert out_indices.dtype == paddle.int32
weight_shape = [1, 3, 3, 1, 1]
weight = paddle.static.data(
name='weight', shape=weight_shape, dtype='float32'
)
bias_shape = [1]
bias = paddle.static.data(
name='bias', shape=bias_shape, dtype='float32'
)
out = sparse.nn.functional.conv3d(
sp_x,
weight,
bias,
stride=1,
padding=0,
dilation=1,
groups=1,
data_format="NDHWC",
)
sp_out = sparse.nn.functional.relu(out)
out_indices = sp_out.indices()
out_values = sp_out.values()
out = sp_out.to_dense()
exe = paddle.static.Executor()
indices_data = [
[0, 0, 0, 0],
[0, 0, 0, 0],
[0, 0, 1, 2],
[1, 3, 2, 3],
]
values_data = [[1.0], [2.0], [3.0], [4.0]]
weight_data = np.array(
[[[[[1], [1], [1]], [[1], [1], [1]], [[1], [1], [1]]]]]
).astype('float32')
weight_data = weight_data.reshape(weight_shape)
bias_data = np.array([1]).astype('float32')
fetch = exe.run(
feed={
'indices': indices_data,
'values': values_data,
'weight': weight_data,
'bias': bias_data,
},
fetch_list=[out, out_indices, out_values],
return_numpy=True,
)
correct_out = np.array([[[[[5.0], [11.0]]]]]).astype('float64')
correct_out_values = [[5.0], [11.0]]
np.testing.assert_array_equal(correct_out, fetch[0])
np.testing.assert_array_equal(correct_out_values, fetch[2])
self.assertTrue(out_indices.dtype == paddle.int32)
paddle.disable_static()
def test2D(self):
paddle.enable_static()
main = paddle.static.Program()
with paddle.static.program_guard(main):
indices = paddle.static.data(
name='indices', shape=[3, 4], dtype='int32'
)
values = paddle.static.data(
name='values', shape=[4, 1], dtype='float32'
)
dense_shape = [1, 3, 4, 1]
sp_x = sparse.sparse_coo_tensor(indices, values, dense_shape)
weight_shape = [3, 3, 1, 1]
weight = paddle.static.data(
name='weight', shape=weight_shape, dtype='float32'
)
bias_shape = [1]
bias = paddle.static.data(
name='bias', shape=bias_shape, dtype='float32'
)
out = sparse.nn.functional.conv2d(
sp_x,
weight,
bias,
stride=1,
padding=0,
dilation=1,
groups=1,
data_format="NHWC",
)
sp_out = sparse.nn.functional.relu(out)
out_indices = sp_out.indices()
out_values = sp_out.values()
out = sp_out.to_dense()
exe = paddle.static.Executor()
indices_data = [[0, 0, 0, 0], [0, 0, 1, 2], [1, 3, 2, 3]]
values_data = [[1.0], [2.0], [3.0], [4.0]]
weight_data = np.array(
[[[[[1], [1], [1]], [[1], [1], [1]], [[1], [1], [1]]]]]
).astype('float32')
weight_data = weight_data.reshape(weight_shape)
bias_data = np.array([1]).astype('float32')
fetch = exe.run(
feed={
'indices': indices_data,
'values': values_data,
'weight': weight_data,
'bias': bias_data,
},
fetch_list=[out, out_indices, out_values],
return_numpy=True,
)
correct_out = np.array([[[[5.0], [11.0]]]]).astype('float64')
correct_out_values = [[5.0], [11.0]]
np.testing.assert_array_equal(correct_out, fetch[0])
np.testing.assert_array_equal(correct_out_values, fetch[2])
self.assertTrue(out_indices.dtype == paddle.int32)
paddle.disable_static()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册