未验证 提交 efc5392d 编写于 作者: T tensor-tang 提交者: GitHub

Merge pull request #12676 from tensor-tang/refine/op/fc

refine fc op
......@@ -170,6 +170,9 @@ function(op_library TARGET)
file(APPEND ${pybind_file} "USE_OP(fake_dequantize_max_abs);\n")
elseif(${TARGET} STREQUAL "tensorrt_engine_op")
message(STATUS "Pybind skips [tensorrt_engine_op], for this OP is only used in inference")
elseif(${TARGET} STREQUAL "fc")
# HACK: fc only have mkldnn and cpu, which would mismatch the cpu only condition
file(APPEND ${pybind_file} "USE_CPU_ONLY_OP(${TARGET});\n")
else()
file(APPEND ${pybind_file} "USE_OP(${TARGET});\n")
endif()
......@@ -300,12 +303,6 @@ op_library(channel_recv_op DEPS concurrency)
list(REMOVE_ITEM GENERAL_OPS ${DEPS_OPS})
# The fully connected layer is deleted when the WITH_MKLDNN flag is OFF
# Because the fully connected layer has only one MKLDNN's operator
if(NOT WITH_MKLDNN)
list(REMOVE_ITEM GENERAL_OPS fc_op)
endif(NOT WITH_MKLDNN)
foreach(src ${GENERAL_OPS})
op_library(${src})
endforeach()
......
......@@ -125,13 +125,16 @@ class FCMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto input = ctx.Input<Tensor>("Input");
auto w = ctx.Input<Tensor>("W");
auto bias = ctx.Input<Tensor>("Bias");
PADDLE_ENFORCE(input->dims().size() == 2 || input->dims().size() == 4,
"Input must be with 2 or 4 dimensions, i.e. NCHW");
// TODO(intel friends): the native weight format is io,
// but the mkldnn weight format is oihw, which may need be transposed.
PADDLE_ENFORCE(w->dims().size() == 2 || w->dims().size() == 4,
"Weights must be with 2 or 4 dimensions, i.e. OI or OIHW");
bool with_bias = ctx.Attr<bool>("bias_attr");
bool with_bias = bias != nullptr;
MKLDNNMD<Tensor> md(input, w, with_bias);
std::shared_ptr<mkldnn::inner_product_forward::primitive_desc> pd =
......@@ -154,6 +157,7 @@ class FCMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto dst_memory = mem.dst(output_data);
auto src_memory = mem.src(input_data);
auto weights_memory = mem.weights(w_data);
// TODO(intel friends): bias memory should also be obtain from bias->data()
auto bias_memory = mem.bias();
auto forward = with_bias ? mkldnn::inner_product_forward(
......@@ -216,7 +220,8 @@ class FCMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
const Tensor* out_grad = ctx.Input<Tensor>(framework::GradVarName("Out"));
const T* out_grad_data = out_grad->data<T>();
bool with_bias = ctx.Attr<bool>("bias_attr");
auto bias = ctx.Input<Tensor>("Bias");
bool with_bias = bias != nullptr;
MKLDNNMD<Tensor> md(input, w, with_bias);
MKLDNNMemory mem(&md, mkldnn_engine);
......
......@@ -14,6 +14,9 @@ limitations under the License. */
#include "paddle/fluid/operators/fc_op.h"
#include <vector>
#include "paddle/fluid/operators/math/blas.h"
DECLARE_int32(paddle_num_threads);
namespace paddle {
namespace operators {
......@@ -25,16 +28,24 @@ void FCOp::InferShape(framework::InferShapeContext* ctx) const {
"Out(Output) of Fully Connected should not be null.");
PADDLE_ENFORCE(ctx->HasInput("W"),
"W(Input) of Fully Connected should not be null.");
// NCHW
auto in_dims = ctx->GetInputDim("Input");
// IO, I=C*H*W
auto w_dims = ctx->GetInputDim("W");
std::vector<int64_t> output_shape({in_dims[0], w_dims[1]});
if (ctx->HasInput("Bias")) {
auto bias_dims = ctx->GetInputDim("Bias");
PADDLE_ENFORCE_EQ(bias_dims[0], 1, "The shape of Bias must be [1, dim].");
PADDLE_ENFORCE_EQ(bias_dims[1], w_dims[1],
"The shape of Bias must be [1, dim].");
}
PADDLE_ENFORCE(in_dims.size() == 2 || in_dims.size() == 4,
"Fully Connected input should be 2-D or 4-D tensor.");
PADDLE_ENFORCE(w_dims.size() == 2 || w_dims.size() == 4,
"Fully Connected input should be 2-D or 4-D tensor.");
PADDLE_ENFORCE_EQ(w_dims.size(), 2UL,
"Fully Connected input should be 2-D tensor.");
PADDLE_ENFORCE_EQ(framework::product(in_dims) / in_dims[0], w_dims[0],
"Fully Connected input and weigth size do not match.");
ctx->SetOutputDim("Out", framework::make_ddim(output_shape));
ctx->ShareLoD("Input", "Out");
......@@ -42,9 +53,12 @@ void FCOp::InferShape(framework::InferShapeContext* ctx) const {
framework::OpKernelType FCOp::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const {
framework::LibraryType library{framework::LibraryType::kMKLDNN};
framework::DataLayout layout{framework::DataLayout::kMKLDNN};
framework::LibraryType library = framework::LibraryType::kPlain;
framework::DataLayout layout = framework::DataLayout::kAnyLayout;
if (ctx.Attr<bool>("use_mkldnn")) {
library = framework::LibraryType::kMKLDNN;
layout = framework::DataLayout::kMKLDNN;
}
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("Input")->type()), ctx.GetPlace(),
layout, library);
......@@ -60,27 +74,39 @@ void FCOpGrad::InferShape(framework::InferShapeContext* ctx) const {
if (ctx->HasOutput(framework::GradVarName("W"))) {
ctx->SetOutputDim(framework::GradVarName("W"), w_dims);
}
if (ctx->HasInput("Bias")) {
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("Bias")),
"Should have bias grad");
auto bias_dims = ctx->GetInputDim("Bias");
ctx->SetOutputDim(framework::GradVarName("Bias"), bias_dims);
}
}
framework::OpKernelType FCOpGrad::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const {
framework::LibraryType library{framework::LibraryType::kMKLDNN};
framework::DataLayout layout{framework::DataLayout::kMKLDNN};
framework::LibraryType library = framework::LibraryType::kPlain;
framework::DataLayout layout = framework::DataLayout::kAnyLayout;
if (ctx.Attr<bool>("use_mkldnn")) {
library = framework::LibraryType::kMKLDNN;
layout = framework::DataLayout::kMKLDNN;
}
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("Input")->type()), ctx.GetPlace(),
layout, library);
}
void FCOpMaker::Make() {
AddInput("Input", "(Tensor) The input tensor of fully connected operator. ");
AddInput("W", "(Tensor), The second input tensor of fc op.");
AddInput("Input",
"(Tensor), The input tensor of fully connected operator with format "
"(NCHW). ");
AddInput("W", "(Tensor), The weight fc op with shape (I, O).");
AddInput("Bias", "(Tensor, optional) Bias vector with shape (1 x O")
.AsDispensable();
AddOutput("Out", "(Tensor) The output tensor of fully connected operator. ");
AddAttr<bool>("use_mkldnn",
"(bool, default false) Only used in mkldnn kernel")
.SetDefault(false);
AddAttr<bool>("bias_attr", "(bool, default false) Only used in mkldnn kernel")
.SetDefault(false);
AddComment(R"DOC(
Fully Connected Operator.
......@@ -94,9 +120,47 @@ void FCOpMaker::Make() {
)DOC");
}
template <typename T>
class FCOpKernel : public framework::OpKernel<T> {
public:
void Compute(const paddle::framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()),
"It must use CPUPlace.");
auto input = ctx.Input<Tensor>("Input");
auto w = ctx.Input<Tensor>("W");
auto bias = ctx.Input<Tensor>("Bias");
auto output = ctx.Output<Tensor>("Out");
auto in_dims = input->dims();
auto w_dims = w->dims();
auto& dev_ctx = ctx.template device_context<platform::CPUDeviceContext>();
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(dev_ctx);
const T* input_data = input->data<T>();
const T* w_data = w->data<T>();
T* output_data = output->mutable_data<T>(ctx.GetPlace());
blas.GEMM(CblasNoTrans, CblasNoTrans, in_dims[0], w_dims[1], w_dims[0],
static_cast<T>(1), input_data, w_data, static_cast<T>(0),
output_data);
if (bias) {
const T* bias_data = bias->data<T>();
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for if (FLAGS_paddle_num_threads > 1)
#endif
for (int bs = 0; bs < in_dims[0]; bs++) {
blas.AXPY(w_dims[1], static_cast<T>(1), bias_data,
output_data + bs * w_dims[1]);
}
}
}
};
} // namespace operators
} // namespace paddle
REGISTER_OPERATOR(fc, paddle::operators::FCOp, paddle::operators::FCOpMaker,
namespace ops = paddle::operators;
REGISTER_OPERATOR(fc, ops::FCOp, ops::FCOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(fc_grad, paddle::operators::FCOpGrad);
REGISTER_OPERATOR(fc_grad, ops::FCOpGrad);
REGISTER_OP_CPU_KERNEL(fc, ops::FCOpKernel<float>, ops::FCOpKernel<double>);
......@@ -22,6 +22,7 @@ def fully_connected_naive(input, weights, bias_data=None):
w_h, w_c = weights.shape
x_data = np.reshape(input, [in_n, in_c * in_h * in_w])
# this transpose should be implemented at C code
w_data = np.transpose(np.reshape(weights, (w_c, in_c * in_h * in_w)))
result = None
......@@ -43,15 +44,11 @@ class TestFCMKLDNNOp(OpTest):
def setUp(self):
self.op_type = "fc"
self.use_mkldnn = True
self.with_bias = True
self.matrix = MatrixGenerate(1, 10, 15, 3, 3)
self.inputs = {'Input': self.matrix.input, 'W': self.matrix.weights}
self.attrs = {
'use_mkldnn': self.use_mkldnn,
'with_bias': self.with_bias
}
self.attrs = {'use_mkldnn': self.use_mkldnn, }
self.outputs = {
'Out': fully_connected_naive(self.matrix.input, self.matrix.weights)
......@@ -85,13 +82,11 @@ class TestFCMKLDNNOp3(TestFCMKLDNNOp):
class TestFCMKLDNNOp4(TestFCMKLDNNOp):
def init_op_type(self):
self.with_bias = False
self.matrix = MatrixGenerate(2, 32, 48, 2, 2)
class TestFCMKLDNNOp4(TestFCMKLDNNOp):
def init_op_type(self):
self.with_bias = False
self.matrix = MatrixGenerate(2, 32, 1000, 6, 6)
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
from op_test import OpTest
def fc_refer(matrix, with_bias):
in_n, in_c, in_h, in_w = matrix.input.shape
w_i, w_o = matrix.weights.shape
x_data = np.reshape(matrix.input, [in_n, in_c * in_h * in_w])
w_data = np.reshape(matrix.weights, [w_i, w_o])
b_data = np.reshape(matrix.bias, [1, w_o])
result = None
if with_bias:
result = np.dot(x_data, w_data) + b_data
else:
result = np.dot(x_data, w_data)
return result
class MatrixGenerate:
def __init__(self, mb, ic, oc, h, w):
self.input = np.random.random((mb, ic, h, w)).astype("float32")
self.weights = np.random.random((ic * h * w, oc)).astype("float32")
self.bias = np.random.random((1, oc)).astype("float32")
class TestFCOp(OpTest):
def setUp(self):
self.op_type = "fc"
self.matrix = MatrixGenerate(1, 10, 15, 3, 3)
self.with_bias = True
if self.with_bias:
self.inputs = {
'Input': self.matrix.input,
'W': self.matrix.weights,
'Bias': self.matrix.bias
}
else:
self.inputs = {'Input': self.matrix.input, 'W': self.matrix.weights}
self.attrs = {'use_mkldnn': False}
self.outputs = {'Out': fc_refer(self.matrix, self.with_bias)}
def test_check_output(self):
self.check_output()
class TestFCOpBiasBoth(TestFCOp):
def init_shapes(self, mb, ic, oc, h, w):
for with_bias in {True, False}:
self.with_bias = with_bias
self.matrix = MatrixGenerate(mb, ic, oc, h, w)
class TestFCOp1(TestFCOpBiasBoth):
def init_op_type(self):
self.init_shapes(2, 8, 10, 1, 1)
class TestFCOp2(TestFCOpBiasBoth):
def init_op_type(self):
self.init_shapes(4, 5, 6, 2, 2)
class TestFCOp4(TestFCOpBiasBoth):
def init_op_type(self):
self.init_shapes(1, 32, 64, 3, 3)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册