From e133df60373b92d1e35b2f34144e7067dbb9752b Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Mon, 13 Aug 2018 23:40:58 +0800 Subject: [PATCH] enable native fc forward --- paddle/fluid/operators/fc_mkldnn_op.cc | 1 + paddle/fluid/operators/fc_op.cc | 55 +++++++++++++++----------- 2 files changed, 33 insertions(+), 23 deletions(-) diff --git a/paddle/fluid/operators/fc_mkldnn_op.cc b/paddle/fluid/operators/fc_mkldnn_op.cc index 99fa659a3..68a47dd6a 100644 --- a/paddle/fluid/operators/fc_mkldnn_op.cc +++ b/paddle/fluid/operators/fc_mkldnn_op.cc @@ -128,6 +128,7 @@ class FCMKLDNNOpKernel : public paddle::framework::OpKernel { PADDLE_ENFORCE(input->dims().size() == 2 || input->dims().size() == 4, "Input must be with 2 or 4 dimensions, i.e. NCHW"); + // TODO(intel): the src weight is io and mkldnn weight need be transposed ! PADDLE_ENFORCE(w->dims().size() == 2 || w->dims().size() == 4, "Weights must be with 2 or 4 dimensions, i.e. OI or OIHW"); diff --git a/paddle/fluid/operators/fc_op.cc b/paddle/fluid/operators/fc_op.cc index 5fee30e14..e71f63c13 100644 --- a/paddle/fluid/operators/fc_op.cc +++ b/paddle/fluid/operators/fc_op.cc @@ -15,6 +15,8 @@ limitations under the License. */ #include "paddle/fluid/operators/fc_op.h" #include +DECLARE_int32(paddle_num_threads); + namespace paddle { namespace operators { @@ -25,25 +27,23 @@ void FCOp::InferShape(framework::InferShapeContext* ctx) const { "Out(Output) of Fully Connected should not be null."); PADDLE_ENFORCE(ctx->HasInput("W"), "W(Input) of Fully Connected should not be null."); - + // NCHW auto in_dims = ctx->GetInputDim("Input"); + // IO, I=C*H*W auto w_dims = ctx->GetInputDim("W"); std::vector output_shape({in_dims[0], w_dims[1]}); if (ctx->HasInput("Bias")) { auto bias_dims = ctx->GetInputDim("Bias"); PADDLE_ENFORCE_EQ(bias_dims[0], 1, "The shape of Bias must be [1, dim]."); - PADDLE_ENFORCE_EQ(bias_dims[1], framework::product(w_dims) / w_dims[0], + PADDLE_ENFORCE_EQ(bias_dims[1], w_dims[1], "The shape of Bias must be [1, dim]."); } PADDLE_ENFORCE(in_dims.size() == 2 || in_dims.size() == 4, "Fully Connected input should be 2-D or 4-D tensor."); - - PADDLE_ENFORCE(w_dims.size() == 2 || w_dims.size() == 4, - "Fully Connected input should be 2-D or 4-D tensor."); - - PADDLE_ENFORCE_EQ(framework::product(w_dims) / w_dims[0], - framework::product(in_dims) / in_dims[0], + PADDLE_ENFORCE_EQ(w_dims.size(), 2UL, + "Fully Connected input should be 2-D tensor."); + PADDLE_ENFORCE_EQ(framework::product(in_dims) / in_dims[0], w_dims[0], "Fully Connected input and weigth size do not match."); ctx->SetOutputDim("Out", framework::make_ddim(output_shape)); @@ -54,7 +54,7 @@ framework::OpKernelType FCOp::GetExpectedKernelType( const framework::ExecutionContext& ctx) const { framework::LibraryType library = framework::LibraryType::kPlain; framework::DataLayout layout = framework::DataLayout::kAnyLayout; - if (ctx.Attr("use_mkldnn");) { + if (ctx.Attr("use_mkldnn")) { library = framework::LibraryType::kMKLDNN; layout = framework::DataLayout::kMKLDNN; } @@ -75,8 +75,9 @@ void FCOpGrad::InferShape(framework::InferShapeContext* ctx) const { } if (ctx->HasInput("Bias")) { + PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("Bias")), + "Should have bias grad"); auto bias_dims = ctx->GetInputDim("Bias"); - PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("Bias")); ctx->SetOutputDim(framework::GradVarName("Bias"), bias_dims); } } @@ -85,7 +86,7 @@ framework::OpKernelType FCOpGrad::GetExpectedKernelType( const framework::ExecutionContext& ctx) const { framework::LibraryType library = framework::LibraryType::kPlain; framework::DataLayout layout = framework::DataLayout::kAnyLayout; - if (ctx.Attr("use_mkldnn");) { + if (ctx.Attr("use_mkldnn")) { library = framework::LibraryType::kMKLDNN; layout = framework::DataLayout::kMKLDNN; } @@ -95,9 +96,11 @@ framework::OpKernelType FCOpGrad::GetExpectedKernelType( } void FCOpMaker::Make() { - AddInput("Input", "(Tensor) The input tensor of fully connected operator. "); - AddInput("W", "(Tensor), The second input tensor of fc op."); - AddInput("Bias", "(Tensor, optional) Bias vector with shape (1 x D") + AddInput("Input", + "(Tensor), The input tensor of fully connected operator with format " + "(NCHW). "); + AddInput("W", "(Tensor), The weight fc op with shape (I, O)."); + AddInput("Bias", "(Tensor, optional) Bias vector with shape (1 x O") .AsDispensable(); AddOutput("Out", "(Tensor) The output tensor of fully connected operator. "); AddAttr("use_mkldnn", @@ -120,25 +123,32 @@ template class FCOpKernel : public framework::OpKernel { public: void Compute(const paddle::framework::ExecutionContext& ctx) const override { - PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()), + PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()), "It must use CPUPlace."); - auto& dev_ctx = ctx.template device_context(); - auto blas = math::GetBlas(dev_ctx); auto input = ctx.Input("Input"); auto w = ctx.Input("W"); auto b = ctx.Input("Bias"); + auto output = ctx.Output("Out"); + auto in_dims = ctx->GetInputDim("Input"); + auto w_dims = ctx->GetInputDim("W"); + auto& dev_ctx = ctx.template device_context(); + auto blas = math::GetBlas(dev_ctx); const T* input_data = input->data(); const T* w_data = w->data(); - auto output = ctx.Output("Out"); T* output_data = output->mutable_data(ctx.GetPlace()); - auto in_dims = ctx->GetInputDim("Input"); - auto w_dims = ctx->GetInputDim("W"); - std::vector output_shape({in_dims[0], w_dims[1]}); + blas.GEMM(CblasNoTrans, CblasNoTrans, in_dims[0], w_dims[1], w_dims[0], + static_cast(1), input_data, w_data, static_cast(0), + output_data); if (bias) { const T* bias_data = bias->data(); +#pragma omp parallel for if (FLAGS_paddle_num_threads > 1) + for (int bs = 0; bs < in_dims[0]; bs++) { + blas.AXPY(w_dims[1], static_cast(1), bias_data, + output_data + bs * w_dimws[1]); + } } } }; @@ -150,5 +160,4 @@ namespace ops = paddle::operators; REGISTER_OPERATOR(fc, ops::FCOp, ops::FCOpMaker, paddle::framework::DefaultGradOpDescMaker); REGISTER_OPERATOR(fc_grad, ops::FCOpGrad); -REGISTER_OP_CPU_KERNEL(fc, ops::FCMKLDNNOpKernel, - ops::FCMKLDNNOpKernel); +REGISTER_OP_CPU_KERNEL(fc, ops::FCOpKernel, ops::FCOpKernel); -- GitLab