From a65c728e5d297d832b9a406510643695555dc5b0 Mon Sep 17 00:00:00 2001 From: Yiqun Liu Date: Wed, 11 Sep 2019 16:32:29 +0800 Subject: [PATCH] Implement the GPU kernel of fc operator (#19687) * Refine the codes related to fc op. * Add GPU implementation for fc functor. * Apply fc_fuse_pass in GPU inference. test=develop * Change the cmake for fc op. * Change PADDLE_ENFORCE to PADDLE_ENFORCE_EQ. * Add an attribute to set the activation type in fc_op. * Enhance the unittest of fc_op. test=develop * Remove the declaration of FCOpGrad back to the header file. test=develop * Set default value for newly added arguments in test_fc_op. test=develop --- cmake/operators.cmake | 3 - .../ir/embedding_fc_lstm_fuse_pass.cc | 1 - .../inference/api/paddle_pass_builder.cc | 3 +- paddle/fluid/operators/CMakeLists.txt | 2 +- paddle/fluid/operators/attention_lstm_op.cc | 14 +- paddle/fluid/operators/fc_op.cc | 200 ++++++++---------- paddle/fluid/operators/fc_op.cu.cc | 20 ++ paddle/fluid/operators/fc_op.h | 52 +++-- .../fused/fused_embedding_fc_lstm_op.cc | 1 - paddle/fluid/operators/fused/fusion_gru_op.cc | 22 +- .../fluid/operators/fused/fusion_lstm_op.cc | 17 +- .../fused/fusion_seqconv_eltadd_relu_op.cc | 8 +- .../fused/fusion_seqexpand_concat_fc_op.cc | 9 +- paddle/fluid/operators/math/CMakeLists.txt | 1 + paddle/fluid/operators/math/fc.cc | 62 ++++++ paddle/fluid/operators/math/fc.cu | 73 +++++++ paddle/fluid/operators/math/fc.h | 34 +++ paddle/fluid/operators/math/fc_compute.h | 55 ----- .../fluid/tests/unittests/test_fc_op.py | 90 ++++---- 19 files changed, 415 insertions(+), 252 deletions(-) create mode 100644 paddle/fluid/operators/fc_op.cu.cc create mode 100644 paddle/fluid/operators/math/fc.cc create mode 100644 paddle/fluid/operators/math/fc.cu create mode 100644 paddle/fluid/operators/math/fc.h delete mode 100644 paddle/fluid/operators/math/fc_compute.h diff --git a/cmake/operators.cmake b/cmake/operators.cmake index 134c894392..f43d284ad0 100644 --- a/cmake/operators.cmake +++ b/cmake/operators.cmake @@ -191,9 +191,6 @@ function(op_library TARGET) file(APPEND ${pybind_file} "USE_OP(fake_quantize_abs_max);\n") elseif(${TARGET} STREQUAL "tensorrt_engine_op") message(STATUS "Pybind skips [tensorrt_engine_op], for this OP is only used in inference") - elseif(${TARGET} STREQUAL "fc") - # HACK: fc only have mkldnn and cpu, which would mismatch the cpu only condition - file(APPEND ${pybind_file} "USE_CPU_ONLY_OP(${TARGET});\n") else() file(APPEND ${pybind_file} "USE_OP(${TARGET});\n") endif() diff --git a/paddle/fluid/framework/ir/embedding_fc_lstm_fuse_pass.cc b/paddle/fluid/framework/ir/embedding_fc_lstm_fuse_pass.cc index 6462e7bf4c..b29b37992d 100644 --- a/paddle/fluid/framework/ir/embedding_fc_lstm_fuse_pass.cc +++ b/paddle/fluid/framework/ir/embedding_fc_lstm_fuse_pass.cc @@ -21,7 +21,6 @@ #include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/cpu_vec.h" -#include "paddle/fluid/operators/math/fc_compute.h" #include "paddle/fluid/platform/cpu_info.h" namespace paddle { diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index 1b58243aaa..83d10abd5f 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -103,9 +103,10 @@ const std::vector kAnakinSubgraphPasses({ GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) { passes_.assign({ - // "identity_scale_op_clean_pass", // + // "identity_scale_op_clean_pass", // "is_test_pass", // "simplify_with_basic_ops_pass", // + "fc_fuse_pass", // "conv_affine_channel_fuse_pass", // "conv_eltwiseadd_affine_channel_fuse_pass", // "conv_bn_fuse_pass", // diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index 66e5522ae8..8b45a0b031 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -90,7 +90,7 @@ endif() set(COMMON_OP_DEPS ${COMMON_OP_DEPS} selected_rows_functor selected_rows lod_tensor maxouting unpooling pooling lod_rank_table context_project sequence_pooling executor) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} dynload_warpctc) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} sequence_padding sequence_scale cos_sim_functor memory jit_kernel_helper concat_and_split cross_entropy softmax vol2col im2col sampler sample_prob tree2col) -set(COMMON_OP_DEPS ${COMMON_OP_DEPS} sequence2batch lstm_compute matrix_bit_code gru_compute activation_functions beam_search) +set(COMMON_OP_DEPS ${COMMON_OP_DEPS} sequence2batch lstm_compute matrix_bit_code gru_compute activation_functions beam_search fc) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} box_wrapper) if (WITH_GPU) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} depthwise_conv prelu) diff --git a/paddle/fluid/operators/attention_lstm_op.cc b/paddle/fluid/operators/attention_lstm_op.cc index 275855cbb6..c6d98f1f9a 100644 --- a/paddle/fluid/operators/attention_lstm_op.cc +++ b/paddle/fluid/operators/attention_lstm_op.cc @@ -16,7 +16,7 @@ limitations under the License. */ #include #include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/cpu_vec.h" -#include "paddle/fluid/operators/math/fc_compute.h" +#include "paddle/fluid/operators/math/fc.h" #include "paddle/fluid/platform/cpu_info.h" namespace paddle { @@ -339,10 +339,13 @@ class AttentionLSTMKernel : public framework::OpKernel { T* lstm_x_data = lstm_x->mutable_data(ctx.GetPlace()); T* lstm_out_data = lstm_out->mutable_data(ctx.GetPlace()); + auto blas = math::GetBlas(ctx); + // x(TxM) * fc (Mx1) part of atten_wgt(M+D)x1 - auto blas = math::GetBlas(ctx); - math::FCCompute(blas, total_T, 1, M, x_data, atten_w_data, - atted_x_data, atten_b_data); + auto& dev_ctx = ctx.template device_context(); + math::FCFunctor fc; + fc(dev_ctx, total_T, 1, M, x_data, atten_w_data, atted_x_data, + atten_b_data); const T* cur_atten_x_data = atted_x_data; const T* cur_x_data = x_data; @@ -369,8 +372,7 @@ class AttentionLSTMKernel : public framework::OpKernel { // 1d. softmax vec_softmax(seq_len, fc_out_data, fc_out_data); // mul x(seq_len*M) and sum pool - math::FCCompute(blas, 1, M, seq_len, fc_out_data, - cur_x_data, lstm_x_data); + fc(dev_ctx, 1, M, seq_len, fc_out_data, cur_x_data, lstm_x_data); /// 2. compute LSTM step // lstm weight : concat[forget , input , output , tilde] diff --git a/paddle/fluid/operators/fc_op.cc b/paddle/fluid/operators/fc_op.cc index 242f5390b8..bc0edd780c 100644 --- a/paddle/fluid/operators/fc_op.cc +++ b/paddle/fluid/operators/fc_op.cc @@ -14,65 +14,76 @@ limitations under the License. */ #include "paddle/fluid/operators/fc_op.h" #include -#include "paddle/fluid/operators/math/blas.h" -#include "paddle/fluid/operators/math/fc_compute.h" namespace paddle { namespace operators { -void FCOp::InferShape(framework::InferShapeContext* ctx) const { - PADDLE_ENFORCE(ctx->HasInput("Input"), - "X(Input) of Fully Connected should not be null."); - PADDLE_ENFORCE(ctx->HasOutput("Out"), - "Out(Output) of Fully Connected should not be null."); - PADDLE_ENFORCE(ctx->HasInput("W"), - "W(Input) of Fully Connected should not be null."); - - auto in_dims = ctx->GetInputDim("Input"); - auto w_dims = ctx->GetInputDim("W"); +class FCOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE_EQ(ctx->HasInput("Input"), true, + "X(Input) of Fully Connected should not be null."); + PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true, + "Out(Output) of Fully Connected should not be null."); + PADDLE_ENFORCE_EQ(ctx->HasInput("W"), true, + "W(Input) of Fully Connected should not be null."); + + auto in_dims = ctx->GetInputDim("Input"); + auto w_dims = ctx->GetInputDim("W"); + + if (ctx->HasInput("Bias")) { + auto bias_dims = ctx->GetInputDim("Bias"); + if (bias_dims.size() == 2) { + PADDLE_ENFORCE_EQ(bias_dims[0], 1, + "The shape of Bias must be [1, dim]."); + PADDLE_ENFORCE_EQ(bias_dims[1], w_dims[1], + "The shape of Bias must be [1, dim]."); + } else if (bias_dims.size() == 1) { + PADDLE_ENFORCE_EQ(bias_dims[0], w_dims[1], + "The shape of Bias must be [1, dim]."); + } + } - if (ctx->HasInput("Bias")) { - auto bias_dims = ctx->GetInputDim("Bias"); - if (bias_dims.size() == 2) { - PADDLE_ENFORCE_EQ(bias_dims[0], 1, "The shape of Bias must be [1, dim]."); - PADDLE_ENFORCE_EQ(bias_dims[1], w_dims[1], - "The shape of Bias must be [1, dim]."); - } else if (bias_dims.size() == 1) { - PADDLE_ENFORCE_EQ(bias_dims[0], w_dims[1], - "The shape of Bias must be [1, dim]."); + auto& activation_type = ctx->Attrs().Get("activation_type"); + if (!activation_type.empty()) { + PADDLE_ENFORCE_EQ(activation_type, "relu", + "Activation %s is not supportetd in fc now.", + activation_type.c_str()); } - } + if (ctx->Attrs().Get("use_mkldnn")) { + PADDLE_ENFORCE_EQ(in_dims.size() == 2 || in_dims.size() == 4, true, + "Fully Connected input should be 2-D or 4-D tensor."); + } + PADDLE_ENFORCE_EQ(w_dims.size(), 2, + "Fully Connected input should be 2-D tensor."); + int in_num_col_dims = ctx->Attrs().Get("in_num_col_dims"); + PADDLE_ENFORCE_GT( + in_dims.size(), in_num_col_dims, + "The input tensor Input's rank of FCOp should be larger than " + "in_num_col_dims."); + + std::vector output_dims; + FCOutputSize(in_dims, w_dims, output_dims, in_num_col_dims); - if (ctx->Attrs().Get("use_mkldnn")) { - PADDLE_ENFORCE(in_dims.size() == 2 || in_dims.size() == 4, - "Fully Connected input should be 2-D or 4-D tensor."); + ctx->SetOutputDim("Out", framework::make_ddim(output_dims)); + ctx->ShareLoD("Input", "Out"); } - PADDLE_ENFORCE_EQ(w_dims.size(), 2, - "Fully Connected input should be 2-D tensor."); - int in_num_col_dims = ctx->Attrs().Get("in_num_col_dims"); - PADDLE_ENFORCE_GT( - in_dims.size(), in_num_col_dims, - "The input tensor Input's rank of FCOp should be larger than " - "in_num_col_dims."); - - std::vector output_dims; - FCOutputSize(in_dims, w_dims, output_dims, in_num_col_dims); - - ctx->SetOutputDim("Out", framework::make_ddim(output_dims)); - ctx->ShareLoD("Input", "Out"); -} -framework::OpKernelType FCOp::GetExpectedKernelType( - const framework::ExecutionContext& ctx) const { - framework::LibraryType library = framework::LibraryType::kPlain; - framework::DataLayout layout = framework::DataLayout::kAnyLayout; - if (ctx.Attr("use_mkldnn")) { - library = framework::LibraryType::kMKLDNN; - layout = framework::DataLayout::kMKLDNN; + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + framework::LibraryType library = framework::LibraryType::kPlain; + framework::DataLayout layout = framework::DataLayout::kAnyLayout; + if (ctx.Attr("use_mkldnn")) { + library = framework::LibraryType::kMKLDNN; + layout = framework::DataLayout::kMKLDNN; + } + return framework::OpKernelType(ctx.Input("Input")->type(), + ctx.GetPlace(), layout, library); } - return framework::OpKernelType(ctx.Input("Input")->type(), - ctx.GetPlace(), layout, library); -} +}; void FCOpGrad::InferShape(framework::InferShapeContext* ctx) const { auto in_dims = ctx->GetInputDim("Input"); @@ -86,8 +97,8 @@ void FCOpGrad::InferShape(framework::InferShapeContext* ctx) const { } if (ctx->HasInput("Bias")) { - PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("Bias")), - "Should have bias grad"); + PADDLE_ENFORCE_EQ(ctx->HasOutput(framework::GradVarName("Bias")), true, + "Should have bias grad"); auto bias_dims = ctx->GetInputDim("Bias"); ctx->SetOutputDim(framework::GradVarName("Bias"), bias_dims); } @@ -105,61 +116,36 @@ framework::OpKernelType FCOpGrad::GetExpectedKernelType( ctx.GetPlace(), layout, library); } -void FCOpMaker::Make() { - AddInput("Input", "(Tensor), The input tensor of fully connected operator."); - AddInput("W", "(Tensor), The weight fc op with shape (I, O)."); - AddInput("Bias", "(Tensor, optional) Bias vector with shape (1 x O") - .AsDispensable(); - AddAttr("in_num_col_dims", - "(int, default 1), The fc op can take tensors with more than " - "two dimensions as its inputs.") - .SetDefault(1) - .EqualGreaterThan(1); - AddOutput("Out", "(Tensor) The output tensor of fully connected operator. "); - AddAttr("use_mkldnn", - "(bool, default false) Only used in mkldnn kernel") - .SetDefault(false); - AddAttr(framework::kAllKernelsMustComputeRuntimeShape, - "Skip calling InferShape() function in the runtime.") - .SetDefault(true); - AddComment(R"DOC( - Fully Connected Operator. - - The fully connected operation calculates the output based on the input, weights and bias. - The size of each dimension of the parameters checked in the infer-shape. -)DOC"); -} - -template -class FCOpKernel : public framework::OpKernel { +class FCOpMaker : public framework::OpProtoAndCheckerMaker { public: - void Compute(const paddle::framework::ExecutionContext& ctx) const override { - PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()), - "It must use CPUPlace."); - auto input = ctx.Input("Input"); - auto w = ctx.Input("W"); - auto bias = ctx.Input("Bias"); - auto output = ctx.Output("Out"); - int in_num_col_dims = ctx.Attr("in_num_col_dims"); - auto w_dims = w->dims(); - - std::vector output_dims; - FCOutputSize(input->dims(), w_dims, output_dims, in_num_col_dims); - output->Resize(framework::make_ddim(output_dims)); - output->set_lod(input->lod()); - - auto out_dims = output->dims(); - int M = framework::product(out_dims) / w_dims[1]; - - const T* input_data = input->data(); - const T* w_data = w->data(); - T* output_data = output->mutable_data(ctx.GetPlace()); - auto blas = math::GetBlas(ctx); - math::FCCompute( - blas, M, w_dims[1], w_dims[0], input_data, w_data, output_data, - bias ? bias->data() : NULL); - - // TODO(TJ): fuse act + void Make() override { + AddInput("Input", + "(Tensor), The input tensor of fully connected operator."); + AddInput("W", "(Tensor), The weight fc op with shape (I, O)."); + AddInput("Bias", "(Tensor, optional) Bias vector with shape (1 x O") + .AsDispensable(); + AddOutput("Out", + "(Tensor) The output tensor of fully connected operator. "); + AddAttr("in_num_col_dims", + "(int, default 1), The fc op can take tensors with more than " + "two dimensions as its inputs.") + .SetDefault(1) + .EqualGreaterThan(1); + AddAttr("activation_type", + "Avctivation type used in fully connected operator.") + .SetDefault(""); + AddAttr("use_mkldnn", + "(bool, default false) Only used in mkldnn kernel") + .SetDefault(false); + AddAttr(framework::kAllKernelsMustComputeRuntimeShape, + "Skip calling InferShape() function in the runtime.") + .SetDefault(true); + AddComment(R"DOC( +Fully Connected Operator. + +The fully connected operation calculates the output based on the input, weights and bias. +The size of each dimension of the parameters checked in the infer-shape. +)DOC"); } }; @@ -170,4 +156,6 @@ namespace ops = paddle::operators; REGISTER_OPERATOR(fc, ops::FCOp, ops::FCOpMaker, paddle::framework::DefaultGradOpDescMaker); REGISTER_OPERATOR(fc_grad, ops::FCOpGrad); -REGISTER_OP_CPU_KERNEL(fc, ops::FCOpKernel, ops::FCOpKernel); +REGISTER_OP_CPU_KERNEL( + fc, ops::FCOpKernel, + ops::FCOpKernel); diff --git a/paddle/fluid/operators/fc_op.cu.cc b/paddle/fluid/operators/fc_op.cu.cc new file mode 100644 index 0000000000..2fd33aeb12 --- /dev/null +++ b/paddle/fluid/operators/fc_op.cu.cc @@ -0,0 +1,20 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/fc_op.h" + +namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL( + fc, ops::FCOpKernel, + ops::FCOpKernel); diff --git a/paddle/fluid/operators/fc_op.h b/paddle/fluid/operators/fc_op.h index b82a63cd83..bf08e6ba68 100644 --- a/paddle/fluid/operators/fc_op.h +++ b/paddle/fluid/operators/fc_op.h @@ -14,24 +14,16 @@ limitations under the License. */ #pragma once +#include +#include #include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/math/fc.h" namespace paddle { namespace operators { using Tensor = framework::Tensor; -class FCOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - void InferShape(framework::InferShapeContext* ctx) const override; - - protected: - framework::OpKernelType GetExpectedKernelType( - const framework::ExecutionContext& ctx) const override; -}; - class FCOpGrad : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; @@ -43,11 +35,6 @@ class FCOpGrad : public framework::OperatorWithKernel { const framework::ExecutionContext& ctx) const override; }; -class FCOpMaker : public framework::OpProtoAndCheckerMaker { - public: - void Make() override; -}; - inline void FCOutputSize(const framework::DDim& in_dims, const framework::DDim& w_dims, std::vector& out_dims, // NOLINT @@ -64,5 +51,38 @@ inline void FCOutputSize(const framework::DDim& in_dims, out_dims.push_back(w_dims[1]); } +template +class FCOpKernel : public framework::OpKernel { + public: + void Compute(const paddle::framework::ExecutionContext& ctx) const override { + auto* input = ctx.Input("Input"); + auto* w = ctx.Input("W"); + auto* bias = ctx.Input("Bias"); + auto* output = ctx.Output("Out"); + int in_num_col_dims = ctx.Attr("in_num_col_dims"); + bool with_relu = + (ctx.Attr("activation_type") == "relu") ? true : false; + + auto w_dims = w->dims(); + + std::vector output_dims; + FCOutputSize(input->dims(), w_dims, output_dims, in_num_col_dims); + output->Resize(framework::make_ddim(output_dims)); + output->set_lod(input->lod()); + + auto out_dims = output->dims(); + int M = framework::product(out_dims) / w_dims[1]; + + const T* input_data = input->data(); + const T* w_data = w->data(); + T* output_data = output->mutable_data(ctx.GetPlace()); + + auto& dev_ctx = ctx.template device_context(); + math::FCFunctor fc; + fc(dev_ctx, M, w_dims[1], w_dims[0], input_data, w_data, output_data, + bias ? bias->data() : NULL, with_relu); + } +}; + } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/fused/fused_embedding_fc_lstm_op.cc b/paddle/fluid/operators/fused/fused_embedding_fc_lstm_op.cc index 2a2c583043..4c13d39406 100644 --- a/paddle/fluid/operators/fused/fused_embedding_fc_lstm_op.cc +++ b/paddle/fluid/operators/fused/fused_embedding_fc_lstm_op.cc @@ -16,7 +16,6 @@ limitations under the License. */ #include #include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/cpu_vec.h" -#include "paddle/fluid/operators/math/fc_compute.h" #include "paddle/fluid/operators/math/sequence2batch.h" #include "paddle/fluid/platform/cpu_info.h" diff --git a/paddle/fluid/operators/fused/fusion_gru_op.cc b/paddle/fluid/operators/fused/fusion_gru_op.cc index e67c073b5b..5c89509907 100644 --- a/paddle/fluid/operators/fused/fusion_gru_op.cc +++ b/paddle/fluid/operators/fused/fusion_gru_op.cc @@ -17,7 +17,7 @@ limitations under the License. */ #include #include "paddle/fluid/operators/jit/kernels.h" #include "paddle/fluid/operators/math/blas.h" -#include "paddle/fluid/operators/math/fc_compute.h" +#include "paddle/fluid/operators/math/fc.h" #include "paddle/fluid/operators/math/sequence2batch.h" namespace paddle { @@ -219,9 +219,11 @@ class FusionGRUKernel : public framework::OpKernel { const T* wh_state_data = wh_data + D * D2; T* hidden_out_data = hidden_out->mutable_data(place); auto blas = math::GetBlas(ctx); - math::FCCompute(blas, total_T, D3, M, x_data, wx_data, - xx_data, - bias ? bias->data() : nullptr); + + auto& dev_ctx = ctx.template device_context(); + math::FCFunctor fc; + fc(dev_ctx, total_T, D3, M, x_data, wx_data, xx_data, + bias ? bias->data() : nullptr); int xx_offset = D3; int gate_offset = D; @@ -290,17 +292,17 @@ class FusionGRUKernel : public framework::OpKernel { auto& dev_ctx = ctx.template device_context(); auto blas = math::GetBlas(dev_ctx); math::LoDTensor2BatchFunctor to_batch; + + math::FCFunctor fc; if (M > D3) { - math::FCCompute(blas, total_T, D3, M, x_data, wx_data, - xx_data, - bias ? bias->data() : nullptr); + fc(dev_ctx, total_T, D3, M, x_data, wx_data, xx_data, + bias ? bias->data() : nullptr); to_batch(dev_ctx, *xx, batched_input, true, is_reverse); } else { to_batch(dev_ctx, *x, xx, true, is_reverse); batched_input->set_lod(xx->lod()); - math::FCCompute(blas, total_T, D3, M, xx_data, wx_data, - batched_input_data, - bias ? bias->data() : nullptr); + fc(dev_ctx, total_T, D3, M, xx_data, wx_data, batched_input_data, + bias ? bias->data() : nullptr); } auto batched_lod = batched_input->lod(); diff --git a/paddle/fluid/operators/fused/fusion_lstm_op.cc b/paddle/fluid/operators/fused/fusion_lstm_op.cc index f04aa017e3..32f0e37a64 100644 --- a/paddle/fluid/operators/fused/fusion_lstm_op.cc +++ b/paddle/fluid/operators/fused/fusion_lstm_op.cc @@ -16,7 +16,7 @@ limitations under the License. */ #include #include "paddle/fluid/operators/jit/kernels.h" #include "paddle/fluid/operators/math/blas.h" -#include "paddle/fluid/operators/math/fc_compute.h" +#include "paddle/fluid/operators/math/fc.h" #include "paddle/fluid/operators/math/sequence2batch.h" namespace paddle { @@ -281,8 +281,10 @@ class FuisonLSTMKernel : public framework::OpKernel { T* h_out_data = hidden_out->mutable_data(place); T* c_out_data = cell_out->mutable_data(place); auto blas = math::GetBlas(ctx); - math::FCCompute(blas, total_T, D4, M, x_data, wx_data, - xx_data, bias->data()); + + auto& dev_ctx = ctx.template device_context(); + math::FCFunctor fc; + fc(dev_ctx, total_T, D4, M, x_data, wx_data, xx_data, bias->data()); int xx_offset = D4; int gate_offset = D; @@ -359,16 +361,15 @@ class FuisonLSTMKernel : public framework::OpKernel { math::LoDTensor2BatchFunctor to_batch; auto& dev_ctx = ctx.template device_context(); auto blas = math::GetBlas(dev_ctx); + math::FCFunctor fc; if (M > D4) { - math::FCCompute(blas, x_dims[0], D4, M, x_data, wx_data, - xx_data, bias->data()); + fc(dev_ctx, x_dims[0], D4, M, x_data, wx_data, xx_data, bias->data()); to_batch(dev_ctx, *xx, batched_input, true, is_reverse); } else { to_batch(dev_ctx, *x, xx, true, is_reverse); batched_input->set_lod(xx->lod()); - math::FCCompute(blas, x_dims[0], D4, M, xx_data, - wx_data, batched_input_data, - bias->data()); + fc(dev_ctx, x_dims[0], D4, M, xx_data, wx_data, batched_input_data, + bias->data()); } auto batched_lod = batched_input->lod(); diff --git a/paddle/fluid/operators/fused/fusion_seqconv_eltadd_relu_op.cc b/paddle/fluid/operators/fused/fusion_seqconv_eltadd_relu_op.cc index 4a45177201..519670cc6a 100644 --- a/paddle/fluid/operators/fused/fusion_seqconv_eltadd_relu_op.cc +++ b/paddle/fluid/operators/fused/fusion_seqconv_eltadd_relu_op.cc @@ -16,7 +16,7 @@ limitations under the License. */ #include // for min, max #include #include "paddle/fluid/operators/math/blas.h" -#include "paddle/fluid/operators/math/fc_compute.h" +#include "paddle/fluid/operators/math/fc.h" namespace paddle { namespace operators { @@ -209,9 +209,9 @@ class FusionSeqConvEltAddReluKernel : public framework::OpKernel { } } auto& dev_ctx = ctx.template device_context(); - auto blas = math::GetBlas(dev_ctx); - math::FCCompute(blas, x_dims[0], w_dims[1], w_dims[0], - col_data, w_data, y_data, b_data, true); + math::FCFunctor fc; + fc(dev_ctx, x_dims[0], w_dims[1], w_dims[0], col_data, w_data, y_data, + b_data, true); } }; diff --git a/paddle/fluid/operators/fused/fusion_seqexpand_concat_fc_op.cc b/paddle/fluid/operators/fused/fusion_seqexpand_concat_fc_op.cc index 46632c1e9a..95a08d3b0f 100644 --- a/paddle/fluid/operators/fused/fusion_seqexpand_concat_fc_op.cc +++ b/paddle/fluid/operators/fused/fusion_seqexpand_concat_fc_op.cc @@ -16,7 +16,7 @@ limitations under the License. */ #include #include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/cpu_vec.h" -#include "paddle/fluid/operators/math/fc_compute.h" +#include "paddle/fluid/operators/math/fc.h" #include "paddle/fluid/platform/cpu_info.h" namespace paddle { @@ -165,8 +165,11 @@ class FusionSeqExpandConcatFCOpKernel : public framework::OpKernel { T* fc_out_data = fc_out->mutable_data(ctx.GetPlace()); auto blas = math::GetBlas(ctx); - math::FCCompute(blas, total_T, D, M0, ref_in_data, w_data, - out_data, b ? b->data() : NULL); + + auto& dev_ctx = ctx.template device_context(); + math::FCFunctor fc; + fc(dev_ctx, total_T, D, M0, ref_in_data, w_data, out_data, + b ? b->data() : NULL); w_data = w_data + M0 * D; // first write on blas.MatMul(N, D, M1, in1_data, w_data, fc_out_data); diff --git a/paddle/fluid/operators/math/CMakeLists.txt b/paddle/fluid/operators/math/CMakeLists.txt index 88e750214e..ca0c92b4fb 100644 --- a/paddle/fluid/operators/math/CMakeLists.txt +++ b/paddle/fluid/operators/math/CMakeLists.txt @@ -56,6 +56,7 @@ math_library(sequence_pooling DEPS math_function jit_kernel_helper) math_library(sequence_scale) math_library(softmax DEPS math_function jit_kernel_helper) math_library(beam_search DEPS math_function) +math_library(fc DEPS blas) math_library(matrix_bit_code) diff --git a/paddle/fluid/operators/math/fc.cc b/paddle/fluid/operators/math/fc.cc new file mode 100644 index 0000000000..b5479a1b43 --- /dev/null +++ b/paddle/fluid/operators/math/fc.cc @@ -0,0 +1,62 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/math/fc.h" +#include "paddle/fluid/operators/jit/kernels.h" +#include "paddle/fluid/operators/math/blas.h" + +namespace paddle { +namespace operators { +namespace math { + +template +class FCFunctor { + public: + void operator()(const platform::CPUDeviceContext& context, const int M, + const int N, const int K, const T* X, const T* W, T* Y, + const T* B = nullptr, bool relu = false) { + auto blas = math::GetBlas(context); + blas.MatMul(M, N, K, X, W, Y); + if (B == NULL) { + return; + } + if (relu) { + auto compute = + jit::KernelFuncs, platform::CPUPlace>::Cache() + .At(N); + for (int i = 0; i < M; i++) { + T* dst = Y + i * N; + compute(B, dst, dst, N); + } + } else { + auto compute = + jit::KernelFuncs, platform::CPUPlace>::Cache().At( + N); +#ifdef PADDLE_WITH_MKLML +#pragma omp parallel for +#endif + for (int i = 0; i < M; i++) { + T* dst = Y + i * N; + compute(B, dst, dst, N); + } + } + } +}; + +template class FCFunctor; +template class FCFunctor; + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/math/fc.cu b/paddle/fluid/operators/math/fc.cu new file mode 100644 index 0000000000..1b22b81039 --- /dev/null +++ b/paddle/fluid/operators/math/fc.cu @@ -0,0 +1,73 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include "paddle/fluid/operators/math/blas.h" +#include "paddle/fluid/operators/math/fc.h" + +namespace paddle { +namespace operators { +namespace math { + +template +__global__ void InplaceAddReluKernel(const T* bias, T* data, int M, int N) { + for (int i = blockIdx.x; i < M; i += gridDim.x) { + int index = i * N + threadIdx.x; + for (int j = threadIdx.x; j < N; j += blockDim.x) { + T tmp = data[index] + bias[j]; + if (DoRelu) { + data[index] = (tmp > 0) ? tmp : 0; + } else { + data[index] = tmp; + } + index += blockDim.x; + } + } +} + +template +class FCFunctor { + public: + void operator()(const platform::CUDADeviceContext& context, const int M, + const int N, const int K, const T* X, const T* W, T* Y, + const T* B = nullptr, bool relu = false) { + auto blas = math::GetBlas(context); + blas.GEMM(false, false, M, N, K, static_cast(1.0), X, K, W, N, + static_cast(0.0), Y, N); + if (B == NULL) { + return; + } + + const int kThreadsPerBlock = 1024; + int max_threads = context.GetMaxPhysicalThreadCount(); + int num_threads = std::min(kThreadsPerBlock, (((N + 31) >> 5) << 5)); + int num_blocks = std::max(max_threads / num_threads, 1); + if (relu) { + InplaceAddReluKernel< + T, true><<>>(B, Y, M, + N); + } else { + InplaceAddReluKernel< + T, false><<>>(B, Y, M, + N); + } + } +}; + +template class FCFunctor; +template class FCFunctor; + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/math/fc.h b/paddle/fluid/operators/math/fc.h new file mode 100644 index 0000000000..9bef496fb9 --- /dev/null +++ b/paddle/fluid/operators/math/fc.h @@ -0,0 +1,34 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include "paddle/fluid/platform/device_context.h" + +namespace paddle { +namespace operators { +namespace math { + +template +class FCFunctor { + public: + void operator()(const DeviceContext& context, const int M, const int N, + const int K, const T* X, const T* W, T* Y, + const T* B = nullptr, bool relu = false); +}; + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/math/fc_compute.h b/paddle/fluid/operators/math/fc_compute.h deleted file mode 100644 index 66ce57594a..0000000000 --- a/paddle/fluid/operators/math/fc_compute.h +++ /dev/null @@ -1,55 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#include "paddle/fluid/operators/jit/kernels.h" -#include "paddle/fluid/operators/math/blas.h" - -namespace paddle { -namespace operators { -namespace math { - -template -inline void FCCompute(const BlasT& blas, const int M, - const int N, const int K, const T* X, const T* W, T* Y, - const T* B = NULL, bool relu = false) { - blas.MatMul(M, N, K, X, W, Y); - if (B == NULL) { - return; - } - if (relu) { - auto compute = - jit::KernelFuncs, platform::CPUPlace>::Cache().At( - N); - for (int i = 0; i < M; i++) { - T* dst = Y + i * N; - compute(B, dst, dst, N); - } - } else { - auto compute = - jit::KernelFuncs, platform::CPUPlace>::Cache().At(N); -#ifdef PADDLE_WITH_MKLML -#pragma omp parallel for -#endif - for (int i = 0; i < M; i++) { - T* dst = Y + i * N; - compute(B, dst, dst, N); - } - } -} - -} // namespace math -} // namespace operators -} // namespace paddle diff --git a/python/paddle/fluid/tests/unittests/test_fc_op.py b/python/paddle/fluid/tests/unittests/test_fc_op.py index ff417ad2f1..6c2088af3d 100644 --- a/python/paddle/fluid/tests/unittests/test_fc_op.py +++ b/python/paddle/fluid/tests/unittests/test_fc_op.py @@ -17,7 +17,7 @@ import numpy as np from op_test import OpTest -def fc_refer(matrix, with_bias): +def fc_refer(matrix, with_bias, with_relu=False): in_n, in_c, in_h, in_w = matrix.input.shape w_i, w_o = matrix.weights.shape @@ -31,22 +31,32 @@ def fc_refer(matrix, with_bias): else: result = np.dot(x_data, w_data) - return result + if with_relu: + return np.maximum(result, 0) + else: + return result class MatrixGenerate: - def __init__(self, mb, ic, oc, h, w): + def __init__(self, mb, ic, oc, h, w, bias_dims=2): self.input = np.random.random((mb, ic, h, w)).astype("float32") self.weights = np.random.random((ic * h * w, oc)).astype("float32") - self.bias = np.random.random((1, oc)).astype("float32") + if bias_dims == 2: + self.bias = np.random.random((1, oc)).astype("float32") + else: + self.bias = np.random.random((oc)).astype("float32") class TestFCOp(OpTest): + def config(self): + self.with_bias = True + self.with_relu = True + self.matrix = MatrixGenerate(1, 10, 15, 3, 3, 2) + def setUp(self): self.op_type = "fc" - self.matrix = MatrixGenerate(1, 10, 15, 3, 3) + self.config() - self.with_bias = True if self.with_bias: self.inputs = { 'Input': self.matrix.input, @@ -56,54 +66,60 @@ class TestFCOp(OpTest): else: self.inputs = {'Input': self.matrix.input, 'W': self.matrix.weights} - self.attrs = {'use_mkldnn': False} + if self.with_relu: + activation_type = "relu" + else: + activation_type = "" + self.attrs = {'use_mkldnn': False, 'activation_type': activation_type} - self.outputs = {'Out': fc_refer(self.matrix, self.with_bias)} + self.outputs = { + 'Out': fc_refer(self.matrix, self.with_bias, self.with_relu) + } def test_check_output(self): self.check_output() -class TestFCOpNoBias(TestFCOp): - def init_shapes(self, mb, ic, oc, h, w): +class TestFCOpNoBias1(TestFCOp): + def config(self): self.with_bias = False - self.matrix = MatrixGenerate(mb, ic, oc, h, w) - - -class TestFCOpWithBias(TestFCOp): - def init_shapes(self, mb, ic, oc, h, w): - self.with_bias = True - self.matrix = MatrixGenerate(mb, ic, oc, h, w) - + self.with_relu = False + self.matrix = MatrixGenerate(2, 8, 10, 1, 1, 2) -class TestFCOp1(TestFCOpNoBias): - def init_op_type(self): - self.init_shapes(2, 8, 10, 1, 1) - -class TestFCOp2(TestFCOpNoBias): - def init_op_type(self): - self.init_shapes(4, 5, 6, 2, 2) +class TestFCOpNoBias2(TestFCOp): + def config(self): + self.with_bias = False + self.with_relu = False + self.matrix = MatrixGenerate(4, 5, 6, 2, 2, 1) -class TestFCOp4(TestFCOpNoBias): - def init_op_type(self): - self.init_shapes(1, 32, 64, 3, 3) +class TestFCOpNoBias4(TestFCOp): + def config(self): + self.with_bias = False + self.with_relu = False + self.matrix = MatrixGenerate(1, 32, 64, 3, 3, 1) -class TestFCOpWithBias1(TestFCOpWithBias): - def init_op_type(self): - self.init_shapes(3, 8, 10, 2, 1) +class TestFCOpWithBias1(TestFCOp): + def config(self): + self.with_bias = True + self.with_relu = False + self.matrix = MatrixGenerate(3, 8, 10, 2, 1, 2) -class TestFCOpWithBias2(TestFCOpWithBias): - def init_op_type(self): - self.init_shapes(4, 5, 6, 2, 2) +class TestFCOpWithBias2(TestFCOp): + def config(self): + self.with_bias = True + self.with_relu = True + self.matrix = MatrixGenerate(4, 5, 6, 2, 2, 1) -class TestFCOpWithBias3(TestFCOpWithBias): - def init_op_type(self): - self.init_shapes(1, 64, 32, 3, 3) +class TestFCOpWithBias3(TestFCOp): + def config(self): + self.with_bias = True + self.with_relu = True + self.matrix = MatrixGenerate(1, 64, 32, 3, 3, 1) if __name__ == "__main__": -- GitLab