From 34a8084328921d4d043fc3c8308063d38087e62f Mon Sep 17 00:00:00 2001 From: mozga-intel Date: Thu, 29 Mar 2018 20:31:59 +0200 Subject: [PATCH] Added new fc files, register fc kernel --- paddle/fluid/operators/CMakeLists.txt | 29 +---- paddle/fluid/operators/fc_mkldnn_op.cc | 108 +--------------- paddle/fluid/operators/fc_op.cc | 122 ++++++++++++++++++ .../operators/{fc_mkldnn_op.h => fc_op.h} | 5 + python/paddle/fluid/layers/nn.py | 67 ++++++---- 5 files changed, 178 insertions(+), 153 deletions(-) create mode 100644 paddle/fluid/operators/fc_op.cc rename paddle/fluid/operators/{fc_mkldnn_op.h => fc_op.h} (91%) diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index 6c79998f07..9ed79453b9 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -1,14 +1,6 @@ file(GLOB GENERAL_OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "*_op.cc") +string(REPLACE "_mkldnn" "" GENERAL_OPS "${GENERAL_OPS}") string(REPLACE ".cc" "" GENERAL_OPS "${GENERAL_OPS}") -if(WITH_MKLDNN) - string(REPLACE "_mkldnn" "" GENERAL_OPS "${GENERAL_OPS}") -else() - foreach(item ${GENERAL_OPS}) - if(${item} MATCHES ".*_mkldnn_op") - list(REMOVE_ITEM GENERAL_OPS ${item}) - endif() - endforeach(item) -endif() list(REMOVE_DUPLICATES GENERAL_OPS) set(DEPS_OPS "") set(pybind_file ${PADDLE_SOURCE_DIR}/paddle/fluid/pybind/pybind.h) @@ -88,12 +80,7 @@ function(op_library TARGET) endif() list(LENGTH cc_srcs cc_srcs_len) - if(WITH_MKLDNN) - list(LENGTH mkldnn_cc_srcs mkldnn_cc_srcs_len) - if (${cc_srcs_len} EQUAL 0 AND ${mkldnn_cc_srcs_len} EQUAL 0) - message(FATAL_ERROR "The op library ${TARGET} should contains at least one .cc file") - endif() - elseif(${cc_srcs_len} EQUAL 0) + if (${cc_srcs_len} EQUAL 0) message(FATAL_ERROR "The op library ${TARGET} should contains at least one .cc file") endif() @@ -122,16 +109,7 @@ function(op_library TARGET) # The registration of USE_OP, please refer to paddle/fluid/framework/op_registry.h. # Note that it's enough to just adding one operator to pybind in a *_op.cc file. # And for detail pybind information, please see generated paddle/pybind/pybind.h. - # This replacing is needed, when the CPU's kernel doesn't exist. - string(REPLACE "_op" "_mkldnn_op" target_mkldnn_file "${TARGET}") - if(EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.cc) - file(READ ${TARGET}.cc TARGET_CONTENT) - elseif(WITH_MKLDNN AND EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${target_mkldnn_file}.cc) - file(READ ${target_mkldnn_file}.cc TARGET_CONTENT) - else() - message(FATAL_ERROR "Cannot read the ${TARGET} file from ${CMAKE_CURRENT_SOURCE_DIR}") - endif() - + file(READ ${TARGET}.cc TARGET_CONTENT) string(REGEX MATCH "REGISTER_OP\\(.*REGISTER_OP\\(" multi_register "${TARGET_CONTENT}") string(REGEX MATCH "REGISTER_OP\\([a-z0-9_]*," one_register "${multi_register}") if (one_register STREQUAL "") @@ -246,6 +224,7 @@ op_library(recurrent_op DEPS executor) op_library(warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale) op_library(cos_sim_op DEPS cos_sim_functor) op_library(parallel_do_op DEPS executor) + if (WITH_GPU) op_library(conv_op DEPS vol2col depthwise_conv im2col) else() diff --git a/paddle/fluid/operators/fc_mkldnn_op.cc b/paddle/fluid/operators/fc_mkldnn_op.cc index 48655d36fc..3e006189ef 100644 --- a/paddle/fluid/operators/fc_mkldnn_op.cc +++ b/paddle/fluid/operators/fc_mkldnn_op.cc @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/fc_mkldnn_op.h" #include "paddle/fluid/framework/tensor.h" +#include "paddle/fluid/operators/fc_op.h" #include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/mkldnn_helper.h" @@ -23,105 +23,12 @@ namespace operators { using paddle::framework::Tensor; using paddle::platform::MKLDNNDeviceContext; -void FCOp::InferShape(framework::InferShapeContext* ctx) const { - PADDLE_ENFORCE(ctx->HasInput("Input"), - "X(Input) of Fully Connected should not be null."); - PADDLE_ENFORCE(ctx->HasOutput("Out"), - "Out(Output) of Fully Connected should not be null."); - PADDLE_ENFORCE(ctx->HasInput("W"), - "W(Input) of Fully Connected should not be null."); - - auto in_dims = ctx->GetInputDim("Input"); - auto w_dims = ctx->GetInputDim("W"); - std::vector output_shape({in_dims[0], w_dims[1]}); - - PADDLE_ENFORCE(in_dims.size() == 4, - "Fully Connected input should be 4-D tensor."); - - PADDLE_ENFORCE(w_dims.size() == 2, - "Fully Connected input should be 2-D tensor."); - - ctx->SetOutputDim("Out", framework::make_ddim(output_shape)); - ctx->ShareLoD("Input", "Out"); -} - -framework::OpKernelType FCOp::GetExpectedKernelType( - const framework::ExecutionContext& ctx) const { - framework::LibraryType library{framework::LibraryType::kMKLDNN}; - - std::string data_format = ctx.Attr("data_format"); - framework::DataLayout layout = framework::StringToDataLayout(data_format); - - return framework::OpKernelType( - framework::ToDataType(ctx.Input("Input")->type()), ctx.GetPlace(), - layout, library); -} - -void FCOpGrad::InferShape(framework::InferShapeContext* ctx) const { - auto in_dims = ctx->GetInputDim("Input"); - auto w_dims = ctx->GetInputDim("W"); - - if (ctx->HasOutput(framework::GradVarName("Input"))) { - ctx->SetOutputDim(framework::GradVarName("Input"), in_dims); - } - if (ctx->HasOutput(framework::GradVarName("W"))) { - ctx->SetOutputDim(framework::GradVarName("W"), w_dims); - } -} - -framework::OpKernelType FCOpGrad::GetExpectedKernelType( - const framework::ExecutionContext& ctx) const { - framework::LibraryType library{framework::LibraryType::kMKLDNN}; - - std::string data_format = ctx.Attr("data_format"); - framework::DataLayout layout = framework::StringToDataLayout(data_format); - - return framework::OpKernelType( - framework::ToDataType(ctx.Input("Input")->type()), ctx.GetPlace(), - layout, library); -} - -class FCOpMaker : public framework::OpProtoAndCheckerMaker { - public: - FCOpMaker(OpProto* proto, OpAttrChecker* op_checker) - : OpProtoAndCheckerMaker(proto, op_checker) { - AddInput( - "Input", - "(Tensor) The input tensor of fully connected operator. " - "The format of input tensor is NCHW, where N is batch size, C is the " - "number of channels, H is the height of the feature, " - "and W is the width of the feature."); - AddInput("W", "(Tensor), The second input tensor of fc op."); - AddOutput("Out", - "(Tensor) The output tensor of pooling operator. " - "The format of output tensor is also NCHW, " - "where N is batch size, C is the number of channels, " - "H is the height of the feature, " - "and W is the width of the feature."); - AddAttr("use_mkldnn", - "(bool, default false) Only used in mkldnn kernel") - .SetDefault(false); - AddAttr("with_bias", - "(bool, default false) Only used in mkldnn kernel") - .SetDefault(false); - AddAttr( - "data_format", - "(string, default NCHW) Only used in " - "An optional string from: \"NHWC\", \"NCHW\". " - "Defaults to \"NHWC\". Specify the data format of the output data, " - "the input will be transformed automatically. ") - .SetDefault("AnyLayout"); - AddComment(R"DOC( -)DOC"); - } -}; - struct MKLDNNMatrixSize final { explicit MKLDNNMatrixSize(const std::vector& in, const std::vector& w) : mb{in[0]}, ic{in[1]}, oc{w[1]}, h{in[2]}, w{in[3]} {} - bool is_spatial() const { return h > 1 && w > 1; } + bool is_spatial() const { return h > 2 && w > 2; } const int mb; const int ic; @@ -229,12 +136,12 @@ class FCMKLDNNOpKernel : public paddle::framework::OpKernel { auto input = ctx.Input("Input"); auto w = ctx.Input("W"); - PADDLE_ENFORCE(input->dims().size() == 4, - "Input must be with 4 dimensions, i.e. NCHW"); + PADDLE_ENFORCE(input->dims().size() == 4 || input->dims().size() == 2, + "Input must be with 2 or 4 dimensions, i.e. NCHW"); PADDLE_ENFORCE(w->dims().size() == 2, "Weights must be with 2 dimensions, i.e. NC"); - bool with_bias = ctx.Attr("with_bias"); + bool with_bias = ctx.Attr("bias_attr"); MKLDNNMD md(input, w, with_bias); std::shared_ptr pd = @@ -319,7 +226,7 @@ class FCMKLDNNGradOpKernel : public paddle::framework::OpKernel { const Tensor* out_grad = ctx.Input(framework::GradVarName("Out")); const T* out_grad_data = out_grad->data(); - bool with_bias = ctx.Attr("with_bias"); + bool with_bias = ctx.Attr("bias_attr"); MKLDNNMD md(input, w, with_bias); MKLDNNMemory mem(&md, mkldnn_engine); @@ -400,9 +307,6 @@ class FCMKLDNNGradOpKernel : public paddle::framework::OpKernel { } // namespace operators } // namespace paddle -REGISTER_OP(fc, paddle::operators::FCOp, paddle::operators::FCOpMaker, fc_grad, - paddle::operators::FCOpGrad); - REGISTER_OP_KERNEL(fc, MKLDNN, ::paddle::platform::CPUPlace, paddle::operators::FCMKLDNNOpKernel); diff --git a/paddle/fluid/operators/fc_op.cc b/paddle/fluid/operators/fc_op.cc new file mode 100644 index 0000000000..93b59854db --- /dev/null +++ b/paddle/fluid/operators/fc_op.cc @@ -0,0 +1,122 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/fc_op.h" + +namespace paddle { +namespace operators { + +void FCOp::InferShape(framework::InferShapeContext* ctx) const { + PADDLE_ENFORCE(ctx->HasInput("Input"), + "X(Input) of Fully Connected should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Out(Output) of Fully Connected should not be null."); + PADDLE_ENFORCE(ctx->HasInput("W"), + "W(Input) of Fully Connected should not be null."); + + auto in_dims = ctx->GetInputDim("Input"); + auto w_dims = ctx->GetInputDim("W"); + std::vector output_shape({in_dims[0], w_dims[1]}); + + PADDLE_ENFORCE(in_dims.size() == 4, + "Fully Connected input should be 4-D tensor."); + + PADDLE_ENFORCE(w_dims.size() == 2, + "Fully Connected input should be 2-D tensor."); + + ctx->SetOutputDim("Out", framework::make_ddim(output_shape)); + ctx->ShareLoD("Input", "Out"); +} + +framework::OpKernelType FCOp::GetExpectedKernelType( + const framework::ExecutionContext& ctx) const { + framework::LibraryType library{framework::LibraryType::kMKLDNN}; + framework::DataLayout layout{framework::DataLayout::kAnyLayout}; + + return framework::OpKernelType( + framework::ToDataType(ctx.Input("Input")->type()), ctx.GetPlace(), + layout, library); +} + +void FCOpGrad::InferShape(framework::InferShapeContext* ctx) const { + auto in_dims = ctx->GetInputDim("Input"); + auto w_dims = ctx->GetInputDim("W"); + + if (ctx->HasOutput(framework::GradVarName("Input"))) { + ctx->SetOutputDim(framework::GradVarName("Input"), in_dims); + } + if (ctx->HasOutput(framework::GradVarName("W"))) { + ctx->SetOutputDim(framework::GradVarName("W"), w_dims); + } +} + +framework::OpKernelType FCOpGrad::GetExpectedKernelType( + const framework::ExecutionContext& ctx) const { + framework::LibraryType library{framework::LibraryType::kMKLDNN}; + framework::DataLayout layout{framework::DataLayout::kAnyLayout}; + + return framework::OpKernelType( + framework::ToDataType(ctx.Input("Input")->type()), ctx.GetPlace(), + layout, library); +} + +FCOpMaker::FCOpMaker(OpProto* proto, OpAttrChecker* op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput( + "Input", + "(Tensor) The input tensor of fully connected operator. " + "The format of input tensor is NCHW, where N is batch size, C is the " + "number of channels, H is the height of the feature, " + "and W is the width of the feature."); + AddInput("W", "(Tensor), The second input tensor of fc op."); + AddOutput("Out", + "(Tensor) The output tensor of fully connected operator. " + "The format of output tensor is also NCHW, " + "where N is batch size, C is the number of channels, " + "H is the height of the feature, " + "and W is the width of the feature."); + AddAttr("use_mkldnn", + "(bool, default false) Only used in mkldnn kernel") + .SetDefault(false); + AddAttr("bias_attr", "(bool, default false) Only used in mkldnn kernel") + .SetDefault(false); + AddComment(R"DOC( + Fully Connected Operator. + + The fully connected operation calculates the output based on the input, weights and bias attribute. + The size of each dimension of the parameters checked in the infer-shape. + Input(Input) is NCHW or NC format. Where N is batch size, C is the number of channels, + H is the height of the feature, and W is the width of the feature. + Weights(W) is OIHW or OI format. Where H is the height of the feature, W is the width of the feature, + O is the height of output, and I is the number of channels. + Output(Out) is NC format. Where N is batch size, and C is the number of channels. + The matrix of bias is generated by the mkldnn framework, when the bias_attr is True. + Additional parametrs are use_mkldnn and bias_attr. + The input(X) size and output(Out) size may be diffrent. + +Example: + Input: + Input shape: $(N, C_{in}, H_{in}, W_{in})$ + Weight shape: $(O_{out}, I_{in}, H_{in}, W_{in})$ + Bias shape: $(O_{out})$ + Output: + Output shape: $(N, C_{out})$ +)DOC"); +} + +} // namespace operators +} // namespace paddle + +REGISTER_OP(fc, paddle::operators::FCOp, paddle::operators::FCOpMaker, fc_grad, + paddle::operators::FCOpGrad); diff --git a/paddle/fluid/operators/fc_mkldnn_op.h b/paddle/fluid/operators/fc_op.h similarity index 91% rename from paddle/fluid/operators/fc_mkldnn_op.h rename to paddle/fluid/operators/fc_op.h index 9e6c66491d..70fa96440d 100644 --- a/paddle/fluid/operators/fc_mkldnn_op.h +++ b/paddle/fluid/operators/fc_op.h @@ -43,5 +43,10 @@ class FCOpGrad : public framework::OperatorWithKernel { const framework::ExecutionContext& ctx) const override; }; +class FCOpMaker : public framework::OpProtoAndCheckerMaker { + public: + FCOpMaker(OpProto* proto, OpAttrChecker* op_checker); +}; + } // namespace operators } // namespace paddle diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index bfae205bcf..40809b304f 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -86,7 +86,6 @@ def fc(input, param_attr=None, bias_attr=None, use_mkldnn=False, - with_bias=False, act=None, name=None): """ @@ -156,16 +155,39 @@ def fc(input, dtype = helper.input_dtype() mul_results = [] - for input_var, param_attr in helper.iter_inputs_and_params(): - input_shape = input_var.shape + if use_mkldnn: + tmp = helper.create_tmp_variable(dtype) + input_shape = input.shape param_shape = [ reduce(lambda a, b: a * b, input_shape[num_flatten_dims:], 1) ] + [size] w = helper.create_parameter( - attr=param_attr, shape=param_shape, dtype=dtype, is_bias=False) - tmp = helper.create_tmp_variable(dtype) - if use_mkldnn == False: + attr=helper.param_attr, + shape=param_shape, + dtype=dtype, + is_bias=False) + bias_attr = False + if bias_attr is not None: + bias_attr = True + helper.append_op( + type="fc", + inputs={"Input": input, + "W": w}, + outputs={"Out": tmp}, + attrs={"use_mkldnn": use_mkldnn, + "bias_attr": bias_attr}) + return helper.append_activation(tmp) + else: + for input_var, param_attr in helper.iter_inputs_and_params(): + input_shape = input_var.shape + param_shape = [ + reduce(lambda a, b: a * b, input_shape[num_flatten_dims:], 1) + ] + [size] + + w = helper.create_parameter( + attr=param_attr, shape=param_shape, dtype=dtype, is_bias=False) + tmp = helper.create_tmp_variable(dtype) helper.append_op( type="mul", inputs={"X": input_var, @@ -174,29 +196,22 @@ def fc(input, attrs={ "x_num_col_dims": num_flatten_dims, "y_num_col_dims": 1, - 'use_mkldnn': use_mkldnn }) + mul_results.append(tmp) + + if len(mul_results) == 1: + pre_bias = mul_results[0] else: + pre_bias = helper.create_tmp_variable(dtype) helper.append_op( - type="fc", - inputs={"Input": input_var, - "W": w}, - outputs={"Out": tmp}, - attrs={"use_mkldnn": use_mkldnn, - "with_bias": with_bias}) - mul_results.append(tmp) - - # sum - if len(mul_results) == 1: - pre_bias = mul_results[0] - else: - pre_bias = helper.create_tmp_variable(dtype) - helper.append_op( - type="sum", inputs={"X": mul_results}, outputs={"Out": pre_bias}) - # add bias - pre_activation = helper.append_bias_op(pre_bias, dim_start=num_flatten_dims) - # add activation - return helper.append_activation(pre_activation) + type="sum", + inputs={"X": mul_results}, + outputs={"Out": pre_bias}) + # add bias + pre_activation = helper.append_bias_op( + pre_bias, dim_start=num_flatten_dims) + # add activation + return helper.append_activation(pre_activation) def embedding(input, -- GitLab