未验证 提交 67b46e45 编写于 作者: S Siming Dai 提交者: GitHub

[phi] Update graph_send_recv OP (#40509)

* add out_size shape for graph_send_recv

* fix bug in register kernel: no const int& support

* add out_size in infermeta

* change unittest

* fix unittest

* fix out_size default value

* fix doc

* delete arg mapping

* add sig

* move -1 to 0

* move -1 to 0
上级 dec9094d
...@@ -38,7 +38,7 @@ class GraphSendRecvGradOp : public framework::OperatorWithKernel { ...@@ -38,7 +38,7 @@ class GraphSendRecvGradOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
auto in_dims = ctx->GetInputDim(framework::GradVarName("Out")); auto in_dims = ctx->GetInputDim("X");
ctx->SetOutputDim(framework::GradVarName("X"), in_dims); ctx->SetOutputDim(framework::GradVarName("X"), in_dims);
} }
...@@ -68,6 +68,12 @@ class GraphSendRecvOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -68,6 +68,12 @@ class GraphSendRecvOpMaker : public framework::OpProtoAndCheckerMaker {
"tensors of Dst_index.") "tensors of Dst_index.")
.SetDefault("SUM") .SetDefault("SUM")
.InEnum({"SUM", "MEAN", "MIN", "MAX"}); .InEnum({"SUM", "MEAN", "MIN", "MAX"});
AddAttr<int64_t>(
"out_size",
"(int64_t, default 0)"
"Define the first dimension of Output tensor."
"If set default 0, then the shape of Out is the same with X.")
.SetDefault(0);
AddComment(R"DOC( AddComment(R"DOC(
Graph Learning Send_Recv combine operator. Graph Learning Send_Recv combine operator.
...@@ -93,6 +99,7 @@ class GraphSendRecvGradOpMaker : public framework::SingleGradOpMaker<T> { ...@@ -93,6 +99,7 @@ class GraphSendRecvGradOpMaker : public framework::SingleGradOpMaker<T> {
op->SetType("graph_send_recv_grad"); op->SetType("graph_send_recv_grad");
op->SetInput("Src_index", this->Input("Src_index")); op->SetInput("Src_index", this->Input("Src_index"));
op->SetInput("Dst_index", this->Input("Dst_index")); op->SetInput("Dst_index", this->Input("Dst_index"));
op->SetInput("X", this->Input("X"));
if (BOOST_GET_CONST(std::string, this->GetAttr("pool_type")) == "MEAN") { if (BOOST_GET_CONST(std::string, this->GetAttr("pool_type")) == "MEAN") {
op->SetInput("Dst_count", this->Output("Dst_count")); op->SetInput("Dst_count", this->Output("Dst_count"));
...@@ -100,7 +107,6 @@ class GraphSendRecvGradOpMaker : public framework::SingleGradOpMaker<T> { ...@@ -100,7 +107,6 @@ class GraphSendRecvGradOpMaker : public framework::SingleGradOpMaker<T> {
if (BOOST_GET_CONST(std::string, this->GetAttr("pool_type")) == "MIN" || if (BOOST_GET_CONST(std::string, this->GetAttr("pool_type")) == "MIN" ||
BOOST_GET_CONST(std::string, this->GetAttr("pool_type")) == "MAX") { BOOST_GET_CONST(std::string, this->GetAttr("pool_type")) == "MAX") {
op->SetInput("X", this->Input("X"));
op->SetInput("Out", this->Output("Out")); op->SetInput("Out", this->Output("Out"));
} }
......
...@@ -145,6 +145,7 @@ void GraphSendRecvInferMeta(const MetaTensor& x, ...@@ -145,6 +145,7 @@ void GraphSendRecvInferMeta(const MetaTensor& x,
const MetaTensor& src_index, const MetaTensor& src_index,
const MetaTensor& dst_index, const MetaTensor& dst_index,
const std::string& pool_type, const std::string& pool_type,
int64_t out_size,
MetaTensor* out, MetaTensor* out,
MetaTensor* dst_count) { MetaTensor* dst_count) {
auto src_index_dims = src_index.dims(); auto src_index_dims = src_index.dims();
...@@ -187,11 +188,23 @@ void GraphSendRecvInferMeta(const MetaTensor& x, ...@@ -187,11 +188,23 @@ void GraphSendRecvInferMeta(const MetaTensor& x,
"Src_index and Dst_index should have the same shape.")); "Src_index and Dst_index should have the same shape."));
auto dims = x.dims(); auto dims = x.dims();
out->set_dims(dims); if (out_size <= 0) {
out->set_dims(dims);
} else {
std::vector<int64_t> dims_ = phi::vectorize(dims);
if (dims_.size() > 0) {
dims_[0] = out_size;
}
out->set_dims(phi::make_ddim(dims_));
}
out->set_dtype(x.dtype()); out->set_dtype(x.dtype());
if (pool_type == "MEAN") { if (pool_type == "MEAN") {
dst_count->set_dims({dims[0]}); if (out_size <= 0) {
dst_count->set_dims({dims[0]});
} else {
dst_count->set_dims({out_size});
}
dst_count->set_dtype(DataType::INT32); dst_count->set_dtype(DataType::INT32);
} }
} }
......
...@@ -51,6 +51,7 @@ void GraphSendRecvInferMeta(const MetaTensor& x, ...@@ -51,6 +51,7 @@ void GraphSendRecvInferMeta(const MetaTensor& x,
const MetaTensor& src_index, const MetaTensor& src_index,
const MetaTensor& dst_index, const MetaTensor& dst_index,
const std::string& pool_type, const std::string& pool_type,
int64_t out_size,
MetaTensor* out, MetaTensor* out,
MetaTensor* dst_count); MetaTensor* dst_count);
......
...@@ -23,15 +23,14 @@ ...@@ -23,15 +23,14 @@
namespace phi { namespace phi {
template <typename T, typename IndexT, typename Functor> template <typename T, typename IndexT, typename Functor>
void GraphSendRecvCpuGradLoop(const int& input_size, void GraphSendRecvCpuGradLoop(const int& index_size,
const int& index_size,
const IndexT* s_index, const IndexT* s_index,
const IndexT* d_index, const IndexT* d_index,
const DenseTensor& src, const DenseTensor& src,
const DenseTensor& input,
DenseTensor* dst, DenseTensor* dst,
const std::string& pool_type, const std::string& pool_type,
const int* dst_count = nullptr, const int* dst_count = nullptr,
const DenseTensor* input = nullptr,
const DenseTensor* output = nullptr) { const DenseTensor* output = nullptr) {
if (pool_type == "SUM") { if (pool_type == "SUM") {
Functor functor; Functor functor;
...@@ -55,7 +54,7 @@ void GraphSendRecvCpuGradLoop(const int& input_size, ...@@ -55,7 +54,7 @@ void GraphSendRecvCpuGradLoop(const int& input_size,
for (int i = 0; i < index_size; ++i) { for (int i = 0; i < index_size; ++i) {
const IndexT& forward_src_idx = d_index[i]; const IndexT& forward_src_idx = d_index[i];
const IndexT& forward_dst_idx = s_index[i]; const IndexT& forward_dst_idx = s_index[i];
auto input_slice = input->Slice(forward_src_idx, forward_src_idx + 1); auto input_slice = input.Slice(forward_src_idx, forward_src_idx + 1);
auto output_slice = output->Slice(forward_dst_idx, forward_dst_idx + 1); auto output_slice = output->Slice(forward_dst_idx, forward_dst_idx + 1);
auto eigen_input = phi::EigenVector<T>::Flatten(input_slice); auto eigen_input = phi::EigenVector<T>::Flatten(input_slice);
auto eigen_output = phi::EigenVector<T>::Flatten(output_slice); auto eigen_output = phi::EigenVector<T>::Flatten(output_slice);
...@@ -73,18 +72,18 @@ template <typename Context, typename T, typename IndexT> ...@@ -73,18 +72,18 @@ template <typename Context, typename T, typename IndexT>
void GraphSendRecvGradOpKernelLaunchHelper( void GraphSendRecvGradOpKernelLaunchHelper(
const Context& ctx, const Context& ctx,
const DenseTensor& out_grad, const DenseTensor& out_grad,
const DenseTensor& x,
const DenseTensor& src_index, const DenseTensor& src_index,
const DenseTensor& dst_index, const DenseTensor& dst_index,
const std::string& pool_type, const std::string& pool_type,
DenseTensor* x_grad, DenseTensor* x_grad,
const DenseTensor* dst_count = nullptr, const DenseTensor* dst_count = nullptr,
const DenseTensor* x = nullptr,
const DenseTensor* out = nullptr) { const DenseTensor* out = nullptr) {
const int& index_size = dst_index.dims()[0]; const int& index_size = dst_index.dims()[0];
ctx.template Alloc<T>(x_grad); ctx.template Alloc<T>(x_grad);
T* p_output = x_grad->data<T>(); T* p_output = x_grad->data<T>();
const auto& src_dims = out_grad.dims(); const auto& src_dims = x.dims();
int64_t memset_size = 1; int64_t memset_size = 1;
for (int i = 0; i < src_dims.size(); ++i) memset_size *= src_dims[i]; for (int i = 0; i < src_dims.size(); ++i) memset_size *= src_dims[i];
const size_t& memset_bytes = memset_size * sizeof(T); const size_t& memset_bytes = memset_size * sizeof(T);
...@@ -97,29 +96,22 @@ void GraphSendRecvGradOpKernelLaunchHelper( ...@@ -97,29 +96,22 @@ void GraphSendRecvGradOpKernelLaunchHelper(
if (pool_type == "SUM") { if (pool_type == "SUM") {
GraphSendRecvCpuGradLoop<T, IndexT, GraphSendRecvSumFunctor<T>>( GraphSendRecvCpuGradLoop<T, IndexT, GraphSendRecvSumFunctor<T>>(
src_dims[0], index_size, d_index, s_index, out_grad, x_grad, pool_type); index_size, d_index, s_index, out_grad, x, x_grad, pool_type);
} else if (pool_type == "MEAN") { } else if (pool_type == "MEAN") {
const int* s_count = dst_count->data<int>(); const int* s_count = dst_count->data<int>();
// Functor not used here. // Functor not used here.
GraphSendRecvCpuGradLoop<T, IndexT, GraphSendRecvSumFunctor<T>>(src_dims[0], GraphSendRecvCpuGradLoop<T, IndexT, GraphSendRecvSumFunctor<T>>(
index_size, index_size, d_index, s_index, out_grad, x, x_grad, pool_type, s_count);
d_index,
s_index,
out_grad,
x_grad,
pool_type,
s_count);
} else if (pool_type == "MIN" || pool_type == "MAX") { } else if (pool_type == "MIN" || pool_type == "MAX") {
// Functor not used here. // Functor not used here.
GraphSendRecvCpuGradLoop<T, IndexT, GraphSendRecvMinFunctor<T>>(src_dims[0], GraphSendRecvCpuGradLoop<T, IndexT, GraphSendRecvMinFunctor<T>>(index_size,
index_size,
d_index, d_index,
s_index, s_index,
out_grad, out_grad,
x,
x_grad, x_grad,
pool_type, pool_type,
nullptr, nullptr,
x,
out); out);
} }
} }
...@@ -127,7 +119,7 @@ void GraphSendRecvGradOpKernelLaunchHelper( ...@@ -127,7 +119,7 @@ void GraphSendRecvGradOpKernelLaunchHelper(
template <typename T, typename Context> template <typename T, typename Context>
void GraphSendRecvGradKernel(const Context& ctx, void GraphSendRecvGradKernel(const Context& ctx,
const DenseTensor& out_grad, const DenseTensor& out_grad,
paddle::optional<const DenseTensor&> x, const DenseTensor& x,
paddle::optional<const DenseTensor&> out, paddle::optional<const DenseTensor&> out,
const DenseTensor& src_index, const DenseTensor& src_index,
const DenseTensor& dst_index, const DenseTensor& dst_index,
...@@ -139,23 +131,23 @@ void GraphSendRecvGradKernel(const Context& ctx, ...@@ -139,23 +131,23 @@ void GraphSendRecvGradKernel(const Context& ctx,
GraphSendRecvGradOpKernelLaunchHelper<Context, T, int32_t>( GraphSendRecvGradOpKernelLaunchHelper<Context, T, int32_t>(
ctx, ctx,
out_grad, out_grad,
x,
src_index, src_index,
dst_index, dst_index,
pool_type, pool_type,
x_grad, x_grad,
dst_count.get_ptr(), dst_count.get_ptr(),
x.get_ptr(),
out.get_ptr()); out.get_ptr());
} else if (index_type == phi::DataType::INT64) { } else if (index_type == phi::DataType::INT64) {
GraphSendRecvGradOpKernelLaunchHelper<Context, T, int64_t>( GraphSendRecvGradOpKernelLaunchHelper<Context, T, int64_t>(
ctx, ctx,
out_grad, out_grad,
x,
src_index, src_index,
dst_index, dst_index,
pool_type, pool_type,
x_grad, x_grad,
dst_count.get_ptr(), dst_count.get_ptr(),
x.get_ptr(),
out.get_ptr()); out.get_ptr());
} }
} }
......
...@@ -83,6 +83,7 @@ void GraphSendRecvOpKernelLaunchHelper(const Context& ctx, ...@@ -83,6 +83,7 @@ void GraphSendRecvOpKernelLaunchHelper(const Context& ctx,
const DenseTensor& src_index, const DenseTensor& src_index,
const DenseTensor& dst_index, const DenseTensor& dst_index,
const std::string& pool_type, const std::string& pool_type,
int64_t out_size,
DenseTensor* out, DenseTensor* out,
DenseTensor* dst_count = nullptr) { DenseTensor* dst_count = nullptr) {
const int& index_size = src_index.dims()[0]; const int& index_size = src_index.dims()[0];
...@@ -91,7 +92,16 @@ void GraphSendRecvOpKernelLaunchHelper(const Context& ctx, ...@@ -91,7 +92,16 @@ void GraphSendRecvOpKernelLaunchHelper(const Context& ctx,
T* p_output = out->data<T>(); T* p_output = out->data<T>();
const auto& src_dims = x.dims(); const auto& src_dims = x.dims();
int64_t memset_size = 1; int64_t memset_size = 1;
for (int i = 0; i < src_dims.size(); ++i) memset_size *= src_dims[i]; if (out_size <= 0) {
for (int i = 0; i < src_dims.size(); ++i) {
memset_size *= src_dims[i];
}
} else {
memset_size = out_size;
for (int i = 1; i < src_dims.size(); ++i) {
memset_size *= src_dims[i];
}
}
const size_t& memset_bytes = memset_size * sizeof(T); const size_t& memset_bytes = memset_size * sizeof(T);
memset(p_output, 0, memset_bytes); memset(p_output, 0, memset_bytes);
...@@ -129,15 +139,16 @@ void GraphSendRecvKernel(const Context& ctx, ...@@ -129,15 +139,16 @@ void GraphSendRecvKernel(const Context& ctx,
const DenseTensor& src_index, const DenseTensor& src_index,
const DenseTensor& dst_index, const DenseTensor& dst_index,
const std::string& pool_type, const std::string& pool_type,
int64_t out_size,
DenseTensor* out, DenseTensor* out,
DenseTensor* dst_count) { DenseTensor* dst_count) {
auto index_type = src_index.dtype(); auto index_type = src_index.dtype();
if (index_type == phi::DataType::INT32) { if (index_type == phi::DataType::INT32) {
GraphSendRecvOpKernelLaunchHelper<Context, T, int32_t>( GraphSendRecvOpKernelLaunchHelper<Context, T, int32_t>(
ctx, x, src_index, dst_index, pool_type, out, dst_count); ctx, x, src_index, dst_index, pool_type, out_size, out, dst_count);
} else if (index_type == phi::DataType::INT64) { } else if (index_type == phi::DataType::INT64) {
GraphSendRecvOpKernelLaunchHelper<Context, T, int64_t>( GraphSendRecvOpKernelLaunchHelper<Context, T, int64_t>(
ctx, x, src_index, dst_index, pool_type, out, dst_count); ctx, x, src_index, dst_index, pool_type, out_size, out, dst_count);
} }
} }
......
...@@ -28,19 +28,19 @@ template <typename Context, typename T, typename IndexT> ...@@ -28,19 +28,19 @@ template <typename Context, typename T, typename IndexT>
void GraphSendRecvGradOpCUDAKernelLaunchHelper( void GraphSendRecvGradOpCUDAKernelLaunchHelper(
const Context& ctx, const Context& ctx,
const DenseTensor& out_grad, const DenseTensor& out_grad,
const DenseTensor& x,
const DenseTensor& src_index, const DenseTensor& src_index,
const DenseTensor& dst_index, const DenseTensor& dst_index,
const std::string& pool_type, const std::string& pool_type,
DenseTensor* x_grad, DenseTensor* x_grad,
const DenseTensor* dst_count = nullptr, const DenseTensor* dst_count = nullptr,
const DenseTensor* x = nullptr,
const DenseTensor* out = nullptr) { const DenseTensor* out = nullptr) {
const int& index_size = dst_index.dims()[0]; const int& index_size = dst_index.dims()[0];
ctx.template Alloc<T>(x_grad); ctx.template Alloc<T>(x_grad);
T* p_output = x_grad->data<T>(); T* p_output = x_grad->data<T>();
const auto& src_dims = out_grad.dims(); const auto& src_dims = x.dims();
int64_t memset_size = 1; int64_t memset_size = 1;
for (int i = 0; i < src_dims.size(); ++i) { for (int i = 0; i < src_dims.size(); ++i) {
memset_size *= src_dims[i]; memset_size *= src_dims[i];
...@@ -86,7 +86,7 @@ void GraphSendRecvGradOpCUDAKernelLaunchHelper( ...@@ -86,7 +86,7 @@ void GraphSendRecvGradOpCUDAKernelLaunchHelper(
ManipulateMeanGradCUDAKernel<T, IndexT><<<grid, block, 0, ctx.stream()>>>( ManipulateMeanGradCUDAKernel<T, IndexT><<<grid, block, 0, ctx.stream()>>>(
p_src, d_index, s_index, p_output, index_size, slice_size, s_count); p_src, d_index, s_index, p_output, index_size, slice_size, s_count);
} else if (pool_type == "MAX" || pool_type == "MIN") { } else if (pool_type == "MAX" || pool_type == "MIN") {
const T* ptr_input = x->data<T>(); const T* ptr_input = x.data<T>();
const T* ptr_output = out->data<T>(); const T* ptr_output = out->data<T>();
ManipulateMinMaxGradCUDAKernel<T, IndexT><<<grid, block, 0, ctx.stream()>>>( ManipulateMinMaxGradCUDAKernel<T, IndexT><<<grid, block, 0, ctx.stream()>>>(
p_src, p_src,
...@@ -103,7 +103,7 @@ void GraphSendRecvGradOpCUDAKernelLaunchHelper( ...@@ -103,7 +103,7 @@ void GraphSendRecvGradOpCUDAKernelLaunchHelper(
template <typename T, typename Context> template <typename T, typename Context>
void GraphSendRecvGradKernel(const Context& ctx, void GraphSendRecvGradKernel(const Context& ctx,
const DenseTensor& out_grad, const DenseTensor& out_grad,
paddle::optional<const DenseTensor&> x, const DenseTensor& x,
paddle::optional<const DenseTensor&> out, paddle::optional<const DenseTensor&> out,
const DenseTensor& src_index, const DenseTensor& src_index,
const DenseTensor& dst_index, const DenseTensor& dst_index,
...@@ -115,23 +115,23 @@ void GraphSendRecvGradKernel(const Context& ctx, ...@@ -115,23 +115,23 @@ void GraphSendRecvGradKernel(const Context& ctx,
GraphSendRecvGradOpCUDAKernelLaunchHelper<Context, T, int32_t>( GraphSendRecvGradOpCUDAKernelLaunchHelper<Context, T, int32_t>(
ctx, ctx,
out_grad, out_grad,
x,
src_index, src_index,
dst_index, dst_index,
pool_type, pool_type,
x_grad, x_grad,
dst_count.get_ptr(), dst_count.get_ptr(),
x.get_ptr(),
out.get_ptr()); out.get_ptr());
} else if (index_type == phi::DataType::INT64) { } else if (index_type == phi::DataType::INT64) {
GraphSendRecvGradOpCUDAKernelLaunchHelper<Context, T, int64_t>( GraphSendRecvGradOpCUDAKernelLaunchHelper<Context, T, int64_t>(
ctx, ctx,
out_grad, out_grad,
x,
src_index, src_index,
dst_index, dst_index,
pool_type, pool_type,
x_grad, x_grad,
dst_count.get_ptr(), dst_count.get_ptr(),
x.get_ptr(),
out.get_ptr()); out.get_ptr());
} }
} }
......
...@@ -32,6 +32,7 @@ void GraphSendRecvOpCUDAKernelLaunchHelper(const Context& ctx, ...@@ -32,6 +32,7 @@ void GraphSendRecvOpCUDAKernelLaunchHelper(const Context& ctx,
const DenseTensor& src_index, const DenseTensor& src_index,
const DenseTensor& dst_index, const DenseTensor& dst_index,
const std::string& pool_type, const std::string& pool_type,
int64_t out_size,
DenseTensor* out, DenseTensor* out,
DenseTensor* dst_count = nullptr) { DenseTensor* dst_count = nullptr) {
const int& index_size = src_index.dims()[0]; const int& index_size = src_index.dims()[0];
...@@ -39,8 +40,15 @@ void GraphSendRecvOpCUDAKernelLaunchHelper(const Context& ctx, ...@@ -39,8 +40,15 @@ void GraphSendRecvOpCUDAKernelLaunchHelper(const Context& ctx,
T* p_output = out->data<T>(); T* p_output = out->data<T>();
const auto& src_dims = x.dims(); const auto& src_dims = x.dims();
int64_t memset_size = 1; int64_t memset_size = 1;
for (int i = 0; i < src_dims.size(); ++i) { if (out_size <= 0) {
memset_size *= src_dims[i]; for (int i = 0; i < src_dims.size(); ++i) {
memset_size *= src_dims[i];
}
} else {
memset_size = out_size;
for (int i = 1; i < src_dims.size(); ++i) {
memset_size *= src_dims[i];
}
} }
const size_t& memset_bytes = memset_size * sizeof(T); const size_t& memset_bytes = memset_size * sizeof(T);
if (pool_type == "SUM" || pool_type == "MEAN") { if (pool_type == "SUM" || pool_type == "MEAN") {
...@@ -100,6 +108,9 @@ void GraphSendRecvOpCUDAKernelLaunchHelper(const Context& ctx, ...@@ -100,6 +108,9 @@ void GraphSendRecvOpCUDAKernelLaunchHelper(const Context& ctx,
IndexT>><<<grid, block, 0, ctx.stream()>>>( IndexT>><<<grid, block, 0, ctx.stream()>>>(
p_src, s_index, d_index, p_output, index_size, slice_size, functor); p_src, s_index, d_index, p_output, index_size, slice_size, functor);
if (out_size > 0) {
input_size = out_size;
}
int64_t grid_max_tmp = (input_size * slice_size + block - 1) / block; int64_t grid_max_tmp = (input_size * slice_size + block - 1) / block;
int64_t grid_max = int64_t grid_max =
grid_max_tmp < max_grid_dimx ? grid_max_tmp : max_grid_dimx; grid_max_tmp < max_grid_dimx ? grid_max_tmp : max_grid_dimx;
...@@ -114,6 +125,9 @@ void GraphSendRecvOpCUDAKernelLaunchHelper(const Context& ctx, ...@@ -114,6 +125,9 @@ void GraphSendRecvOpCUDAKernelLaunchHelper(const Context& ctx,
IndexT>><<<grid, block, 0, ctx.stream()>>>( IndexT>><<<grid, block, 0, ctx.stream()>>>(
p_src, s_index, d_index, p_output, index_size, slice_size, functor); p_src, s_index, d_index, p_output, index_size, slice_size, functor);
if (out_size > 0) {
input_size = out_size;
}
int64_t grid_min_tmp = (input_size * slice_size + block - 1) / block; int64_t grid_min_tmp = (input_size * slice_size + block - 1) / block;
int64_t grid_min = int64_t grid_min =
grid_min_tmp < max_grid_dimx ? grid_min_tmp : max_grid_dimx; grid_min_tmp < max_grid_dimx ? grid_min_tmp : max_grid_dimx;
...@@ -130,6 +144,9 @@ void GraphSendRecvOpCUDAKernelLaunchHelper(const Context& ctx, ...@@ -130,6 +144,9 @@ void GraphSendRecvOpCUDAKernelLaunchHelper(const Context& ctx,
ctx.template Alloc<int32_t>(dst_count); ctx.template Alloc<int32_t>(dst_count);
int32_t* p_dst_count = dst_count->data<int32_t>(); int32_t* p_dst_count = dst_count->data<int32_t>();
if (out_size > 0) {
input_size = out_size;
}
#ifdef PADDLE_WITH_HIP #ifdef PADDLE_WITH_HIP
hipMemset(p_dst_count, 0, input_size * sizeof(int)); hipMemset(p_dst_count, 0, input_size * sizeof(int));
...@@ -155,15 +172,16 @@ void GraphSendRecvKernel(const Context& ctx, ...@@ -155,15 +172,16 @@ void GraphSendRecvKernel(const Context& ctx,
const DenseTensor& src_index, const DenseTensor& src_index,
const DenseTensor& dst_index, const DenseTensor& dst_index,
const std::string& pool_type, const std::string& pool_type,
int64_t out_size,
DenseTensor* out, DenseTensor* out,
DenseTensor* dst_count) { DenseTensor* dst_count) {
auto index_type = src_index.dtype(); auto index_type = src_index.dtype();
if (index_type == phi::DataType::INT32) { if (index_type == phi::DataType::INT32) {
GraphSendRecvOpCUDAKernelLaunchHelper<Context, T, int32_t>( GraphSendRecvOpCUDAKernelLaunchHelper<Context, T, int32_t>(
ctx, x, src_index, dst_index, pool_type, out, dst_count); ctx, x, src_index, dst_index, pool_type, out_size, out, dst_count);
} else if (index_type == phi::DataType::INT64) { } else if (index_type == phi::DataType::INT64) {
GraphSendRecvOpCUDAKernelLaunchHelper<Context, T, int64_t>( GraphSendRecvOpCUDAKernelLaunchHelper<Context, T, int64_t>(
ctx, x, src_index, dst_index, pool_type, out, dst_count); ctx, x, src_index, dst_index, pool_type, out_size, out, dst_count);
} }
} }
......
...@@ -23,7 +23,7 @@ namespace phi { ...@@ -23,7 +23,7 @@ namespace phi {
template <typename T, typename Context> template <typename T, typename Context>
void GraphSendRecvGradKernel(const Context& ctx, void GraphSendRecvGradKernel(const Context& ctx,
const DenseTensor& out_grad, const DenseTensor& out_grad,
paddle::optional<const DenseTensor&> x, const DenseTensor& x,
paddle::optional<const DenseTensor&> out, paddle::optional<const DenseTensor&> out,
const DenseTensor& src_index, const DenseTensor& src_index,
const DenseTensor& dst_index, const DenseTensor& dst_index,
......
...@@ -25,6 +25,7 @@ void GraphSendRecvKernel(const Context& ctx, ...@@ -25,6 +25,7 @@ void GraphSendRecvKernel(const Context& ctx,
const DenseTensor& src_index, const DenseTensor& src_index,
const DenseTensor& dst_index, const DenseTensor& dst_index,
const std::string& pool_type, const std::string& pool_type,
int64_t out_size,
DenseTensor* out, DenseTensor* out,
DenseTensor* dst_count); DenseTensor* dst_count);
......
...@@ -16,6 +16,14 @@ limitations under the License. */ ...@@ -16,6 +16,14 @@ limitations under the License. */
namespace phi { namespace phi {
KernelSignature GraphSendRecvOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("graph_send_recv",
{"X", "Src_index", "Dst_index"},
{"pool_type", "out_size"},
{"Out", "Dst_count"});
}
KernelSignature GraphSendRecvGradOpArgumentMapping( KernelSignature GraphSendRecvGradOpArgumentMapping(
const ArgumentMappingContext& ctx) { const ArgumentMappingContext& ctx) {
return KernelSignature( return KernelSignature(
...@@ -27,5 +35,8 @@ KernelSignature GraphSendRecvGradOpArgumentMapping( ...@@ -27,5 +35,8 @@ KernelSignature GraphSendRecvGradOpArgumentMapping(
} // namespace phi } // namespace phi
PD_REGISTER_ARG_MAPPING_FN(graph_send_recv,
phi::GraphSendRecvOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(graph_send_recv_grad, PD_REGISTER_ARG_MAPPING_FN(graph_send_recv_grad,
phi::GraphSendRecvGradOpArgumentMapping); phi::GraphSendRecvGradOpArgumentMapping);
...@@ -304,6 +304,35 @@ class API_GraphSendRecvOpTest(unittest.TestCase): ...@@ -304,6 +304,35 @@ class API_GraphSendRecvOpTest(unittest.TestCase):
"two value is\ "two value is\
{}\n{}, check diff!".format(np_res, ret_res)) {}\n{}, check diff!".format(np_res, ret_res))
def test_set_outsize_gpu(self):
if paddle.fluid.core.is_compiled_with_cuda():
x = paddle.to_tensor(
np.array([[0, 2, 3], [1, 4, 5], [2, 6, 6]]), dtype="float32")
src_index = paddle.to_tensor(np.array([0, 0, 1]), dtype="int32")
dst_index = paddle.to_tensor(np.array([0, 1, 1]), dtype="int32")
res = paddle.incubate.graph_send_recv(x, src_index, dst_index,
"sum")
out_size = paddle.max(dst_index) + 1
res_set_outsize = paddle.incubate.graph_send_recv(
x, src_index, dst_index, "sum", out_size)
np_res = np.array(
[[0, 2, 3], [1, 6, 8], [0, 0, 0]], dtype="float32")
np_res_set_outsize = np.array(
[[0, 2, 3], [1, 6, 8]], dtype="float32")
self.assertTrue(
np.allclose(
np_res, res, atol=1e-6),
"two value is\
{}\n{}, check diff!".format(np_res, res))
self.assertTrue(
np.allclose(
np_res_set_outsize, res_set_outsize, atol=1e-6),
"two value is\
{}\n{}, check diff!"
.format(np_res_set_outsize, res_set_outsize))
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -19,7 +19,12 @@ from paddle.fluid import core ...@@ -19,7 +19,12 @@ from paddle.fluid import core
from paddle import _C_ops from paddle import _C_ops
def graph_send_recv(x, src_index, dst_index, pool_type="sum", name=None): def graph_send_recv(x,
src_index,
dst_index,
pool_type="sum",
out_size=None,
name=None):
r""" r"""
Graph Learning Send_Recv combine operator. Graph Learning Send_Recv combine operator.
...@@ -27,7 +32,7 @@ def graph_send_recv(x, src_index, dst_index, pool_type="sum", name=None): ...@@ -27,7 +32,7 @@ def graph_send_recv(x, src_index, dst_index, pool_type="sum", name=None):
This operator is mainly used in Graph Learning domain, and the main purpose is to reduce intermediate memory This operator is mainly used in Graph Learning domain, and the main purpose is to reduce intermediate memory
consumption in the process of message passing. Take `x` as the input tensor, we first use `src_index` consumption in the process of message passing. Take `x` as the input tensor, we first use `src_index`
to gather the corresponding data, and then use `dst_index` to update the corresponding position of output tensor to gather the corresponding data, and then use `dst_index` to update the corresponding position of output tensor
in different pooling types, like sum, mean, max, or min. in different pooling types, like sum, mean, max, or min. Besides, we can set `out_size` to get necessary output shape.
.. code-block:: text .. code-block:: text
...@@ -43,6 +48,8 @@ def graph_send_recv(x, src_index, dst_index, pool_type="sum", name=None): ...@@ -43,6 +48,8 @@ def graph_send_recv(x, src_index, dst_index, pool_type="sum", name=None):
pool_type = "sum" pool_type = "sum"
out_size = None
Then: Then:
Out = [[0, 2, 3], Out = [[0, 2, 3],
...@@ -56,6 +63,9 @@ def graph_send_recv(x, src_index, dst_index, pool_type="sum", name=None): ...@@ -56,6 +63,9 @@ def graph_send_recv(x, src_index, dst_index, pool_type="sum", name=None):
The available data type is int32, int64. The available data type is int32, int64.
pool_type (str): The pooling type of graph_send_recv, including `sum`, `mean`, `max`, `min`. pool_type (str): The pooling type of graph_send_recv, including `sum`, `mean`, `max`, `min`.
Default value is `sum`. Default value is `sum`.
out_size (int64|None): We can set `out_size` to get necessary output shape. If not set, then this
attribute will not be used. If set, it should be equal with or larger than
max(dst_index) + 1.
name (str, optional): Name for the operation (optional, default is None). name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`. For more information, please refer to :ref:`api_guide_Name`.
...@@ -75,6 +85,21 @@ def graph_send_recv(x, src_index, dst_index, pool_type="sum", name=None): ...@@ -75,6 +85,21 @@ def graph_send_recv(x, src_index, dst_index, pool_type="sum", name=None):
out = paddle.incubate.graph_send_recv(x, src_index, dst_index, pool_type="sum") out = paddle.incubate.graph_send_recv(x, src_index, dst_index, pool_type="sum")
# Outputs: [[0., 2., 3.], [2., 8., 10.], [1., 4., 5.]] # Outputs: [[0., 2., 3.], [2., 8., 10.], [1., 4., 5.]]
x = paddle.to_tensor([[0, 2, 3], [1, 4, 5], [2, 6, 7]], dtype="float32")
indexes = paddle.to_tensor([[0, 1], [2, 1], [0, 0]], dtype="int32")
src_index = indexes[:, 0]
dst_index = indexes[:, 1]
out_size = paddle.max(dst_index) + 1
out = paddle.incubate.graph_send_recv(x, src_index, dst_index, pool_type="sum", out_size=out_size)
# Outputs: [[0., 2., 3.], [[2., 8., 10.]]]
x = paddle.to_tensor([[0, 2, 3], [1, 4, 5], [2, 6, 7]], dtype="float32")
indexes = paddle.to_tensor([[0, 1], [2, 1], [0, 0]], dtype="int32")
src_index = indexes[:, 0]
dst_index = indexes[:, 1]
out = paddle.incubate.graph_send_recv(x, src_index, dst_index, pool_type="sum")
# Outputs: [[0., 2., 3.], [2., 8., 10.], [0., 0., 0.]]
""" """
if pool_type not in ["sum", "mean", "max", "min"]: if pool_type not in ["sum", "mean", "max", "min"]:
...@@ -82,9 +107,16 @@ def graph_send_recv(x, src_index, dst_index, pool_type="sum", name=None): ...@@ -82,9 +107,16 @@ def graph_send_recv(x, src_index, dst_index, pool_type="sum", name=None):
"pool_type should be `sum`, `mean`, `max` or `min`, but received %s" "pool_type should be `sum`, `mean`, `max` or `min`, but received %s"
% pool_type) % pool_type)
# TODO(daisiming): Should we add judgement for out_size: max(dst_index) + 1.
if in_dygraph_mode(): if in_dygraph_mode():
out, tmp = _C_ops.graph_send_recv(x, src_index, dst_index, 'pool_type', if out_size is None or out_size <= 0:
pool_type.upper()) out, tmp = _C_ops.graph_send_recv(x, src_index, dst_index,
'pool_type', pool_type.upper())
else:
out, tmp = _C_ops.graph_send_recv(
x, src_index, dst_index, 'pool_type',
pool_type.upper(), 'out_size', out_size)
return out return out
check_variable_and_dtype(x, "X", ("float32", "float64", "int32", "int64"), check_variable_and_dtype(x, "X", ("float32", "float64", "int32", "int64"),
...@@ -105,5 +137,8 @@ def graph_send_recv(x, src_index, dst_index, pool_type="sum", name=None): ...@@ -105,5 +137,8 @@ def graph_send_recv(x, src_index, dst_index, pool_type="sum", name=None):
"Dst_index": dst_index}, "Dst_index": dst_index},
outputs={"Out": out, outputs={"Out": out,
"Dst_count": dst_count}, "Dst_count": dst_count},
attrs={"pool_type": pool_type.upper()}) attrs={
"pool_type": pool_type.upper(),
"out_size": 0 if out_size is None or out_size <= 0 else out_size
})
return out return out
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册