match_matrix_tensor_op.cc 17.5 KB
Newer Older
A
Aurelius84 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15 16
#include "paddle/fluid/operators/match_matrix_tensor_op.h"

A
Aurelius84 已提交
17 18 19
#include <fstream>
#include <iomanip>
#include <iostream>
20
#include <memory>
A
Aurelius84 已提交
21 22 23 24 25 26 27 28 29 30 31
#include <vector>

#include "paddle/fluid/operators/search_compute.h"

namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
using LoD = framework::LoD;

void MatchMatrixTensorOP::InferShape(framework::InferShapeContext* ctx) const {
32 33 34 35 36
  OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "match_matrix_tensor");
  OP_INOUT_CHECK(ctx->HasInput("Y"), "Input", "Y", "match_matrix_tensor");
  OP_INOUT_CHECK(ctx->HasInput("W"), "Input", "W", "match_matrix_tensor");
  OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "match_matrix_tensor");
  OP_INOUT_CHECK(ctx->HasOutput("Tmp"), "Output", "Tmp", "match_matrix_tensor");
A
Aurelius84 已提交
37 38

  auto x_dims = ctx->GetInputDim("X");
39 40
  PADDLE_ENFORCE_EQ(x_dims.size(),
                    2,
41 42 43 44
                    platform::errors::InvalidArgument(
                        "The dimensions of Input(X) should be equal to 2, "
                        "but received %d.",
                        x_dims.size()));
A
Aurelius84 已提交
45 46

  auto y_dims = ctx->GetInputDim("Y");
47 48
  PADDLE_ENFORCE_EQ(y_dims.size(),
                    2,
49 50 51 52
                    platform::errors::InvalidArgument(
                        "The dimensions of Input(Y) should be equal to 2, "
                        "but received %d.",
                        y_dims.size()));
A
Aurelius84 已提交
53 54

  auto w_dims = ctx->GetInputDim("W");
55 56
  PADDLE_ENFORCE_EQ(w_dims.size(),
                    3,
57 58 59 60
                    platform::errors::InvalidArgument(
                        "The dimensions of Input(W) should be equal to 3, "
                        "but received %d.",
                        w_dims.size()));
A
Aurelius84 已提交
61 62

  int dim_t = ctx->Attrs().Get<int>("dim_t");
63
  PADDLE_ENFORCE_EQ(
64 65
      w_dims[0],
      x_dims[1],
66 67 68 69
      platform::errors::InvalidArgument(
          "The first dimension of Input(W) should be equal to the second "
          "dimension of Input(X). But received the first dimension of Input(W) "
          "is %d, the second dimension of Input(X) is %d.",
70 71
          w_dims[0],
          x_dims[1]));
72
  PADDLE_ENFORCE_EQ(
73 74
      w_dims[1],
      dim_t,
75 76 77
      platform::errors::InvalidArgument(
          "The second dimension of Input(W) should be equal to 'dim_t', but "
          "received the second dimension of Input(W) is %d, 'dim_t' is %d.",
78 79
          w_dims[1],
          dim_t));
80
  PADDLE_ENFORCE_EQ(
81 82
      w_dims[2],
      y_dims[1],
83 84 85 86
      platform::errors::InvalidArgument(
          "The last dimension of Input(W) should be equal to "
          "the second dimension of Input(Y). But received the last dimension "
          "of Input(W) is %d, the second dimension of Input(Y) is %d.",
87 88
          w_dims[2],
          y_dims[1]));
A
Aurelius84 已提交
89

90 91
  int64_t out_dim_0 = -1;
  int64_t tmp_dim_0 = -1;
A
Aurelius84 已提交
92 93
  if (ctx->IsRuntime()) {
    framework::Variable* x_var =
R
Ruibiao Chen 已提交
94
        PADDLE_GET(framework::Variable*, ctx->GetInputVarPtrs("X")[0]);
A
Aurelius84 已提交
95
    const auto& x_lod = x_var->Get<LoDTensor>().lod();
96 97
    PADDLE_ENFORCE_EQ(x_lod.empty(),
                      false,
98 99 100
                      platform::errors::InvalidArgument(
                          "The Input(X) should hold LoD information, but "
                          "received Input(X).lod() is empty."));
A
Aurelius84 已提交
101
    const auto& x_lod_0 = x_lod[0];
102 103
    PADDLE_ENFORCE_GE(x_lod_0.size(),
                      2,
104 105 106 107
                      platform::errors::InvalidArgument(
                          "The dimensions of Input(X)'s LoD data should be "
                          "equal to 2, but received %d.",
                          x_lod_0.size()));
108 109
    PADDLE_ENFORCE_EQ(x_dims[0],
                      static_cast<int64_t>(x_lod_0.back()),
110 111 112 113 114
                      platform::errors::InvalidArgument(
                          "The last element of Input(X)'s LoD data should be "
                          "equal to the first dimension of Input(X). "
                          "But received the last element of Input(X)'s LoD "
                          "data is %d, the first dimension of Input(X) is %d.",
115 116
                          x_lod_0.back(),
                          x_dims[0]));
A
Aurelius84 已提交
117 118

    framework::Variable* y_var =
R
Ruibiao Chen 已提交
119
        PADDLE_GET(framework::Variable*, ctx->GetInputVarPtrs("Y")[0]);
A
Aurelius84 已提交
120
    const auto& y_lod = y_var->Get<LoDTensor>().lod();
121 122
    PADDLE_ENFORCE_EQ(y_lod.empty(),
                      false,
123 124 125
                      platform::errors::InvalidArgument(
                          "The Input(Y) should hold LoD information, but "
                          "received Input(Y).lod() is empty."));
A
Aurelius84 已提交
126
    const auto& y_lod_0 = y_lod[0];
127 128
    PADDLE_ENFORCE_GE(y_lod_0.size(),
                      2,
129 130 131 132
                      platform::errors::InvalidArgument(
                          "The dimensions of Input(Y)'s LoD data should be "
                          "equal to 2, but received %d.",
                          y_lod_0.size()));
133 134
    PADDLE_ENFORCE_EQ(y_dims[0],
                      static_cast<int64_t>(y_lod_0.back()),
135 136 137 138 139
                      platform::errors::InvalidArgument(
                          "The last element of Input(Y)'s LoD data should be "
                          "equal to the first dimension of Input(Y). "
                          "But received the last element of Input(Y)'s LoD "
                          "data is %d, the first dimension of Input(Y) is %d.",
140 141
                          y_lod_0.back(),
                          y_dims[0]));
A
Aurelius84 已提交
142

143 144
    PADDLE_ENFORCE_EQ(x_lod_0.size(),
                      y_lod_0.size(),
145 146 147 148 149
                      platform::errors::InvalidArgument(
                          "The dimensions of Input(X)'s and Input(Y)'s LoD "
                          "data should be equal. "
                          "But received the dimensions of Input(X)'s LoD is "
                          "%d, the dimensions of Input(Y)'s LoD is %d.",
150 151
                          x_lod_0.size(),
                          y_lod_0.size()));
A
Aurelius84 已提交
152 153 154

    out_dim_0 = 0;
    for (size_t i = 1; i < x_lod_0.size(); i++) {
155 156
      int64_t x_len = x_lod_0[i] - x_lod_0[i - 1];
      int64_t y_len = y_lod_0[i] - y_lod_0[i - 1];
A
Aurelius84 已提交
157 158 159 160 161 162 163 164
      out_dim_0 += (x_len * y_len);
    }
    out_dim_0 *= dim_t;

    tmp_dim_0 = x_dims[0] * dim_t * x_dims[1];
  } else {
    // compile time
    framework::VarDesc* x_desc =
R
Ruibiao Chen 已提交
165
        PADDLE_GET(framework::VarDesc*, ctx->GetInputVarPtrs("X")[0]);
166
    PADDLE_ENFORCE_GE(
167 168
        x_desc->GetLoDLevel(),
        1,
169 170 171
        platform::errors::InvalidArgument("The LoD level of Input(X) should be "
                                          "greater than 1, but reviced %d.",
                                          x_desc->GetLoDLevel()));
A
Aurelius84 已提交
172
    framework::VarDesc* y_desc =
R
Ruibiao Chen 已提交
173
        PADDLE_GET(framework::VarDesc*, ctx->GetInputVarPtrs("Y")[0]);
174
    PADDLE_ENFORCE_GE(
175 176
        y_desc->GetLoDLevel(),
        1,
177 178 179
        platform::errors::InvalidArgument("The LoD level of Input(Y) should be "
                                          "greater than 1, but reviced %d.",
                                          y_desc->GetLoDLevel()));
180
    ctx->ShareLoD("X", "Out");
A
Aurelius84 已提交
181 182 183 184 185 186
  }

  std::vector<int64_t> out_dims_vec{out_dim_0};
  out_dims_vec.push_back(1);
  std::vector<int64_t> tmp_dims_vec{tmp_dim_0};
  tmp_dims_vec.push_back(1);
187 188
  ctx->SetOutputDim("Out", phi::make_ddim(out_dims_vec));
  ctx->SetOutputDim("Tmp", phi::make_ddim(tmp_dims_vec));
A
Aurelius84 已提交
189 190 191 192
}

void MatchMatrixTensorOpGrad::InferShape(
    framework::InferShapeContext* ctx) const {
193 194 195
  OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "match_matrix_tensor_grad");
  OP_INOUT_CHECK(ctx->HasInput("Y"), "Input", "Y", "match_matrix_tensor_grad");
  OP_INOUT_CHECK(ctx->HasInput("W"), "Input", "W", "match_matrix_tensor_grad");
196 197 198 199
  OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")),
                 "Input",
                 "Out@GRAD",
                 "match_matrix_tensor_grad");
A
Aurelius84 已提交
200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232

  if (ctx->HasOutput(framework::GradVarName("X"))) {
    ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
    ctx->ShareLoD("X", /*->*/ framework::GradVarName("X"));
  }
  if (ctx->HasOutput(framework::GradVarName("Y"))) {
    ctx->SetOutputDim(framework::GradVarName("Y"), ctx->GetInputDim("Y"));
    ctx->ShareLoD("Y", /*->*/ framework::GradVarName("Y"));
  }
  if (ctx->HasOutput(framework::GradVarName("W"))) {
    ctx->SetOutputDim(framework::GradVarName("W"), ctx->GetInputDim("W"));
  }
}

void MatchMatrixTensorOpMaker::Make() {
  AddInput("X",
           "X (LoDTensor, default LoDTensor<float>) Input variable which "
           "should contain lod information.");
  AddInput("Y",
           "Y (LoDTensor, default LoDTensor<float>) Input variable which "
           "should contain lod information.");
  AddInput("W", "W (Tensor), The weight of X and Y.");
  AddAttr<int>("dim_t", "the dim of W").SetDefault(1);
  AddOutput("Out",
            "(LoDTensor, default LoDTensor<float>) Output variable which "
            "is X * W * Y");
  AddOutput("Tmp",
            "(LoDTensor, default LoDTensor<float>) tmp variable which is "
            "used for X * W");
  AddComment(R"DOC(
      Match Matrix Tensor Operator

      This operator calculate X * W * Y, only support 2-D for X and Y.
233
      the output is a level-1 LodTensor:
A
Aurelius84 已提交
234
        level_0: dim_t
235

A
Aurelius84 已提交
236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251
      NOTE: only support 'float32' data type now.

    )DOC");
}

template <typename DeviceContext, typename T>
class CPUMatchMatrixTensorOPKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
    auto* x = ctx.Input<LoDTensor>("X");
    auto* y = ctx.Input<LoDTensor>("Y");
    auto* w = ctx.Input<Tensor>("W");
    auto* out = ctx.Output<LoDTensor>("Out");
    auto* tmp = ctx.Output<LoDTensor>("Tmp");

    int dim_t = ctx.Attr<int>("dim_t");
252
    int64_t dim_in = x->dims()[1];
A
Aurelius84 已提交
253 254 255 256 257

    const auto& offset_l = x->lod()[0];
    const auto& offset_r = y->lod()[0];

    std::vector<size_t> top_offset;
258
    size_t top_size = 0;
A
Aurelius84 已提交
259 260
    top_offset.push_back(top_size);
    for (size_t b = 0; b < x->lod()[0].size() - 1; b++) {
261 262
      size_t len_l = offset_l[b + 1] - offset_l[b];
      size_t len_r = offset_r[b + 1] - offset_r[b];
A
Aurelius84 已提交
263 264 265 266 267 268 269 270 271 272
      top_size += dim_t * len_l * len_r;
      top_offset.push_back(top_size);
    }
    auto* out_data = out->mutable_data<T>(ctx.GetPlace());
    memset(out_data, 0.0, out->dims()[0] * out->dims()[1] * sizeof(T));

    auto* bottom_l_data = x->data<T>();
    auto* bottom_r_data = y->data<T>();
    auto* t_data = w->data<T>();
    auto* bottom_l_trans_data = tmp->mutable_data<T>(ctx.GetPlace());
273 274
    memset(
        bottom_l_trans_data, 0.0, tmp->dims()[0] * tmp->dims()[1] * sizeof(T));
A
Aurelius84 已提交
275

L
Leo Chen 已提交
276
    auto blas = phi::funcs::GetBlas<phi::CPUContext, T>(ctx);
A
Aurelius84 已提交
277

278 279 280 281 282 283 284 285 286 287 288
    call_gemm(blas,
              CblasNoTrans,
              CblasNoTrans,
              x->dims()[0],
              dim_t * dim_in,
              dim_in,
              1.0f,
              bottom_l_data,
              t_data,
              0.0f,
              bottom_l_trans_data);
A
Aurelius84 已提交
289 290 291

    for (size_t b = 0; b < x->lod()[0].size() - 1; b++) {
      for (int t = 0; t < dim_t; t++) {
292 293
        size_t len_l = offset_l[b + 1] - offset_l[b];
        size_t len_r = offset_r[b + 1] - offset_r[b];
A
Aurelius84 已提交
294 295 296 297
        auto* top_data = out_data + top_offset[b] + t * len_l * len_r;
        const auto* l_t_data =
            bottom_l_trans_data + offset_l[b] * dim_t * dim_in + t * dim_in;
        const auto* r_data = bottom_r_data + offset_r[b] * dim_in;
L
Leo Chen 已提交
298
        auto blas_2 = phi::funcs::GetBlas<phi::CPUContext, T>(ctx);
299 300 301 302 303 304 305 306 307 308 309
        call_gemm_with_lda(blas_2,
                           CblasNoTrans,
                           CblasTrans,
                           len_l,
                           len_r,
                           dim_in,
                           1.0f,
                           l_t_data,
                           r_data,
                           0.0f,
                           top_data,
A
Aurelius84 已提交
310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330
                           dim_t * dim_in);
      }
    }

    framework::LoD out_lod;
    out_lod.push_back(top_offset);

    out->set_lod(out_lod);
  }
};

template <typename DeviceContext, typename T>
class CPUMatchMatrixTensorOPGradKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
    auto* x = ctx.Input<LoDTensor>("X");
    auto* y = ctx.Input<LoDTensor>("Y");
    auto* w = ctx.Input<Tensor>("W");
    auto* tmp = ctx.Input<LoDTensor>("Tmp");

    int dim_t = ctx.Attr<int>("dim_t");
331
    int64_t dim_in = x->dims()[1];
A
Aurelius84 已提交
332 333 334

    const auto& offset_l = x->lod()[0];
    const auto& offset_r = y->lod()[0];
335 336
    std::vector<size_t> top_offset;
    size_t top_size = 0;
A
Aurelius84 已提交
337 338
    top_offset.push_back(top_size);
    for (size_t b = 0; b < x->lod()[0].size() - 1; b++) {
339 340
      size_t len_l = offset_l[b + 1] - offset_l[b];
      size_t len_r = offset_r[b + 1] - offset_r[b];
A
Aurelius84 已提交
341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361
      top_size += dim_t * len_l * len_r;
      top_offset.push_back(top_size);
    }

    auto* bottom_l_data = x->data<T>();
    auto* bottom_r_data = y->data<T>();
    auto* bottom_l_trans_data = tmp->data<T>();

    auto* d_out = ctx.Input<LoDTensor>(framework::GradVarName("Out"));
    auto* d_x = ctx.Output<LoDTensor>(framework::GradVarName("X"));
    auto* d_y = ctx.Output<LoDTensor>(framework::GradVarName("Y"));

    Tensor tmp_grad;
    tmp_grad.Resize(tmp->dims());
    auto* d_tmp_data = tmp_grad.mutable_data<T>(ctx.GetPlace());
    auto* top_diff = d_out->data<T>();
    auto* bottom_l_diff = d_x->mutable_data<T>(ctx.GetPlace());
    auto* bottom_r_diff = d_y->mutable_data<T>(ctx.GetPlace());
    auto* bottom_l_trans_diff = const_cast<T*>(d_tmp_data);
    memset(bottom_l_diff, 0.0, x->dims()[0] * x->dims()[1] * sizeof(T));
    memset(bottom_r_diff, 0.0, y->dims()[0] * y->dims()[1] * sizeof(T));
362 363
    memset(
        bottom_l_trans_diff, 0.0, tmp->dims()[0] * tmp->dims()[1] * sizeof(T));
A
Aurelius84 已提交
364 365 366

    for (size_t b = 0; b < x->lod()[0].size() - 1; b++) {
      for (int t = 0; t < dim_t; t++) {
367 368
        size_t len_l = offset_l[b + 1] - offset_l[b];
        size_t len_r = offset_r[b + 1] - offset_r[b];
A
Aurelius84 已提交
369

370 371
        for (size_t i = 0; i < len_l; i++) {
          for (size_t j = 0; j < len_r; j++) {
A
Aurelius84 已提交
372 373 374 375 376 377 378 379 380 381 382
            auto diff =
                top_diff[top_offset[b] + t * len_l * len_r + i * len_r + j];
            auto* l_trans_data = bottom_l_trans_data +
                                 (offset_l[b] + i) * dim_in * dim_t +
                                 t * dim_in;
            auto* l_trans_diff = bottom_l_trans_diff +
                                 (offset_l[b] + i) * dim_in * dim_t +
                                 t * dim_in;
            auto* r_data = bottom_r_data + (offset_r[b] + j) * dim_in;
            auto* r_diff = bottom_r_diff + (offset_r[b] + j) * dim_in;
            if (diff != 0.0) {
383 384
              axpy(r_data, l_trans_diff, dim_in, diff);
              axpy(l_trans_data, r_diff, dim_in, diff);
A
Aurelius84 已提交
385 386 387 388 389 390
            }
          }
        }
      }
    }

L
Leo Chen 已提交
391
    auto blas = phi::funcs::GetBlas<phi::CPUContext, T>(ctx);
A
Aurelius84 已提交
392 393 394 395 396 397

    auto* t_data = w->data<T>();
    auto* d_w = ctx.Output<Tensor>(framework::GradVarName("W"));
    auto* t_diff = d_w->mutable_data<T>(ctx.GetPlace());
    memset(t_diff, 0.0, w->dims()[0] * w->dims()[1] * w->dims()[2] * sizeof(T));
    // bottom_diff
398 399 400 401 402 403 404 405 406 407
    call_gemm(blas,
              CblasNoTrans,
              CblasTrans,
              x->dims()[0],
              dim_in,
              dim_t * dim_in,
              1.0f,
              bottom_l_trans_diff,
              t_data,
              1.0f,
A
Aurelius84 已提交
408 409 410
              bottom_l_diff);

    // t_diff
411 412 413 414 415 416 417 418 419 420
    call_gemm(blas,
              CblasTrans,
              CblasNoTrans,
              dim_in,
              dim_t * dim_in,
              x->dims()[0],
              1.0f,
              bottom_l_data,
              bottom_l_trans_diff,
              1.0f,
A
Aurelius84 已提交
421 422 423 424
              t_diff);
  }
};

425 426 427 428 429 430
template <typename T>
class MatchMatrixTensorGradOpMaker : public framework::SingleGradOpMaker<T> {
 public:
  using framework::SingleGradOpMaker<T>::SingleGradOpMaker;

 protected:
431
  void Apply(GradOpPtr<T> grad_op) const override {
432 433 434 435 436 437 438 439 440 441 442 443 444
    grad_op->SetType("match_matrix_tensor_grad");
    grad_op->SetInput("X", this->Input("X"));
    grad_op->SetInput("Y", this->Input("Y"));
    grad_op->SetInput("W", this->Input("W"));
    grad_op->SetInput("Tmp", this->Output("Tmp"));
    grad_op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
    grad_op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
    grad_op->SetOutput(framework::GradVarName("Y"), this->InputGrad("Y"));
    grad_op->SetOutput(framework::GradVarName("W"), this->InputGrad("W"));
    grad_op->SetAttrMap(this->Attrs());
  }
};

A
Aurelius84 已提交
445 446 447 448
}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
H
hong 已提交
449
REGISTER_OPERATOR(
450 451
    match_matrix_tensor,
    ops::MatchMatrixTensorOP,
H
hong 已提交
452
    ops::MatchMatrixTensorOpMaker,
453 454
    ops::MatchMatrixTensorGradOpMaker<paddle::framework::OpDesc>,
    ops::MatchMatrixTensorGradOpMaker<paddle::imperative::OpBase>);
A
Aurelius84 已提交
455 456
REGISTER_OPERATOR(match_matrix_tensor_grad, ops::MatchMatrixTensorOpGrad);

457 458
REGISTER_OP_CPU_KERNEL(
    match_matrix_tensor,
L
Leo Chen 已提交
459
    ops::CPUMatchMatrixTensorOPKernel<phi::CPUContext, float>);
A
Aurelius84 已提交
460

461 462
REGISTER_OP_CPU_KERNEL(
    match_matrix_tensor_grad,
L
Leo Chen 已提交
463
    ops::CPUMatchMatrixTensorOPGradKernel<phi::CPUContext, float>);