From 2811ea4440c8dd3ecaedc7476f9a7c0cb1519a2c Mon Sep 17 00:00:00 2001 From: mozga-intel Date: Wed, 28 Mar 2018 10:45:52 +0200 Subject: [PATCH] Implementation of MKLDNN FC --- paddle/fluid/operators/CMakeLists.txt | 29 +- paddle/fluid/operators/fc_mkldnn_op.cc | 410 ++++++++++++++++++ paddle/fluid/operators/fc_mkldnn_op.h | 47 ++ python/paddle/fluid/layers/nn.py | 32 +- .../fluid/tests/unittests/test_fc_op.py | 99 +++++ 5 files changed, 603 insertions(+), 14 deletions(-) create mode 100644 paddle/fluid/operators/fc_mkldnn_op.cc create mode 100644 paddle/fluid/operators/fc_mkldnn_op.h create mode 100644 python/paddle/fluid/tests/unittests/test_fc_op.py diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index 9ed79453b96..6c79998f074 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -1,6 +1,14 @@ file(GLOB GENERAL_OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "*_op.cc") -string(REPLACE "_mkldnn" "" GENERAL_OPS "${GENERAL_OPS}") string(REPLACE ".cc" "" GENERAL_OPS "${GENERAL_OPS}") +if(WITH_MKLDNN) + string(REPLACE "_mkldnn" "" GENERAL_OPS "${GENERAL_OPS}") +else() + foreach(item ${GENERAL_OPS}) + if(${item} MATCHES ".*_mkldnn_op") + list(REMOVE_ITEM GENERAL_OPS ${item}) + endif() + endforeach(item) +endif() list(REMOVE_DUPLICATES GENERAL_OPS) set(DEPS_OPS "") set(pybind_file ${PADDLE_SOURCE_DIR}/paddle/fluid/pybind/pybind.h) @@ -80,7 +88,12 @@ function(op_library TARGET) endif() list(LENGTH cc_srcs cc_srcs_len) - if (${cc_srcs_len} EQUAL 0) + if(WITH_MKLDNN) + list(LENGTH mkldnn_cc_srcs mkldnn_cc_srcs_len) + if (${cc_srcs_len} EQUAL 0 AND ${mkldnn_cc_srcs_len} EQUAL 0) + message(FATAL_ERROR "The op library ${TARGET} should contains at least one .cc file") + endif() + elseif(${cc_srcs_len} EQUAL 0) message(FATAL_ERROR "The op library ${TARGET} should contains at least one .cc file") endif() @@ -109,7 +122,16 @@ function(op_library TARGET) # The registration of USE_OP, please refer to paddle/fluid/framework/op_registry.h. # Note that it's enough to just adding one operator to pybind in a *_op.cc file. # And for detail pybind information, please see generated paddle/pybind/pybind.h. - file(READ ${TARGET}.cc TARGET_CONTENT) + # This replacing is needed, when the CPU's kernel doesn't exist. + string(REPLACE "_op" "_mkldnn_op" target_mkldnn_file "${TARGET}") + if(EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.cc) + file(READ ${TARGET}.cc TARGET_CONTENT) + elseif(WITH_MKLDNN AND EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${target_mkldnn_file}.cc) + file(READ ${target_mkldnn_file}.cc TARGET_CONTENT) + else() + message(FATAL_ERROR "Cannot read the ${TARGET} file from ${CMAKE_CURRENT_SOURCE_DIR}") + endif() + string(REGEX MATCH "REGISTER_OP\\(.*REGISTER_OP\\(" multi_register "${TARGET_CONTENT}") string(REGEX MATCH "REGISTER_OP\\([a-z0-9_]*," one_register "${multi_register}") if (one_register STREQUAL "") @@ -224,7 +246,6 @@ op_library(recurrent_op DEPS executor) op_library(warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale) op_library(cos_sim_op DEPS cos_sim_functor) op_library(parallel_do_op DEPS executor) - if (WITH_GPU) op_library(conv_op DEPS vol2col depthwise_conv im2col) else() diff --git a/paddle/fluid/operators/fc_mkldnn_op.cc b/paddle/fluid/operators/fc_mkldnn_op.cc new file mode 100644 index 00000000000..48655d36fc9 --- /dev/null +++ b/paddle/fluid/operators/fc_mkldnn_op.cc @@ -0,0 +1,410 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/fc_mkldnn_op.h" +#include "paddle/fluid/framework/tensor.h" +#include "paddle/fluid/platform/device_context.h" +#include "paddle/fluid/platform/mkldnn_helper.h" + +namespace paddle { +namespace operators { + +using paddle::framework::Tensor; +using paddle::platform::MKLDNNDeviceContext; + +void FCOp::InferShape(framework::InferShapeContext* ctx) const { + PADDLE_ENFORCE(ctx->HasInput("Input"), + "X(Input) of Fully Connected should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Out(Output) of Fully Connected should not be null."); + PADDLE_ENFORCE(ctx->HasInput("W"), + "W(Input) of Fully Connected should not be null."); + + auto in_dims = ctx->GetInputDim("Input"); + auto w_dims = ctx->GetInputDim("W"); + std::vector output_shape({in_dims[0], w_dims[1]}); + + PADDLE_ENFORCE(in_dims.size() == 4, + "Fully Connected input should be 4-D tensor."); + + PADDLE_ENFORCE(w_dims.size() == 2, + "Fully Connected input should be 2-D tensor."); + + ctx->SetOutputDim("Out", framework::make_ddim(output_shape)); + ctx->ShareLoD("Input", "Out"); +} + +framework::OpKernelType FCOp::GetExpectedKernelType( + const framework::ExecutionContext& ctx) const { + framework::LibraryType library{framework::LibraryType::kMKLDNN}; + + std::string data_format = ctx.Attr("data_format"); + framework::DataLayout layout = framework::StringToDataLayout(data_format); + + return framework::OpKernelType( + framework::ToDataType(ctx.Input("Input")->type()), ctx.GetPlace(), + layout, library); +} + +void FCOpGrad::InferShape(framework::InferShapeContext* ctx) const { + auto in_dims = ctx->GetInputDim("Input"); + auto w_dims = ctx->GetInputDim("W"); + + if (ctx->HasOutput(framework::GradVarName("Input"))) { + ctx->SetOutputDim(framework::GradVarName("Input"), in_dims); + } + if (ctx->HasOutput(framework::GradVarName("W"))) { + ctx->SetOutputDim(framework::GradVarName("W"), w_dims); + } +} + +framework::OpKernelType FCOpGrad::GetExpectedKernelType( + const framework::ExecutionContext& ctx) const { + framework::LibraryType library{framework::LibraryType::kMKLDNN}; + + std::string data_format = ctx.Attr("data_format"); + framework::DataLayout layout = framework::StringToDataLayout(data_format); + + return framework::OpKernelType( + framework::ToDataType(ctx.Input("Input")->type()), ctx.GetPlace(), + layout, library); +} + +class FCOpMaker : public framework::OpProtoAndCheckerMaker { + public: + FCOpMaker(OpProto* proto, OpAttrChecker* op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput( + "Input", + "(Tensor) The input tensor of fully connected operator. " + "The format of input tensor is NCHW, where N is batch size, C is the " + "number of channels, H is the height of the feature, " + "and W is the width of the feature."); + AddInput("W", "(Tensor), The second input tensor of fc op."); + AddOutput("Out", + "(Tensor) The output tensor of pooling operator. " + "The format of output tensor is also NCHW, " + "where N is batch size, C is the number of channels, " + "H is the height of the feature, " + "and W is the width of the feature."); + AddAttr("use_mkldnn", + "(bool, default false) Only used in mkldnn kernel") + .SetDefault(false); + AddAttr("with_bias", + "(bool, default false) Only used in mkldnn kernel") + .SetDefault(false); + AddAttr( + "data_format", + "(string, default NCHW) Only used in " + "An optional string from: \"NHWC\", \"NCHW\". " + "Defaults to \"NHWC\". Specify the data format of the output data, " + "the input will be transformed automatically. ") + .SetDefault("AnyLayout"); + AddComment(R"DOC( +)DOC"); + } +}; + +struct MKLDNNMatrixSize final { + explicit MKLDNNMatrixSize(const std::vector& in, + const std::vector& w) + : mb{in[0]}, ic{in[1]}, oc{w[1]}, h{in[2]}, w{in[3]} {} + + bool is_spatial() const { return h > 1 && w > 1; } + + const int mb; + const int ic; + const int oc; + const int h, w; +}; + +template +class MKLDNNMD { + public: + explicit MKLDNNMD(const T* in, const T* w, bool bias) + : sz_(std::unique_ptr(new MKLDNNMatrixSize( + paddle::framework::vectorize2int(in->dims()), + paddle::framework::vectorize2int(w->dims())))) { + with_bias_ = bias; + } + + mkldnn::memory::desc dst() const { + return platform::MKLDNNMemDesc({sz_->mb, sz_->oc}, + mkldnn::memory::data_type::f32, + mkldnn::memory::format::nc); + } + + mkldnn::memory::desc src() const { + return sz_->is_spatial() + ? platform::MKLDNNMemDesc({sz_->mb, sz_->ic, sz_->h, sz_->w}, + mkldnn::memory::data_type::f32, + mkldnn::memory::format::nchw) + : platform::MKLDNNMemDesc({sz_->mb, sz_->ic}, + mkldnn::memory::data_type::f32, + mkldnn::memory::format::nc); + } + + mkldnn::memory::desc weights() const { + return sz_->is_spatial() + ? platform::MKLDNNMemDesc({sz_->oc, sz_->ic, sz_->h, sz_->w}, + mkldnn::memory::data_type::f32, + mkldnn::memory::format::oihw) + : platform::MKLDNNMemDesc({sz_->oc, sz_->ic}, + mkldnn::memory::data_type::f32, + mkldnn::memory::format::oi); + } + + mkldnn::memory::desc bias() const { + return with_bias_ + ? platform::MKLDNNMemDesc({sz_->oc}, + mkldnn::memory::data_type::f32, + mkldnn::memory::format::format_undef) + : platform::MKLDNNMemDesc({}, mkldnn::memory::data_type::f32, + mkldnn::memory::format::format_undef); + } + + private: + std::unique_ptr sz_; + bool with_bias_; +}; + +class MKLDNNMemory { + public: + MKLDNNMemory(MKLDNNMD* t, const mkldnn::engine& e) + : md_{t}, engine_{e} {} + virtual ~MKLDNNMemory() = default; + + template + mkldnn::memory dst(const Output* out) { + return mkldnn::memory({md_->dst(), engine_}, + static_cast(const_cast(out))); + } + + template + mkldnn::memory dst(Output* out) { + return mkldnn::memory({md_->dst(), engine_}, out); + } + + template + mkldnn::memory src(const Input* in) { + return mkldnn::memory({md_->src(), engine_}, + static_cast(const_cast(in))); + } + + template + mkldnn::memory weights(const Weight* w) { + return mkldnn::memory({md_->weights(), engine_}, + static_cast(const_cast(w))); + } + + mkldnn::memory bias() { + return mkldnn::memory(mkldnn::memory::primitive_desc(md_->bias(), engine_)); + } + + private: + MKLDNNMD* md_; + const mkldnn::engine& engine_; +}; + +template +class FCMKLDNNOpKernel : public paddle::framework::OpKernel { + void Compute(const paddle::framework::ExecutionContext& ctx) const override { + PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()), + "It must use CPUPlace."); + + auto& dev_ctx = ctx.template device_context(); + const auto& mkldnn_engine = dev_ctx.GetEngine(); + + auto input = ctx.Input("Input"); + auto w = ctx.Input("W"); + + PADDLE_ENFORCE(input->dims().size() == 4, + "Input must be with 4 dimensions, i.e. NCHW"); + PADDLE_ENFORCE(w->dims().size() == 2, + "Weights must be with 2 dimensions, i.e. NC"); + + bool with_bias = ctx.Attr("with_bias"); + MKLDNNMD md(input, w, with_bias); + + std::shared_ptr pd = + FcFwdPrimitiveDesc(md.src(), md.weights(), md.dst(), md.bias(), + with_bias, mkldnn_engine); + + const std::string key = ctx.op().Output("Out"); + const std::string key_fc_pd = key + "@fc_pd"; + + dev_ctx.SetBlob(key_fc_pd, pd); + + MKLDNNMemory mem(&md, mkldnn_engine); + + const T* input_data = input->data(); + const T* w_data = w->data(); + + auto output = ctx.Output("Out"); + T* output_data = output->mutable_data(ctx.GetPlace()); + + auto dst_memory = mem.dst(output_data); + auto src_memory = mem.src(input_data); + auto weights_memory = mem.weights(w_data); + auto bias_memory = mem.bias(); + + auto forward = with_bias ? mkldnn::inner_product_forward( + *pd, src_memory, weights_memory, bias_memory, + dst_memory) + : mkldnn::inner_product_forward( + *pd, src_memory, weights_memory, dst_memory); + + std::vector pipeline = {forward}; + mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait(); + } + + private: + std::unique_ptr + FcFwdPrimitiveDesc(const mkldnn::memory::desc& src, + const mkldnn::memory::desc& weights, + const mkldnn::memory::desc& dst, + const mkldnn::memory::desc& bias, const bool with_bias, + const mkldnn::engine& engine) const { + auto desc = with_bias + ? mkldnn::inner_product_forward::desc( + mkldnn::prop_kind::forward, src, weights, bias, dst) + : mkldnn::inner_product_forward::desc( + mkldnn::prop_kind::forward, src, weights, dst); + + auto pd = new mkldnn::inner_product_forward::primitive_desc(desc, engine); + return std::unique_ptr(pd); + } +}; + +template +class FCMKLDNNGradOpKernel : public paddle::framework::OpKernel { + public: + void Compute(const paddle::framework::ExecutionContext& ctx) const override { + PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()), + "It must use CPUPlace."); + + auto& dev_ctx = ctx.template device_context(); + const auto& mkldnn_engine = dev_ctx.GetEngine(); + + T* input_grad_data = nullptr; + T* w_grad_data = nullptr; + + Tensor* input_grad = ctx.Output(framework::GradVarName("Input")); + Tensor* w_grad = ctx.Output(framework::GradVarName("W")); + + if (input_grad) { + input_grad_data = input_grad->mutable_data(ctx.GetPlace()); + } + if (w_grad) { + w_grad_data = w_grad->mutable_data(ctx.GetPlace()); + } + + const Tensor* input = ctx.Input("Input"); + const T* input_data = input->data(); + + const Tensor* w = ctx.Input("W"); + const T* w_data = w->data(); + + const Tensor* out_grad = ctx.Input(framework::GradVarName("Out")); + const T* out_grad_data = out_grad->data(); + + bool with_bias = ctx.Attr("with_bias"); + + MKLDNNMD md(input, w, with_bias); + MKLDNNMemory mem(&md, mkldnn_engine); + + auto dst_memory = mem.dst(out_grad_data); + auto src_memory = mem.src(input_data); + auto weights_memory = mem.weights(w_data); + auto bias_memory = mem.bias(); + + const std::string key = ctx.op().Input("Out"); + const std::string key_fc_pd = key + "@fc_pd"; + + auto pd = + std::static_pointer_cast( + dev_ctx.GetBlob(key_fc_pd)); + + PADDLE_ENFORCE(pd != nullptr, "Fail to find key_fc_pd in device context"); + + if (w_grad) { + auto weights_grad_memory = mem.weights(w_grad_data); + + mkldnn::inner_product_backward_weights::primitive_desc bwd_weight_pd = + FcBwdWeightsPrimitiveDesc(md.src(), md.weights(), md.dst(), md.bias(), + with_bias, *pd, mkldnn_engine); + + auto bwd_weights_prim = mkldnn::inner_product_backward_weights( + bwd_weight_pd, src_memory, dst_memory, weights_grad_memory, + bias_memory); + + std::vector pipeline{bwd_weights_prim}; + mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait(); + } + + if (input_grad) { + auto src_grad_memory = mem.src(input_grad_data); + + mkldnn::inner_product_backward_data::primitive_desc bwd_data_pd = + FcBwdDataPrimitiveDesc(md.src(), md.weights(), md.dst(), *pd, + mkldnn_engine); + + auto bwd_data_prim = mkldnn::inner_product_backward_data( + bwd_data_pd, dst_memory, weights_memory, src_grad_memory); + + std::vector pipeline{bwd_data_prim}; + mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait(); + } + } + + private: + mkldnn::inner_product_backward_weights::primitive_desc + FcBwdWeightsPrimitiveDesc( + const mkldnn::memory::desc& src, const mkldnn::memory::desc& diff_weights, + const mkldnn::memory::desc& diff_dst, const mkldnn::memory::desc& bias, + const bool with_bias, + const mkldnn::inner_product_forward::primitive_desc& pd, + const mkldnn::engine& engine) const { + auto bwd_weight_desc = with_bias + ? mkldnn::inner_product_backward_weights::desc( + src, diff_weights, bias, diff_dst) + : mkldnn::inner_product_backward_weights::desc( + src, diff_weights, bias, diff_dst); + + return mkldnn::inner_product_backward_weights::primitive_desc( + bwd_weight_desc, engine, pd); + } + + mkldnn::inner_product_backward_data::primitive_desc FcBwdDataPrimitiveDesc( + const mkldnn::memory::desc& diff_src, const mkldnn::memory::desc& weights, + const mkldnn::memory::desc& diff_dst, + const mkldnn::inner_product_forward::primitive_desc& pd, + const mkldnn::engine& engine) const { + auto bwd_data_desc = + mkldnn::inner_product_backward_data::desc(diff_src, weights, diff_dst); + return mkldnn::inner_product_backward_data::primitive_desc(bwd_data_desc, + engine, pd); + } +}; +} // namespace operators +} // namespace paddle + +REGISTER_OP(fc, paddle::operators::FCOp, paddle::operators::FCOpMaker, fc_grad, + paddle::operators::FCOpGrad); + +REGISTER_OP_KERNEL(fc, MKLDNN, ::paddle::platform::CPUPlace, + paddle::operators::FCMKLDNNOpKernel); + +REGISTER_OP_KERNEL(fc_grad, MKLDNN, ::paddle::platform::CPUPlace, + paddle::operators::FCMKLDNNGradOpKernel); diff --git a/paddle/fluid/operators/fc_mkldnn_op.h b/paddle/fluid/operators/fc_mkldnn_op.h new file mode 100644 index 00000000000..9e6c66491dd --- /dev/null +++ b/paddle/fluid/operators/fc_mkldnn_op.h @@ -0,0 +1,47 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +class FCOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override; + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override; +}; + +class FCOpGrad : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override; + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override; +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 3d13133bf25..bfae205bcf5 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -86,6 +86,7 @@ def fc(input, param_attr=None, bias_attr=None, use_mkldnn=False, + with_bias=False, act=None, name=None): """ @@ -133,6 +134,8 @@ def fc(input, bias_attr (ParamAttr|list of ParamAttr, default None): The parameter attribute for the bias of this layer. If it is set to None, no bias will be added to the output units. act (str, default None): Activation to be applied to the output of this layer. + use_mkldnn(bool): Use mkldnn kernel or not, it is valid only when the mkldnn + library is installed. Default: False name (str, default None): The name of this layer. Returns: @@ -162,16 +165,25 @@ def fc(input, w = helper.create_parameter( attr=param_attr, shape=param_shape, dtype=dtype, is_bias=False) tmp = helper.create_tmp_variable(dtype) - helper.append_op( - type="mul", - inputs={"X": input_var, - "Y": w}, - outputs={"Out": tmp}, - attrs={ - "x_num_col_dims": num_flatten_dims, - "y_num_col_dims": 1, - 'use_mkldnn': use_mkldnn - }) + if use_mkldnn == False: + helper.append_op( + type="mul", + inputs={"X": input_var, + "Y": w}, + outputs={"Out": tmp}, + attrs={ + "x_num_col_dims": num_flatten_dims, + "y_num_col_dims": 1, + 'use_mkldnn': use_mkldnn + }) + else: + helper.append_op( + type="fc", + inputs={"Input": input_var, + "W": w}, + outputs={"Out": tmp}, + attrs={"use_mkldnn": use_mkldnn, + "with_bias": with_bias}) mul_results.append(tmp) # sum diff --git a/python/paddle/fluid/tests/unittests/test_fc_op.py b/python/paddle/fluid/tests/unittests/test_fc_op.py new file mode 100644 index 00000000000..3f547f3c484 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_fc_op.py @@ -0,0 +1,99 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np +from op_test import OpTest + + +def fully_connected_naive(input, weights, bias_data=None): + in_n, in_c, in_h, in_w = input.shape + w_h, w_c = weights.shape + + x_data = np.reshape(input, [in_n, in_c * in_h * in_w]) + w_data = np.transpose(np.reshape(weights, (w_c, in_c * in_h * in_w))) + result = None + + if not bias_data: + result = np.dot(x_data, w_data) + else: + result = np.dot(x_data, w_data) + bias_data + + return result + + +class MatrixGenerate: + def __init__(self, mb, ic, oc, h, w): + self.input = np.random.random((mb, ic, h, w)).astype("float32") + self.weights = np.random.random((ic * h * w, oc)).astype("float32") + + +class TestFCMKLDNNOp(OpTest): + def setUp(self): + self.op_type = "fc" + self.use_mkldnn = True + self.with_bias = True + self.matrix = MatrixGenerate(1, 10, 15, 3, 3) + + self.inputs = {'Input': self.matrix.input, 'W': self.matrix.weights} + + self.attrs = { + 'use_mkldnn': self.use_mkldnn, + 'with_bias': self.with_bias + } + + self.outputs = { + 'Out': fully_connected_naive(self.matrix.input, self.matrix.weights) + } + + def test_check_output(self): + self.check_output() + + def test_check_grad_normal(self): + self.check_grad(set(['Input', 'W']), 'Out', max_relative_error=0.9) + + def test_check_grad_no_weight(self): + self.check_grad( + ['Input'], 'Out', max_relative_error=0.5, no_grad_set=set('W')) + + +class TestFCMKLDNNOp1(TestFCMKLDNNOp): + def init_op_type(self): + self.matrix = MatrixGenerate(2, 15, 48, 2, 2) + + +class TestFCMKLDNNOp2(TestFCMKLDNNOp): + def init_op_type(self): + self.matrix = MatrixGenerate(2, 32, 40, 1, 1) + + +class TestFCMKLDNNOp3(TestFCMKLDNNOp): + def init_op_type(self): + self.matrix = MatrixGenerate(2, 2, 4, 1, 1) + + +class TestFCMKLDNNOp4(TestFCMKLDNNOp): + def init_op_type(self): + self.with_bias = False + self.matrix = MatrixGenerate(2, 32, 48, 2, 2) + + +class TestFCMKLDNNOp4(TestFCMKLDNNOp): + def init_op_type(self): + self.with_bias = False + self.matrix = MatrixGenerate(2, 32, 1000, 6, 6) + + +if __name__ == "__main__": + unittest.main() -- GitLab