未验证 提交 88724a53 编写于 作者: S Siming Dai 提交者: GitHub

[geometric]Add paddle.geometric.send_uv API (#44848)

* initial commit

* fix op maker bug

* fix mul grad bug

* add unittest

* fix add grad bug, add cpu kernel

* add paddle.geometric.message_passing

* add paddle.geometric.send_uv api, add unittest

* add fp16 judgement

* fix file typo, move compute_type to message_op

* add impl file

* fix unittest timeout time

* add review revise
上级 6a15d407
......@@ -135,6 +135,16 @@
func : fft_r2c
backward : fft_r2c_grad
- api : graph_send_uv
args : (Tensor x, Tensor y, Tensor src_index, Tensor dst_index, str message_op = "ADD")
output : Tensor(out)
infer_meta :
func : GraphSendUVInferMeta
kernel :
func : graph_send_uv
data_type : x
backward : graph_send_uv_grad
- api : lgamma
args : (Tensor x)
output : Tensor(out)
......
......@@ -147,6 +147,17 @@
data_type: out_grad
no_need_buffer: x
- backward_api : graph_send_uv_grad
forward : graph_send_uv (Tensor x, Tensor y, Tensor src_index, Tensor dst_index, str message_op = "ADD") -> Tensor(out)
args: (Tensor x, Tensor y, Tensor src_index, Tensor dst_index, Tensor out_grad, str message_op = "ADD")
output : Tensor(x_grad), Tensor(y_grad)
infer_meta :
func : GeneralBinaryGradInferMeta
param : [x, y]
kernel :
func : graph_send_uv_grad
data_type : x
- backward_api : lgamma_grad
forward : lgamma(Tensor x) -> Tensor(out)
args : (Tensor x, Tensor out_grad)
......
......@@ -2687,6 +2687,76 @@ void GraphSendUERecvInferMeta(const MetaTensor& x,
out->set_dims(phi::make_ddim(out_dims_array));
}
void GraphSendUVInferMeta(const MetaTensor& x,
const MetaTensor& y,
const MetaTensor& src_index,
const MetaTensor& dst_index,
const std::string& message_op,
MetaTensor* out) {
auto src_index_dims = src_index.dims();
if (src_index_dims.size() == 2) {
PADDLE_ENFORCE_EQ(src_index_dims[1],
1,
phi::errors::InvalidArgument(
"The last dim of Src_index should be 1 when it "
"is 2D, but we get %d",
src_index_dims[1]));
} else {
PADDLE_ENFORCE_EQ(
src_index_dims.size(),
1,
phi::errors::InvalidArgument(
"The Src_index should be 1D, when it is not 2D, but we get %d",
src_index_dims.size()));
}
auto dst_index_dims = dst_index.dims();
if (dst_index_dims.size() == 2) {
PADDLE_ENFORCE_EQ(dst_index_dims[1],
1,
phi::errors::InvalidArgument(
"The last dim of Dst_index should be 1 when it "
"is 2D, but we get %d",
dst_index_dims[1]));
} else {
PADDLE_ENFORCE_EQ(
dst_index_dims.size(),
1,
phi::errors::InvalidArgument("The Dst_index should be 1D, "
"when it is not 2D, but we get %d",
dst_index_dims.size()));
}
PADDLE_ENFORCE_EQ(src_index_dims[0],
dst_index_dims[0],
phi::errors::InvalidArgument(
"Src_index and Dst_index should have the same shape."));
// Infer out's shape according to x and y(need broadcasting condition)
out->set_dtype(x.dtype());
auto x_dims = x.dims();
auto y_dims = y.dims();
auto x_dims1 = phi::vectorize<int>(x_dims);
auto y_dims1 = phi::vectorize<int>(y_dims);
std::vector<int> x_dims2(x_dims1.begin() + 1, x_dims1.end());
std::vector<int> y_dims2(y_dims1.begin() + 1, y_dims1.end());
int max_dim = std::max(x_dims2.size(), y_dims2.size());
int axis = std::abs(static_cast<int>(x_dims2.size() - y_dims2.size()));
std::vector<int> x_dims_array(max_dim);
std::vector<int> y_dims_array(max_dim);
std::vector<int> out_dims_array(max_dim);
// Only need to broadcast dimensions other than the 0th dimension.
phi::funcs::GetBroadcastDimsArrays(phi::make_ddim(x_dims2),
phi::make_ddim(y_dims2),
x_dims_array.data(),
y_dims_array.data(),
out_dims_array.data(),
max_dim,
axis);
out_dims_array.insert(out_dims_array.begin(), src_index_dims[0]);
out->set_dims(phi::make_ddim(out_dims_array));
}
} // namespace phi
PD_REGISTER_INFER_META_FN(batch_norm, phi::BatchNormInferMeta);
......
......@@ -476,4 +476,11 @@ void GraphSendUERecvInferMeta(const MetaTensor& x,
MetaTensor* out,
MetaTensor* dst_count);
void GraphSendUVInferMeta(const MetaTensor& x,
const MetaTensor& y,
const MetaTensor& src_index,
const MetaTensor& dst_index,
const std::string& message_op,
MetaTensor* out);
} // namespace phi
......@@ -24,7 +24,7 @@
#include "paddle/phi/kernels/cpu/graph_send_ue_recv_funcs.h"
#include "paddle/phi/kernels/empty_kernel.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/impl/graph_messaage_passing_impl.h"
#include "paddle/phi/kernels/impl/graph_message_passing_impl.h"
#include "paddle/phi/kernels/reduce_sum_kernel.h"
namespace phi {
......
......@@ -22,7 +22,7 @@
#include "paddle/phi/core/hostdevice.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/cpu/graph_send_ue_recv_funcs.h"
#include "paddle/phi/kernels/impl/graph_messaage_passing_impl.h"
#include "paddle/phi/kernels/impl/graph_message_passing_impl.h"
namespace phi {
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/graph_send_uv_grad_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/hostdevice.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/empty_kernel.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/impl/graph_message_passing_impl.h"
#include "paddle/phi/kernels/reduce_sum_kernel.h"
namespace phi {
template <typename Context, typename T, typename IndexT>
void CalculateGrad(const Context& ctx,
const T* out_grad,
const IndexT* s_index,
const IndexT* d_index,
const phi::DDim& out_grad_dims,
const phi::DDim& x_grad_dims,
const std::string& message_op,
int64_t index_size,
int64_t slice_size,
T* x_grad,
const DenseTensor& out_grad_tensor,
const DenseTensor& y) {
std::vector<int64_t> reduce_idx;
bool reduce = ReduceGrad(out_grad_dims, x_grad_dims, reduce_idx);
if (message_op == "ADD") {
if (!reduce) {
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for
#endif
for (int64_t i = 0; i < index_size; i++) {
IndexT dst = d_index[i];
T* x_grad_off = x_grad + dst * slice_size;
const T* out_grad_off = out_grad + i * slice_size;
for (int64_t j = 0; j < slice_size; j++) {
if (out_grad_off[j] != 0) {
#ifdef PADDLE_WITH_MKLML
#pragma omp atomic
#endif
x_grad_off[j] += out_grad_off[j];
}
}
}
} else {
const auto& bcast_info = phi::CalcBCastInfo(out_grad_dims, x_grad_dims);
auto out_grad_dims_1 = phi::vectorize<int>(out_grad_dims);
std::vector<int> out_grad_dims_2(out_grad_dims_1.begin() + 1,
out_grad_dims_1.end());
out_grad_dims_2.emplace(out_grad_dims_2.begin(), x_grad_dims[0]);
DenseTensor x_grad_v2 = phi::Empty<T, Context>(ctx, out_grad_dims_2);
phi::funcs::SetConstant<Context, T>()(ctx, &x_grad_v2, T(0));
T* x_grad_v2_data = x_grad_v2.data<T>();
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for
#endif
for (int64_t i = 0; i < index_size; i++) {
IndexT dst = d_index[i];
T* x_grad_off = x_grad_v2_data + dst * bcast_info.out_len;
const T* out_grad_off = out_grad + i * bcast_info.out_len;
for (int64_t j = 0; j < bcast_info.out_len; j++) {
if (out_grad_off[j] != 0) {
#ifdef PADDLE_WITH_MKLML
#pragma omp atomic
#endif
x_grad_off[j] += out_grad_off[j];
}
}
}
DenseTensor x_grad_out = phi::Sum<T, Context>(
ctx,
x_grad_v2,
reduce_idx,
paddle::experimental::CppTypeToDataType<T>::Type(),
true);
memcpy(x_grad, x_grad_out.data<T>(), x_grad_out.numel() * sizeof(T));
}
} else if (message_op == "MUL") {
const auto& bcast = phi::CalcBCastInfo(y.dims(), out_grad_dims);
const T* y_data = y.data<T>();
if (!reduce) {
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for
#endif
for (int64_t i = 0; i < index_size; i++) {
IndexT src = s_index[i];
IndexT dst = d_index[i];
T* x_grad_off = x_grad + dst * bcast.out_len;
const T* y_off = y_data + src * bcast.l_len;
const T* out_grad_off = out_grad + i * bcast.r_len;
for (int64_t j = 0; j < bcast.out_len; j++) {
int64_t y_add = bcast.use_bcast ? bcast.l_offset[j] : j;
int64_t o_add = bcast.use_bcast ? bcast.r_offset[j] : j;
T val = y_off[y_add] * out_grad_off[o_add];
if (val != 0) {
#ifdef PADDLE_WITH_MKLML
#pragma omp atomic
#endif
x_grad_off[j] += val;
}
}
}
} else {
auto out_grad_dims_1 = phi::vectorize<int>(out_grad_dims);
std::vector<int> out_grad_dims_2(out_grad_dims_1.begin() + 1,
out_grad_dims_1.end());
out_grad_dims_2.emplace(out_grad_dims_2.begin(), x_grad_dims[0]);
DenseTensor x_grad_v2 = phi::Empty<T, Context>(ctx, out_grad_dims_2);
phi::funcs::SetConstant<Context, T>()(ctx, &x_grad_v2, T(0));
T* x_grad_v2_data = x_grad_v2.data<T>();
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for
#endif
for (int64_t i = 0; i < index_size; i++) {
IndexT src = s_index[i];
IndexT dst = d_index[i];
T* x_grad_off = x_grad_v2_data + dst * bcast.out_len;
const T* y_off = y_data + src * bcast.l_len;
const T* out_grad_off = out_grad + i * bcast.r_len;
for (int64_t j = 0; j < bcast.out_len; j++) {
int64_t y_add = bcast.use_bcast ? bcast.l_offset[j] : j;
int64_t o_add = bcast.use_bcast ? bcast.r_offset[j] : j;
T val = y_off[y_add] * out_grad_off[o_add];
if (val != 0) {
#ifdef PADDLE_WITH_MKLML
#pragma omp atomic
#endif
x_grad_off[j] += val;
}
}
}
DenseTensor x_grad_out = phi::Sum<T, Context>(
ctx,
x_grad_v2,
reduce_idx,
paddle::experimental::CppTypeToDataType<T>::Type(),
true);
memcpy(x_grad, x_grad_out.data<T>(), x_grad_out.numel() * sizeof(T));
}
}
}
template <typename Context, typename T, typename IndexT>
void GraphSendUVGradOpKernelLaunchHelper(const Context& ctx,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& out_grad,
const DenseTensor& src_index,
const DenseTensor& dst_index,
const std::string& message_op,
DenseTensor* x_grad,
DenseTensor* y_grad) {
const int64_t& index_size = dst_index.dims()[0];
PADDLE_ENFORCE_GT(
index_size,
0,
errors::InvalidArgument("The first dimension of src_index or dst_index "
"shoule be greater than 0, but received %d.",
index_size));
ctx.template Alloc<T>(x_grad);
T* x_grad_data = x_grad->data<T>();
ctx.template Alloc<T>(y_grad);
T* y_grad_data = y_grad->data<T>();
const auto& x_grad_dims = x_grad->dims();
const auto& y_grad_dims = y_grad->dims();
int64_t memset_size_x = 1, memset_size_y = 1;
int64_t slice_size_x = 1, slice_size_y = 1;
for (int i = 0; i < x_grad_dims.size(); i++) {
memset_size_x *= x_grad_dims[i];
if (i > 0) slice_size_x *= x_grad_dims[i];
}
for (int i = 0; i < y_grad_dims.size(); i++) {
memset_size_y *= y_grad_dims[i];
if (i > 0) slice_size_y *= y_grad_dims[i];
}
const size_t& memset_bytes_x = memset_size_x * sizeof(T);
const size_t& memset_bytes_y = memset_size_y * sizeof(T);
memset(x_grad_data, 0, memset_bytes_x);
memset(y_grad_data, 0, memset_bytes_y);
const T* out_grad_data = out_grad.data<T>();
const IndexT* s_index = src_index.data<IndexT>();
const IndexT* d_index = dst_index.data<IndexT>();
const auto& out_grad_dims = out_grad.dims();
// Calculate X Grad.
CalculateGrad<Context, T, IndexT>(ctx,
out_grad_data,
d_index,
s_index,
out_grad_dims,
x_grad_dims,
message_op,
index_size,
slice_size_x,
x_grad_data,
out_grad,
y);
// Calcuate Y Grad.
CalculateGrad<Context, T, IndexT>(ctx,
out_grad_data,
s_index,
d_index,
out_grad_dims,
y_grad_dims,
message_op,
index_size,
slice_size_y,
y_grad_data,
out_grad,
x);
}
template <typename T, typename Context>
void GraphSendUVGradKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& src_index,
const DenseTensor& dst_index,
const DenseTensor& out_grad,
const std::string& message_op,
DenseTensor* x_grad,
DenseTensor* y_grad) {
auto index_type = src_index.dtype();
if (index_type == phi::DataType::INT32) {
GraphSendUVGradOpKernelLaunchHelper<Context, T, int32_t>(
ctx, x, y, out_grad, src_index, dst_index, message_op, x_grad, y_grad);
} else if (index_type == phi::DataType::INT64) {
GraphSendUVGradOpKernelLaunchHelper<Context, T, int64_t>(
ctx, x, y, out_grad, src_index, dst_index, message_op, x_grad, y_grad);
}
}
} // namespace phi
PD_REGISTER_KERNEL(graph_send_uv_grad,
CPU,
ALL_LAYOUT,
phi::GraphSendUVGradKernel,
float,
double,
int,
int64_t) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/graph_send_uv_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/hostdevice.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/cpu/graph_send_ue_recv_funcs.h"
#include "paddle/phi/kernels/impl/graph_message_passing_impl.h"
namespace phi {
template <typename T, typename IndexT, typename ComputeFunctor>
void GraphSendUVCpuKernel(const BroadCastInfo& bcast,
const T* x_data,
const T* y_data,
const IndexT* src_indices,
const IndexT* dst_indices,
T* output,
int64_t index_size,
ComputeFunctor cfunctor) {
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for
#endif
for (int64_t i = 0; i < index_size; i++) {
IndexT src = src_indices[i];
IndexT dst = dst_indices[i];
T* out_off = output + i * bcast.out_len;
const T* x_off = x_data + src * bcast.l_len;
const T* y_off = y_data + dst * bcast.r_len;
for (int64_t j = 0; j < bcast.out_len; j++) {
int64_t x_add = bcast.use_bcast ? bcast.l_offset[j] : j;
int64_t y_add = bcast.use_bcast ? bcast.r_offset[j] : j;
T val = cfunctor(x_off[x_add], y_off[y_add]);
out_off[j] = val;
}
}
}
template <typename Context, typename T, typename IndexT>
void GraphSendUVOpKernelLaunchHelper(const Context& ctx,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& src_index,
const DenseTensor& dst_index,
const std::string& message_op,
DenseTensor* out) {
const int& index_size = src_index.dims()[0];
PADDLE_ENFORCE_GT(
index_size,
0,
errors::InvalidArgument("The first dimension of src_index or dst_index "
"shoule be greater than 0, but received %d.",
index_size));
auto out_dims = out->dims();
int64_t memset_size = 1;
for (int i = 0; i < out_dims.size(); i++) {
memset_size *= out_dims[i];
}
ctx.template Alloc<T>(out);
T* out_data = out->data<T>();
const auto& bcast_info = phi::CalcBCastInfo(x.dims(), y.dims());
const T* x_data = x.data<T>();
const T* y_data = y.data<T>();
const IndexT* s_index = src_index.data<IndexT>();
const IndexT* d_index = dst_index.data<IndexT>();
if (message_op == "ADD") {
GraphAddFunctor<T> add_functor;
GraphSendUVCpuKernel<T, IndexT, GraphAddFunctor<T>>(bcast_info,
x_data,
y_data,
s_index,
d_index,
out_data,
index_size,
add_functor);
} else if (message_op == "MUL") {
GraphMulFunctor<T> mul_functor;
GraphSendUVCpuKernel<T, IndexT, GraphMulFunctor<T>>(bcast_info,
x_data,
y_data,
s_index,
d_index,
out_data,
index_size,
mul_functor);
}
}
template <typename T, typename Context>
void GraphSendUVKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& src_index,
const DenseTensor& dst_index,
const std::string& message_op,
DenseTensor* out) {
auto index_type = src_index.dtype();
if (index_type == phi::DataType::INT32) {
GraphSendUVOpKernelLaunchHelper<Context, T, int32_t>(
ctx, x, y, src_index, dst_index, message_op, out);
} else if (index_type == phi::DataType::INT64) {
GraphSendUVOpKernelLaunchHelper<Context, T, int64_t>(
ctx, x, y, src_index, dst_index, message_op, out);
}
}
} // namespace phi
PD_REGISTER_KERNEL(graph_send_uv,
CPU,
ALL_LAYOUT,
phi::GraphSendUVKernel,
float,
double,
int,
int64_t) {}
......@@ -19,7 +19,7 @@
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/hostdevice.h"
#include "paddle/phi/kernels/impl/graph_messaage_passing_impl.h"
#include "paddle/phi/kernels/impl/graph_message_passing_impl.h"
namespace phi {
......
......@@ -21,7 +21,7 @@
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/gpu/graph_send_recv_funcs.h"
#include "paddle/phi/kernels/gpu/graph_send_ue_recv_funcs.h"
#include "paddle/phi/kernels/impl/graph_messaage_passing_impl.h"
#include "paddle/phi/kernels/impl/graph_message_passing_impl.h"
#include "paddle/phi/kernels/reduce_sum_kernel.h"
namespace phi {
......
......@@ -15,7 +15,7 @@
#include "paddle/phi/kernels/graph_send_ue_recv_kernel.h"
#include "paddle/phi/kernels/gpu/graph_send_recv_funcs.h"
#include "paddle/phi/kernels/gpu/graph_send_ue_recv_funcs.h"
#include "paddle/phi/kernels/impl/graph_messaage_passing_impl.h"
#include "paddle/phi/kernels/impl/graph_message_passing_impl.h"
#include <thrust/device_vector.h>
#include <thrust/fill.h>
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/graph_send_uv_grad_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/hostdevice.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/empty_kernel.h"
#include "paddle/phi/kernels/funcs/elementwise_functor.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/gpu/graph_send_recv_funcs.h"
#include "paddle/phi/kernels/gpu/graph_send_ue_recv_funcs.h"
#include "paddle/phi/kernels/impl/graph_message_passing_impl.h"
#include "paddle/phi/kernels/reduce_sum_kernel.h"
namespace phi {
template <typename T, typename IndexT>
__global__ void GraphSendUVGradCUDAKernel(const T* out_grad,
const IndexT* src_indices,
const IndexT* dst_indices,
int64_t index_size,
int64_t slice_size,
T* x_grad) {
IndexT ty = blockIdx.y * blockDim.y + threadIdx.y;
const IndexT stride_y = blockDim.y * gridDim.y;
while (ty < index_size) {
IndexT src = src_indices[ty];
IndexT dst = dst_indices[ty];
int64_t tx = blockIdx.x * blockDim.x + threadIdx.x;
int64_t stride_x = blockDim.x * gridDim.x;
const T* out_grad_off = out_grad + ty * slice_size;
T* x_grad_off = x_grad + dst * slice_size;
while (tx < slice_size) {
paddle::platform::CudaAtomicAdd(x_grad_off + tx, out_grad_off[tx]);
tx += stride_x;
}
ty += stride_y;
}
}
template <typename Context, typename T, typename IndexT>
void CalculateGrad(const Context& ctx,
const T* out_grad,
const IndexT* s_index,
const IndexT* d_index,
const phi::DDim& out_grad_dims,
const phi::DDim& x_grad_dims,
const std::string& message_op,
int64_t index_size,
int64_t slice_size,
T* x_grad,
const DenseTensor& out_grad_tensor,
const DenseTensor& y) {
std::vector<int64_t> reduce_idx;
bool reduce = ReduceGrad(out_grad_dims, x_grad_dims, reduce_idx);
if (message_op == "ADD") {
if (!reduce) {
const int ntx = FindNumThreads(slice_size, ctx.GetMaxThreadsPerBlock());
const int nty = ctx.GetMaxThreadsPerBlock() / ntx;
const int nbx = (slice_size + ntx - 1) / ntx;
const int nby = (index_size + nty - 1) / nty;
const dim3 grid_tmp(nbx, nby);
const dim3 block_tmp(ntx, nty);
GraphSendUVGradCUDAKernel<T, IndexT>
<<<grid_tmp, block_tmp, 0, ctx.stream()>>>(
out_grad, d_index, s_index, index_size, slice_size, x_grad);
} else {
const auto& bcast_info = phi::CalcBCastInfo(out_grad_dims, x_grad_dims);
auto out_grad_dims_1 = phi::vectorize<int>(out_grad_dims);
std::vector<int> out_grad_dims_2(out_grad_dims_1.begin() + 1,
out_grad_dims_1.end());
out_grad_dims_2.insert(out_grad_dims_2.begin(), x_grad_dims[0]);
DenseTensor x_grad_v2 = phi::Empty<T, Context>(ctx, out_grad_dims_2);
phi::funcs::SetConstant<Context, T>()(ctx, &x_grad_v2, T(0));
T* x_grad_v2_data = x_grad_v2.data<T>();
const int ntx =
FindNumThreads(bcast_info.out_len, ctx.GetMaxThreadsPerBlock());
const int nty = ctx.GetMaxThreadsPerBlock() / ntx;
const int nbx = (bcast_info.out_len + ntx - 1) / ntx;
const int nby = (index_size + nty - 1) / nty;
const dim3 grid_tmp(nbx, nby);
const dim3 block_tmp(ntx, nty);
GraphSendUVGradCUDAKernel<T, IndexT>
<<<grid_tmp, block_tmp, 0, ctx.stream()>>>(out_grad,
d_index,
s_index,
index_size,
bcast_info.out_len,
x_grad_v2_data);
// Run reduce sum
DenseTensor x_grad_out = phi::Sum<T, Context>(
ctx,
x_grad_v2,
reduce_idx,
paddle::experimental::CppTypeToDataType<T>::Type(),
true);
#ifdef PADDLE_WITH_HIP
hipMemcpy(x_grad,
x_grad_out.data<T>(),
x_grad_out.numel() * sizeof(T),
hipMemcpyDeviceToDevice);
#else
cudaMemcpy(x_grad,
x_grad_out.data<T>(),
x_grad_out.numel() * sizeof(T),
cudaMemcpyDeviceToDevice);
#endif
}
} else if (message_op == "MUL") {
const auto& bcast_info = phi::CalcBCastInfo(y.dims(), out_grad_dims);
thrust::device_vector<int64_t> l_bcastoff, r_bcastoff;
if (bcast_info.use_bcast) {
CopyBCastOff(bcast_info, l_bcastoff, r_bcastoff);
}
int64_t out_len = bcast_info.out_len;
const int ntx = FindNumThreads(out_len, ctx.GetMaxThreadsPerBlock());
const int nty = ctx.GetMaxThreadsPerBlock() / ntx;
const int nbx = (out_len + ntx - 1) / ntx;
const int nby = (index_size + nty - 1) / nty;
const dim3 grid_(nbx, nby);
const dim3 block_(ntx, nty);
funcs::MultiplyFunctor<T> mul_functor;
GraphSendUERecvSumCUDAFunctor<T> sum_functor;
const T* y_data = y.data<T>();
if (!reduce) {
GraphSendUERecvCUDAKernel<T,
IndexT,
GraphSendUERecvSumCUDAFunctor<T>,
funcs::MultiplyFunctor<T>>
<<<grid_, block_, 0, ctx.stream()>>>(
y_data,
out_grad,
d_index,
s_index,
thrust::raw_pointer_cast(l_bcastoff.data()),
thrust::raw_pointer_cast(r_bcastoff.data()),
x_grad,
index_size,
bcast_info.l_len,
bcast_info.r_len,
out_len,
bcast_info.use_bcast,
mul_functor,
sum_functor);
} else {
auto out_grad_dims_1 = phi::vectorize<int>(out_grad_dims);
std::vector<int> out_grad_dims_2(out_grad_dims_1.begin() + 1,
out_grad_dims_1.end());
out_grad_dims_2.insert(out_grad_dims_2.begin(), x_grad_dims[0]);
DenseTensor x_grad_v2 = phi::Empty<T, Context>(ctx, out_grad_dims_2);
phi::funcs::SetConstant<Context, T>()(ctx, &x_grad_v2, T(0));
T* x_grad_v2_data = x_grad_v2.data<T>();
GraphSendUERecvCUDAKernel<T,
IndexT,
GraphSendUERecvSumCUDAFunctor<T>,
funcs::MultiplyFunctor<T>>
<<<grid_, block_, 0, ctx.stream()>>>(
y_data,
out_grad,
d_index,
s_index,
thrust::raw_pointer_cast(l_bcastoff.data()),
thrust::raw_pointer_cast(r_bcastoff.data()),
x_grad_v2_data,
index_size,
bcast_info.l_len,
bcast_info.r_len,
out_len,
bcast_info.use_bcast,
mul_functor,
sum_functor);
// Run reduce_sum
DenseTensor x_grad_out = phi::Sum<T, Context>(
ctx,
x_grad_v2,
reduce_idx,
paddle::experimental::CppTypeToDataType<T>::Type(),
true);
#ifdef PADDLE_WITH_HIP
hipMemcpy(x_grad,
x_grad_out.data<T>(),
x_grad_out.numel() * sizeof(T),
hipMemcpyDeviceToDevice);
#else
cudaMemcpy(x_grad,
x_grad_out.data<T>(),
x_grad_out.numel() * sizeof(T),
cudaMemcpyDeviceToDevice);
#endif
}
}
}
template <typename Context, typename T, typename IndexT>
void GraphSendUVGradOpCUDAKernelLaunchHelper(const Context& ctx,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& out_grad,
const DenseTensor& src_index,
const DenseTensor& dst_index,
const std::string& message_op,
DenseTensor* x_grad,
DenseTensor* y_grad) {
const int64_t& index_size = dst_index.dims()[0];
PADDLE_ENFORCE_GT(
index_size,
0,
errors::InvalidArgument("The first dimension of src_index or dst_index "
"shoule be greater than 0, but received %d.",
index_size));
ctx.template Alloc<T>(x_grad);
T* x_grad_data = x_grad->data<T>();
ctx.template Alloc<T>(y_grad);
T* y_grad_data = y_grad->data<T>();
const auto& x_grad_dims = x_grad->dims();
const auto& y_grad_dims = y_grad->dims();
int64_t memset_size_x = 1, memset_size_y = 1;
int64_t slice_size_x = 1, slice_size_y = 1;
for (int i = 0; i < x_grad_dims.size(); i++) {
memset_size_x *= x_grad_dims[i];
if (i > 0) slice_size_x *= x_grad_dims[i];
}
for (int i = 0; i < y_grad_dims.size(); i++) {
memset_size_y *= y_grad_dims[i];
if (i > 0) slice_size_y *= y_grad_dims[i];
}
const size_t& memset_bytes_x = memset_size_x * sizeof(T);
const size_t& memset_bytes_y = memset_size_y * sizeof(T);
#ifdef PADDLE_WITH_HIP
hipMemset(x_grad_data, 0, memset_bytes_x);
hipMemset(y_grad_data, 0, memset_bytes_y);
#else
cudaMemset(x_grad_data, 0, memset_bytes_x);
cudaMemset(y_grad_data, 0, memset_bytes_y);
#endif
const T* out_grad_data = out_grad.data<T>();
const IndexT* s_index = src_index.data<IndexT>();
const IndexT* d_index = dst_index.data<IndexT>();
// Calculate X grad.
const auto& out_grad_dims = out_grad.dims();
CalculateGrad<Context, T, IndexT>(ctx,
out_grad_data,
s_index,
d_index,
out_grad_dims,
x_grad_dims,
message_op,
index_size,
slice_size_x,
x_grad_data,
out_grad,
y);
// Calculate Y grad.
CalculateGrad<Context, T, IndexT>(ctx,
out_grad_data,
d_index,
s_index,
out_grad_dims,
y_grad_dims,
message_op,
index_size,
slice_size_y,
y_grad_data,
out_grad,
x);
}
template <typename T, typename Context>
void GraphSendUVGradKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& src_index,
const DenseTensor& dst_index,
const DenseTensor& out_grad,
const std::string& message_op,
DenseTensor* x_grad,
DenseTensor* y_grad) {
auto index_type = src_index.dtype();
if (index_type == phi::DataType::INT32) {
GraphSendUVGradOpCUDAKernelLaunchHelper<Context, T, int32_t>(
ctx, x, y, out_grad, src_index, dst_index, message_op, x_grad, y_grad);
} else if (index_type == phi::DataType::INT64) {
GraphSendUVGradOpCUDAKernelLaunchHelper<Context, T, int64_t>(
ctx, x, y, out_grad, src_index, dst_index, message_op, x_grad, y_grad);
}
}
} // namespace phi
PD_REGISTER_KERNEL(graph_send_uv_grad,
GPU,
ALL_LAYOUT,
phi::GraphSendUVGradKernel,
float,
double,
int,
int64_t,
phi::dtype::float16) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/graph_send_uv_kernel.h"
#include "paddle/phi/kernels/gpu/graph_send_ue_recv_funcs.h"
#include "paddle/phi/kernels/impl/graph_message_passing_impl.h"
#include <thrust/device_vector.h>
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/hostdevice.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/elementwise_functor.h"
namespace phi {
template <typename T, typename IndexT, typename ComputeFunctor>
__global__ void GraphSendUVCUDAKernel(const T* x_data,
const T* y_data,
const IndexT* src_indices,
const IndexT* dst_indices,
const int64_t* xbcast_off,
const int64_t* ybcast_off,
T* output,
int64_t index_size,
int64_t x_len,
int64_t y_len,
int64_t out_len,
bool use_bcast,
ComputeFunctor cfunctor) {
IndexT ty = blockIdx.y * blockDim.y + threadIdx.y;
const IndexT stride_y = blockDim.y * gridDim.y;
while (ty < index_size) {
IndexT src = src_indices[ty];
IndexT dst = dst_indices[ty];
int64_t tx = blockIdx.x * blockDim.x + threadIdx.x;
int64_t stride_x = blockDim.x * gridDim.x;
const T* x_off = x_data + src * x_len;
const T* y_off = y_data + dst * y_len;
T* out_off = output + ty * out_len;
while (tx < out_len) {
int64_t x_add = use_bcast ? xbcast_off[tx] : tx;
int64_t y_add = use_bcast ? ybcast_off[tx] : tx;
T val = cfunctor(x_off[x_add], y_off[y_add]);
out_off[tx] = val;
tx += stride_x;
}
ty += stride_y;
}
}
template <typename Context, typename T, typename IndexT>
void GraphSendUVOpCUDAKernelLaunchHelper(const Context& ctx,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& src_index,
const DenseTensor& dst_index,
const std::string& message_op,
DenseTensor* out) {
const int64_t& index_size = src_index.dims()[0];
PADDLE_ENFORCE_GT(
index_size,
0,
errors::InvalidArgument("The first dimension of src_index or dst_index "
"shoule be greater than 0, but received %d.",
index_size));
auto out_dims = out->dims();
int64_t memset_size = 1;
for (int i = 0; i < out_dims.size(); i++) {
memset_size *= out_dims[i];
}
ctx.template Alloc<T>(out);
T* out_data = out->data<T>();
const auto& bcast_info = phi::CalcBCastInfo(x.dims(), y.dims());
const T* x_data = x.data<T>();
const T* y_data = y.data<T>();
const IndexT* s_index = src_index.data<IndexT>();
const IndexT* d_index = dst_index.data<IndexT>();
thrust::device_vector<int64_t> x_bcastoff, y_bcastoff;
if (bcast_info.use_bcast) {
CopyBCastOff(bcast_info, x_bcastoff, y_bcastoff);
}
int64_t out_len = bcast_info.out_len;
const int ntx = FindNumThreads(out_len, ctx.GetMaxThreadsPerBlock());
const int nty = ctx.GetMaxThreadsPerBlock() / ntx;
const int nbx = (out_len + ntx - 1) / ntx;
const int nby = (index_size + nty - 1) / nty;
const dim3 grid(nbx, nby);
const dim3 block(ntx, nty);
if (message_op == "ADD") {
funcs::AddFunctor<T> add_functor;
GraphSendUVCUDAKernel<T, IndexT, funcs::AddFunctor<T>>
<<<grid, block, 0, ctx.stream()>>>(
x_data,
y_data,
s_index,
d_index,
thrust::raw_pointer_cast(x_bcastoff.data()),
thrust::raw_pointer_cast(y_bcastoff.data()),
out_data,
index_size,
bcast_info.l_len,
bcast_info.r_len,
out_len,
bcast_info.use_bcast,
add_functor);
} else if (message_op == "MUL") {
funcs::MultiplyFunctor<T> mul_functor;
GraphSendUVCUDAKernel<T, IndexT, funcs::MultiplyFunctor<T>>
<<<grid, block, 0, ctx.stream()>>>(
x_data,
y_data,
s_index,
d_index,
thrust::raw_pointer_cast(x_bcastoff.data()),
thrust::raw_pointer_cast(y_bcastoff.data()),
out_data,
index_size,
bcast_info.l_len,
bcast_info.r_len,
out_len,
bcast_info.use_bcast,
mul_functor);
}
}
template <typename T, typename Context>
void GraphSendUVKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& src_index,
const DenseTensor& dst_index,
const std::string& message_op,
DenseTensor* out) {
auto index_type = src_index.dtype();
if (index_type == phi::DataType::INT32) {
GraphSendUVOpCUDAKernelLaunchHelper<Context, T, int32_t>(
ctx, x, y, src_index, dst_index, message_op, out);
} else if (index_type == phi::DataType::INT64) {
GraphSendUVOpCUDAKernelLaunchHelper<Context, T, int64_t>(
ctx, x, y, src_index, dst_index, message_op, out);
}
}
} // namespace phi
PD_REGISTER_KERNEL(graph_send_uv,
GPU,
ALL_LAYOUT,
phi::GraphSendUVKernel,
float,
double,
int,
int64_t,
phi::dtype::float16) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void GraphSendUVGradKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& src_index,
const DenseTensor& dst_index,
const DenseTensor& out_grad,
const std::string& message_op,
DenseTensor* x_grad,
DenseTensor* y_grad);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void GraphSendUVKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& src_index,
const DenseTensor& dst_index,
const std::string& message_op,
DenseTensor* out);
} // namespace phi
......@@ -1565,6 +1565,7 @@ set_tests_properties(test_pool3d_api PROPERTIES TIMEOUT 120)
set_tests_properties(test_cumprod_op PROPERTIES TIMEOUT 120)
set_tests_properties(test_split_program PROPERTIES TIMEOUT 120)
set_tests_properties(test_graph_send_ue_recv_op PROPERTIES TIMEOUT 60)
set_tests_properties(test_graph_send_uv_op PROPERTIES TIMEOUT 60)
if(WITH_DISTRIBUTE
AND WITH_GPU
AND WITH_NCCL)
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle
import paddle.fluid as fluid
import paddle.fluid.core as core
from paddle.fluid.framework import _test_eager_guard
from op_test import OpTest
def compute_graph_send_uv(inputs, attributes):
x = inputs['x']
y = inputs['y']
src_index = inputs['src_index']
dst_index = inputs['dst_index']
message_op = attributes['message_op']
gather_x = x[src_index]
gather_y = y[dst_index]
# Calculate forward output.
if message_op == "ADD":
results = gather_x + gather_y
elif message_op == "MUL":
results = gather_x * gather_y
return results
def graph_send_uv_wrapper(x, y, src_index, dst_index, message_op="add"):
return paddle.geometric.send_uv(x, y, src_index, dst_index,
message_op.lower())
class TestGraphSendUVOp(OpTest):
def setUp(self):
paddle.enable_static()
self.python_api = graph_send_uv_wrapper
self.python_out_sig = ['out']
self.op_type = "graph_send_uv"
self.set_config()
self.inputs = {
'x': self.x,
'y': self.y,
'src_index': self.src_index,
'dst_index': self.dst_index
}
self.attrs = {'message_op': self.message_op}
out = compute_graph_send_uv(self.inputs, self.attrs)
self.outputs = {'out': out}
def test_check_output(self):
self.check_output(check_eager=True)
def test_check_grad(self):
self.check_grad(['x', 'y'], 'out', check_eager=True)
def set_config(self):
self.x = np.random.random((10, 20)).astype("float64")
self.y = np.random.random((10, 20)).astype("float64")
index = np.random.randint(0, 10, (15, 2)).astype(np.int64)
self.src_index = index[:, 0]
self.dst_index = index[:, 1]
self.message_op = 'ADD'
class TestCase1(TestGraphSendUVOp):
def set_config(self):
self.x = np.random.random((10, 20)).astype("float64")
self.y = np.random.random((10, 20)).astype("float64")
index = np.random.randint(0, 10, (15, 2)).astype(np.int64)
self.src_index = index[:, 0]
self.dst_index = index[:, 1]
self.message_op = 'MUL'
class TestCase2(TestGraphSendUVOp):
def set_config(self):
self.x = np.random.random((100, 1)).astype("float64")
self.y = np.random.random((100, 20)).astype("float64")
index = np.random.randint(0, 100, (15, 2)).astype(np.int64)
self.src_index = index[:, 0]
self.dst_index = index[:, 1]
self.message_op = 'ADD'
class TestCase3(TestGraphSendUVOp):
def set_config(self):
self.x = np.random.random((100, 20)).astype("float64")
self.y = np.random.random((100, 1)).astype("float64")
index = np.random.randint(0, 100, (15, 2)).astype(np.int64)
self.src_index = index[:, 0]
self.dst_index = index[:, 1]
self.message_op = 'ADD'
class TestCase4(TestGraphSendUVOp):
def set_config(self):
self.x = np.random.random((100, 1)).astype("float64")
self.y = np.random.random((100, 20)).astype("float64")
index = np.random.randint(0, 100, (15, 2)).astype(np.int64)
self.src_index = index[:, 0]
self.dst_index = index[:, 1]
self.message_op = 'MUL'
class TestCase5(TestGraphSendUVOp):
def set_config(self):
self.x = np.random.random((100, 20)).astype("float64")
self.y = np.random.random((100, 1)).astype("float64")
index = np.random.randint(0, 100, (15, 2)).astype(np.int64)
self.src_index = index[:, 0]
self.dst_index = index[:, 1]
self.message_op = 'MUL'
class TestCase6(TestGraphSendUVOp):
def set_config(self):
self.x = np.random.random((10, 10, 1)).astype("float64")
self.y = np.random.random((10, 10, 10))
index = np.random.randint(0, 10, (15, 2)).astype(np.int64)
self.src_index = index[:, 0]
self.dst_index = index[:, 1]
self.message_op = 'ADD'
class TestCase7(TestGraphSendUVOp):
def set_config(self):
self.x = np.random.random((10, 10, 1)).astype("float64")
self.y = np.random.random((10, 10, 10))
index = np.random.randint(0, 10, (15, 2)).astype(np.int64)
self.src_index = index[:, 0]
self.dst_index = index[:, 1]
self.message_op = 'MUL'
class API_GeometricSendUVTest(unittest.TestCase):
def test_compute_all_dygraph(self):
paddle.disable_static()
x = paddle.to_tensor([[0, 2, 3], [1, 4, 5], [2, 6, 7]], dtype="float32")
y = paddle.to_tensor([[1, 1, 2], [2, 3, 4], [4, 5, 6]], dtype="float32")
src_index = paddle.to_tensor(np.array([0, 1, 2, 0]), dtype="int32")
dst_index = paddle.to_tensor(np.array([1, 2, 1, 0]), dtype="int32")
res_add = paddle.geometric.send_uv(x,
y,
src_index,
dst_index,
message_op="add")
res_sub = paddle.geometric.send_uv(x,
y,
src_index,
dst_index,
message_op="sub")
res_mul = paddle.geometric.send_uv(x,
y,
src_index,
dst_index,
message_op="mul")
res_div = paddle.geometric.send_uv(x,
y,
src_index,
dst_index,
message_op="div")
res = [res_add, res_sub, res_mul, res_div]
np_add = np.array([[2, 5, 7], [5, 9, 11], [4, 9, 11], [1, 3, 5]],
dtype="float32")
np_sub = np.array([[-2, -1, -1], [-3, -1, -1], [0, 3, 3], [-1, 1, 1]],
dtype="float32")
np_mul = np.array([[0, 6, 12], [4, 20, 30], [4, 18, 28], [0, 2, 6]],
dtype="float32")
np_div = np.array(
[[0, 2 / 3, 0.75], [0.25, 0.8, 5 / 6], [1, 2, 7 / 4], [0, 2, 1.5]],
dtype="float32")
for np_res, paddle_res in zip([np_add, np_sub, np_mul, np_div], res):
self.assertTrue(
np.allclose(np_res, paddle_res, atol=1e-6), "two value is\
{}\n{}, check diff!".format(np_res, paddle_res))
def test_compute_all_static(self):
paddle.enable_static()
with paddle.static.program_guard(paddle.static.Program()):
x = paddle.static.data(name="x", shape=[3, 3], dtype="float32")
y = paddle.static.data(name="y", shape=[3, 3], dtype="float32")
src_index = paddle.static.data(name="src", shape=[4], dtype="int32")
dst_index = paddle.static.data(name="dst", shape=[4], dtype="int32")
res_add = paddle.geometric.send_uv(x,
y,
src_index,
dst_index,
message_op="add")
res_sub = paddle.geometric.send_uv(x,
y,
src_index,
dst_index,
message_op="sub")
res_mul = paddle.geometric.send_uv(x,
y,
src_index,
dst_index,
message_op="mul")
res_div = paddle.geometric.send_uv(x,
y,
src_index,
dst_index,
message_op="div")
exe = paddle.static.Executor(paddle.CPUPlace())
data1 = np.array([[0, 2, 3], [1, 4, 5], [2, 6, 7]], dtype="float32")
data2 = np.array([[1, 1, 2], [2, 3, 4], [4, 5, 6]], dtype="float32")
data3 = np.array([0, 1, 2, 0], dtype="int32")
data4 = np.array([1, 2, 1, 0], dtype="int32")
np_add = np.array([[2, 5, 7], [5, 9, 11], [4, 9, 11], [1, 3, 5]],
dtype="float32")
np_sub = np.array(
[[-2, -1, -1], [-3, -1, -1], [0, 3, 3], [-1, 1, 1]],
dtype="float32")
np_mul = np.array([[0, 6, 12], [4, 20, 30], [4, 18, 28], [0, 2, 6]],
dtype="float32")
np_div = np.array([[0, 2 / 3, 0.75], [0.25, 0.8, 5 / 6],
[1, 2, 7 / 4], [0, 2, 1.5]],
dtype="float32")
ret = exe.run(feed={
'x': data1,
'y': data2,
'src': data3,
'dst': data4,
},
fetch_list=[res_add, res_sub, res_mul, res_div])
for np_res, paddle_res in zip([np_add, np_sub, np_mul, np_div],
ret):
self.assertTrue(
np.allclose(np_res, paddle_res, atol=1e-6), "two value is\
{}\n{}, check diff!".format(np_res, paddle_res))
def test_api_eager_dygraph(self):
with _test_eager_guard():
self.test_compute_all_dygraph()
......@@ -14,8 +14,10 @@
from .message_passing import send_u_recv # noqa: F401
from .message_passing import send_ue_recv # noqa: F401
from .message_passing import send_uv # noqa: F401
__all__ = [
'send_u_recv',
'send_ue_recv',
'send_uv',
]
......@@ -14,3 +14,10 @@
from .send_recv import send_u_recv # noqa: F401
from .send_recv import send_ue_recv # noqa: F401
from .send_recv import send_uv # noqa: F401
__all__ = [
'send_u_recv',
'send_ue_recv',
'send_uv',
]
......@@ -21,6 +21,8 @@ from paddle import _C_ops
from .utils import convert_out_size_to_list, get_out_size_tensor_inputs, reshape_lhs_rhs
__all__ = []
def send_u_recv(x,
src_index,
......@@ -336,3 +338,116 @@ def send_ue_recv(x,
},
attrs=attrs)
return out
def send_uv(x, y, src_index, dst_index, message_op="add", name=None):
"""
Graph Learning message passing api.
This api is mainly used in Graph Learning domain, and the main purpose is to reduce intermediate memory
consumption in the process of message passing. Take `x` as the source node feature tensor, take `y` as
the destination node feature tensor. Then we use `src_index` and `dst_index` to gather the corresponding data,
and then compute the edge features in different message_ops like `add`, `sub`, `mul`, `div`.
.. code-block:: text
Given:
x = [[0, 2, 3],
[1, 4, 5],
[2, 6, 7]]
y = [[0, 1, 2],
[2, 3, 4],
[4, 5, 6]]
src_index = [0, 1, 2, 0]
dst_index = [1, 2, 1, 0]
message_op = "add"
Then:
out = [[2, 5, 7],
[5, 9, 11],
[4, 9, 11],
[0, 3, 5]]
Args:
x (Tensor): The source node feature tensor, and the available data type is float32, float64, int32, int64. And we support float16 in gpu version.
y (Tensor): The destination node feature tensor, and the available data type is float32, float64, int32, int64. And we support float16 in gpu version.
src_index (Tensor): An 1-D tensor, and the available data type is int32, int64.
dst_index (Tensor): An 1-D tensor, and should have the same shape as `src_index`.
The available data type is int32, int64.
message_op (Tensor): Different message ops for x and y, including `add`, `sub`, `mul` and `div`.
name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
Returns:
out (Tensor): The output tensor.
Examples:
.. code-block:: python
import paddle
x = paddle.to_tensor([[0, 2, 3], [1, 4, 5], [2, 6, 7]], dtype="float32")
y = paddle.to_tensor([[0, 1, 2], [2, 3, 4], [4, 5, 6]], dtype="float32")
indexes = paddle.to_tensor([[0, 1], [1, 2], [2, 1], [0, 0]], dtype="int32")
src_index = indexes[:, 0]
dst_index = indexes[:, 1]
out = paddle.geometric.send_uv(x, y, src_index, dst_index, message_op="add")
# Outputs: [[2., 5., 7.], [5., 9., 11.], [4., 9., 11.], [0., 3., 5.]]
"""
if message_op not in ['add', 'sub', 'mul', 'div']:
raise ValueError(
"message_op should be `add`, `sub`, `mul`, `div`, but received %s" %
message_op)
x, y = reshape_lhs_rhs(x, y)
if message_op == 'sub':
message_op = 'add'
y = -y
if message_op == 'div':
message_op = 'mul'
y = 1. / y
if in_dygraph_mode():
return _C_ops.final_state_graph_send_uv(x, y, src_index, dst_index,
message_op.upper())
else:
if _in_legacy_dygraph():
return _C_ops.graph_send_uv(x, y, src_index, dst_index,
"message_op", message_op.upper())
else:
helper = LayerHelper("send_uv", **locals())
check_variable_and_dtype(
x, 'x', ['int32', 'int64', 'float32', 'float64', 'float16'],
'graph_send_uv')
check_variable_and_dtype(
y, 'y', ['int32', 'int64', 'float32', 'float64', 'float16'],
'graph_send_uv')
check_variable_and_dtype(src_index, 'src_index', ['int32', 'int64'],
'graph_send_uv')
check_variable_and_dtype(dst_index, 'dst_index', ['int32', 'int64'],
'graph_send_uv')
out = helper.create_variable_for_type_inference(dtype=x.dtype)
inputs = {
'x': x,
'y': y,
'src_index': src_index,
'dst_index': dst_index
}
attrs = {'message_op': message_op.upper()}
helper.append_op(type="graph_send_uv",
inputs=inputs,
attrs=attrs,
outputs={"out": out})
return out
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册