未验证 提交 109f8a8e 编写于 作者: S Siming Dai 提交者: GitHub

[cherry-pick] Add paddle.incubate.graph_send_recv API(#37205) (#37343)

* Add paddle.incubate.graph_send_recv API

* fix bug in CudaAtomicMin and CudaAtomicMax

* add empty line
上级 604b6fc0
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/graph_send_recv_op.h"
namespace paddle {
namespace operators {
class GraphSendRecvOP : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "GraphSendRecv");
OP_INOUT_CHECK(ctx->HasInput("Src_index"), "Input", "Src_index",
"GraphSendRecv");
OP_INOUT_CHECK(ctx->HasInput("Dst_index"), "Input", "Dst_index",
"GraphSendRecv");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "GraphSendRecv");
auto src_index_dims = ctx->GetInputDim("Src_index");
if (src_index_dims.size() == 2) {
PADDLE_ENFORCE_EQ(src_index_dims[1], 1,
platform::errors::InvalidArgument(
"The last dim of Src_index should be 1 when it "
"is 2D, but we get %d",
src_index_dims[1]));
} else {
PADDLE_ENFORCE_EQ(
src_index_dims.size(), 1,
platform::errors::InvalidArgument(
"The Src_index should be 1D, when it is not 2D, but we get %d",
src_index_dims.size()));
}
auto dst_index_dims = ctx->GetInputDim("Dst_index");
if (dst_index_dims.size() == 2) {
PADDLE_ENFORCE_EQ(dst_index_dims[1], 1,
platform::errors::InvalidArgument(
"The last dim of Dst_index should be 1 when it "
"is 2D, but we get %d",
dst_index_dims[1]));
} else {
PADDLE_ENFORCE_EQ(
dst_index_dims.size(), 1,
platform::errors::InvalidArgument("The Dst_index should be 1D, "
"when it is not 2D, but we get %d",
dst_index_dims.size()));
}
PADDLE_ENFORCE_EQ(
src_index_dims[0], dst_index_dims[0],
platform::errors::InvalidArgument(
"Src_index and Dst_index should have the same shape."));
auto dims = ctx->GetInputDim("X");
ctx->SetOutputDim("Out", dims);
if (ctx->Attrs().Get<std::string>("pool_type") == "MEAN") {
OP_INOUT_CHECK(ctx->HasOutput("Dst_count"), "Output", "Dst_count",
"GraphSendRecv");
ctx->SetOutputDim("Dst_count", {dims[0]});
}
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.device_context());
}
};
class GraphSendRecvGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
auto in_dims = ctx->GetInputDim(framework::GradVarName("Out"));
ctx->SetOutputDim(framework::GradVarName("X"), in_dims);
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out")),
ctx.device_context());
}
};
class GraphSendRecvOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X",
"The input tensor with data type float32, float64, int32, int64.");
AddInput("Src_index", "The source index tensor.");
AddInput("Dst_index", "The destination index tensor.");
AddOutput("Out", "Output tensor of graph_send_recv op.");
AddOutput("Dst_count",
"Count tensor of Dst_index, mainly for MEAN pool_type.")
.AsIntermediate();
AddAttr<std::string>("pool_type",
"(string, default 'SUM')"
"Define different pool types to receive the result "
"tensors of Dst_index.")
.SetDefault("SUM")
.InEnum({"SUM", "MEAN", "MIN", "MAX"});
AddComment(R"DOC(
Graph Learning Send_Recv combine operator.
$Out = Recv(Send(X, Src_index), Dst_index, pool_type)$
This operator is mainly used in Graph Learning domain, and the main purpose is to reduce
intermediate memory consumption in the process of message passing.
Take `x` as the input tensor, we first use `src_index` to gather corresponding data,
and then use `dst_index` to update the corresponding position of output tensor in different
pooling types, like sum, mean, max, or min.
)DOC");
}
};
template <typename T>
class GraphSendRecvGradOpMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> op) const override {
op->SetType("graph_send_recv_grad");
op->SetInput("Src_index", this->Input("Src_index"));
op->SetInput("Dst_index", this->Input("Dst_index"));
if (BOOST_GET_CONST(std::string, this->GetAttr("pool_type")) == "MEAN") {
op->SetInput("Dst_count", this->Output("Dst_count"));
}
if (BOOST_GET_CONST(std::string, this->GetAttr("pool_type")) == "MIN" ||
BOOST_GET_CONST(std::string, this->GetAttr("pool_type")) == "MAX") {
op->SetInput("X", this->Input("X"));
op->SetInput("Out", this->Output("Out"));
}
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetAttrMap(this->Attrs());
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
using CPU = paddle::platform::CPUDeviceContext;
REGISTER_OPERATOR(graph_send_recv, ops::GraphSendRecvOP,
ops::GraphSendRecvOpMaker,
ops::GraphSendRecvGradOpMaker<paddle::framework::OpDesc>,
ops::GraphSendRecvGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(graph_send_recv_grad, ops::GraphSendRecvGradOp);
REGISTER_OP_CPU_KERNEL(graph_send_recv, ops::GraphSendRecvOpKernel<CPU, float>,
ops::GraphSendRecvOpKernel<CPU, double>,
ops::GraphSendRecvOpKernel<CPU, int>,
ops::GraphSendRecvOpKernel<CPU, int64_t>);
REGISTER_OP_CPU_KERNEL(graph_send_recv_grad,
ops::GraphSendRecvGradOpKernel<CPU, float>,
ops::GraphSendRecvGradOpKernel<CPU, double>,
ops::GraphSendRecvGradOpKernel<CPU, int>,
ops::GraphSendRecvGradOpKernel<CPU, int64_t>);
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <thrust/device_vector.h>
#include <thrust/fill.h>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/graph_send_recv_op.h"
#include "paddle/fluid/platform/cuda_primitives.h"
#include "paddle/fluid/platform/place.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T, typename IndexT>
struct GraphSendRecvSumCUDAFunctor {
DEVICE inline void operator()(const T* params, T* output, const IndexT& in_i,
const IndexT& out_i) {
paddle::platform::CudaAtomicAdd(output + out_i, *(params + in_i));
}
};
template <typename T, typename IndexT>
struct GraphSendRecvMaxCUDAFunctor {
DEVICE inline void operator()(const T* params, T* output, const IndexT& in_i,
const IndexT& out_i) {
paddle::platform::CudaAtomicMax(output + out_i, *(params + in_i));
}
};
template <typename T, typename IndexT>
struct GraphSendRecvMinCUDAFunctor {
DEVICE inline void operator()(const T* params, T* output, const IndexT& in_i,
const IndexT& out_i) {
paddle::platform::CudaAtomicMin(output + out_i, *(params + in_i));
}
};
template <typename T, typename IndexT, typename Functor>
__global__ void GraphSendRecvCUDAKernel(const T* params,
const IndexT* src_indices,
const IndexT* dst_indices, T* output,
size_t index_size, size_t slice_size,
Functor functor) {
CUDA_KERNEL_LOOP_TYPE(i, index_size * slice_size, int64_t) {
int64_t indices_i = i / slice_size;
int64_t slice_i = i - indices_i * slice_size;
IndexT src_i = src_indices[indices_i];
IndexT dst_i = dst_indices[indices_i];
int64_t in_i = src_i * slice_size + slice_i;
int64_t out_i = dst_i * slice_size + slice_i;
functor(params, output, in_i, out_i);
}
}
// For max
template <typename T>
__global__ void InputResetMaxCUDAKernel(T* output, size_t input_size,
size_t slice_size) {
CUDA_KERNEL_LOOP_TYPE(i, input_size * slice_size, int64_t) {
if (*(output + i) == std::numeric_limits<T>::min()) {
*(output + i) = 0;
}
}
}
// For min
template <typename T>
__global__ void InputResetMinCUDAKernel(T* output, size_t input_size,
size_t slice_size) {
CUDA_KERNEL_LOOP_TYPE(i, input_size * slice_size, int64_t) {
if (*(output + i) == std::numeric_limits<T>::max()) {
*(output + i) = 0;
}
}
}
// Get dst_count
template <typename T, typename IndexT>
__global__ void ComputeCountCUDAKernel(int* count, const IndexT* dst_indices,
size_t index_size) {
CUDA_KERNEL_LOOP_TYPE(i, index_size, int64_t) {
IndexT dst_i = dst_indices[i];
paddle::platform::CudaAtomicAdd(count + dst_i, 1);
}
}
// For forward mean
template <typename T>
__global__ void ManipulateMeanCUDAKernel(T* output, int* count,
size_t input_size, size_t slice_size) {
CUDA_KERNEL_LOOP_TYPE(i, input_size * slice_size, int64_t) {
int64_t c_index = i / slice_size;
if (*(count + c_index) > 1) {
*(output + i) = *(output + i) / *(count + c_index);
}
}
}
// For backward mean
template <typename T, typename IndexT>
__global__ void ManipulateMeanGradCUDAKernel(
const T* params, const IndexT* src_indices, const IndexT* dst_indices,
T* output, size_t index_size, size_t slice_size, const int* dst_count) {
CUDA_KERNEL_LOOP_TYPE(i, index_size * slice_size, int64_t) {
int64_t indices_i = i / slice_size;
int64_t slice_i = i - indices_i * slice_size;
IndexT src_i = src_indices[indices_i];
IndexT dst_i = dst_indices[indices_i];
int64_t in_i = src_i * slice_size + slice_i;
int64_t out_i = dst_i * slice_size + slice_i;
paddle::platform::CudaAtomicAdd(output + out_i,
*(params + in_i) / dst_count[src_i]);
}
}
// For backward min and max
template <typename T, typename IndexT>
__global__ void ManipulateMinMaxGradCUDAKernel(
const T* params, const IndexT* src_indices, const IndexT* dst_indices,
T* output, size_t index_size, size_t slice_size, const T* ptr_input,
const T* ptr_output) {
CUDA_KERNEL_LOOP_TYPE(i, index_size * slice_size, int64_t) {
int64_t indices_i = i / slice_size;
int64_t slice_i = i - indices_i * slice_size;
IndexT src_i = src_indices[indices_i];
IndexT dst_i = dst_indices[indices_i];
int64_t in_i = src_i * slice_size + slice_i;
int64_t out_i = dst_i * slice_size + slice_i;
paddle::platform::CudaAtomicAdd(
output + out_i,
*(params + in_i) * (*(ptr_input + out_i) == *(ptr_output + in_i)));
}
}
template <typename DeviceContext, typename T, typename IndexT>
void GraphSendRecvOpCUDAKernelLaunchHelper(
const framework::ExecutionContext& ctx, const Tensor& src_index,
const Tensor& dst_index) {
auto* X = ctx.Input<Tensor>("X");
auto* Y = ctx.Output<Tensor>("Out");
std::string pool_type = ctx.Attr<std::string>("pool_type");
const int& index_size = src_index.dims()[0];
T* p_output = Y->mutable_data<T>(ctx.GetPlace());
const auto& src_dims = X->dims();
int64_t memset_size = 1;
for (int i = 0; i < src_dims.size(); ++i) {
memset_size *= src_dims[i];
}
const size_t& memset_bytes = memset_size * sizeof(T);
if (pool_type == "SUM" || pool_type == "MEAN") {
#ifdef PADDLE_WITH_HIP
hipMemset(p_output, 0, memset_bytes);
#else
cudaMemset(p_output, 0, memset_bytes);
#endif
} else if (pool_type == "MAX") {
thrust::device_ptr<T> p_output_ptr(p_output);
thrust::fill(thrust::device, p_output_ptr, p_output_ptr + memset_size,
std::numeric_limits<T>::min());
} else if (pool_type == "MIN") {
thrust::device_ptr<T> p_output_ptr(p_output);
thrust::fill(thrust::device, p_output_ptr, p_output_ptr + memset_size,
std::numeric_limits<T>::max());
}
if (index_size == 0) return;
int64_t slice_size = 1;
for (int i = 1; i < src_dims.size(); ++i) {
slice_size *= src_dims[i];
}
const T* p_src = X->data<T>();
const IndexT* s_index = src_index.data<IndexT>();
const IndexT* d_index = dst_index.data<IndexT>();
#ifdef PADDLE_WITH_HIP
int block = 256;
#else
int block = 1024;
#endif
int64_t n = slice_size * index_size;
const auto& dev_ctx = ctx.cuda_device_context();
int64_t max_grid_dimx = dev_ctx.GetCUDAMaxGridDimSize().x;
int64_t grid_tmp = (n + block - 1) / block;
int64_t grid = grid_tmp < max_grid_dimx ? grid_tmp : max_grid_dimx;
int64_t input_size = src_dims[0];
if (pool_type == "SUM") {
GraphSendRecvSumCUDAFunctor<T, IndexT> functor;
GraphSendRecvCUDAKernel<T, IndexT,
GraphSendRecvSumCUDAFunctor<T, IndexT>><<<
grid, block, 0, reinterpret_cast<const platform::CUDADeviceContext&>(
ctx.device_context())
.stream()>>>(p_src, s_index, d_index, p_output,
index_size, slice_size, functor);
} else if (pool_type == "MAX") {
GraphSendRecvMaxCUDAFunctor<T, IndexT> functor;
GraphSendRecvCUDAKernel<T, IndexT,
GraphSendRecvMaxCUDAFunctor<T, IndexT>><<<
grid, block, 0, reinterpret_cast<const platform::CUDADeviceContext&>(
ctx.device_context())
.stream()>>>(p_src, s_index, d_index, p_output,
index_size, slice_size, functor);
int64_t grid_max_tmp = (input_size * slice_size + block - 1) / block;
int64_t grid_max =
grid_max_tmp < max_grid_dimx ? grid_max_tmp : max_grid_dimx;
InputResetMaxCUDAKernel<
T><<<grid_max, block, 0,
reinterpret_cast<const platform::CUDADeviceContext&>(
ctx.device_context())
.stream()>>>(p_output, input_size, slice_size);
} else if (pool_type == "MIN") {
GraphSendRecvMinCUDAFunctor<T, IndexT> functor;
GraphSendRecvCUDAKernel<T, IndexT,
GraphSendRecvMinCUDAFunctor<T, IndexT>><<<
grid, block, 0, reinterpret_cast<const platform::CUDADeviceContext&>(
ctx.device_context())
.stream()>>>(p_src, s_index, d_index, p_output,
index_size, slice_size, functor);
int64_t grid_min_tmp = (input_size * slice_size + block - 1) / block;
int64_t grid_min =
grid_min_tmp < max_grid_dimx ? grid_min_tmp : max_grid_dimx;
InputResetMinCUDAKernel<
T><<<grid_min, block, 0,
reinterpret_cast<const platform::CUDADeviceContext&>(
ctx.device_context())
.stream()>>>(p_output, input_size, slice_size);
} else if (pool_type == "MEAN") {
GraphSendRecvSumCUDAFunctor<T, IndexT> functor;
GraphSendRecvCUDAKernel<T, IndexT,
GraphSendRecvSumCUDAFunctor<T, IndexT>><<<
grid, block, 0, reinterpret_cast<const platform::CUDADeviceContext&>(
ctx.device_context())
.stream()>>>(p_src, s_index, d_index, p_output,
index_size, slice_size, functor);
auto* dst_count = ctx.Output<Tensor>("Dst_count");
int* p_dst_count = dst_count->mutable_data<int>(ctx.GetPlace());
#ifdef PADDLE_WITH_HIP
hipMemset(p_dst_count, 0, input_size * sizeof(int));
#else
cudaMemset(p_dst_count, 0, input_size * sizeof(int));
#endif
int64_t grid_count = (index_size + block - 1) / block;
ComputeCountCUDAKernel<
T, IndexT><<<grid_count, block, 0,
reinterpret_cast<const platform::CUDADeviceContext&>(
ctx.device_context())
.stream()>>>(p_dst_count, d_index, index_size);
int64_t grid_mean_tmp = (input_size * slice_size + block - 1) / block;
int64_t grid_mean =
grid_mean_tmp < max_grid_dimx ? grid_mean_tmp : max_grid_dimx;
ManipulateMeanCUDAKernel<
T><<<grid_mean, block, 0,
reinterpret_cast<const platform::CUDADeviceContext&>(
ctx.device_context())
.stream()>>>(p_output, p_dst_count, input_size, slice_size);
}
}
template <typename DeviceContext, typename T, typename IndexT>
void GraphSendRecvGradOpCUDAKernelLaunchHelper(
const framework::ExecutionContext& ctx, const Tensor& src_index,
const Tensor& dst_index) {
auto* X = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* Y = ctx.Output<Tensor>(framework::GradVarName("X"));
std::string pool_type = ctx.Attr<std::string>("pool_type");
const int& index_size = src_index.dims()[0];
T* p_output = Y->mutable_data<T>(ctx.GetPlace());
const auto& src_dims = X->dims();
int64_t memset_size = 1;
for (int i = 0; i < src_dims.size(); ++i) {
memset_size *= src_dims[i];
}
const size_t& memset_bytes = memset_size * sizeof(T);
#ifdef PADDLE_WITH_HIP
hipMemset(p_output, 0, memset_bytes);
#else
cudaMemset(p_output, 0, memset_bytes);
#endif
if (index_size == 0) return;
int64_t slice_size = 1;
for (int i = 1; i < src_dims.size(); ++i) {
slice_size *= src_dims[i];
}
const T* p_src = X->data<T>();
const IndexT* s_index = src_index.data<IndexT>();
const IndexT* d_index = dst_index.data<IndexT>();
#ifdef PADDLE_WITH_HIP
int block = 256;
#else
int block = 1024;
#endif
int64_t n = slice_size * index_size;
const auto& dev_ctx = ctx.cuda_device_context();
int64_t max_grid_dimx = dev_ctx.GetCUDAMaxGridDimSize().x;
int64_t grid_tmp = (n + block - 1) / block;
int64_t grid = grid_tmp < max_grid_dimx ? grid_tmp : max_grid_dimx;
int64_t input_size = src_dims[0];
if (pool_type == "SUM") {
GraphSendRecvSumCUDAFunctor<T, IndexT> functor;
GraphSendRecvCUDAKernel<T, IndexT,
GraphSendRecvSumCUDAFunctor<T, IndexT>><<<
grid, block, 0, reinterpret_cast<const platform::CUDADeviceContext&>(
ctx.device_context())
.stream()>>>(p_src, s_index, d_index, p_output,
index_size, slice_size, functor);
} else if (pool_type == "MEAN") {
auto* dst_count = ctx.Input<Tensor>("Dst_count");
const int* s_count = dst_count->data<int>();
ManipulateMeanGradCUDAKernel<T, IndexT><<<
grid, block, 0, reinterpret_cast<const platform::CUDADeviceContext&>(
ctx.device_context())
.stream()>>>(p_src, s_index, d_index, p_output,
index_size, slice_size, s_count);
} else if (pool_type == "MAX" || pool_type == "MIN") {
auto* input = ctx.Input<Tensor>("X");
auto* output = ctx.Input<Tensor>("Out");
const T* ptr_input = input->data<T>();
const T* ptr_output = output->data<T>();
ManipulateMinMaxGradCUDAKernel<T, IndexT><<<
grid, block, 0, reinterpret_cast<const platform::CUDADeviceContext&>(
ctx.device_context())
.stream()>>>(p_src, s_index, d_index, p_output,
index_size, slice_size, ptr_input,
ptr_output);
}
}
template <typename DeviceContext, typename T>
class GraphSendRecvOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* src_index = ctx.Input<Tensor>("Src_index");
auto* dst_index = ctx.Input<Tensor>("Dst_index");
auto index_type = src_index->type();
if (index_type == framework::proto::VarType::INT32) {
GraphSendRecvOpCUDAKernelLaunchHelper<DeviceContext, T, int>(
ctx, *src_index, *dst_index);
} else if (index_type == framework::proto::VarType::INT64) {
GraphSendRecvOpCUDAKernelLaunchHelper<DeviceContext, T, int64_t>(
ctx, *src_index, *dst_index);
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Unsupported Src_index or Dst_index dtype, expected int, int64, but "
"got %s.",
index_type));
}
}
};
template <typename DeviceContext, typename T>
class GraphSendRecvGradOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* src_index = ctx.Input<Tensor>("Dst_index");
auto* dst_index = ctx.Input<Tensor>("Src_index");
auto index_type = src_index->type();
if (index_type == framework::proto::VarType::INT32) {
GraphSendRecvGradOpCUDAKernelLaunchHelper<DeviceContext, T, int>(
ctx, *src_index, *dst_index);
} else if (index_type == framework::proto::VarType::INT64) {
GraphSendRecvGradOpCUDAKernelLaunchHelper<DeviceContext, T, int64_t>(
ctx, *src_index, *dst_index);
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Unsupported Src_index or Dst_index dtype, expected int, int64, but "
"got %s.",
index_type));
}
}
};
} // namespace operators
} // namespace paddle
using CUDA = paddle::platform::CUDADeviceContext;
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(graph_send_recv,
ops::GraphSendRecvOpCUDAKernel<CUDA, float>,
ops::GraphSendRecvOpCUDAKernel<CUDA, double>,
ops::GraphSendRecvOpCUDAKernel<CUDA, int>,
ops::GraphSendRecvOpCUDAKernel<CUDA, int64_t>);
REGISTER_OP_CUDA_KERNEL(graph_send_recv_grad,
ops::GraphSendRecvGradOpCUDAKernel<CUDA, float>,
ops::GraphSendRecvGradOpCUDAKernel<CUDA, double>,
ops::GraphSendRecvGradOpCUDAKernel<CUDA, int>,
ops::GraphSendRecvGradOpCUDAKernel<CUDA, int64_t>);
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/platform/place.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T>
struct GraphSendRecvSumFunctor {
void operator()(const bool& first_flag, const Tensor& src_slice,
Tensor* dst_slice) {
auto eigen_src = framework::EigenVector<T>::Flatten(src_slice);
auto eigen_dst = framework::EigenVector<T>::Flatten(*dst_slice);
eigen_dst += eigen_src;
}
};
template <typename T>
struct GraphSendRecvMinFunctor {
void operator()(const bool& first_flag, const Tensor& src_slice,
Tensor* dst_slice) {
auto eigen_src = framework::EigenVector<T>::Flatten(src_slice);
auto eigen_dst = framework::EigenVector<T>::Flatten(*dst_slice);
if (first_flag) {
eigen_dst += eigen_src;
} else {
eigen_dst = eigen_dst.cwiseMin(eigen_src);
}
}
};
template <typename T>
struct GraphSendRecvMaxFunctor {
void operator()(const int& first_flag, const Tensor& src_slice,
Tensor* dst_slice) {
auto eigen_src = framework::EigenVector<T>::Flatten(src_slice);
auto eigen_dst = framework::EigenVector<T>::Flatten(*dst_slice);
if (first_flag) {
eigen_dst += eigen_src;
} else {
eigen_dst = eigen_dst.cwiseMax(eigen_src);
}
}
};
template <typename T, typename IndexT, typename Functor>
void elementwise_inner_operation(const Tensor& src, Tensor* dst,
const IndexT& src_index,
const IndexT& dst_index,
const bool& first_flag, Functor functor) {
auto src_slice = src.Slice(src_index, src_index + 1);
auto dst_slice = dst->Slice(dst_index, dst_index + 1);
functor(first_flag, src_slice, &dst_slice);
}
template <typename T, typename IndexT, typename Functor>
void graph_send_recv_cpu_for_loop(const int& input_size, const int& index_size,
const IndexT* s_index, const IndexT* d_index,
const Tensor& src, Tensor* dst,
const std::string& pool_type,
int* dst_count = nullptr) {
Functor functor;
if (pool_type == "SUM") {
for (int i = 0; i < index_size; ++i) {
const IndexT& src_idx = s_index[i];
const IndexT& dst_idx = d_index[i];
elementwise_inner_operation<T, IndexT, Functor>(src, dst, src_idx,
dst_idx, false, functor);
}
} else if (pool_type == "MEAN") {
for (int i = 0; i < index_size; ++i) {
const IndexT& src_idx = s_index[i];
const IndexT& dst_idx = d_index[i];
elementwise_inner_operation<T, IndexT, Functor>(src, dst, src_idx,
dst_idx, false, functor);
}
for (int i = 0; i < index_size; ++i) {
IndexT dst_idx = d_index[i];
*(dst_count + dst_idx) += 1;
}
for (int i = 0; i < input_size; ++i) {
if (*(dst_count + i) == 0) continue;
auto dst_slice = dst->Slice(i, i + 1);
auto eigen_dst = framework::EigenVector<T>::Flatten(dst_slice);
eigen_dst = eigen_dst / static_cast<T>(*(dst_count + i));
}
} else if (pool_type == "MIN" || pool_type == "MAX") {
std::set<IndexT> existed_dst;
for (int i = 0; i < index_size; ++i) {
const IndexT& src_idx = s_index[i];
const IndexT& dst_idx = d_index[i];
bool in_set = existed_dst.find(dst_idx) != existed_dst.end();
if (!in_set) {
elementwise_inner_operation<T, IndexT, Functor>(src, dst, src_idx,
dst_idx, true, functor);
existed_dst.emplace(dst_idx);
} else {
elementwise_inner_operation<T, IndexT, Functor>(
src, dst, src_idx, dst_idx, false, functor);
}
}
}
}
template <typename T, typename IndexT, typename Functor>
void graph_send_recv_cpu_for_loop_grad(
const int& input_size, const int& index_size, const IndexT* s_index,
const IndexT* d_index, const Tensor& src, Tensor* dst,
const std::string& pool_type, const int* dst_count = nullptr,
const Tensor* input = nullptr, const Tensor* output = nullptr) {
if (pool_type == "SUM") {
Functor functor;
for (int i = 0; i < index_size; ++i) {
const IndexT& src_idx = s_index[i];
const IndexT& dst_idx = d_index[i];
elementwise_inner_operation<T, IndexT, Functor>(src, dst, src_idx,
dst_idx, false, functor);
}
} else if (pool_type == "MEAN") {
for (int i = 0; i < index_size; ++i) {
const IndexT& src_idx = s_index[i];
const IndexT& dst_idx = d_index[i];
auto src_slice = src.Slice(src_idx, src_idx + 1);
auto dst_slice = dst->Slice(dst_idx, dst_idx + 1);
auto eigen_src = framework::EigenVector<T>::Flatten(src_slice);
auto eigen_dst = framework::EigenVector<T>::Flatten(dst_slice);
eigen_dst += (eigen_src / static_cast<T>(dst_count[src_idx]));
}
} else if (pool_type == "MIN" || pool_type == "MAX") {
for (int i = 0; i < index_size; ++i) {
const IndexT& forward_src_idx = d_index[i];
const IndexT& forward_dst_idx = s_index[i];
auto input_slice = input->Slice(forward_src_idx, forward_src_idx + 1);
auto output_slice = output->Slice(forward_dst_idx, forward_dst_idx + 1);
auto eigen_input = framework::EigenVector<T>::Flatten(input_slice);
auto eigen_output = framework::EigenVector<T>::Flatten(output_slice);
auto src_slice = src.Slice(forward_dst_idx, forward_dst_idx + 1);
auto dst_slice = dst->Slice(forward_src_idx, forward_src_idx + 1);
auto eigen_src = framework::EigenVector<T>::Flatten(src_slice);
auto eigen_dst = framework::EigenVector<T>::Flatten(dst_slice);
eigen_dst += eigen_src * (eigen_output == eigen_input);
}
}
}
template <typename DeviceContext, typename T, typename IndexT>
void GraphSendRecvOpKernelLaunchHelper(const framework::ExecutionContext& ctx,
const Tensor& src_index) {
auto* X = ctx.Input<Tensor>("X");
auto* dst_index = ctx.Input<Tensor>("Dst_index");
auto* Y = ctx.Output<Tensor>("Out");
const int& index_size = src_index.dims()[0];
T* p_output = Y->mutable_data<T>(ctx.GetPlace());
const auto& src_dims = X->dims();
int64_t memset_size = 1;
for (int i = 0; i < src_dims.size(); ++i) memset_size *= src_dims[i];
const size_t& memset_bytes = memset_size * sizeof(T);
memset(p_output, 0, memset_bytes);
if (index_size == 0) return;
const IndexT* s_index = src_index.data<IndexT>();
const IndexT* d_index = dst_index->data<IndexT>();
const std::string& pool_type = ctx.Attr<std::string>("pool_type");
if (pool_type == "SUM") {
graph_send_recv_cpu_for_loop<T, IndexT, GraphSendRecvSumFunctor<T>>(
src_dims[0], index_size, s_index, d_index, *X, Y, pool_type);
} else if (pool_type == "MIN") {
graph_send_recv_cpu_for_loop<T, IndexT, GraphSendRecvMinFunctor<T>>(
src_dims[0], index_size, s_index, d_index, *X, Y, pool_type);
} else if (pool_type == "MAX") {
graph_send_recv_cpu_for_loop<T, IndexT, GraphSendRecvMaxFunctor<T>>(
src_dims[0], index_size, s_index, d_index, *X, Y, pool_type);
} else if (pool_type == "MEAN") {
auto* dst_count = ctx.Output<Tensor>("Dst_count");
int* p_dst_count = dst_count->mutable_data<int>(ctx.GetPlace());
memset(p_dst_count, 0, src_dims[0] * sizeof(int));
graph_send_recv_cpu_for_loop<T, IndexT, GraphSendRecvSumFunctor<T>>(
src_dims[0], index_size, s_index, d_index, *X, Y, pool_type,
p_dst_count);
}
}
template <typename DeviceContext, typename T, typename IndexT>
void GraphSendRecvGradOpKernelLaunchHelper(
const framework::ExecutionContext& ctx, const Tensor& src_index) {
auto* X = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* dst_index = ctx.Input<Tensor>("Src_index");
auto* Y = ctx.Output<Tensor>(framework::GradVarName("X"));
const int& index_size = src_index.dims()[0];
T* p_output = Y->mutable_data<T>(ctx.GetPlace());
const auto& src_dims = X->dims();
int64_t memset_size = 1;
for (int i = 0; i < src_dims.size(); ++i) memset_size *= src_dims[i];
const size_t& memset_bytes = memset_size * sizeof(T);
memset(p_output, 0, memset_bytes);
if (index_size == 0) return;
const IndexT* s_index = src_index.data<IndexT>();
const IndexT* d_index = dst_index->data<IndexT>();
const std::string& pool_type = ctx.Attr<std::string>("pool_type");
if (pool_type == "SUM") {
graph_send_recv_cpu_for_loop_grad<T, IndexT, GraphSendRecvSumFunctor<T>>(
src_dims[0], index_size, s_index, d_index, *X, Y, pool_type);
} else if (pool_type == "MEAN") {
auto* dst_count = ctx.Input<Tensor>("Dst_count");
const int* s_count = dst_count->data<int>();
// Functor not used here.
graph_send_recv_cpu_for_loop_grad<T, IndexT, GraphSendRecvSumFunctor<T>>(
src_dims[0], index_size, s_index, d_index, *X, Y, pool_type, s_count);
} else if (pool_type == "MIN" || pool_type == "MAX") {
const auto* input = ctx.Input<Tensor>("X");
const auto* output = ctx.Input<Tensor>("Out");
// Functor not used here.
graph_send_recv_cpu_for_loop_grad<T, IndexT, GraphSendRecvMinFunctor<T>>(
src_dims[0], index_size, s_index, d_index, *X, Y, pool_type, nullptr,
input, output);
}
}
template <typename DeviceContext, typename T>
class GraphSendRecvOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* src_index = ctx.Input<Tensor>("Src_index");
auto index_type = src_index->type();
if (index_type == framework::proto::VarType::INT32) {
GraphSendRecvOpKernelLaunchHelper<DeviceContext, T, int>(ctx, *src_index);
} else if (index_type == framework::proto::VarType::INT64) {
GraphSendRecvOpKernelLaunchHelper<DeviceContext, T, int64_t>(ctx,
*src_index);
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Unsupported Src_index or Dst_index type, Expected int, int64, but "
"got %s.",
index_type));
}
}
};
template <typename DeviceContext, typename T>
class GraphSendRecvGradOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* src_index = ctx.Input<Tensor>("Dst_index");
auto index_type = src_index->type();
if (index_type == framework::proto::VarType::INT32) {
GraphSendRecvGradOpKernelLaunchHelper<DeviceContext, T, int>(ctx,
*src_index);
} else if (index_type == framework::proto::VarType::INT64) {
GraphSendRecvGradOpKernelLaunchHelper<DeviceContext, T, int64_t>(
ctx, *src_index);
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Unsupported Src_index or Dst_index type, Expected int, int64, but "
"got %s.",
index_type));
}
}
};
} // namespace operators
} // namespace paddle
......@@ -178,9 +178,17 @@ CUDA_ATOMIC_WRAPPER(Max, int64_t) {
// Here, we check long long int must be int64_t.
static_assert(sizeof(int64_t) == sizeof(long long int), // NOLINT
"long long should be int64");
return CudaAtomicMax(
reinterpret_cast<unsigned long long int *>(address), // NOLINT
static_cast<unsigned long long int>(val)); // NOLINT
long long int res = *address; // NOLINT
while (val > res) {
long long int old = res; // NOLINT
res = (long long int)atomicCAS((unsigned long long int *)address, // NOLINT
(unsigned long long int)old, // NOLINT
(unsigned long long int)val); // NOLINT
if (res == old) {
break;
}
}
return res;
}
CUDA_ATOMIC_WRAPPER(Max, float) {
......@@ -254,9 +262,17 @@ CUDA_ATOMIC_WRAPPER(Min, int64_t) {
// Here, we check long long int must be int64_t.
static_assert(sizeof(int64_t) == sizeof(long long int), // NOLINT
"long long should be int64");
return CudaAtomicMin(
reinterpret_cast<unsigned long long int *>(address), // NOLINT
static_cast<unsigned long long int>(val)); // NOLINT
long long int res = *address; // NOLINT
while (val < res) {
long long int old = res; // NOLINT
res = (long long int)atomicCAS((unsigned long long int *)address, // NOLINT
(unsigned long long int)old, // NOLINT
(unsigned long long int)val); // NOLINT
if (res == old) {
break;
}
}
return res;
}
CUDA_ATOMIC_WRAPPER(Min, float) {
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle
import paddle.fluid as fluid
from op_test import OpTest
class TestGraphSendRecvMaxOp(OpTest):
def setUp(self):
paddle.enable_static()
self.op_type = "graph_send_recv"
x = np.random.random((10, 20)).astype("float64")
index = np.random.randint(0, 10, (15, 2)).astype(np.int64)
src_index = index[:, 0]
dst_index = index[:, 1]
self.inputs = {'X': x, 'Src_index': src_index, 'Dst_index': dst_index}
self.attrs = {'pool_type': 'MAX'}
out, self.gradient = compute_graph_send_recv_for_min_max(self.inputs,
self.attrs)
self.outputs = {'Out': out}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Out', user_defined_grads=[self.gradient])
class TestGraphSendRecvMinOp(OpTest):
def setUp(self):
paddle.enable_static()
self.op_type = "graph_send_recv"
x = np.random.random((10, 20)).astype("float64")
index = np.random.randint(0, 10, (15, 2)).astype(np.int64)
src_index = index[:, 0]
dst_index = index[:, 1]
self.inputs = {'X': x, 'Src_index': src_index, 'Dst_index': dst_index}
self.attrs = {'pool_type': 'MIN'}
out, self.gradient = compute_graph_send_recv_for_min_max(self.inputs,
self.attrs)
self.outputs = {'Out': out}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Out', user_defined_grads=[self.gradient])
class TestGraphSendRecvSumOp(OpTest):
def setUp(self):
paddle.enable_static()
self.op_type = "graph_send_recv"
x = np.random.random((10, 20)).astype("float64")
index = np.random.randint(0, 10, (15, 2)).astype(np.int64)
src_index = index[:, 0]
dst_index = index[:, 1]
self.inputs = {'X': x, 'Src_index': src_index, 'Dst_index': dst_index}
self.attrs = {'pool_type': 'SUM'}
out, _ = compute_graph_send_recv_for_sum_mean(self.inputs, self.attrs)
self.outputs = {'Out': out}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Out')
class TestGraphSendRecvMeanOp(OpTest):
def setUp(self):
paddle.enable_static()
self.op_type = "graph_send_recv"
x = np.random.random((10, 20)).astype("float64")
index = np.random.randint(0, 10, (15, 2)).astype(np.int64)
src_index = index[:, 0]
dst_index = index[:, 1]
self.inputs = {'X': x, 'Src_index': src_index, 'Dst_index': dst_index}
self.attrs = {'pool_type': 'MEAN'}
out, dst_count = compute_graph_send_recv_for_sum_mean(self.inputs,
self.attrs)
self.outputs = {'Out': out, 'Dst_count': dst_count}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Out')
def compute_graph_send_recv_for_sum_mean(inputs, attributes):
x = inputs['X']
src_index = inputs['Src_index']
dst_index = inputs['Dst_index']
pool_type = attributes['pool_type']
gather_x = x[src_index]
target_shape = list(x.shape)
results = np.zeros(target_shape, dtype=x.dtype)
if pool_type == 'SUM':
for index, s_id in enumerate(dst_index):
results[s_id, :] += gather_x[index, :]
elif pool_type == 'MEAN':
count = np.zeros(target_shape[0], dtype=np.int32)
for index, s_id in enumerate(dst_index):
results[s_id, :] += gather_x[index, :]
count[s_id] += 1
results = results / count.reshape([-1, 1])
results[np.isnan(results)] = 0
else:
raise ValueError("Invalid pool_type, only SUM, MEAN supported!")
count = np.zeros(target_shape[0], dtype=np.int32)
for index, s_id in enumerate(dst_index):
count[s_id] += 1
return results, count
def compute_graph_send_recv_for_min_max(inputs, attributes):
x = inputs['X']
src_index = inputs['Src_index']
dst_index = inputs['Dst_index']
pool_type = attributes['pool_type']
gather_x = x[src_index]
target_shape = list(x.shape)
results = np.zeros(target_shape, dtype=x.dtype)
gradient = np.zeros_like(x)
# Calculate forward output
if pool_type == "MAX":
first_set = set()
for index, s_id in enumerate(dst_index):
if s_id not in first_set:
results[s_id, :] += gather_x[index, :]
first_set.add(s_id)
else:
results[s_id, :] = np.maximum(results[s_id, :],
gather_x[index, :])
elif pool_type == "MIN":
first_set = set()
for index, s_id in enumerate(dst_index):
if s_id not in first_set:
results[s_id, :] += gather_x[index, :]
first_set.add(s_id)
else:
results[s_id, :] = np.minimum(results[s_id, :],
gather_x[index, :])
else:
raise ValueError("Invalid pool_type, only MAX, MIN supported!")
# Calculate backward gradient
index_size = len(src_index)
for i in range(index_size):
forward_src_idx = src_index[i]
forward_dst_idx = dst_index[i]
gradient[forward_src_idx] += 1 * (
x[forward_src_idx] == results[forward_dst_idx])
return results, gradient / results.size
class API_GraphSendRecvOpTest(unittest.TestCase):
def test_static(self):
paddle.enable_static()
with paddle.static.program_guard(paddle.static.Program()):
x = paddle.static.data(name="x", shape=[3, 3], dtype="float32")
src_index = paddle.static.data(name="src", shape=[4], dtype="int32")
dst_index = paddle.static.data(name="dst", shape=[4], dtype="int32")
res_sum = paddle.incubate.graph_send_recv(x, src_index, dst_index,
"sum")
res_mean = paddle.incubate.graph_send_recv(x, src_index, dst_index,
"mean")
res_max = paddle.incubate.graph_send_recv(x, src_index, dst_index,
"max")
res_min = paddle.incubate.graph_send_recv(x, src_index, dst_index,
"min")
exe = paddle.static.Executor(paddle.CPUPlace())
data1 = np.array([[0, 2, 3], [1, 4, 5], [2, 6, 7]], dtype='float32')
data2 = np.array([0, 1, 2, 0], dtype="int32")
data3 = np.array([1, 2, 1, 0], dtype="int32")
np_sum = np.array(
[[0, 2, 3], [2, 8, 10], [1, 4, 5]], dtype="float32")
np_mean = np.array(
[[0, 2, 3], [1, 4, 5], [1, 4, 5]], dtype="float32")
np_max = np.array(
[[0, 2, 3], [2, 6, 7], [1, 4, 5]], dtype="float32")
np_min = np.array(
[[0, 2, 3], [0, 2, 3], [1, 4, 5]], dtype="float32")
ret = exe.run(feed={'x': data1,
'src': data2,
'dst': data3},
fetch_list=[res_sum, res_mean, res_max, res_min])
for np_res, ret_res in zip([np_sum, np_mean, np_max, np_min], ret):
self.assertTrue(
np.allclose(
np_res, ret_res, atol=1e-6),
"two value is\
{}\n{}, check diff!".format(np_res, ret_res))
def test_dygraph(self):
device = paddle.CPUPlace()
with paddle.fluid.dygraph.guard(device):
x = paddle.to_tensor(
np.array([[0, 2, 3], [1, 4, 5], [2, 6, 7]]), dtype="float32")
src_index = paddle.to_tensor(np.array([0, 1, 2, 0]), dtype="int32")
dst_index = paddle.to_tensor(np.array([1, 2, 1, 0]), dtype="int32")
res_sum = paddle.incubate.graph_send_recv(x, src_index, dst_index,
"sum")
res_mean = paddle.incubate.graph_send_recv(x, src_index, dst_index,
"mean")
res_max = paddle.incubate.graph_send_recv(x, src_index, dst_index,
"max")
res_min = paddle.incubate.graph_send_recv(x, src_index, dst_index,
"min")
np_sum = np.array(
[[0, 2, 3], [2, 8, 10], [1, 4, 5]], dtype="float32")
np_mean = np.array(
[[0, 2, 3], [1, 4, 5], [1, 4, 5]], dtype="float32")
np_max = np.array(
[[0, 2, 3], [2, 6, 7], [1, 4, 5]], dtype="float32")
np_min = np.array(
[[0, 2, 3], [0, 2, 3], [1, 4, 5]], dtype="float32")
ret = [res_sum, res_mean, res_max, res_min]
for np_res, ret_res in zip([np_sum, np_mean, np_max, np_min], ret):
self.assertTrue(
np.allclose(
np_res, ret_res, atol=1e-6),
"two value is\
{}\n{}, check diff!".format(np_res, ret_res))
def test_int32_input(self):
device = paddle.CPUPlace()
with paddle.fluid.dygraph.guard(device):
x = paddle.to_tensor(
np.array([[0, 2, 3], [1, 4, 5], [2, 6, 6]]), dtype="int32")
src_index = paddle.to_tensor(
np.array([0, 1, 2, 0, 1]), dtype="int32")
dst_index = paddle.to_tensor(
np.array([1, 2, 1, 0, 1]), dtype="int32")
res_sum = paddle.incubate.graph_send_recv(x, src_index, dst_index,
"sum")
res_mean = paddle.incubate.graph_send_recv(x, src_index, dst_index,
"mean")
res_max = paddle.incubate.graph_send_recv(x, src_index, dst_index,
"max")
res_min = paddle.incubate.graph_send_recv(x, src_index, dst_index,
"min")
np_sum = np.array(
[[0, 2, 3], [3, 12, 14], [1, 4, 5]], dtype="int32")
np_mean = np.array([[0, 2, 3], [1, 4, 4], [1, 4, 5]], dtype="int32")
np_max = np.array([[0, 2, 3], [2, 6, 6], [1, 4, 5]], dtype="int32")
np_min = np.array([[0, 2, 3], [0, 2, 3], [1, 4, 5]], dtype="int32")
ret = [res_sum, res_mean, res_max, res_min]
for np_res, ret_res in zip([np_sum, np_mean, np_max, np_min], ret):
self.assertTrue(
np.allclose(
np_res, ret_res, atol=1e-6),
"two value is\
{}\n{}, check diff!".format(np_res, ret_res))
if __name__ == '__main__':
unittest.main()
......@@ -18,6 +18,7 @@ from .checkpoint import auto_checkpoint # noqa: F401
from ..fluid.layer_helper import LayerHelper # noqa: F401
from .operators import softmax_mask_fuse_upper_triangle # noqa: F401
from .operators import softmax_mask_fuse # noqa: F401
from .operators import graph_send_recv
from .tensor import segment_sum
from .tensor import segment_mean
from .tensor import segment_max
......@@ -30,6 +31,7 @@ __all__ = [
'ModelAverage',
'softmax_mask_fuse_upper_triangle',
'softmax_mask_fuse',
'graph_send_recv',
'segment_sum',
'segment_mean',
'segment_max',
......
......@@ -15,3 +15,4 @@
from .softmax_mask_fuse_upper_triangle import softmax_mask_fuse_upper_triangle # noqa: F401
from .softmax_mask_fuse import softmax_mask_fuse # noqa: F401
from .resnet_unit import ResNetUnit #noqa: F401
from .graph_send_recv import graph_send_recv #noqa: F401
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle.fluid.layer_helper import LayerHelper
from paddle.fluid.framework import in_dygraph_mode
from paddle.fluid.data_feeder import check_variable_and_dtype
from paddle.fluid import core
def graph_send_recv(x, src_index, dst_index, pool_type="sum", name=None):
r"""
Graph Learning Send_Recv combine operator.
This operator is mainly used in Graph Learning domain, and the main purpose is to reduce intermediate memory
consumption in the process of message passing. Take `x` as the input tensor, we first use `src_index`
to gather the corresponding data, and then use `dst_index` to update the corresponding position of output tensor
in different pooling types, like sum, mean, max, or min.
.. code-block:: text
Given:
X = [[0, 2, 3],
[1, 4, 5],
[2, 6, 7]]
src_index = [0, 1, 2, 0]
dst_index = [1, 2, 1, 0]
pool_type = "sum"
Then:
Out = [[0, 2, 3],
[2, 8, 10],
[1, 4, 5]]
Args:
x (Tensor): The input tensor, and the available data type is float32, float64, int32, int64.
src_index (Tensor): An 1-D tensor, and the available data type is int32, int64.
dst_index (Tensor): An 1-D tensor, and should have the same shape as `src_index`.
The available data type is int32, int64.
pool_type (str): The pooling type of graph_send_recv, including `sum`, `mean`, `max`, `min`.
Default value is `sum`.
name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
Returns:
out (Tensor): The output tensor, should have the same shape and same dtype as input tensor `x`.
Examples:
.. code-block:: python
import paddle
x = paddle.to_tensor([[0, 2, 3], [1, 4, 5], [2, 6, 7]], dtype="float32")
indexes = paddle.to_tensor([[0, 1], [1, 2], [2, 1], [0, 0]], dtype="int32")
src_index = indexes[:, 0]
dst_index = indexes[:, 1]
out = paddle.incubate.graph_send_recv(x, src_index, dst_index, pool_type="sum")
# Outputs: [[0., 2., 3.], [2., 8., 10.], [1., 4., 5.]]
"""
if pool_type not in ["sum", "mean", "max", "min"]:
raise ValueError(
"pool_type should be `sum`, `mean`, `max` or `min`, but received %s"
% pool_type)
if in_dygraph_mode():
out, tmp = core.ops.graph_send_recv(x, src_index, dst_index,
'pool_type', pool_type.upper())
return out
check_variable_and_dtype(x, "X", ("float32", "float64", "int32", "int64"),
"graph_send_recv")
check_variable_and_dtype(src_index, "Src_index", ("int32", "int64"),
"graph_send_recv")
check_variable_and_dtype(dst_index, "Dst_index", ("int32", "int64"),
"graph_send_recv")
helper = LayerHelper("graph_send_recv", **locals())
out = helper.create_variable_for_type_inference(dtype=x.dtype)
dst_count = helper.create_variable_for_type_inference(
dtype="int32", stop_gradient=True)
helper.append_op(
type="graph_send_recv",
inputs={"X": x,
"Src_index": src_index,
"Dst_index": dst_index},
outputs={"Out": out,
"Dst_count": dst_count},
attrs={"pool_type": pool_type.upper()})
return out
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册