diff --git a/paddle/fluid/operators/fc_op.cc b/paddle/fluid/operators/fc_op.cc index a9ae1396db8d7dab0364779e506d5c0a3e2ff6ed..5fee30e146077acd1a72580dc8377464c5955775 100644 --- a/paddle/fluid/operators/fc_op.cc +++ b/paddle/fluid/operators/fc_op.cc @@ -30,21 +30,34 @@ void FCOp::InferShape(framework::InferShapeContext* ctx) const { auto w_dims = ctx->GetInputDim("W"); std::vector output_shape({in_dims[0], w_dims[1]}); + if (ctx->HasInput("Bias")) { + auto bias_dims = ctx->GetInputDim("Bias"); + PADDLE_ENFORCE_EQ(bias_dims[0], 1, "The shape of Bias must be [1, dim]."); + PADDLE_ENFORCE_EQ(bias_dims[1], framework::product(w_dims) / w_dims[0], + "The shape of Bias must be [1, dim]."); + } PADDLE_ENFORCE(in_dims.size() == 2 || in_dims.size() == 4, "Fully Connected input should be 2-D or 4-D tensor."); PADDLE_ENFORCE(w_dims.size() == 2 || w_dims.size() == 4, "Fully Connected input should be 2-D or 4-D tensor."); + PADDLE_ENFORCE_EQ(framework::product(w_dims) / w_dims[0], + framework::product(in_dims) / in_dims[0], + "Fully Connected input and weigth size do not match."); + ctx->SetOutputDim("Out", framework::make_ddim(output_shape)); ctx->ShareLoD("Input", "Out"); } framework::OpKernelType FCOp::GetExpectedKernelType( const framework::ExecutionContext& ctx) const { - framework::LibraryType library{framework::LibraryType::kMKLDNN}; - framework::DataLayout layout{framework::DataLayout::kMKLDNN}; - + framework::LibraryType library = framework::LibraryType::kPlain; + framework::DataLayout layout = framework::DataLayout::kAnyLayout; + if (ctx.Attr("use_mkldnn");) { + library = framework::LibraryType::kMKLDNN; + layout = framework::DataLayout::kMKLDNN; + } return framework::OpKernelType( framework::ToDataType(ctx.Input("Input")->type()), ctx.GetPlace(), layout, library); @@ -60,13 +73,22 @@ void FCOpGrad::InferShape(framework::InferShapeContext* ctx) const { if (ctx->HasOutput(framework::GradVarName("W"))) { ctx->SetOutputDim(framework::GradVarName("W"), w_dims); } + + if (ctx->HasInput("Bias")) { + auto bias_dims = ctx->GetInputDim("Bias"); + PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("Bias")); + ctx->SetOutputDim(framework::GradVarName("Bias"), bias_dims); + } } framework::OpKernelType FCOpGrad::GetExpectedKernelType( const framework::ExecutionContext& ctx) const { - framework::LibraryType library{framework::LibraryType::kMKLDNN}; - framework::DataLayout layout{framework::DataLayout::kMKLDNN}; - + framework::LibraryType library = framework::LibraryType::kPlain; + framework::DataLayout layout = framework::DataLayout::kAnyLayout; + if (ctx.Attr("use_mkldnn");) { + library = framework::LibraryType::kMKLDNN; + layout = framework::DataLayout::kMKLDNN; + } return framework::OpKernelType( framework::ToDataType(ctx.Input("Input")->type()), ctx.GetPlace(), layout, library); @@ -75,12 +97,12 @@ framework::OpKernelType FCOpGrad::GetExpectedKernelType( void FCOpMaker::Make() { AddInput("Input", "(Tensor) The input tensor of fully connected operator. "); AddInput("W", "(Tensor), The second input tensor of fc op."); + AddInput("Bias", "(Tensor, optional) Bias vector with shape (1 x D") + .AsDispensable(); AddOutput("Out", "(Tensor) The output tensor of fully connected operator. "); AddAttr("use_mkldnn", "(bool, default false) Only used in mkldnn kernel") .SetDefault(false); - AddAttr("bias_attr", "(bool, default false) Only used in mkldnn kernel") - .SetDefault(false); AddComment(R"DOC( Fully Connected Operator. @@ -94,9 +116,39 @@ void FCOpMaker::Make() { )DOC"); } +template +class FCOpKernel : public framework::OpKernel { + public: + void Compute(const paddle::framework::ExecutionContext& ctx) const override { + PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()), + "It must use CPUPlace."); + auto& dev_ctx = ctx.template device_context(); + auto blas = math::GetBlas(dev_ctx); + auto input = ctx.Input("Input"); + auto w = ctx.Input("W"); + auto b = ctx.Input("Bias"); + + const T* input_data = input->data(); + const T* w_data = w->data(); + auto output = ctx.Output("Out"); + T* output_data = output->mutable_data(ctx.GetPlace()); + + auto in_dims = ctx->GetInputDim("Input"); + auto w_dims = ctx->GetInputDim("W"); + std::vector output_shape({in_dims[0], w_dims[1]}); + + if (bias) { + const T* bias_data = bias->data(); + } + } +}; + } // namespace operators } // namespace paddle -REGISTER_OPERATOR(fc, paddle::operators::FCOp, paddle::operators::FCOpMaker, +namespace ops = paddle::operators; +REGISTER_OPERATOR(fc, ops::FCOp, ops::FCOpMaker, paddle::framework::DefaultGradOpDescMaker); -REGISTER_OPERATOR(fc_grad, paddle::operators::FCOpGrad); +REGISTER_OPERATOR(fc_grad, ops::FCOpGrad); +REGISTER_OP_CPU_KERNEL(fc, ops::FCMKLDNNOpKernel, + ops::FCMKLDNNOpKernel);