fusion_squared_mat_sub_op.cc 5.5 KB
Newer Older
T
tensor-tang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License. */

#include "paddle/fluid/operators/fused/fusion_squared_mat_sub_op.h"
#include <string>
#include <vector>
#include "paddle/fluid/operators/jit/kernels.h"

namespace paddle {
namespace operators {

void FusionSquaredMatSubOp::InferShape(
    framework::InferShapeContext* ctx) const {
  PADDLE_ENFORCE(ctx->HasInput("X"),
                 "Input(X) of FusionSquaredMatSubOp should not be null.");
  PADDLE_ENFORCE(ctx->HasInput("Y"),
                 "Input(Y) of FusionSquaredMatSubOp should not be null.");
  PADDLE_ENFORCE(
      ctx->HasOutput("SquaredX"),
      "Output(SquaredX) of FusionSquaredMatSubOp should not be null.");
  PADDLE_ENFORCE(
      ctx->HasOutput("SquaredY"),
      "Output(SquaredY) of FusionSquaredMatSubOp should not be null.");
  PADDLE_ENFORCE(
      ctx->HasOutput("SquaredXY"),
      "Output(SquaredXY) of FusionSquaredMatSubOp should not be null.");
  PADDLE_ENFORCE(ctx->HasOutput("Out"),
                 "Output(Out) of FusionSquaredMatSubOp should not be null.");

  auto x_dims = ctx->GetInputDim("X");
  auto y_dims = ctx->GetInputDim("Y");
  PADDLE_ENFORCE_EQ(x_dims.size(), y_dims.size(),
                    "Input tensors dims size should be equal.");
  PADDLE_ENFORCE_EQ(x_dims.size(), 2UL, "Input tensors should be a Matrix.");
  PADDLE_ENFORCE_EQ(x_dims[1], y_dims[0], "Inputs Matrix should be multiply.");

  ctx->SetOutputDim("SquaredX", x_dims);
  ctx->SetOutputDim("SquaredY", y_dims);
  ctx->SetOutputDim("SquaredXY", {x_dims[0], y_dims[1]});
  ctx->SetOutputDim("Out", {x_dims[0], y_dims[1]});
}

framework::OpKernelType FusionSquaredMatSubOp::GetExpectedKernelType(
    const framework::ExecutionContext& ctx) const {
  return framework::OpKernelType(framework::GetDataTypeOfVar(ctx.InputVar("X")),
                                 ctx.GetPlace());
}

void FusionSquaredMatSubOpMaker::Make() {
  AddInput("X", "(Tensor) Input Mat A of this operator.");
  AddInput("Y", "(Tensor) Input Mat B of this operator.");
  AddOutput("SquaredX", "(Tensor) Squared X.").AsIntermediate();
  AddOutput("SquaredY", "(Tensor) Squared Y.").AsIntermediate();
  AddOutput("SquaredXY", "(Tensor) Squared X*Y.").AsIntermediate();
  AddOutput("Out", "(Tensor) Output tensor of concat operator.");
  AddAttr<float>("scalar", "The scalar on output matrix.").SetDefault(1.f);
  AddComment(R"DOC(
    Fusion Squared Matrix and substrct operator.
    
71
    ( (X * Y).^2 - (X.^2 * Y.^2) ) .* scalar
T
tensor-tang 已提交
72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89
)DOC");
}

template <typename T>
class FusionSquaredMatSubKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
    auto x = ctx.Input<Tensor>("X");
    auto y = ctx.Input<Tensor>("Y");
    auto* squared_x = ctx.Output<Tensor>("SquaredX");
    auto* squared_y = ctx.Output<Tensor>("SquaredY");
    auto* squared_xy = ctx.Output<Tensor>("SquaredXY");
    auto* out = ctx.Output<Tensor>("Out");
    auto place = ctx.GetPlace();
    T scalar = static_cast<T>(ctx.Attr<float>("scalar"));

    auto x_dims = x->dims();
    auto y_dims = y->dims();
90 91 92 93 94
    jit::matmul_attr_t attr;
    attr.m = x_dims[0];
    attr.k = x_dims[1];
    attr.n = y_dims[1];
    int o_numel = attr.m * attr.n;
T
tensor-tang 已提交
95 96

    auto vsquare_x =
97 98
        jit::Get<jit::kVSquare, jit::XYNTuples<T>, platform::CPUPlace>(attr.m *
                                                                       attr.k);
T
tensor-tang 已提交
99
    auto vsquare_y =
100 101
        jit::Get<jit::kVSquare, jit::XYNTuples<T>, platform::CPUPlace>(attr.k *
                                                                       attr.n);
T
tensor-tang 已提交
102 103 104 105 106 107 108
    auto vsquare_xy =
        jit::Get<jit::kVSquare, jit::XYNTuples<T>, platform::CPUPlace>(o_numel);
    auto vsub =
        jit::Get<jit::kVSub, jit::XYZNTuples<T>, platform::CPUPlace>(o_numel);
    auto vscal =
        jit::Get<jit::kVScal, jit::AXYNTuples<T>, platform::CPUPlace>(o_numel);
    auto matmul =
109
        jit::Get<jit::kMatMul, jit::MatMulTuples<T>, platform::CPUPlace>(attr);
T
tensor-tang 已提交
110 111 112 113 114 115 116 117

    const T* x_data = x->data<T>();
    const T* y_data = y->data<T>();
    T* squared_x_data = squared_x->mutable_data<T>(place);
    T* squared_y_data = squared_y->mutable_data<T>(place);
    T* squared_xy_data = squared_xy->mutable_data<T>(place);
    T* o_data = out->mutable_data<T>(place);

118
    matmul(x_data, y_data, squared_xy_data, &attr);
119 120
    vsquare_xy(squared_xy_data, squared_xy_data, o_numel);

121 122 123
    vsquare_x(x_data, squared_x_data, attr.m * attr.k);
    vsquare_y(y_data, squared_y_data, attr.k * attr.n);
    matmul(squared_x_data, squared_y_data, o_data, &attr);
124 125

    vsub(squared_xy_data, o_data, o_data, o_numel);
T
tensor-tang 已提交
126 127 128 129 130 131 132 133 134 135 136 137 138 139 140
    vscal(&scalar, o_data, o_data, o_numel);
  }
};

}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
REGISTER_OPERATOR(fusion_squared_mat_sub, ops::FusionSquaredMatSubOp,
                  ops::FusionSquaredMatSubOpMaker,
                  paddle::framework::DefaultGradOpDescMaker<true>);

REGISTER_OP_CPU_KERNEL(fusion_squared_mat_sub,
                       ops::FusionSquaredMatSubKernel<float>,
                       ops::FusionSquaredMatSubKernel<double>);