未验证 提交 43d6bdca 编写于 作者: R ronnywang 提交者: GitHub

Fix the custom pass with empty type (#54065)

上级 23baa8c6
......@@ -420,13 +420,17 @@ GraphPatternDetector::handle_t GetGenerateRewrite(
return handler;
}
GeneratePass::GeneratePass(const std::string& binary_str) {
GeneratePass::GeneratePass(const std::string& binary_str,
const std::string& pass_type) {
RegisterType(pass_type);
multi_pass_desc_.ParseFromString(binary_str);
VerifyDesc();
}
GeneratePass::GeneratePass(const proto::MultiPassDesc& multi_pass_desc)
GeneratePass::GeneratePass(const proto::MultiPassDesc& multi_pass_desc,
const std::string& pass_type)
: multi_pass_desc_(multi_pass_desc) {
RegisterType(pass_type);
VerifyDesc();
}
......
......@@ -24,9 +24,11 @@ namespace ir {
class GeneratePass : public Pass {
public:
// from binary_str
explicit GeneratePass(const std::string& binary_str);
explicit GeneratePass(const std::string& binary_str,
const std::string& pass_type = "");
// from PassDesc/MultiPassDesc
explicit GeneratePass(const proto::MultiPassDesc& multi_pass_desc);
explicit GeneratePass(const proto::MultiPassDesc& multi_pass_desc,
const std::string& pass_type = "");
protected:
void ApplyImpl(Graph* graph) const override;
......
......@@ -185,6 +185,9 @@ class Pass {
// Pass must be placed after this Pass.
virtual void CheckPrevPass() const {}
protected:
void RegisterType(const std::string &type) { type_ = type; }
private:
template <typename PassType>
friend struct PassRegistrar;
......@@ -207,8 +210,6 @@ class Pass {
attrs_.insert(default_attr_values.begin(), default_attr_values.end());
}
void RegisterType(const std::string &type) { type_ = type; }
mutable bool applied_{false};
std::string type_;
std::unordered_set<std::string> required_pass_attrs_;
......
......@@ -60,6 +60,8 @@ class CConcatOpCustomDeviceKernel : public framework::OpKernel<T> {
int nranks = ctx.Attr<int>("nranks");
int rank = ctx.Attr<int>("rank");
int rid = ctx.Attr<int>("ring_id");
auto place = ctx.GetPlace();
PADDLE_ENFORCE_GE(rank,
0,
platform::errors::PreconditionNotMet(
......@@ -98,8 +100,27 @@ class CConcatOpCustomDeviceKernel : public framework::OpKernel<T> {
auto task = pg->AllGather(in_tensor, out_tensor);
task->Wait();
} else {
PADDLE_THROW(phi::errors::Unavailable(
"CustomDevice c_concat only support ProcessGroup"));
auto comm = platform::XCCLCommContext::Instance(place.GetDeviceType())
.Get(rid, place);
PADDLE_ENFORCE_EQ(
nranks,
comm->nranks(),
platform::errors::InvalidArgument(
"nranks: %s should equal to %s", nranks, comm->nranks()));
int64_t send_numel = x->numel();
const T* send_buff = x->data<T>();
T* recv_buff = temp_out.data<T>();
// should ExecutionContext for calc stream.
auto& stream = *dev_ctx.GetStream();
phi::DeviceManager::CCLAllGather(
place.GetDeviceType(),
reinterpret_cast<void*>(const_cast<T*>(send_buff)),
recv_buff,
send_numel,
phi::ccl::ToCCLDataType(x->dtype()),
comm->comm(),
stream);
}
std::vector<phi::DenseTensor> inputs;
int axis = x->dims().size() - 1;
......
......@@ -2275,8 +2275,8 @@ All parameter, weight, gradient are variables in Paddle.
pass_type, [pass_type, callable]() {
py::gil_scoped_acquire guard;
std::unique_ptr<framework::ir::Pass> pass(
new framework::ir::GeneratePass(
py::cast<std::string>(callable())));
new framework::ir::GeneratePass(py::cast<std::string>(callable()),
pass_type));
return pass;
});
});
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册