fill_diagonal_op.cc 8.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/fill_diagonal_op.h"

namespace paddle {
namespace operators {

int64_t CalStride(framework::DDim dim) {
  int rank = dim.size();
  int64_t dimsum = 1;
  int64_t strides = 0;
  for (int i = rank - 1; i >= 0; i--) {
    strides += dimsum;
    dimsum *= dim[i];
  }
  return strides;
}

class FillIDiagonalOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
  void Make() override {
    AddComment(R"DOC(Fill replace operator
                Fill the diagonal of an tensor with 'value'.
                )DOC");
    AddInput("X", "(Tensor) The input tensor.");
    AddOutput("Out",
              "Tensor, the output tensor, with the same shape and data type "
              "as input(x)");
    AddAttr<float>(
        "value",
        "The float values of tensor, whose dim is one, and no need of grad")
        .SetDefault(0);
    AddAttr<bool>("wrap",
                  "the diagonal 'wrapped' after N columns for tall matrices")
        .SetDefault(false);
    AddAttr<int>("offset",
                 "offset of diagonal, zero means no offset, positive means "
                 "offset to up-right corner; negtive means offset to "
                 "bottom-left corner")
        .SetDefault(0);
  }
};

class FillIDiagonalOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

  void InferShape(framework::InferShapeContext *context) const override {
    OP_INOUT_CHECK(context->HasInput("X"), "Input", "X", "FillIDiagonal");
    OP_INOUT_CHECK(context->HasOutput("Out"), "Output", "Out", "FillIDiagonal");
    auto x_dims = context->GetInputDim("X");
    context->SetOutputDim("Out", x_dims);
  }

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext &ctx) const override {
    return framework::OpKernelType(
        OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
  }
};

class FillIDiagonalOpVarTypeInference : public framework::VarTypeInference {
 public:
  void operator()(framework::InferVarTypeContext *ctx) const override {
    auto var_type = ctx->GetInputType("X", 0);
    auto data_type = ctx->GetInputDataType("X", 0);
    ctx->SetOutputType("Out", var_type, framework::ALL_ELEMENTS);
    ctx->SetOutputDataType("Out", data_type, framework::ALL_ELEMENTS);
  }
};

template <typename T>
class FillIDiagonalKernel : public framework::OpKernel<T> {
 public:
  void Compute(const paddle::framework::ExecutionContext &ctx) const override {
    auto fill_val = ctx.template Attr<float>("value");
    auto *out = ctx.Output<framework::Tensor>("Out");
    auto offset = ctx.Attr<int>("offset");
    auto wrap = ctx.Attr<bool>("wrap");

    auto *xin = ctx.Input<framework::Tensor>("X");

    T temp_var = static_cast<T>(fill_val);

    T *out_data = out->mutable_data<T>(ctx.GetPlace());
    framework::TensorCopy(*xin, ctx.GetPlace(), out);

    auto out_dims = out->dims();
    auto strides = CalStride(out_dims);
    auto size = out->numel();

    // The wrap mode supported only the dims equels to 2; In wrap mode, the
    // value will be filled in cycles
    if (!wrap) {
      size = std::min(size, out_dims[1] * out_dims[1]);
    }

111 112 113 114 115 116 117 118 119
    for (int64_t i = 0; i < size; i += strides) {
      // to check if the new position with offset is still in the same line;
      // this modify should not affect across lines.
      // out_dims[1] is also work for tensor with dim>2, for which the dims must
      // be the same number
      if (i % out_dims[1] + offset >= 0 &&
          i % out_dims[1] + offset < out_dims[1]) {
        out_data[i + offset] = temp_var;
      }
120 121 122 123 124 125 126 127 128
    }
  }
};

class FillIDiagonalGradOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

  void InferShape(framework::InferShapeContext *ctx) const override {
129 130 131 132
    OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")),
                   "Input",
                   "Out@GRAD",
                   "mul");
133 134 135 136 137 138 139 140 141 142
    auto x_dims = ctx->GetInputDim(framework::GradVarName("Out"));
    auto x_grad_name = framework::GradVarName("X");
    if (ctx->HasOutput(x_grad_name)) {
      ctx->SetOutputDim(x_grad_name, x_dims);
    }
  }

  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext &ctx) const override {
    // Note: don't get data type from ctx.Input<framework::Tensor>("Input");
143 144
    auto dtype = framework::TransToProtoVarType(
        ctx.Input<framework::Tensor>(framework::GradVarName("Out"))->type());
145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187
    return framework::OpKernelType(dtype, ctx.GetPlace());
  }
};

template <typename T>
class FillIDiagonalGradOpMaker : public framework::SingleGradOpMaker<T> {
 public:
  using framework::SingleGradOpMaker<T>::SingleGradOpMaker;

 protected:
  void Apply(GradOpPtr<T> retv) const override {
    retv->SetType("fill_diagonal_grad");
    retv->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
    retv->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
    retv->SetAttrMap(this->Attrs());
  }
};

template <typename T>
class FillIDiagonalGradKernel : public framework::OpKernel<T> {
 public:
  void Compute(const paddle::framework::ExecutionContext &ctx) const override {
    auto *dx = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
    auto *dout = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));

    auto offset = ctx.Attr<int>("offset");
    auto wrap = ctx.Attr<bool>("wrap");

    if (dx) {
      auto *data = dx->mutable_data<T>(ctx.GetPlace());
      framework::TensorCopy(*dout, ctx.GetPlace(), dx);

      auto dx_dims = dx->dims();
      auto strides = CalStride(dx_dims);
      auto size = dx->numel();
      auto wrapsize = std::min(size, dx_dims[1] * dx_dims[1]);

      // The wrap mode supported only the dims equels to 2; In wrap mode, the
      // value will be filled in cycles
      if (wrap) {
        wrapsize = size;
      }

188 189 190 191 192
      for (int64_t i = 0; i < wrapsize; i += strides) {
        if (i % dx_dims[1] + offset >= 0 &&
            i % dx_dims[1] + offset < dx_dims[1]) {
          data[i + offset] = T(0);
        }
193 194 195 196 197 198 199 200 201 202 203 204 205 206
      }
    }
  }
};

DECLARE_INPLACE_OP_INFERER(FillIDiagonalOpInplaceInferer, {"X", "Out"});
DECLARE_INPLACE_OP_INFERER(FillIDiagonalGradOpInplaceInferer,
                           {framework::GradVarName("Out"),
                            framework::GradVarName("X")});

}  // namespace operators
}  // namespace paddle
namespace ops = paddle::operators;

207 208
REGISTER_OPERATOR(fill_diagonal,
                  ops::FillIDiagonalOp,
209 210 211 212 213 214
                  ops::FillIDiagonalOpMaker,
                  ops::FillIDiagonalOpVarTypeInference,
                  ops::FillIDiagonalGradOpMaker<paddle::framework::OpDesc>,
                  ops::FillIDiagonalGradOpMaker<paddle::imperative::OpBase>,
                  ops::FillIDiagonalOpInplaceInferer);

215 216
REGISTER_OPERATOR(fill_diagonal_grad,
                  ops::FillIDiagonalGradOp,
217 218
                  ops::FillIDiagonalGradOpInplaceInferer);

219 220
REGISTER_OP_CPU_KERNEL(fill_diagonal,
                       ops::FillIDiagonalKernel<float>,
221 222 223 224 225 226
                       ops::FillIDiagonalKernel<double>,
                       ops::FillIDiagonalKernel<int64_t>,
                       ops::FillIDiagonalKernel<int>,
                       ops::FillIDiagonalKernel<paddle::platform::float16>,
                       ops::FillIDiagonalKernel<bool>);

227 228
REGISTER_OP_CPU_KERNEL(fill_diagonal_grad,
                       ops::FillIDiagonalGradKernel<float>,
229 230 231 232 233
                       ops::FillIDiagonalGradKernel<double>,
                       ops::FillIDiagonalGradKernel<int64_t>,
                       ops::FillIDiagonalGradKernel<int>,
                       ops::FillIDiagonalGradKernel<paddle::platform::float16>,
                       ops::FillIDiagonalGradKernel<bool>);