未验证 提交 e5506be6 编写于 作者: Z zhangyuqin1998 提交者: GitHub

fix graph_reindex (#52930)

* fix graph_reindex

* fix

* Update op_compat.yaml
上级 d9edb233
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/multiary.h"
namespace paddle {
namespace operators {
class GraphReindexOP : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.device_context().GetPlace());
}
};
class GraphReindexOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "The destination nodes of the input graph.");
AddInput("Neighbors", "The neighbor nodes of the destination nodes `X`.");
AddInput("Count", "The number of neighbor nodes of each destination node.");
// Note(daisiming): If using buffer hashtable, we must ensure the number of
// nodes of the input graph should be no larger than maximum(int32).
AddInput("HashTable_Value",
"One of the buffer tensor of hashtable for reindex")
.AsDispensable();
AddInput("HashTable_Index",
"One of the buffer tensor of hashtable for reindex")
.AsDispensable();
AddAttr<bool>("flag_buffer_hashtable",
"Define whether using the buffer hashtable.")
.SetDefault(false);
AddOutput("Reindex_Src",
"The source node index of graph edges after reindex.");
AddOutput("Reindex_Dst",
"The destination node index of graph edges after reindex.");
AddOutput("Out_Nodes", "The original index of graph nodes before reindex");
AddComment(R"DOC(
Graph Reindex operator.
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(graph_reindex,
GraphReindexInferShapeFunctor,
PD_INFER_META(phi::GraphReindexInferMeta));
REGISTER_OPERATOR(
graph_reindex,
ops::GraphReindexOP,
ops::GraphReindexOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
GraphReindexInferShapeFunctor);
......@@ -2383,6 +2383,12 @@
attrs:
pivot : pivots
- op: reindex_graph (graph_reindex)
inputs :
{x : X, neighbors : Neighbors, count : Count, hashtable_value : HashTable_Value, hashtable_index : HashTable_Index}
outputs :
{reindex_src : Reindex_Src, reindex_dst : Reindex_Dst, out_nodes : Out_Nodes}
- op: sigmoid_cross_entropy_with_logits
backward: sigmoid_cross_entropy_with_logits_grad
inputs :
......
......@@ -1496,6 +1496,16 @@
inplace : (x -> out)
backward : reciprocal_grad
- op : reindex_graph
args : (Tensor x, Tensor neighbors, Tensor count, Tensor hashtable_value, Tensor hashtable_index)
output : Tensor(reindex_src), Tensor(reindex_dst), Tensor(out_nodes)
infer_meta :
func : GraphReindexInferMeta
kernel :
func : graph_reindex
data_type : x
optional : hashtable_value, hashtable_index
- op : relu
args : (Tensor x)
output : Tensor(out)
......
......@@ -1336,10 +1336,11 @@ void GraphReindexInferMeta(const MetaTensor& x,
const MetaTensor& count,
const MetaTensor& hashtable_value,
const MetaTensor& hashtable_index,
bool flag_buffer_hashtable,
MetaTensor* reindex_src,
MetaTensor* reindex_dst,
MetaTensor* out_nodes) {
bool flag_buffer_hashtable =
hashtable_value.initialized() && hashtable_index.initialized();
auto GraphReindexShapeCheck = [](const phi::DDim& dims,
std::string tensor_name) {
if (dims.size() == 2) {
......
......@@ -288,7 +288,6 @@ void GraphReindexInferMeta(const MetaTensor& x,
const MetaTensor& count,
const MetaTensor& hashtable_value,
const MetaTensor& hashtable_index,
bool flag_buffer_hashtable,
MetaTensor* reindex_src,
MetaTensor* reindex_dst,
MetaTensor* out_nodes);
......
......@@ -29,7 +29,6 @@ void GraphReindexKernel(const Context& dev_ctx,
const DenseTensor& count,
const paddle::optional<DenseTensor>& hashtable_value,
const paddle::optional<DenseTensor>& hashtable_index,
bool flag_buffer_hashtable,
DenseTensor* reindex_src,
DenseTensor* reindex_dst,
DenseTensor* out_nodes) {
......
......@@ -381,10 +381,11 @@ void GraphReindexKernel(const Context& dev_ctx,
const DenseTensor& count,
const paddle::optional<DenseTensor>& hashtable_value,
const paddle::optional<DenseTensor>& hashtable_index,
bool flag_buffer_hashtable,
DenseTensor* reindex_src,
DenseTensor* reindex_dst,
DenseTensor* out_nodes) {
bool flag_buffer_hashtable =
hashtable_value.is_initialized() && hashtable_index.is_initialized();
const T* x_data = x.data<T>();
const T* neighbors_data = neighbors.data<T>();
const int* count_data = count.data<int>();
......
......@@ -51,7 +51,6 @@ void GraphReindexKernel(const Context& dev_ctx,
const DenseTensor& count,
const paddle::optional<DenseTensor>& hashtable_value,
const paddle::optional<DenseTensor>& hashtable_index,
bool flag_buffer_hashtable,
DenseTensor* reindex_src,
DenseTensor* reindex_dst,
DenseTensor* out_nodes);
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature GraphReindexOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature(
"graph_reindex",
{"X", "Neighbors", "Count", "HashTable_Value", "HashTable_Index"},
{"flag_buffer_hashtable"},
{"Reindex_Src", "Reindex_Dst", "Out_Nodes"});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(graph_reindex, phi::GraphReindexOpArgumentMapping);
......@@ -13,7 +13,7 @@
# limitations under the License.
import paddle
from paddle import _legacy_C_ops
from paddle import _C_ops
from paddle.fluid.data_feeder import check_variable_and_dtype
from paddle.fluid.framework import Variable, _non_static_mode
from paddle.fluid.layer_helper import LayerHelper
......@@ -87,14 +87,12 @@ def reindex_graph(
)
if _non_static_mode():
reindex_src, reindex_dst, out_nodes = _legacy_C_ops.graph_reindex(
reindex_src, reindex_dst, out_nodes = _C_ops.reindex_graph(
x,
neighbors,
count,
value_buffer,
index_buffer,
"flag_buffer_hashtable",
use_buffer_hashtable,
)
return reindex_src, reindex_dst, out_nodes
......@@ -130,7 +128,6 @@ def reindex_graph(
"Reindex_Dst": reindex_dst,
"Out_Nodes": out_nodes,
},
attrs={"flag_buffer_hashtable": use_buffer_hashtable},
)
return reindex_src, reindex_dst, out_nodes
......@@ -211,14 +208,12 @@ def reindex_heter_graph(
if _non_static_mode():
neighbors = paddle.concat(neighbors, axis=0)
count = paddle.concat(count, axis=0)
reindex_src, reindex_dst, out_nodes = _legacy_C_ops.graph_reindex(
reindex_src, reindex_dst, out_nodes = _C_ops.reindex_graph(
x,
neighbors,
count,
value_buffer,
index_buffer,
"flag_buffer_hashtable",
use_buffer_hashtable,
)
return reindex_src, reindex_dst, out_nodes
......@@ -264,6 +259,5 @@ def reindex_heter_graph(
"Reindex_Dst": reindex_dst,
"Out_Nodes": out_nodes,
},
attrs={"flag_buffer_hashtable": use_buffer_hashtable},
)
return reindex_src, reindex_dst, out_nodes
......@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle import _legacy_C_ops
from paddle import _C_ops
from paddle.fluid.data_feeder import check_variable_and_dtype
from paddle.fluid.framework import _non_static_mode
from paddle.fluid.layer_helper import LayerHelper
......@@ -117,14 +117,12 @@ def graph_reindex(
)
if _non_static_mode():
reindex_src, reindex_dst, out_nodes = _legacy_C_ops.graph_reindex(
reindex_src, reindex_dst, out_nodes = _C_ops.reindex_graph(
x,
neighbors,
count,
value_buffer,
index_buffer,
"flag_buffer_hashtable",
flag_buffer_hashtable,
)
return reindex_src, reindex_dst, out_nodes
......@@ -160,6 +158,5 @@ def graph_reindex(
"Reindex_Dst": reindex_dst,
"Out_Nodes": out_nodes,
},
attrs={"flag_buffer_hashtable": flag_buffer_hashtable},
)
return reindex_src, reindex_dst, out_nodes
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册