graph_send_recv_grad_kernel.cc 6.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/graph_send_recv_grad_kernel.h"
#include "paddle/phi/kernels/cpu/graph_send_recv_funcs.h"

#include <algorithm>
#include <vector>

#include "paddle/phi/core/kernel_registry.h"

namespace phi {

template <typename T, typename IndexT, typename Functor>
26
void GraphSendRecvCpuGradLoop(const int& index_size,
27 28 29
                              const IndexT* s_index,
                              const IndexT* d_index,
                              const DenseTensor& src,
30
                              const DenseTensor& input,
31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56
                              DenseTensor* dst,
                              const std::string& pool_type,
                              const int* dst_count = nullptr,
                              const DenseTensor* output = nullptr) {
  if (pool_type == "SUM") {
    Functor functor;
    for (int i = 0; i < index_size; ++i) {
      const IndexT& src_idx = s_index[i];
      const IndexT& dst_idx = d_index[i];
      ElementwiseInnerOperation<T, IndexT, Functor>(
          src, dst, src_idx, dst_idx, false, functor);
    }
  } else if (pool_type == "MEAN") {
    for (int i = 0; i < index_size; ++i) {
      const IndexT& src_idx = s_index[i];
      const IndexT& dst_idx = d_index[i];
      auto src_slice = src.Slice(src_idx, src_idx + 1);
      auto dst_slice = dst->Slice(dst_idx, dst_idx + 1);
      auto eigen_src = phi::EigenVector<T>::Flatten(src_slice);
      auto eigen_dst = phi::EigenVector<T>::Flatten(dst_slice);
      eigen_dst += (eigen_src / static_cast<T>(dst_count[src_idx]));
    }
  } else if (pool_type == "MIN" || pool_type == "MAX") {
    for (int i = 0; i < index_size; ++i) {
      const IndexT& forward_src_idx = d_index[i];
      const IndexT& forward_dst_idx = s_index[i];
57
      auto input_slice = input.Slice(forward_src_idx, forward_src_idx + 1);
58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74
      auto output_slice = output->Slice(forward_dst_idx, forward_dst_idx + 1);
      auto eigen_input = phi::EigenVector<T>::Flatten(input_slice);
      auto eigen_output = phi::EigenVector<T>::Flatten(output_slice);

      auto src_slice = src.Slice(forward_dst_idx, forward_dst_idx + 1);
      auto dst_slice = dst->Slice(forward_src_idx, forward_src_idx + 1);
      auto eigen_src = phi::EigenVector<T>::Flatten(src_slice);
      auto eigen_dst = phi::EigenVector<T>::Flatten(dst_slice);
      eigen_dst += eigen_src * (eigen_output == eigen_input);
    }
  }
}

template <typename Context, typename T, typename IndexT>
void GraphSendRecvGradOpKernelLaunchHelper(
    const Context& ctx,
    const DenseTensor& out_grad,
75
    const DenseTensor& x,
76 77 78 79 80 81 82 83 84 85
    const DenseTensor& src_index,
    const DenseTensor& dst_index,
    const std::string& pool_type,
    DenseTensor* x_grad,
    const DenseTensor* dst_count = nullptr,
    const DenseTensor* out = nullptr) {
  const int& index_size = dst_index.dims()[0];

  ctx.template Alloc<T>(x_grad);
  T* p_output = x_grad->data<T>();
86
  const auto& src_dims = x.dims();
87 88 89 90 91 92 93 94 95 96 97 98
  int64_t memset_size = 1;
  for (int i = 0; i < src_dims.size(); ++i) memset_size *= src_dims[i];
  const size_t& memset_bytes = memset_size * sizeof(T);
  memset(p_output, 0, memset_bytes);

  if (index_size == 0) return;

  const IndexT* s_index = src_index.data<IndexT>();
  const IndexT* d_index = dst_index.data<IndexT>();

  if (pool_type == "SUM") {
    GraphSendRecvCpuGradLoop<T, IndexT, GraphSendRecvSumFunctor<T>>(
99
        index_size, d_index, s_index, out_grad, x, x_grad, pool_type);
100 101 102
  } else if (pool_type == "MEAN") {
    const int* s_count = dst_count->data<int>();
    // Functor not used here.
103 104
    GraphSendRecvCpuGradLoop<T, IndexT, GraphSendRecvSumFunctor<T>>(
        index_size, d_index, s_index, out_grad, x, x_grad, pool_type, s_count);
105 106
  } else if (pool_type == "MIN" || pool_type == "MAX") {
    // Functor not used here.
107
    GraphSendRecvCpuGradLoop<T, IndexT, GraphSendRecvMinFunctor<T>>(index_size,
108 109 110
                                                                    d_index,
                                                                    s_index,
                                                                    out_grad,
111
                                                                    x,
112 113 114 115 116 117 118 119 120 121
                                                                    x_grad,
                                                                    pool_type,
                                                                    nullptr,
                                                                    out);
  }
}

template <typename T, typename Context>
void GraphSendRecvGradKernel(const Context& ctx,
                             const DenseTensor& out_grad,
122
                             const DenseTensor& x,
123 124 125 126 127 128 129 130 131 132 133
                             paddle::optional<const DenseTensor&> out,
                             const DenseTensor& src_index,
                             const DenseTensor& dst_index,
                             paddle::optional<const DenseTensor&> dst_count,
                             const std::string& pool_type,
                             DenseTensor* x_grad) {
  auto index_type = src_index.dtype();
  if (index_type == phi::DataType::INT32) {
    GraphSendRecvGradOpKernelLaunchHelper<Context, T, int32_t>(
        ctx,
        out_grad,
134
        x,
135 136 137 138 139 140 141 142 143 144
        src_index,
        dst_index,
        pool_type,
        x_grad,
        dst_count.get_ptr(),
        out.get_ptr());
  } else if (index_type == phi::DataType::INT64) {
    GraphSendRecvGradOpKernelLaunchHelper<Context, T, int64_t>(
        ctx,
        out_grad,
145
        x,
146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164
        src_index,
        dst_index,
        pool_type,
        x_grad,
        dst_count.get_ptr(),
        out.get_ptr());
  }
}

}  // namespace phi

PD_REGISTER_KERNEL(graph_send_recv_grad,
                   CPU,
                   ALL_LAYOUT,
                   phi::GraphSendRecvGradKernel,
                   float,
                   double,
                   int,
                   int64_t) {}