未验证 提交 2bc5121d 编写于 作者: L lilong12 提交者: GitHub

add the paddle.distributed.split api (#29970)

* add distributed.split, test=develop
上级 13aef970
......@@ -21,6 +21,7 @@ from ..fluid.layers.tensor import fill_constant
from ..fluid.layers import utils
from ..fluid.dygraph.parallel import prepare_context
import paddle
from .fleet import fleet
import paddle.fluid as fluid
import paddle.fluid.core as core
......@@ -31,6 +32,7 @@ __all__ = [
'all_gather',
'scatter',
'barrier',
'split',
'ReduceOp',
]
......@@ -485,3 +487,227 @@ def barrier(group=0):
inputs={'X': [temp]},
outputs={'Out': [temp]},
attrs={'ring_id': group})
def _parallel_linear(x, num_rows, num_cols, axis, param_attr, bias_attr,
gather_out, inner_rank, name):
"""
Parallel Linear
"""
if not name:
name = "fc_by_row_rank_%d" % inner_rank if axis == 0 else "fc_by_col_rank_%d" % inner_rank
else:
name = name + "_by_row_rank_%d" % inner_rank if axis == 0 else name + "_by_col_rank_%d" % inner_rank
linear = paddle.nn.Linear(
num_rows,
num_cols,
weight_attr=param_attr,
bias_attr=bias_attr,
name=name)
weight = linear.weight
weight.is_distributed = True
linear_out = linear(x)
startup_block = paddle.static.default_startup_program().global_block()
main_block = paddle.static.default_main_program().global_block()
startup_block.vars[weight.name].is_distributed = True
main_block.vars[weight.name].is_distributed = True
if gather_out:
if axis == 0:
paddle.distributed.all_reduce(linear_out, group=0)
else:
output = []
paddle.distributed.all_gather(output, linear_out, group=0)
linear_out = paddle.concat(output, axis=len(linear_out.shape) - 1)
return linear_out
def _parallel_embedding(x, per_part_embeddings, origin_size, param_attr,
inner_rank, num_partitions, name):
"""
Parallel Embedding
"""
if not name:
name = "emb_rank_%d" % inner_rank
else:
name = name + "_rank_%d" % inner_rank
origin_num_embeddings = origin_size[0]
embedding = paddle.nn.Embedding(
per_part_embeddings,
origin_size[1],
padding_idx=per_part_embeddings - 1,
sparse=False,
weight_attr=param_attr,
name=name)
origin_input_shape = x.shape
if len(origin_input_shape) == 2:
x = paddle.unsqueeze(x, axis=-1)
else:
assert origin_input_shape[-1] == 1, (
"The last dimension size of x must be 1.")
x_shard = paddle.shard_index(x, origin_num_embeddings, num_partitions,
inner_rank, per_part_embeddings - 1)
if len(origin_input_shape) == 2:
x_shard = paddle.squeeze(x_shard, axis=-1)
embedding.weight.is_distributed = True
emb_out = embedding(x_shard)
startup_block = paddle.static.default_startup_program().global_block()
main_block = paddle.static.default_main_program().global_block()
startup_block.vars[embedding.weight.name].is_distributed = True
main_block.vars[embedding.weight.name].is_distributed = True
paddle.distributed.all_reduce(emb_out, group=0)
return emb_out
def split(x,
size,
operation,
axis=0,
num_partitions=1,
gather_out=True,
weight_attr=None,
bias_attr=None,
name=None):
"""
Split the weight of the specified operation into multiple devices
and do the computation in parallel.
Now the following three cases are supported.
Case 1: Parallel Embedding
The weight of the embedding operation is a NxM matrix with N rows and M columns.
With parallel embedding, the weight is split into num_partitions partitions, each
of which is a matrix with (N/num_partitions + 1) rows and M column where the last
row as the padding idx.
Suppose we split the NxM weight into two partitons on device_0 and device_1
respectively. Then, one each device, the final weight has (N/2 + 1) rows with the
index range from 0 to N/2. On device_0, all values in the input within [0, N/2 -1]
keep unchanged and all other values are changed to N/2 which is the padding index and
are mapped to all zeros after embedding. In the same way, on device_1, the value V in the
input within [N/2, N-1] will be changed to (V - N/2), and all other values are changed
to N/2 and are mapped to all zeros after embedding. Finally, the results on the two
devices are sum-reduced.
Case 2: Row Parallel Linear
The weight of the linear operation is a NxM matrix with N rows and M columns.
With row parallel linear, the weight is split into num_partitions partitions, each
of which is a matrix with N/num_partitions rows and M column.
Case 3: Column Parallel Linear
The weight of the linear operation is a NxM matrix with N rows and M columns.
With column parallel linear, the weight is split into num_paratitions partitions, each
of which is a matrix with N rows and M/num_partitions column.
Args:
x (Tensor): Input tensor. It's data type should be float16, float32, float64, int32 or int64.
size (list|tuple): A list or tuple with two elements indicating the shape of the weight.
operation (str): The name of the operation. The supported operations are 'linear' and 'embedding'.
axis (int, Optional): Indicate along which axis to split the weight. Default: 0.
num_partitions (int, Optional): How many parts the weight is partitioned. Default: 1.
gather_out (bool, Optional): Whether to gather the output after computation. By default, the output
on each partitions will be gathered after computation. Default: True.
weight_attr (ParamAttr, Optional): The parameter attribute for the learnable
weights(Parameter) of the specified operation. Default: None.
bias_attr (ParamAttr, Optional): The parameter attribute for the bias
of the specified operation. Default: None.
name (str, Optional): The default value is None. Normally there is no need for user to set this
property. Default: None. For more information, please refer to :ref:`api_guide_Name`.
Returns:
Tensor.
Examples:
.. code-block:: python
import paddle
from paddle.distributed import init_parallel_env
paddle.set_device('gpu:%d'%paddle.distributed.ParallelEnv().dev_id)
init_parallel_env()
data = paddle.randint(0, 8, shape=[10,4])
emb_out = padle.distributed.split(
data,
(8, 8),
operation="embedding",
num_partitions=2)
"""
assert isinstance(size, (list, tuple)), (
"The type of size for "
"paddle.distributed.split must be list or tuple.")
assert len(size) == 2, ("Number of elements in size of "
"paddle.distributed.split must be two.")
assert isinstance(operation, str), ("The type of operation for "
"paddle.distributed.split must be str.")
supported_operations = [
'linear',
'embedding',
]
assert operation in supported_operations, (
"The operation for "
"paddle.distributed.split must be one of {}.".format(
supported_operations))
if in_dygraph_mode():
rank = paddle.distributed.get_rank()
nranks = paddle.distributed.get_world_size()
else:
assert fleet._role_maker, ("To use paddle.distributed.split, "
"you must call fleet.init() firstly.")
rank = fleet.worker_index()
nranks = fleet.worker_num()
# rank within a model parallel group
inner_rank = rank % num_partitions
if operation == "embedding":
assert axis == 0, ("We only support to split the weight of embedding "
"along the first axis now.")
per_part_size = (size[0] + num_partitions - 1) // num_partitions
last_part_size = size[0] - per_part_size * (num_partitions - 1)
if inner_rank == num_partitions - 1: per_part_size = last_part_size
per_part_size += 1 # make the last row as the padding index
emb_out = _parallel_embedding(x, per_part_size, size, weight_attr,
inner_rank, num_partitions, name)
return emb_out
else:
if axis == 0:
assert size[0] % num_partitions == 0, (
"Number of rows of the weight for linear ({}) must be"
" divisible by num_partitions ({})".format(size[0],
num_partitions))
per_part_size = size[0] // num_partitions
linear_size = (per_part_size, size[1])
assert x.shape[-1] == per_part_size, (
"The width ({}) of the input "
"x must be equal to the height ({}) of the weight. Maybe you "
"should split the input x using paddle.split.".format(
x.shape[-1], per_part_size))
elif axis == 1:
assert size[1] % num_partitions == 0, (
"Number of column of the weight for linear ({}) must be"
" divisible by num_partitions ({})".format(size[1],
num_partitions))
per_part_size = size[1] // num_partitions
linear_size = (size[0], per_part_size)
else:
raise ValueError("The value of axis must be 0 or 1, but the value "
"given is {}.".format(axis))
linear_out = _parallel_linear(
x,
linear_size[0],
linear_size[1],
axis,
weight_attr,
bias_attr,
gather_out,
inner_rank,
name=name)
return linear_out
......@@ -44,5 +44,9 @@ def _init_parallel_ctx():
def _broadcast_parameters(parameters):
for param in parameters:
# In model parallel, some parameters are split into multiple devices,
# so we could not broadcast these parameters.
if param.is_distributed: continue
if isinstance(param, Parameter) and param.trainable:
collective._broadcast(param, 0, sync_mode=True)
......@@ -73,6 +73,10 @@ if(NOT WITH_GPU OR WIN32)
LIST(REMOVE_ITEM TEST_OPS test_collective_sendrecv)
LIST(REMOVE_ITEM TEST_OPS test_reducescatter)
LIST(REMOVE_ITEM TEST_OPS test_reducescatter_api)
LIST(REMOVE_ITEM TEST_OPS test_collective_split_embedding)
LIST(REMOVE_ITEM TEST_OPS test_collective_split_embedding_none_divisible)
LIST(REMOVE_ITEM TEST_OPS test_collective_split_row_linear)
LIST(REMOVE_ITEM TEST_OPS test_collective_split_col_linear)
LIST(REMOVE_ITEM TEST_OPS test_collective_reduce_api)
LIST(REMOVE_ITEM TEST_OPS test_collective_scatter_api)
LIST(REMOVE_ITEM TEST_OPS test_collective_barrier_api)
......@@ -816,6 +820,17 @@ if(WITH_GPU AND NOT WIN32)
set_tests_properties(test_collective_barrier_api PROPERTIES TIMEOUT 120)
set_tests_properties(test_collective_scatter PROPERTIES TIMEOUT 120)
set_tests_properties(test_collective_sendrecv PROPERTIES TIMEOUT 120)
set_tests_properties(test_collective_split_embedding
test_collective_split_embedding_none_divisible
test_collective_split_row_linear
test_collective_split_col_linear
test_collective_scatter_api
test_collective_barrier_api
test_collective_reduce_api
test_collective_allreduce_api
test_collective_broadcast_api
test_collective_allgather_api
PROPERTIES LABELS "RUN_TYPE=DIST")
endif()
if(WITH_GPU)
set_tests_properties(test_imperative_auto_mixed_precision PROPERTIES TIMEOUT 120)
......
......@@ -47,10 +47,10 @@ class TestCollectiveScatterAPI(TestCollectiveAPIRunnerBase):
tindata = layers.data(
name="tindata",
shape=[10, 1000],
dtype='float64',
dtype='float32',
append_batch_size=False)
toutdata = layers.fill_constant(
shape=[5, 1000], dtype='float64', value=1.0)
shape=[5, 1000], dtype='float32', value=1.0)
tensor_list = None
if rank == 1:
tensor_list = paddle.split(tindata, 2, axis=0)
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import numpy as np
import argparse
import os
import sys
import signal
import time
import socket
from contextlib import closing
from six import string_types
import math
import paddle
import paddle.fluid as fluid
import paddle.fluid.profiler as profiler
import paddle.fluid.unique_name as nameGen
from paddle.fluid import core
import paddle.distributed.fleet as fleet
from paddle.fluid.incubate.fleet.base import role_maker
import unittest
from multiprocessing import Process
import paddle.fluid.layers as layers
from functools import reduce
from test_collective_api_base import TestCollectiveAPIRunnerBase, runtime_main
paddle.enable_static()
class TestColumnParallelLinearAPI(TestCollectiveAPIRunnerBase):
def __init__(self):
self.global_ring_id = 0
def get_model(self, main_prog, startup_program, rank):
with fluid.program_guard(main_prog, startup_program):
fleet.init(is_collective=True)
np.random.seed(2020)
np_array = np.random.rand(1000, 16)
data = paddle.static.data(
name='tindata', shape=[10, 1000], dtype="float32")
paddle.distributed.broadcast(data, src=0)
if rank == 0:
param_attr = paddle.fluid.ParamAttr(
initializer=paddle.fluid.initializer.NumpyArrayInitializer(
np_array[:, 0:8]), )
else:
param_attr = paddle.fluid.ParamAttr(
initializer=paddle.fluid.initializer.NumpyArrayInitializer(
np_array[:, 8:16]), )
linear_out = paddle.distributed.split(
data,
size=(1000, 16),
operation='linear',
axis=1,
num_partitions=2,
weight_attr=param_attr,
bias_attr=False, )
return [linear_out]
if __name__ == "__main__":
runtime_main(TestColumnParallelLinearAPI, "column_parallel_linear")
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import numpy as np
import argparse
import os
import sys
import signal
import time
import socket
from contextlib import closing
from six import string_types
import math
import paddle
import paddle.fluid as fluid
import paddle.fluid.profiler as profiler
import paddle.fluid.unique_name as nameGen
from paddle.fluid import core
import paddle.distributed.fleet as fleet
from paddle.fluid.incubate.fleet.base import role_maker
import unittest
from multiprocessing import Process
import paddle.fluid.layers as layers
from functools import reduce
from test_collective_api_base import TestCollectiveAPIRunnerBase, runtime_main
paddle.enable_static()
class TestParallelEmbeddingAPI(TestCollectiveAPIRunnerBase):
def __init__(self):
self.global_ring_id = 0
def get_model(self, main_prog, startup_program, rank):
with fluid.program_guard(main_prog, startup_program):
fleet.init(is_collective=True)
np.random.seed(2020)
np_array = np.random.rand(10, 8)
paddle.seed(2020)
data_in = paddle.randint(0, 8, shape=(10, 4))
data = paddle.static.data(
name='tindata', shape=[10, 1000], dtype="float32")
if rank == 0:
param_attr = paddle.fluid.ParamAttr(
initializer=paddle.fluid.initializer.NumpyArrayInitializer(
np_array[0:5, :]), )
else:
param_attr = paddle.fluid.ParamAttr(
initializer=paddle.fluid.initializer.NumpyArrayInitializer(
np_array[5:10, :]), )
emb_out = paddle.distributed.split(
data_in, (8, 8),
operation="embedding",
num_partitions=2,
weight_attr=param_attr)
return [data_in, emb_out]
if __name__ == "__main__":
runtime_main(TestParallelEmbeddingAPI, "parallel_embedding")
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import numpy as np
import argparse
import os
import sys
import signal
import time
import socket
from contextlib import closing
from six import string_types
import math
import paddle
import paddle.fluid as fluid
import paddle.fluid.profiler as profiler
import paddle.fluid.unique_name as nameGen
from paddle.fluid import core
import paddle.distributed.fleet as fleet
from paddle.fluid.incubate.fleet.base import role_maker
import unittest
from multiprocessing import Process
import paddle.fluid.layers as layers
from functools import reduce
from test_collective_api_base import TestCollectiveAPIRunnerBase, runtime_main
paddle.enable_static()
class TestParallelEmbeddingAPINoneDivisible(TestCollectiveAPIRunnerBase):
def __init__(self):
self.global_ring_id = 0
def get_model(self, main_prog, startup_program, rank):
with fluid.program_guard(main_prog, startup_program):
fleet.init(is_collective=True)
np.random.seed(2020)
np_array = np.random.rand(9, 8)
paddle.seed(2020)
data_in = paddle.randint(0, 7, shape=(10, 4))
data = paddle.static.data(
name='tindata', shape=[10, 1000], dtype="float32")
if rank == 0:
param_attr = paddle.fluid.ParamAttr(
initializer=paddle.fluid.initializer.NumpyArrayInitializer(
np_array[0:5, :]), )
else:
param_attr = paddle.fluid.ParamAttr(
initializer=paddle.fluid.initializer.NumpyArrayInitializer(
np_array[5:9, :]), )
emb_out = paddle.distributed.split(
data_in, (7, 8),
operation="embedding",
num_partitions=2,
weight_attr=param_attr)
return [data_in, emb_out]
if __name__ == "__main__":
runtime_main(TestParallelEmbeddingAPINoneDivisible, "parallel_embedding")
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import numpy as np
import argparse
import os
import sys
import signal
import time
import socket
from contextlib import closing
from six import string_types
import math
import paddle
import paddle.fluid as fluid
import paddle.fluid.profiler as profiler
import paddle.fluid.unique_name as nameGen
from paddle.fluid import core
import paddle.distributed.fleet as fleet
from paddle.fluid.incubate.fleet.base import role_maker
import unittest
from multiprocessing import Process
import paddle.fluid.layers as layers
from functools import reduce
from test_collective_api_base import TestCollectiveAPIRunnerBase, runtime_main
paddle.enable_static()
class TestRowParallelLinearAPI(TestCollectiveAPIRunnerBase):
def __init__(self):
self.global_ring_id = 0
def get_model(self, main_prog, startup_program, rank):
with fluid.program_guard(main_prog, startup_program):
fleet.init(is_collective=True)
np.random.seed(2020)
np_array = np.random.rand(1000, 16)
data = paddle.static.data(
name='tindata', shape=[10, 1000], dtype="float32")
paddle.distributed.broadcast(data, src=0)
data = paddle.split(data, 2, axis=1)[rank]
if rank == 0:
param_attr = paddle.fluid.ParamAttr(
initializer=paddle.fluid.initializer.NumpyArrayInitializer(
np_array[0:500, :]), )
else:
param_attr = paddle.fluid.ParamAttr(
initializer=paddle.fluid.initializer.NumpyArrayInitializer(
np_array[500:1000, :]), )
linear_out = paddle.distributed.split(
data,
size=(1000, 8),
operation='linear',
axis=0,
num_partitions=2,
weight_attr=param_attr,
bias_attr=False, )
return [linear_out]
if __name__ == "__main__":
runtime_main(TestRowParallelLinearAPI, "row_parallel_linear")
......@@ -55,7 +55,7 @@ class TestCollectiveAPIRunnerBase(object):
exe = fluid.Executor(place)
exe.run(startup_prog)
np.random.seed(os.getpid())
indata = np.random.random((10, 1000))
indata = np.random.random((10, 1000)).astype("float32")
fetch_list = []
for elem in result:
fetch_list.append(elem.name)
......@@ -219,5 +219,31 @@ class TestDistBase(unittest.TestCase):
self.assertTrue(
np.allclose(
tr1_out, need_result, rtol=1e-05, atol=1e-05))
elif col_type == "parallel_embedding":
result_data = tr0_out[0]
np.random.seed(2020)
need_result = np.random.rand(10, 8)
for i in range(result_data.shape[0]):
for j in range(result_data.shape[1]):
data = result_data[i][j]
if data >= 4: data += 1
assert np.allclose(
tr0_out[1][i][j], need_result[data], atol=1e-08)
elif col_type == "row_parallel_linear":
result_data = tr0_out[0]
np.random.seed(2020)
weight = np.random.rand(1000, 16)
need_result = np.matmul(input1, weight)
self.assertTrue(
np.allclose(
result_data, need_result, rtol=1e-05, atol=1e-05))
elif col_type == "column_parallel_linear":
result_data = tr0_out[0]
np.random.seed(2020)
weight = np.random.rand(1000, 16)
need_result = np.matmul(input1, weight)
self.assertTrue(
np.allclose(
result_data, need_result, rtol=1e-05, atol=1e-05))
else:
pass
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import paddle
from test_collective_api_base import TestDistBase
paddle.enable_static()
class TestColParallelLinearAPI(TestDistBase):
def _setup_config(self):
pass
def test_col_parallel_linear(self):
self.check_with_place("column_parallel_linear_api.py",
"column_parallel_linear", "nccl")
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import paddle
from test_collective_api_base import TestDistBase
paddle.enable_static()
class TestParallelEmbeddingAPI(TestDistBase):
def _setup_config(self):
pass
def test_parallel_embedding(self):
self.check_with_place("parallel_embedding_api.py", "parallel_embedding",
"nccl")
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import paddle
from test_collective_api_base import TestDistBase
paddle.enable_static()
class TestParallelEmbeddingNoneDivisibleAPI(TestDistBase):
def _setup_config(self):
pass
def test_parallel_embedding_none_divisible(self):
self.check_with_place("parallel_embedding_api_none_divisible.py",
"parallel_embedding", "nccl")
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import paddle
from test_collective_api_base import TestDistBase
paddle.enable_static()
class TestRowParallelLinearAPI(TestDistBase):
def _setup_config(self):
pass
def test_row_parallel_linear(self):
self.check_with_place("row_parallel_linear_api.py",
"row_parallel_linear", "nccl")
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册