未验证 提交 a08a118d 编写于 作者: X xiemoyuan 提交者: GitHub

Support list and tuple for args. (#32344)

* Support list and tuple for parameters of layer_norm, multiprocess_reader, DatasetFolder and ImageFolder.

* add unittest for layer_norm.

* add require gpu for example.
上级 85e697d7
...@@ -1080,10 +1080,12 @@ def split(x, ...@@ -1080,10 +1080,12 @@ def split(x,
import paddle import paddle
from paddle.distributed import init_parallel_env from paddle.distributed import init_parallel_env
# required: gpu
paddle.set_device('gpu:%d'%paddle.distributed.ParallelEnv().dev_id) paddle.set_device('gpu:%d'%paddle.distributed.ParallelEnv().dev_id)
init_parallel_env() init_parallel_env()
data = paddle.randint(0, 8, shape=[10,4]) data = paddle.randint(0, 8, shape=[10,4])
emb_out = padle.distributed.split( emb_out = paddle.distributed.split(
data, data,
(8, 8), (8, 8),
operation="embedding", operation="embedding",
......
...@@ -82,5 +82,60 @@ class TestDygraphLayerNormv2(unittest.TestCase): ...@@ -82,5 +82,60 @@ class TestDygraphLayerNormv2(unittest.TestCase):
self.assertTrue(np.allclose(y1, y2)) self.assertTrue(np.allclose(y1, y2))
class TestLayerNormFunction(unittest.TestCase):
def test_dygraph(self):
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda() and core.op_support_gpu("layer_norm"):
places.append(fluid.CUDAPlace(0))
for p in places:
shape = [4, 10, 4, 4]
def compute_v0(x):
with fluid.dygraph.guard(p):
ln = fluid.dygraph.LayerNorm(shape[1:])
y = ln(fluid.dygraph.to_variable(x))
return y.numpy()
def compute_v1(x):
with fluid.dygraph.guard(p):
x = fluid.dygraph.to_variable(x)
y = paddle.nn.functional.layer_norm(x, shape[1:])
return y.numpy()
def compute_v2(x):
with fluid.dygraph.guard(p):
x = fluid.dygraph.to_variable(x)
y = paddle.nn.functional.layer_norm(x, tuple(shape[1:]))
return y.numpy()
def compute_v3(x):
with fluid.dygraph.guard(p):
ln = fluid.dygraph.LayerNorm(shape[-1])
y = ln(fluid.dygraph.to_variable(x))
return y.numpy()
def compute_v4(x):
with fluid.dygraph.guard(p):
x = fluid.dygraph.to_variable(x)
y = paddle.nn.functional.layer_norm(x, shape[-1])
return y.numpy()
x = np.random.randn(*shape).astype("float32")
y0 = compute_v0(x)
y1 = compute_v1(x)
y2 = compute_v2(x)
self.assertTrue(np.allclose(y0, y1))
self.assertTrue(np.allclose(y0, y2))
y3 = compute_v3(x)
y4 = compute_v4(x)
self.assertTrue(np.allclose(y3, y4))
self.assertRaises(
ValueError,
paddle.nn.functional.layer_norm,
x=x,
normalized_shape=1.0)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -23,6 +23,8 @@ from ...fluid.initializer import Constant ...@@ -23,6 +23,8 @@ from ...fluid.initializer import Constant
from ...fluid.param_attr import ParamAttr from ...fluid.param_attr import ParamAttr
from ...fluid import core, dygraph_utils from ...fluid import core, dygraph_utils
import numbers
__all__ = [ __all__ = [
'batch_norm', 'batch_norm',
# 'data_norm', # 'data_norm',
...@@ -289,6 +291,14 @@ def layer_norm(x, ...@@ -289,6 +291,14 @@ def layer_norm(x,
""" """
input_shape = list(x.shape) input_shape = list(x.shape)
input_ndim = len(input_shape) input_ndim = len(input_shape)
if isinstance(normalized_shape, numbers.Integral):
normalized_shape = [normalized_shape]
elif isinstance(normalized_shape, tuple):
normalized_shape = list(normalized_shape)
elif not isinstance(normalized_shape, list):
raise ValueError(
"`normalized_shape` should be int, list of ints or tuple of ints.")
normalized_ndim = len(normalized_shape) normalized_ndim = len(normalized_shape)
begin_norm_axis = input_ndim - normalized_ndim begin_norm_axis = input_ndim - normalized_ndim
if input_ndim < normalized_ndim or input_shape[ if input_ndim < normalized_ndim or input_shape[
......
...@@ -588,7 +588,8 @@ def multiprocess_reader(readers, use_pipe=True, queue_size=1000): ...@@ -588,7 +588,8 @@ def multiprocess_reader(readers, use_pipe=True, queue_size=1000):
sys.stderr.write("import ujson error: " + str(e) + " use json\n") sys.stderr.write("import ujson error: " + str(e) + " use json\n")
import json import json
assert type(readers) is list and len(readers) > 0 assert isinstance(readers, (list, tuple)) and len(readers) > 0, (
"`readers` must be list or tuple.")
def _read_into_queue(reader, queue): def _read_into_queue(reader, queue):
try: try:
......
...@@ -28,11 +28,14 @@ def has_valid_extension(filename, extensions): ...@@ -28,11 +28,14 @@ def has_valid_extension(filename, extensions):
Args: Args:
filename (str): path to a file filename (str): path to a file
extensions (tuple of str): extensions to consider (lowercase) extensions (list[str]|tuple[str]): extensions to consider
Returns: Returns:
bool: True if the filename ends with one of given extensions bool: True if the filename ends with one of given extensions
""" """
assert isinstance(extensions,
(list, tuple)), ("`extensions` must be list or tuple.")
extensions = tuple([x.lower() for x in extensions])
return filename.lower().endswith(extensions) return filename.lower().endswith(extensions)
...@@ -73,7 +76,7 @@ class DatasetFolder(Dataset): ...@@ -73,7 +76,7 @@ class DatasetFolder(Dataset):
Args: Args:
root (string): Root directory path. root (string): Root directory path.
loader (callable|optional): A function to load a sample given its path. loader (callable|optional): A function to load a sample given its path.
extensions (tuple[str]|optional): A list of allowed extensions. extensions (list[str]|tuple[str]|optional): A list of allowed extensions.
both extensions and is_valid_file should not be passed. both extensions and is_valid_file should not be passed.
transform (callable|optional): A function/transform that takes in transform (callable|optional): A function/transform that takes in
a sample and returns a transformed version. a sample and returns a transformed version.
...@@ -226,7 +229,7 @@ class ImageFolder(Dataset): ...@@ -226,7 +229,7 @@ class ImageFolder(Dataset):
Args: Args:
root (string): Root directory path. root (string): Root directory path.
loader (callable, optional): A function to load a sample given its path. loader (callable, optional): A function to load a sample given its path.
extensions (tuple[string], optional): A list of allowed extensions. extensions (list[str]|tuple[str], optional): A list of allowed extensions.
both extensions and is_valid_file should not be passed. both extensions and is_valid_file should not be passed.
transform (callable, optional): A function/transform that takes in transform (callable, optional): A function/transform that takes in
a sample and returns a transformed version. a sample and returns a transformed version.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册