From 67b46e4575afaf737bb293fc05dca503cee14cd2 Mon Sep 17 00:00:00 2001 From: Siming Dai <908660116@qq.com> Date: Tue, 22 Mar 2022 11:28:16 +0800 Subject: [PATCH] [phi] Update graph_send_recv OP (#40509) * add out_size shape for graph_send_recv * fix bug in register kernel: no const int& support * add out_size in infermeta * change unittest * fix unittest * fix out_size default value * fix doc * delete arg mapping * add sig * move -1 to 0 * move -1 to 0 --- paddle/fluid/operators/graph_send_recv_op.cc | 10 ++++- paddle/phi/infermeta/ternary.cc | 17 ++++++- paddle/phi/infermeta/ternary.h | 1 + .../cpu/graph_send_recv_grad_kernel.cc | 34 ++++++-------- .../phi/kernels/cpu/graph_send_recv_kernel.cc | 17 +++++-- .../gpu/graph_send_recv_grad_kernel.cu | 12 ++--- .../phi/kernels/gpu/graph_send_recv_kernel.cu | 26 +++++++++-- .../phi/kernels/graph_send_recv_grad_kernel.h | 2 +- paddle/phi/kernels/graph_send_recv_kernel.h | 1 + paddle/phi/ops/compat/graph_send_recv_sig.cc | 11 +++++ .../unittests/test_graph_send_recv_op.py | 29 ++++++++++++ .../incubate/operators/graph_send_recv.py | 45 ++++++++++++++++--- 12 files changed, 161 insertions(+), 44 deletions(-) diff --git a/paddle/fluid/operators/graph_send_recv_op.cc b/paddle/fluid/operators/graph_send_recv_op.cc index f7c006dbcb1..f67dea74028 100644 --- a/paddle/fluid/operators/graph_send_recv_op.cc +++ b/paddle/fluid/operators/graph_send_recv_op.cc @@ -38,7 +38,7 @@ class GraphSendRecvGradOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext* ctx) const override { - auto in_dims = ctx->GetInputDim(framework::GradVarName("Out")); + auto in_dims = ctx->GetInputDim("X"); ctx->SetOutputDim(framework::GradVarName("X"), in_dims); } @@ -68,6 +68,12 @@ class GraphSendRecvOpMaker : public framework::OpProtoAndCheckerMaker { "tensors of Dst_index.") .SetDefault("SUM") .InEnum({"SUM", "MEAN", "MIN", "MAX"}); + AddAttr( + "out_size", + "(int64_t, default 0)" + "Define the first dimension of Output tensor." + "If set default 0, then the shape of Out is the same with X.") + .SetDefault(0); AddComment(R"DOC( Graph Learning Send_Recv combine operator. @@ -93,6 +99,7 @@ class GraphSendRecvGradOpMaker : public framework::SingleGradOpMaker { op->SetType("graph_send_recv_grad"); op->SetInput("Src_index", this->Input("Src_index")); op->SetInput("Dst_index", this->Input("Dst_index")); + op->SetInput("X", this->Input("X")); if (BOOST_GET_CONST(std::string, this->GetAttr("pool_type")) == "MEAN") { op->SetInput("Dst_count", this->Output("Dst_count")); @@ -100,7 +107,6 @@ class GraphSendRecvGradOpMaker : public framework::SingleGradOpMaker { if (BOOST_GET_CONST(std::string, this->GetAttr("pool_type")) == "MIN" || BOOST_GET_CONST(std::string, this->GetAttr("pool_type")) == "MAX") { - op->SetInput("X", this->Input("X")); op->SetInput("Out", this->Output("Out")); } diff --git a/paddle/phi/infermeta/ternary.cc b/paddle/phi/infermeta/ternary.cc index 556fb874470..a72b8d913f8 100644 --- a/paddle/phi/infermeta/ternary.cc +++ b/paddle/phi/infermeta/ternary.cc @@ -145,6 +145,7 @@ void GraphSendRecvInferMeta(const MetaTensor& x, const MetaTensor& src_index, const MetaTensor& dst_index, const std::string& pool_type, + int64_t out_size, MetaTensor* out, MetaTensor* dst_count) { auto src_index_dims = src_index.dims(); @@ -187,11 +188,23 @@ void GraphSendRecvInferMeta(const MetaTensor& x, "Src_index and Dst_index should have the same shape.")); auto dims = x.dims(); - out->set_dims(dims); + if (out_size <= 0) { + out->set_dims(dims); + } else { + std::vector dims_ = phi::vectorize(dims); + if (dims_.size() > 0) { + dims_[0] = out_size; + } + out->set_dims(phi::make_ddim(dims_)); + } out->set_dtype(x.dtype()); if (pool_type == "MEAN") { - dst_count->set_dims({dims[0]}); + if (out_size <= 0) { + dst_count->set_dims({dims[0]}); + } else { + dst_count->set_dims({out_size}); + } dst_count->set_dtype(DataType::INT32); } } diff --git a/paddle/phi/infermeta/ternary.h b/paddle/phi/infermeta/ternary.h index 42a0f35dc1d..8521a1ee855 100644 --- a/paddle/phi/infermeta/ternary.h +++ b/paddle/phi/infermeta/ternary.h @@ -51,6 +51,7 @@ void GraphSendRecvInferMeta(const MetaTensor& x, const MetaTensor& src_index, const MetaTensor& dst_index, const std::string& pool_type, + int64_t out_size, MetaTensor* out, MetaTensor* dst_count); diff --git a/paddle/phi/kernels/cpu/graph_send_recv_grad_kernel.cc b/paddle/phi/kernels/cpu/graph_send_recv_grad_kernel.cc index 8538461b1b8..6a83cee1ae4 100644 --- a/paddle/phi/kernels/cpu/graph_send_recv_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/graph_send_recv_grad_kernel.cc @@ -23,15 +23,14 @@ namespace phi { template -void GraphSendRecvCpuGradLoop(const int& input_size, - const int& index_size, +void GraphSendRecvCpuGradLoop(const int& index_size, const IndexT* s_index, const IndexT* d_index, const DenseTensor& src, + const DenseTensor& input, DenseTensor* dst, const std::string& pool_type, const int* dst_count = nullptr, - const DenseTensor* input = nullptr, const DenseTensor* output = nullptr) { if (pool_type == "SUM") { Functor functor; @@ -55,7 +54,7 @@ void GraphSendRecvCpuGradLoop(const int& input_size, for (int i = 0; i < index_size; ++i) { const IndexT& forward_src_idx = d_index[i]; const IndexT& forward_dst_idx = s_index[i]; - auto input_slice = input->Slice(forward_src_idx, forward_src_idx + 1); + auto input_slice = input.Slice(forward_src_idx, forward_src_idx + 1); auto output_slice = output->Slice(forward_dst_idx, forward_dst_idx + 1); auto eigen_input = phi::EigenVector::Flatten(input_slice); auto eigen_output = phi::EigenVector::Flatten(output_slice); @@ -73,18 +72,18 @@ template void GraphSendRecvGradOpKernelLaunchHelper( const Context& ctx, const DenseTensor& out_grad, + const DenseTensor& x, const DenseTensor& src_index, const DenseTensor& dst_index, const std::string& pool_type, DenseTensor* x_grad, const DenseTensor* dst_count = nullptr, - const DenseTensor* x = nullptr, const DenseTensor* out = nullptr) { const int& index_size = dst_index.dims()[0]; ctx.template Alloc(x_grad); T* p_output = x_grad->data(); - const auto& src_dims = out_grad.dims(); + const auto& src_dims = x.dims(); int64_t memset_size = 1; for (int i = 0; i < src_dims.size(); ++i) memset_size *= src_dims[i]; const size_t& memset_bytes = memset_size * sizeof(T); @@ -97,29 +96,22 @@ void GraphSendRecvGradOpKernelLaunchHelper( if (pool_type == "SUM") { GraphSendRecvCpuGradLoop>( - src_dims[0], index_size, d_index, s_index, out_grad, x_grad, pool_type); + index_size, d_index, s_index, out_grad, x, x_grad, pool_type); } else if (pool_type == "MEAN") { const int* s_count = dst_count->data(); // Functor not used here. - GraphSendRecvCpuGradLoop>(src_dims[0], - index_size, - d_index, - s_index, - out_grad, - x_grad, - pool_type, - s_count); + GraphSendRecvCpuGradLoop>( + index_size, d_index, s_index, out_grad, x, x_grad, pool_type, s_count); } else if (pool_type == "MIN" || pool_type == "MAX") { // Functor not used here. - GraphSendRecvCpuGradLoop>(src_dims[0], - index_size, + GraphSendRecvCpuGradLoop>(index_size, d_index, s_index, out_grad, + x, x_grad, pool_type, nullptr, - x, out); } } @@ -127,7 +119,7 @@ void GraphSendRecvGradOpKernelLaunchHelper( template void GraphSendRecvGradKernel(const Context& ctx, const DenseTensor& out_grad, - paddle::optional x, + const DenseTensor& x, paddle::optional out, const DenseTensor& src_index, const DenseTensor& dst_index, @@ -139,23 +131,23 @@ void GraphSendRecvGradKernel(const Context& ctx, GraphSendRecvGradOpKernelLaunchHelper( ctx, out_grad, + x, src_index, dst_index, pool_type, x_grad, dst_count.get_ptr(), - x.get_ptr(), out.get_ptr()); } else if (index_type == phi::DataType::INT64) { GraphSendRecvGradOpKernelLaunchHelper( ctx, out_grad, + x, src_index, dst_index, pool_type, x_grad, dst_count.get_ptr(), - x.get_ptr(), out.get_ptr()); } } diff --git a/paddle/phi/kernels/cpu/graph_send_recv_kernel.cc b/paddle/phi/kernels/cpu/graph_send_recv_kernel.cc index fecbd4b1d7a..8f71ba12cc4 100644 --- a/paddle/phi/kernels/cpu/graph_send_recv_kernel.cc +++ b/paddle/phi/kernels/cpu/graph_send_recv_kernel.cc @@ -83,6 +83,7 @@ void GraphSendRecvOpKernelLaunchHelper(const Context& ctx, const DenseTensor& src_index, const DenseTensor& dst_index, const std::string& pool_type, + int64_t out_size, DenseTensor* out, DenseTensor* dst_count = nullptr) { const int& index_size = src_index.dims()[0]; @@ -91,7 +92,16 @@ void GraphSendRecvOpKernelLaunchHelper(const Context& ctx, T* p_output = out->data(); const auto& src_dims = x.dims(); int64_t memset_size = 1; - for (int i = 0; i < src_dims.size(); ++i) memset_size *= src_dims[i]; + if (out_size <= 0) { + for (int i = 0; i < src_dims.size(); ++i) { + memset_size *= src_dims[i]; + } + } else { + memset_size = out_size; + for (int i = 1; i < src_dims.size(); ++i) { + memset_size *= src_dims[i]; + } + } const size_t& memset_bytes = memset_size * sizeof(T); memset(p_output, 0, memset_bytes); @@ -129,15 +139,16 @@ void GraphSendRecvKernel(const Context& ctx, const DenseTensor& src_index, const DenseTensor& dst_index, const std::string& pool_type, + int64_t out_size, DenseTensor* out, DenseTensor* dst_count) { auto index_type = src_index.dtype(); if (index_type == phi::DataType::INT32) { GraphSendRecvOpKernelLaunchHelper( - ctx, x, src_index, dst_index, pool_type, out, dst_count); + ctx, x, src_index, dst_index, pool_type, out_size, out, dst_count); } else if (index_type == phi::DataType::INT64) { GraphSendRecvOpKernelLaunchHelper( - ctx, x, src_index, dst_index, pool_type, out, dst_count); + ctx, x, src_index, dst_index, pool_type, out_size, out, dst_count); } } diff --git a/paddle/phi/kernels/gpu/graph_send_recv_grad_kernel.cu b/paddle/phi/kernels/gpu/graph_send_recv_grad_kernel.cu index 75692966b46..8bd3337280d 100644 --- a/paddle/phi/kernels/gpu/graph_send_recv_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/graph_send_recv_grad_kernel.cu @@ -28,19 +28,19 @@ template void GraphSendRecvGradOpCUDAKernelLaunchHelper( const Context& ctx, const DenseTensor& out_grad, + const DenseTensor& x, const DenseTensor& src_index, const DenseTensor& dst_index, const std::string& pool_type, DenseTensor* x_grad, const DenseTensor* dst_count = nullptr, - const DenseTensor* x = nullptr, const DenseTensor* out = nullptr) { const int& index_size = dst_index.dims()[0]; ctx.template Alloc(x_grad); T* p_output = x_grad->data(); - const auto& src_dims = out_grad.dims(); + const auto& src_dims = x.dims(); int64_t memset_size = 1; for (int i = 0; i < src_dims.size(); ++i) { memset_size *= src_dims[i]; @@ -86,7 +86,7 @@ void GraphSendRecvGradOpCUDAKernelLaunchHelper( ManipulateMeanGradCUDAKernel<<>>( p_src, d_index, s_index, p_output, index_size, slice_size, s_count); } else if (pool_type == "MAX" || pool_type == "MIN") { - const T* ptr_input = x->data(); + const T* ptr_input = x.data(); const T* ptr_output = out->data(); ManipulateMinMaxGradCUDAKernel<<>>( p_src, @@ -103,7 +103,7 @@ void GraphSendRecvGradOpCUDAKernelLaunchHelper( template void GraphSendRecvGradKernel(const Context& ctx, const DenseTensor& out_grad, - paddle::optional x, + const DenseTensor& x, paddle::optional out, const DenseTensor& src_index, const DenseTensor& dst_index, @@ -115,23 +115,23 @@ void GraphSendRecvGradKernel(const Context& ctx, GraphSendRecvGradOpCUDAKernelLaunchHelper( ctx, out_grad, + x, src_index, dst_index, pool_type, x_grad, dst_count.get_ptr(), - x.get_ptr(), out.get_ptr()); } else if (index_type == phi::DataType::INT64) { GraphSendRecvGradOpCUDAKernelLaunchHelper( ctx, out_grad, + x, src_index, dst_index, pool_type, x_grad, dst_count.get_ptr(), - x.get_ptr(), out.get_ptr()); } } diff --git a/paddle/phi/kernels/gpu/graph_send_recv_kernel.cu b/paddle/phi/kernels/gpu/graph_send_recv_kernel.cu index fab306f831a..2826c071d6e 100644 --- a/paddle/phi/kernels/gpu/graph_send_recv_kernel.cu +++ b/paddle/phi/kernels/gpu/graph_send_recv_kernel.cu @@ -32,6 +32,7 @@ void GraphSendRecvOpCUDAKernelLaunchHelper(const Context& ctx, const DenseTensor& src_index, const DenseTensor& dst_index, const std::string& pool_type, + int64_t out_size, DenseTensor* out, DenseTensor* dst_count = nullptr) { const int& index_size = src_index.dims()[0]; @@ -39,8 +40,15 @@ void GraphSendRecvOpCUDAKernelLaunchHelper(const Context& ctx, T* p_output = out->data(); const auto& src_dims = x.dims(); int64_t memset_size = 1; - for (int i = 0; i < src_dims.size(); ++i) { - memset_size *= src_dims[i]; + if (out_size <= 0) { + for (int i = 0; i < src_dims.size(); ++i) { + memset_size *= src_dims[i]; + } + } else { + memset_size = out_size; + for (int i = 1; i < src_dims.size(); ++i) { + memset_size *= src_dims[i]; + } } const size_t& memset_bytes = memset_size * sizeof(T); if (pool_type == "SUM" || pool_type == "MEAN") { @@ -100,6 +108,9 @@ void GraphSendRecvOpCUDAKernelLaunchHelper(const Context& ctx, IndexT>><<>>( p_src, s_index, d_index, p_output, index_size, slice_size, functor); + if (out_size > 0) { + input_size = out_size; + } int64_t grid_max_tmp = (input_size * slice_size + block - 1) / block; int64_t grid_max = grid_max_tmp < max_grid_dimx ? grid_max_tmp : max_grid_dimx; @@ -114,6 +125,9 @@ void GraphSendRecvOpCUDAKernelLaunchHelper(const Context& ctx, IndexT>><<>>( p_src, s_index, d_index, p_output, index_size, slice_size, functor); + if (out_size > 0) { + input_size = out_size; + } int64_t grid_min_tmp = (input_size * slice_size + block - 1) / block; int64_t grid_min = grid_min_tmp < max_grid_dimx ? grid_min_tmp : max_grid_dimx; @@ -130,6 +144,9 @@ void GraphSendRecvOpCUDAKernelLaunchHelper(const Context& ctx, ctx.template Alloc(dst_count); int32_t* p_dst_count = dst_count->data(); + if (out_size > 0) { + input_size = out_size; + } #ifdef PADDLE_WITH_HIP hipMemset(p_dst_count, 0, input_size * sizeof(int)); @@ -155,15 +172,16 @@ void GraphSendRecvKernel(const Context& ctx, const DenseTensor& src_index, const DenseTensor& dst_index, const std::string& pool_type, + int64_t out_size, DenseTensor* out, DenseTensor* dst_count) { auto index_type = src_index.dtype(); if (index_type == phi::DataType::INT32) { GraphSendRecvOpCUDAKernelLaunchHelper( - ctx, x, src_index, dst_index, pool_type, out, dst_count); + ctx, x, src_index, dst_index, pool_type, out_size, out, dst_count); } else if (index_type == phi::DataType::INT64) { GraphSendRecvOpCUDAKernelLaunchHelper( - ctx, x, src_index, dst_index, pool_type, out, dst_count); + ctx, x, src_index, dst_index, pool_type, out_size, out, dst_count); } } diff --git a/paddle/phi/kernels/graph_send_recv_grad_kernel.h b/paddle/phi/kernels/graph_send_recv_grad_kernel.h index d163e6e278a..3694c8f1e6c 100644 --- a/paddle/phi/kernels/graph_send_recv_grad_kernel.h +++ b/paddle/phi/kernels/graph_send_recv_grad_kernel.h @@ -23,7 +23,7 @@ namespace phi { template void GraphSendRecvGradKernel(const Context& ctx, const DenseTensor& out_grad, - paddle::optional x, + const DenseTensor& x, paddle::optional out, const DenseTensor& src_index, const DenseTensor& dst_index, diff --git a/paddle/phi/kernels/graph_send_recv_kernel.h b/paddle/phi/kernels/graph_send_recv_kernel.h index 95dbdc4443a..51768fbc18f 100644 --- a/paddle/phi/kernels/graph_send_recv_kernel.h +++ b/paddle/phi/kernels/graph_send_recv_kernel.h @@ -25,6 +25,7 @@ void GraphSendRecvKernel(const Context& ctx, const DenseTensor& src_index, const DenseTensor& dst_index, const std::string& pool_type, + int64_t out_size, DenseTensor* out, DenseTensor* dst_count); diff --git a/paddle/phi/ops/compat/graph_send_recv_sig.cc b/paddle/phi/ops/compat/graph_send_recv_sig.cc index dacb8b25a89..fa4da0704c9 100644 --- a/paddle/phi/ops/compat/graph_send_recv_sig.cc +++ b/paddle/phi/ops/compat/graph_send_recv_sig.cc @@ -16,6 +16,14 @@ limitations under the License. */ namespace phi { +KernelSignature GraphSendRecvOpArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature("graph_send_recv", + {"X", "Src_index", "Dst_index"}, + {"pool_type", "out_size"}, + {"Out", "Dst_count"}); +} + KernelSignature GraphSendRecvGradOpArgumentMapping( const ArgumentMappingContext& ctx) { return KernelSignature( @@ -27,5 +35,8 @@ KernelSignature GraphSendRecvGradOpArgumentMapping( } // namespace phi +PD_REGISTER_ARG_MAPPING_FN(graph_send_recv, + phi::GraphSendRecvOpArgumentMapping); + PD_REGISTER_ARG_MAPPING_FN(graph_send_recv_grad, phi::GraphSendRecvGradOpArgumentMapping); diff --git a/python/paddle/fluid/tests/unittests/test_graph_send_recv_op.py b/python/paddle/fluid/tests/unittests/test_graph_send_recv_op.py index 68b354775d1..30f943e3248 100644 --- a/python/paddle/fluid/tests/unittests/test_graph_send_recv_op.py +++ b/python/paddle/fluid/tests/unittests/test_graph_send_recv_op.py @@ -304,6 +304,35 @@ class API_GraphSendRecvOpTest(unittest.TestCase): "two value is\ {}\n{}, check diff!".format(np_res, ret_res)) + def test_set_outsize_gpu(self): + if paddle.fluid.core.is_compiled_with_cuda(): + x = paddle.to_tensor( + np.array([[0, 2, 3], [1, 4, 5], [2, 6, 6]]), dtype="float32") + src_index = paddle.to_tensor(np.array([0, 0, 1]), dtype="int32") + dst_index = paddle.to_tensor(np.array([0, 1, 1]), dtype="int32") + res = paddle.incubate.graph_send_recv(x, src_index, dst_index, + "sum") + out_size = paddle.max(dst_index) + 1 + res_set_outsize = paddle.incubate.graph_send_recv( + x, src_index, dst_index, "sum", out_size) + + np_res = np.array( + [[0, 2, 3], [1, 6, 8], [0, 0, 0]], dtype="float32") + np_res_set_outsize = np.array( + [[0, 2, 3], [1, 6, 8]], dtype="float32") + + self.assertTrue( + np.allclose( + np_res, res, atol=1e-6), + "two value is\ + {}\n{}, check diff!".format(np_res, res)) + self.assertTrue( + np.allclose( + np_res_set_outsize, res_set_outsize, atol=1e-6), + "two value is\ + {}\n{}, check diff!" + .format(np_res_set_outsize, res_set_outsize)) + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/incubate/operators/graph_send_recv.py b/python/paddle/incubate/operators/graph_send_recv.py index 45810621e42..05f6a80a442 100644 --- a/python/paddle/incubate/operators/graph_send_recv.py +++ b/python/paddle/incubate/operators/graph_send_recv.py @@ -19,7 +19,12 @@ from paddle.fluid import core from paddle import _C_ops -def graph_send_recv(x, src_index, dst_index, pool_type="sum", name=None): +def graph_send_recv(x, + src_index, + dst_index, + pool_type="sum", + out_size=None, + name=None): r""" Graph Learning Send_Recv combine operator. @@ -27,7 +32,7 @@ def graph_send_recv(x, src_index, dst_index, pool_type="sum", name=None): This operator is mainly used in Graph Learning domain, and the main purpose is to reduce intermediate memory consumption in the process of message passing. Take `x` as the input tensor, we first use `src_index` to gather the corresponding data, and then use `dst_index` to update the corresponding position of output tensor - in different pooling types, like sum, mean, max, or min. + in different pooling types, like sum, mean, max, or min. Besides, we can set `out_size` to get necessary output shape. .. code-block:: text @@ -43,6 +48,8 @@ def graph_send_recv(x, src_index, dst_index, pool_type="sum", name=None): pool_type = "sum" + out_size = None + Then: Out = [[0, 2, 3], @@ -56,6 +63,9 @@ def graph_send_recv(x, src_index, dst_index, pool_type="sum", name=None): The available data type is int32, int64. pool_type (str): The pooling type of graph_send_recv, including `sum`, `mean`, `max`, `min`. Default value is `sum`. + out_size (int64|None): We can set `out_size` to get necessary output shape. If not set, then this + attribute will not be used. If set, it should be equal with or larger than + max(dst_index) + 1. name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. @@ -75,6 +85,21 @@ def graph_send_recv(x, src_index, dst_index, pool_type="sum", name=None): out = paddle.incubate.graph_send_recv(x, src_index, dst_index, pool_type="sum") # Outputs: [[0., 2., 3.], [2., 8., 10.], [1., 4., 5.]] + x = paddle.to_tensor([[0, 2, 3], [1, 4, 5], [2, 6, 7]], dtype="float32") + indexes = paddle.to_tensor([[0, 1], [2, 1], [0, 0]], dtype="int32") + src_index = indexes[:, 0] + dst_index = indexes[:, 1] + out_size = paddle.max(dst_index) + 1 + out = paddle.incubate.graph_send_recv(x, src_index, dst_index, pool_type="sum", out_size=out_size) + # Outputs: [[0., 2., 3.], [[2., 8., 10.]]] + + x = paddle.to_tensor([[0, 2, 3], [1, 4, 5], [2, 6, 7]], dtype="float32") + indexes = paddle.to_tensor([[0, 1], [2, 1], [0, 0]], dtype="int32") + src_index = indexes[:, 0] + dst_index = indexes[:, 1] + out = paddle.incubate.graph_send_recv(x, src_index, dst_index, pool_type="sum") + # Outputs: [[0., 2., 3.], [2., 8., 10.], [0., 0., 0.]] + """ if pool_type not in ["sum", "mean", "max", "min"]: @@ -82,9 +107,16 @@ def graph_send_recv(x, src_index, dst_index, pool_type="sum", name=None): "pool_type should be `sum`, `mean`, `max` or `min`, but received %s" % pool_type) + # TODO(daisiming): Should we add judgement for out_size: max(dst_index) + 1. + if in_dygraph_mode(): - out, tmp = _C_ops.graph_send_recv(x, src_index, dst_index, 'pool_type', - pool_type.upper()) + if out_size is None or out_size <= 0: + out, tmp = _C_ops.graph_send_recv(x, src_index, dst_index, + 'pool_type', pool_type.upper()) + else: + out, tmp = _C_ops.graph_send_recv( + x, src_index, dst_index, 'pool_type', + pool_type.upper(), 'out_size', out_size) return out check_variable_and_dtype(x, "X", ("float32", "float64", "int32", "int64"), @@ -105,5 +137,8 @@ def graph_send_recv(x, src_index, dst_index, pool_type="sum", name=None): "Dst_index": dst_index}, outputs={"Out": out, "Dst_count": dst_count}, - attrs={"pool_type": pool_type.upper()}) + attrs={ + "pool_type": pool_type.upper(), + "out_size": 0 if out_size is None or out_size <= 0 else out_size + }) return out -- GitLab