From 43d6bdca6851c258ce1e44e44da1aca455a39580 Mon Sep 17 00:00:00 2001 From: ronnywang Date: Thu, 25 May 2023 10:09:40 +0800 Subject: [PATCH] Fix the custom pass with empty type (#54065) --- paddle/fluid/framework/ir/generate_pass.cc | 8 ++++-- paddle/fluid/framework/ir/generate_pass.h | 6 +++-- paddle/fluid/framework/ir/pass.h | 5 ++-- .../custom_device_common_op_registry.cc | 25 +++++++++++++++++-- paddle/fluid/pybind/pybind.cc | 4 +-- 5 files changed, 38 insertions(+), 10 deletions(-) diff --git a/paddle/fluid/framework/ir/generate_pass.cc b/paddle/fluid/framework/ir/generate_pass.cc index 61c6ce5757a..0088312c7b9 100644 --- a/paddle/fluid/framework/ir/generate_pass.cc +++ b/paddle/fluid/framework/ir/generate_pass.cc @@ -420,13 +420,17 @@ GraphPatternDetector::handle_t GetGenerateRewrite( return handler; } -GeneratePass::GeneratePass(const std::string& binary_str) { +GeneratePass::GeneratePass(const std::string& binary_str, + const std::string& pass_type) { + RegisterType(pass_type); multi_pass_desc_.ParseFromString(binary_str); VerifyDesc(); } -GeneratePass::GeneratePass(const proto::MultiPassDesc& multi_pass_desc) +GeneratePass::GeneratePass(const proto::MultiPassDesc& multi_pass_desc, + const std::string& pass_type) : multi_pass_desc_(multi_pass_desc) { + RegisterType(pass_type); VerifyDesc(); } diff --git a/paddle/fluid/framework/ir/generate_pass.h b/paddle/fluid/framework/ir/generate_pass.h index 192c963cfdd..3a9d0f1efa7 100644 --- a/paddle/fluid/framework/ir/generate_pass.h +++ b/paddle/fluid/framework/ir/generate_pass.h @@ -24,9 +24,11 @@ namespace ir { class GeneratePass : public Pass { public: // from binary_str - explicit GeneratePass(const std::string& binary_str); + explicit GeneratePass(const std::string& binary_str, + const std::string& pass_type = ""); // from PassDesc/MultiPassDesc - explicit GeneratePass(const proto::MultiPassDesc& multi_pass_desc); + explicit GeneratePass(const proto::MultiPassDesc& multi_pass_desc, + const std::string& pass_type = ""); protected: void ApplyImpl(Graph* graph) const override; diff --git a/paddle/fluid/framework/ir/pass.h b/paddle/fluid/framework/ir/pass.h index c7f76cb0bfd..1f59466e1cd 100644 --- a/paddle/fluid/framework/ir/pass.h +++ b/paddle/fluid/framework/ir/pass.h @@ -185,6 +185,9 @@ class Pass { // Pass must be placed after this Pass. virtual void CheckPrevPass() const {} + protected: + void RegisterType(const std::string &type) { type_ = type; } + private: template friend struct PassRegistrar; @@ -207,8 +210,6 @@ class Pass { attrs_.insert(default_attr_values.begin(), default_attr_values.end()); } - void RegisterType(const std::string &type) { type_ = type; } - mutable bool applied_{false}; std::string type_; std::unordered_set required_pass_attrs_; diff --git a/paddle/fluid/operators/custom_device_common_op_registry.cc b/paddle/fluid/operators/custom_device_common_op_registry.cc index a75106231e0..d5ae2f84b4e 100644 --- a/paddle/fluid/operators/custom_device_common_op_registry.cc +++ b/paddle/fluid/operators/custom_device_common_op_registry.cc @@ -60,6 +60,8 @@ class CConcatOpCustomDeviceKernel : public framework::OpKernel { int nranks = ctx.Attr("nranks"); int rank = ctx.Attr("rank"); int rid = ctx.Attr("ring_id"); + auto place = ctx.GetPlace(); + PADDLE_ENFORCE_GE(rank, 0, platform::errors::PreconditionNotMet( @@ -98,8 +100,27 @@ class CConcatOpCustomDeviceKernel : public framework::OpKernel { auto task = pg->AllGather(in_tensor, out_tensor); task->Wait(); } else { - PADDLE_THROW(phi::errors::Unavailable( - "CustomDevice c_concat only support ProcessGroup")); + auto comm = platform::XCCLCommContext::Instance(place.GetDeviceType()) + .Get(rid, place); + PADDLE_ENFORCE_EQ( + nranks, + comm->nranks(), + platform::errors::InvalidArgument( + "nranks: %s should equal to %s", nranks, comm->nranks())); + + int64_t send_numel = x->numel(); + const T* send_buff = x->data(); + T* recv_buff = temp_out.data(); + // should ExecutionContext for calc stream. + auto& stream = *dev_ctx.GetStream(); + phi::DeviceManager::CCLAllGather( + place.GetDeviceType(), + reinterpret_cast(const_cast(send_buff)), + recv_buff, + send_numel, + phi::ccl::ToCCLDataType(x->dtype()), + comm->comm(), + stream); } std::vector inputs; int axis = x->dims().size() - 1; diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 59052a40cbc..7e09266271c 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -2275,8 +2275,8 @@ All parameter, weight, gradient are variables in Paddle. pass_type, [pass_type, callable]() { py::gil_scoped_acquire guard; std::unique_ptr pass( - new framework::ir::GeneratePass( - py::cast(callable()))); + new framework::ir::GeneratePass(py::cast(callable()), + pass_type)); return pass; }); }); -- GitLab