diff --git a/paddle/fluid/operators/fused/fusion_squared_mat_sub_op.cc b/paddle/fluid/operators/fused/fusion_squared_mat_sub_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..c9063bd327c65336f7bb177a733912b91aa5a8b7 --- /dev/null +++ b/paddle/fluid/operators/fused/fusion_squared_mat_sub_op.cc @@ -0,0 +1,137 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ + +#include "paddle/fluid/operators/fused/fusion_squared_mat_sub_op.h" +#include +#include +#include "paddle/fluid/operators/jit/kernels.h" + +namespace paddle { +namespace operators { + +void FusionSquaredMatSubOp::InferShape( + framework::InferShapeContext* ctx) const { + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of FusionSquaredMatSubOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Y"), + "Input(Y) of FusionSquaredMatSubOp should not be null."); + PADDLE_ENFORCE( + ctx->HasOutput("SquaredX"), + "Output(SquaredX) of FusionSquaredMatSubOp should not be null."); + PADDLE_ENFORCE( + ctx->HasOutput("SquaredY"), + "Output(SquaredY) of FusionSquaredMatSubOp should not be null."); + PADDLE_ENFORCE( + ctx->HasOutput("SquaredXY"), + "Output(SquaredXY) of FusionSquaredMatSubOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(Out) of FusionSquaredMatSubOp should not be null."); + + auto x_dims = ctx->GetInputDim("X"); + auto y_dims = ctx->GetInputDim("Y"); + PADDLE_ENFORCE_EQ(x_dims.size(), y_dims.size(), + "Input tensors dims size should be equal."); + PADDLE_ENFORCE_EQ(x_dims.size(), 2UL, "Input tensors should be a Matrix."); + PADDLE_ENFORCE_EQ(x_dims[1], y_dims[0], "Inputs Matrix should be multiply."); + + ctx->SetOutputDim("SquaredX", x_dims); + ctx->SetOutputDim("SquaredY", y_dims); + ctx->SetOutputDim("SquaredXY", {x_dims[0], y_dims[1]}); + ctx->SetOutputDim("Out", {x_dims[0], y_dims[1]}); +} + +framework::OpKernelType FusionSquaredMatSubOp::GetExpectedKernelType( + const framework::ExecutionContext& ctx) const { + return framework::OpKernelType(framework::GetDataTypeOfVar(ctx.InputVar("X")), + ctx.GetPlace()); +} + +void FusionSquaredMatSubOpMaker::Make() { + AddInput("X", "(Tensor) Input Mat A of this operator."); + AddInput("Y", "(Tensor) Input Mat B of this operator."); + AddOutput("SquaredX", "(Tensor) Squared X.").AsIntermediate(); + AddOutput("SquaredY", "(Tensor) Squared Y.").AsIntermediate(); + AddOutput("SquaredXY", "(Tensor) Squared X*Y.").AsIntermediate(); + AddOutput("Out", "(Tensor) Output tensor of concat operator."); + AddAttr("scalar", "The scalar on output matrix.").SetDefault(1.f); + AddComment(R"DOC( + Fusion Squared Matrix and substrct operator. + + ( (A.^2 * B.^2) - (A * B).^2 ) .* scalar +)DOC"); +} + +template +class FusionSquaredMatSubKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto x = ctx.Input("X"); + auto y = ctx.Input("Y"); + auto* squared_x = ctx.Output("SquaredX"); + auto* squared_y = ctx.Output("SquaredY"); + auto* squared_xy = ctx.Output("SquaredXY"); + auto* out = ctx.Output("Out"); + auto place = ctx.GetPlace(); + T scalar = static_cast(ctx.Attr("scalar")); + + auto x_dims = x->dims(); + auto y_dims = y->dims(); + int m = x_dims[0]; + int k = x_dims[1]; + int n = y_dims[1]; + int o_numel = m * n; + + auto vsquare_x = + jit::Get, platform::CPUPlace>(m * k); + auto vsquare_y = + jit::Get, platform::CPUPlace>(k * n); + auto vsquare_xy = + jit::Get, platform::CPUPlace>(o_numel); + auto vsub = + jit::Get, platform::CPUPlace>(o_numel); + auto vscal = + jit::Get, platform::CPUPlace>(o_numel); + auto matmul = + jit::Get, platform::CPUPlace>(k); + + const T* x_data = x->data(); + const T* y_data = y->data(); + T* squared_x_data = squared_x->mutable_data(place); + T* squared_y_data = squared_y->mutable_data(place); + T* squared_xy_data = squared_xy->mutable_data(place); + T* o_data = out->mutable_data(place); + + vsquare_x(x_data, squared_x_data, m * k); + vsquare_y(y_data, squared_y_data, k * n); + + matmul(x_data, y_data, o_data, m, n, k); + vsquare_xy(o_data, squared_xy_data, o_numel); + + matmul(squared_x_data, squared_y_data, o_data, m, n, k); + vsub(o_data, squared_xy_data, o_data, o_numel); + vscal(&scalar, o_data, o_data, o_numel); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(fusion_squared_mat_sub, ops::FusionSquaredMatSubOp, + ops::FusionSquaredMatSubOpMaker, + paddle::framework::DefaultGradOpDescMaker); + +REGISTER_OP_CPU_KERNEL(fusion_squared_mat_sub, + ops::FusionSquaredMatSubKernel, + ops::FusionSquaredMatSubKernel); diff --git a/paddle/fluid/operators/fused/fusion_squared_mat_sub_op.h b/paddle/fluid/operators/fused/fusion_squared_mat_sub_op.h new file mode 100644 index 0000000000000000000000000000000000000000..0ab2c2bb10a15cc6d9a472142416bd363e65944f --- /dev/null +++ b/paddle/fluid/operators/fused/fusion_squared_mat_sub_op.h @@ -0,0 +1,42 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ + +#pragma once +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { + +using LoDTensor = framework::LoDTensor; +using Tensor = framework::Tensor; + +// ( (A.^2 * B.^2) - (A * B).^2 ) .* scalar +class FusionSquaredMatSubOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override; + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override; +}; + +class FusionSquaredMatSubOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override; +}; + +} // namespace operators +} // namespace paddle