未验证 提交 0bc8bdf7 编写于 作者: L lilong12 提交者: GitHub

set dim[0] to -1 if dim[0] < 0 during compiling for c_allgather op (#21402)

* set dim[0] to -1 if dim[0] < 0 and remove assertion to runtime, test=develop

* modify ENFORCE message, test=develop

* add validation for x.shape[0] > 0, test=develop

* add ut, test=develop
上级 c5f0293c
......@@ -29,6 +29,7 @@ class CAllGatherOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_GE(nranks, 2, "nranks should be >=2");
framework::DDim dim = ctx->GetInputDim("X");
dim[0] = dim[0] * nranks;
if (dim[0] < 0) dim[0] = -1;
ctx->SetOutputDim("Out", dim);
}
};
......
......@@ -36,6 +36,11 @@ class CReduceScatterOpCUDAKernel : public framework::OpKernel<T> {
int nranks = comm->nranks();
auto out_dims = in->dims();
PADDLE_ENFORCE_EQ(out_dims[0] % nranks, 0,
platform::errors::InvalidArgument(
"The input tensor X's "
"dim[0] (%d) should be divisible by nranks(%d)",
out_dims[0], nranks));
out_dims[0] = out_dims[0] / nranks;
out->mutable_data<T>(out_dims, place);
......
......@@ -134,7 +134,7 @@ def _c_reducescatter(x, nranks, ring_id=0, use_calc_stream=False):
if not isinstance(x, Variable):
raise TypeError('x must be a Variable')
if x.shape[0] % nranks != 0:
if x.shape[0] > 0 and x.shape[0] % nranks != 0:
raise ValueError('x.shape[0](%d) cannot be evenly divided by nranks(%d)'
% (x.shape[0], nranks))
......
......@@ -29,6 +29,7 @@ if(NOT WITH_GPU OR WIN32)
LIST(REMOVE_ITEM TEST_OPS test_allreduce)
LIST(REMOVE_ITEM TEST_OPS test_broadcast)
LIST(REMOVE_ITEM TEST_OPS test_reducescatter)
LIST(REMOVE_ITEM TEST_OPS test_reducescatter_api)
endif()
if(WIN32)
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import numpy as np
import argparse
import os
import sys
import signal
import time
from contextlib import closing
from six import string_types
import math
import paddle
import paddle.fluid as fluid
import paddle.fluid.profiler as profiler
import paddle.fluid.unique_name as nameGen
from paddle.fluid import core
import unittest
from multiprocessing import Process
import paddle.fluid.layers as layers
from functools import reduce
from test_collective_base import TestCollectiveRunnerBase, runtime_main
class TestCollectiveReduceScatter(TestCollectiveRunnerBase):
def __init__(self):
self.global_ring_id = 0
def get_model(self, main_prog, startup_program):
ring_id = 0
nranks = 2
with fluid.program_guard(main_prog, startup_program):
tindata = layers.data(
name="tindata", shape=[10, 1000], dtype='float32')
toutdata = fluid.layers.collective._c_reducescatter(tindata, nranks)
return toutdata
if __name__ == "__main__":
runtime_main(TestCollectiveReduceScatter, "reducescatter", 0)
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import paddle.fluid as fluid
from test_collective_base import TestDistBase
class TestReduceScatterAPI(TestDistBase):
def _setup_config(self):
pass
def test_reducescatter(self, col_type="reduce_scatter"):
self.check_with_place("collective_reducescatter.py", col_type)
def test_reducescatter_with_error(self):
nranks = 2
tindata = fluid.data(name="tindata", shape=[5, 1000], dtype='float32')
try:
toutdata = fluid.layers.collective._c_reducescatter(tindata, nranks)
except ValueError:
pass
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册