未验证 提交 788be26d 编写于 作者: lil-Xing's avatar lil-Xing 提交者: GitHub

add phi operator c_concat and ut (#55320)

* add phi operator c_concat and ut

* update create_var use

* update copyright
上级 4b6d2f5f
......@@ -156,6 +156,16 @@
optional : bias
backward : depthwise_conv2d_transpose_grad
- op : dist_concat
args : (Tensor x, int ring_id = 0, int nranks = 1)
output : Tensor(out)
infer_meta :
func : DistConcatInferMeta
param: [x, nranks]
kernel :
func : dist_concat
param: [x, nranks]
- op : einsum
args : (Tensor[] x, str equation)
output : Tensor(out), Tensor[](inner_cache){x.size()}, Tensor[](xshape){x.size()}
......
......@@ -844,6 +844,14 @@ void DirichletInferMeta(const MetaTensor& alpha, MetaTensor* out) {
out->set_dtype(alpha.dtype());
}
void DistConcatInferMeta(const MetaTensor& x, int nranks, MetaTensor* out) {
auto dim = x.dims();
dim[dim.size() - 1] = dim[dim.size() - 1] * nranks;
if (dim[dim.size() - 1] < 0) dim[dim.size() - 1] = -1;
out->set_dtype(x.dtype());
out->set_dims(dim);
}
void DistReduceInferMeta(const MetaTensor& x, MetaTensor* out) {
out->set_dtype(x.dtype());
out->set_dims(x.dims());
......
......@@ -147,6 +147,8 @@ void DirichletInferMeta(const MetaTensor& alpha, MetaTensor* out);
void DistBroadcastInferMeta(const MetaTensor& x, MetaTensor* out);
void DistConcatInferMeta(const MetaTensor& x, int nranks, MetaTensor* out);
void DistReduceInferMeta(const MetaTensor& x, MetaTensor* out);
void EmbeddingGradSparseInferMeta(const MetaTensor& x, MetaTensor* out);
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void DistConcatKernel(const Context& dev_ctx,
const DenseTensor& x,
int nranks,
DenseTensor* out);
} // namespace phi
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/dist_concat_kernel.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/concat_and_split_functor.h"
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#include "paddle/phi/core/distributed/nccl_comm_context.h"
#endif
namespace phi {
template <typename T, typename Context>
void DistConcatKernel(const Context& dev_ctx,
const DenseTensor& x,
int nranks,
DenseTensor* out) {
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
DenseTensor temp_out;
auto temp_out_dims = x.dims();
temp_out_dims[0] *= nranks;
temp_out.Resize(temp_out_dims);
dev_ctx.template Alloc<T>(&temp_out);
auto comm_ctx =
static_cast<distributed::NCCLCommContext*>(dev_ctx.GetCommContext());
PADDLE_ENFORCE_NE(
comm_ctx,
nullptr,
errors::Unavailable("NCCLCommContext is nullptr, collective op should "
"has ring_id attr."));
PADDLE_ENFORCE_EQ(
nranks,
comm_ctx->GetSize(),
errors::InvalidArgument(
"nranks: %s should equal to %s", nranks, comm_ctx->GetSize()));
gpuStream_t stream = dev_ctx.stream();
comm_ctx->AllGather(&temp_out, x, stream);
std::vector<DenseTensor> inputs;
int axis = x.dims().size() - 1;
auto out_dims = x.dims();
out_dims[out_dims.size() - 1] *= nranks;
int rows_per_tensor = x.dims()[0];
int offset = 0;
for (int i = 0; i < nranks; i++) {
DenseTensor temp =
temp_out.Slice(static_cast<int64_t>(offset),
static_cast<int64_t>(offset + rows_per_tensor));
inputs.emplace_back(temp);
offset += rows_per_tensor;
}
phi::funcs::ConcatFunctor<Context, T> functor;
out->Resize(out_dims);
dev_ctx.template Alloc<T>(out);
functor(dev_ctx, inputs, axis, out);
#else
PADDLE_THROW(
errors::PreconditionNotMet("PaddlePaddle should compile with GPU."));
#endif
}
} // namespace phi
#if NCCL_VERSION_CODE >= 21000
PD_REGISTER_KERNEL(dist_concat,
GPU,
ALL_LAYOUT,
phi::DistConcatKernel,
float,
double,
int,
uint8_t,
int8_t,
int64_t,
bool,
phi::dtype::bfloat16,
phi::dtype::float16) {}
#else
PD_REGISTER_KERNEL(dist_concat,
GPU,
ALL_LAYOUT,
phi::DistConcatKernel,
float,
double,
int,
uint8_t,
int8_t,
int64_t,
bool,
phi::dtype::float16) {}
#endif
......@@ -323,6 +323,11 @@ def is_available():
def _init_parallel_env(backend):
master_endpoint = os.getenv("PADDLE_MASTER", None)
if master_endpoint is None:
master_endpoint = os.getenv("PADDLE_TRAINER_ENDPOINTS").split(',')[0]
assert (
master_endpoint is not None
), "Please set PADDLE_MASTER enviroment variable."
if master_endpoint:
master_addr = master_endpoint.split(":")[0]
master_port = int(master_endpoint.split(":")[1])
......
......@@ -151,6 +151,13 @@ if((WITH_GPU OR WITH_ROCM) AND (LINUX))
set_tests_properties(test_collective_broadcast_object_list_api
PROPERTIES TIMEOUT "120" LABELS "RUN_TYPE=DIST")
endif()
if((WITH_GPU OR WITH_ROCM) AND (LINUX))
py_test_modules(
test_collective_concat_api MODULES test_collective_concat_api ENVS
"http_proxy=;https_proxy=;PYTHONPATH=..:${PADDLE_BINARY_DIR}/python")
set_tests_properties(test_collective_concat_api
PROPERTIES TIMEOUT "120" LABELS "RUN_TYPE=DIST")
endif()
if((WITH_GPU OR WITH_ROCM) AND (LINUX))
py_test_modules(
test_collective_cpu_barrier_with_gloo MODULES
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from legacy_test.test_collective_api_base import (
TestCollectiveAPIRunnerBase,
runtime_main,
)
import paddle
from paddle import fluid, framework
from paddle.fluid import data_feeder
paddle.enable_static()
def concat_new(tensor, group=None):
op_type = 'dist_concat'
data_feeder.check_variable_and_dtype(
tensor,
'tensor',
[
'float16',
'float32',
'float64',
'int32',
'int64',
'int8',
'uint8',
'bool',
'uint16',
],
op_type,
)
helper = framework.LayerHelper(op_type, **locals())
ring_id = 0 if group is None else group.id
nranks = 2
out = helper.create_variable_for_type_inference(dtype=tensor.dtype)
helper.append_op(
type=op_type,
inputs={'x': [tensor]},
outputs={'out': [out]},
attrs={
'ring_id': ring_id,
'nranks': nranks,
},
)
return out
class TestCollectiveConcatAPI(TestCollectiveAPIRunnerBase):
def __init__(self):
self.global_ring_id = 0
def get_model(self, main_prog, startup_program):
pass
def get_model_new(
self, main_prog, startup_program, rank, dtype=None, reduce_type=None
):
with fluid.program_guard(main_prog, startup_program):
tindata = paddle.static.data(
name="tindata", shape=[10, 1000], dtype=dtype
)
tindata.desc.set_need_check_feed(False)
toutdata = concat_new(tindata)
return [toutdata]
if __name__ == "__main__":
runtime_main(TestCollectiveConcatAPI, "concat")
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from legacy_test.test_collective_api_base import TestDistBase
import paddle
paddle.enable_static()
class TestCollectiveConcatAPI(TestDistBase):
def _setup_config(self):
pass
def test_concat_with_comm_context(self):
dtypes_to_test = [
"float16",
"float32",
"float64",
"int32",
"int64",
"int8",
"uint8",
"bool",
]
if self._nccl_version >= 21000:
dtypes_to_test.append("bfloat16")
for dtype in dtypes_to_test:
self.check_with_place(
"collective_concat_api.py",
"dist_concat",
"nccl",
dtype=dtype,
need_envs={"USE_COMM_CONTEXT": "1"},
)
if __name__ == '__main__':
unittest.main()
......@@ -494,6 +494,15 @@ class TestDistBase(unittest.TestCase):
np.testing.assert_allclose(
result_data, need_result, rtol=1e-05, atol=1e-05
)
elif col_type == "dist_concat":
result_data = tr0_out[0]
need_result = np.concatenate((input1, input2), axis=1)
np.testing.assert_allclose(
result_data, need_result, rtol=1e-05, atol=1e-05
)
np.testing.assert_allclose(
result_data, need_result, rtol=1e-05, atol=1e-05
)
elif col_type == "alltoall":
need_result1 = np.vstack(
(
......
......@@ -528,6 +528,7 @@ HIGH_PARALLEL_JOB_NEW = [
'test_collective_reduce_api',
'test_multiprocess_dataloader_exception',
'test_collective_allgather_api',
'test_collective_concat_api',
'test_dist_fleet_ps10',
'test_dist_sparse_tensor_load_rmsprop',
'test_collective_split_embedding_none_divisible',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册