fill_diagonal_tensor_op.cc 10.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/fill_diagonal_tensor_op.h"

namespace paddle {
namespace operators {

// calculate the offset\new_dims\(strides of dim1/dim2)\matoffset
void CalMatDims(framework::DDim out_dims, int dim1, int dim2, int64_t *offset,
                int64_t *new_dims, int64_t *strides, int64_t *matoffset) {
  int64_t dimprod = 1, batchdim = 1;
  int rank = out_dims.size();
  int matoffidx = 0;
  for (int i = rank - 1; i >= 0; i--) {
    if (i == dim2) {
      strides[0] = dimprod;
    } else if (i == dim1) {
      strides[1] = dimprod;
    } else {
      batchdim *= out_dims[i];
      // matoffset calculate the offset position of the diagonal defined by dim1
      // and dim2
      // the first circle calculate the final free dimension
      // and then calculate the front free dim one by one
      if (matoffidx == 0) {
        for (int64_t j = 0; j < out_dims[i]; j++) {
          matoffset[matoffidx] = dimprod * j;
          matoffidx++;
        }
      } else {
        auto size = matoffidx;
        for (int64_t j = 1; j < out_dims[i]; j++) {
          for (int64_t k = 0; k < size; k++) {
            matoffset[matoffidx] = matoffset[k] + dimprod * j;
            matoffidx++;
          }
        }
      }
    }
    dimprod *= out_dims[i];
  }

  auto diagdim = dim1;
  if (*offset >= 0) {
    diagdim = std::min(out_dims[dim1], out_dims[dim2] - *offset);
    *offset *= strides[0];
  } else {
    diagdim = std::min(out_dims[dim1] + *offset, out_dims[dim2]);
    *offset *= -strides[1];
  }
  new_dims[0] = batchdim;
  new_dims[1] = diagdim;
  return;
}

class FillDiagonalTensorOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
  void Make() override {
    AddComment(R"DOC(Fill replace operator
                Fill the diagonal of an tensor with `Y` Tensor.
                )DOC");
    AddInput("X", "(Tensor) The input tensor.");
    AddInput("Y", "(Tensor) The input tensor to fill in.");
    AddOutput("Out",
              "Tensor, the output tensor, with the same shape and data type "
              "as input(x)");
    AddAttr<int>("dim1", "the first dim to figure out the diagonal")
        .SetDefault(0);
    AddAttr<int>("dim2", "the second dim to figure out the diagonal")
        .SetDefault(1);
    AddAttr<int64_t>("offset",
                     "offset of diagonal, zero means no offset, positive means "
                     "offset to up-right corner; negtive means offset to "
                     "bottom-left corner")
        .SetDefault(0);
  }
};

class FillDiagonalTensorOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

  void InferShape(framework::InferShapeContext *context) const override {
    OP_INOUT_CHECK(context->HasInput("X"), "Input", "X", "FillDiagonalTensor");
    OP_INOUT_CHECK(context->HasOutput("Out"), "Output", "Out",
                   "FillDiagonalTensor");
    auto x_dims = context->GetInputDim("X");
    context->SetOutputDim("Out", x_dims);
  }

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext &ctx) const override {
    return framework::OpKernelType(
        OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
  }
};

class FillDiagonalTensorOpVarTypeInference
    : public framework::VarTypeInference {
 public:
  void operator()(framework::InferVarTypeContext *ctx) const override {
    auto var_type = ctx->GetInputType("X", 0);
    auto data_type = ctx->GetInputDataType("X", 0);
    ctx->SetOutputType("Out", var_type, framework::ALL_ELEMENTS);
    ctx->SetOutputDataType("Out", data_type, framework::ALL_ELEMENTS);
  }
};

template <typename T>
class FillDiagonalTensorKernel : public framework::OpKernel<T> {
 public:
  void Compute(const paddle::framework::ExecutionContext &ctx) const override {
    auto *out = ctx.Output<framework::Tensor>("Out");
    auto *srctensor = ctx.Input<framework::Tensor>("Y");
    auto dim1 = ctx.Attr<int>("dim1");
    auto dim2 = ctx.Attr<int>("dim2");
    auto offset = ctx.Attr<int64_t>("offset");
    auto *xin = ctx.Input<framework::Tensor>("X");

    T *out_data = out->mutable_data<T>(ctx.GetPlace());
    const T *fill_data = srctensor->data<T>();

    framework::TensorCopy(*xin, ctx.GetPlace(), out);
    auto out_dims = out->dims();
    auto matdims = srctensor->dims();
139
    auto fill_dims = phi::flatten_to_2d(matdims, matdims.size() - 1);
140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189

    int64_t new_dims[2], strides[2];
    std::vector<int64_t> matdim;
    matdim.resize(fill_dims[0]);
    CalMatDims(out_dims, dim1, dim2, &offset, new_dims, strides, matdim.data());
    PADDLE_ENFORCE_EQ(
        new_dims[0], fill_dims[0],
        platform::errors::InvalidArgument("The dims should be %d x %d, but get "
                                          "%d x %d in fill tensor Y",
                                          new_dims[0], new_dims[1],
                                          fill_dims[0], fill_dims[1]));
    PADDLE_ENFORCE_EQ(
        new_dims[1], fill_dims[1],
        platform::errors::InvalidArgument("The dims should be %d x %d, but get "
                                          "%d x %d in fill tensor Y",
                                          new_dims[0], new_dims[1],
                                          fill_dims[0], fill_dims[1]));

    auto size = out->numel();
    for (int64_t i = 0; i < fill_dims[0]; i += 1) {
      auto sumoff = matdim[i] + offset;
      for (int64_t j = 0; j < fill_dims[1]; j += 1) {
        auto fill_index = j * (strides[1] + strides[0]) + sumoff;
        if (fill_index < size) {
          out_data[fill_index] = fill_data[i * fill_dims[1] + j];
        }
      }
    }
  }
};

class FillDiagonalTensorGradOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

  void InferShape(framework::InferShapeContext *ctx) const override {
    OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input",
                   "Out@GRAD", "mul");
    auto x_dims = ctx->GetInputDim(framework::GradVarName("Out"));
    auto x_grad_name = framework::GradVarName("X");
    if (ctx->HasOutput(x_grad_name)) {
      ctx->SetOutputDim(x_grad_name, x_dims);
    }
  }

  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext &ctx) const override {
    // Note: don't get data type from ctx.Input<framework::Tensor>("Input");
    auto dtype =
        ctx.Input<framework::Tensor>(framework::GradVarName("Out"))->type();
190 191
    return framework::OpKernelType(framework::TransToProtoVarType(dtype),
                                   ctx.GetPlace());
192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293
  }
};

template <typename T>
class FillDiagonalTensorGradOpMaker : public framework::SingleGradOpMaker<T> {
 public:
  using framework::SingleGradOpMaker<T>::SingleGradOpMaker;

 protected:
  void Apply(GradOpPtr<T> retv) const override {
    retv->SetType("fill_diagonal_tensor_grad");
    retv->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
    retv->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
    retv->SetAttrMap(this->Attrs());
  }
};

template <typename T>
class FillDiagonalTensorGradKernel : public framework::OpKernel<T> {
 public:
  void Compute(const paddle::framework::ExecutionContext &ctx) const override {
    auto *dx = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
    auto *dout = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));

    auto dim1 = ctx.Attr<int>("dim1");
    auto dim2 = ctx.Attr<int>("dim2");
    auto offset = ctx.Attr<int64_t>("offset");
    auto matrows = 1;

    if (dx) {
      auto *data = dx->mutable_data<T>(ctx.GetPlace());

      auto dx_dims = dx->dims();
      for (int i = 0; i < dx_dims.size(); i++) {
        if (i != dim1 && i != dim2) {
          matrows *= dx_dims[i];
        }
      }

      int64_t new_dims[2], strides[2];
      std::vector<int64_t> matdim;
      matdim.resize(matrows);
      CalMatDims(dx_dims, dim1, dim2, &offset, new_dims, strides,
                 matdim.data());

      auto size = dx->numel();
      framework::TensorCopy(*dout, ctx.GetPlace(), dx);

      for (int64_t i = 0; i < new_dims[0]; i += 1) {
        auto sumoff = matdim[i] + offset;
        for (int64_t j = 0; j < new_dims[1]; j += 1) {
          auto fill_index = j * (strides[1] + strides[0]) + sumoff;
          if (fill_index < size) {
            data[fill_index] = 0;
          }
        }
      }
    }
  }
};

DECLARE_INPLACE_OP_INFERER(FillDiagonalTensorOpInplaceInferer, {"X", "Out"});
DECLARE_INPLACE_OP_INFERER(FillDiagonalTensorGradOpInplaceInferer,
                           {framework::GradVarName("Out"),
                            framework::GradVarName("X")});

}  // namespace operators
}  // namespace paddle
namespace ops = paddle::operators;

REGISTER_OPERATOR(
    fill_diagonal_tensor, ops::FillDiagonalTensorOp,
    ops::FillDiagonalTensorOpMaker, ops::FillDiagonalTensorOpVarTypeInference,
    ops::FillDiagonalTensorGradOpMaker<paddle::framework::OpDesc>,
    ops::FillDiagonalTensorGradOpMaker<paddle::imperative::OpBase>,
    ops::FillDiagonalTensorOpInplaceInferer);

REGISTER_OPERATOR(fill_diagonal_tensor_grad, ops::FillDiagonalTensorGradOp,
                  ops::FillDiagonalTensorGradOpInplaceInferer);

REGISTER_OP_CPU_KERNEL(
    fill_diagonal_tensor, ops::FillDiagonalTensorKernel<float>,
    ops::FillDiagonalTensorKernel<double>,
    ops::FillDiagonalTensorKernel<int64_t>, ops::FillDiagonalTensorKernel<int>,
    ops::FillDiagonalTensorKernel<int8_t>,
    ops::FillDiagonalTensorKernel<uint8_t>,
    ops::FillDiagonalTensorKernel<paddle::platform::float16>,
    ops::FillDiagonalTensorKernel<paddle::platform::complex<float>>,
    ops::FillDiagonalTensorKernel<paddle::platform::complex<double>>,
    ops::FillDiagonalTensorKernel<bool>);

REGISTER_OP_CPU_KERNEL(
    fill_diagonal_tensor_grad, ops::FillDiagonalTensorGradKernel<float>,
    ops::FillDiagonalTensorGradKernel<double>,
    ops::FillDiagonalTensorGradKernel<int64_t>,
    ops::FillDiagonalTensorGradKernel<int>,
    ops::FillDiagonalTensorGradKernel<int8_t>,
    ops::FillDiagonalTensorGradKernel<uint8_t>,
    ops::FillDiagonalTensorGradKernel<paddle::platform::float16>,
    ops::FillDiagonalTensorGradKernel<paddle::platform::complex<float>>,
    ops::FillDiagonalTensorGradKernel<paddle::platform::complex<double>>,
    ops::FillDiagonalTensorGradKernel<bool>);