NCCLTools.h 3.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

S
ShenLiang 已提交
17
#ifdef PADDLE_WITH_CUDA
18
#include <cuda_runtime.h>
S
ShenLiang 已提交
19 20 21 22 23
#endif
#ifdef PADDLE_WITH_HIP
#include <hip/hip_runtime.h>
#endif

24
#include <error.h>
25

26 27
#include <string>

28
#include "paddle/fluid/distributed/collective/Types.h"
29 30
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/variable.h"
S
ShenLiang 已提交
31 32

#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
33
#include "paddle/fluid/platform/cuda_device_guard.h"
S
ShenLiang 已提交
34 35
#endif

36
#include "paddle/fluid/platform/device_context.h"
S
ShenLiang 已提交
37 38 39 40

#ifdef PADDLE_WITH_RCCL
#include "paddle/fluid/platform/dynload/rccl.h"
#else
41
#include "paddle/fluid/platform/dynload/nccl.h"
S
ShenLiang 已提交
42 43
#endif

44
#include "paddle/fluid/platform/enforce.h"
R
Ruibiao Chen 已提交
45
#include "paddle/utils/variant.h"
46 47 48 49

namespace paddle {
namespace distributed {

50
#define NCCL_CHECK(cmd)                                 \
51 52 53 54 55 56 57 58 59
  do {                                                  \
    ncclResult_t r = cmd;                               \
    if (r != ncclSuccess) {                             \
      printf("Failed, NCCL error %s:%d '%s'\n",         \
             __FILE__,                                  \
             __LINE__,                                  \
             platform::dynload::ncclGetErrorString(r)); \
      exit(EXIT_FAILURE);                               \
    }                                                   \
60 61
  } while (0)

L
lilong12 已提交
62
ncclRedOp_t ToNCCLRedType(ReduceOp reduction);
63

L
lilong12 已提交
64 65
std::string SerializeNCCLUniqueId(const ncclUniqueId& ncclID);

66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92
// static check for p2p
void StaticCheckTensor(const phi::DenseTensor& tensor,
                       int rank,
                       int world_size);

// static check for collective
void StaticCheckTensors(const phi::DenseTensor& out_tensor,
                        const phi::DenseTensor& in_tensor,
                        int rank,
                        int world_size,
                        int out_size_factor,
                        int in_size_factor);

void StaticCheckTensorsSameShape(const phi::DenseTensor& out_tensor,
                                 const phi::DenseTensor& in_tensor,
                                 int rank,
                                 int world_size);

void StaticCheckTensorsScatterLikeShape(const phi::DenseTensor& out_tensor,
                                        const phi::DenseTensor& in_tensor,
                                        int rank,
                                        int world_size);

void StaticCheckTensorsGatherLikeShape(const phi::DenseTensor& out_tensor,
                                       const phi::DenseTensor& in_tensor,
                                       int rank,
                                       int world_size);
93 94
}  // namespace distributed
}  // namespace paddle