未验证 提交 67b46e45 编写于 作者: S Siming Dai 提交者: GitHub

[phi] Update graph_send_recv OP (#40509)

* add out_size shape for graph_send_recv

* fix bug in register kernel: no const int& support

* add out_size in infermeta

* change unittest

* fix unittest

* fix out_size default value

* fix doc

* delete arg mapping

* add sig

* move -1 to 0

* move -1 to 0
上级 dec9094d
......@@ -38,7 +38,7 @@ class GraphSendRecvGradOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
auto in_dims = ctx->GetInputDim(framework::GradVarName("Out"));
auto in_dims = ctx->GetInputDim("X");
ctx->SetOutputDim(framework::GradVarName("X"), in_dims);
}
......@@ -68,6 +68,12 @@ class GraphSendRecvOpMaker : public framework::OpProtoAndCheckerMaker {
"tensors of Dst_index.")
.SetDefault("SUM")
.InEnum({"SUM", "MEAN", "MIN", "MAX"});
AddAttr<int64_t>(
"out_size",
"(int64_t, default 0)"
"Define the first dimension of Output tensor."
"If set default 0, then the shape of Out is the same with X.")
.SetDefault(0);
AddComment(R"DOC(
Graph Learning Send_Recv combine operator.
......@@ -93,6 +99,7 @@ class GraphSendRecvGradOpMaker : public framework::SingleGradOpMaker<T> {
op->SetType("graph_send_recv_grad");
op->SetInput("Src_index", this->Input("Src_index"));
op->SetInput("Dst_index", this->Input("Dst_index"));
op->SetInput("X", this->Input("X"));
if (BOOST_GET_CONST(std::string, this->GetAttr("pool_type")) == "MEAN") {
op->SetInput("Dst_count", this->Output("Dst_count"));
......@@ -100,7 +107,6 @@ class GraphSendRecvGradOpMaker : public framework::SingleGradOpMaker<T> {
if (BOOST_GET_CONST(std::string, this->GetAttr("pool_type")) == "MIN" ||
BOOST_GET_CONST(std::string, this->GetAttr("pool_type")) == "MAX") {
op->SetInput("X", this->Input("X"));
op->SetInput("Out", this->Output("Out"));
}
......
......@@ -145,6 +145,7 @@ void GraphSendRecvInferMeta(const MetaTensor& x,
const MetaTensor& src_index,
const MetaTensor& dst_index,
const std::string& pool_type,
int64_t out_size,
MetaTensor* out,
MetaTensor* dst_count) {
auto src_index_dims = src_index.dims();
......@@ -187,11 +188,23 @@ void GraphSendRecvInferMeta(const MetaTensor& x,
"Src_index and Dst_index should have the same shape."));
auto dims = x.dims();
if (out_size <= 0) {
out->set_dims(dims);
} else {
std::vector<int64_t> dims_ = phi::vectorize(dims);
if (dims_.size() > 0) {
dims_[0] = out_size;
}
out->set_dims(phi::make_ddim(dims_));
}
out->set_dtype(x.dtype());
if (pool_type == "MEAN") {
if (out_size <= 0) {
dst_count->set_dims({dims[0]});
} else {
dst_count->set_dims({out_size});
}
dst_count->set_dtype(DataType::INT32);
}
}
......
......@@ -51,6 +51,7 @@ void GraphSendRecvInferMeta(const MetaTensor& x,
const MetaTensor& src_index,
const MetaTensor& dst_index,
const std::string& pool_type,
int64_t out_size,
MetaTensor* out,
MetaTensor* dst_count);
......
......@@ -23,15 +23,14 @@
namespace phi {
template <typename T, typename IndexT, typename Functor>
void GraphSendRecvCpuGradLoop(const int& input_size,
const int& index_size,
void GraphSendRecvCpuGradLoop(const int& index_size,
const IndexT* s_index,
const IndexT* d_index,
const DenseTensor& src,
const DenseTensor& input,
DenseTensor* dst,
const std::string& pool_type,
const int* dst_count = nullptr,
const DenseTensor* input = nullptr,
const DenseTensor* output = nullptr) {
if (pool_type == "SUM") {
Functor functor;
......@@ -55,7 +54,7 @@ void GraphSendRecvCpuGradLoop(const int& input_size,
for (int i = 0; i < index_size; ++i) {
const IndexT& forward_src_idx = d_index[i];
const IndexT& forward_dst_idx = s_index[i];
auto input_slice = input->Slice(forward_src_idx, forward_src_idx + 1);
auto input_slice = input.Slice(forward_src_idx, forward_src_idx + 1);
auto output_slice = output->Slice(forward_dst_idx, forward_dst_idx + 1);
auto eigen_input = phi::EigenVector<T>::Flatten(input_slice);
auto eigen_output = phi::EigenVector<T>::Flatten(output_slice);
......@@ -73,18 +72,18 @@ template <typename Context, typename T, typename IndexT>
void GraphSendRecvGradOpKernelLaunchHelper(
const Context& ctx,
const DenseTensor& out_grad,
const DenseTensor& x,
const DenseTensor& src_index,
const DenseTensor& dst_index,
const std::string& pool_type,
DenseTensor* x_grad,
const DenseTensor* dst_count = nullptr,
const DenseTensor* x = nullptr,
const DenseTensor* out = nullptr) {
const int& index_size = dst_index.dims()[0];
ctx.template Alloc<T>(x_grad);
T* p_output = x_grad->data<T>();
const auto& src_dims = out_grad.dims();
const auto& src_dims = x.dims();
int64_t memset_size = 1;
for (int i = 0; i < src_dims.size(); ++i) memset_size *= src_dims[i];
const size_t& memset_bytes = memset_size * sizeof(T);
......@@ -97,29 +96,22 @@ void GraphSendRecvGradOpKernelLaunchHelper(
if (pool_type == "SUM") {
GraphSendRecvCpuGradLoop<T, IndexT, GraphSendRecvSumFunctor<T>>(
src_dims[0], index_size, d_index, s_index, out_grad, x_grad, pool_type);
index_size, d_index, s_index, out_grad, x, x_grad, pool_type);
} else if (pool_type == "MEAN") {
const int* s_count = dst_count->data<int>();
// Functor not used here.
GraphSendRecvCpuGradLoop<T, IndexT, GraphSendRecvSumFunctor<T>>(src_dims[0],
index_size,
d_index,
s_index,
out_grad,
x_grad,
pool_type,
s_count);
GraphSendRecvCpuGradLoop<T, IndexT, GraphSendRecvSumFunctor<T>>(
index_size, d_index, s_index, out_grad, x, x_grad, pool_type, s_count);
} else if (pool_type == "MIN" || pool_type == "MAX") {
// Functor not used here.
GraphSendRecvCpuGradLoop<T, IndexT, GraphSendRecvMinFunctor<T>>(src_dims[0],
index_size,
GraphSendRecvCpuGradLoop<T, IndexT, GraphSendRecvMinFunctor<T>>(index_size,
d_index,
s_index,
out_grad,
x,
x_grad,
pool_type,
nullptr,
x,
out);
}
}
......@@ -127,7 +119,7 @@ void GraphSendRecvGradOpKernelLaunchHelper(
template <typename T, typename Context>
void GraphSendRecvGradKernel(const Context& ctx,
const DenseTensor& out_grad,
paddle::optional<const DenseTensor&> x,
const DenseTensor& x,
paddle::optional<const DenseTensor&> out,
const DenseTensor& src_index,
const DenseTensor& dst_index,
......@@ -139,23 +131,23 @@ void GraphSendRecvGradKernel(const Context& ctx,
GraphSendRecvGradOpKernelLaunchHelper<Context, T, int32_t>(
ctx,
out_grad,
x,
src_index,
dst_index,
pool_type,
x_grad,
dst_count.get_ptr(),
x.get_ptr(),
out.get_ptr());
} else if (index_type == phi::DataType::INT64) {
GraphSendRecvGradOpKernelLaunchHelper<Context, T, int64_t>(
ctx,
out_grad,
x,
src_index,
dst_index,
pool_type,
x_grad,
dst_count.get_ptr(),
x.get_ptr(),
out.get_ptr());
}
}
......
......@@ -83,6 +83,7 @@ void GraphSendRecvOpKernelLaunchHelper(const Context& ctx,
const DenseTensor& src_index,
const DenseTensor& dst_index,
const std::string& pool_type,
int64_t out_size,
DenseTensor* out,
DenseTensor* dst_count = nullptr) {
const int& index_size = src_index.dims()[0];
......@@ -91,7 +92,16 @@ void GraphSendRecvOpKernelLaunchHelper(const Context& ctx,
T* p_output = out->data<T>();
const auto& src_dims = x.dims();
int64_t memset_size = 1;
for (int i = 0; i < src_dims.size(); ++i) memset_size *= src_dims[i];
if (out_size <= 0) {
for (int i = 0; i < src_dims.size(); ++i) {
memset_size *= src_dims[i];
}
} else {
memset_size = out_size;
for (int i = 1; i < src_dims.size(); ++i) {
memset_size *= src_dims[i];
}
}
const size_t& memset_bytes = memset_size * sizeof(T);
memset(p_output, 0, memset_bytes);
......@@ -129,15 +139,16 @@ void GraphSendRecvKernel(const Context& ctx,
const DenseTensor& src_index,
const DenseTensor& dst_index,
const std::string& pool_type,
int64_t out_size,
DenseTensor* out,
DenseTensor* dst_count) {
auto index_type = src_index.dtype();
if (index_type == phi::DataType::INT32) {
GraphSendRecvOpKernelLaunchHelper<Context, T, int32_t>(
ctx, x, src_index, dst_index, pool_type, out, dst_count);
ctx, x, src_index, dst_index, pool_type, out_size, out, dst_count);
} else if (index_type == phi::DataType::INT64) {
GraphSendRecvOpKernelLaunchHelper<Context, T, int64_t>(
ctx, x, src_index, dst_index, pool_type, out, dst_count);
ctx, x, src_index, dst_index, pool_type, out_size, out, dst_count);
}
}
......
......@@ -28,19 +28,19 @@ template <typename Context, typename T, typename IndexT>
void GraphSendRecvGradOpCUDAKernelLaunchHelper(
const Context& ctx,
const DenseTensor& out_grad,
const DenseTensor& x,
const DenseTensor& src_index,
const DenseTensor& dst_index,
const std::string& pool_type,
DenseTensor* x_grad,
const DenseTensor* dst_count = nullptr,
const DenseTensor* x = nullptr,
const DenseTensor* out = nullptr) {
const int& index_size = dst_index.dims()[0];
ctx.template Alloc<T>(x_grad);
T* p_output = x_grad->data<T>();
const auto& src_dims = out_grad.dims();
const auto& src_dims = x.dims();
int64_t memset_size = 1;
for (int i = 0; i < src_dims.size(); ++i) {
memset_size *= src_dims[i];
......@@ -86,7 +86,7 @@ void GraphSendRecvGradOpCUDAKernelLaunchHelper(
ManipulateMeanGradCUDAKernel<T, IndexT><<<grid, block, 0, ctx.stream()>>>(
p_src, d_index, s_index, p_output, index_size, slice_size, s_count);
} else if (pool_type == "MAX" || pool_type == "MIN") {
const T* ptr_input = x->data<T>();
const T* ptr_input = x.data<T>();
const T* ptr_output = out->data<T>();
ManipulateMinMaxGradCUDAKernel<T, IndexT><<<grid, block, 0, ctx.stream()>>>(
p_src,
......@@ -103,7 +103,7 @@ void GraphSendRecvGradOpCUDAKernelLaunchHelper(
template <typename T, typename Context>
void GraphSendRecvGradKernel(const Context& ctx,
const DenseTensor& out_grad,
paddle::optional<const DenseTensor&> x,
const DenseTensor& x,
paddle::optional<const DenseTensor&> out,
const DenseTensor& src_index,
const DenseTensor& dst_index,
......@@ -115,23 +115,23 @@ void GraphSendRecvGradKernel(const Context& ctx,
GraphSendRecvGradOpCUDAKernelLaunchHelper<Context, T, int32_t>(
ctx,
out_grad,
x,
src_index,
dst_index,
pool_type,
x_grad,
dst_count.get_ptr(),
x.get_ptr(),
out.get_ptr());
} else if (index_type == phi::DataType::INT64) {
GraphSendRecvGradOpCUDAKernelLaunchHelper<Context, T, int64_t>(
ctx,
out_grad,
x,
src_index,
dst_index,
pool_type,
x_grad,
dst_count.get_ptr(),
x.get_ptr(),
out.get_ptr());
}
}
......
......@@ -32,6 +32,7 @@ void GraphSendRecvOpCUDAKernelLaunchHelper(const Context& ctx,
const DenseTensor& src_index,
const DenseTensor& dst_index,
const std::string& pool_type,
int64_t out_size,
DenseTensor* out,
DenseTensor* dst_count = nullptr) {
const int& index_size = src_index.dims()[0];
......@@ -39,9 +40,16 @@ void GraphSendRecvOpCUDAKernelLaunchHelper(const Context& ctx,
T* p_output = out->data<T>();
const auto& src_dims = x.dims();
int64_t memset_size = 1;
if (out_size <= 0) {
for (int i = 0; i < src_dims.size(); ++i) {
memset_size *= src_dims[i];
}
} else {
memset_size = out_size;
for (int i = 1; i < src_dims.size(); ++i) {
memset_size *= src_dims[i];
}
}
const size_t& memset_bytes = memset_size * sizeof(T);
if (pool_type == "SUM" || pool_type == "MEAN") {
#ifdef PADDLE_WITH_HIP
......@@ -100,6 +108,9 @@ void GraphSendRecvOpCUDAKernelLaunchHelper(const Context& ctx,
IndexT>><<<grid, block, 0, ctx.stream()>>>(
p_src, s_index, d_index, p_output, index_size, slice_size, functor);
if (out_size > 0) {
input_size = out_size;
}
int64_t grid_max_tmp = (input_size * slice_size + block - 1) / block;
int64_t grid_max =
grid_max_tmp < max_grid_dimx ? grid_max_tmp : max_grid_dimx;
......@@ -114,6 +125,9 @@ void GraphSendRecvOpCUDAKernelLaunchHelper(const Context& ctx,
IndexT>><<<grid, block, 0, ctx.stream()>>>(
p_src, s_index, d_index, p_output, index_size, slice_size, functor);
if (out_size > 0) {
input_size = out_size;
}
int64_t grid_min_tmp = (input_size * slice_size + block - 1) / block;
int64_t grid_min =
grid_min_tmp < max_grid_dimx ? grid_min_tmp : max_grid_dimx;
......@@ -130,6 +144,9 @@ void GraphSendRecvOpCUDAKernelLaunchHelper(const Context& ctx,
ctx.template Alloc<int32_t>(dst_count);
int32_t* p_dst_count = dst_count->data<int32_t>();
if (out_size > 0) {
input_size = out_size;
}
#ifdef PADDLE_WITH_HIP
hipMemset(p_dst_count, 0, input_size * sizeof(int));
......@@ -155,15 +172,16 @@ void GraphSendRecvKernel(const Context& ctx,
const DenseTensor& src_index,
const DenseTensor& dst_index,
const std::string& pool_type,
int64_t out_size,
DenseTensor* out,
DenseTensor* dst_count) {
auto index_type = src_index.dtype();
if (index_type == phi::DataType::INT32) {
GraphSendRecvOpCUDAKernelLaunchHelper<Context, T, int32_t>(
ctx, x, src_index, dst_index, pool_type, out, dst_count);
ctx, x, src_index, dst_index, pool_type, out_size, out, dst_count);
} else if (index_type == phi::DataType::INT64) {
GraphSendRecvOpCUDAKernelLaunchHelper<Context, T, int64_t>(
ctx, x, src_index, dst_index, pool_type, out, dst_count);
ctx, x, src_index, dst_index, pool_type, out_size, out, dst_count);
}
}
......
......@@ -23,7 +23,7 @@ namespace phi {
template <typename T, typename Context>
void GraphSendRecvGradKernel(const Context& ctx,
const DenseTensor& out_grad,
paddle::optional<const DenseTensor&> x,
const DenseTensor& x,
paddle::optional<const DenseTensor&> out,
const DenseTensor& src_index,
const DenseTensor& dst_index,
......
......@@ -25,6 +25,7 @@ void GraphSendRecvKernel(const Context& ctx,
const DenseTensor& src_index,
const DenseTensor& dst_index,
const std::string& pool_type,
int64_t out_size,
DenseTensor* out,
DenseTensor* dst_count);
......
......@@ -16,6 +16,14 @@ limitations under the License. */
namespace phi {
KernelSignature GraphSendRecvOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("graph_send_recv",
{"X", "Src_index", "Dst_index"},
{"pool_type", "out_size"},
{"Out", "Dst_count"});
}
KernelSignature GraphSendRecvGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature(
......@@ -27,5 +35,8 @@ KernelSignature GraphSendRecvGradOpArgumentMapping(
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(graph_send_recv,
phi::GraphSendRecvOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(graph_send_recv_grad,
phi::GraphSendRecvGradOpArgumentMapping);
......@@ -304,6 +304,35 @@ class API_GraphSendRecvOpTest(unittest.TestCase):
"two value is\
{}\n{}, check diff!".format(np_res, ret_res))
def test_set_outsize_gpu(self):
if paddle.fluid.core.is_compiled_with_cuda():
x = paddle.to_tensor(
np.array([[0, 2, 3], [1, 4, 5], [2, 6, 6]]), dtype="float32")
src_index = paddle.to_tensor(np.array([0, 0, 1]), dtype="int32")
dst_index = paddle.to_tensor(np.array([0, 1, 1]), dtype="int32")
res = paddle.incubate.graph_send_recv(x, src_index, dst_index,
"sum")
out_size = paddle.max(dst_index) + 1
res_set_outsize = paddle.incubate.graph_send_recv(
x, src_index, dst_index, "sum", out_size)
np_res = np.array(
[[0, 2, 3], [1, 6, 8], [0, 0, 0]], dtype="float32")
np_res_set_outsize = np.array(
[[0, 2, 3], [1, 6, 8]], dtype="float32")
self.assertTrue(
np.allclose(
np_res, res, atol=1e-6),
"two value is\
{}\n{}, check diff!".format(np_res, res))
self.assertTrue(
np.allclose(
np_res_set_outsize, res_set_outsize, atol=1e-6),
"two value is\
{}\n{}, check diff!"
.format(np_res_set_outsize, res_set_outsize))
if __name__ == '__main__':
unittest.main()
......@@ -19,7 +19,12 @@ from paddle.fluid import core
from paddle import _C_ops
def graph_send_recv(x, src_index, dst_index, pool_type="sum", name=None):
def graph_send_recv(x,
src_index,
dst_index,
pool_type="sum",
out_size=None,
name=None):
r"""
Graph Learning Send_Recv combine operator.
......@@ -27,7 +32,7 @@ def graph_send_recv(x, src_index, dst_index, pool_type="sum", name=None):
This operator is mainly used in Graph Learning domain, and the main purpose is to reduce intermediate memory
consumption in the process of message passing. Take `x` as the input tensor, we first use `src_index`
to gather the corresponding data, and then use `dst_index` to update the corresponding position of output tensor
in different pooling types, like sum, mean, max, or min.
in different pooling types, like sum, mean, max, or min. Besides, we can set `out_size` to get necessary output shape.
.. code-block:: text
......@@ -43,6 +48,8 @@ def graph_send_recv(x, src_index, dst_index, pool_type="sum", name=None):
pool_type = "sum"
out_size = None
Then:
Out = [[0, 2, 3],
......@@ -56,6 +63,9 @@ def graph_send_recv(x, src_index, dst_index, pool_type="sum", name=None):
The available data type is int32, int64.
pool_type (str): The pooling type of graph_send_recv, including `sum`, `mean`, `max`, `min`.
Default value is `sum`.
out_size (int64|None): We can set `out_size` to get necessary output shape. If not set, then this
attribute will not be used. If set, it should be equal with or larger than
max(dst_index) + 1.
name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
......@@ -75,6 +85,21 @@ def graph_send_recv(x, src_index, dst_index, pool_type="sum", name=None):
out = paddle.incubate.graph_send_recv(x, src_index, dst_index, pool_type="sum")
# Outputs: [[0., 2., 3.], [2., 8., 10.], [1., 4., 5.]]
x = paddle.to_tensor([[0, 2, 3], [1, 4, 5], [2, 6, 7]], dtype="float32")
indexes = paddle.to_tensor([[0, 1], [2, 1], [0, 0]], dtype="int32")
src_index = indexes[:, 0]
dst_index = indexes[:, 1]
out_size = paddle.max(dst_index) + 1
out = paddle.incubate.graph_send_recv(x, src_index, dst_index, pool_type="sum", out_size=out_size)
# Outputs: [[0., 2., 3.], [[2., 8., 10.]]]
x = paddle.to_tensor([[0, 2, 3], [1, 4, 5], [2, 6, 7]], dtype="float32")
indexes = paddle.to_tensor([[0, 1], [2, 1], [0, 0]], dtype="int32")
src_index = indexes[:, 0]
dst_index = indexes[:, 1]
out = paddle.incubate.graph_send_recv(x, src_index, dst_index, pool_type="sum")
# Outputs: [[0., 2., 3.], [2., 8., 10.], [0., 0., 0.]]
"""
if pool_type not in ["sum", "mean", "max", "min"]:
......@@ -82,9 +107,16 @@ def graph_send_recv(x, src_index, dst_index, pool_type="sum", name=None):
"pool_type should be `sum`, `mean`, `max` or `min`, but received %s"
% pool_type)
# TODO(daisiming): Should we add judgement for out_size: max(dst_index) + 1.
if in_dygraph_mode():
out, tmp = _C_ops.graph_send_recv(x, src_index, dst_index, 'pool_type',
pool_type.upper())
if out_size is None or out_size <= 0:
out, tmp = _C_ops.graph_send_recv(x, src_index, dst_index,
'pool_type', pool_type.upper())
else:
out, tmp = _C_ops.graph_send_recv(
x, src_index, dst_index, 'pool_type',
pool_type.upper(), 'out_size', out_size)
return out
check_variable_and_dtype(x, "X", ("float32", "float64", "int32", "int64"),
......@@ -105,5 +137,8 @@ def graph_send_recv(x, src_index, dst_index, pool_type="sum", name=None):
"Dst_index": dst_index},
outputs={"Out": out,
"Dst_count": dst_count},
attrs={"pool_type": pool_type.upper()})
attrs={
"pool_type": pool_type.upper(),
"out_size": 0 if out_size is None or out_size <= 0 else out_size
})
return out
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册