diff --git a/paddle/fluid/framework/ir/generate_pass.cc b/paddle/fluid/framework/ir/generate_pass.cc index 61c6ce5757aa1f6a6f9f1a780121e92a195e692d..0088312c7b98d1e038d00e10d3a45f34e8b438cf 100644 --- a/paddle/fluid/framework/ir/generate_pass.cc +++ b/paddle/fluid/framework/ir/generate_pass.cc @@ -420,13 +420,17 @@ GraphPatternDetector::handle_t GetGenerateRewrite( return handler; } -GeneratePass::GeneratePass(const std::string& binary_str) { +GeneratePass::GeneratePass(const std::string& binary_str, + const std::string& pass_type) { + RegisterType(pass_type); multi_pass_desc_.ParseFromString(binary_str); VerifyDesc(); } -GeneratePass::GeneratePass(const proto::MultiPassDesc& multi_pass_desc) +GeneratePass::GeneratePass(const proto::MultiPassDesc& multi_pass_desc, + const std::string& pass_type) : multi_pass_desc_(multi_pass_desc) { + RegisterType(pass_type); VerifyDesc(); } diff --git a/paddle/fluid/framework/ir/generate_pass.h b/paddle/fluid/framework/ir/generate_pass.h index 192c963cfddcb38c2bba7e6807f378673a6bba1a..3a9d0f1efa71e2500307ce6a3c550e87532b3103 100644 --- a/paddle/fluid/framework/ir/generate_pass.h +++ b/paddle/fluid/framework/ir/generate_pass.h @@ -24,9 +24,11 @@ namespace ir { class GeneratePass : public Pass { public: // from binary_str - explicit GeneratePass(const std::string& binary_str); + explicit GeneratePass(const std::string& binary_str, + const std::string& pass_type = ""); // from PassDesc/MultiPassDesc - explicit GeneratePass(const proto::MultiPassDesc& multi_pass_desc); + explicit GeneratePass(const proto::MultiPassDesc& multi_pass_desc, + const std::string& pass_type = ""); protected: void ApplyImpl(Graph* graph) const override; diff --git a/paddle/fluid/framework/ir/pass.h b/paddle/fluid/framework/ir/pass.h index c7f76cb0bfdd5b369acbfd788595d1c91a63d206..1f59466e1cd802e18135d9855ced50d21433e907 100644 --- a/paddle/fluid/framework/ir/pass.h +++ b/paddle/fluid/framework/ir/pass.h @@ -185,6 +185,9 @@ class Pass { // Pass must be placed after this Pass. virtual void CheckPrevPass() const {} + protected: + void RegisterType(const std::string &type) { type_ = type; } + private: template friend struct PassRegistrar; @@ -207,8 +210,6 @@ class Pass { attrs_.insert(default_attr_values.begin(), default_attr_values.end()); } - void RegisterType(const std::string &type) { type_ = type; } - mutable bool applied_{false}; std::string type_; std::unordered_set required_pass_attrs_; diff --git a/paddle/fluid/operators/custom_device_common_op_registry.cc b/paddle/fluid/operators/custom_device_common_op_registry.cc index a75106231e00e60328cc40961b426779b557bfba..d5ae2f84b4ed1deda9e10d966ef4ec8e4bb15356 100644 --- a/paddle/fluid/operators/custom_device_common_op_registry.cc +++ b/paddle/fluid/operators/custom_device_common_op_registry.cc @@ -60,6 +60,8 @@ class CConcatOpCustomDeviceKernel : public framework::OpKernel { int nranks = ctx.Attr("nranks"); int rank = ctx.Attr("rank"); int rid = ctx.Attr("ring_id"); + auto place = ctx.GetPlace(); + PADDLE_ENFORCE_GE(rank, 0, platform::errors::PreconditionNotMet( @@ -98,8 +100,27 @@ class CConcatOpCustomDeviceKernel : public framework::OpKernel { auto task = pg->AllGather(in_tensor, out_tensor); task->Wait(); } else { - PADDLE_THROW(phi::errors::Unavailable( - "CustomDevice c_concat only support ProcessGroup")); + auto comm = platform::XCCLCommContext::Instance(place.GetDeviceType()) + .Get(rid, place); + PADDLE_ENFORCE_EQ( + nranks, + comm->nranks(), + platform::errors::InvalidArgument( + "nranks: %s should equal to %s", nranks, comm->nranks())); + + int64_t send_numel = x->numel(); + const T* send_buff = x->data(); + T* recv_buff = temp_out.data(); + // should ExecutionContext for calc stream. + auto& stream = *dev_ctx.GetStream(); + phi::DeviceManager::CCLAllGather( + place.GetDeviceType(), + reinterpret_cast(const_cast(send_buff)), + recv_buff, + send_numel, + phi::ccl::ToCCLDataType(x->dtype()), + comm->comm(), + stream); } std::vector inputs; int axis = x->dims().size() - 1; diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 59052a40cbc5a7080a50d3149d5abdad805ed17e..7e09266271ca75c62435851d679e4f1c3ab8066f 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -2275,8 +2275,8 @@ All parameter, weight, gradient are variables in Paddle. pass_type, [pass_type, callable]() { py::gil_scoped_acquire guard; std::unique_ptr pass( - new framework::ir::GeneratePass( - py::cast(callable()))); + new framework::ir::GeneratePass(py::cast(callable()), + pass_type)); return pass; }); });