diff --git a/paddle/fluid/operators/fused/fusion_repeated_fc_relu_op.cc b/paddle/fluid/operators/fused/fusion_repeated_fc_relu_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..4e9a5ec41262a12701fd46360d8b6551913705d9 --- /dev/null +++ b/paddle/fluid/operators/fused/fusion_repeated_fc_relu_op.cc @@ -0,0 +1,149 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ + +#include "paddle/fluid/operators/fused/fusion_repeated_fc_relu_op.h" +#include +#include +#include "paddle/fluid/operators/jit/kernels.h" + +namespace paddle { +namespace operators { + +void FusionRepeatedFCReluOp::InferShape( + framework::InferShapeContext* ctx) const { + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of FusionRepeatedFCReluOp should not be null."); + auto sz = ctx->Inputs("W").size(); + PADDLE_ENFORCE_GT( + sz, 1UL, "Inputs(W) of FusionRepeatedFCReluOp should larger than 1."); + PADDLE_ENFORCE_EQ(ctx->Inputs("Bias").size(), sz, + "Size of inputs(Bias) of FusionRepeatedFCReluOp should be " + "equal to inputs size."); + PADDLE_ENFORCE_EQ(ctx->Outputs("ReluOut").size(), sz - 1, + "Size of output(ReluOut) of FusionRepeatedFCReluOp should " + "be equal to inputs size -1."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(Out) of FusionRepeatedFCReluOp should not be null."); + + auto i_dims = ctx->GetInputDim("X"); + PADDLE_ENFORCE_EQ(i_dims.size(), 2UL, "Input shape size should be 2"); + + auto w_dims = ctx->GetInputsDim("W"); + auto b_dims = ctx->GetInputsDim("Bias"); + PADDLE_ENFORCE_EQ(w_dims.size(), b_dims.size(), + "Shape size of weight and bias should be equal"); + PADDLE_ENFORCE_EQ(w_dims.size(), sz, + "Shape size of weight and bias should be equal"); + PADDLE_ENFORCE_EQ(i_dims[1], w_dims[0][0], + "inpute width should be equal with weight height"); + + for (size_t i = 1; i < sz; ++i) { + PADDLE_ENFORCE_EQ(w_dims[i].size(), 2UL, + "Every weight shape size should be 2."); + PADDLE_ENFORCE_EQ(framework::product(b_dims[i]), w_dims[i][1], + "The length of Bias must be equal with w_dims[1]."); + } + ctx->SetOutputDim("Out", {i_dims[0], w_dims[sz - 1][1]}); + ctx->ShareLoD("X", /*->*/ "Out"); +} + +framework::OpKernelType FusionRepeatedFCReluOp::GetExpectedKernelType( + const framework::ExecutionContext& ctx) const { + return framework::OpKernelType(framework::GetDataTypeOfVar(ctx.InputVar("X")), + ctx.GetPlace()); +} + +void FusionRepeatedFCReluOpMaker::Make() { + AddInput("X", "(LoDTensor) Input tensors of this operator."); + AddInput("W", "(Tensor) The weight tensors of this operator.").AsDuplicable(); + AddInput("Bias", "(Tensor) The bias tensors of this operator.") + .AsDuplicable(); + AddOutput("ReluOut", "(Tensor) The output tensor of each relu operator.") + .AsDuplicable() + .AsIntermediate(); + AddOutput("Out", "(LoDTensor) Output tensor of this operator."); + AddComment(R"DOC( + Fusion Repeated FC with Relu Operator. +)DOC"); +} + +template +static void fc_relu(const T* x, const T* w, const T* b, T* y, int m, int n, + int k) { + auto matmul = + jit::Get, platform::CPUPlace>(k); + auto addbias_relu = + jit::Get, platform::CPUPlace>(n); + matmul(x, w, y, m, n, k); + T* dst = y; + for (int i = 0; i < m; ++i) { + addbias_relu(b, dst, dst, n); + dst += n; + } +} + +template +class FusionRepeatedFCReluKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto in = ctx.Input("X"); + auto weights = ctx.MultiInput("W"); + auto biases = ctx.MultiInput("Bias"); + auto relus = ctx.MultiOutput("ReluOut"); + auto* out = ctx.Output("Out"); + auto place = ctx.GetPlace(); + int weight_sz = static_cast(weights.size()); + + auto i_dims = in->dims(); + auto w_dims = weights[0]->dims(); + int m = i_dims[0]; + int n = w_dims[1]; + int k = w_dims[0]; + relus[0]->Resize({m, n}); + fc_relu(in->data(), weights[0]->data(), biases[0]->data(), + relus[0]->mutable_data(place), m, n, k); + + for (int i = 1; i < weight_sz - 1; ++i) { + auto i_dims = relus[i - 1]->dims(); + auto w_dims = weights[i]->dims(); + int m = i_dims[0]; + int n = w_dims[1]; + int k = w_dims[0]; + relus[i - 1]->Resize({m, n}); + fc_relu(relus[i - 1]->data(), weights[i]->data(), + biases[i]->data(), relus[i]->mutable_data(place), m, n, k); + } + + auto i_dims_last = relus[weight_sz - 2]->dims(); + auto w_dims_last = weights[weight_sz - 1]->dims(); + m = i_dims_last[0]; + n = w_dims_last[1]; + k = w_dims_last[0]; + fc_relu(relus[weight_sz - 2]->data(), weights[weight_sz - 1]->data(), + biases[weight_sz - 1]->data(), out->mutable_data(place), m, n, + k); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(fusion_repeated_fc_relu, ops::FusionRepeatedFCReluOp, + ops::FusionRepeatedFCReluOpMaker, + paddle::framework::DefaultGradOpDescMaker); + +REGISTER_OP_CPU_KERNEL(fusion_repeated_fc_relu, + ops::FusionRepeatedFCReluKernel, + ops::FusionRepeatedFCReluKernel); diff --git a/paddle/fluid/operators/fused/fusion_repeated_fc_relu_op.h b/paddle/fluid/operators/fused/fusion_repeated_fc_relu_op.h new file mode 100644 index 0000000000000000000000000000000000000000..cdcaf8b4833464100ed579a5962c60013edecdb0 --- /dev/null +++ b/paddle/fluid/operators/fused/fusion_repeated_fc_relu_op.h @@ -0,0 +1,41 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ + +#pragma once +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { + +using LoDTensor = framework::LoDTensor; +using Tensor = framework::Tensor; + +class FusionRepeatedFCReluOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override; + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override; +}; + +class FusionRepeatedFCReluOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override; +}; + +} // namespace operators +} // namespace paddle