未验证 提交 43d6bdca 编写于 作者: R ronnywang 提交者: GitHub

Fix the custom pass with empty type (#54065)

上级 23baa8c6
...@@ -420,13 +420,17 @@ GraphPatternDetector::handle_t GetGenerateRewrite( ...@@ -420,13 +420,17 @@ GraphPatternDetector::handle_t GetGenerateRewrite(
return handler; return handler;
} }
GeneratePass::GeneratePass(const std::string& binary_str) { GeneratePass::GeneratePass(const std::string& binary_str,
const std::string& pass_type) {
RegisterType(pass_type);
multi_pass_desc_.ParseFromString(binary_str); multi_pass_desc_.ParseFromString(binary_str);
VerifyDesc(); VerifyDesc();
} }
GeneratePass::GeneratePass(const proto::MultiPassDesc& multi_pass_desc) GeneratePass::GeneratePass(const proto::MultiPassDesc& multi_pass_desc,
const std::string& pass_type)
: multi_pass_desc_(multi_pass_desc) { : multi_pass_desc_(multi_pass_desc) {
RegisterType(pass_type);
VerifyDesc(); VerifyDesc();
} }
......
...@@ -24,9 +24,11 @@ namespace ir { ...@@ -24,9 +24,11 @@ namespace ir {
class GeneratePass : public Pass { class GeneratePass : public Pass {
public: public:
// from binary_str // from binary_str
explicit GeneratePass(const std::string& binary_str); explicit GeneratePass(const std::string& binary_str,
const std::string& pass_type = "");
// from PassDesc/MultiPassDesc // from PassDesc/MultiPassDesc
explicit GeneratePass(const proto::MultiPassDesc& multi_pass_desc); explicit GeneratePass(const proto::MultiPassDesc& multi_pass_desc,
const std::string& pass_type = "");
protected: protected:
void ApplyImpl(Graph* graph) const override; void ApplyImpl(Graph* graph) const override;
......
...@@ -185,6 +185,9 @@ class Pass { ...@@ -185,6 +185,9 @@ class Pass {
// Pass must be placed after this Pass. // Pass must be placed after this Pass.
virtual void CheckPrevPass() const {} virtual void CheckPrevPass() const {}
protected:
void RegisterType(const std::string &type) { type_ = type; }
private: private:
template <typename PassType> template <typename PassType>
friend struct PassRegistrar; friend struct PassRegistrar;
...@@ -207,8 +210,6 @@ class Pass { ...@@ -207,8 +210,6 @@ class Pass {
attrs_.insert(default_attr_values.begin(), default_attr_values.end()); attrs_.insert(default_attr_values.begin(), default_attr_values.end());
} }
void RegisterType(const std::string &type) { type_ = type; }
mutable bool applied_{false}; mutable bool applied_{false};
std::string type_; std::string type_;
std::unordered_set<std::string> required_pass_attrs_; std::unordered_set<std::string> required_pass_attrs_;
......
...@@ -60,6 +60,8 @@ class CConcatOpCustomDeviceKernel : public framework::OpKernel<T> { ...@@ -60,6 +60,8 @@ class CConcatOpCustomDeviceKernel : public framework::OpKernel<T> {
int nranks = ctx.Attr<int>("nranks"); int nranks = ctx.Attr<int>("nranks");
int rank = ctx.Attr<int>("rank"); int rank = ctx.Attr<int>("rank");
int rid = ctx.Attr<int>("ring_id"); int rid = ctx.Attr<int>("ring_id");
auto place = ctx.GetPlace();
PADDLE_ENFORCE_GE(rank, PADDLE_ENFORCE_GE(rank,
0, 0,
platform::errors::PreconditionNotMet( platform::errors::PreconditionNotMet(
...@@ -98,8 +100,27 @@ class CConcatOpCustomDeviceKernel : public framework::OpKernel<T> { ...@@ -98,8 +100,27 @@ class CConcatOpCustomDeviceKernel : public framework::OpKernel<T> {
auto task = pg->AllGather(in_tensor, out_tensor); auto task = pg->AllGather(in_tensor, out_tensor);
task->Wait(); task->Wait();
} else { } else {
PADDLE_THROW(phi::errors::Unavailable( auto comm = platform::XCCLCommContext::Instance(place.GetDeviceType())
"CustomDevice c_concat only support ProcessGroup")); .Get(rid, place);
PADDLE_ENFORCE_EQ(
nranks,
comm->nranks(),
platform::errors::InvalidArgument(
"nranks: %s should equal to %s", nranks, comm->nranks()));
int64_t send_numel = x->numel();
const T* send_buff = x->data<T>();
T* recv_buff = temp_out.data<T>();
// should ExecutionContext for calc stream.
auto& stream = *dev_ctx.GetStream();
phi::DeviceManager::CCLAllGather(
place.GetDeviceType(),
reinterpret_cast<void*>(const_cast<T*>(send_buff)),
recv_buff,
send_numel,
phi::ccl::ToCCLDataType(x->dtype()),
comm->comm(),
stream);
} }
std::vector<phi::DenseTensor> inputs; std::vector<phi::DenseTensor> inputs;
int axis = x->dims().size() - 1; int axis = x->dims().size() - 1;
......
...@@ -2275,8 +2275,8 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -2275,8 +2275,8 @@ All parameter, weight, gradient are variables in Paddle.
pass_type, [pass_type, callable]() { pass_type, [pass_type, callable]() {
py::gil_scoped_acquire guard; py::gil_scoped_acquire guard;
std::unique_ptr<framework::ir::Pass> pass( std::unique_ptr<framework::ir::Pass> pass(
new framework::ir::GeneratePass( new framework::ir::GeneratePass(py::cast<std::string>(callable()),
py::cast<std::string>(callable()))); pass_type));
return pass; return pass;
}); });
}); });
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册