bmm_op.h 2.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 *     Unless required by applicable law or agreed to in writing, software
 *     distributed under the License is distributed on an "AS IS" BASIS,
 *     WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *     See the License for the specific language governing permissions and
 *     limitations under the License. */

#ifndef PADDLE_FLUID_OPERATORS_BMM_OP_H_
#define PADDLE_FLUID_OPERATORS_BMM_OP_H_

#include <algorithm>
#include <utility>
#include <vector>
21

22 23
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
24 25
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/math_function.h"
26 27 28
namespace paddle {
namespace operators {

29
using Tensor = phi::DenseTensor;
30 31

static void ReshapeTensorIntoMatrixSequence(
32
    phi::DenseTensor *x, const phi::funcs::MatDescriptor &descriptor) {
33 34 35 36 37 38 39 40 41 42
  int64_t h, w;
  h = descriptor.height_;
  w = descriptor.width_;
  if (descriptor.trans_) {
    std::swap(w, h);
  }

  x->Resize({descriptor.batch_size_, h, w});
}

43 44 45
static void ReshapeXYOutIntoMatrixSequence(phi::DenseTensor *x,
                                           phi::DenseTensor *y,
                                           phi::DenseTensor *out,
46
                                           bool trans_x,
47 48 49
                                           bool trans_y) {
  auto x_dim = x->dims();
  auto y_dim = y->dims();
50 51
  auto mat_dim_x = phi::funcs::CreateMatrixDescriptor(x_dim, 0, false);
  auto mat_dim_y = phi::funcs::CreateMatrixDescriptor(y_dim, 0, false);
52 53

  out->Resize({std::max(mat_dim_x.batch_size_, mat_dim_y.batch_size_),
54 55
               mat_dim_x.height_,
               mat_dim_y.width_});
56 57 58 59 60 61 62 63

  ReshapeTensorIntoMatrixSequence(x, mat_dim_x);
  ReshapeTensorIntoMatrixSequence(y, mat_dim_y);
}

}  // namespace operators
}  // namespace paddle
#endif  // PADDLE_FLUID_OPERATORS_BMM_OP_H_