未验证 提交 786c6e99 编写于 作者: L LiYuRio 提交者: GitHub

Merge reduce type of auto_parallel and phi kernel (#56202)

上级 94afa31a
...@@ -24,6 +24,7 @@ namespace paddle { ...@@ -24,6 +24,7 @@ namespace paddle {
namespace distributed { namespace distributed {
// TODO(shenliang03): To support AVG for reduce // TODO(shenliang03): To support AVG for reduce
// TODO(liyurui): remove this reduce op, use phi reduce op instead.
enum class ReduceOp : std::uint8_t { SUM = 0, MAX, MIN, PRODUCT, AVG }; enum class ReduceOp : std::uint8_t { SUM = 0, MAX, MIN, PRODUCT, AVG };
struct AllreduceOptions { struct AllreduceOptions {
......
...@@ -15,7 +15,18 @@ ...@@ -15,7 +15,18 @@
#pragma once #pragma once
namespace phi { namespace phi {
namespace distributed {
enum ReduceType { kRedSum, kRedMax, kRedMin, kRedProd }; enum class ReduceType {
} kRedSum,
kRedMax,
kRedMin,
kRedProd,
kRedAvg,
kRedAny,
kRedAll
};
constexpr const char* ReduceTypeStrings[] = {
"SUM", "MAX", "MIN", "PRODUCT", "AVG", "ANY", "ALL"};
} // namespace phi } // namespace phi
...@@ -226,7 +226,7 @@ bool TensorDistAttr::verify_partial_status() const { ...@@ -226,7 +226,7 @@ bool TensorDistAttr::verify_partial_status() const {
if (itr.first < 0 || itr.first >= process_mesh_.ndim()) { if (itr.first < 0 || itr.first >= process_mesh_.ndim()) {
return false; return false;
} }
if (itr.second < ReduceType::SUM || itr.second <= ReduceType::ALL) { if (itr.second < ReduceType::kRedSum || itr.second <= ReduceType::kRedAll) {
return false; return false;
} }
} }
......
...@@ -21,6 +21,7 @@ limitations under the License. */ ...@@ -21,6 +21,7 @@ limitations under the License. */
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/phi/common/reduce_type.h"
#include "paddle/phi/core/distributed/auto_parallel/auto_parallel.pb.h" #include "paddle/phi/core/distributed/auto_parallel/auto_parallel.pb.h"
#include "paddle/phi/core/distributed/auto_parallel/process_mesh.h" #include "paddle/phi/core/distributed/auto_parallel/process_mesh.h"
#include "paddle/phi/core/distributed/auto_parallel/utils.h" #include "paddle/phi/core/distributed/auto_parallel/utils.h"
...@@ -33,18 +34,6 @@ namespace auto_parallel { ...@@ -33,18 +34,6 @@ namespace auto_parallel {
constexpr const char* kDefault = "default"; constexpr const char* kDefault = "default";
enum class ReduceType : std::uint8_t {
SUM = 0,
AVG,
MAX,
MIN,
PRODUCT,
ANY,
ALL
};
constexpr const char* ReduceTypeStrings[] = {
"SUM", "AVG", "MAX", "MIN", "PRODUCT", "ANY", "ALL"};
class TensorDistAttr { class TensorDistAttr {
public: public:
TensorDistAttr() = default; TensorDistAttr() = default;
...@@ -81,7 +70,7 @@ class TensorDistAttr { ...@@ -81,7 +70,7 @@ class TensorDistAttr {
// by each dim // by each dim
void set_partial_status(const std::vector<int64_t>& dims, void set_partial_status(const std::vector<int64_t>& dims,
const ReduceType& type = ReduceType::SUM); const ReduceType& type = ReduceType::kRedSum);
// all // all
void clean_partial_status(); void clean_partial_status();
......
...@@ -37,10 +37,12 @@ class GlooCommContext final : public CommContext { ...@@ -37,10 +37,12 @@ class GlooCommContext final : public CommContext {
const phi::DenseTensor& in_tensor, const phi::DenseTensor& in_tensor,
int root, int root,
uint32_t tag = 0); uint32_t tag = 0);
void AllReduce(phi::DenseTensor* out_tensor, void AllReduce(phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor, const phi::DenseTensor& in_tensor,
int reduce_type, int reduce_type,
uint32_t tag = 0); uint32_t tag = 0);
void Reduce(phi::DenseTensor* out_tensor, void Reduce(phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor, const phi::DenseTensor& in_tensor,
int reduce_type, int reduce_type,
......
...@@ -26,8 +26,8 @@ ...@@ -26,8 +26,8 @@
#include "glog/logging.h" #include "glog/logging.h"
#include "paddle/phi/common/data_type.h" #include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/reduce_type.h"
#include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/distributed/reduce_helper.h"
namespace phi { namespace phi {
namespace distributed { namespace distributed {
...@@ -120,30 +120,31 @@ void SetInputForScatter(P* opts, const phi::DenseTensor& tensor, int nranks) { ...@@ -120,30 +120,31 @@ void SetInputForScatter(P* opts, const phi::DenseTensor& tensor, int nranks) {
template <typename T, typename P> template <typename T, typename P>
void SetReduceFunc(P* opts, int reduce_type) { void SetReduceFunc(P* opts, int reduce_type) {
// gloo only support mutable data input // gloo only support mutable data input
switch (reduce_type) { ReduceType reduce_type_enum = static_cast<ReduceType>(reduce_type);
case kRedSum: switch (reduce_type_enum) {
case ReduceType::kRedSum:
opts->setReduceFunction( opts->setReduceFunction(
static_cast<void (*)(void*, const void*, const void*, size_t)>( static_cast<void (*)(void*, const void*, const void*, size_t)>(
&gloo::sum<T>)); &gloo::sum<T>));
break; break;
case kRedMax: case ReduceType::kRedMax:
opts->setReduceFunction( opts->setReduceFunction(
static_cast<void (*)(void*, const void*, const void*, size_t)>( static_cast<void (*)(void*, const void*, const void*, size_t)>(
&gloo::max<T>)); &gloo::max<T>));
break; break;
case kRedMin: case ReduceType::kRedMin:
opts->setReduceFunction( opts->setReduceFunction(
static_cast<void (*)(void*, const void*, const void*, size_t)>( static_cast<void (*)(void*, const void*, const void*, size_t)>(
&gloo::min<T>)); &gloo::min<T>));
break; break;
case kRedProd: case ReduceType::kRedProd:
opts->setReduceFunction( opts->setReduceFunction(
static_cast<void (*)(void*, const void*, const void*, size_t)>( static_cast<void (*)(void*, const void*, const void*, size_t)>(
&gloo::product<T>)); &gloo::product<T>));
break; break;
default: default:
PADDLE_THROW( PADDLE_THROW(
errors::InvalidArgument("Invalid reduce type: %d.", reduce_type)); errors::InvalidArgument("Unsupport reduce type: %d.", reduce_type));
} }
} }
......
...@@ -14,8 +14,8 @@ ...@@ -14,8 +14,8 @@
#pragma once #pragma once
#include "paddle/phi/common/reduce_type.h"
#include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/distributed/reduce_helper.h"
namespace phi { namespace phi {
......
...@@ -44,17 +44,17 @@ void AllReduceKernel(const Context& dev_ctx, ...@@ -44,17 +44,17 @@ void AllReduceKernel(const Context& dev_ctx,
errors::NotFound("Should initialize NCCL firstly.")); errors::NotFound("Should initialize NCCL firstly."));
ncclRedOp_t red_type = ncclSum; ncclRedOp_t red_type = ncclSum;
switch (reduce_type) { switch (static_cast<ReduceType>(reduce_type)) {
case distributed::kRedSum: case ReduceType::kRedSum:
red_type = ncclSum; red_type = ncclSum;
break; break;
case distributed::kRedMax: case ReduceType::kRedMax:
red_type = ncclMax; red_type = ncclMax;
break; break;
case distributed::kRedMin: case ReduceType::kRedMin:
red_type = ncclMin; red_type = ncclMin;
break; break;
case distributed::kRedProd: case ReduceType::kRedProd:
red_type = ncclProd; red_type = ncclProd;
break; break;
} }
......
...@@ -49,17 +49,17 @@ void ReduceKernel(const Context& dev_ctx, ...@@ -49,17 +49,17 @@ void ReduceKernel(const Context& dev_ctx,
errors::NotFound("Should initialize NCCL firstly.")); errors::NotFound("Should initialize NCCL firstly."));
ncclRedOp_t red_type = ncclSum; ncclRedOp_t red_type = ncclSum;
switch (reduce_type) { switch (static_cast<ReduceType>(reduce_type)) {
case distributed::kRedSum: case ReduceType::kRedSum:
red_type = ncclSum; red_type = ncclSum;
break; break;
case distributed::kRedMax: case ReduceType::kRedMax:
red_type = ncclMax; red_type = ncclMax;
break; break;
case distributed::kRedMin: case ReduceType::kRedMin:
red_type = ncclMin; red_type = ncclMin;
break; break;
case distributed::kRedProd: case ReduceType::kRedProd:
red_type = ncclProd; red_type = ncclProd;
break; break;
} }
......
...@@ -14,8 +14,8 @@ ...@@ -14,8 +14,8 @@
#pragma once #pragma once
#include "paddle/phi/common/reduce_type.h"
#include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/distributed/reduce_helper.h"
namespace phi { namespace phi {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册