未验证 提交 6bd2d2b1 编写于 作者: W wawltor 提交者: GitHub

[Phi] move the graph_send_recv op to the phi (#40092)

* [Phi] transfer old kernel to pten kernel for the graph_send_recv op

* update the code for the define of graph_send_recv

* fix the gradient problem for graph_send_recv

* fix the compile problem

* update the enfore message for the windows

* update the code for the compiler

* update compiler problem for the windows

* udpate the code for windows

* fix some format problem
上级 413a743e
......@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/graph_send_recv_op.h"
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
......@@ -171,13 +171,3 @@ REGISTER_OPERATOR(graph_send_recv, ops::GraphSendRecvOP,
ops::GraphSendRecvGradOpMaker<paddle::framework::OpDesc>,
ops::GraphSendRecvGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(graph_send_recv_grad, ops::GraphSendRecvGradOp);
REGISTER_OP_CPU_KERNEL(graph_send_recv, ops::GraphSendRecvOpKernel<CPU, float>,
ops::GraphSendRecvOpKernel<CPU, double>,
ops::GraphSendRecvOpKernel<CPU, int>,
ops::GraphSendRecvOpKernel<CPU, int64_t>);
REGISTER_OP_CPU_KERNEL(graph_send_recv_grad,
ops::GraphSendRecvGradOpKernel<CPU, float>,
ops::GraphSendRecvGradOpKernel<CPU, double>,
ops::GraphSendRecvGradOpKernel<CPU, int>,
ops::GraphSendRecvGradOpKernel<CPU, int64_t>);
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <thrust/device_vector.h>
#include <thrust/fill.h>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/graph_send_recv_op.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/fluid/platform/place.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T, typename IndexT>
struct GraphSendRecvSumCUDAFunctor {
DEVICE inline void operator()(const T* params, T* output, const IndexT& in_i,
const IndexT& out_i) {
paddle::platform::CudaAtomicAdd(output + out_i, *(params + in_i));
}
};
template <typename T, typename IndexT>
struct GraphSendRecvMaxCUDAFunctor {
DEVICE inline void operator()(const T* params, T* output, const IndexT& in_i,
const IndexT& out_i) {
paddle::platform::CudaAtomicMax(output + out_i, *(params + in_i));
}
};
template <typename T, typename IndexT>
struct GraphSendRecvMinCUDAFunctor {
DEVICE inline void operator()(const T* params, T* output, const IndexT& in_i,
const IndexT& out_i) {
paddle::platform::CudaAtomicMin(output + out_i, *(params + in_i));
}
};
template <typename T, typename IndexT, typename Functor>
__global__ void GraphSendRecvCUDAKernel(const T* params,
const IndexT* src_indices,
const IndexT* dst_indices, T* output,
size_t index_size, size_t slice_size,
Functor functor) {
CUDA_KERNEL_LOOP_TYPE(i, index_size * slice_size, int64_t) {
int64_t indices_i = i / slice_size;
int64_t slice_i = i - indices_i * slice_size;
IndexT src_i = src_indices[indices_i];
IndexT dst_i = dst_indices[indices_i];
int64_t in_i = src_i * slice_size + slice_i;
int64_t out_i = dst_i * slice_size + slice_i;
functor(params, output, in_i, out_i);
}
}
// For max
template <typename T>
__global__ void InputResetMaxCUDAKernel(T* output, size_t input_size,
size_t slice_size) {
CUDA_KERNEL_LOOP_TYPE(i, input_size * slice_size, int64_t) {
if (*(output + i) == std::numeric_limits<T>::min()) {
*(output + i) = 0;
}
}
}
// For min
template <typename T>
__global__ void InputResetMinCUDAKernel(T* output, size_t input_size,
size_t slice_size) {
CUDA_KERNEL_LOOP_TYPE(i, input_size * slice_size, int64_t) {
if (*(output + i) == std::numeric_limits<T>::max()) {
*(output + i) = 0;
}
}
}
// Get dst_count
template <typename T, typename IndexT>
__global__ void ComputeCountCUDAKernel(int* count, const IndexT* dst_indices,
size_t index_size) {
CUDA_KERNEL_LOOP_TYPE(i, index_size, int64_t) {
IndexT dst_i = dst_indices[i];
paddle::platform::CudaAtomicAdd(count + dst_i, 1);
}
}
// For forward mean
template <typename T>
__global__ void ManipulateMeanCUDAKernel(T* output, int* count,
size_t input_size, size_t slice_size) {
CUDA_KERNEL_LOOP_TYPE(i, input_size * slice_size, int64_t) {
int64_t c_index = i / slice_size;
if (*(count + c_index) > 1) {
*(output + i) = *(output + i) / *(count + c_index);
}
}
}
// For backward mean
template <typename T, typename IndexT>
__global__ void ManipulateMeanGradCUDAKernel(
const T* params, const IndexT* src_indices, const IndexT* dst_indices,
T* output, size_t index_size, size_t slice_size, const int* dst_count) {
CUDA_KERNEL_LOOP_TYPE(i, index_size * slice_size, int64_t) {
int64_t indices_i = i / slice_size;
int64_t slice_i = i - indices_i * slice_size;
IndexT src_i = src_indices[indices_i];
IndexT dst_i = dst_indices[indices_i];
int64_t in_i = src_i * slice_size + slice_i;
int64_t out_i = dst_i * slice_size + slice_i;
paddle::platform::CudaAtomicAdd(output + out_i,
*(params + in_i) / dst_count[src_i]);
}
}
// For backward min and max
template <typename T, typename IndexT>
__global__ void ManipulateMinMaxGradCUDAKernel(
const T* params, const IndexT* src_indices, const IndexT* dst_indices,
T* output, size_t index_size, size_t slice_size, const T* ptr_input,
const T* ptr_output) {
CUDA_KERNEL_LOOP_TYPE(i, index_size * slice_size, int64_t) {
int64_t indices_i = i / slice_size;
int64_t slice_i = i - indices_i * slice_size;
IndexT src_i = src_indices[indices_i];
IndexT dst_i = dst_indices[indices_i];
int64_t in_i = src_i * slice_size + slice_i;
int64_t out_i = dst_i * slice_size + slice_i;
paddle::platform::CudaAtomicAdd(
output + out_i,
*(params + in_i) * (*(ptr_input + out_i) == *(ptr_output + in_i)));
}
}
template <typename DeviceContext, typename T, typename IndexT>
void GraphSendRecvOpCUDAKernelLaunchHelper(
const framework::ExecutionContext& ctx, const Tensor& src_index,
const Tensor& dst_index) {
auto* X = ctx.Input<Tensor>("X");
auto* Y = ctx.Output<Tensor>("Out");
std::string pool_type = ctx.Attr<std::string>("pool_type");
const int& index_size = src_index.dims()[0];
T* p_output = Y->mutable_data<T>(ctx.GetPlace());
const auto& src_dims = X->dims();
int64_t memset_size = 1;
for (int i = 0; i < src_dims.size(); ++i) {
memset_size *= src_dims[i];
}
const size_t& memset_bytes = memset_size * sizeof(T);
if (pool_type == "SUM" || pool_type == "MEAN") {
#ifdef PADDLE_WITH_HIP
hipMemset(p_output, 0, memset_bytes);
#else
cudaMemset(p_output, 0, memset_bytes);
#endif
} else if (pool_type == "MAX") {
thrust::device_ptr<T> p_output_ptr(p_output);
thrust::fill(thrust::device, p_output_ptr, p_output_ptr + memset_size,
std::numeric_limits<T>::min());
} else if (pool_type == "MIN") {
thrust::device_ptr<T> p_output_ptr(p_output);
thrust::fill(thrust::device, p_output_ptr, p_output_ptr + memset_size,
std::numeric_limits<T>::max());
}
if (index_size == 0) return;
int64_t slice_size = 1;
for (int i = 1; i < src_dims.size(); ++i) {
slice_size *= src_dims[i];
}
const T* p_src = X->data<T>();
const IndexT* s_index = src_index.data<IndexT>();
const IndexT* d_index = dst_index.data<IndexT>();
#ifdef PADDLE_WITH_HIP
int block = 256;
#else
int block = 1024;
#endif
int64_t n = slice_size * index_size;
const auto& dev_ctx = ctx.cuda_device_context();
int64_t max_grid_dimx = dev_ctx.GetCUDAMaxGridDimSize()[0];
int64_t grid_tmp = (n + block - 1) / block;
int64_t grid = grid_tmp < max_grid_dimx ? grid_tmp : max_grid_dimx;
int64_t input_size = src_dims[0];
if (pool_type == "SUM") {
GraphSendRecvSumCUDAFunctor<T, IndexT> functor;
GraphSendRecvCUDAKernel<T, IndexT,
GraphSendRecvSumCUDAFunctor<T, IndexT>><<<
grid, block, 0, reinterpret_cast<const platform::CUDADeviceContext&>(
ctx.device_context())
.stream()>>>(p_src, s_index, d_index, p_output,
index_size, slice_size, functor);
} else if (pool_type == "MAX") {
GraphSendRecvMaxCUDAFunctor<T, IndexT> functor;
GraphSendRecvCUDAKernel<T, IndexT,
GraphSendRecvMaxCUDAFunctor<T, IndexT>><<<
grid, block, 0, reinterpret_cast<const platform::CUDADeviceContext&>(
ctx.device_context())
.stream()>>>(p_src, s_index, d_index, p_output,
index_size, slice_size, functor);
int64_t grid_max_tmp = (input_size * slice_size + block - 1) / block;
int64_t grid_max =
grid_max_tmp < max_grid_dimx ? grid_max_tmp : max_grid_dimx;
InputResetMaxCUDAKernel<
T><<<grid_max, block, 0,
reinterpret_cast<const platform::CUDADeviceContext&>(
ctx.device_context())
.stream()>>>(p_output, input_size, slice_size);
} else if (pool_type == "MIN") {
GraphSendRecvMinCUDAFunctor<T, IndexT> functor;
GraphSendRecvCUDAKernel<T, IndexT,
GraphSendRecvMinCUDAFunctor<T, IndexT>><<<
grid, block, 0, reinterpret_cast<const platform::CUDADeviceContext&>(
ctx.device_context())
.stream()>>>(p_src, s_index, d_index, p_output,
index_size, slice_size, functor);
int64_t grid_min_tmp = (input_size * slice_size + block - 1) / block;
int64_t grid_min =
grid_min_tmp < max_grid_dimx ? grid_min_tmp : max_grid_dimx;
InputResetMinCUDAKernel<
T><<<grid_min, block, 0,
reinterpret_cast<const platform::CUDADeviceContext&>(
ctx.device_context())
.stream()>>>(p_output, input_size, slice_size);
} else if (pool_type == "MEAN") {
GraphSendRecvSumCUDAFunctor<T, IndexT> functor;
GraphSendRecvCUDAKernel<T, IndexT,
GraphSendRecvSumCUDAFunctor<T, IndexT>><<<
grid, block, 0, reinterpret_cast<const platform::CUDADeviceContext&>(
ctx.device_context())
.stream()>>>(p_src, s_index, d_index, p_output,
index_size, slice_size, functor);
auto* dst_count = ctx.Output<Tensor>("Dst_count");
int* p_dst_count = dst_count->mutable_data<int>(ctx.GetPlace());
#ifdef PADDLE_WITH_HIP
hipMemset(p_dst_count, 0, input_size * sizeof(int));
#else
cudaMemset(p_dst_count, 0, input_size * sizeof(int));
#endif
int64_t grid_count = (index_size + block - 1) / block;
ComputeCountCUDAKernel<
T, IndexT><<<grid_count, block, 0,
reinterpret_cast<const platform::CUDADeviceContext&>(
ctx.device_context())
.stream()>>>(p_dst_count, d_index, index_size);
int64_t grid_mean_tmp = (input_size * slice_size + block - 1) / block;
int64_t grid_mean =
grid_mean_tmp < max_grid_dimx ? grid_mean_tmp : max_grid_dimx;
ManipulateMeanCUDAKernel<
T><<<grid_mean, block, 0,
reinterpret_cast<const platform::CUDADeviceContext&>(
ctx.device_context())
.stream()>>>(p_output, p_dst_count, input_size, slice_size);
}
}
template <typename DeviceContext, typename T, typename IndexT>
void GraphSendRecvGradOpCUDAKernelLaunchHelper(
const framework::ExecutionContext& ctx, const Tensor& src_index,
const Tensor& dst_index) {
auto* X = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* Y = ctx.Output<Tensor>(framework::GradVarName("X"));
std::string pool_type = ctx.Attr<std::string>("pool_type");
const int& index_size = src_index.dims()[0];
T* p_output = Y->mutable_data<T>(ctx.GetPlace());
const auto& src_dims = X->dims();
int64_t memset_size = 1;
for (int i = 0; i < src_dims.size(); ++i) {
memset_size *= src_dims[i];
}
const size_t& memset_bytes = memset_size * sizeof(T);
#ifdef PADDLE_WITH_HIP
hipMemset(p_output, 0, memset_bytes);
#else
cudaMemset(p_output, 0, memset_bytes);
#endif
if (index_size == 0) return;
int64_t slice_size = 1;
for (int i = 1; i < src_dims.size(); ++i) {
slice_size *= src_dims[i];
}
const T* p_src = X->data<T>();
const IndexT* s_index = src_index.data<IndexT>();
const IndexT* d_index = dst_index.data<IndexT>();
#ifdef PADDLE_WITH_HIP
int block = 256;
#else
int block = 1024;
#endif
int64_t n = slice_size * index_size;
const auto& dev_ctx = ctx.cuda_device_context();
int64_t max_grid_dimx = dev_ctx.GetCUDAMaxGridDimSize()[0];
int64_t grid_tmp = (n + block - 1) / block;
int64_t grid = grid_tmp < max_grid_dimx ? grid_tmp : max_grid_dimx;
int64_t input_size = src_dims[0];
if (pool_type == "SUM") {
GraphSendRecvSumCUDAFunctor<T, IndexT> functor;
GraphSendRecvCUDAKernel<T, IndexT,
GraphSendRecvSumCUDAFunctor<T, IndexT>><<<
grid, block, 0, reinterpret_cast<const platform::CUDADeviceContext&>(
ctx.device_context())
.stream()>>>(p_src, s_index, d_index, p_output,
index_size, slice_size, functor);
} else if (pool_type == "MEAN") {
auto* dst_count = ctx.Input<Tensor>("Dst_count");
const int* s_count = dst_count->data<int>();
ManipulateMeanGradCUDAKernel<T, IndexT><<<
grid, block, 0, reinterpret_cast<const platform::CUDADeviceContext&>(
ctx.device_context())
.stream()>>>(p_src, s_index, d_index, p_output,
index_size, slice_size, s_count);
} else if (pool_type == "MAX" || pool_type == "MIN") {
auto* input = ctx.Input<Tensor>("X");
auto* output = ctx.Input<Tensor>("Out");
const T* ptr_input = input->data<T>();
const T* ptr_output = output->data<T>();
ManipulateMinMaxGradCUDAKernel<T, IndexT><<<
grid, block, 0, reinterpret_cast<const platform::CUDADeviceContext&>(
ctx.device_context())
.stream()>>>(p_src, s_index, d_index, p_output,
index_size, slice_size, ptr_input,
ptr_output);
}
}
template <typename DeviceContext, typename T>
class GraphSendRecvOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* src_index = ctx.Input<Tensor>("Src_index");
auto* dst_index = ctx.Input<Tensor>("Dst_index");
auto index_type = framework::TransToProtoVarType(src_index->dtype());
if (index_type == framework::proto::VarType::INT32) {
GraphSendRecvOpCUDAKernelLaunchHelper<DeviceContext, T, int>(
ctx, *src_index, *dst_index);
} else if (index_type == framework::proto::VarType::INT64) {
GraphSendRecvOpCUDAKernelLaunchHelper<DeviceContext, T, int64_t>(
ctx, *src_index, *dst_index);
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Unsupported Src_index or Dst_index dtype, expected int, int64, but "
"got %s.",
index_type));
}
}
};
template <typename DeviceContext, typename T>
class GraphSendRecvGradOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* src_index = ctx.Input<Tensor>("Dst_index");
auto* dst_index = ctx.Input<Tensor>("Src_index");
auto index_type = framework::TransToProtoVarType(src_index->dtype());
if (index_type == framework::proto::VarType::INT32) {
GraphSendRecvGradOpCUDAKernelLaunchHelper<DeviceContext, T, int>(
ctx, *src_index, *dst_index);
} else if (index_type == framework::proto::VarType::INT64) {
GraphSendRecvGradOpCUDAKernelLaunchHelper<DeviceContext, T, int64_t>(
ctx, *src_index, *dst_index);
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Unsupported Src_index or Dst_index dtype, expected int, int64, but "
"got %s.",
index_type));
}
}
};
} // namespace operators
} // namespace paddle
using CUDA = paddle::platform::CUDADeviceContext;
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(graph_send_recv,
ops::GraphSendRecvOpCUDAKernel<CUDA, float>,
ops::GraphSendRecvOpCUDAKernel<CUDA, double>,
ops::GraphSendRecvOpCUDAKernel<CUDA, int>,
ops::GraphSendRecvOpCUDAKernel<CUDA, int64_t>);
REGISTER_OP_CUDA_KERNEL(graph_send_recv_grad,
ops::GraphSendRecvGradOpCUDAKernel<CUDA, float>,
ops::GraphSendRecvGradOpCUDAKernel<CUDA, double>,
ops::GraphSendRecvGradOpCUDAKernel<CUDA, int>,
ops::GraphSendRecvGradOpCUDAKernel<CUDA, int64_t>);
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/platform/place.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T>
struct GraphSendRecvSumFunctor {
void operator()(const bool& first_flag, const Tensor& src_slice,
Tensor* dst_slice) {
auto eigen_src = framework::EigenVector<T>::Flatten(src_slice);
auto eigen_dst = framework::EigenVector<T>::Flatten(*dst_slice);
eigen_dst += eigen_src;
}
};
template <typename T>
struct GraphSendRecvMinFunctor {
void operator()(const bool& first_flag, const Tensor& src_slice,
Tensor* dst_slice) {
auto eigen_src = framework::EigenVector<T>::Flatten(src_slice);
auto eigen_dst = framework::EigenVector<T>::Flatten(*dst_slice);
if (first_flag) {
eigen_dst += eigen_src;
} else {
eigen_dst = eigen_dst.cwiseMin(eigen_src);
}
}
};
template <typename T>
struct GraphSendRecvMaxFunctor {
void operator()(const int& first_flag, const Tensor& src_slice,
Tensor* dst_slice) {
auto eigen_src = framework::EigenVector<T>::Flatten(src_slice);
auto eigen_dst = framework::EigenVector<T>::Flatten(*dst_slice);
if (first_flag) {
eigen_dst += eigen_src;
} else {
eigen_dst = eigen_dst.cwiseMax(eigen_src);
}
}
};
template <typename T, typename IndexT, typename Functor>
void elementwise_inner_operation(const Tensor& src, Tensor* dst,
const IndexT& src_index,
const IndexT& dst_index,
const bool& first_flag, Functor functor) {
auto src_slice = src.Slice(src_index, src_index + 1);
auto dst_slice = dst->Slice(dst_index, dst_index + 1);
functor(first_flag, src_slice, &dst_slice);
}
template <typename T, typename IndexT, typename Functor>
void graph_send_recv_cpu_for_loop(const int& input_size, const int& index_size,
const IndexT* s_index, const IndexT* d_index,
const Tensor& src, Tensor* dst,
const std::string& pool_type,
int* dst_count = nullptr) {
Functor functor;
if (pool_type == "SUM") {
for (int i = 0; i < index_size; ++i) {
const IndexT& src_idx = s_index[i];
const IndexT& dst_idx = d_index[i];
elementwise_inner_operation<T, IndexT, Functor>(src, dst, src_idx,
dst_idx, false, functor);
}
} else if (pool_type == "MEAN") {
for (int i = 0; i < index_size; ++i) {
const IndexT& src_idx = s_index[i];
const IndexT& dst_idx = d_index[i];
elementwise_inner_operation<T, IndexT, Functor>(src, dst, src_idx,
dst_idx, false, functor);
}
for (int i = 0; i < index_size; ++i) {
IndexT dst_idx = d_index[i];
*(dst_count + dst_idx) += 1;
}
for (int i = 0; i < input_size; ++i) {
if (*(dst_count + i) == 0) continue;
auto dst_slice = dst->Slice(i, i + 1);
auto eigen_dst = framework::EigenVector<T>::Flatten(dst_slice);
eigen_dst = eigen_dst / static_cast<T>(*(dst_count + i));
}
} else if (pool_type == "MIN" || pool_type == "MAX") {
std::set<IndexT> existed_dst;
for (int i = 0; i < index_size; ++i) {
const IndexT& src_idx = s_index[i];
const IndexT& dst_idx = d_index[i];
bool in_set = existed_dst.find(dst_idx) != existed_dst.end();
if (!in_set) {
elementwise_inner_operation<T, IndexT, Functor>(src, dst, src_idx,
dst_idx, true, functor);
existed_dst.emplace(dst_idx);
} else {
elementwise_inner_operation<T, IndexT, Functor>(
src, dst, src_idx, dst_idx, false, functor);
}
}
}
}
template <typename T, typename IndexT, typename Functor>
void graph_send_recv_cpu_for_loop_grad(
const int& input_size, const int& index_size, const IndexT* s_index,
const IndexT* d_index, const Tensor& src, Tensor* dst,
const std::string& pool_type, const int* dst_count = nullptr,
const Tensor* input = nullptr, const Tensor* output = nullptr) {
if (pool_type == "SUM") {
Functor functor;
for (int i = 0; i < index_size; ++i) {
const IndexT& src_idx = s_index[i];
const IndexT& dst_idx = d_index[i];
elementwise_inner_operation<T, IndexT, Functor>(src, dst, src_idx,
dst_idx, false, functor);
}
} else if (pool_type == "MEAN") {
for (int i = 0; i < index_size; ++i) {
const IndexT& src_idx = s_index[i];
const IndexT& dst_idx = d_index[i];
auto src_slice = src.Slice(src_idx, src_idx + 1);
auto dst_slice = dst->Slice(dst_idx, dst_idx + 1);
auto eigen_src = framework::EigenVector<T>::Flatten(src_slice);
auto eigen_dst = framework::EigenVector<T>::Flatten(dst_slice);
eigen_dst += (eigen_src / static_cast<T>(dst_count[src_idx]));
}
} else if (pool_type == "MIN" || pool_type == "MAX") {
for (int i = 0; i < index_size; ++i) {
const IndexT& forward_src_idx = d_index[i];
const IndexT& forward_dst_idx = s_index[i];
auto input_slice = input->Slice(forward_src_idx, forward_src_idx + 1);
auto output_slice = output->Slice(forward_dst_idx, forward_dst_idx + 1);
auto eigen_input = framework::EigenVector<T>::Flatten(input_slice);
auto eigen_output = framework::EigenVector<T>::Flatten(output_slice);
auto src_slice = src.Slice(forward_dst_idx, forward_dst_idx + 1);
auto dst_slice = dst->Slice(forward_src_idx, forward_src_idx + 1);
auto eigen_src = framework::EigenVector<T>::Flatten(src_slice);
auto eigen_dst = framework::EigenVector<T>::Flatten(dst_slice);
eigen_dst += eigen_src * (eigen_output == eigen_input);
}
}
}
template <typename DeviceContext, typename T, typename IndexT>
void GraphSendRecvOpKernelLaunchHelper(const framework::ExecutionContext& ctx,
const Tensor& src_index) {
auto* X = ctx.Input<Tensor>("X");
auto* dst_index = ctx.Input<Tensor>("Dst_index");
auto* Y = ctx.Output<Tensor>("Out");
const int& index_size = src_index.dims()[0];
T* p_output = Y->mutable_data<T>(ctx.GetPlace());
const auto& src_dims = X->dims();
int64_t memset_size = 1;
for (int i = 0; i < src_dims.size(); ++i) memset_size *= src_dims[i];
const size_t& memset_bytes = memset_size * sizeof(T);
memset(p_output, 0, memset_bytes);
if (index_size == 0) return;
const IndexT* s_index = src_index.data<IndexT>();
const IndexT* d_index = dst_index->data<IndexT>();
const std::string& pool_type = ctx.Attr<std::string>("pool_type");
if (pool_type == "SUM") {
graph_send_recv_cpu_for_loop<T, IndexT, GraphSendRecvSumFunctor<T>>(
src_dims[0], index_size, s_index, d_index, *X, Y, pool_type);
} else if (pool_type == "MIN") {
graph_send_recv_cpu_for_loop<T, IndexT, GraphSendRecvMinFunctor<T>>(
src_dims[0], index_size, s_index, d_index, *X, Y, pool_type);
} else if (pool_type == "MAX") {
graph_send_recv_cpu_for_loop<T, IndexT, GraphSendRecvMaxFunctor<T>>(
src_dims[0], index_size, s_index, d_index, *X, Y, pool_type);
} else if (pool_type == "MEAN") {
auto* dst_count = ctx.Output<Tensor>("Dst_count");
int* p_dst_count = dst_count->mutable_data<int>(ctx.GetPlace());
memset(p_dst_count, 0, src_dims[0] * sizeof(int));
graph_send_recv_cpu_for_loop<T, IndexT, GraphSendRecvSumFunctor<T>>(
src_dims[0], index_size, s_index, d_index, *X, Y, pool_type,
p_dst_count);
}
}
template <typename DeviceContext, typename T, typename IndexT>
void GraphSendRecvGradOpKernelLaunchHelper(
const framework::ExecutionContext& ctx, const Tensor& src_index) {
auto* X = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* dst_index = ctx.Input<Tensor>("Src_index");
auto* Y = ctx.Output<Tensor>(framework::GradVarName("X"));
const int& index_size = src_index.dims()[0];
T* p_output = Y->mutable_data<T>(ctx.GetPlace());
const auto& src_dims = X->dims();
int64_t memset_size = 1;
for (int i = 0; i < src_dims.size(); ++i) memset_size *= src_dims[i];
const size_t& memset_bytes = memset_size * sizeof(T);
memset(p_output, 0, memset_bytes);
if (index_size == 0) return;
const IndexT* s_index = src_index.data<IndexT>();
const IndexT* d_index = dst_index->data<IndexT>();
const std::string& pool_type = ctx.Attr<std::string>("pool_type");
if (pool_type == "SUM") {
graph_send_recv_cpu_for_loop_grad<T, IndexT, GraphSendRecvSumFunctor<T>>(
src_dims[0], index_size, s_index, d_index, *X, Y, pool_type);
} else if (pool_type == "MEAN") {
auto* dst_count = ctx.Input<Tensor>("Dst_count");
const int* s_count = dst_count->data<int>();
// Functor not used here.
graph_send_recv_cpu_for_loop_grad<T, IndexT, GraphSendRecvSumFunctor<T>>(
src_dims[0], index_size, s_index, d_index, *X, Y, pool_type, s_count);
} else if (pool_type == "MIN" || pool_type == "MAX") {
const auto* input = ctx.Input<Tensor>("X");
const auto* output = ctx.Input<Tensor>("Out");
// Functor not used here.
graph_send_recv_cpu_for_loop_grad<T, IndexT, GraphSendRecvMinFunctor<T>>(
src_dims[0], index_size, s_index, d_index, *X, Y, pool_type, nullptr,
input, output);
}
}
template <typename DeviceContext, typename T>
class GraphSendRecvOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* src_index = ctx.Input<Tensor>("Src_index");
auto index_type = framework::TransToProtoVarType(src_index->dtype());
if (index_type == framework::proto::VarType::INT32) {
GraphSendRecvOpKernelLaunchHelper<DeviceContext, T, int>(ctx, *src_index);
} else if (index_type == framework::proto::VarType::INT64) {
GraphSendRecvOpKernelLaunchHelper<DeviceContext, T, int64_t>(ctx,
*src_index);
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Unsupported Src_index or Dst_index type, Expected int, int64, but "
"got %s.",
index_type));
}
}
};
template <typename DeviceContext, typename T>
class GraphSendRecvGradOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* src_index = ctx.Input<Tensor>("Dst_index");
auto index_type = framework::TransToProtoVarType(src_index->dtype());
if (index_type == framework::proto::VarType::INT32) {
GraphSendRecvGradOpKernelLaunchHelper<DeviceContext, T, int>(ctx,
*src_index);
} else if (index_type == framework::proto::VarType::INT64) {
GraphSendRecvGradOpKernelLaunchHelper<DeviceContext, T, int64_t>(
ctx, *src_index);
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Unsupported Src_index or Dst_index type, Expected int, int64, but "
"got %s.",
index_type));
}
}
};
} // namespace operators
} // namespace paddle
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include <vector>
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/hostdevice.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
namespace phi {
template <typename T>
struct GraphSendRecvSumFunctor {
void operator()(const bool& first_flag,
const DenseTensor& src_slice,
DenseTensor* dst_slice) {
auto eigen_src = phi::EigenVector<T>::Flatten(src_slice);
auto eigen_dst = phi::EigenVector<T>::Flatten(*dst_slice);
eigen_dst += eigen_src;
}
};
template <typename T>
struct GraphSendRecvMinFunctor {
void operator()(const bool& first_flag,
const DenseTensor& src_slice,
DenseTensor* dst_slice) {
auto eigen_src = phi::EigenVector<T>::Flatten(src_slice);
auto eigen_dst = phi::EigenVector<T>::Flatten(*dst_slice);
if (first_flag) {
eigen_dst += eigen_src;
} else {
eigen_dst = eigen_dst.cwiseMin(eigen_src);
}
}
};
template <typename T>
struct GraphSendRecvMaxFunctor {
void operator()(const int& first_flag,
const DenseTensor& src_slice,
DenseTensor* dst_slice) {
auto eigen_src = phi::EigenVector<T>::Flatten(src_slice);
auto eigen_dst = phi::EigenVector<T>::Flatten(*dst_slice);
if (first_flag) {
eigen_dst += eigen_src;
} else {
eigen_dst = eigen_dst.cwiseMax(eigen_src);
}
}
};
template <typename T, typename IndexT, typename Functor>
void ElementwiseInnerOperation(const DenseTensor& src,
DenseTensor* dst,
const IndexT& src_index,
const IndexT& dst_index,
const bool& first_flag,
Functor functor) {
auto src_slice = src.Slice(src_index, src_index + 1);
auto dst_slice = dst->Slice(dst_index, dst_index + 1);
functor(first_flag, src_slice, &dst_slice);
}
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/graph_send_recv_grad_kernel.h"
#include "paddle/phi/kernels/cpu/graph_send_recv_funcs.h"
#include <algorithm>
#include <vector>
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename T, typename IndexT, typename Functor>
void GraphSendRecvCpuGradLoop(const int& input_size,
const int& index_size,
const IndexT* s_index,
const IndexT* d_index,
const DenseTensor& src,
DenseTensor* dst,
const std::string& pool_type,
const int* dst_count = nullptr,
const DenseTensor* input = nullptr,
const DenseTensor* output = nullptr) {
if (pool_type == "SUM") {
Functor functor;
for (int i = 0; i < index_size; ++i) {
const IndexT& src_idx = s_index[i];
const IndexT& dst_idx = d_index[i];
ElementwiseInnerOperation<T, IndexT, Functor>(
src, dst, src_idx, dst_idx, false, functor);
}
} else if (pool_type == "MEAN") {
for (int i = 0; i < index_size; ++i) {
const IndexT& src_idx = s_index[i];
const IndexT& dst_idx = d_index[i];
auto src_slice = src.Slice(src_idx, src_idx + 1);
auto dst_slice = dst->Slice(dst_idx, dst_idx + 1);
auto eigen_src = phi::EigenVector<T>::Flatten(src_slice);
auto eigen_dst = phi::EigenVector<T>::Flatten(dst_slice);
eigen_dst += (eigen_src / static_cast<T>(dst_count[src_idx]));
}
} else if (pool_type == "MIN" || pool_type == "MAX") {
for (int i = 0; i < index_size; ++i) {
const IndexT& forward_src_idx = d_index[i];
const IndexT& forward_dst_idx = s_index[i];
auto input_slice = input->Slice(forward_src_idx, forward_src_idx + 1);
auto output_slice = output->Slice(forward_dst_idx, forward_dst_idx + 1);
auto eigen_input = phi::EigenVector<T>::Flatten(input_slice);
auto eigen_output = phi::EigenVector<T>::Flatten(output_slice);
auto src_slice = src.Slice(forward_dst_idx, forward_dst_idx + 1);
auto dst_slice = dst->Slice(forward_src_idx, forward_src_idx + 1);
auto eigen_src = phi::EigenVector<T>::Flatten(src_slice);
auto eigen_dst = phi::EigenVector<T>::Flatten(dst_slice);
eigen_dst += eigen_src * (eigen_output == eigen_input);
}
}
}
template <typename Context, typename T, typename IndexT>
void GraphSendRecvGradOpKernelLaunchHelper(
const Context& ctx,
const DenseTensor& out_grad,
const DenseTensor& src_index,
const DenseTensor& dst_index,
const std::string& pool_type,
DenseTensor* x_grad,
const DenseTensor* dst_count = nullptr,
const DenseTensor* x = nullptr,
const DenseTensor* out = nullptr) {
const int& index_size = dst_index.dims()[0];
ctx.template Alloc<T>(x_grad);
T* p_output = x_grad->data<T>();
const auto& src_dims = out_grad.dims();
int64_t memset_size = 1;
for (int i = 0; i < src_dims.size(); ++i) memset_size *= src_dims[i];
const size_t& memset_bytes = memset_size * sizeof(T);
memset(p_output, 0, memset_bytes);
if (index_size == 0) return;
const IndexT* s_index = src_index.data<IndexT>();
const IndexT* d_index = dst_index.data<IndexT>();
if (pool_type == "SUM") {
GraphSendRecvCpuGradLoop<T, IndexT, GraphSendRecvSumFunctor<T>>(
src_dims[0], index_size, d_index, s_index, out_grad, x_grad, pool_type);
} else if (pool_type == "MEAN") {
const int* s_count = dst_count->data<int>();
// Functor not used here.
GraphSendRecvCpuGradLoop<T, IndexT, GraphSendRecvSumFunctor<T>>(src_dims[0],
index_size,
d_index,
s_index,
out_grad,
x_grad,
pool_type,
s_count);
} else if (pool_type == "MIN" || pool_type == "MAX") {
// Functor not used here.
GraphSendRecvCpuGradLoop<T, IndexT, GraphSendRecvMinFunctor<T>>(src_dims[0],
index_size,
d_index,
s_index,
out_grad,
x_grad,
pool_type,
nullptr,
x,
out);
}
}
template <typename T, typename Context>
void GraphSendRecvGradKernel(const Context& ctx,
const DenseTensor& out_grad,
paddle::optional<const DenseTensor&> x,
paddle::optional<const DenseTensor&> out,
const DenseTensor& src_index,
const DenseTensor& dst_index,
paddle::optional<const DenseTensor&> dst_count,
const std::string& pool_type,
DenseTensor* x_grad) {
auto index_type = src_index.dtype();
if (index_type == phi::DataType::INT32) {
GraphSendRecvGradOpKernelLaunchHelper<Context, T, int32_t>(
ctx,
out_grad,
src_index,
dst_index,
pool_type,
x_grad,
dst_count.get_ptr(),
x.get_ptr(),
out.get_ptr());
} else if (index_type == phi::DataType::INT64) {
GraphSendRecvGradOpKernelLaunchHelper<Context, T, int64_t>(
ctx,
out_grad,
src_index,
dst_index,
pool_type,
x_grad,
dst_count.get_ptr(),
x.get_ptr(),
out.get_ptr());
}
}
} // namespace phi
PD_REGISTER_KERNEL(graph_send_recv_grad,
CPU,
ALL_LAYOUT,
phi::GraphSendRecvGradKernel,
float,
double,
int,
int64_t) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/graph_send_recv_kernel.h"
#include "paddle/phi/kernels/cpu/graph_send_recv_funcs.h"
#include <algorithm>
#include <set>
#include <vector>
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/hostdevice.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename T, typename IndexT, typename Functor>
void GraphSendRecvCpuLoop(const int& input_size,
const int& index_size,
const IndexT* s_index,
const IndexT* d_index,
const DenseTensor& src,
DenseTensor* dst,
const std::string& pool_type,
int* dst_count = nullptr) {
Functor functor;
if (pool_type == "SUM") {
for (int i = 0; i < index_size; ++i) {
const IndexT& src_idx = s_index[i];
const IndexT& dst_idx = d_index[i];
ElementwiseInnerOperation<T, IndexT, Functor>(
src, dst, src_idx, dst_idx, false, functor);
}
} else if (pool_type == "MEAN") {
for (int i = 0; i < index_size; ++i) {
const IndexT& src_idx = s_index[i];
const IndexT& dst_idx = d_index[i];
ElementwiseInnerOperation<T, IndexT, Functor>(
src, dst, src_idx, dst_idx, false, functor);
}
for (int i = 0; i < index_size; ++i) {
IndexT dst_idx = d_index[i];
*(dst_count + dst_idx) += 1;
}
for (int i = 0; i < input_size; ++i) {
if (*(dst_count + i) == 0) continue;
auto dst_slice = dst->Slice(i, i + 1);
auto eigen_dst = phi::EigenVector<T>::Flatten(dst_slice);
eigen_dst = eigen_dst / static_cast<T>(*(dst_count + i));
}
} else if (pool_type == "MIN" || pool_type == "MAX") {
std::set<IndexT> existed_dst;
for (int i = 0; i < index_size; ++i) {
const IndexT& src_idx = s_index[i];
const IndexT& dst_idx = d_index[i];
bool in_set = existed_dst.find(dst_idx) != existed_dst.end();
if (!in_set) {
ElementwiseInnerOperation<T, IndexT, Functor>(
src, dst, src_idx, dst_idx, true, functor);
existed_dst.emplace(dst_idx);
} else {
ElementwiseInnerOperation<T, IndexT, Functor>(
src, dst, src_idx, dst_idx, false, functor);
}
}
}
}
template <typename Context, typename T, typename IndexT>
void GraphSendRecvOpKernelLaunchHelper(const Context& ctx,
const DenseTensor& x,
const DenseTensor& src_index,
const DenseTensor& dst_index,
const std::string& pool_type,
DenseTensor* out,
DenseTensor* dst_count = nullptr) {
const int& index_size = src_index.dims()[0];
ctx.template Alloc<T>(out);
T* p_output = out->data<T>();
const auto& src_dims = x.dims();
int64_t memset_size = 1;
for (int i = 0; i < src_dims.size(); ++i) memset_size *= src_dims[i];
const size_t& memset_bytes = memset_size * sizeof(T);
memset(p_output, 0, memset_bytes);
if (index_size == 0) return;
const IndexT* s_index = src_index.data<IndexT>();
const IndexT* d_index = dst_index.data<IndexT>();
if (pool_type == "SUM") {
GraphSendRecvCpuLoop<T, IndexT, GraphSendRecvSumFunctor<T>>(
src_dims[0], index_size, s_index, d_index, x, out, pool_type);
} else if (pool_type == "MIN") {
GraphSendRecvCpuLoop<T, IndexT, GraphSendRecvMinFunctor<T>>(
src_dims[0], index_size, s_index, d_index, x, out, pool_type);
} else if (pool_type == "MAX") {
GraphSendRecvCpuLoop<T, IndexT, GraphSendRecvMaxFunctor<T>>(
src_dims[0], index_size, s_index, d_index, x, out, pool_type);
} else if (pool_type == "MEAN") {
ctx.template Alloc<int>(dst_count);
int* p_dst_count = dst_count->data<int>();
memset(p_dst_count, 0, src_dims[0] * sizeof(int));
GraphSendRecvCpuLoop<T, IndexT, GraphSendRecvSumFunctor<T>>(src_dims[0],
index_size,
s_index,
d_index,
x,
out,
pool_type,
p_dst_count);
}
}
template <typename T, typename Context>
void GraphSendRecvKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& src_index,
const DenseTensor& dst_index,
const std::string& pool_type,
DenseTensor* out,
DenseTensor* dst_count) {
auto index_type = src_index.dtype();
if (index_type == phi::DataType::INT32) {
GraphSendRecvOpKernelLaunchHelper<Context, T, int32_t>(
ctx, x, src_index, dst_index, pool_type, out, dst_count);
} else if (index_type == phi::DataType::INT64) {
GraphSendRecvOpKernelLaunchHelper<Context, T, int64_t>(
ctx, x, src_index, dst_index, pool_type, out, dst_count);
}
}
} // namespace phi
PD_REGISTER_KERNEL(graph_send_recv,
CPU,
ALL_LAYOUT,
phi::GraphSendRecvKernel,
float,
double,
int,
int64_t) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/kernels/graph_send_recv_kernel.h"
#include <thrust/device_vector.h>
#include <thrust/fill.h>
#include <algorithm>
#include <vector>
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/hostdevice.h"
namespace phi {
template <typename T, typename IndexT>
struct GraphSendRecvSumCUDAFunctor {
DEVICE inline void operator()(const T* params,
T* output,
const IndexT& in_i,
const IndexT& out_i) {
paddle::platform::CudaAtomicAdd(output + out_i, *(params + in_i));
}
};
template <typename T, typename IndexT>
struct GraphSendRecvMaxCUDAFunctor {
DEVICE inline void operator()(const T* params,
T* output,
const IndexT& in_i,
const IndexT& out_i) {
paddle::platform::CudaAtomicMax(output + out_i, *(params + in_i));
}
};
template <typename T, typename IndexT>
struct GraphSendRecvMinCUDAFunctor {
DEVICE inline void operator()(const T* params,
T* output,
const IndexT& in_i,
const IndexT& out_i) {
paddle::platform::CudaAtomicMin(output + out_i, *(params + in_i));
}
};
template <typename T, typename IndexT, typename Functor>
__global__ void GraphSendRecvCUDAKernel(const T* params,
const IndexT* src_indices,
const IndexT* dst_indices,
T* output,
size_t index_size,
size_t slice_size,
Functor functor) {
CUDA_KERNEL_LOOP_TYPE(i, index_size * slice_size, int64_t) {
int64_t indices_i = i / slice_size;
int64_t slice_i = i - indices_i * slice_size;
IndexT src_i = src_indices[indices_i];
IndexT dst_i = dst_indices[indices_i];
int64_t in_i = src_i * slice_size + slice_i;
int64_t out_i = dst_i * slice_size + slice_i;
functor(params, output, in_i, out_i);
}
}
// For max
template <typename T>
__global__ void InputResetMaxCUDAKernel(T* output,
size_t input_size,
size_t slice_size) {
CUDA_KERNEL_LOOP_TYPE(i, input_size * slice_size, int64_t) {
if (*(output + i) == std::numeric_limits<T>::min()) {
*(output + i) = 0;
}
}
}
// For min
template <typename T>
__global__ void InputResetMinCUDAKernel(T* output,
size_t input_size,
size_t slice_size) {
CUDA_KERNEL_LOOP_TYPE(i, input_size * slice_size, int64_t) {
if (*(output + i) == std::numeric_limits<T>::max()) {
*(output + i) = 0;
}
}
}
// Get dst_count
template <typename T, typename IndexT>
__global__ void ComputeCountCUDAKernel(int32_t* count,
const IndexT* dst_indices,
size_t index_size) {
CUDA_KERNEL_LOOP_TYPE(i, index_size, int64_t) {
IndexT dst_i = dst_indices[i];
paddle::platform::CudaAtomicAdd(count + dst_i, 1);
}
}
// For forward mean
template <typename T>
__global__ void ManipulateMeanCUDAKernel(T* output,
int32_t* count,
size_t input_size,
size_t slice_size) {
CUDA_KERNEL_LOOP_TYPE(i, input_size * slice_size, int64_t) {
int64_t c_index = i / slice_size;
if (*(count + c_index) > 1) {
*(output + i) = *(output + i) / *(count + c_index);
}
}
}
// For backward mean
template <typename T, typename IndexT>
__global__ void ManipulateMeanGradCUDAKernel(const T* params,
const IndexT* src_indices,
const IndexT* dst_indices,
T* output,
size_t index_size,
size_t slice_size,
const int32_t* dst_count) {
CUDA_KERNEL_LOOP_TYPE(i, index_size * slice_size, int64_t) {
int64_t indices_i = i / slice_size;
int64_t slice_i = i - indices_i * slice_size;
IndexT src_i = src_indices[indices_i];
IndexT dst_i = dst_indices[indices_i];
int64_t in_i = src_i * slice_size + slice_i;
int64_t out_i = dst_i * slice_size + slice_i;
paddle::platform::CudaAtomicAdd(output + out_i,
*(params + in_i) / dst_count[src_i]);
}
}
// For backward min and max
template <typename T, typename IndexT>
__global__ void ManipulateMinMaxGradCUDAKernel(const T* params,
const IndexT* src_indices,
const IndexT* dst_indices,
T* output,
size_t index_size,
size_t slice_size,
const T* ptr_input,
const T* ptr_output) {
CUDA_KERNEL_LOOP_TYPE(i, index_size * slice_size, int64_t) {
int64_t indices_i = i / slice_size;
int64_t slice_i = i - indices_i * slice_size;
IndexT src_i = src_indices[indices_i];
IndexT dst_i = dst_indices[indices_i];
int64_t in_i = src_i * slice_size + slice_i;
int64_t out_i = dst_i * slice_size + slice_i;
paddle::platform::CudaAtomicAdd(
output + out_i,
*(params + in_i) * (*(ptr_input + out_i) == *(ptr_output + in_i)));
}
}
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/gpu/graph_send_recv_funcs.h"
#include "paddle/phi/kernels/graph_send_recv_grad_kernel.h"
#include <algorithm>
#include <vector>
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/hostdevice.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename Context, typename T, typename IndexT>
void GraphSendRecvGradOpCUDAKernelLaunchHelper(
const Context& ctx,
const DenseTensor& out_grad,
const DenseTensor& src_index,
const DenseTensor& dst_index,
const std::string& pool_type,
DenseTensor* x_grad,
const DenseTensor* dst_count = nullptr,
const DenseTensor* x = nullptr,
const DenseTensor* out = nullptr) {
const int& index_size = dst_index.dims()[0];
ctx.template Alloc<T>(x_grad);
T* p_output = x_grad->data<T>();
const auto& src_dims = out_grad.dims();
int64_t memset_size = 1;
for (int i = 0; i < src_dims.size(); ++i) {
memset_size *= src_dims[i];
}
const size_t& memset_bytes = memset_size * sizeof(T);
#ifdef PADDLE_WITH_HIP
hipMemset(p_output, 0, memset_bytes);
#else
cudaMemset(p_output, 0, memset_bytes);
#endif
if (index_size == 0) return;
int64_t slice_size = 1;
for (int i = 1; i < src_dims.size(); ++i) {
slice_size *= src_dims[i];
}
const T* p_src = out_grad.data<T>();
const IndexT* s_index = src_index.data<IndexT>();
const IndexT* d_index = dst_index.data<IndexT>();
#ifdef PADDLE_WITH_HIP
int block = 256;
#else
int block = 1024;
#endif
int64_t n = slice_size * index_size;
int64_t max_grid_dimx = ctx.GetCUDAMaxGridDimSize()[0];
int64_t grid_tmp = (n + block - 1) / block;
int64_t grid = grid_tmp < max_grid_dimx ? grid_tmp : max_grid_dimx;
int64_t input_size = src_dims[0];
if (pool_type == "SUM") {
GraphSendRecvSumCUDAFunctor<T, IndexT> functor;
GraphSendRecvCUDAKernel<
T,
IndexT,
GraphSendRecvSumCUDAFunctor<T,
IndexT>><<<grid, block, 0, ctx.stream()>>>(
p_src, d_index, s_index, p_output, index_size, slice_size, functor);
} else if (pool_type == "MEAN") {
const int32_t* s_count = dst_count->data<int32_t>();
ManipulateMeanGradCUDAKernel<T, IndexT><<<grid, block, 0, ctx.stream()>>>(
p_src, d_index, s_index, p_output, index_size, slice_size, s_count);
} else if (pool_type == "MAX" || pool_type == "MIN") {
const T* ptr_input = x->data<T>();
const T* ptr_output = out->data<T>();
ManipulateMinMaxGradCUDAKernel<T, IndexT><<<grid, block, 0, ctx.stream()>>>(
p_src,
d_index,
s_index,
p_output,
index_size,
slice_size,
ptr_input,
ptr_output);
}
}
template <typename T, typename Context>
void GraphSendRecvGradKernel(const Context& ctx,
const DenseTensor& out_grad,
paddle::optional<const DenseTensor&> x,
paddle::optional<const DenseTensor&> out,
const DenseTensor& src_index,
const DenseTensor& dst_index,
paddle::optional<const DenseTensor&> dst_count,
const std::string& pool_type,
DenseTensor* x_grad) {
auto index_type = src_index.dtype();
if (index_type == phi::DataType::INT32) {
GraphSendRecvGradOpCUDAKernelLaunchHelper<Context, T, int32_t>(
ctx,
out_grad,
src_index,
dst_index,
pool_type,
x_grad,
dst_count.get_ptr(),
x.get_ptr(),
out.get_ptr());
} else if (index_type == phi::DataType::INT64) {
GraphSendRecvGradOpCUDAKernelLaunchHelper<Context, T, int64_t>(
ctx,
out_grad,
src_index,
dst_index,
pool_type,
x_grad,
dst_count.get_ptr(),
x.get_ptr(),
out.get_ptr());
}
}
} // namespace phi
PD_REGISTER_KERNEL(graph_send_recv_grad,
GPU,
ALL_LAYOUT,
phi::GraphSendRecvGradKernel,
float,
double,
int,
int64_t) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/gpu/graph_send_recv_funcs.h"
#include "paddle/phi/kernels/graph_send_recv_kernel.h"
#include <thrust/device_vector.h>
#include <thrust/fill.h>
#include <algorithm>
#include <vector>
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/hostdevice.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename Context, typename T, typename IndexT>
void GraphSendRecvOpCUDAKernelLaunchHelper(const Context& ctx,
const DenseTensor& x,
const DenseTensor& src_index,
const DenseTensor& dst_index,
const std::string& pool_type,
DenseTensor* out,
DenseTensor* dst_count = nullptr) {
const int& index_size = src_index.dims()[0];
ctx.template Alloc<T>(out);
T* p_output = out->data<T>();
const auto& src_dims = x.dims();
int64_t memset_size = 1;
for (int i = 0; i < src_dims.size(); ++i) {
memset_size *= src_dims[i];
}
const size_t& memset_bytes = memset_size * sizeof(T);
if (pool_type == "SUM" || pool_type == "MEAN") {
#ifdef PADDLE_WITH_HIP
hipMemset(p_output, 0, memset_bytes);
#else
cudaMemset(p_output, 0, memset_bytes);
#endif
} else if (pool_type == "MAX") {
thrust::device_ptr<T> p_output_ptr(p_output);
thrust::fill(thrust::device,
p_output_ptr,
p_output_ptr + memset_size,
std::numeric_limits<T>::min());
} else if (pool_type == "MIN") {
thrust::device_ptr<T> p_output_ptr(p_output);
thrust::fill(thrust::device,
p_output_ptr,
p_output_ptr + memset_size,
std::numeric_limits<T>::max());
}
if (index_size == 0) return;
int64_t slice_size = 1;
for (int i = 1; i < src_dims.size(); ++i) {
slice_size *= src_dims[i];
}
const T* p_src = x.data<T>();
const IndexT* s_index = src_index.data<IndexT>();
const IndexT* d_index = dst_index.data<IndexT>();
#ifdef PADDLE_WITH_HIP
int block = 256;
#else
int block = 1024;
#endif
int64_t n = slice_size * index_size;
int64_t max_grid_dimx = ctx.GetCUDAMaxGridDimSize()[0];
int64_t grid_tmp = (n + block - 1) / block;
int64_t grid = grid_tmp < max_grid_dimx ? grid_tmp : max_grid_dimx;
int64_t input_size = src_dims[0];
if (pool_type == "SUM") {
GraphSendRecvSumCUDAFunctor<T, IndexT> functor;
GraphSendRecvCUDAKernel<
T,
IndexT,
GraphSendRecvSumCUDAFunctor<T,
IndexT>><<<grid, block, 0, ctx.stream()>>>(
p_src, s_index, d_index, p_output, index_size, slice_size, functor);
} else if (pool_type == "MAX") {
GraphSendRecvMaxCUDAFunctor<T, IndexT> functor;
GraphSendRecvCUDAKernel<
T,
IndexT,
GraphSendRecvMaxCUDAFunctor<T,
IndexT>><<<grid, block, 0, ctx.stream()>>>(
p_src, s_index, d_index, p_output, index_size, slice_size, functor);
int64_t grid_max_tmp = (input_size * slice_size + block - 1) / block;
int64_t grid_max =
grid_max_tmp < max_grid_dimx ? grid_max_tmp : max_grid_dimx;
InputResetMaxCUDAKernel<T><<<grid_max, block, 0, ctx.stream()>>>(
p_output, input_size, slice_size);
} else if (pool_type == "MIN") {
GraphSendRecvMinCUDAFunctor<T, IndexT> functor;
GraphSendRecvCUDAKernel<
T,
IndexT,
GraphSendRecvMinCUDAFunctor<T,
IndexT>><<<grid, block, 0, ctx.stream()>>>(
p_src, s_index, d_index, p_output, index_size, slice_size, functor);
int64_t grid_min_tmp = (input_size * slice_size + block - 1) / block;
int64_t grid_min =
grid_min_tmp < max_grid_dimx ? grid_min_tmp : max_grid_dimx;
InputResetMinCUDAKernel<T><<<grid_min, block, 0, ctx.stream()>>>(
p_output, input_size, slice_size);
} else if (pool_type == "MEAN") {
GraphSendRecvSumCUDAFunctor<T, IndexT> functor;
GraphSendRecvCUDAKernel<
T,
IndexT,
GraphSendRecvSumCUDAFunctor<T,
IndexT>><<<grid, block, 0, ctx.stream()>>>(
p_src, s_index, d_index, p_output, index_size, slice_size, functor);
ctx.template Alloc<int32_t>(dst_count);
int32_t* p_dst_count = dst_count->data<int32_t>();
#ifdef PADDLE_WITH_HIP
hipMemset(p_dst_count, 0, input_size * sizeof(int));
#else
cudaMemset(p_dst_count, 0, input_size * sizeof(int));
#endif
int64_t grid_count = (index_size + block - 1) / block;
ComputeCountCUDAKernel<T, IndexT><<<grid_count, block, 0, ctx.stream()>>>(
p_dst_count, d_index, index_size);
int64_t grid_mean_tmp = (input_size * slice_size + block - 1) / block;
int64_t grid_mean =
grid_mean_tmp < max_grid_dimx ? grid_mean_tmp : max_grid_dimx;
ManipulateMeanCUDAKernel<T><<<grid_mean, block, 0, ctx.stream()>>>(
p_output, p_dst_count, input_size, slice_size);
}
}
template <typename T, typename Context>
void GraphSendRecvKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& src_index,
const DenseTensor& dst_index,
const std::string& pool_type,
DenseTensor* out,
DenseTensor* dst_count) {
auto index_type = src_index.dtype();
if (index_type == phi::DataType::INT32) {
GraphSendRecvOpCUDAKernelLaunchHelper<Context, T, int32_t>(
ctx, x, src_index, dst_index, pool_type, out, dst_count);
} else if (index_type == phi::DataType::INT64) {
GraphSendRecvOpCUDAKernelLaunchHelper<Context, T, int64_t>(
ctx, x, src_index, dst_index, pool_type, out, dst_count);
}
}
} // namespace phi
PD_REGISTER_KERNEL(graph_send_recv,
GPU,
ALL_LAYOUT,
phi::GraphSendRecvKernel,
float,
double,
int,
int64_t) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/utils/optional.h"
namespace phi {
template <typename T, typename Context>
void GraphSendRecvGradKernel(const Context& ctx,
const DenseTensor& out_grad,
paddle::optional<const DenseTensor&> x,
paddle::optional<const DenseTensor&> out,
const DenseTensor& src_index,
const DenseTensor& dst_index,
paddle::optional<const DenseTensor&> dst_count,
const std::string& pool_type,
DenseTensor* x_grad);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void GraphSendRecvKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& src_index,
const DenseTensor& dst_index,
const std::string& pool_type,
DenseTensor* out,
DenseTensor* dst_count);
} // namespace phi
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature GraphSendRecvGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature(
"graph_send_recv_grad",
{GradVarName("Out"), "X", "Out", "Src_index", "Dst_index", "Dst_count"},
{"pool_type"},
{GradVarName("X")});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(graph_send_recv_grad,
phi::GraphSendRecvGradOpArgumentMapping);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册