未验证 提交 615b15a3 编写于 作者: S Siming Dai 提交者: GitHub

[geometric]Add paddle.geometric.send_ue_recv API (#43174)

* add init file

* add op definition and infermeta

* add kernel definition funcs

* add broadcast infer shape

* add gpu forward kernel

* delete SUB and DIV

* add x_grad

* add template

* add e_grad for min and max

* fix small bug

* temp commit

* temp commit

* add e_grad for sum and mean

* fix some compile bug

* fix compile bugs

* fix compile problem

* add sum forward unittest

* fix broadcast error, add kernel sig, register e_grad, change unit test

* fix grad

* add temp grad fix

* temp commit

* add min max unittest

* add max, min unittest, fix mul bug

* add cpu forward sum and mean

* add forward min max, fix mean unittest

* add cpu backward min max

* fix code-style

* add backward sum mean

* fix rocm ci

* set uniitest timeout

* fix bug of x broadcast to e, gpu grad

* fix bug of x broadcast to e, cpu grad

* rename BOOST_GET_CONST macro

* fix rocm ci

* mv graph_send_e_recv to graph_send_ue_recv

* move out_size to IntArray

* add eager op test

* fix max pool type bug, add unittest for api

* revise api doc

* add fp16 for atomic min and max, add unittest

* add unittest

* add fp16 support for graph_send_recv

* fix unittest fp16 bug

* change OutSizeTensor to Out_size

* move E to Y

* add copyright, fix comment

* review code

* fix thread block size

* fix thread block size

* change api attribute name: pool_type to reduce_op, compute_type to message_op

* change api attribute name, move pool_type to reduce_op, move compute_type to message_op
上级 9b35f035
...@@ -64,9 +64,9 @@ class GraphSendRecvOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -64,9 +64,9 @@ class GraphSendRecvOpMaker : public framework::OpProtoAndCheckerMaker {
.AsDispensable(); .AsDispensable();
AddOutput("Out", "Output tensor of graph_send_recv op."); AddOutput("Out", "Output tensor of graph_send_recv op.");
AddOutput("Dst_count", AddOutput("Dst_count",
"Count tensor of Dst_index, mainly for MEAN pool_type.") "Count tensor of Dst_index, mainly for MEAN reduce_op.")
.AsIntermediate(); .AsIntermediate();
AddAttr<std::string>("pool_type", AddAttr<std::string>("reduce_op",
"(string, default 'SUM')" "(string, default 'SUM')"
"Define different pool types to receive the result " "Define different pool types to receive the result "
"tensors of Dst_index.") "tensors of Dst_index.")
...@@ -81,7 +81,7 @@ class GraphSendRecvOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -81,7 +81,7 @@ class GraphSendRecvOpMaker : public framework::OpProtoAndCheckerMaker {
AddComment(R"DOC( AddComment(R"DOC(
Graph Learning Send_Recv combine operator. Graph Learning Send_Recv combine operator.
$Out = Recv(Send(X, Src_index), Dst_index, pool_type)$ $Out = Recv(Send(X, Src_index), Dst_index, reduce_op)$
This operator is mainly used in Graph Learning domain, and the main purpose is to reduce This operator is mainly used in Graph Learning domain, and the main purpose is to reduce
intermediate memory consumption in the process of message passing. intermediate memory consumption in the process of message passing.
...@@ -105,12 +105,12 @@ class GraphSendRecvGradOpMaker : public framework::SingleGradOpMaker<T> { ...@@ -105,12 +105,12 @@ class GraphSendRecvGradOpMaker : public framework::SingleGradOpMaker<T> {
op->SetInput("Dst_index", this->Input("Dst_index")); op->SetInput("Dst_index", this->Input("Dst_index"));
op->SetInput("X", this->Input("X")); op->SetInput("X", this->Input("X"));
if (PADDLE_GET_CONST(std::string, this->GetAttr("pool_type")) == "MEAN") { if (PADDLE_GET_CONST(std::string, this->GetAttr("reduce_op")) == "MEAN") {
op->SetInput("Dst_count", this->Output("Dst_count")); op->SetInput("Dst_count", this->Output("Dst_count"));
} }
if (PADDLE_GET_CONST(std::string, this->GetAttr("pool_type")) == "MIN" || if (PADDLE_GET_CONST(std::string, this->GetAttr("reduce_op")) == "MIN" ||
PADDLE_GET_CONST(std::string, this->GetAttr("pool_type")) == "MAX") { PADDLE_GET_CONST(std::string, this->GetAttr("reduce_op")) == "MAX") {
op->SetInput("Out", this->Output("Out")); op->SetInput("Out", this->Output("Out"));
} }
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/multiary.h"
namespace paddle {
namespace operators {
class GraphSendUERecvOP : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.device_context());
}
};
class GraphSendUERecvGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
auto in_dims = ctx->GetInputDim("X");
ctx->SetOutputDim(framework::GradVarName("X"), in_dims);
auto y_dims = ctx->GetInputDim("Y");
ctx->SetOutputDim(framework::GradVarName("Y"), y_dims);
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out")),
ctx.device_context());
}
};
class GraphSendUERecvOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X",
"The input tensor with data type float32, float64, int32, int64.");
AddInput("Y",
"The input edge weight tensor, data type should be same with X");
AddInput("Src_index", "The source index tensor.");
AddInput("Dst_index", "The destination index tensor.");
AddInput("Out_size",
"(Tensor<int>, optional). The 0th dimension of the output."
"It has a higher priority than Attr(out_size).")
.AsDispensable();
AddOutput("Out", "Output tensor of graph_send_ue_recv op.");
AddOutput("Dst_count",
"Count tensor of Dst_index, mainly for MEAN reduce_op.")
.AsIntermediate();
AddAttr<std::string>("message_op",
"(string, default 'ADD')"
"Define differenct computation types between X and E.")
.SetDefault("ADD")
.InEnum({"ADD", "MUL"});
AddAttr<std::string>("reduce_op",
"(string, default 'SUM')"
"Define different pool types to receive the result "
"tensors of Dst_index.")
.SetDefault("SUM")
.InEnum({"SUM", "MEAN", "MIN", "MAX"});
AddAttr<std::vector<int64_t>>(
"out_size",
"(vector<int64_t>, default {0})"
"Define the first dimension of Output tensor."
"If set default {0}, then the shape of Out is the same with X.")
.SetDefault({0});
AddComment(R"DOC(
Graph Learning Send_UE_Recv combine operator.
$Out = Recv(Compute(Send(X, Src_index), Y, message_op), Dst_index, reduce_op)$
This operator is mainly used in Graph Learning domain, and the main purpose is to reduce
intermediate memory consumption in the process of message passing.
Take `X` as the input tensor, we first use `src_index` to gather corresponding data.
Then the gather data should compute with `Y` in different message_ops, like add, sub, mul, and div,
and get the computation result. Then, use `dst_index` to update the corresponding position of output
tensor in different pooling types, like sum, mean, max, or min.
)DOC");
}
};
template <typename T>
class GraphSendUERecvGradOpMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> op) const override {
op->SetType("graph_send_ue_recv_grad");
op->SetInput("X", this->Input("X"));
op->SetInput("Y", this->Input("Y"));
op->SetInput("Src_index", this->Input("Src_index"));
op->SetInput("Dst_index", this->Input("Dst_index"));
if (PADDLE_GET_CONST(std::string, this->GetAttr("reduce_op")) == "MEAN") {
op->SetInput("Dst_count", this->Output("Dst_count"));
}
if (PADDLE_GET_CONST(std::string, this->GetAttr("reduce_op")) == "MIN" ||
PADDLE_GET_CONST(std::string, this->GetAttr("reduce_op")) == "MAX") {
op->SetInput("Out", this->Output("Out"));
}
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetOutput(framework::GradVarName("Y"), this->InputGrad("Y"));
op->SetAttrMap(this->Attrs());
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(graph_send_ue_recv,
GraphSendUERecvInferShapeFunctor,
PD_INFER_META(phi::GraphSendUERecvInferMeta));
REGISTER_OPERATOR(graph_send_ue_recv,
ops::GraphSendUERecvOP,
ops::GraphSendUERecvOpMaker,
ops::GraphSendUERecvGradOpMaker<paddle::framework::OpDesc>,
ops::GraphSendUERecvGradOpMaker<paddle::imperative::OpBase>,
GraphSendUERecvInferShapeFunctor);
REGISTER_OPERATOR(graph_send_ue_recv_grad, ops::GraphSendUERecvGradOp);
...@@ -419,6 +419,55 @@ CUDA_ATOMIC_WRAPPER(Max, double) { ...@@ -419,6 +419,55 @@ CUDA_ATOMIC_WRAPPER(Max, double) {
return __longlong_as_double(old); return __longlong_as_double(old);
} }
#ifdef PADDLE_CUDA_FP16
inline static __device__ uint32_t max_to_low_half(uint32_t val, float x) {
float16 low_half;
// The float16 in lower 16bits
low_half.x = static_cast<uint16_t>(val & 0xFFFFu);
low_half = static_cast<float16>(max(static_cast<float>(low_half), x));
return (val & 0xFFFF0000u) | low_half.x;
}
inline static __device__ uint32_t max_to_high_half(uint32_t val, float x) {
float16 high_half;
// The float16 in higher 16bits
high_half.x = static_cast<uint16_t>(val >> 16);
high_half = static_cast<float16>(max(static_cast<float>(high_half), x));
return (val & 0xFFFFu) | (static_cast<uint32_t>(high_half.x) << 16);
}
CUDA_ATOMIC_WRAPPER(Max, float16) {
if (*address >= val) {
return *address;
}
uint32_t *address_as_ui = reinterpret_cast<uint32_t *>(
reinterpret_cast<char *>(address) -
(reinterpret_cast<uintptr_t>(address) & 0x02));
float val_f = static_cast<float>(val);
uint32_t old = *address_as_ui;
uint32_t assumed;
if (((uintptr_t)address & 0x02) == 0) {
// The float16 value stay at lower 16 bits of the address.
do {
assumed = old;
old = atomicCAS(address_as_ui, assumed, max_to_low_half(assumed, val_f));
} while (old != assumed);
float16 ret;
ret.x = old & 0xFFFFu;
return ret;
} else {
// The float16 value stay at higher 16 bits of the address.
do {
assumed = old;
old = atomicCAS(address_as_ui, assumed, max_to_high_half(assumed, val_f));
} while (old != assumed);
float16 ret;
ret.x = old >> 16;
return ret;
}
}
#endif
// For atomicMin // For atomicMin
USE_CUDA_ATOMIC(Min, int); USE_CUDA_ATOMIC(Min, int);
USE_CUDA_ATOMIC(Min, unsigned int); USE_CUDA_ATOMIC(Min, unsigned int);
...@@ -503,5 +552,54 @@ CUDA_ATOMIC_WRAPPER(Min, double) { ...@@ -503,5 +552,54 @@ CUDA_ATOMIC_WRAPPER(Min, double) {
return __longlong_as_double(old); return __longlong_as_double(old);
} }
#ifdef PADDLE_CUDA_FP16
inline static __device__ uint32_t min_to_low_half(uint32_t val, float x) {
float16 low_half;
// The float16 in lower 16bits
low_half.x = static_cast<uint16_t>(val & 0xFFFFu);
low_half = static_cast<float16>(min(static_cast<float>(low_half), x));
return (val & 0xFFFF0000u) | low_half.x;
}
inline static __device__ uint32_t min_to_high_half(uint32_t val, float x) {
float16 high_half;
// The float16 in higher 16bits
high_half.x = static_cast<uint16_t>(val >> 16);
high_half = static_cast<float16>(min(static_cast<float>(high_half), x));
return (val & 0xFFFFu) | (static_cast<uint32_t>(high_half.x) << 16);
}
CUDA_ATOMIC_WRAPPER(Min, float16) {
if (*address <= val) {
return *address;
}
uint32_t *address_as_ui = reinterpret_cast<uint32_t *>(
reinterpret_cast<char *>(address) -
(reinterpret_cast<uintptr_t>(address) & 0x02));
float val_f = static_cast<float>(val);
uint32_t old = *address_as_ui;
uint32_t assumed;
if (((uintptr_t)address & 0x02) == 0) {
// The float16 value stay at lower 16 bits of the address.
do {
assumed = old;
old = atomicCAS(address_as_ui, assumed, min_to_low_half(assumed, val_f));
} while (old != assumed);
float16 ret;
ret.x = old & 0xFFFFu;
return ret;
} else {
// The float16 value stay at higher 16 bits of the address.
do {
assumed = old;
old = atomicCAS(address_as_ui, assumed, min_to_high_half(assumed, val_f));
} while (old != assumed);
float16 ret;
ret.x = old >> 16;
return ret;
}
}
#endif
} // namespace platform } // namespace platform
} // namespace paddle } // namespace paddle
...@@ -226,6 +226,7 @@ std::map<std::string, std::set<std::string>> op_ins_map = { ...@@ -226,6 +226,7 @@ std::map<std::string, std::set<std::string>> op_ins_map = {
"Mean3", "Mean3",
"Var3"}}, "Var3"}},
{"graph_send_recv", {"X", "Src_index", "Dst_index", "Out_size"}}, {"graph_send_recv", {"X", "Src_index", "Dst_index", "Out_size"}},
{"graph_send_ue_recv", {"X", "Y", "Src_index", "Dst_index", "Out_size"}},
}; };
// NOTE(zhiqiu): Like op_ins_map. // NOTE(zhiqiu): Like op_ins_map.
......
...@@ -1082,7 +1082,7 @@ ...@@ -1082,7 +1082,7 @@
func : generate_proposals_v2 func : generate_proposals_v2
- api : graph_send_recv - api : graph_send_recv
args : (Tensor x, Tensor src_index, Tensor dst_index, str pool_type = "SUM", IntArray out_size = {0}) args : (Tensor x, Tensor src_index, Tensor dst_index, str reduce_op = "SUM", IntArray out_size = {0})
output : Tensor(out), Tensor(dst_count) output : Tensor(out), Tensor(dst_count)
infer_meta : infer_meta :
func : GraphSendRecvInferMeta func : GraphSendRecvInferMeta
...@@ -1092,6 +1092,17 @@ ...@@ -1092,6 +1092,17 @@
intermediate : dst_count intermediate : dst_count
backward : graph_send_recv_grad backward : graph_send_recv_grad
- api : graph_send_ue_recv
args : (Tensor x, Tensor y, Tensor src_index, Tensor dst_index, str message_op, str reduce_op, IntArray out_size)
output : Tensor(out), Tensor(dst_count)
infer_meta :
func : GraphSendUERecvInferMeta
kernel :
func : graph_send_ue_recv
data_type : x
intermediate : dst_count
backward : graph_send_ue_recv_grad
- api : greater_equal - api : greater_equal
args : (Tensor x, Tensor y, int axis = -1) args : (Tensor x, Tensor y, int axis = -1)
output : Tensor output : Tensor
......
...@@ -941,8 +941,8 @@ ...@@ -941,8 +941,8 @@
func : gelu_grad func : gelu_grad
- backward_api : graph_send_recv_grad - backward_api : graph_send_recv_grad
forward : graph_send_recv (Tensor x, Tensor src_index, Tensor dst_index, str pool_type = "SUM", IntArray out_size = {0}) -> Tensor(out), Tensor(dst_count) forward : graph_send_recv (Tensor x, Tensor src_index, Tensor dst_index, str reduce_op = "SUM", IntArray out_size = {0}) -> Tensor(out), Tensor(dst_count)
args : (Tensor x, Tensor src_index, Tensor dst_index, Tensor out, Tensor dst_count, Tensor out_grad, str pool_type = "SUM") args : (Tensor x, Tensor src_index, Tensor dst_index, Tensor out, Tensor dst_count, Tensor out_grad, str reduce_op = "SUM")
output : Tensor(x_grad) output : Tensor(x_grad)
infer_meta : infer_meta :
func : GeneralUnaryGradInferMeta func : GeneralUnaryGradInferMeta
...@@ -952,6 +952,18 @@ ...@@ -952,6 +952,18 @@
data_type : out_grad data_type : out_grad
optional: out, dst_count optional: out, dst_count
- backward_api : graph_send_ue_recv_grad
forward : graph_send_ue_recv (Tensor x, Tensor y, Tensor src_index, Tensor dst_index, str message_op, str reduce_op, IntArray out_size) -> Tensor(out), Tensor(dst_count)
args : (Tensor x, Tensor y, Tensor src_index, Tensor dst_index, Tensor out, Tensor dst_count, Tensor out_grad, str message_op, str reduce_op)
output : Tensor(x_grad), Tensor(y_grad)
infer_meta :
func : GeneralBinaryGradInferMeta
param : [x, y]
kernel :
func : graph_send_ue_recv_grad
data_type : out_grad
optional: out, dst_count
# grid sample # grid sample
- backward_api : grid_sample_grad - backward_api : grid_sample_grad
forward : grid_sample (Tensor x, Tensor grid, str mode, str padding_mode, bool align_corners) -> Tensor(out) forward : grid_sample (Tensor x, Tensor grid, str mode, str padding_mode, bool align_corners) -> Tensor(out)
......
...@@ -20,6 +20,7 @@ limitations under the License. */ ...@@ -20,6 +20,7 @@ limitations under the License. */
#include "paddle/phi/common/scalar.h" #include "paddle/phi/common/scalar.h"
#include "paddle/phi/core/infermeta_utils.h" #include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/core/meta_tensor.h" #include "paddle/phi/core/meta_tensor.h"
#include "paddle/phi/kernels/funcs/common_shape.h"
#include "paddle/phi/kernels/funcs/concat_funcs.h" #include "paddle/phi/kernels/funcs/concat_funcs.h"
namespace phi { namespace phi {
...@@ -2598,6 +2599,94 @@ void Yolov3LossInferMeta(const MetaTensor& x, ...@@ -2598,6 +2599,94 @@ void Yolov3LossInferMeta(const MetaTensor& x,
gt_match_mask->set_dtype(x.dtype()); gt_match_mask->set_dtype(x.dtype());
} }
void GraphSendUERecvInferMeta(const MetaTensor& x,
const MetaTensor& y,
const MetaTensor& src_index,
const MetaTensor& dst_index,
const std::string& message_op,
const std::string& reduce_op,
const IntArray& out_size,
MetaTensor* out,
MetaTensor* dst_count) {
auto src_index_dims = src_index.dims();
if (src_index_dims.size() == 2) {
PADDLE_ENFORCE_EQ(src_index_dims[1],
1,
phi::errors::InvalidArgument(
"The last dim of Src_index should be 1 when it "
"is 2D, but we get %d",
src_index_dims[1]));
} else {
PADDLE_ENFORCE_EQ(
src_index_dims.size(),
1,
phi::errors::InvalidArgument(
"The Src_index should be 1D, when it is not 2D, but we get %d",
src_index_dims.size()));
}
auto dst_index_dims = dst_index.dims();
if (dst_index_dims.size() == 2) {
PADDLE_ENFORCE_EQ(dst_index_dims[1],
1,
phi::errors::InvalidArgument(
"The last dim of Dst_index should be 1 when it "
"is 2D, but we get %d",
dst_index_dims[1]));
} else {
PADDLE_ENFORCE_EQ(
dst_index_dims.size(),
1,
phi::errors::InvalidArgument("The Dst_index should be 1D, "
"when it is not 2D, but we get %d",
dst_index_dims.size()));
}
PADDLE_ENFORCE_EQ(src_index_dims[0],
dst_index_dims[0],
phi::errors::InvalidArgument(
"Src_index and Dst_index should have the same shape."));
auto y_dims = y.dims();
PADDLE_ENFORCE_EQ(
y_dims[0],
src_index_dims[0],
phi::errors::InvalidArgument(
"Expect Input Y to have size %d as Src_index on the first dimension, "
"but we get %d",
src_index_dims[0],
y_dims[0]));
auto x_dims = x.dims();
if (reduce_op == "MEAN") {
dst_count->set_dims({-1});
dst_count->set_dtype(DataType::INT32);
}
// Infer out's shape according to x and e(need broadcasting condition)
out->set_dtype(x.dtype());
auto x_dims1 = phi::vectorize<int>(x_dims);
auto y_dims1 = phi::vectorize<int>(y_dims);
std::vector<int> x_dims2(x_dims1.begin() + 1, x_dims1.end());
std::vector<int> y_dims2(y_dims1.begin() + 1, y_dims1.end());
int max_dim = std::max(x_dims2.size(), y_dims2.size());
int axis = std::abs(static_cast<int>(x_dims2.size() - y_dims2.size()));
std::vector<int> x_dims_array(max_dim);
std::vector<int> y_dims_array(max_dim);
std::vector<int> out_dims_array(max_dim);
// Only need to broadcast dimensions other than the 0th dimension.
phi::funcs::GetBroadcastDimsArrays(phi::make_ddim(x_dims2),
phi::make_ddim(y_dims2),
x_dims_array.data(),
y_dims_array.data(),
out_dims_array.data(),
max_dim,
axis);
out_dims_array.insert(out_dims_array.begin(), -1);
out->set_dims(phi::make_ddim(out_dims_array));
}
} // namespace phi } // namespace phi
PD_REGISTER_INFER_META_FN(batch_norm, phi::BatchNormInferMeta); PD_REGISTER_INFER_META_FN(batch_norm, phi::BatchNormInferMeta);
......
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once #pragma once
#include "paddle/phi/common/int_array.h"
#include "paddle/phi/common/scalar.h" #include "paddle/phi/common/scalar.h"
#include "paddle/phi/core/meta_tensor.h" #include "paddle/phi/core/meta_tensor.h"
namespace phi { namespace phi {
...@@ -465,4 +466,14 @@ void Yolov3LossInferMeta(const MetaTensor& x, ...@@ -465,4 +466,14 @@ void Yolov3LossInferMeta(const MetaTensor& x,
MetaTensor* objectness_mask, MetaTensor* objectness_mask,
MetaTensor* gt_match_mask); MetaTensor* gt_match_mask);
void GraphSendUERecvInferMeta(const MetaTensor& x,
const MetaTensor& y,
const MetaTensor& src_index,
const MetaTensor& dst_index,
const std::string& message_op,
const std::string& reduce_op,
const IntArray& out_size,
MetaTensor* out,
MetaTensor* dst_count);
} // namespace phi } // namespace phi
...@@ -411,7 +411,7 @@ void InstanceNormInferMeta(const MetaTensor& x, ...@@ -411,7 +411,7 @@ void InstanceNormInferMeta(const MetaTensor& x,
void GraphSendRecvInferMeta(const MetaTensor& x, void GraphSendRecvInferMeta(const MetaTensor& x,
const MetaTensor& src_index, const MetaTensor& src_index,
const MetaTensor& dst_index, const MetaTensor& dst_index,
const std::string& pool_type, const std::string& reduce_op,
const IntArray& out_size, const IntArray& out_size,
MetaTensor* out, MetaTensor* out,
MetaTensor* dst_count) { MetaTensor* dst_count) {
...@@ -460,7 +460,7 @@ void GraphSendRecvInferMeta(const MetaTensor& x, ...@@ -460,7 +460,7 @@ void GraphSendRecvInferMeta(const MetaTensor& x,
out->set_dims(phi::make_ddim(dims_)); out->set_dims(phi::make_ddim(dims_));
out->set_dtype(x.dtype()); out->set_dtype(x.dtype());
if (pool_type == "MEAN") { if (reduce_op == "MEAN") {
dst_count->set_dims({-1}); dst_count->set_dims({-1});
dst_count->set_dtype(DataType::INT32); dst_count->set_dtype(DataType::INT32);
} }
......
...@@ -75,7 +75,7 @@ void InstanceNormInferMeta(const MetaTensor& x, ...@@ -75,7 +75,7 @@ void InstanceNormInferMeta(const MetaTensor& x,
void GraphSendRecvInferMeta(const MetaTensor& x, void GraphSendRecvInferMeta(const MetaTensor& x,
const MetaTensor& src_index, const MetaTensor& src_index,
const MetaTensor& dst_index, const MetaTensor& dst_index,
const std::string& pool_type, const std::string& reduce_op,
const IntArray& out_size, const IntArray& out_size,
MetaTensor* out, MetaTensor* out,
MetaTensor* dst_count); MetaTensor* dst_count);
......
...@@ -29,10 +29,10 @@ void GraphSendRecvCpuGradLoop(const int& index_size, ...@@ -29,10 +29,10 @@ void GraphSendRecvCpuGradLoop(const int& index_size,
const DenseTensor& src, const DenseTensor& src,
const DenseTensor& input, const DenseTensor& input,
DenseTensor* dst, DenseTensor* dst,
const std::string& pool_type, const std::string& reduce_op,
const int* dst_count = nullptr, const int* dst_count = nullptr,
const DenseTensor* output = nullptr) { const DenseTensor* output = nullptr) {
if (pool_type == "SUM") { if (reduce_op == "SUM") {
Functor functor; Functor functor;
for (int i = 0; i < index_size; ++i) { for (int i = 0; i < index_size; ++i) {
const IndexT& src_idx = s_index[i]; const IndexT& src_idx = s_index[i];
...@@ -40,7 +40,7 @@ void GraphSendRecvCpuGradLoop(const int& index_size, ...@@ -40,7 +40,7 @@ void GraphSendRecvCpuGradLoop(const int& index_size,
ElementwiseInnerOperation<T, IndexT, Functor>( ElementwiseInnerOperation<T, IndexT, Functor>(
src, dst, src_idx, dst_idx, false, functor); src, dst, src_idx, dst_idx, false, functor);
} }
} else if (pool_type == "MEAN") { } else if (reduce_op == "MEAN") {
for (int i = 0; i < index_size; ++i) { for (int i = 0; i < index_size; ++i) {
const IndexT& src_idx = s_index[i]; const IndexT& src_idx = s_index[i];
const IndexT& dst_idx = d_index[i]; const IndexT& dst_idx = d_index[i];
...@@ -50,7 +50,7 @@ void GraphSendRecvCpuGradLoop(const int& index_size, ...@@ -50,7 +50,7 @@ void GraphSendRecvCpuGradLoop(const int& index_size,
auto eigen_dst = phi::EigenVector<T>::Flatten(dst_slice); auto eigen_dst = phi::EigenVector<T>::Flatten(dst_slice);
eigen_dst += (eigen_src / static_cast<T>(dst_count[src_idx])); eigen_dst += (eigen_src / static_cast<T>(dst_count[src_idx]));
} }
} else if (pool_type == "MIN" || pool_type == "MAX") { } else if (reduce_op == "MIN" || reduce_op == "MAX") {
for (int i = 0; i < index_size; ++i) { for (int i = 0; i < index_size; ++i) {
const IndexT& forward_src_idx = d_index[i]; const IndexT& forward_src_idx = d_index[i];
const IndexT& forward_dst_idx = s_index[i]; const IndexT& forward_dst_idx = s_index[i];
...@@ -75,7 +75,7 @@ void GraphSendRecvGradOpKernelLaunchHelper( ...@@ -75,7 +75,7 @@ void GraphSendRecvGradOpKernelLaunchHelper(
const DenseTensor& x, const DenseTensor& x,
const DenseTensor& src_index, const DenseTensor& src_index,
const DenseTensor& dst_index, const DenseTensor& dst_index,
const std::string& pool_type, const std::string& reduce_op,
DenseTensor* x_grad, DenseTensor* x_grad,
const DenseTensor* dst_count = nullptr, const DenseTensor* dst_count = nullptr,
const DenseTensor* out = nullptr) { const DenseTensor* out = nullptr) {
...@@ -94,15 +94,15 @@ void GraphSendRecvGradOpKernelLaunchHelper( ...@@ -94,15 +94,15 @@ void GraphSendRecvGradOpKernelLaunchHelper(
const IndexT* s_index = src_index.data<IndexT>(); const IndexT* s_index = src_index.data<IndexT>();
const IndexT* d_index = dst_index.data<IndexT>(); const IndexT* d_index = dst_index.data<IndexT>();
if (pool_type == "SUM") { if (reduce_op == "SUM") {
GraphSendRecvCpuGradLoop<T, IndexT, GraphSendRecvSumFunctor<T>>( GraphSendRecvCpuGradLoop<T, IndexT, GraphSendRecvSumFunctor<T>>(
index_size, d_index, s_index, out_grad, x, x_grad, pool_type); index_size, d_index, s_index, out_grad, x, x_grad, reduce_op);
} else if (pool_type == "MEAN") { } else if (reduce_op == "MEAN") {
const int* s_count = dst_count->data<int>(); const int* s_count = dst_count->data<int>();
// Functor not used here. // Functor not used here.
GraphSendRecvCpuGradLoop<T, IndexT, GraphSendRecvSumFunctor<T>>( GraphSendRecvCpuGradLoop<T, IndexT, GraphSendRecvSumFunctor<T>>(
index_size, d_index, s_index, out_grad, x, x_grad, pool_type, s_count); index_size, d_index, s_index, out_grad, x, x_grad, reduce_op, s_count);
} else if (pool_type == "MIN" || pool_type == "MAX") { } else if (reduce_op == "MIN" || reduce_op == "MAX") {
// Functor not used here. // Functor not used here.
GraphSendRecvCpuGradLoop<T, IndexT, GraphSendRecvMinFunctor<T>>(index_size, GraphSendRecvCpuGradLoop<T, IndexT, GraphSendRecvMinFunctor<T>>(index_size,
d_index, d_index,
...@@ -110,7 +110,7 @@ void GraphSendRecvGradOpKernelLaunchHelper( ...@@ -110,7 +110,7 @@ void GraphSendRecvGradOpKernelLaunchHelper(
out_grad, out_grad,
x, x,
x_grad, x_grad,
pool_type, reduce_op,
nullptr, nullptr,
out); out);
} }
...@@ -124,7 +124,7 @@ void GraphSendRecvGradKernel(const Context& ctx, ...@@ -124,7 +124,7 @@ void GraphSendRecvGradKernel(const Context& ctx,
const paddle::optional<DenseTensor>& out, const paddle::optional<DenseTensor>& out,
const paddle::optional<DenseTensor>& dst_count, const paddle::optional<DenseTensor>& dst_count,
const DenseTensor& out_grad, const DenseTensor& out_grad,
const std::string& pool_type, const std::string& reduce_op,
DenseTensor* x_grad) { DenseTensor* x_grad) {
auto index_type = src_index.dtype(); auto index_type = src_index.dtype();
if (index_type == phi::DataType::INT32) { if (index_type == phi::DataType::INT32) {
...@@ -134,7 +134,7 @@ void GraphSendRecvGradKernel(const Context& ctx, ...@@ -134,7 +134,7 @@ void GraphSendRecvGradKernel(const Context& ctx,
x, x,
src_index, src_index,
dst_index, dst_index,
pool_type, reduce_op,
x_grad, x_grad,
dst_count.get_ptr(), dst_count.get_ptr(),
out.get_ptr()); out.get_ptr());
...@@ -145,7 +145,7 @@ void GraphSendRecvGradKernel(const Context& ctx, ...@@ -145,7 +145,7 @@ void GraphSendRecvGradKernel(const Context& ctx,
x, x,
src_index, src_index,
dst_index, dst_index,
pool_type, reduce_op,
x_grad, x_grad,
dst_count.get_ptr(), dst_count.get_ptr(),
out.get_ptr()); out.get_ptr());
......
...@@ -32,17 +32,17 @@ void GraphSendRecvCpuLoop(const int& input_size, ...@@ -32,17 +32,17 @@ void GraphSendRecvCpuLoop(const int& input_size,
const IndexT* d_index, const IndexT* d_index,
const DenseTensor& src, const DenseTensor& src,
DenseTensor* dst, DenseTensor* dst,
const std::string& pool_type, const std::string& reduce_op,
int* dst_count = nullptr) { int* dst_count = nullptr) {
Functor functor; Functor functor;
if (pool_type == "SUM") { if (reduce_op == "SUM") {
for (int i = 0; i < index_size; ++i) { for (int i = 0; i < index_size; ++i) {
const IndexT& src_idx = s_index[i]; const IndexT& src_idx = s_index[i];
const IndexT& dst_idx = d_index[i]; const IndexT& dst_idx = d_index[i];
ElementwiseInnerOperation<T, IndexT, Functor>( ElementwiseInnerOperation<T, IndexT, Functor>(
src, dst, src_idx, dst_idx, false, functor); src, dst, src_idx, dst_idx, false, functor);
} }
} else if (pool_type == "MEAN") { } else if (reduce_op == "MEAN") {
for (int i = 0; i < index_size; ++i) { for (int i = 0; i < index_size; ++i) {
const IndexT& src_idx = s_index[i]; const IndexT& src_idx = s_index[i];
const IndexT& dst_idx = d_index[i]; const IndexT& dst_idx = d_index[i];
...@@ -59,7 +59,7 @@ void GraphSendRecvCpuLoop(const int& input_size, ...@@ -59,7 +59,7 @@ void GraphSendRecvCpuLoop(const int& input_size,
auto eigen_dst = phi::EigenVector<T>::Flatten(dst_slice); auto eigen_dst = phi::EigenVector<T>::Flatten(dst_slice);
eigen_dst = eigen_dst / static_cast<T>(*(dst_count + i)); eigen_dst = eigen_dst / static_cast<T>(*(dst_count + i));
} }
} else if (pool_type == "MIN" || pool_type == "MAX") { } else if (reduce_op == "MIN" || reduce_op == "MAX") {
std::set<IndexT> existed_dst; std::set<IndexT> existed_dst;
for (int i = 0; i < index_size; ++i) { for (int i = 0; i < index_size; ++i) {
const IndexT& src_idx = s_index[i]; const IndexT& src_idx = s_index[i];
...@@ -82,7 +82,7 @@ void GraphSendRecvOpKernelLaunchHelper(const Context& ctx, ...@@ -82,7 +82,7 @@ void GraphSendRecvOpKernelLaunchHelper(const Context& ctx,
const DenseTensor& x, const DenseTensor& x,
const DenseTensor& src_index, const DenseTensor& src_index,
const DenseTensor& dst_index, const DenseTensor& dst_index,
const std::string& pool_type, const std::string& reduce_op,
int64_t out_size, int64_t out_size,
DenseTensor* out, DenseTensor* out,
DenseTensor* dst_count = nullptr) { DenseTensor* dst_count = nullptr) {
...@@ -117,16 +117,16 @@ void GraphSendRecvOpKernelLaunchHelper(const Context& ctx, ...@@ -117,16 +117,16 @@ void GraphSendRecvOpKernelLaunchHelper(const Context& ctx,
const IndexT* s_index = src_index.data<IndexT>(); const IndexT* s_index = src_index.data<IndexT>();
const IndexT* d_index = dst_index.data<IndexT>(); const IndexT* d_index = dst_index.data<IndexT>();
if (pool_type == "SUM") { if (reduce_op == "SUM") {
GraphSendRecvCpuLoop<T, IndexT, GraphSendRecvSumFunctor<T>>( GraphSendRecvCpuLoop<T, IndexT, GraphSendRecvSumFunctor<T>>(
src_dims[0], index_size, s_index, d_index, x, out, pool_type); src_dims[0], index_size, s_index, d_index, x, out, reduce_op);
} else if (pool_type == "MIN") { } else if (reduce_op == "MIN") {
GraphSendRecvCpuLoop<T, IndexT, GraphSendRecvMinFunctor<T>>( GraphSendRecvCpuLoop<T, IndexT, GraphSendRecvMinFunctor<T>>(
src_dims[0], index_size, s_index, d_index, x, out, pool_type); src_dims[0], index_size, s_index, d_index, x, out, reduce_op);
} else if (pool_type == "MAX") { } else if (reduce_op == "MAX") {
GraphSendRecvCpuLoop<T, IndexT, GraphSendRecvMaxFunctor<T>>( GraphSendRecvCpuLoop<T, IndexT, GraphSendRecvMaxFunctor<T>>(
src_dims[0], index_size, s_index, d_index, x, out, pool_type); src_dims[0], index_size, s_index, d_index, x, out, reduce_op);
} else if (pool_type == "MEAN") { } else if (reduce_op == "MEAN") {
int64_t input_size = out_size <= 0 ? src_dims[0] : out_size; int64_t input_size = out_size <= 0 ? src_dims[0] : out_size;
dst_count->Resize({input_size}); dst_count->Resize({input_size});
ctx.template Alloc<int>(dst_count); ctx.template Alloc<int>(dst_count);
...@@ -138,7 +138,7 @@ void GraphSendRecvOpKernelLaunchHelper(const Context& ctx, ...@@ -138,7 +138,7 @@ void GraphSendRecvOpKernelLaunchHelper(const Context& ctx,
d_index, d_index,
x, x,
out, out,
pool_type, reduce_op,
p_dst_count); p_dst_count);
} }
} }
...@@ -148,7 +148,7 @@ void GraphSendRecvKernel(const Context& ctx, ...@@ -148,7 +148,7 @@ void GraphSendRecvKernel(const Context& ctx,
const DenseTensor& x, const DenseTensor& x,
const DenseTensor& src_index, const DenseTensor& src_index,
const DenseTensor& dst_index, const DenseTensor& dst_index,
const std::string& pool_type, const std::string& reduce_op,
const IntArray& out_size, const IntArray& out_size,
DenseTensor* out, DenseTensor* out,
DenseTensor* dst_count) { DenseTensor* dst_count) {
...@@ -159,7 +159,7 @@ void GraphSendRecvKernel(const Context& ctx, ...@@ -159,7 +159,7 @@ void GraphSendRecvKernel(const Context& ctx,
x, x,
src_index, src_index,
dst_index, dst_index,
pool_type, reduce_op,
out_size_data[0], out_size_data[0],
out, out,
dst_count); dst_count);
...@@ -168,7 +168,7 @@ void GraphSendRecvKernel(const Context& ctx, ...@@ -168,7 +168,7 @@ void GraphSendRecvKernel(const Context& ctx,
x, x,
src_index, src_index,
dst_index, dst_index,
pool_type, reduce_op,
out_size_data[0], out_size_data[0],
out, out,
dst_count); dst_count);
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include <vector>
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/hostdevice.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
namespace phi {
template <typename T>
struct GraphAddFunctor {
inline T operator()(const T a, const T b) const { return a + b; }
};
template <typename T>
struct GraphMulFunctor {
inline T operator()(const T a, const T b) const { return a * b; }
};
template <typename T>
struct GraphMaxFunctor {
inline T operator()(const T a, const T b) const { return a < b ? b : a; }
};
template <typename T>
struct GraphMinFunctor {
inline T operator()(const T a, const T b) const { return a < b ? a : b; }
};
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/graph_send_ue_recv_grad_kernel.h"
#include <algorithm>
#include <vector>
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/hostdevice.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/cpu/graph_send_recv_funcs.h"
#include "paddle/phi/kernels/cpu/graph_send_ue_recv_funcs.h"
#include "paddle/phi/kernels/empty_kernel.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/impl/graph_messaage_passing_impl.h"
#include "paddle/phi/kernels/reduce_sum_kernel.h"
namespace phi {
template <typename Context, typename T, typename IndexT>
void CalculateXGrad(const Context& ctx,
const T* out_grad,
const T* x_data,
const T* e_data,
const phi::DDim& out_grad_dims,
const phi::DDim& x_dims,
const phi::DDim& e_dims,
const IndexT* s_index,
const IndexT* d_index,
const std::string& message_op,
const std::string& reduce_op,
int64_t index_size,
T* x_grad,
const DenseTensor& out_grad_tensor,
DenseTensor* x_grad_tensor,
const DenseTensor* dst_count = nullptr,
const DenseTensor* out = nullptr) {
std::vector<int64_t> reduce_idx;
bool reduce = ReduceGrad(out_grad_dims, x_dims, reduce_idx);
if (reduce_op == "SUM") {
if (message_op == "ADD") {
GraphSendRecvSumFunctor<T> sum_functor;
if (!reduce) {
for (int64_t i = 0; i < index_size; i++) {
IndexT src = s_index[i];
IndexT dst = d_index[i];
ElementwiseInnerOperation<T, IndexT, GraphSendRecvSumFunctor<T>>(
out_grad_tensor, x_grad_tensor, src, dst, false, sum_functor);
}
} else {
DenseTensor x_grad_v2 =
phi::EmptyLike<T, Context>(ctx, out_grad_tensor);
phi::funcs::SetConstant<Context, T>()(ctx, &x_grad_v2, T(0));
for (int64_t i = 0; i < index_size; i++) {
IndexT src = s_index[i];
IndexT dst = d_index[i];
ElementwiseInnerOperation<T, IndexT, GraphSendRecvSumFunctor<T>>(
out_grad_tensor, &x_grad_v2, src, dst, false, sum_functor);
}
DenseTensor x_grad_out = phi::Sum<T, Context>(
ctx,
x_grad_v2,
reduce_idx,
paddle::experimental::CppTypeToDataType<T>::Type(),
true);
memcpy(x_grad, x_grad_out.data<T>(), x_grad_out.numel() * sizeof(T));
}
} else if (message_op == "MUL") {
const auto& bcast = phi::CalcBCastInfo(out_grad_dims, e_dims);
if (!reduce) {
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for
#endif
for (int64_t i = 0; i < index_size; i++) {
IndexT src = s_index[i];
IndexT dst = d_index[i];
T* x_grad_off = x_grad + dst * bcast.out_len;
const T* out_grad_off = out_grad + src * bcast.l_len;
const T* e_off = e_data + i * bcast.r_len;
for (int j = 0; j < bcast.out_len; j++) {
int64_t o_add = bcast.use_bcast ? bcast.l_offset[j] : j;
int64_t e_add = bcast.use_bcast ? bcast.r_offset[j] : j;
T val = out_grad_off[o_add] * e_off[e_add];
if (val != 0) {
#ifdef PADDLE_WITH_MKLML
#pragma omp atomic
#endif
x_grad_off[j] += val;
}
}
}
} else {
DenseTensor x_grad_v2 =
phi::EmptyLike<T, Context>(ctx, out_grad_tensor);
phi::funcs::SetConstant<Context, T>()(ctx, &x_grad_v2, T(0));
T* x_grad_v2_data = x_grad_v2.data<T>();
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for
#endif
for (int64_t i = 0; i < index_size; i++) {
IndexT src = s_index[i];
IndexT dst = d_index[i];
T* x_grad_off = x_grad_v2_data + dst * bcast.out_len;
const T* out_grad_off = out_grad + src * bcast.l_len;
const T* e_off = e_data + i * bcast.r_len;
for (int j = 0; j < bcast.out_len; j++) {
int64_t o_add = bcast.use_bcast ? bcast.l_offset[j] : j;
int64_t e_add = bcast.use_bcast ? bcast.r_offset[j] : j;
T val = out_grad_off[o_add] * e_off[e_add];
if (val != 0) {
#ifdef PADDLE_WITH_MKLML
#pragma omp atomic
#endif
x_grad_off[j] += val;
}
}
}
DenseTensor x_grad_out = phi::Sum<T, Context>(
ctx,
x_grad_v2,
reduce_idx,
paddle::experimental::CppTypeToDataType<T>::Type(),
true);
memcpy(x_grad, x_grad_out.data<T>(), x_grad_out.numel() * sizeof(T));
}
}
} else if (reduce_op == "MEAN") {
const int* s_count = dst_count->data<int>();
if (message_op == "ADD") {
if (!reduce) {
for (int64_t i = 0; i < index_size; i++) {
IndexT src = s_index[i];
IndexT dst = d_index[i];
auto out_grad_slice = out_grad_tensor.Slice(src, src + 1);
auto x_grad_slice = x_grad_tensor->Slice(dst, dst + 1);
auto eigen_out_grad = phi::EigenVector<T>::Flatten(out_grad_slice);
auto eigen_x_grad = phi::EigenVector<T>::Flatten(x_grad_slice);
eigen_x_grad += (eigen_out_grad / static_cast<T>(s_count[src]));
}
} else {
DenseTensor x_grad_v2 =
phi::EmptyLike<T, Context>(ctx, out_grad_tensor);
phi::funcs::SetConstant<Context, T>()(ctx, &x_grad_v2, T(0));
for (int64_t i = 0; i < index_size; i++) {
IndexT src = s_index[i];
IndexT dst = d_index[i];
auto out_grad_slice = out_grad_tensor.Slice(src, src + 1);
auto x_grad_slice = x_grad_v2.Slice(dst, dst + 1);
auto eigen_out_grad = phi::EigenVector<T>::Flatten(out_grad_slice);
auto eigen_x_grad = phi::EigenVector<T>::Flatten(x_grad_slice);
eigen_x_grad += (eigen_out_grad / static_cast<T>(s_count[src]));
}
DenseTensor x_grad_out = phi::Sum<T, Context>(
ctx,
x_grad_v2,
reduce_idx,
paddle::experimental::CppTypeToDataType<T>::Type(),
true);
memcpy(x_grad, x_grad_out.data<T>(), x_grad_out.numel() * sizeof(T));
}
} else if (message_op == "MUL") {
const auto& bcast = phi::CalcBCastInfo(out_grad_dims, e_dims);
if (!reduce) {
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for
#endif
for (int64_t i = 0; i < index_size; i++) {
IndexT src = s_index[i];
IndexT dst = d_index[i];
const T* out_grad_off = out_grad + src * bcast.l_len;
const T* e_off = e_data + i * bcast.r_len;
T* x_grad_off = x_grad + dst * bcast.out_len;
for (int64_t j = 0; j < bcast.out_len; j++) {
int64_t o_add = bcast.use_bcast ? bcast.l_offset[j] : j;
int64_t e_add = bcast.use_bcast ? bcast.r_offset[j] : j;
T val = out_grad_off[o_add] * e_off[e_add];
#ifdef PADDLE_WITH_MKLML
#pragma omp atomic
#endif
x_grad_off[j] += (val / s_count[src]);
}
}
} else {
DenseTensor x_grad_v2 =
phi::EmptyLike<T, Context>(ctx, out_grad_tensor);
phi::funcs::SetConstant<Context, T>()(ctx, &x_grad_v2, T(0));
T* x_grad_v2_data = x_grad_v2.data<T>();
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for
#endif
for (int64_t i = 0; i < index_size; i++) {
IndexT src = s_index[i];
IndexT dst = d_index[i];
const T* out_grad_off = out_grad + src * bcast.l_len;
const T* e_off = e_data + i * bcast.r_len;
T* x_grad_off = x_grad_v2_data + dst * bcast.out_len;
for (int64_t j = 0; j < bcast.out_len; j++) {
int64_t o_add = bcast.use_bcast ? bcast.l_offset[j] : j;
int64_t e_add = bcast.use_bcast ? bcast.r_offset[j] : j;
T val = out_grad_off[o_add] * e_off[e_add];
#ifdef PADDLE_WITH_MKLML
#pragma omp atomic
#endif
x_grad_off[j] += (val / s_count[src]);
}
}
DenseTensor x_grad_out = phi::Sum<T, Context>(
ctx,
x_grad_v2,
reduce_idx,
paddle::experimental::CppTypeToDataType<T>::Type(),
true);
memcpy(x_grad, x_grad_out.data<T>(), x_grad_out.numel() * sizeof(T));
}
}
}
}
template <typename T, typename IndexT>
void CalculateEGrad(const T* out_grad_data,
const T* x_data,
const T* e_data,
const phi::DDim& x_dims,
const phi::DDim& e_dims,
const IndexT* s_index,
const IndexT* d_index,
const std::string& message_op,
const std::string& reduce_op,
int64_t index_size,
T* e_grad,
const DenseTensor* dst_count = nullptr) {
const auto& bcast = phi::CalcBCastInfo(x_dims, e_dims);
if (reduce_op == "SUM") {
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for
#endif
for (int64_t i = 0; i < index_size; i++) {
IndexT src = s_index[i];
IndexT dst = d_index[i];
const T* x_off = x_data + src * bcast.l_len;
const T* out_grad_off = out_grad_data + dst * bcast.out_len;
T* e_grad_off = e_grad + i * bcast.r_len;
for (int64_t j = 0; j < bcast.out_len; j++) {
int64_t x_add = bcast.use_bcast ? bcast.l_offset[j] : j;
int64_t e_add = bcast.use_bcast ? bcast.r_offset[j] : j;
if (message_op == "ADD") {
#ifdef PADDLE_WITH_MKLML
#pragma omp atomic
#endif
e_grad_off[e_add] += out_grad_off[j];
} else if (message_op == "MUL") {
#ifdef PADDLE_WITH_MKLML
#pragma omp atomic
#endif
e_grad_off[e_add] += (out_grad_off[j] * x_off[x_add]);
}
}
}
} else if (reduce_op == "MEAN") {
const int* s_count = dst_count->data<int>();
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for
#endif
for (int64_t i = 0; i < index_size; i++) {
IndexT src = s_index[i];
IndexT dst = d_index[i];
const T* x_off = x_data + src * bcast.l_len;
const T* out_grad_off = out_grad_data + dst * bcast.out_len;
T* e_grad_off = e_grad + i * bcast.r_len;
for (int64_t j = 0; j < bcast.out_len; j++) {
int64_t x_add = bcast.use_bcast ? bcast.l_offset[j] : j;
int64_t e_add = bcast.use_bcast ? bcast.r_offset[j] : j;
if (message_op == "ADD") {
#ifdef PADDLE_WITH_MKLML
#pragma omp atomic
#endif
e_grad_off[e_add] += (out_grad_off[j] / s_count[dst]);
} else if (message_op == "MUL") {
#ifdef PADDLE_WITH_MKLML
#pragma omp atomic
#endif
e_grad_off[e_add] += (out_grad_off[j] * x_off[x_add] / s_count[dst]);
}
}
}
}
}
template <typename T, typename IndexT>
void CalculateXEGradForMinMax(const T* out_grad,
const T* x_data,
const T* e_data,
const phi::DDim& x_dims,
const phi::DDim& e_dims,
const IndexT* s_index,
const IndexT* d_index,
const std::string& message_op,
const std::string& reduce_op,
int64_t index_size,
T* x_grad,
T* e_grad,
const DenseTensor* out = nullptr) {
const T* out_data = out->data<T>();
const auto& bcast = phi::CalcBCastInfo(x_dims, e_dims);
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for
#endif
for (int64_t i = 0; i < index_size; i++) {
IndexT src = s_index[i];
IndexT dst = d_index[i];
const T* x_off = x_data + dst * bcast.l_len;
const T* e_off = e_data + i * bcast.r_len;
const T* out_off = out_data + src * bcast.out_len;
const T* out_grad_off = out_grad + src * bcast.out_len;
T* x_grad_off = x_grad + dst * bcast.l_len;
T* e_grad_off = e_grad + i * bcast.r_len;
for (int64_t j = 0; j < bcast.out_len; j++) {
int64_t x_add = bcast.use_bcast ? bcast.l_offset[j] : j;
int64_t e_add = bcast.use_bcast ? bcast.r_offset[j] : j;
if (message_op == "ADD") {
T val = x_off[x_add] + e_off[e_add];
#ifdef PADDLE_WITH_MKLML
#pragma omp critical
#endif
x_grad_off[x_add] += (out_grad_off[j] * (val == out_off[j]));
e_grad_off[e_add] += (out_grad_off[j] * (val == out_off[j]));
} else if (message_op == "MUL") {
T val = x_off[x_add] * e_off[e_add];
#ifdef PADDLE_WITH_MKLML
#pragma omp critical
#endif
x_grad_off[x_add] +=
(out_grad_off[j] * (val == out_off[j]) * e_off[e_add]);
e_grad_off[e_add] +=
(out_grad_off[j] * (val == out_off[j]) * x_off[x_add]);
}
}
}
}
template <typename Context, typename T, typename IndexT>
void GraphSendUERecvGradOpKernelLaunchHelper(
const Context& ctx,
const DenseTensor& out_grad,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& src_index,
const DenseTensor& dst_index,
const std::string& message_op,
const std::string& reduce_op,
DenseTensor* x_grad,
DenseTensor* y_grad,
const DenseTensor* dst_count = nullptr,
const DenseTensor* out = nullptr) {
const int& index_size = dst_index.dims()[0];
ctx.template Alloc<T>(x_grad);
T* x_grad_data = x_grad->data<T>();
ctx.template Alloc<T>(y_grad);
T* y_grad_data = y_grad->data<T>();
const auto& x_dims = x.dims();
const auto& y_dims = y.dims();
int64_t memset_size_x = 1, memset_size_y = 1;
int64_t slice_size = 1;
for (int i = 0; i < x_dims.size(); i++) {
memset_size_x *= x_dims[i];
if (i > 0) slice_size *= x_dims[i];
}
for (int i = 0; i < y_dims.size(); i++) {
memset_size_y *= y_dims[i];
}
const size_t& memset_bytes_x = memset_size_x * sizeof(T);
const size_t& memset_bytes_y = memset_size_y * sizeof(T);
memset(x_grad_data, 0, memset_bytes_x);
memset(y_grad_data, 0, memset_bytes_y);
if (index_size == 0) return;
const T* out_grad_data = out_grad.data<T>();
const T* x_data = x.data<T>();
const T* y_data = y.data<T>();
const IndexT* s_index = src_index.data<IndexT>();
const IndexT* d_index = dst_index.data<IndexT>();
if (reduce_op == "SUM" || reduce_op == "MEAN") {
CalculateXGrad<Context, T, IndexT>(ctx,
out_grad_data,
x_data,
y_data,
out_grad.dims(),
x_dims,
y_dims,
d_index,
s_index,
message_op,
reduce_op,
index_size,
x_grad_data,
out_grad,
x_grad,
dst_count,
out);
CalculateEGrad<T, IndexT>(out_grad_data,
x_data,
y_data,
x_dims,
y_dims,
s_index,
d_index,
message_op,
reduce_op,
index_size,
y_grad_data,
dst_count);
} else if (reduce_op == "MIN" || reduce_op == "MAX") {
CalculateXEGradForMinMax<T, IndexT>(out_grad_data,
x_data,
y_data,
x_dims,
y_dims,
d_index,
s_index,
message_op,
reduce_op,
index_size,
x_grad_data,
y_grad_data,
out);
}
}
template <typename T, typename Context>
void GraphSendUERecvGradKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& src_index,
const DenseTensor& dst_index,
const paddle::optional<DenseTensor>& out,
const paddle::optional<DenseTensor>& dst_count,
const DenseTensor& out_grad,
const std::string& message_op,
const std::string& reduce_op,
DenseTensor* x_grad,
DenseTensor* y_grad) {
auto index_type = src_index.dtype();
if (index_type == phi::DataType::INT32) {
GraphSendUERecvGradOpKernelLaunchHelper<Context, T, int32_t>(
ctx,
out_grad,
x,
y,
src_index,
dst_index,
message_op,
reduce_op,
x_grad,
y_grad,
dst_count.get_ptr(),
out.get_ptr());
} else if (index_type == phi::DataType::INT64) {
GraphSendUERecvGradOpKernelLaunchHelper<Context, T, int64_t>(
ctx,
out_grad,
x,
y,
src_index,
dst_index,
message_op,
reduce_op,
x_grad,
y_grad,
dst_count.get_ptr(),
out.get_ptr());
}
}
} // namespace phi
PD_REGISTER_KERNEL(graph_send_ue_recv_grad,
CPU,
ALL_LAYOUT,
phi::GraphSendUERecvGradKernel,
float,
double,
int,
int64_t) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/graph_send_ue_recv_kernel.h"
#include <algorithm>
#include <set>
#include <vector>
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/hostdevice.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/cpu/graph_send_ue_recv_funcs.h"
#include "paddle/phi/kernels/impl/graph_messaage_passing_impl.h"
namespace phi {
template <typename T, typename IndexT, typename ComputeFunctor>
void GraphSendUERecvSumCpuKernel(const BroadCastInfo& bcast,
const T* x_data,
const T* y_data,
const IndexT* src_indices,
const IndexT* dst_indices,
T* output,
int64_t index_size,
ComputeFunctor cfunctor) {
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for
#endif
for (int64_t i = 0; i < index_size; i++) {
IndexT src = src_indices[i];
IndexT dst = dst_indices[i];
T* out_off = output + dst * bcast.out_len;
const T* x_off = x_data + src * bcast.l_len;
const T* y_off = y_data + i * bcast.r_len;
for (int64_t j = 0; j < bcast.out_len; j++) {
int64_t x_add = bcast.use_bcast ? bcast.l_offset[j] : j;
int64_t y_add = bcast.use_bcast ? bcast.r_offset[j] : j;
T val = cfunctor(x_off[x_add], y_off[y_add]);
if (val != 0) {
#ifdef PADDLE_WITH_MKLML
#pragma omp atomic
#endif
out_off[j] += val;
}
}
}
}
template <typename T,
typename IndexT,
typename ComputeFunctor,
typename CmpFunctor>
void GraphSendUERecvMinMaxCpuKernel(const BroadCastInfo& bcast,
const T* x_data,
const T* y_data,
const IndexT* src_indices,
const IndexT* dst_indices,
T* output,
int64_t index_size,
ComputeFunctor cfunctor,
CmpFunctor pfunctor) {
std::set<IndexT> existed_dst;
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for
#endif
for (int64_t i = 0; i < index_size; i++) {
IndexT src = src_indices[i];
IndexT dst = dst_indices[i];
T* out_off = output + dst * bcast.out_len;
const T* x_off = x_data + src * bcast.l_len;
const T* y_off = y_data + i * bcast.r_len;
bool in_set = existed_dst.find(dst) != existed_dst.end();
for (int64_t j = 0; j < bcast.out_len; j++) {
int64_t x_add = bcast.use_bcast ? bcast.l_offset[j] : j;
int64_t y_add = bcast.use_bcast ? bcast.r_offset[j] : j;
T val = cfunctor(x_off[x_add], y_off[y_add]);
#ifdef PADDLE_WITH_MKLML
#pragma omp critical
#endif
if (!in_set) {
out_off[j] = val;
} else {
out_off[j] = pfunctor(out_off[j], val);
}
}
#ifdef PADDLE_WITH_MKLML
#pragma omp critical
#endif
if (!in_set) {
existed_dst.emplace(dst);
}
}
}
template <typename Context, typename T, typename IndexT>
void GraphSendUERecvOpKernelLaunchHelper(const Context& ctx,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& src_index,
const DenseTensor& dst_index,
const std::string& message_op,
const std::string& reduce_op,
int64_t out_size,
DenseTensor* out,
DenseTensor* dst_count = nullptr) {
const int& index_size = src_index.dims()[0];
auto out_dims = out->dims();
int64_t memset_size = 1;
std::vector<int64_t> dims_ = phi::vectorize(out_dims);
if (out_size <= 0) {
dims_[0] = x.dims()[0];
} else {
dims_[0] = out_size;
}
out->Resize(phi::make_ddim(dims_));
for (size_t i = 0; i < dims_.size(); i++) {
memset_size *= dims_[i];
}
ctx.template Alloc<T>(out);
T* out_data = out->data<T>();
const size_t& memset_bytes = memset_size * sizeof(T);
memset(out_data, 0, memset_bytes);
if (index_size == 0) return;
const auto& bcast_info = phi::CalcBCastInfo(x.dims(), y.dims());
const T* x_data = x.data<T>();
const T* y_data = y.data<T>();
const IndexT* s_index = src_index.data<IndexT>();
const IndexT* d_index = dst_index.data<IndexT>();
if (reduce_op == "SUM" || reduce_op == "MEAN") {
if (message_op == "ADD") {
GraphAddFunctor<T> add_functor;
GraphSendUERecvSumCpuKernel<T, IndexT, GraphAddFunctor<T>>(bcast_info,
x_data,
y_data,
s_index,
d_index,
out_data,
index_size,
add_functor);
} else if (message_op == "MUL") {
GraphMulFunctor<T> mul_functor;
GraphSendUERecvSumCpuKernel<T, IndexT, GraphMulFunctor<T>>(bcast_info,
x_data,
y_data,
s_index,
d_index,
out_data,
index_size,
mul_functor);
}
if (reduce_op == "MEAN") {
int64_t input_size = out_size <= 0 ? x.dims()[0] : out_size;
dst_count->Resize({input_size});
int* dst_count_data = ctx.template Alloc<int>(dst_count);
memset(dst_count_data, 0, input_size * sizeof(int));
for (int i = 0; i < index_size; i++) {
IndexT dst_idx = d_index[i];
dst_count_data[dst_idx] += 1;
}
for (int i = 0; i < input_size; i++) {
if (dst_count_data[i] == 0) continue;
auto out_slice = out->Slice(i, i + 1);
auto eigen_out = phi::EigenVector<T>::Flatten(out_slice);
eigen_out = eigen_out / static_cast<T>(dst_count_data[i]);
}
}
} else if (reduce_op == "MIN") {
GraphMinFunctor<T> min_functor;
if (message_op == "ADD") {
GraphAddFunctor<T> add_functor;
GraphSendUERecvMinMaxCpuKernel<T,
IndexT,
GraphAddFunctor<T>,
GraphMinFunctor<T>>(bcast_info,
x_data,
y_data,
s_index,
d_index,
out_data,
index_size,
add_functor,
min_functor);
} else if (message_op == "MUL") {
GraphMulFunctor<T> mul_functor;
GraphSendUERecvMinMaxCpuKernel<T,
IndexT,
GraphMulFunctor<T>,
GraphMinFunctor<T>>(bcast_info,
x_data,
y_data,
s_index,
d_index,
out_data,
index_size,
mul_functor,
min_functor);
}
} else if (reduce_op == "MAX") {
GraphMaxFunctor<T> max_functor;
if (message_op == "ADD") {
GraphAddFunctor<T> add_functor;
GraphSendUERecvMinMaxCpuKernel<T,
IndexT,
GraphAddFunctor<T>,
GraphMaxFunctor<T>>(bcast_info,
x_data,
y_data,
s_index,
d_index,
out_data,
index_size,
add_functor,
max_functor);
} else if (message_op == "MUL") {
GraphMulFunctor<T> mul_functor;
GraphSendUERecvMinMaxCpuKernel<T,
IndexT,
GraphMulFunctor<T>,
GraphMaxFunctor<T>>(bcast_info,
x_data,
y_data,
s_index,
d_index,
out_data,
index_size,
mul_functor,
max_functor);
}
}
}
template <typename T, typename Context>
void GraphSendUERecvKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& src_index,
const DenseTensor& dst_index,
const std::string& message_op,
const std::string& reduce_op,
const IntArray& out_size,
DenseTensor* out,
DenseTensor* dst_count) {
auto index_type = src_index.dtype();
auto& out_size_data = out_size.GetData();
if (index_type == phi::DataType::INT32) {
GraphSendUERecvOpKernelLaunchHelper<Context, T, int32_t>(ctx,
x,
y,
src_index,
dst_index,
message_op,
reduce_op,
out_size_data[0],
out,
dst_count);
} else if (index_type == phi::DataType::INT64) {
GraphSendUERecvOpKernelLaunchHelper<Context, T, int64_t>(ctx,
x,
y,
src_index,
dst_index,
message_op,
reduce_op,
out_size_data[0],
out,
dst_count);
}
}
} // namespace phi
PD_REGISTER_KERNEL(graph_send_ue_recv,
CPU,
ALL_LAYOUT,
phi::GraphSendUERecvKernel,
float,
double,
int,
int64_t) {}
...@@ -119,7 +119,7 @@ __global__ void ManipulateMeanCUDAKernel(T* output, ...@@ -119,7 +119,7 @@ __global__ void ManipulateMeanCUDAKernel(T* output,
CUDA_KERNEL_LOOP_TYPE(i, input_size * slice_size, int64_t) { CUDA_KERNEL_LOOP_TYPE(i, input_size * slice_size, int64_t) {
int64_t c_index = i / slice_size; int64_t c_index = i / slice_size;
if (*(count + c_index) > 1) { if (*(count + c_index) > 1) {
*(output + i) = *(output + i) / *(count + c_index); *(output + i) = *(output + i) / static_cast<T>(*(count + c_index));
} }
} }
} }
...@@ -140,8 +140,8 @@ __global__ void ManipulateMeanGradCUDAKernel(const T* params, ...@@ -140,8 +140,8 @@ __global__ void ManipulateMeanGradCUDAKernel(const T* params,
IndexT dst_i = dst_indices[indices_i]; IndexT dst_i = dst_indices[indices_i];
int64_t in_i = src_i * slice_size + slice_i; int64_t in_i = src_i * slice_size + slice_i;
int64_t out_i = dst_i * slice_size + slice_i; int64_t out_i = dst_i * slice_size + slice_i;
paddle::platform::CudaAtomicAdd(output + out_i, paddle::platform::CudaAtomicAdd(
*(params + in_i) / dst_count[src_i]); output + out_i, *(params + in_i) / static_cast<T>(dst_count[src_i]));
} }
} }
...@@ -164,7 +164,8 @@ __global__ void ManipulateMinMaxGradCUDAKernel(const T* params, ...@@ -164,7 +164,8 @@ __global__ void ManipulateMinMaxGradCUDAKernel(const T* params,
int64_t out_i = dst_i * slice_size + slice_i; int64_t out_i = dst_i * slice_size + slice_i;
paddle::platform::CudaAtomicAdd( paddle::platform::CudaAtomicAdd(
output + out_i, output + out_i,
*(params + in_i) * (*(ptr_input + out_i) == *(ptr_output + in_i))); *(params + in_i) *
static_cast<T>(*(ptr_input + out_i) == *(ptr_output + in_i)));
} }
} }
......
...@@ -31,7 +31,7 @@ void GraphSendRecvGradOpCUDAKernelLaunchHelper( ...@@ -31,7 +31,7 @@ void GraphSendRecvGradOpCUDAKernelLaunchHelper(
const DenseTensor& x, const DenseTensor& x,
const DenseTensor& src_index, const DenseTensor& src_index,
const DenseTensor& dst_index, const DenseTensor& dst_index,
const std::string& pool_type, const std::string& reduce_op,
DenseTensor* x_grad, DenseTensor* x_grad,
const DenseTensor* dst_count = nullptr, const DenseTensor* dst_count = nullptr,
const DenseTensor* out = nullptr) { const DenseTensor* out = nullptr) {
...@@ -73,16 +73,16 @@ void GraphSendRecvGradOpCUDAKernelLaunchHelper( ...@@ -73,16 +73,16 @@ void GraphSendRecvGradOpCUDAKernelLaunchHelper(
int64_t grid_tmp = (n + block - 1) / block; int64_t grid_tmp = (n + block - 1) / block;
int64_t grid = grid_tmp < max_grid_dimx ? grid_tmp : max_grid_dimx; int64_t grid = grid_tmp < max_grid_dimx ? grid_tmp : max_grid_dimx;
int64_t input_size = src_dims[0]; int64_t input_size = src_dims[0];
if (pool_type == "SUM") { if (reduce_op == "SUM") {
GraphSendRecvSumCUDAFunctor<T, IndexT> functor; GraphSendRecvSumCUDAFunctor<T, IndexT> functor;
GraphSendRecvCUDAKernel<T, IndexT, GraphSendRecvSumCUDAFunctor<T, IndexT>> GraphSendRecvCUDAKernel<T, IndexT, GraphSendRecvSumCUDAFunctor<T, IndexT>>
<<<grid, block, 0, ctx.stream()>>>( <<<grid, block, 0, ctx.stream()>>>(
p_src, d_index, s_index, p_output, index_size, slice_size, functor); p_src, d_index, s_index, p_output, index_size, slice_size, functor);
} else if (pool_type == "MEAN") { } else if (reduce_op == "MEAN") {
const int32_t* s_count = dst_count->data<int32_t>(); const int32_t* s_count = dst_count->data<int32_t>();
ManipulateMeanGradCUDAKernel<T, IndexT><<<grid, block, 0, ctx.stream()>>>( ManipulateMeanGradCUDAKernel<T, IndexT><<<grid, block, 0, ctx.stream()>>>(
p_src, d_index, s_index, p_output, index_size, slice_size, s_count); p_src, d_index, s_index, p_output, index_size, slice_size, s_count);
} else if (pool_type == "MAX" || pool_type == "MIN") { } else if (reduce_op == "MAX" || reduce_op == "MIN") {
const T* ptr_input = x.data<T>(); const T* ptr_input = x.data<T>();
const T* ptr_output = out->data<T>(); const T* ptr_output = out->data<T>();
ManipulateMinMaxGradCUDAKernel<T, IndexT> ManipulateMinMaxGradCUDAKernel<T, IndexT>
...@@ -105,7 +105,7 @@ void GraphSendRecvGradKernel(const Context& ctx, ...@@ -105,7 +105,7 @@ void GraphSendRecvGradKernel(const Context& ctx,
const paddle::optional<DenseTensor>& out, const paddle::optional<DenseTensor>& out,
const paddle::optional<DenseTensor>& dst_count, const paddle::optional<DenseTensor>& dst_count,
const DenseTensor& out_grad, const DenseTensor& out_grad,
const std::string& pool_type, const std::string& reduce_op,
DenseTensor* x_grad) { DenseTensor* x_grad) {
auto index_type = src_index.dtype(); auto index_type = src_index.dtype();
if (index_type == phi::DataType::INT32) { if (index_type == phi::DataType::INT32) {
...@@ -115,7 +115,7 @@ void GraphSendRecvGradKernel(const Context& ctx, ...@@ -115,7 +115,7 @@ void GraphSendRecvGradKernel(const Context& ctx,
x, x,
src_index, src_index,
dst_index, dst_index,
pool_type, reduce_op,
x_grad, x_grad,
dst_count.get_ptr(), dst_count.get_ptr(),
out.get_ptr()); out.get_ptr());
...@@ -126,7 +126,7 @@ void GraphSendRecvGradKernel(const Context& ctx, ...@@ -126,7 +126,7 @@ void GraphSendRecvGradKernel(const Context& ctx,
x, x,
src_index, src_index,
dst_index, dst_index,
pool_type, reduce_op,
x_grad, x_grad,
dst_count.get_ptr(), dst_count.get_ptr(),
out.get_ptr()); out.get_ptr());
...@@ -142,4 +142,5 @@ PD_REGISTER_KERNEL(graph_send_recv_grad, ...@@ -142,4 +142,5 @@ PD_REGISTER_KERNEL(graph_send_recv_grad,
float, float,
double, double,
int, int,
int64_t) {} int64_t,
phi::dtype::float16) {}
...@@ -32,7 +32,7 @@ void GraphSendRecvOpCUDAKernelLaunchHelper(const Context& ctx, ...@@ -32,7 +32,7 @@ void GraphSendRecvOpCUDAKernelLaunchHelper(const Context& ctx,
const DenseTensor& x, const DenseTensor& x,
const DenseTensor& src_index, const DenseTensor& src_index,
const DenseTensor& dst_index, const DenseTensor& dst_index,
const std::string& pool_type, const std::string& reduce_op,
int64_t out_size, int64_t out_size,
DenseTensor* out, DenseTensor* out,
DenseTensor* dst_count = nullptr) { DenseTensor* dst_count = nullptr) {
...@@ -59,19 +59,19 @@ void GraphSendRecvOpCUDAKernelLaunchHelper(const Context& ctx, ...@@ -59,19 +59,19 @@ void GraphSendRecvOpCUDAKernelLaunchHelper(const Context& ctx,
ctx.template Alloc<T>(out); ctx.template Alloc<T>(out);
T* p_output = out->data<T>(); T* p_output = out->data<T>();
const size_t& memset_bytes = memset_size * sizeof(T); const size_t& memset_bytes = memset_size * sizeof(T);
if (pool_type == "SUM" || pool_type == "MEAN") { if (reduce_op == "SUM" || reduce_op == "MEAN") {
#ifdef PADDLE_WITH_HIP #ifdef PADDLE_WITH_HIP
hipMemset(p_output, 0, memset_bytes); hipMemset(p_output, 0, memset_bytes);
#else #else
cudaMemset(p_output, 0, memset_bytes); cudaMemset(p_output, 0, memset_bytes);
#endif #endif
} else if (pool_type == "MAX") { } else if (reduce_op == "MAX") {
thrust::device_ptr<T> p_output_ptr(p_output); thrust::device_ptr<T> p_output_ptr(p_output);
thrust::fill(thrust::device, thrust::fill(thrust::device,
p_output_ptr, p_output_ptr,
p_output_ptr + memset_size, p_output_ptr + memset_size,
std::numeric_limits<T>::lowest()); std::numeric_limits<T>::lowest());
} else if (pool_type == "MIN") { } else if (reduce_op == "MIN") {
thrust::device_ptr<T> p_output_ptr(p_output); thrust::device_ptr<T> p_output_ptr(p_output);
thrust::fill(thrust::device, thrust::fill(thrust::device,
p_output_ptr, p_output_ptr,
...@@ -99,12 +99,12 @@ void GraphSendRecvOpCUDAKernelLaunchHelper(const Context& ctx, ...@@ -99,12 +99,12 @@ void GraphSendRecvOpCUDAKernelLaunchHelper(const Context& ctx,
int64_t grid_tmp = (n + block - 1) / block; int64_t grid_tmp = (n + block - 1) / block;
int64_t grid = grid_tmp < max_grid_dimx ? grid_tmp : max_grid_dimx; int64_t grid = grid_tmp < max_grid_dimx ? grid_tmp : max_grid_dimx;
int64_t input_size = out_size <= 0 ? src_dims[0] : out_size; int64_t input_size = out_size <= 0 ? src_dims[0] : out_size;
if (pool_type == "SUM") { if (reduce_op == "SUM") {
GraphSendRecvSumCUDAFunctor<T, IndexT> functor; GraphSendRecvSumCUDAFunctor<T, IndexT> functor;
GraphSendRecvCUDAKernel<T, IndexT, GraphSendRecvSumCUDAFunctor<T, IndexT>> GraphSendRecvCUDAKernel<T, IndexT, GraphSendRecvSumCUDAFunctor<T, IndexT>>
<<<grid, block, 0, ctx.stream()>>>( <<<grid, block, 0, ctx.stream()>>>(
p_src, s_index, d_index, p_output, index_size, slice_size, functor); p_src, s_index, d_index, p_output, index_size, slice_size, functor);
} else if (pool_type == "MAX") { } else if (reduce_op == "MAX") {
GraphSendRecvMaxCUDAFunctor<T, IndexT> functor; GraphSendRecvMaxCUDAFunctor<T, IndexT> functor;
GraphSendRecvCUDAKernel<T, IndexT, GraphSendRecvMaxCUDAFunctor<T, IndexT>> GraphSendRecvCUDAKernel<T, IndexT, GraphSendRecvMaxCUDAFunctor<T, IndexT>>
<<<grid, block, 0, ctx.stream()>>>( <<<grid, block, 0, ctx.stream()>>>(
...@@ -115,7 +115,7 @@ void GraphSendRecvOpCUDAKernelLaunchHelper(const Context& ctx, ...@@ -115,7 +115,7 @@ void GraphSendRecvOpCUDAKernelLaunchHelper(const Context& ctx,
grid_max_tmp < max_grid_dimx ? grid_max_tmp : max_grid_dimx; grid_max_tmp < max_grid_dimx ? grid_max_tmp : max_grid_dimx;
InputResetMaxCUDAKernel<T><<<grid_max, block, 0, ctx.stream()>>>( InputResetMaxCUDAKernel<T><<<grid_max, block, 0, ctx.stream()>>>(
p_output, input_size, slice_size); p_output, input_size, slice_size);
} else if (pool_type == "MIN") { } else if (reduce_op == "MIN") {
GraphSendRecvMinCUDAFunctor<T, IndexT> functor; GraphSendRecvMinCUDAFunctor<T, IndexT> functor;
GraphSendRecvCUDAKernel<T, IndexT, GraphSendRecvMinCUDAFunctor<T, IndexT>> GraphSendRecvCUDAKernel<T, IndexT, GraphSendRecvMinCUDAFunctor<T, IndexT>>
<<<grid, block, 0, ctx.stream()>>>( <<<grid, block, 0, ctx.stream()>>>(
...@@ -126,7 +126,7 @@ void GraphSendRecvOpCUDAKernelLaunchHelper(const Context& ctx, ...@@ -126,7 +126,7 @@ void GraphSendRecvOpCUDAKernelLaunchHelper(const Context& ctx,
grid_min_tmp < max_grid_dimx ? grid_min_tmp : max_grid_dimx; grid_min_tmp < max_grid_dimx ? grid_min_tmp : max_grid_dimx;
InputResetMinCUDAKernel<T><<<grid_min, block, 0, ctx.stream()>>>( InputResetMinCUDAKernel<T><<<grid_min, block, 0, ctx.stream()>>>(
p_output, input_size, slice_size); p_output, input_size, slice_size);
} else if (pool_type == "MEAN") { } else if (reduce_op == "MEAN") {
GraphSendRecvSumCUDAFunctor<T, IndexT> functor; GraphSendRecvSumCUDAFunctor<T, IndexT> functor;
GraphSendRecvCUDAKernel<T, IndexT, GraphSendRecvSumCUDAFunctor<T, IndexT>> GraphSendRecvCUDAKernel<T, IndexT, GraphSendRecvSumCUDAFunctor<T, IndexT>>
<<<grid, block, 0, ctx.stream()>>>( <<<grid, block, 0, ctx.stream()>>>(
...@@ -158,7 +158,7 @@ void GraphSendRecvKernel(const Context& ctx, ...@@ -158,7 +158,7 @@ void GraphSendRecvKernel(const Context& ctx,
const DenseTensor& x, const DenseTensor& x,
const DenseTensor& src_index, const DenseTensor& src_index,
const DenseTensor& dst_index, const DenseTensor& dst_index,
const std::string& pool_type, const std::string& reduce_op,
const IntArray& out_size, const IntArray& out_size,
DenseTensor* out, DenseTensor* out,
DenseTensor* dst_count) { DenseTensor* dst_count) {
...@@ -169,7 +169,7 @@ void GraphSendRecvKernel(const Context& ctx, ...@@ -169,7 +169,7 @@ void GraphSendRecvKernel(const Context& ctx,
x, x,
src_index, src_index,
dst_index, dst_index,
pool_type, reduce_op,
out_size_data[0], out_size_data[0],
out, out,
dst_count); dst_count);
...@@ -178,7 +178,7 @@ void GraphSendRecvKernel(const Context& ctx, ...@@ -178,7 +178,7 @@ void GraphSendRecvKernel(const Context& ctx,
x, x,
src_index, src_index,
dst_index, dst_index,
pool_type, reduce_op,
out_size_data[0], out_size_data[0],
out, out,
dst_count); dst_count);
...@@ -194,4 +194,5 @@ PD_REGISTER_KERNEL(graph_send_recv, ...@@ -194,4 +194,5 @@ PD_REGISTER_KERNEL(graph_send_recv,
float, float,
double, double,
int, int,
int64_t) {} int64_t,
phi::dtype::float16) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <thrust/device_vector.h>
#include <thrust/fill.h>
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/hostdevice.h"
#include "paddle/phi/kernels/impl/graph_messaage_passing_impl.h"
namespace phi {
inline void CopyBCastOff(const BroadCastInfo& bcast_info,
thrust::device_vector<int64_t>& l_bcastoff,
thrust::device_vector<int64_t>& r_bcastoff) {
l_bcastoff.resize(bcast_info.out_len);
r_bcastoff.resize(bcast_info.out_len);
#ifdef PADDLE_WITH_HIP
hipMemcpy(thrust::raw_pointer_cast(l_bcastoff.data()),
bcast_info.l_offset.data(),
sizeof(int64_t) * bcast_info.out_len,
hipMemcpyHostToDevice);
hipMemcpy(thrust::raw_pointer_cast(r_bcastoff.data()),
bcast_info.r_offset.data(),
sizeof(int64_t) * bcast_info.out_len,
hipMemcpyHostToDevice);
#else
cudaMemcpy(thrust::raw_pointer_cast(l_bcastoff.data()),
bcast_info.l_offset.data(),
sizeof(int64_t) * bcast_info.out_len,
cudaMemcpyHostToDevice);
cudaMemcpy(thrust::raw_pointer_cast(r_bcastoff.data()),
bcast_info.r_offset.data(),
sizeof(int64_t) * bcast_info.out_len,
cudaMemcpyHostToDevice);
#endif
}
inline int FindNumThreads(int dim, int max_num_threads) {
PADDLE_ENFORCE_GE(dim,
0,
phi::errors::PreconditionNotMet(
"Required dim >= 0, but received dim = %d", dim));
int res = max_num_threads;
if (dim == 0) res = 1;
while (res > dim) {
res = res >> 1;
}
res = res <= 32 ? 32 : res;
return res;
}
template <typename T>
struct GraphSendUERecvSumCUDAFunctor {
DEVICE inline void operator()(T* output, T val) {
paddle::platform::CudaAtomicAdd(output, val);
}
};
template <typename T>
struct GraphSendUERecvMaxCUDAFunctor {
DEVICE inline void operator()(T* output, T val) {
paddle::platform::CudaAtomicMax(output, val);
}
};
template <typename T>
struct GraphSendUERecvMinCUDAFunctor {
DEVICE inline void operator()(T* output, T val) {
paddle::platform::CudaAtomicMin(output, val);
}
};
template <typename T,
typename IndexT,
typename ReduceFunctor,
typename ComputeFunctor>
__global__ void GraphSendUERecvCUDAKernel(const T* x_data,
const T* e_data,
const IndexT* src_indices,
const IndexT* dst_indices,
const int64_t* xbcast_off,
const int64_t* ebcast_off,
T* output,
int64_t index_size,
int64_t x_len,
int64_t e_len,
int64_t out_len,
bool use_bcast,
ComputeFunctor cfunctor,
ReduceFunctor rfunctor) {
IndexT ty = blockIdx.y * blockDim.y + threadIdx.y;
const IndexT stride_y = blockDim.y * gridDim.y;
while (ty < index_size) {
IndexT src = src_indices[ty];
IndexT dst = dst_indices[ty];
int64_t tx = blockIdx.x * blockDim.x + threadIdx.x;
int64_t stride_x = blockDim.x * gridDim.x;
const T* x_off = x_data + src * x_len;
const T* e_off = e_data + ty * e_len;
T* out_off = output + dst * out_len;
while (tx < out_len) {
int64_t x_add = use_bcast ? xbcast_off[tx] : tx;
int64_t e_add = use_bcast ? ebcast_off[tx] : tx;
T val = cfunctor(x_off[x_add], e_off[e_add]);
rfunctor(out_off + tx, val);
tx += stride_x;
}
ty += stride_y;
}
}
// x_grad: for backward mean with mul.
template <typename T, typename IndexT>
__global__ void ManipulateMeanGradCUDAKernelForMulX(const T* out_grad_data,
const T* e_data,
const IndexT* src_indices,
const IndexT* dst_indices,
const int* dst_count,
const int64_t* l_bcastoff,
const int64_t* r_bcastoff,
T* x_grad,
int64_t index_size,
int64_t l_len,
int64_t r_len,
int64_t out_len,
bool use_bcast) {
IndexT ty = blockIdx.y * blockDim.y + threadIdx.y;
const IndexT stride_y = blockDim.y * gridDim.y;
while (ty < index_size) {
IndexT src = src_indices[ty];
IndexT dst = dst_indices[ty];
int64_t tx = blockIdx.x * blockDim.x + threadIdx.x;
int64_t stride_x = blockDim.x * gridDim.x;
const T* out_grad_off = out_grad_data + src * l_len;
const T* e_off = e_data + ty * r_len;
T* x_grad_off = x_grad + dst * out_len;
while (tx < out_len) {
int64_t o_add = use_bcast ? l_bcastoff[tx] : tx;
int64_t e_add = use_bcast ? r_bcastoff[tx] : tx;
T val = out_grad_off[o_add] * e_off[e_add];
paddle::platform::CudaAtomicAdd(x_grad_off + tx,
val / static_cast<T>(dst_count[src]));
tx += stride_x;
}
ty += stride_y;
}
}
// e_grad: backward sum for add.
template <typename T, typename IndexT>
__global__ void ManipulateSumGradCUDAKernelForAddE(const T* out_grad_data,
const IndexT* dst_indices,
const int64_t* r_bcastoff,
T* e_grad,
int64_t index_size,
int64_t r_len,
int64_t out_len,
bool use_bcast) {
IndexT ty = blockIdx.y * blockDim.y + threadIdx.y;
const IndexT stride_y = blockDim.y * gridDim.y;
while (ty < index_size) {
IndexT dst = dst_indices[ty];
int64_t tx = blockIdx.x * blockDim.x + threadIdx.x;
int64_t stride_x = blockDim.x * gridDim.x;
T* e_grad_off = e_grad + ty * r_len;
const T* out_grad_off = out_grad_data + dst * out_len;
while (tx < out_len) {
int64_t e_add = use_bcast ? r_bcastoff[tx] : tx;
paddle::platform::CudaAtomicAdd(e_grad_off + e_add, out_grad_off[tx]);
tx += stride_x;
}
ty += stride_y;
}
}
// e_grad: backward sum for mul.
template <typename T, typename IndexT>
__global__ void ManipulateSumGradCUDAKernelForMulE(const T* x_data,
const T* out_grad_data,
const IndexT* src_indices,
const IndexT* dst_indices,
const int64_t* l_bcastoff,
const int64_t* r_bcastoff,
T* e_grad,
int64_t index_size,
int64_t l_len,
int64_t r_len,
int64_t out_len,
bool use_bcast) {
IndexT ty = blockIdx.y * blockDim.y + threadIdx.y;
const IndexT stride_y = blockDim.y * gridDim.y;
while (ty < index_size) {
IndexT src = src_indices[ty];
IndexT dst = dst_indices[ty];
int64_t tx = blockIdx.x * blockDim.x + threadIdx.x;
int64_t stride_x = blockDim.x * gridDim.x;
const T* x_off = x_data + src * l_len;
T* e_grad_off = e_grad + ty * r_len;
const T* out_grad_off = out_grad_data + dst * out_len;
while (tx < out_len) {
int64_t x_add = use_bcast ? l_bcastoff[tx] : tx;
int64_t e_add = use_bcast ? r_bcastoff[tx] : tx;
paddle::platform::CudaAtomicAdd(e_grad_off + e_add,
out_grad_off[tx] * x_off[x_add]);
tx += stride_x;
}
ty += stride_y;
}
}
// e_grad: backward mean for add
template <typename T, typename IndexT>
__global__ void ManipulateMeanGradCUDAKernelForAddE(const T* out_grad_data,
const IndexT* dst_indices,
const int* dst_count,
const int64_t* r_bcastoff,
T* e_grad,
int64_t index_size,
int64_t r_len,
int64_t out_len,
bool use_bcast) {
IndexT ty = blockIdx.y * blockDim.y + threadIdx.y;
const IndexT stride_y = blockDim.y * gridDim.y;
while (ty < index_size) {
IndexT dst = dst_indices[ty];
int64_t tx = blockIdx.x * blockDim.x + threadIdx.x;
int64_t stride_x = blockDim.x * gridDim.x;
T* e_grad_off = e_grad + ty * r_len;
const T* out_grad_off = out_grad_data + dst * out_len;
while (tx < out_len) {
int64_t e_add = use_bcast ? r_bcastoff[tx] : tx;
paddle::platform::CudaAtomicAdd(
e_grad_off + e_add,
out_grad_off[tx] / static_cast<T>(dst_count[dst]));
tx += stride_x;
}
ty += stride_y;
}
}
// e_grad: backward mean for mul.
template <typename T, typename IndexT>
__global__ void ManipulateMeanGradCUDAKernelForMulE(const T* x_data,
const T* out_grad_data,
const IndexT* src_indices,
const IndexT* dst_indices,
const int* dst_count,
const int64_t* l_bcastoff,
const int64_t* r_bcastoff,
T* e_grad,
int64_t index_size,
int64_t l_len,
int64_t r_len,
int64_t out_len,
bool use_bcast) {
IndexT ty = blockIdx.y * blockDim.y + threadIdx.y;
const IndexT stride_y = blockDim.y * gridDim.y;
while (ty < index_size) {
IndexT src = src_indices[ty];
IndexT dst = dst_indices[ty];
int64_t tx = blockIdx.x * blockDim.x + threadIdx.x;
int64_t stride_x = blockDim.x * gridDim.x;
const T* x_off = x_data + src * l_len;
T* e_grad_off = e_grad + ty * r_len;
const T* out_grad_off = out_grad_data + dst * out_len;
while (tx < out_len) {
int64_t x_add = use_bcast ? l_bcastoff[tx] : tx;
int64_t e_add = use_bcast ? r_bcastoff[tx] : tx;
paddle::platform::CudaAtomicAdd(
e_grad_off + e_add,
out_grad_off[tx] * x_off[x_add] / static_cast<T>(dst_count[dst]));
tx += stride_x;
}
ty += stride_y;
}
}
// x_grad, e_grad: backward min and max for add.
template <typename T, typename IndexT>
__global__ void ManipulateMinMaxGradCUDAKernelForAdd(const T* x_data,
const T* e_data,
const T* out,
const T* out_grad,
const IndexT* src_indices,
const IndexT* dst_indices,
const int64_t* xbcast_off,
const int64_t* ebcast_off,
T* x_grad,
T* e_grad,
int64_t index_size,
int64_t x_len,
int64_t e_len,
int64_t out_len,
bool use_bcast) {
IndexT ty = blockIdx.y * blockDim.y + threadIdx.y;
const IndexT stride_y = blockDim.y * gridDim.y;
while (ty < index_size) {
IndexT src = src_indices[ty];
IndexT dst = dst_indices[ty];
int64_t tx = blockIdx.x * blockDim.x + threadIdx.x;
int64_t stride_x = blockDim.x * gridDim.x;
const T* x_off = x_data + dst * x_len;
const T* e_off = e_data + ty * e_len;
const T* out_off = out + src * out_len;
const T* out_grad_off = out_grad + src * out_len;
T* x_grad_off = x_grad + dst * x_len;
T* e_grad_off = e_grad + ty * e_len;
while (tx < out_len) {
int64_t x_add = use_bcast ? xbcast_off[tx] : tx;
int64_t e_add = use_bcast ? ebcast_off[tx] : tx;
T val = x_off[x_add] + e_off[e_add];
paddle::platform::CudaAtomicAdd(
x_grad_off + x_add,
out_grad_off[tx] * static_cast<T>(val == out_off[tx]));
paddle::platform::CudaAtomicAdd(
e_grad_off + e_add,
out_grad_off[tx] * static_cast<T>(val == out_off[tx]));
tx += stride_x;
}
ty += stride_y;
}
}
// x_grad, e_grad: backward min and max for mul.
template <typename T, typename IndexT>
__global__ void ManipulateMinMaxGradCUDAKernelForMul(const T* x_data,
const T* e_data,
const T* out,
const T* out_grad,
const IndexT* src_indices,
const IndexT* dst_indices,
const int64_t* xbcast_off,
const int64_t* ebcast_off,
T* x_grad,
T* e_grad,
int64_t index_size,
int64_t x_len,
int64_t e_len,
int64_t out_len,
bool use_bcast) {
IndexT ty = blockIdx.y * blockDim.y + threadIdx.y;
const IndexT stride_y = blockDim.y * gridDim.y;
while (ty < index_size) {
IndexT src = src_indices[ty];
IndexT dst = dst_indices[ty];
int64_t tx = blockIdx.x * blockDim.x + threadIdx.x;
int64_t stride_x = blockDim.x * gridDim.x;
const T* x_off = x_data + dst * x_len;
const T* e_off = e_data + ty * e_len;
const T* out_off = out + src * out_len;
const T* out_grad_off = out_grad + src * out_len;
T* x_grad_off = x_grad + dst * x_len;
T* e_grad_off = e_grad + ty * e_len;
while (tx < out_len) {
int64_t x_add = use_bcast ? xbcast_off[tx] : tx;
int64_t e_add = use_bcast ? ebcast_off[tx] : tx;
T val = x_off[x_add] * e_off[e_add];
paddle::platform::CudaAtomicAdd(
x_grad_off + x_add,
out_grad_off[tx] * static_cast<T>(val == out_off[tx]) * e_off[e_add]);
paddle::platform::CudaAtomicAdd(
e_grad_off + e_add,
out_grad_off[tx] * static_cast<T>(val == out_off[tx]) * x_off[x_add]);
tx += stride_x;
}
ty += stride_y;
}
}
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/graph_send_ue_recv_kernel.h"
#include "paddle/phi/kernels/gpu/graph_send_recv_funcs.h"
#include "paddle/phi/kernels/gpu/graph_send_ue_recv_funcs.h"
#include "paddle/phi/kernels/impl/graph_messaage_passing_impl.h"
#include <thrust/device_vector.h>
#include <thrust/fill.h>
#include <algorithm>
#include <vector>
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/hostdevice.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/elementwise_functor.h"
namespace phi {
template <typename Context, typename T, typename IndexT>
void GraphSendUERecvOpCUDAKernelLaunchHelper(const Context& ctx,
const DenseTensor& x,
const DenseTensor& e,
const DenseTensor& src_index,
const DenseTensor& dst_index,
const std::string& message_op,
const std::string& reduce_op,
int64_t out_size,
DenseTensor* out,
DenseTensor* dst_count = nullptr) {
const int& index_size = src_index.dims()[0];
auto out_dims = out->dims();
int64_t memset_size = 1;
std::vector<int64_t> dims_ = phi::vectorize(out_dims);
if (out_size <= 0) {
dims_[0] = x.dims()[0];
} else {
dims_[0] = out_size;
}
out->Resize(phi::make_ddim(dims_));
for (size_t i = 0; i < dims_.size(); i++) {
memset_size *= dims_[i];
}
ctx.template Alloc<T>(out);
T* out_data = out->data<T>();
const size_t& memset_bytes = memset_size * sizeof(T);
if (reduce_op == "SUM" || reduce_op == "MEAN") {
#ifdef PADDLE_WITH_HIP
hipMemset(out_data, 0, memset_bytes);
#else
cudaMemset(out_data, 0, memset_bytes);
#endif
} else if (reduce_op == "MAX") {
thrust::device_ptr<T> out_data_ptr(out_data);
thrust::fill(thrust::device,
out_data_ptr,
out_data_ptr + memset_size,
std::numeric_limits<T>::lowest());
} else if (reduce_op == "MIN") {
thrust::device_ptr<T> out_data_ptr(out_data);
thrust::fill(thrust::device,
out_data_ptr,
out_data_ptr + memset_size,
std::numeric_limits<T>::max());
}
if (index_size == 0) return;
const auto& bcast_info = phi::CalcBCastInfo(x.dims(), e.dims());
const T* x_data = x.data<T>();
const T* e_data = e.data<T>();
const IndexT* s_index = src_index.data<IndexT>();
const IndexT* d_index = dst_index.data<IndexT>();
thrust::device_vector<int64_t> x_bcastoff, e_bcastoff;
if (bcast_info.use_bcast) {
CopyBCastOff(bcast_info, x_bcastoff, e_bcastoff);
}
int64_t out_len = bcast_info.out_len;
const int ntx = FindNumThreads(out_len, ctx.GetMaxThreadsPerBlock());
const int nty = ctx.GetMaxThreadsPerBlock() / ntx;
const int nbx = (out_len + ntx - 1) / ntx;
const int nby = (index_size + nty - 1) / nty;
const dim3 grid(nbx, nby);
const dim3 block(ntx, nty);
int64_t input_size = x.dims()[0];
#ifdef PADDLE_WITH_HIP
int block_ = 256;
#else
int block_ = 1024;
#endif
if (reduce_op == "SUM" || reduce_op == "MEAN") {
GraphSendUERecvSumCUDAFunctor<T> sum_functor;
if (message_op == "ADD") {
funcs::AddFunctor<T> add_funtor;
GraphSendUERecvCUDAKernel<T,
IndexT,
GraphSendUERecvSumCUDAFunctor<T>,
funcs::AddFunctor<T>>
<<<grid, block, 0, ctx.stream()>>>(
x_data,
e_data,
s_index,
d_index,
thrust::raw_pointer_cast(x_bcastoff.data()),
thrust::raw_pointer_cast(e_bcastoff.data()),
out_data,
index_size,
bcast_info.l_len,
bcast_info.r_len,
out_len,
bcast_info.use_bcast,
add_funtor,
sum_functor);
} else if (message_op == "MUL") {
funcs::MultiplyFunctor<T> mul_functor;
GraphSendUERecvCUDAKernel<T,
IndexT,
GraphSendUERecvSumCUDAFunctor<T>,
funcs::MultiplyFunctor<T>>
<<<grid, block, 0, ctx.stream()>>>(
x_data,
e_data,
s_index,
d_index,
thrust::raw_pointer_cast(x_bcastoff.data()),
thrust::raw_pointer_cast(e_bcastoff.data()),
out_data,
index_size,
bcast_info.l_len,
bcast_info.r_len,
out_len,
bcast_info.use_bcast,
mul_functor,
sum_functor);
}
if (reduce_op == "MEAN") {
input_size = out_size <= 0 ? x.dims()[0] : out_size;
dst_count->Resize({input_size});
ctx.template Alloc<int>(dst_count);
int* dst_count_data = dst_count->data<int>();
#ifdef PADDLE_WITH_HIP
hipMemset(dst_count_data, 0, input_size * sizeof(int));
#else
cudaMemset(dst_count_data, 0, input_size * sizeof(int));
#endif
int64_t grid_count = (index_size + block_ - 1) / block_;
ComputeCountCUDAKernel<T, IndexT>
<<<grid_count, block_, 0, ctx.stream()>>>(
dst_count_data, d_index, index_size);
int64_t grid_mean = (input_size * out_len + block_ - 1) / block_;
int64_t max_grid_dimx = ctx.GetCUDAMaxGridDimSize()[0];
int64_t grid_mean_ =
grid_mean < max_grid_dimx ? grid_mean : max_grid_dimx;
ManipulateMeanCUDAKernel<T><<<grid_mean_, block_, 0, ctx.stream()>>>(
out_data, dst_count_data, input_size, out_len);
}
} else if (reduce_op == "MAX") {
GraphSendUERecvMaxCUDAFunctor<T> max_functor;
if (message_op == "ADD") {
funcs::AddFunctor<T> add_funtor;
GraphSendUERecvCUDAKernel<T,
IndexT,
GraphSendUERecvMaxCUDAFunctor<T>,
funcs::AddFunctor<T>>
<<<grid, block, 0, ctx.stream()>>>(
x_data,
e_data,
s_index,
d_index,
thrust::raw_pointer_cast(x_bcastoff.data()),
thrust::raw_pointer_cast(e_bcastoff.data()),
out_data,
index_size,
bcast_info.l_len,
bcast_info.r_len,
out_len,
bcast_info.use_bcast,
add_funtor,
max_functor);
} else if (message_op == "MUL") {
funcs::MultiplyFunctor<T> mul_functor;
GraphSendUERecvCUDAKernel<T,
IndexT,
GraphSendUERecvMaxCUDAFunctor<T>,
funcs::MultiplyFunctor<T>>
<<<grid, block, 0, ctx.stream()>>>(
x_data,
e_data,
s_index,
d_index,
thrust::raw_pointer_cast(x_bcastoff.data()),
thrust::raw_pointer_cast(e_bcastoff.data()),
out_data,
index_size,
bcast_info.l_len,
bcast_info.r_len,
out_len,
bcast_info.use_bcast,
mul_functor,
max_functor);
}
if (out_size > 0) {
input_size = out_size;
}
int64_t grid_max = (input_size * out_len + block_ - 1) / block_;
int64_t max_grid_dimx = ctx.GetCUDAMaxGridDimSize()[0];
int64_t grid_max_ = grid_max < max_grid_dimx ? grid_max : max_grid_dimx;
InputResetMaxCUDAKernel<T>
<<<grid_max_, block_, 0, ctx.stream()>>>(out_data, input_size, out_len);
} else if (reduce_op == "MIN") {
GraphSendUERecvMinCUDAFunctor<T> min_functor;
if (message_op == "ADD") {
funcs::AddFunctor<T> add_funtor;
GraphSendUERecvCUDAKernel<T,
IndexT,
GraphSendUERecvMinCUDAFunctor<T>,
funcs::AddFunctor<T>>
<<<grid, block, 0, ctx.stream()>>>(
x_data,
e_data,
s_index,
d_index,
thrust::raw_pointer_cast(x_bcastoff.data()),
thrust::raw_pointer_cast(e_bcastoff.data()),
out_data,
index_size,
bcast_info.l_len,
bcast_info.r_len,
out_len,
bcast_info.use_bcast,
add_funtor,
min_functor);
} else if (message_op == "MUL") {
funcs::MultiplyFunctor<T> mul_functor;
GraphSendUERecvCUDAKernel<T,
IndexT,
GraphSendUERecvMinCUDAFunctor<T>,
funcs::MultiplyFunctor<T>>
<<<grid, block, 0, ctx.stream()>>>(
x_data,
e_data,
s_index,
d_index,
thrust::raw_pointer_cast(x_bcastoff.data()),
thrust::raw_pointer_cast(e_bcastoff.data()),
out_data,
index_size,
bcast_info.l_len,
bcast_info.r_len,
out_len,
bcast_info.use_bcast,
mul_functor,
min_functor);
}
if (out_size > 0) {
input_size = out_size;
}
int64_t grid_min = (input_size * out_len + block_ - 1) / block_;
int64_t max_grid_dimx = ctx.GetCUDAMaxGridDimSize()[0];
int64_t grid_min_ = grid_min < max_grid_dimx ? grid_min : max_grid_dimx;
InputResetMinCUDAKernel<T>
<<<grid_min_, block_, 0, ctx.stream()>>>(out_data, input_size, out_len);
}
}
template <typename T, typename Context>
void GraphSendUERecvKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& src_index,
const DenseTensor& dst_index,
const std::string& message_op,
const std::string& reduce_op,
const IntArray& out_size,
DenseTensor* out,
DenseTensor* dst_count) {
auto index_type = src_index.dtype();
auto& out_size_data = out_size.GetData();
if (index_type == phi::DataType::INT32) {
GraphSendUERecvOpCUDAKernelLaunchHelper<Context, T, int32_t>(
ctx,
x,
y,
src_index,
dst_index,
message_op,
reduce_op,
out_size_data[0],
out,
dst_count);
} else if (index_type == phi::DataType::INT64) {
GraphSendUERecvOpCUDAKernelLaunchHelper<Context, T, int64_t>(
ctx,
x,
y,
src_index,
dst_index,
message_op,
reduce_op,
out_size_data[0],
out,
dst_count);
}
}
} // namespace phi
PD_REGISTER_KERNEL(graph_send_ue_recv,
GPU,
ALL_LAYOUT,
phi::GraphSendUERecvKernel,
float,
double,
int,
int64_t,
phi::dtype::float16) {}
...@@ -29,6 +29,6 @@ void GraphSendRecvGradKernel(const Context& ctx, ...@@ -29,6 +29,6 @@ void GraphSendRecvGradKernel(const Context& ctx,
const paddle::optional<DenseTensor>& out, const paddle::optional<DenseTensor>& out,
const paddle::optional<DenseTensor>& dst_count, const paddle::optional<DenseTensor>& dst_count,
const DenseTensor& out_grad, const DenseTensor& out_grad,
const std::string& pool_type, const std::string& reduce_op,
DenseTensor* x_grad); DenseTensor* x_grad);
} // namespace phi } // namespace phi
...@@ -26,7 +26,7 @@ void GraphSendRecvKernel(const Context& ctx, ...@@ -26,7 +26,7 @@ void GraphSendRecvKernel(const Context& ctx,
const DenseTensor& x, const DenseTensor& x,
const DenseTensor& src_index, const DenseTensor& src_index,
const DenseTensor& dst_index, const DenseTensor& dst_index,
const std::string& pool_type, const std::string& reduce_op,
const IntArray& out_size, const IntArray& out_size,
DenseTensor* out, DenseTensor* out,
DenseTensor* dst_count); DenseTensor* dst_count);
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/utils/optional.h"
namespace phi {
template <typename T, typename Context>
void GraphSendUERecvGradKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& src_index,
const DenseTensor& dst_index,
const paddle::optional<DenseTensor>& out,
const paddle::optional<DenseTensor>& dst_count,
const DenseTensor& out_grad,
const std::string& message_op,
const std::string& reduce_op,
DenseTensor* x_grad,
DenseTensor* y_grad);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "paddle/phi/common/int_array.h"
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void GraphSendUERecvKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& src_index,
const DenseTensor& dst_index,
const std::string& message_op,
const std::string& reduce_op,
const IntArray& out_size,
DenseTensor* out,
DenseTensor* dst_count);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
// Copyright The DGL team.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <vector>
#include "paddle/phi/kernels/funcs/common_shape.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
namespace phi {
struct BroadCastInfo {
bool use_bcast;
// l_offset[i] indicates the start position of tensor lhs that required to
// compute the i-th element in output, so as r_offset[i].
std::vector<int64_t> l_offset, r_offset;
int64_t l_len, r_len, out_len, reduce_size;
};
inline bool UseBroadCast(const phi::DDim& l_dims, const phi::DDim& r_dims) {
if (l_dims.size() != r_dims.size()) {
return true;
}
for (int i = 1; i < l_dims.size(); i++) {
if (l_dims[i] != r_dims[i]) {
return true;
}
}
return false;
}
inline BroadCastInfo CalcBCastInfo(const phi::DDim& l_dims,
const phi::DDim& r_dims) {
BroadCastInfo binfo;
binfo.use_bcast = UseBroadCast(l_dims, r_dims);
binfo.l_len = 1;
binfo.r_len = 1;
for (int i = 1; i < l_dims.size(); i++) {
binfo.l_len *= l_dims[i];
}
for (int i = 1; i < r_dims.size(); i++) {
binfo.r_len *= r_dims[i];
}
// TODO(daisiming): Whether to add dot.
binfo.reduce_size = 1;
if (binfo.use_bcast) {
const int max_dim = std::max(l_dims.size(), r_dims.size()) - 1;
int stride_l = 1, stride_r = 1;
binfo.l_offset.emplace_back(0);
binfo.r_offset.emplace_back(0);
int out_len = 1;
for (int i = 0; i < max_dim; i++) {
// Iterate the axis from back to front.
const int dl =
(l_dims.size() - 1 - i < 1) ? 1 : l_dims[l_dims.size() - 1 - i];
const int dr =
(r_dims.size() - 1 - i < 1) ? 1 : r_dims[r_dims.size() - 1 - i];
for (int j = 1; j < std::max(dl, dr); j++) {
for (int k = 0; k < out_len; k++) {
binfo.l_offset.emplace_back(binfo.l_offset[k] +
j * (j < dl) * stride_l);
binfo.r_offset.emplace_back(binfo.r_offset[k] +
j * (j < dr) * stride_r);
}
}
out_len *= std::max(dl, dr);
stride_l *= dl;
stride_r *= dr;
}
binfo.out_len = out_len;
} else {
binfo.out_len = binfo.l_len;
}
return binfo;
}
inline std::vector<int> InferBroadcastShape(const phi::DDim& x_dims,
const phi::DDim& e_dims,
const std::string& type = "x") {
auto x_dims1 = phi::vectorize<int>(x_dims);
auto e_dims1 = phi::vectorize<int>(e_dims);
std::vector<int> x_dims2(x_dims1.begin() + 1, x_dims1.end());
std::vector<int> e_dims2(e_dims1.begin() + 1, e_dims1.end());
int max_dim = std::max(x_dims2.size(), e_dims2.size());
int axis = std::abs(static_cast<int>(x_dims2.size() - e_dims2.size()));
std::vector<int> x_dims_array(max_dim);
std::vector<int> e_dims_array(max_dim);
std::vector<int> out_dims_array(max_dim);
// Only need to broadcast dimensions other than the 0th dimension.
phi::funcs::GetBroadcastDimsArrays(phi::make_ddim(x_dims2),
phi::make_ddim(e_dims2),
x_dims_array.data(),
e_dims_array.data(),
out_dims_array.data(),
max_dim,
axis);
if (type == "x") {
out_dims_array.insert(out_dims_array.begin(), x_dims[0]);
} else {
out_dims_array.insert(out_dims_array.begin(), e_dims[0]);
}
return out_dims_array;
}
inline bool ReduceGrad(const phi::DDim& out_grad_dims,
const phi::DDim& x_dims,
std::vector<int64_t>& axis) {
// We must ensure the ndim of out_grad and x are the same.
bool reduce = false;
for (int i = 1; i < out_grad_dims.size(); i++) {
if (out_grad_dims[i] != x_dims[i]) {
reduce = true;
break;
}
}
if (!reduce) return false;
// Get reduce axis.
for (int i = 1; i < out_grad_dims.size(); i++) {
if (out_grad_dims[i] - x_dims[i] != 0) {
axis.emplace_back(i);
}
}
return true;
}
} // namespace phi
...@@ -21,12 +21,12 @@ KernelSignature GraphSendRecvOpArgumentMapping( ...@@ -21,12 +21,12 @@ KernelSignature GraphSendRecvOpArgumentMapping(
if (ctx.HasInput("Out_size")) { if (ctx.HasInput("Out_size")) {
return KernelSignature("graph_send_recv", return KernelSignature("graph_send_recv",
{"X", "Src_index", "Dst_index"}, {"X", "Src_index", "Dst_index"},
{"pool_type", "Out_size"}, {"reduce_op", "Out_size"},
{"Out", "Dst_count"}); {"Out", "Dst_count"});
} else { } else {
return KernelSignature("graph_send_recv", return KernelSignature("graph_send_recv",
{"X", "Src_index", "Dst_index"}, {"X", "Src_index", "Dst_index"},
{"pool_type", "out_size"}, {"reduce_op", "out_size"},
{"Out", "Dst_count"}); {"Out", "Dst_count"});
} }
} }
...@@ -36,7 +36,7 @@ KernelSignature GraphSendRecvGradOpArgumentMapping( ...@@ -36,7 +36,7 @@ KernelSignature GraphSendRecvGradOpArgumentMapping(
return KernelSignature( return KernelSignature(
"graph_send_recv_grad", "graph_send_recv_grad",
{"X", "Src_index", "Dst_index", "Out", "Dst_count", "Out@GRAD"}, {"X", "Src_index", "Dst_index", "Out", "Dst_count", "Out@GRAD"},
{"pool_type"}, {"reduce_op"},
{"X@GRAD"}); {"X@GRAD"});
} }
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature GraphSendUERecvOpArgumentMapping(
const ArgumentMappingContext& ctx) {
if (ctx.HasInput("Out_size")) {
return KernelSignature("graph_send_ue_recv",
{"X", "Y", "Src_index", "Dst_index"},
{"message_op", "reduce_op", "Out_size"},
{"Out", "Dst_count"});
} else {
return KernelSignature("graph_send_ue_recv",
{"X", "Y", "Src_index", "Dst_index"},
{"message_op", "reduce_op", "out_size"},
{"Out", "Dst_count"});
}
}
KernelSignature GraphSendUERecvGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature(
"graph_send_ue_recv_grad",
{"X", "Y", "Src_index", "Dst_index", "Out", "Dst_count", "Out@GRAD"},
{"message_op", "reduce_op"},
{"X@GRAD", "Y@GRAD"});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(graph_send_ue_recv,
phi::GraphSendUERecvOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(graph_send_ue_recv_grad,
phi::GraphSendUERecvGradOpArgumentMapping);
...@@ -1562,6 +1562,7 @@ set_tests_properties(test_reader_reset PROPERTIES TIMEOUT 120) ...@@ -1562,6 +1562,7 @@ set_tests_properties(test_reader_reset PROPERTIES TIMEOUT 120)
set_tests_properties(test_pool3d_api PROPERTIES TIMEOUT 120) set_tests_properties(test_pool3d_api PROPERTIES TIMEOUT 120)
set_tests_properties(test_cumprod_op PROPERTIES TIMEOUT 120) set_tests_properties(test_cumprod_op PROPERTIES TIMEOUT 120)
set_tests_properties(test_split_program PROPERTIES TIMEOUT 120) set_tests_properties(test_split_program PROPERTIES TIMEOUT 120)
set_tests_properties(test_graph_send_ue_recv_op PROPERTIES TIMEOUT 60)
if(WITH_DISTRIBUTE if(WITH_DISTRIBUTE
AND WITH_GPU AND WITH_GPU
AND WITH_NCCL) AND WITH_NCCL)
......
...@@ -25,11 +25,11 @@ from op_test import OpTest ...@@ -25,11 +25,11 @@ from op_test import OpTest
def graph_send_recv_wrapper(x, def graph_send_recv_wrapper(x,
src_index, src_index,
dst_index, dst_index,
pool_type="sum", reduce_op="sum",
out_size=None, out_size=None,
name=None): name=None):
return paddle.geometric.send_u_recv(x, src_index, dst_index, return paddle.geometric.send_u_recv(x, src_index, dst_index,
pool_type.lower(), out_size, name) reduce_op.lower(), out_size, name)
class TestGraphSendRecvMaxOp(OpTest): class TestGraphSendRecvMaxOp(OpTest):
...@@ -46,7 +46,7 @@ class TestGraphSendRecvMaxOp(OpTest): ...@@ -46,7 +46,7 @@ class TestGraphSendRecvMaxOp(OpTest):
self.inputs = {'X': x, 'Src_index': src_index, 'Dst_index': dst_index} self.inputs = {'X': x, 'Src_index': src_index, 'Dst_index': dst_index}
self.attrs = {'pool_type': 'MAX'} self.attrs = {'reduce_op': 'MAX'}
out, self.gradient = compute_graph_send_recv_for_min_max( out, self.gradient = compute_graph_send_recv_for_min_max(
self.inputs, self.attrs) self.inputs, self.attrs)
...@@ -76,7 +76,7 @@ class TestGraphSendRecvMinOp(OpTest): ...@@ -76,7 +76,7 @@ class TestGraphSendRecvMinOp(OpTest):
self.inputs = {'X': x, 'Src_index': src_index, 'Dst_index': dst_index} self.inputs = {'X': x, 'Src_index': src_index, 'Dst_index': dst_index}
self.attrs = {'pool_type': 'MIN'} self.attrs = {'reduce_op': 'MIN'}
out, self.gradient = compute_graph_send_recv_for_min_max( out, self.gradient = compute_graph_send_recv_for_min_max(
self.inputs, self.attrs) self.inputs, self.attrs)
...@@ -107,7 +107,7 @@ class TestGraphSendRecvSumOp(OpTest): ...@@ -107,7 +107,7 @@ class TestGraphSendRecvSumOp(OpTest):
self.inputs = {'X': x, 'Src_index': src_index, 'Dst_index': dst_index} self.inputs = {'X': x, 'Src_index': src_index, 'Dst_index': dst_index}
self.attrs = {'pool_type': 'SUM'} self.attrs = {'reduce_op': 'SUM'}
out, _ = compute_graph_send_recv_for_sum_mean(self.inputs, self.attrs) out, _ = compute_graph_send_recv_for_sum_mean(self.inputs, self.attrs)
...@@ -134,7 +134,7 @@ class TestGraphSendRecvMeanOp(OpTest): ...@@ -134,7 +134,7 @@ class TestGraphSendRecvMeanOp(OpTest):
self.inputs = {'X': x, 'Src_index': src_index, 'Dst_index': dst_index} self.inputs = {'X': x, 'Src_index': src_index, 'Dst_index': dst_index}
self.attrs = {'pool_type': 'MEAN'} self.attrs = {'reduce_op': 'MEAN'}
out, dst_count = compute_graph_send_recv_for_sum_mean( out, dst_count = compute_graph_send_recv_for_sum_mean(
self.inputs, self.attrs) self.inputs, self.attrs)
...@@ -153,15 +153,15 @@ def compute_graph_send_recv_for_sum_mean(inputs, attributes): ...@@ -153,15 +153,15 @@ def compute_graph_send_recv_for_sum_mean(inputs, attributes):
src_index = inputs['Src_index'] src_index = inputs['Src_index']
dst_index = inputs['Dst_index'] dst_index = inputs['Dst_index']
pool_type = attributes['pool_type'] reduce_op = attributes['reduce_op']
gather_x = x[src_index] gather_x = x[src_index]
target_shape = list(x.shape) target_shape = list(x.shape)
results = np.zeros(target_shape, dtype=x.dtype) results = np.zeros(target_shape, dtype=x.dtype)
if pool_type == 'SUM': if reduce_op == 'SUM':
for index, s_id in enumerate(dst_index): for index, s_id in enumerate(dst_index):
results[s_id, :] += gather_x[index, :] results[s_id, :] += gather_x[index, :]
elif pool_type == 'MEAN': elif reduce_op == 'MEAN':
count = np.zeros(target_shape[0], dtype=np.int32) count = np.zeros(target_shape[0], dtype=np.int32)
for index, s_id in enumerate(dst_index): for index, s_id in enumerate(dst_index):
results[s_id, :] += gather_x[index, :] results[s_id, :] += gather_x[index, :]
...@@ -169,7 +169,7 @@ def compute_graph_send_recv_for_sum_mean(inputs, attributes): ...@@ -169,7 +169,7 @@ def compute_graph_send_recv_for_sum_mean(inputs, attributes):
results = results / count.reshape([-1, 1]) results = results / count.reshape([-1, 1])
results[np.isnan(results)] = 0 results[np.isnan(results)] = 0
else: else:
raise ValueError("Invalid pool_type, only SUM, MEAN supported!") raise ValueError("Invalid reduce_op, only SUM, MEAN supported!")
count = np.zeros(target_shape[0], dtype=np.int32) count = np.zeros(target_shape[0], dtype=np.int32)
for index, s_id in enumerate(dst_index): for index, s_id in enumerate(dst_index):
...@@ -183,7 +183,7 @@ def compute_graph_send_recv_for_min_max(inputs, attributes): ...@@ -183,7 +183,7 @@ def compute_graph_send_recv_for_min_max(inputs, attributes):
src_index = inputs['Src_index'] src_index = inputs['Src_index']
dst_index = inputs['Dst_index'] dst_index = inputs['Dst_index']
pool_type = attributes['pool_type'] reduce_op = attributes['reduce_op']
gather_x = x[src_index] gather_x = x[src_index]
target_shape = list(x.shape) target_shape = list(x.shape)
...@@ -191,7 +191,7 @@ def compute_graph_send_recv_for_min_max(inputs, attributes): ...@@ -191,7 +191,7 @@ def compute_graph_send_recv_for_min_max(inputs, attributes):
gradient = np.zeros_like(x) gradient = np.zeros_like(x)
# Calculate forward output # Calculate forward output
if pool_type == "MAX": if reduce_op == "MAX":
first_set = set() first_set = set()
for index, s_id in enumerate(dst_index): for index, s_id in enumerate(dst_index):
if s_id not in first_set: if s_id not in first_set:
...@@ -200,7 +200,7 @@ def compute_graph_send_recv_for_min_max(inputs, attributes): ...@@ -200,7 +200,7 @@ def compute_graph_send_recv_for_min_max(inputs, attributes):
else: else:
results[s_id, :] = np.maximum(results[s_id, :], results[s_id, :] = np.maximum(results[s_id, :],
gather_x[index, :]) gather_x[index, :])
elif pool_type == "MIN": elif reduce_op == "MIN":
first_set = set() first_set = set()
for index, s_id in enumerate(dst_index): for index, s_id in enumerate(dst_index):
if s_id not in first_set: if s_id not in first_set:
...@@ -210,7 +210,7 @@ def compute_graph_send_recv_for_min_max(inputs, attributes): ...@@ -210,7 +210,7 @@ def compute_graph_send_recv_for_min_max(inputs, attributes):
results[s_id, :] = np.minimum(results[s_id, :], results[s_id, :] = np.minimum(results[s_id, :],
gather_x[index, :]) gather_x[index, :])
else: else:
raise ValueError("Invalid pool_type, only MAX, MIN supported!") raise ValueError("Invalid reduce_op, only MAX, MIN supported!")
# Calculate backward gradient # Calculate backward gradient
index_size = len(src_index) index_size = len(src_index)
......
...@@ -13,7 +13,9 @@ ...@@ -13,7 +13,9 @@
# limitations under the License. # limitations under the License.
from .message_passing import send_u_recv # noqa: F401 from .message_passing import send_u_recv # noqa: F401
from .message_passing import send_ue_recv # noqa: F401
__all__ = [ __all__ = [
'send_u_recv', 'send_u_recv',
'send_ue_recv',
] ]
...@@ -13,3 +13,4 @@ ...@@ -13,3 +13,4 @@
# limitations under the License. # limitations under the License.
from .send_recv import send_u_recv # noqa: F401 from .send_recv import send_u_recv # noqa: F401
from .send_recv import send_ue_recv # noqa: F401
...@@ -19,13 +19,13 @@ from paddle.fluid.framework import Variable ...@@ -19,13 +19,13 @@ from paddle.fluid.framework import Variable
from paddle.fluid.data_feeder import check_variable_and_dtype, check_type, check_dtype, convert_dtype from paddle.fluid.data_feeder import check_variable_and_dtype, check_type, check_dtype, convert_dtype
from paddle import _C_ops from paddle import _C_ops
from .utils import convert_out_size_to_list, get_out_size_tensor_inputs from .utils import convert_out_size_to_list, get_out_size_tensor_inputs, reshape_lhs_rhs
def send_u_recv(x, def send_u_recv(x,
src_index, src_index,
dst_index, dst_index,
pool_type="sum", reduce_op="sum",
out_size=None, out_size=None,
name=None): name=None):
""" """
...@@ -35,13 +35,13 @@ def send_u_recv(x, ...@@ -35,13 +35,13 @@ def send_u_recv(x,
This api is mainly used in Graph Learning domain, and the main purpose is to reduce intermediate memory This api is mainly used in Graph Learning domain, and the main purpose is to reduce intermediate memory
consumption in the process of message passing. Take `x` as the input tensor, we first use `src_index` consumption in the process of message passing. Take `x` as the input tensor, we first use `src_index`
to gather the corresponding data, and then use `dst_index` to update the corresponding position of output tensor to gather the corresponding data, and then use `dst_index` to update the corresponding position of output tensor
in different pooling types, like sum, mean, max, or min. Besides, we can use `out_size` to set necessary output shape. in different reduce ops, like sum, mean, max, or min. Besides, we can use `out_size` to set necessary output shape.
.. code-block:: text .. code-block:: text
Given: Given:
X = [[0, 2, 3], x = [[0, 2, 3],
[1, 4, 5], [1, 4, 5],
[2, 6, 7]] [2, 6, 7]]
...@@ -49,22 +49,23 @@ def send_u_recv(x, ...@@ -49,22 +49,23 @@ def send_u_recv(x,
dst_index = [1, 2, 1, 0] dst_index = [1, 2, 1, 0]
pool_type = "sum" reduce_op = "sum"
out_size = None out_size = None
Then: Then:
Out = [[0, 2, 3], out = [[0, 2, 3],
[2, 8, 10], [2, 8, 10],
[1, 4, 5]] [1, 4, 5]]
Args: Args:
x (Tensor): The input tensor, and the available data type is float32, float64, int32, int64. x (Tensor): The input tensor, and the available data type is float32, float64, int32, int64.
And we support float16 in gpu version.
src_index (Tensor): An 1-D tensor, and the available data type is int32, int64. src_index (Tensor): An 1-D tensor, and the available data type is int32, int64.
dst_index (Tensor): An 1-D tensor, and should have the same shape as `src_index`. dst_index (Tensor): An 1-D tensor, and should have the same shape as `src_index`.
The available data type is int32, int64. The available data type is int32, int64.
pool_type (str): Different pooling types, including `sum`, `mean`, `max`, `min`. reduce_op (str): Different reduce ops, including `sum`, `mean`, `max`, `min`.
Default value is `sum`. Default value is `sum`.
out_size (int|Tensor|None): We can set `out_size` to get necessary output shape. If not set or out_size (int|Tensor|None): We can set `out_size` to get necessary output shape. If not set or
out_size is smaller or equal to 0, then this input will not be used. out_size is smaller or equal to 0, then this input will not be used.
...@@ -88,7 +89,7 @@ def send_u_recv(x, ...@@ -88,7 +89,7 @@ def send_u_recv(x,
indexes = paddle.to_tensor([[0, 1], [1, 2], [2, 1], [0, 0]], dtype="int32") indexes = paddle.to_tensor([[0, 1], [1, 2], [2, 1], [0, 0]], dtype="int32")
src_index = indexes[:, 0] src_index = indexes[:, 0]
dst_index = indexes[:, 1] dst_index = indexes[:, 1]
out = paddle.geometric.send_u_recv(x, src_index, dst_index, pool_type="sum") out = paddle.geometric.send_u_recv(x, src_index, dst_index, reduce_op="sum")
# Outputs: [[0., 2., 3.], [2., 8., 10.], [1., 4., 5.]] # Outputs: [[0., 2., 3.], [2., 8., 10.], [1., 4., 5.]]
x = paddle.to_tensor([[0, 2, 3], [1, 4, 5], [2, 6, 7]], dtype="float32") x = paddle.to_tensor([[0, 2, 3], [1, 4, 5], [2, 6, 7]], dtype="float32")
...@@ -96,39 +97,40 @@ def send_u_recv(x, ...@@ -96,39 +97,40 @@ def send_u_recv(x,
src_index = indexes[:, 0] src_index = indexes[:, 0]
dst_index = indexes[:, 1] dst_index = indexes[:, 1]
out_size = paddle.max(dst_index) + 1 out_size = paddle.max(dst_index) + 1
out = paddle.geometric.send_u_recv(x, src_index, dst_index, pool_type="sum", out_size=out_size) out = paddle.geometric.send_u_recv(x, src_index, dst_index, reduce_op="sum", out_size=out_size)
# Outputs: [[0., 2., 3.], [[2., 8., 10.]]] # Outputs: [[0., 2., 3.], [[2., 8., 10.]]]
x = paddle.to_tensor([[0, 2, 3], [1, 4, 5], [2, 6, 7]], dtype="float32") x = paddle.to_tensor([[0, 2, 3], [1, 4, 5], [2, 6, 7]], dtype="float32")
indexes = paddle.to_tensor([[0, 1], [2, 1], [0, 0]], dtype="int32") indexes = paddle.to_tensor([[0, 1], [2, 1], [0, 0]], dtype="int32")
src_index = indexes[:, 0] src_index = indexes[:, 0]
dst_index = indexes[:, 1] dst_index = indexes[:, 1]
out = paddle.geometric.send_u_recv(x, src_index, dst_index, pool_type="sum") out = paddle.geometric.send_u_recv(x, src_index, dst_index, reduce_op="sum")
# Outputs: [[0., 2., 3.], [2., 8., 10.], [0., 0., 0.]] # Outputs: [[0., 2., 3.], [2., 8., 10.], [0., 0., 0.]]
""" """
if pool_type not in ["sum", "mean", "max", "min"]: if reduce_op not in ["sum", "mean", "max", "min"]:
raise ValueError( raise ValueError(
"pool_type should be `sum`, `mean`, `max` or `min`, but received %s" "reduce_op should be `sum`, `mean`, `max` or `min`, but received %s"
% pool_type) % reduce_op)
# TODO(daisiming): Should we add judgement for out_size: max(dst_index) + 1. # TODO(daisiming): Should we add judgement for out_size: max(dst_index) + 1.
if _in_legacy_dygraph(): if _in_legacy_dygraph():
out_size = convert_out_size_to_list(out_size) out_size = convert_out_size_to_list(out_size)
out, tmp = _C_ops.graph_send_recv(x, src_index, out, tmp = _C_ops.graph_send_recv(x, src_index,
dst_index, None, 'pool_type', dst_index, None, 'reduce_op',
pool_type.upper(), 'out_size', reduce_op.upper(), 'out_size',
out_size) out_size)
return out return out
if in_dygraph_mode(): if in_dygraph_mode():
out_size = convert_out_size_to_list(out_size) out_size = convert_out_size_to_list(out_size)
return _C_ops.final_state_graph_send_recv(x, src_index, dst_index, return _C_ops.final_state_graph_send_recv(x, src_index, dst_index,
pool_type.upper(), out_size) reduce_op.upper(), out_size)
check_variable_and_dtype(x, "X", ("float32", "float64", "int32", "int64"), check_variable_and_dtype(
"graph_send_recv") x, "X", ("float32", "float64", "int32", "int64", "float16"),
"graph_send_recv")
check_variable_and_dtype(src_index, "Src_index", ("int32", "int64"), check_variable_and_dtype(src_index, "Src_index", ("int32", "int64"),
"graph_send_recv") "graph_send_recv")
check_variable_and_dtype(dst_index, "Dst_index", ("int32", "int64"), check_variable_and_dtype(dst_index, "Dst_index", ("int32", "int64"),
...@@ -146,7 +148,7 @@ def send_u_recv(x, ...@@ -146,7 +148,7 @@ def send_u_recv(x,
stop_gradient=True) stop_gradient=True)
inputs = {"X": x, "Src_index": src_index, "Dst_index": dst_index} inputs = {"X": x, "Src_index": src_index, "Dst_index": dst_index}
attrs = {"pool_type": pool_type.upper()} attrs = {"reduce_op": reduce_op.upper()}
get_out_size_tensor_inputs(inputs=inputs, get_out_size_tensor_inputs(inputs=inputs,
attrs=attrs, attrs=attrs,
out_size=out_size, out_size=out_size,
...@@ -160,3 +162,177 @@ def send_u_recv(x, ...@@ -160,3 +162,177 @@ def send_u_recv(x,
}, },
attrs=attrs) attrs=attrs)
return out return out
def send_ue_recv(x,
y,
src_index,
dst_index,
message_op="add",
reduce_op="sum",
out_size=None,
name=None):
"""
Graph Learning message passing api.
This api is mainly used in Graph Learning domain, and the main purpose is to reduce intermediate memory
consumption in the process of message passing. Take `x` as the input tensor, we first use `src_index`
to gather the corresponding data, after computing with `y` in different message ops like add/sub/mul/div, then use `dst_index` to
update the corresponding position of output tensor in different reduce ops, like sum, mean, max, or min.
Besides, we can use `out_size` to set necessary output shape.
.. code-block:: text
Given:
x = [[0, 2, 3],
[1, 4, 5],
[2, 6, 7]]
y = [1, 1, 1]
src_index = [0, 1, 2, 0]
dst_index = [1, 2, 1, 0]
message_op = "add"
reduce_op = "sum"
out_size = None
Then:
out = [[1, 3, 4],
[4, 10, 12],
[2, 5, 6]]
Args:
x (Tensor): The input node feature tensor, and the available data type is float32, float64, int32, int64.
And we support float16 in gpu version.
y (Tensor): The input edge feature tensor, and the available data type is float32, float64, int32, int64.
And we support float16 in gpu version.
src_index (Tensor): An 1-D tensor, and the available data type is int32, int64.
dst_index (Tensor): An 1-D tensor, and should have the same shape as `src_index`.
The available data type is int32, int64.
message_op (str): Different message ops for x and e, including `add`, `sub`, `mul`, `div`.
reduce_op (str): Different reduce ops, including `sum`, `mean`, `max`, `min`.
Default value is `sum`.
out_size (int|Tensor|None): We can set `out_size` to get necessary output shape. If not set or
out_size is smaller or equal to 0, then this input will not be used.
Otherwise, `out_size` should be equal with or larger than
max(dst_index) + 1.
name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
Returns:
out (Tensor): The output tensor, should have the same shape and same dtype as input tensor `x`.
If `out_size` is set correctly, then it should have the same shape as `x` except
the 0th dimension.
Examples:
.. code-block:: python
import paddle
x = paddle.to_tensor([[0, 2, 3], [1, 4, 5], [2, 6, 7]], dtype="float32")
y = paddle.to_tensor([1, 1, 1, 1], dtype="float32")
indexes = paddle.to_tensor([[0, 1], [1, 2], [2, 1], [0, 0]], dtype="int32")
src_index = indexes[:, 0]
dst_index = indexes[:, 1]
out = paddle.geometric.send_ue_recv(x, y, src_index, dst_index, message_op="add", reduce_op="sum")
# Outputs: [[1., 3., 4.], [4., 10., 12.], [2., 5., 6.]]
x = paddle.to_tensor([[0, 2, 3], [1, 4, 5], [2, 6, 7]], dtype="float32")
y = paddle.to_tensor([1, 1, 1], dtype="float32")
indexes = paddle.to_tensor([[0, 1], [2, 1], [0, 0]], dtype="int32")
src_index = indexes[:, 0]
dst_index = indexes[:, 1]
out_size = paddle.max(dst_index) + 1
out = paddle.geometric.send_ue_recv(x, y, src_index, dst_index, message_op="add", reduce_op="sum", out_size=out_size)
# Outputs: [[1., 3., 4.], [[4., 10., 12.]]]
x = paddle.to_tensor([[0, 2, 3], [1, 4, 5], [2, 6, 7]], dtype="float32")
y = paddle.to_tensor([1, 1, 1], dtype="float32")
indexes = paddle.to_tensor([[0, 1], [2, 1], [0, 0]], dtype="int32")
src_index = indexes[:, 0]
dst_index = indexes[:, 1]
out = paddle.geometric.send_ue_recv(x, y, src_index, dst_index, message_op="add", reduce_op="sum")
# Outputs: [[1., 3., 4.], [4., 10., 12.], [0., 0., 0.]]
"""
if message_op not in ["add", "sub", "mul", "div"]:
raise ValueError(
"message_op should be `add`, `sub`, `mul`, `div`, but received %s" %
message_op)
if reduce_op not in ["sum", "mean", "max", "min"]:
raise ValueError(
"reduce_op should be `sum`, `mean`, `max` or `min`, but received %s"
% reduce_op)
x, y = reshape_lhs_rhs(x, y)
if message_op == 'sub':
message_op = 'add'
y = -y
if message_op == "div":
message_op = 'mul'
y = 1. / y
# TODO(daisiming): Should we add judgement for out_size: max(dst_index) + 1.
if _in_legacy_dygraph():
out_size = convert_out_size_to_list(out_size)
out, tmp = _C_ops.graph_send_ue_recv(x, y, src_index, dst_index,
None, 'message_op',
message_op.upper(), 'reduce_op',
reduce_op.upper(), 'out_size',
out_size)
return out
if in_dygraph_mode():
out_size = convert_out_size_to_list(out_size)
return _C_ops.final_state_graph_send_ue_recv(x, y, src_index, dst_index,
message_op.upper(),
reduce_op.upper(),
out_size)
check_variable_and_dtype(
x, "X", ("float32", "float64", "int32", "int64", "float16"),
"graph_send_ue_recv")
check_variable_and_dtype(
y, "Y", ("float32", "float64", "int32", "int64", "float16"),
"graph_send_ue_recv")
check_variable_and_dtype(src_index, "Src_index", ("int32", "int64"),
"graph_send_ue_recv")
check_variable_and_dtype(dst_index, "Dst_index", ("int32", "int64"),
"graph_send_ue_recv")
if out_size:
check_type(out_size, 'out_size', (int, np.int32, np.int64, Variable),
'graph_send_ue_recv')
if isinstance(out_size, Variable):
check_dtype(out_size.dtype, 'out_size', ['int32', 'int64'],
'graph_send_ue_recv')
helper = LayerHelper("send_ue_recv", **locals())
out = helper.create_variable_for_type_inference(dtype=x.dtype)
dst_count = helper.create_variable_for_type_inference(dtype="int32",
stop_gradient=True)
inputs = {"X": x, "Y": y, "Src_index": src_index, "Dst_index": dst_index}
attrs = {"message_op": message_op.upper(), "reduce_op": reduce_op.upper()}
get_out_size_tensor_inputs(inputs=inputs,
attrs=attrs,
out_size=out_size,
op_type='graph_send_ue_recv')
helper.append_op(type="graph_send_ue_recv",
inputs=inputs,
outputs={
"Out": out,
"Dst_count": dst_count
},
attrs=attrs)
return out
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
import numpy as np import numpy as np
import paddle
from paddle.fluid.framework import Variable from paddle.fluid.framework import Variable
from paddle.fluid.data_feeder import check_dtype, convert_dtype from paddle.fluid.data_feeder import check_dtype, convert_dtype
from paddle.fluid.layers.tensor import cast from paddle.fluid.layers.tensor import cast
...@@ -50,3 +51,35 @@ def get_out_size_tensor_inputs(inputs, attrs, out_size, op_type): ...@@ -50,3 +51,35 @@ def get_out_size_tensor_inputs(inputs, attrs, out_size, op_type):
inputs["Out_size"] = out_size inputs["Out_size"] = out_size
else: else:
raise TypeError("Out_size only supports Variable or int.") raise TypeError("Out_size only supports Variable or int.")
def reshape_lhs_rhs(x, y):
"""
Expand dims to ensure there will be no broadcasting issues with different
number of dimensions.
"""
if len(x.shape) == 1:
x = paddle.reshape(x, [-1, 1])
if len(y.shape) == 1:
y = paddle.reshape(y, [-1, 1])
x_shape = paddle.shape(x)
y_shape = paddle.shape(y)
if len(x.shape) != len(y.shape):
max_ndims = max(len(x.shape), len(y.shape))
x_pad_ndims = max_ndims - len(x.shape)
y_pad_ndims = max_ndims - len(y.shape)
new_x_shape = [
x_shape[0],
] + [
1,
] * x_pad_ndims + list(x_shape[1:])
new_y_shape = [
y_shape[0],
] + [
1,
] * y_pad_ndims + list(y_shape[1:])
x = paddle.reshape(x, new_x_shape)
y = paddle.reshape(y, new_y_shape)
return x, y
...@@ -69,7 +69,7 @@ def graph_send_recv(x, ...@@ -69,7 +69,7 @@ def graph_send_recv(x,
src_index (Tensor): An 1-D tensor, and the available data type is int32, int64. src_index (Tensor): An 1-D tensor, and the available data type is int32, int64.
dst_index (Tensor): An 1-D tensor, and should have the same shape as `src_index`. dst_index (Tensor): An 1-D tensor, and should have the same shape as `src_index`.
The available data type is int32, int64. The available data type is int32, int64.
pool_type (str): The pooling type of graph_send_recv, including `sum`, `mean`, `max`, `min`. pool_type (str): The pooling types of graph_send_recv, including `sum`, `mean`, `max`, `min`.
Default value is `sum`. Default value is `sum`.
out_size (int|Tensor|None): We can set `out_size` to get necessary output shape. If not set or out_size (int|Tensor|None): We can set `out_size` to get necessary output shape. If not set or
out_size is smaller or equal to 0, then this input will not be used. out_size is smaller or equal to 0, then this input will not be used.
...@@ -123,7 +123,7 @@ def graph_send_recv(x, ...@@ -123,7 +123,7 @@ def graph_send_recv(x,
if _in_legacy_dygraph(): if _in_legacy_dygraph():
out_size = convert_out_size_to_list(out_size) out_size = convert_out_size_to_list(out_size)
out, tmp = _C_ops.graph_send_recv(x, src_index, out, tmp = _C_ops.graph_send_recv(x, src_index,
dst_index, None, 'pool_type', dst_index, None, 'reduce_op',
pool_type.upper(), 'out_size', pool_type.upper(), 'out_size',
out_size) out_size)
return out return out
...@@ -151,7 +151,7 @@ def graph_send_recv(x, ...@@ -151,7 +151,7 @@ def graph_send_recv(x,
stop_gradient=True) stop_gradient=True)
inputs = {"X": x, "Src_index": src_index, "Dst_index": dst_index} inputs = {"X": x, "Src_index": src_index, "Dst_index": dst_index}
attrs = {"pool_type": pool_type.upper()} attrs = {"reduce_op": pool_type.upper()}
get_out_size_tensor_inputs(inputs=inputs, get_out_size_tensor_inputs(inputs=inputs,
attrs=attrs, attrs=attrs,
out_size=out_size, out_size=out_size,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册