diff --git a/paddle/fluid/operators/fc_mkldnn_op.cc b/paddle/fluid/operators/fc_mkldnn_op.cc index d0023713b4e532a961458525bcc3748f6f0b2105..9c704a2949f7100e0812eafe1e58ef04bf71f840 100644 --- a/paddle/fluid/operators/fc_mkldnn_op.cc +++ b/paddle/fluid/operators/fc_mkldnn_op.cc @@ -125,10 +125,10 @@ class FCMKLDNNOpKernel : public paddle::framework::OpKernel { auto input = ctx.Input("Input"); auto w = ctx.Input("W"); - PADDLE_ENFORCE(input->dims().size() == 4 || input->dims().size() == 2, + PADDLE_ENFORCE(input->dims().size() == 2 || input->dims().size() == 4, "Input must be with 2 or 4 dimensions, i.e. NCHW"); - PADDLE_ENFORCE(w->dims().size() == 2, - "Weights must be with 2 dimensions, i.e. NC"); + PADDLE_ENFORCE(w->dims().size() == 2 || w->dims().size() == 4, + "Weights must be with 2 or 4 dimensions, i.e. OI or OIHW"); bool with_bias = ctx.Attr("bias_attr"); MKLDNNMD md(input, w, with_bias); diff --git a/paddle/fluid/operators/fc_op.cc b/paddle/fluid/operators/fc_op.cc index 41f9c0b8a9bbca91c5c54a55ca652ce40f214f99..381771f157d78fb04e54f0a07c40e4df2c91441a 100644 --- a/paddle/fluid/operators/fc_op.cc +++ b/paddle/fluid/operators/fc_op.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/fc_op.h" +#include namespace paddle { namespace operators { @@ -29,11 +30,11 @@ void FCOp::InferShape(framework::InferShapeContext* ctx) const { auto w_dims = ctx->GetInputDim("W"); std::vector output_shape({in_dims[0], w_dims[1]}); - PADDLE_ENFORCE(in_dims.size() == 4 || in_dims.size() == 2, + PADDLE_ENFORCE(in_dims.size() == 2 || in_dims.size() == 4, "Fully Connected input should be 2-D or 4-D tensor."); - PADDLE_ENFORCE(w_dims.size() == 2, - "Fully Connected input should be 2-D tensor."); + PADDLE_ENFORCE(w_dims.size() == 2 || w_dims.size() == 4, + "Fully Connected input should be 2-D or 4-D tensor."); ctx->SetOutputDim("Out", framework::make_ddim(output_shape)); ctx->ShareLoD("Input", "Out"); @@ -73,19 +74,9 @@ framework::OpKernelType FCOpGrad::GetExpectedKernelType( FCOpMaker::FCOpMaker(OpProto* proto, OpAttrChecker* op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { - AddInput( - "Input", - "(Tensor) The input tensor of fully connected operator. " - "The format of input tensor is NCHW, where N is batch size, C is the " - "number of channels, H is the height of the feature, " - "and W is the width of the feature."); + AddInput("Input", "(Tensor) The input tensor of fully connected operator. "); AddInput("W", "(Tensor), The second input tensor of fc op."); - AddOutput("Out", - "(Tensor) The output tensor of fully connected operator. " - "The format of output tensor is also NCHW, " - "where N is batch size, C is the number of channels, " - "H is the height of the feature, " - "and W is the width of the feature."); + AddOutput("Out", "(Tensor) The output tensor of fully connected operator. "); AddAttr("use_mkldnn", "(bool, default false) Only used in mkldnn kernel") .SetDefault(false);