未验证 提交 786c6e99 编写于 作者: L LiYuRio 提交者: GitHub

Merge reduce type of auto_parallel and phi kernel (#56202)

上级 94afa31a
......@@ -24,6 +24,7 @@ namespace paddle {
namespace distributed {
// TODO(shenliang03): To support AVG for reduce
// TODO(liyurui): remove this reduce op, use phi reduce op instead.
enum class ReduceOp : std::uint8_t { SUM = 0, MAX, MIN, PRODUCT, AVG };
struct AllreduceOptions {
......
......@@ -15,7 +15,18 @@
#pragma once
namespace phi {
namespace distributed {
enum ReduceType { kRedSum, kRedMax, kRedMin, kRedProd };
}
enum class ReduceType {
kRedSum,
kRedMax,
kRedMin,
kRedProd,
kRedAvg,
kRedAny,
kRedAll
};
constexpr const char* ReduceTypeStrings[] = {
"SUM", "MAX", "MIN", "PRODUCT", "AVG", "ANY", "ALL"};
} // namespace phi
......@@ -226,7 +226,7 @@ bool TensorDistAttr::verify_partial_status() const {
if (itr.first < 0 || itr.first >= process_mesh_.ndim()) {
return false;
}
if (itr.second < ReduceType::SUM || itr.second <= ReduceType::ALL) {
if (itr.second < ReduceType::kRedSum || itr.second <= ReduceType::kRedAll) {
return false;
}
}
......
......@@ -21,6 +21,7 @@ limitations under the License. */
#include <string>
#include <vector>
#include "paddle/phi/common/reduce_type.h"
#include "paddle/phi/core/distributed/auto_parallel/auto_parallel.pb.h"
#include "paddle/phi/core/distributed/auto_parallel/process_mesh.h"
#include "paddle/phi/core/distributed/auto_parallel/utils.h"
......@@ -33,18 +34,6 @@ namespace auto_parallel {
constexpr const char* kDefault = "default";
enum class ReduceType : std::uint8_t {
SUM = 0,
AVG,
MAX,
MIN,
PRODUCT,
ANY,
ALL
};
constexpr const char* ReduceTypeStrings[] = {
"SUM", "AVG", "MAX", "MIN", "PRODUCT", "ANY", "ALL"};
class TensorDistAttr {
public:
TensorDistAttr() = default;
......@@ -81,7 +70,7 @@ class TensorDistAttr {
// by each dim
void set_partial_status(const std::vector<int64_t>& dims,
const ReduceType& type = ReduceType::SUM);
const ReduceType& type = ReduceType::kRedSum);
// all
void clean_partial_status();
......
......@@ -37,10 +37,12 @@ class GlooCommContext final : public CommContext {
const phi::DenseTensor& in_tensor,
int root,
uint32_t tag = 0);
void AllReduce(phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
int reduce_type,
uint32_t tag = 0);
void Reduce(phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
int reduce_type,
......
......@@ -26,8 +26,8 @@
#include "glog/logging.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/reduce_type.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/distributed/reduce_helper.h"
namespace phi {
namespace distributed {
......@@ -120,30 +120,31 @@ void SetInputForScatter(P* opts, const phi::DenseTensor& tensor, int nranks) {
template <typename T, typename P>
void SetReduceFunc(P* opts, int reduce_type) {
// gloo only support mutable data input
switch (reduce_type) {
case kRedSum:
ReduceType reduce_type_enum = static_cast<ReduceType>(reduce_type);
switch (reduce_type_enum) {
case ReduceType::kRedSum:
opts->setReduceFunction(
static_cast<void (*)(void*, const void*, const void*, size_t)>(
&gloo::sum<T>));
break;
case kRedMax:
case ReduceType::kRedMax:
opts->setReduceFunction(
static_cast<void (*)(void*, const void*, const void*, size_t)>(
&gloo::max<T>));
break;
case kRedMin:
case ReduceType::kRedMin:
opts->setReduceFunction(
static_cast<void (*)(void*, const void*, const void*, size_t)>(
&gloo::min<T>));
break;
case kRedProd:
case ReduceType::kRedProd:
opts->setReduceFunction(
static_cast<void (*)(void*, const void*, const void*, size_t)>(
&gloo::product<T>));
break;
default:
PADDLE_THROW(
errors::InvalidArgument("Invalid reduce type: %d.", reduce_type));
errors::InvalidArgument("Unsupport reduce type: %d.", reduce_type));
}
}
......
......@@ -14,8 +14,8 @@
#pragma once
#include "paddle/phi/common/reduce_type.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/distributed/reduce_helper.h"
namespace phi {
......
......@@ -44,17 +44,17 @@ void AllReduceKernel(const Context& dev_ctx,
errors::NotFound("Should initialize NCCL firstly."));
ncclRedOp_t red_type = ncclSum;
switch (reduce_type) {
case distributed::kRedSum:
switch (static_cast<ReduceType>(reduce_type)) {
case ReduceType::kRedSum:
red_type = ncclSum;
break;
case distributed::kRedMax:
case ReduceType::kRedMax:
red_type = ncclMax;
break;
case distributed::kRedMin:
case ReduceType::kRedMin:
red_type = ncclMin;
break;
case distributed::kRedProd:
case ReduceType::kRedProd:
red_type = ncclProd;
break;
}
......
......@@ -49,17 +49,17 @@ void ReduceKernel(const Context& dev_ctx,
errors::NotFound("Should initialize NCCL firstly."));
ncclRedOp_t red_type = ncclSum;
switch (reduce_type) {
case distributed::kRedSum:
switch (static_cast<ReduceType>(reduce_type)) {
case ReduceType::kRedSum:
red_type = ncclSum;
break;
case distributed::kRedMax:
case ReduceType::kRedMax:
red_type = ncclMax;
break;
case distributed::kRedMin:
case ReduceType::kRedMin:
red_type = ncclMin;
break;
case distributed::kRedProd:
case ReduceType::kRedProd:
red_type = ncclProd;
break;
}
......
......@@ -14,8 +14,8 @@
#pragma once
#include "paddle/phi/common/reduce_type.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/distributed/reduce_helper.h"
namespace phi {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册