提交 e133df60 编写于 作者: T tensor-tang

enable native fc forward

上级 038cbf79
...@@ -128,6 +128,7 @@ class FCMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -128,6 +128,7 @@ class FCMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
PADDLE_ENFORCE(input->dims().size() == 2 || input->dims().size() == 4, PADDLE_ENFORCE(input->dims().size() == 2 || input->dims().size() == 4,
"Input must be with 2 or 4 dimensions, i.e. NCHW"); "Input must be with 2 or 4 dimensions, i.e. NCHW");
// TODO(intel): the src weight is io and mkldnn weight need be transposed !
PADDLE_ENFORCE(w->dims().size() == 2 || w->dims().size() == 4, PADDLE_ENFORCE(w->dims().size() == 2 || w->dims().size() == 4,
"Weights must be with 2 or 4 dimensions, i.e. OI or OIHW"); "Weights must be with 2 or 4 dimensions, i.e. OI or OIHW");
......
...@@ -15,6 +15,8 @@ limitations under the License. */ ...@@ -15,6 +15,8 @@ limitations under the License. */
#include "paddle/fluid/operators/fc_op.h" #include "paddle/fluid/operators/fc_op.h"
#include <vector> #include <vector>
DECLARE_int32(paddle_num_threads);
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -25,25 +27,23 @@ void FCOp::InferShape(framework::InferShapeContext* ctx) const { ...@@ -25,25 +27,23 @@ void FCOp::InferShape(framework::InferShapeContext* ctx) const {
"Out(Output) of Fully Connected should not be null."); "Out(Output) of Fully Connected should not be null.");
PADDLE_ENFORCE(ctx->HasInput("W"), PADDLE_ENFORCE(ctx->HasInput("W"),
"W(Input) of Fully Connected should not be null."); "W(Input) of Fully Connected should not be null.");
// NCHW
auto in_dims = ctx->GetInputDim("Input"); auto in_dims = ctx->GetInputDim("Input");
// IO, I=C*H*W
auto w_dims = ctx->GetInputDim("W"); auto w_dims = ctx->GetInputDim("W");
std::vector<int64_t> output_shape({in_dims[0], w_dims[1]}); std::vector<int64_t> output_shape({in_dims[0], w_dims[1]});
if (ctx->HasInput("Bias")) { if (ctx->HasInput("Bias")) {
auto bias_dims = ctx->GetInputDim("Bias"); auto bias_dims = ctx->GetInputDim("Bias");
PADDLE_ENFORCE_EQ(bias_dims[0], 1, "The shape of Bias must be [1, dim]."); PADDLE_ENFORCE_EQ(bias_dims[0], 1, "The shape of Bias must be [1, dim].");
PADDLE_ENFORCE_EQ(bias_dims[1], framework::product(w_dims) / w_dims[0], PADDLE_ENFORCE_EQ(bias_dims[1], w_dims[1],
"The shape of Bias must be [1, dim]."); "The shape of Bias must be [1, dim].");
} }
PADDLE_ENFORCE(in_dims.size() == 2 || in_dims.size() == 4, PADDLE_ENFORCE(in_dims.size() == 2 || in_dims.size() == 4,
"Fully Connected input should be 2-D or 4-D tensor."); "Fully Connected input should be 2-D or 4-D tensor.");
PADDLE_ENFORCE_EQ(w_dims.size(), 2UL,
PADDLE_ENFORCE(w_dims.size() == 2 || w_dims.size() == 4, "Fully Connected input should be 2-D tensor.");
"Fully Connected input should be 2-D or 4-D tensor."); PADDLE_ENFORCE_EQ(framework::product(in_dims) / in_dims[0], w_dims[0],
PADDLE_ENFORCE_EQ(framework::product(w_dims) / w_dims[0],
framework::product(in_dims) / in_dims[0],
"Fully Connected input and weigth size do not match."); "Fully Connected input and weigth size do not match.");
ctx->SetOutputDim("Out", framework::make_ddim(output_shape)); ctx->SetOutputDim("Out", framework::make_ddim(output_shape));
...@@ -54,7 +54,7 @@ framework::OpKernelType FCOp::GetExpectedKernelType( ...@@ -54,7 +54,7 @@ framework::OpKernelType FCOp::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const { const framework::ExecutionContext& ctx) const {
framework::LibraryType library = framework::LibraryType::kPlain; framework::LibraryType library = framework::LibraryType::kPlain;
framework::DataLayout layout = framework::DataLayout::kAnyLayout; framework::DataLayout layout = framework::DataLayout::kAnyLayout;
if (ctx.Attr<bool>("use_mkldnn");) { if (ctx.Attr<bool>("use_mkldnn")) {
library = framework::LibraryType::kMKLDNN; library = framework::LibraryType::kMKLDNN;
layout = framework::DataLayout::kMKLDNN; layout = framework::DataLayout::kMKLDNN;
} }
...@@ -75,8 +75,9 @@ void FCOpGrad::InferShape(framework::InferShapeContext* ctx) const { ...@@ -75,8 +75,9 @@ void FCOpGrad::InferShape(framework::InferShapeContext* ctx) const {
} }
if (ctx->HasInput("Bias")) { if (ctx->HasInput("Bias")) {
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("Bias")),
"Should have bias grad");
auto bias_dims = ctx->GetInputDim("Bias"); auto bias_dims = ctx->GetInputDim("Bias");
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("Bias"));
ctx->SetOutputDim(framework::GradVarName("Bias"), bias_dims); ctx->SetOutputDim(framework::GradVarName("Bias"), bias_dims);
} }
} }
...@@ -85,7 +86,7 @@ framework::OpKernelType FCOpGrad::GetExpectedKernelType( ...@@ -85,7 +86,7 @@ framework::OpKernelType FCOpGrad::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const { const framework::ExecutionContext& ctx) const {
framework::LibraryType library = framework::LibraryType::kPlain; framework::LibraryType library = framework::LibraryType::kPlain;
framework::DataLayout layout = framework::DataLayout::kAnyLayout; framework::DataLayout layout = framework::DataLayout::kAnyLayout;
if (ctx.Attr<bool>("use_mkldnn");) { if (ctx.Attr<bool>("use_mkldnn")) {
library = framework::LibraryType::kMKLDNN; library = framework::LibraryType::kMKLDNN;
layout = framework::DataLayout::kMKLDNN; layout = framework::DataLayout::kMKLDNN;
} }
...@@ -95,9 +96,11 @@ framework::OpKernelType FCOpGrad::GetExpectedKernelType( ...@@ -95,9 +96,11 @@ framework::OpKernelType FCOpGrad::GetExpectedKernelType(
} }
void FCOpMaker::Make() { void FCOpMaker::Make() {
AddInput("Input", "(Tensor) The input tensor of fully connected operator. "); AddInput("Input",
AddInput("W", "(Tensor), The second input tensor of fc op."); "(Tensor), The input tensor of fully connected operator with format "
AddInput("Bias", "(Tensor, optional) Bias vector with shape (1 x D") "(NCHW). ");
AddInput("W", "(Tensor), The weight fc op with shape (I, O).");
AddInput("Bias", "(Tensor, optional) Bias vector with shape (1 x O")
.AsDispensable(); .AsDispensable();
AddOutput("Out", "(Tensor) The output tensor of fully connected operator. "); AddOutput("Out", "(Tensor) The output tensor of fully connected operator. ");
AddAttr<bool>("use_mkldnn", AddAttr<bool>("use_mkldnn",
...@@ -120,25 +123,32 @@ template <typename T> ...@@ -120,25 +123,32 @@ template <typename T>
class FCOpKernel : public framework::OpKernel<T> { class FCOpKernel : public framework::OpKernel<T> {
public: public:
void Compute(const paddle::framework::ExecutionContext& ctx) const override { void Compute(const paddle::framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()), PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()),
"It must use CPUPlace."); "It must use CPUPlace.");
auto& dev_ctx = ctx.template device_context<CPUDeviceContext>();
auto blas = math::GetBlas<CPUDeviceContext, T>(dev_ctx);
auto input = ctx.Input<Tensor>("Input"); auto input = ctx.Input<Tensor>("Input");
auto w = ctx.Input<Tensor>("W"); auto w = ctx.Input<Tensor>("W");
auto b = ctx.Input<Tensor>("Bias"); auto b = ctx.Input<Tensor>("Bias");
auto output = ctx.Output<Tensor>("Out");
auto in_dims = ctx->GetInputDim("Input");
auto w_dims = ctx->GetInputDim("W");
auto& dev_ctx = ctx.template device_context<CPUDeviceContext>();
auto blas = math::GetBlas<CPUDeviceContext, T>(dev_ctx);
const T* input_data = input->data<T>(); const T* input_data = input->data<T>();
const T* w_data = w->data<T>(); const T* w_data = w->data<T>();
auto output = ctx.Output<Tensor>("Out");
T* output_data = output->mutable_data<T>(ctx.GetPlace()); T* output_data = output->mutable_data<T>(ctx.GetPlace());
auto in_dims = ctx->GetInputDim("Input"); blas.GEMM(CblasNoTrans, CblasNoTrans, in_dims[0], w_dims[1], w_dims[0],
auto w_dims = ctx->GetInputDim("W"); static_cast<T>(1), input_data, w_data, static_cast<T>(0),
std::vector<int64_t> output_shape({in_dims[0], w_dims[1]}); output_data);
if (bias) { if (bias) {
const T* bias_data = bias->data<T>(); const T* bias_data = bias->data<T>();
#pragma omp parallel for if (FLAGS_paddle_num_threads > 1)
for (int bs = 0; bs < in_dims[0]; bs++) {
blas.AXPY(w_dims[1], static_cast<T>(1), bias_data,
output_data + bs * w_dimws[1]);
}
} }
} }
}; };
...@@ -150,5 +160,4 @@ namespace ops = paddle::operators; ...@@ -150,5 +160,4 @@ namespace ops = paddle::operators;
REGISTER_OPERATOR(fc, ops::FCOp, ops::FCOpMaker, REGISTER_OPERATOR(fc, ops::FCOp, ops::FCOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>); paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(fc_grad, ops::FCOpGrad); REGISTER_OPERATOR(fc_grad, ops::FCOpGrad);
REGISTER_OP_CPU_KERNEL(fc, ops::FCMKLDNNOpKernel<float>, REGISTER_OP_CPU_KERNEL(fc, ops::FCOpKernel<float>, ops::FCOpKernel<double>);
ops::FCMKLDNNOpKernel<double>);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册