einsum_grad_impl.h 9.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once

#include "paddle/phi/core/dense_tensor.h"
17
#include "paddle/phi/kernels/complex_kernel.h"
18 19 20 21 22
#include "paddle/phi/kernels/impl/einsum_impl.h"
#include "paddle/phi/kernels/tile_kernel.h"
#include "paddle/utils/string/string_helper.h"

namespace phi {
23

24 25 26 27 28 29
template <typename T, typename Context>
DenseTensor PerformTileAndReduction(const Context& dev_ctx,
                                    const LabelMap& label2type,
                                    const LabelMap& label2shape,
                                    const std::vector<int>& broadcast_dims,
                                    const std::vector<int>& ellipsis_dims,
30 31 32 33 34 35 36
                                    std::string equ,   // value pass
                                    DenseTensor& t) {  // NOLINT
  auto tmp_label = equ;
  ReplaceEllipsis(tmp_label);
  auto tmp_union = unique_labels(tmp_label);
  auto op_label = std::string(tmp_union.begin(), tmp_union.end());
  VLOG(5) << "Start PerformTileAndReduction" << equ;
37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63
  DenseTensor ret;
  std::vector<int> repeat_times;
  std::vector<int> resize_dims;
  std::vector<int> recover_shape;
  for (int c : op_label) {
    if (label2type[c] == LabelType::Reduction) {
      // '.' can't be Reduction, so we don't deal '.' here.
      repeat_times.push_back(label2shape[c]);
      resize_dims.push_back(1);
      recover_shape.push_back(label2shape[c]);
    } else {
      if (c != '.') {
        resize_dims.push_back(label2shape[c]);
        repeat_times.push_back(1);
        recover_shape.push_back(label2shape[c]);
      } else {
        int n_dims = broadcast_dims.size();
        resize_dims.insert(
            resize_dims.end(), broadcast_dims.begin(), broadcast_dims.end());
        recover_shape.insert(
            recover_shape.end(), ellipsis_dims.begin(), ellipsis_dims.end());
        while (n_dims--) repeat_times.push_back(1);
      }
    }
  }
  t.Resize(make_ddim(resize_dims));
  DenseTensor after_tile;
64 65 66 67 68
  if (std::all_of(repeat_times.begin(), repeat_times.end(), [](int x) {
        return x == 1;
      })) {
    after_tile = t;
  } else {
69 70
    VLOG(4) << "do TileKernel with repeat_times="
            << paddle::string::join_strings(repeat_times, ",");
71 72
    TileKernel<T, Context>(dev_ctx, t, repeat_times, &after_tile);
  }
73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89
  size_t n_ellipsis_idx = op_label.find(".", 0);
  if (n_ellipsis_idx != std::string::npos) {
    // may be we need reduce. broadcast_dims is not equal to ellipsis dims.
    std::vector<int64_t> to_reduce;
    for (size_t i = 0; i < broadcast_dims.size() - ellipsis_dims.size(); ++i)
      to_reduce.push_back(i + n_ellipsis_idx);

    int new_offset =
        n_ellipsis_idx + broadcast_dims.size() - ellipsis_dims.size();
    for (size_t i = 0; i < ellipsis_dims.size(); ++i)
      if (ellipsis_dims[i] == 1) to_reduce.push_back(i + new_offset);

    VLOG(5) << "PermformTileAndReduction: reduce sum axis: "
            << paddle::string::join_strings(to_reduce, ",");
    if (to_reduce.size() != 0) {
      ret = Sum<T, Context>(dev_ctx,
                            after_tile,
90
                            phi::IntArray(to_reduce),
91 92 93 94 95 96 97 98 99 100 101
                            after_tile.dtype(),
                            false);  // not keep dim.
    } else {
      ret = after_tile;
    }
  } else {
    ret = after_tile;
  }
  VLOG(5) << "PermformTileAndReduction: recover shape: "
          << paddle::string::join_strings(recover_shape, ",");
  ret.Resize(make_ddim(recover_shape));
102 103 104
  // undiagonalize by einsum equation. only contain undiagonal operations.
  DenseTensor out;
  VLOG(5) << "Undiagonal by einsum with args: " << op_label + "->" + equ;
105
  EinsumInferKernel<T, Context>(dev_ctx, {&ret}, op_label + "->" + equ, &out);
106
  return out;
107 108 109 110 111
}

template <typename T, typename Context>
void EinsumGradKernel(const Context& dev_ctx,
                      const std::vector<const DenseTensor*>& x,
112
                      const std::vector<const DenseTensor*>& inner_cache,
113 114 115
                      const DenseTensor& out_grad,
                      const std::string& equation,
                      std::vector<DenseTensor*> x_grad) {
116
  VLOG(5) << "Start EinsumGradKernel:";
117 118 119 120 121 122 123 124 125 126 127 128
  LabelMap labelshape(0);
  LabelMap labeltype(LabelType::Reduction);
  std::vector<LabelMap> label2perms(x.size(), LabelMap(-1));
  std::vector<char> all_labels;  // order: ABO, AO, BO, AB, Reduce
  std::vector<std::vector<int>> ellipsis_dims(2);
  std::vector<int> broadcast_dims;
  std::vector<int> output_dims;

  std::vector<DDim> input_dims;
  for (auto& i : x) {
    input_dims.push_back(i->dims());
  }
129
  std::vector<std::string> input_strs;
130 131 132 133 134 135 136 137 138 139
  std::string right;
  ParseEinsumEquation(equation,
                      input_dims,
                      &labelshape,
                      &labeltype,
                      &all_labels,
                      &label2perms,
                      &ellipsis_dims,
                      &broadcast_dims,
                      &output_dims,
140 141
                      &right,
                      &input_strs);
142 143 144 145 146

  auto gather_labels_except_reduction = [&labeltype](std::string all) {
    std::string res("");
    for (auto c : all)
      if (labeltype[static_cast<int>(c)] != LabelType::Reduction) res += c;
147 148
    auto tmp_unique = unique_labels(res);
    return std::string(tmp_unique.begin(), tmp_unique.end());
149 150 151 152 153 154 155 156 157
  };
  if (x.size() == 1) {  // Unary
    auto splits = paddle::string::split_string(equation, "->");
    auto left = splits[0];
    right = splits[1].substr(1);
    auto new_equation = right + "->" + gather_labels_except_reduction(left);
    auto new_operands = std::vector<const DenseTensor*>();
    new_operands.push_back(&out_grad);
    DenseTensor before_tile;
158
    VLOG(5) << "new_equation is " << new_equation;
159 160
    EinsumInferKernel<T, Context>(
        dev_ctx, new_operands, new_equation, &before_tile);
161 162 163 164 165 166 167 168 169 170 171 172 173 174
    *(x_grad[0]) = PerformTileAndReduction<T, Context>(dev_ctx,
                                                       labeltype,
                                                       labelshape,
                                                       broadcast_dims,
                                                       ellipsis_dims[0],
                                                       left,
                                                       before_tile);
  } else {
    auto splits = paddle::string::split_string(equation, "->");
    auto left = splits[0];
    auto ops = paddle::string::split_string(left, ",");
    right = splits[1].substr(1);

    auto equation_for_A =
175
        ops[1] + "," + right + "->" + gather_labels_except_reduction(ops[0]);
176 177 178 179 180
    auto equation_for_B =
        right + "," + ops[0] + "->" + gather_labels_except_reduction(ops[1]);
    auto operands_for_A = std::vector<const DenseTensor*>();
    auto operands_for_B = std::vector<const DenseTensor*>();
    DenseTensor dA, dB;
181
    auto out_grad_conj = Conj<T, Context>(dev_ctx, out_grad);
182
    // dA = einsum(B, dC)
183
    operands_for_A.push_back(x[1]);
184
    operands_for_A.push_back(&out_grad_conj);
185
    // dB = einsum(dC, A)
186
    operands_for_B.push_back(&out_grad_conj);
187 188 189
    operands_for_B.push_back(x[0]);

    DenseTensor before_tile;
190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215

    std::vector<DenseTensor> cache(3);  // set empty; TA, TB, TdC
    if (inner_cache.size() >
        0) {  // for compatibility,  we can load and run v2.3 EinsumOp.
      cache[0].ShareBufferWith(*(inner_cache[0]));
      cache[1].ShareBufferWith(*(inner_cache[1]));
    }
    EinsumKernelImpl<T, Context>(dev_ctx,
                                 all_labels,
                                 operands_for_A,
                                 equation_for_A,
                                 &dA,
                                 {&cache[1], &cache[2]},
                                 false);

    EinsumKernelImpl<T, Context>(dev_ctx,
                                 all_labels,
                                 operands_for_B,
                                 equation_for_B,
                                 &dB,
                                 {&cache[2], &cache[0]},
                                 false);

    // release the cache tensor dTC to save memory right now. they are useless
    // now.
    cache.clear();
216 217 218 219 220 221 222 223
    if (x_grad[0]) {
      *(x_grad[0]) = PerformTileAndReduction<T, Context>(dev_ctx,
                                                         labeltype,
                                                         labelshape,
                                                         broadcast_dims,
                                                         ellipsis_dims[0],
                                                         ops[0],
                                                         dA);
224
      *(x_grad[0]) = Conj<T, Context>(dev_ctx, *x_grad[0]);
225 226 227 228 229 230 231 232 233
    }
    if (x_grad[1]) {
      *(x_grad[1]) = PerformTileAndReduction<T, Context>(dev_ctx,
                                                         labeltype,
                                                         labelshape,
                                                         broadcast_dims,
                                                         ellipsis_dims[1],
                                                         ops[1],
                                                         dB);
234
      *(x_grad[1]) = Conj<T, Context>(dev_ctx, *x_grad[1]);
235
    }
236 237 238
  }
}
}  // namespace phi