未验证 提交 a09b9a3f 编写于 作者: J jameszhang 提交者: GitHub

kunlun add support for c_concat and c_split (#49757)

* kunlun add support for c_concat and c_split

* replace mutable_data() and ShareDataWith()
上级 a99c3cd4
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/collective/c_concat_op.h"
#include <vector>
#include "paddle/phi/api/include/tensor.h"
#include "paddle/phi/kernels/funcs//concat_and_split_functor.h"
#if defined(PADDLE_WITH_XPU_BKCL)
#include "paddle/fluid/distributed/collective/process_group.h"
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/device/xpu/bkcl_helper.h"
#endif
namespace paddle {
namespace operators {
template <typename T>
class CConcatOpXPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto x = ctx.Input<phi::DenseTensor>("X");
auto out = ctx.Output<phi::DenseTensor>("Out");
BKCLDataType dtype =
platform::ToBKCLDataType(framework::TransToProtoVarType(x->dtype()));
int nranks = ctx.Attr<int>("nranks");
int rank = ctx.Attr<int>("rank");
int rid = ctx.Attr<int>("ring_id");
PADDLE_ENFORCE_GE(rank,
0,
platform::errors::PreconditionNotMet(
"The value of rank (%d) for c_concat must be "
"greater than or equal to 0.",
rank));
PADDLE_ENFORCE_GE(nranks,
2,
platform::errors::PreconditionNotMet(
"The value of nranks (%d) for c_concat must be "
"greater than or equal to 2.",
nranks));
PADDLE_ENFORCE_LT(rank,
nranks,
platform::errors::PreconditionNotMet(
"The value of rank (%d) for c_concat must be "
"less than that of nranks (%d).",
rank,
nranks));
#if defined(PADDLE_WITH_XPU_BKCL)
auto& dev_ctx = ctx.template device_context<phi::XPUContext>();
phi::DenseTensor temp_out;
framework::DDim temp_out_dims = x->dims();
temp_out_dims[0] *= nranks;
temp_out.Resize(temp_out_dims);
dev_ctx.template Alloc(&temp_out, x->dtype());
auto map = distributed::ProcessGroupMapFromGid::getInstance();
if (map->has(rid)) {
// Use ProcessGroup
distributed::ProcessGroup* pg = map->get(rid);
std::vector<phi::DenseTensor> in_tensor;
std::vector<phi::DenseTensor> out_tensor;
in_tensor.push_back(*x);
out_tensor.push_back(temp_out);
auto task = pg->AllGather(in_tensor, out_tensor);
task->Wait();
} else {
auto comm = dev_ctx.bkcl_context();
int64_t send_numel = x->numel();
const T* send_buff = x->data<T>();
T* recv_buff = temp_out.data<T>();
auto stream = dev_ctx.x_context()->xpu_stream;
PADDLE_ENFORCE_XPU_SUCCESS(bkcl_all_gather(
comm, send_buff, send_numel, recv_buff, dtype, stream));
}
std::vector<phi::DenseTensor> inputs;
int axis = x->dims().size() - 1;
auto out_dims = x->dims();
out_dims[out_dims.size() - 1] *= nranks;
int rows_per_tensor = x->dims()[0];
int offset = 0;
for (int i = 0; i < nranks; i++) {
phi::DenseTensor temp = temp_out.Slice(offset, offset + rows_per_tensor);
inputs.emplace_back(temp);
offset += rows_per_tensor;
}
phi::funcs::ConcatFunctor<phi::XPUContext, T> functor;
out->Resize(out_dims);
dev_ctx.template Alloc(out, x->dtype());
functor(dev_ctx, inputs, axis, out);
#else
PADDLE_THROW(platform::errors::PreconditionNotMet(
"PaddlePaddle should compile with XPU."));
#endif
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_XPU_KERNEL(c_concat,
ops::CConcatOpXPUKernel<float>,
ops::CConcatOpXPUKernel<plat::float16>);
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <vector>
#include "paddle/fluid/operators/collective/c_split_op.h"
#if defined(PADDLE_WITH_XPU)
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#endif
namespace paddle {
namespace operators {
template <typename T>
class CSplitOpXPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
using XPUType = typename XPUTypeTrait<T>::Type;
auto x = ctx.Input<phi::DenseTensor>("X");
auto out = ctx.Output<phi::DenseTensor>("Out");
int nranks = ctx.Attr<int>("nranks");
int rank = ctx.Attr<int>("rank");
PADDLE_ENFORCE_GE(rank,
0,
platform::errors::PreconditionNotMet(
"The value of rank (%d) for c_split must be "
"greater than or equal to 0.",
rank));
PADDLE_ENFORCE_GE(nranks,
2,
platform::errors::PreconditionNotMet(
"The value of nranks (%d) for c_split must be "
"greater than or equal to 2.",
nranks));
PADDLE_ENFORCE_LT(rank,
nranks,
platform::errors::PreconditionNotMet(
"The value of rank (%d) for c_split must be "
"less than that of nranks (%d).",
rank,
nranks));
auto& dev_ctx = ctx.template device_context<phi::XPUContext>();
auto dims = x->dims();
auto dims_size = dims.size();
// final dim
int64_t end_size = dims[dims_size - 1];
// remain dim
auto remain_ddim = phi::slice_ddim(dims, 0, dims_size - 1);
int64_t remain_numel = phi::product(remain_ddim);
dims[dims_size - 1] /= nranks;
out->Resize(dims);
dev_ctx.template Alloc(out, x->dtype());
std::vector<XPUType*> output_list(nranks, nullptr);
output_list.at(rank) = reinterpret_cast<XPUType*>(out->data<T>());
std::vector<int64_t> split_list(nranks, dims[dims_size - 1]);
int axis = 1;
auto ret = xpu::split(dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(x->data<T>()),
output_list,
{remain_numel, end_size},
split_list,
axis);
PADDLE_ENFORCE_XDNN_SUCCESS(ret, "split");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_XPU_KERNEL(c_split,
ops::CSplitOpXPUKernel<float>,
ops::CSplitOpXPUKernel<int>,
ops::CSplitOpXPUKernel<plat::float16>);
......@@ -82,6 +82,8 @@ XPUOpMap& get_kl2_ops() {
XPUKernelSet({phi::DataType::FLOAT16,
phi::DataType::FLOAT32,
phi::DataType::INT32})},
{"c_concat",
XPUKernelSet({phi::DataType::FLOAT16, phi::DataType::FLOAT32})},
{"c_embedding", XPUKernelSet({phi::DataType::FLOAT32})},
{"c_identity",
XPUKernelSet({phi::DataType::FLOAT16,
......@@ -89,6 +91,10 @@ XPUOpMap& get_kl2_ops() {
phi::DataType::FLOAT64,
phi::DataType::INT32,
phi::DataType::INT64})},
{"c_split",
XPUKernelSet({phi::DataType::FLOAT16,
phi::DataType::FLOAT32,
phi::DataType::INT32})},
{"c_sync_calc_stream", XPUKernelSet({phi::DataType::FLOAT32})},
{"c_sync_comm_stream", XPUKernelSet({phi::DataType::FLOAT32})},
{"cast",
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from test_collective_base_xpu import TestCollectiveRunnerBase, runtime_main
import paddle
import paddle.fluid as fluid
import paddle.fluid.layers as layers
from paddle.fluid import core
paddle.enable_static()
class TestCollectiveConcat(TestCollectiveRunnerBase):
def __init__(self):
self.global_ring_id = 0
def get_model(self, main_prog, startup_program):
ring_id = 0
nranks = 2
with fluid.program_guard(main_prog, startup_program):
tindata = layers.data(
name="tindata", shape=[10, 1000], dtype='float32'
)
toutdata = main_prog.current_block().create_var(
name="outofconcat",
dtype='float32',
type=core.VarDesc.VarType.LOD_TENSOR,
persistable=False,
stop_gradient=False,
)
main_prog.global_block().append_op(
type="c_concat",
inputs={'X': tindata},
attrs={'ring_id': ring_id, 'rank': self.rank, 'nranks': nranks},
outputs={'Out': toutdata},
)
return toutdata
if __name__ == "__main__":
runtime_main(TestCollectiveConcat, "c_concat", 0)
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from test_collective_base_xpu import TestCollectiveRunnerBase, runtime_main
import paddle
import paddle.fluid as fluid
import paddle.fluid.layers as layers
from paddle.fluid import core
paddle.enable_static()
class TestCollectiveAllGather(TestCollectiveRunnerBase):
def __init__(self):
self.global_ring_id = 0
def get_model(self, main_prog, startup_program):
ring_id = 0
nranks = 2
with fluid.program_guard(main_prog, startup_program):
tindata = layers.data(
name="tindata", shape=[10, 1000], dtype='float32'
)
toutdata = main_prog.current_block().create_var(
name="outofsplit",
dtype='float32',
type=core.VarDesc.VarType.LOD_TENSOR,
persistable=False,
stop_gradient=False,
)
main_prog.global_block().append_op(
type="c_split",
inputs={'X': tindata},
attrs={'ring_id': ring_id, 'rank': self.rank, 'nranks': nranks},
outputs={'Out': toutdata},
)
return toutdata
if __name__ == "__main__":
runtime_main(TestCollectiveAllGather, "c_split", 0)
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import unittest
from test_collective_base_xpu import TestDistBase
import paddle
from paddle.fluid import core
sys.path.append("..")
from xpu.get_test_cover_info import (
XPUOpTestWrapper,
create_test_class,
get_xpu_op_support_types,
)
paddle.enable_static()
class XPUTestCConcatOp(XPUOpTestWrapper):
def __init__(self):
self.op_name = 'c_concat'
self.use_dynamic_create_class = False
class TestConcatOp(TestDistBase):
def _setup_config(self):
pass
def test_concat(self, col_type="c_concat"):
self.check_with_place(
"collective_concat_op.py", col_type, self.in_type_str
)
support_types = get_xpu_op_support_types('c_concat')
for stype in support_types:
create_test_class(
globals(),
XPUTestCConcatOp,
stype,
ignore_device_version=[core.XPUVersion.XPU1],
)
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import unittest
from test_collective_base_xpu import TestDistBase
import paddle
from paddle.fluid import core
sys.path.append("..")
from xpu.get_test_cover_info import (
XPUOpTestWrapper,
create_test_class,
get_xpu_op_support_types,
)
paddle.enable_static()
class XPUTestCSplitOp(XPUOpTestWrapper):
def __init__(self):
self.op_name = 'c_split'
self.use_dynamic_create_class = False
class TestSplitOp(TestDistBase):
def _setup_config(self):
pass
def test_split(self, col_type="c_split"):
self.check_with_place(
"collective_split_op.py", col_type, self.in_type_str
)
support_types = get_xpu_op_support_types('c_split')
for stype in support_types:
create_test_class(
globals(),
XPUTestCSplitOp,
stype,
ignore_device_version=[core.XPUVersion.XPU1],
)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册