未验证 提交 2b108a04 编写于 作者: L lilong12 提交者: GitHub

add c_concat and c_split ops (#32486)

* add c_concat op
上级 b6f8ccd2
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/collective/c_concat_op.h"
namespace paddle {
namespace operators {
class CConcatOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "c_concat");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "c_concat");
int nranks = ctx->Attrs().Get<int>("nranks");
int rank = ctx->Attrs().Get<int>("rank");
int ring_id = ctx->Attrs().Get<int>("ring_id");
PADDLE_ENFORCE_GE(nranks, 2, platform::errors::InvalidArgument(
"The number of ranks (%d) for c_concat "
"must be greater than 1.",
nranks));
PADDLE_ENFORCE_GE(
ring_id, 0,
platform::errors::InvalidArgument(
"The ring_id (%d) for c_concat must be non-negative.", ring_id));
PADDLE_ENFORCE_GE(
rank, 0, platform::errors::InvalidArgument(
"The rank (%d) for c_concat must be non-negative.", rank));
PADDLE_ENFORCE_LT(rank, nranks,
platform::errors::InvalidArgument(
"The value of rank (%d) for c_concat must "
"be less than that of nranks.",
rank, nranks));
framework::DDim dim = ctx->GetInputDim("X");
dim[dim.size() - 1] = dim[dim.size() - 1] * nranks;
if (dim[dim.size() - 1] < 0) dim[dim.size() - 1] = -1;
ctx->SetOutputDim("Out", dim);
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
}
};
template <typename T>
class CConcatOpGradMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> retv) const override {
retv->SetType("c_split");
retv->SetInput("X", this->OutputGrad("Out"));
retv->SetOutput("Out", this->InputGrad("X"));
retv->SetAttrMap(this->Attrs());
}
};
class CConcatOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() {
AddInput("X", "(Tensor) tensor to be concated.");
AddOutput("Out", "(Tensor) the result of concat.");
AddAttr<int>("rank", "(int default 0) rank id.").SetDefault(0);
AddAttr<int>("nranks", "(int default 1) number of ranks.").SetDefault(1);
AddAttr<int>("ring_id", "(int default 0) ring id.").SetDefault(0);
AddAttr<bool>(
"use_calc_stream",
"(bool default true) eject CUDA operations to calculation stream.")
.SetDefault(true);
AddAttr<bool>("use_model_parallel",
"(bool default true) use this op with model parallel.")
.SetDefault(true);
AddComment(R"DOC(
CConcat Operator
AllGather the tensors on different trainers and concat them along the last dimension.
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OPERATOR(c_concat, ops::CConcatOp,
ops::CConcatOpGradMaker<paddle::framework::OpDesc>,
ops::CConcatOpGradMaker<paddle::imperative::OpBase>,
ops::CConcatOpMaker);
REGISTER_OP_CPU_KERNEL(c_concat, ops::CConcatOpCPUKernel<float>,
ops::CConcatOpCPUKernel<double>,
ops::CConcatOpCPUKernel<int>,
ops::CConcatOpCPUKernel<int64_t>,
ops::CConcatOpCPUKernel<plat::float16>);
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <vector>
#include "paddle/fluid/operators/collective/c_concat_op.h"
#include "paddle/fluid/operators/math/concat_and_split.h"
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/nccl_helper.h"
#endif
namespace paddle {
namespace operators {
template <typename T>
class CConcatOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto x = ctx.Input<framework::Tensor>("X");
auto out = ctx.Output<framework::Tensor>("Out");
ncclDataType_t dtype = platform::ToNCCLDataType(x->type());
int nranks = ctx.Attr<int>("nranks");
int rank = ctx.Attr<int>("rank");
int rid = ctx.Attr<int>("ring_id");
auto place = ctx.GetPlace();
PADDLE_ENFORCE_GE(rank, 0,
platform::errors::PreconditionNotMet(
"The value of rank (%d) for c_concat must be "
"greater than or equal to 0.",
rank));
PADDLE_ENFORCE_GE(nranks, 2,
platform::errors::PreconditionNotMet(
"The value of nranks (%d) for c_concat must be "
"greater than or equal to 2.",
nranks));
PADDLE_ENFORCE_LT(rank, nranks,
platform::errors::PreconditionNotMet(
"The value of rank (%d) for c_concat must be "
"less than that of nranks (%d).",
rank, nranks));
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
auto comm = platform::NCCLCommContext::Instance().Get(rid, place);
PADDLE_ENFORCE_EQ(
nranks, comm->nranks(),
platform::errors::InvalidArgument("nranks: %s should equal to %s",
nranks, comm->nranks()));
framework::Tensor temp_out;
framework::DDim temp_out_dims = x->dims();
temp_out_dims[0] *= nranks;
temp_out.mutable_data<T>(temp_out_dims, place);
int64_t send_numel = x->numel();
const T* send_buff = x->data<T>();
T* recv_buff = temp_out.data<T>();
gpuStream_t stream = nullptr;
auto dev_ctx = platform::DeviceContextPool::Instance().Get(place);
stream = static_cast<platform::CUDADeviceContext*>(dev_ctx)->stream();
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclAllGather(
send_buff, recv_buff, send_numel, static_cast<ncclDataType_t>(dtype),
comm->comm(), stream));
std::vector<framework::Tensor> inputs;
int axis = x->dims().size() - 1;
auto out_dims = x->dims();
out_dims[out_dims.size() - 1] *= nranks;
int rows_per_tensor = x->dims()[0];
int offset = 0;
for (int i = 0; i < nranks; i++) {
framework::Tensor temp = temp_out.Slice(offset, offset + rows_per_tensor);
inputs.emplace_back(temp);
offset += rows_per_tensor;
}
math::ConcatFunctor<platform::CUDADeviceContext, T> functor;
out->mutable_data<T>(out_dims, place);
auto& dev_ctx2 = ctx.template device_context<platform::CUDADeviceContext>();
functor(dev_ctx2, inputs, axis, out);
#else
PADDLE_THROW(platform::errors::PreconditionNotMet(
"PaddlePaddle should compile with GPU."));
#endif
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(c_concat, ops::CConcatOpCUDAKernel<float>,
ops::CConcatOpCUDAKernel<double>,
ops::CConcatOpCUDAKernel<int>,
ops::CConcatOpCUDAKernel<int64_t>,
ops::CConcatOpCUDAKernel<plat::float16>);
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <algorithm>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
template <typename T>
class CConcatOpCPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_THROW(platform::errors::Unavailable(
"Do not support c_concat for cpu kernel now."));
}
};
} // namespace operators
} // namespace paddle
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/collective/c_split_op.h"
namespace paddle {
namespace operators {
class CSplitOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "c_split");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "c_split");
int nranks = ctx->Attrs().Get<int>("nranks");
int rank = ctx->Attrs().Get<int>("rank");
int ring_id = ctx->Attrs().Get<int>("ring_id");
PADDLE_ENFORCE_GE(nranks, 2, platform::errors::InvalidArgument(
"The number of ranks (%d) for c_split "
"must be greater than 1.",
nranks));
PADDLE_ENFORCE_GE(
ring_id, 0,
platform::errors::InvalidArgument(
"The ring_id (%d) for c_split must be non-negative.", ring_id));
PADDLE_ENFORCE_GE(
rank, 0, platform::errors::InvalidArgument(
"The rank (%d) for c_split must be non-negative.", rank));
PADDLE_ENFORCE_LT(rank, nranks,
platform::errors::InvalidArgument(
"The value of rank (%d) for c_split must "
"be less than that of nranks.",
rank, nranks));
framework::DDim dim = ctx->GetInputDim("X");
dim[dim.size() - 1] = dim[dim.size() - 1] / nranks;
if (dim[0] < 0) dim[0] = -1;
ctx->SetOutputDim("Out", dim);
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
}
};
template <typename T>
class CSplitOpGradMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> retv) const override {
retv->SetType("c_allgather");
retv->SetInput("X", this->OutputGrad("Out"));
retv->SetOutput("Out", this->InputGrad("X"));
retv->SetAttrMap(this->Attrs());
}
};
class CSplitOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() {
AddInput("X", "(Tensor) tensor to be split.");
AddOutput("Out", "(Tensor) the result of split.");
AddAttr<int>("rank", "(int default 0) rank id.").SetDefault(0);
AddAttr<int>("nranks", "(int default 1) number of ranks.").SetDefault(1);
AddAttr<int>("ring_id", "(int default 0) ring id.").SetDefault(0);
AddAttr<bool>(
"use_calc_stream",
"(bool default false) eject CUDA operations to calculation stream.")
.SetDefault(false);
AddAttr<bool>("use_model_parallel",
"(bool default false) use this op with model parallel.")
.SetDefault(true);
AddComment(R"DOC(
CSplit Operator
Split the tensor evenly according to its rank.
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OPERATOR(c_split, ops::CSplitOp,
ops::CSplitOpGradMaker<paddle::framework::OpDesc>,
ops::CSplitOpGradMaker<paddle::imperative::OpBase>,
ops::CSplitOpMaker);
REGISTER_OP_CPU_KERNEL(c_split, ops::CSplitOpCPUKernel<float>,
ops::CSplitOpCPUKernel<double>,
ops::CSplitOpCPUKernel<int>,
ops::CSplitOpCPUKernel<int64_t>,
ops::CSplitOpCPUKernel<plat::float16>);
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <vector>
#include "paddle/fluid/operators/collective/c_split_op.h"
#include "paddle/fluid/operators/math/concat_and_split.h"
namespace paddle {
namespace operators {
template <typename T>
class CSplitOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto x = ctx.Input<framework::Tensor>("X");
auto out = ctx.Output<framework::Tensor>("Out");
int nranks = ctx.Attr<int>("nranks");
int rank = ctx.Attr<int>("rank");
auto place = ctx.GetPlace();
PADDLE_ENFORCE_GE(rank, 0, platform::errors::PreconditionNotMet(
"The value of rank (%d) for c_split must be "
"greater than or equal to 0.",
rank));
PADDLE_ENFORCE_GE(nranks, 2,
platform::errors::PreconditionNotMet(
"The value of nranks (%d) for c_split must be "
"greater than or equal to 2.",
nranks));
PADDLE_ENFORCE_LT(rank, nranks,
platform::errors::PreconditionNotMet(
"The value of rank (%d) for c_split must be "
"less than that of nranks (%d).",
rank, nranks));
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
std::vector<const framework::Tensor*> shape_refer;
std::vector<framework::Tensor*> results;
size_t numel = x->numel();
auto dims = x->dims();
numel /= nranks;
int axis = dims.size() - 1;
dims[dims.size() - 1] /= nranks;
for (int i = 0; i < nranks; i++) {
framework::Tensor* out = new framework::Tensor();
out->mutable_data<T>(dims, place);
shape_refer.emplace_back(out);
results.emplace_back(out);
}
math::SplitFunctor<platform::CUDADeviceContext, T> functor;
functor(dev_ctx, *x, shape_refer, axis, &results);
out->mutable_data<T>(dims, place);
paddle::framework::TensorCopySync(*results[rank], out->place(), out);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(c_split, ops::CSplitOpCUDAKernel<float>,
ops::CSplitOpCUDAKernel<double>,
ops::CSplitOpCUDAKernel<int>,
ops::CSplitOpCUDAKernel<int64_t>,
ops::CSplitOpCUDAKernel<plat::float16>);
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <algorithm>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
template <typename T>
class CSplitOpCPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_THROW(platform::errors::Unavailable(
"Do not support c_split for cpu kernel now."));
}
};
} // namespace operators
} // namespace paddle
......@@ -71,6 +71,8 @@ endforeach()
if(((NOT WITH_ROCM) AND (NOT WITH_GPU)) OR WIN32)
LIST(REMOVE_ITEM TEST_OPS test_c_comm_init_all_op)
LIST(REMOVE_ITEM TEST_OPS test_c_concat)
LIST(REMOVE_ITEM TEST_OPS test_c_split)
LIST(REMOVE_ITEM TEST_OPS test_allgather)
LIST(REMOVE_ITEM TEST_OPS test_allreduce)
LIST(REMOVE_ITEM TEST_OPS test_broadcast)
......@@ -873,6 +875,8 @@ if((WITH_ROCM OR WITH_GPU) AND NOT WIN32)
set_tests_properties(test_collective_reduce_api PROPERTIES TIMEOUT 120)
set_tests_properties(test_collective_reduce PROPERTIES TIMEOUT 120)
set_tests_properties(test_allreduce PROPERTIES TIMEOUT 120)
set_tests_properties(test_c_concat PROPERTIES TIMEOUT 120)
set_tests_properties(test_c_split PROPERTIES TIMEOUT 120)
set_tests_properties(test_allgather PROPERTIES TIMEOUT 120)
set_tests_properties(test_collective_scatter_api PROPERTIES TIMEOUT 120)
set_tests_properties(test_collective_barrier_api PROPERTIES TIMEOUT 120)
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import numpy as np
import argparse
import os
import sys
import signal
import time
from contextlib import closing
from six import string_types
import math
import paddle
import paddle.fluid as fluid
import paddle.fluid.profiler as profiler
import paddle.fluid.unique_name as nameGen
from paddle.fluid import core
import unittest
from multiprocessing import Process
import paddle.fluid.layers as layers
from functools import reduce
from test_collective_base import TestCollectiveRunnerBase, runtime_main
paddle.enable_static()
class TestCollectiveConcat(TestCollectiveRunnerBase):
def __init__(self):
self.global_ring_id = 0
def get_model(self, main_prog, startup_program):
ring_id = 0
nranks = 2
with fluid.program_guard(main_prog, startup_program):
tindata = layers.data(
name="tindata", shape=[10, 1000], dtype='float32')
toutdata = main_prog.current_block().create_var(
name="outofconcat",
dtype='float32',
type=core.VarDesc.VarType.LOD_TENSOR,
persistable=False,
stop_gradient=False)
main_prog.global_block().append_op(
type="c_concat",
inputs={'X': tindata},
attrs={
'ring_id': ring_id,
'rank': self.rank,
'nranks': nranks
},
outputs={'Out': toutdata})
return toutdata
if __name__ == "__main__":
runtime_main(TestCollectiveConcat, "concat", 0)
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import numpy as np
import argparse
import os
import sys
import signal
import time
from contextlib import closing
from six import string_types
import math
import paddle
import paddle.fluid as fluid
import paddle.fluid.profiler as profiler
import paddle.fluid.unique_name as nameGen
from paddle.fluid import core
import unittest
from multiprocessing import Process
import paddle.fluid.layers as layers
from functools import reduce
from test_collective_base import TestCollectiveRunnerBase, runtime_main
paddle.enable_static()
class TestCollectiveAllGather(TestCollectiveRunnerBase):
def __init__(self):
self.global_ring_id = 0
def get_model(self, main_prog, startup_program):
ring_id = 0
nranks = 2
with fluid.program_guard(main_prog, startup_program):
tindata = layers.data(
name="tindata", shape=[10, 1000], dtype='float32')
toutdata = main_prog.current_block().create_var(
name="outofsplit",
dtype='float32',
type=core.VarDesc.VarType.LOD_TENSOR,
persistable=False,
stop_gradient=False)
main_prog.global_block().append_op(
type="c_split",
inputs={'X': tindata},
attrs={
'ring_id': ring_id,
'rank': self.rank,
'nranks': nranks
},
outputs={'Out': toutdata})
return toutdata
if __name__ == "__main__":
runtime_main(TestCollectiveAllGather, "split", 0)
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import paddle
from test_collective_base import TestDistBase
paddle.enable_static()
class TestConcatOp(TestDistBase):
def _setup_config(self):
pass
def test_concat(self, col_type="concat"):
self.check_with_place("collective_concat_op.py", col_type)
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import paddle
from test_collective_base import TestDistBase
paddle.enable_static()
class TestSplitOp(TestDistBase):
def _setup_config(self):
pass
def test_split(self, col_type="split"):
self.check_with_place("collective_split_op.py", col_type)
if __name__ == '__main__':
unittest.main()
......@@ -284,5 +284,22 @@ class TestDistBase(unittest.TestCase):
need_result2 = np.concatenate((tmp20, tmp21), axis=1)
self.assertTrue(np.allclose(tr0_out, need_result1))
self.assertTrue(np.allclose(tr1_out, need_result2))
elif col_type == "concat":
need_result = np.concatenate((input1, input2), axis=1)
self.assertTrue(
np.allclose(
tr0_out, need_result, rtol=1e-05, atol=1e-05))
self.assertTrue(
np.allclose(
tr1_out, need_result, rtol=1e-05, atol=1e-05))
elif col_type == "split":
need_result1 = np.split(input1, 2, axis=1)[0]
need_result2 = np.split(input2, 2, axis=1)[1]
self.assertTrue(
np.allclose(
tr0_out, need_result1, rtol=1e-05, atol=1e-05))
self.assertTrue(
np.allclose(
tr1_out, need_result2, rtol=1e-05, atol=1e-05))
else:
pass
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册