“17994e38aa9b11b3df0e3dcf440606e849e5be4f”上不存在“paddle/legacy/gserver/tests/rnn_data_provider.py”
提交 0c39b97b 编写于 作者: M Michał Gallus 提交者: Tao Luo

[MKL-DNN] Add Fully Connected Op for inference only(#15226)

* fuse mul and elementwise add to fc

* Reimplement the FC forward operator

* Fix FC MKLDNN integration by transposing weights

* Add FC MKLDNN Pass

test=develop

* FC MKLDNN Pass: change memcpy to std::copy

* Fix MKLDNN FC handling of mismatch input and weights dims

* Lower tolerance for MKL-DNN in resnet50 test

test=develop

* Adjust FC to support MKLDNN Op placement

test=develop

* Adjust Placement Op to set use_mkldnn attribute for graph

test=develop

* MKLDNN FC: fix weights format so that gemm version is called

test=develop

* FC MKLDNN: Remove tolerance decrease from tester_helper

* FC MKL-DNN: Refactor the code, change input reorder to weight reorder

* MKL-DNN FC: Introduce operator caching

test=develop

* FC MKL-DNN: Fix the tensor type in ExpectedKernelType

test=develop

* FC MKL-DNN: fix style changes

test=develop

* FC MKL-DNN: fallback to native on non-supported dim sizes

test=develop

* FC MKLDNN: fix CMake paths

test=develop

* FC MKLDNN: Refine placement pass graph mkldnn attribute

test=develop

* Fix Transpiler error for fuse_conv_eltwise

test=develop

* Fix missing STL includes in files

test=develop

* FC MKL-DNN: Enable new output size computation

Also, refine pass to comply with newest interface.
test=develop

* FC MKL-DNN: enable only when fc_mkldnn_pass is enabled

* FC MKL-DNN: Allow Weights to use oi or io format

* FC MKL-DNN: Adjust UT to work with correct dims

test=develop

* Enable MKL DEBUG for resnet50 analyzer

test=develop

* FC MKL-DNN: Improve Hashing function

test=develop

* FC MKL-DNN: Fix shape for fc weights in transpiler

* FC MKL-DNN: Update input pointer in re-used fc primitive

* Add log for not handling fc fuse for unsupported dims

test=develop

* FC MKL-DNN: Move transpose from pass to Op Kernel

test=develop

* FC MKL-DNN: Disable transpose in unit test

test=develop

* FC MKL-DNN: Remove fc_mkldnn_pass from default list

* Correct Flag for fake data analyzer tests

test=develop

* FC MKL-DNN: Add comment about fc mkldnn pass disablement

test=develop

* FC MKL-DNN: Disable fc in int8 tests

test=develop
上级 21138eb1
......@@ -385,7 +385,7 @@ function(cc_test TARGET_NAME)
set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_cpu_deterministic=true)
set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_init_allocated_mem=true)
set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_limit_of_tmp_allocation=4294967296) # 4G
set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_cudnn_deterministic=true)
set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_cudnn_deterministic=true ${MKL_DEBUG_FLAG})
# No unit test should exceed 10 minutes.
set_tests_properties(${TARGET_NAME} PROPERTIES TIMEOUT 600)
endif()
......
......@@ -88,6 +88,7 @@ if(WITH_MKLDNN)
pass_library(conv_brelu_mkldnn_fuse_pass inference mkldnn)
pass_library(conv_concat_relu_mkldnn_fuse_pass inference mkldnn)
pass_library(conv_elementwise_add_mkldnn_fuse_pass inference mkldnn)
pass_library(fc_mkldnn_pass inference mkldnn)
pass_library(cpu_quantize_placement_pass base mkldnn)
pass_library(cpu_quantize_pass inference mkldnn)
pass_library(cpu_quantize_squash_pass inference mkldnn)
......
......@@ -13,6 +13,7 @@
// limitations under the License.
#include "paddle/fluid/framework/ir/fc_fuse_pass.h"
#include <memory>
#include <string>
#include <unordered_set>
#include <vector>
......@@ -80,6 +81,7 @@ void FCFusePass::ApplyImpl(ir::Graph* graph) const {
}
desc.SetType("fc");
auto fc_node = g->CreateOpNode(&desc); // OpDesc will be copied.
GraphSafeRemoveNodes(graph, {mul, elementwise_add, mul_out});
......
......@@ -14,7 +14,10 @@
#include <algorithm>
#include <array>
#include <memory>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/ir/graph_helper.h"
......@@ -896,6 +899,33 @@ PDNode *patterns::FC::operator()(paddle::framework::ir::PDNode *x,
}
}
PDNode *patterns::FCMKLDNN::operator()(paddle::framework::ir::PDNode *x,
bool with_bias) {
// Create shared nodes.
x->assert_is_op_input("fc", "Input");
auto *fc_op = pattern->NewNode(fc_repr())->assert_is_op("fc");
// Create variables
// Filter
auto *fc_weight_var = pattern->NewNode(weights_repr())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input("fc", "W");
// Bias
auto *fc_bias_var = pattern->NewNode(bias_repr())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input("fc", "Bias");
// Output
auto *fc_out_var = pattern->NewNode(output_repr())
->AsOutput()
->assert_is_op_output("fc", "Out")
->assert_is_only_output_of_op("fc");
fc_op->LinksFrom({x, fc_weight_var, fc_bias_var}).LinksTo({fc_out_var});
return fc_out_var;
}
PDNode *patterns::Embedding::operator()(PDNode *x) {
x->assert_is_op_input("lookup_table", "Ids");
auto *lookup_table_op =
......
......@@ -517,6 +517,25 @@ struct FC : public PatternBase {
PATTERN_DECL_NODE(Out);
};
// MKL-DNN's FC with bias
// op: fc
// named node:
// fc
// w, bias, output
struct FCMKLDNN : public PatternBase {
FCMKLDNN(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "fc_mkldnn") {}
PDNode* operator()(PDNode* x, bool with_bias);
// declare operator node's name
PATTERN_DECL_NODE(fc);
// declare variable node's name
PATTERN_DECL_NODE(weights);
PATTERN_DECL_NODE(bias);
PATTERN_DECL_NODE(output);
};
// Embedding
struct Embedding : public PatternBase {
Embedding(PDPattern* pattern, const std::string& name_scope)
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/mkldnn/fc_mkldnn_pass.h"
#include <algorithm>
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace framework {
namespace ir {
void FCMKLDNNPass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE(graph);
Init("fc_mkldnn_pass", graph);
auto* scope = param_scope();
PADDLE_ENFORCE(scope);
GraphPatternDetector gpd;
auto* x = gpd.mutable_pattern()
->NewNode("fc_mkldnn_pass/x")
->AsInput()
->assert_is_op_input("fc", "Input");
patterns::FCMKLDNN fc_pattern(gpd.mutable_pattern(), "fc_mkldnn_pass");
fc_pattern(x, true /*with bias*/);
int found_fc_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
VLOG(4) << "Handle FC MKL-DNN pass";
if (!(graph->Has("use_mkldnn") && graph->Get<bool>("use_mkldnn"))) {
VLOG(3) << "do not perform fc fuse";
return;
}
GET_IR_NODE_FROM_SUBGRAPH(fc, fc, fc_pattern);
GET_IR_NODE_FROM_SUBGRAPH(weights, weights, fc_pattern);
GET_IR_NODE_FROM_SUBGRAPH(bias, bias, fc_pattern);
GET_IR_NODE_FROM_SUBGRAPH(output, output, fc_pattern);
OpDesc* desc = fc->Op();
auto in_size = fc->inputs[0]->Var()->GetShape().size();
if (in_size != 2 && in_size != 4) {
VLOG(3) << "Do not enable FC MKL-DNN for dimensions different than 2 & 4";
return;
}
desc->SetAttr("use_mkldnn", true);
PADDLE_ENFORCE(subgraph.count(x));
found_fc_count++;
};
gpd(graph, handler);
AddStatis(found_fc_count);
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(fc_mkldnn_pass, paddle::framework::ir::FCMKLDNNPass);
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace paddle {
namespace framework {
namespace ir {
/*
* Transpose weights of FC to comply with MKL-DNN interface
*/
class FCMKLDNNPass : public FusePassBase {
public:
virtual ~FCMKLDNNPass() {}
protected:
void ApplyImpl(ir::Graph* graph) const;
};
} // namespace ir
} // namespace framework
} // namespace paddle
......@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/ir/mkldnn/mkldnn_placement_pass.h"
#include <memory>
#include <string>
#include <unordered_set>
......@@ -24,6 +25,9 @@ void MKLDNNPlacementPass::ApplyImpl(ir::Graph* graph) const {
VLOG(3) << "Applies MKL-DNN placement strategy.";
const auto& op_types_list =
Get<std::unordered_set<std::string>>("mkldnn_enabled_op_types");
if (!graph->Has("use_mkldnn")) {
graph->Set<bool>("use_mkldnn", new bool(true));
}
for (const Node* n : graph->Nodes()) {
if (n->IsOp()) {
auto* op = n->Op();
......
......@@ -146,16 +146,19 @@ void CpuPassStrategy::EnableMKLDNN() {
if (!use_mkldnn_) {
passes_.insert(passes_.begin(), "mkldnn_placement_pass");
for (auto &pass : std::vector<std::string>(
{"depthwise_conv_mkldnn_pass", //
"conv_bn_fuse_pass", // Execute BN passes again to
"conv_eltwiseadd_bn_fuse_pass", // preserve correct pass order
"conv_bias_mkldnn_fuse_pass", //
"conv3d_bias_mkldnn_fuse_pass", //
"conv_elementwise_add_mkldnn_fuse_pass",
"conv_concat_relu_mkldnn_fuse_pass",
"conv_relu_mkldnn_fuse_pass", //
"conv_brelu_mkldnn_fuse_pass"})) {
for (auto &pass : std::vector<std::string>({
"depthwise_conv_mkldnn_pass", //
"conv_bn_fuse_pass", // Execute BN passes again to
"conv_eltwiseadd_bn_fuse_pass", // preserve correct pass order
"conv_bias_mkldnn_fuse_pass", //
"conv3d_bias_mkldnn_fuse_pass", //
"conv_elementwise_add_mkldnn_fuse_pass",
"conv_concat_relu_mkldnn_fuse_pass",
"conv_relu_mkldnn_fuse_pass", //
"conv_brelu_mkldnn_fuse_pass", //
// Disabled due to topology-dependent speed-up
// "fc_mkldnn_pass"
})) {
passes_.push_back(pass);
}
}
......
......@@ -33,8 +33,10 @@ function(inference_analysis_api_int8_test target model_dir data_dir filename)
--paddle_num_threads=${CPU_NUM_THREADS_ON_CI}
--iterations=2)
endfunction()
function(inference_analysis_api_test_with_fake_data target install_dir filename model_name)
function(inference_analysis_api_test_with_fake_data target install_dir filename model_name mkl_debug)
if(mkl_debug)
set(MKL_DEBUG_FLAG MKL_DEBUG_CPU_TYPE=7)
endif()
download_model(${install_dir} ${model_name})
inference_analysis_test(${target} SRCS ${filename}
EXTRA_DEPS ${INFERENCE_EXTRA_DEPS}
......@@ -143,15 +145,15 @@ inference_analysis_api_test_with_refer_result(test_analyzer_mobilenet_transpose
# googlenet
inference_analysis_api_test_with_fake_data(test_analyzer_googlenet
"${INFERENCE_DEMO_INSTALL_DIR}/googlenet" analyzer_resnet50_tester.cc "googlenet.tar.gz")
"${INFERENCE_DEMO_INSTALL_DIR}/googlenet" analyzer_resnet50_tester.cc "googlenet.tar.gz" false)
# resnet50
inference_analysis_api_test_with_fake_data(test_analyzer_resnet50
"${INFERENCE_DEMO_INSTALL_DIR}/resnet50" analyzer_resnet50_tester.cc "resnet50_model.tar.gz")
"${INFERENCE_DEMO_INSTALL_DIR}/resnet50" analyzer_resnet50_tester.cc "resnet50_model.tar.gz" true)
# mobilenet with depthwise_conv op
inference_analysis_api_test_with_fake_data(test_analyzer_mobilenet_depthwise_conv
"${INFERENCE_DEMO_INSTALL_DIR}/mobilenet_depthwise_conv" analyzer_resnet50_tester.cc "mobilenet_model.tar.gz")
"${INFERENCE_DEMO_INSTALL_DIR}/mobilenet_depthwise_conv" analyzer_resnet50_tester.cc "mobilenet_model.tar.gz" false)
# int8 image classification tests
if(WITH_MKLDNN)
......
......@@ -152,6 +152,7 @@ void profile(bool use_mkldnn = false) {
if (use_mkldnn) {
config.EnableMKLDNN();
config.pass_builder()->AppendPass("fc_mkldnn_pass");
}
std::vector<std::vector<PaddleTensor>> outputs;
......
......@@ -200,8 +200,9 @@ void profile(bool use_mkldnn = false) {
cfg.EnableMKLDNN();
// Enable all the mkldnn supported ops except conv3d in dam
std::unordered_set<std::string> op_list = {"softmax", "elementwise_add",
"relu"};
"relu", "fc"};
cfg.SetMKLDNNOp(op_list);
cfg.pass_builder()->AppendPass("fc_mkldnn_pass");
}
std::vector<std::vector<PaddleTensor>> outputs;
......
......@@ -100,6 +100,7 @@ void profile(bool use_mkldnn = false) {
if (use_mkldnn) {
cfg.EnableMKLDNN();
cfg.pass_builder()->AppendPass("fc_mkldnn_pass");
}
std::vector<std::vector<PaddleTensor>> input_slots_all;
......
......@@ -48,6 +48,7 @@ void profile(bool use_mkldnn = false) {
if (use_mkldnn) {
cfg.EnableMKLDNN();
cfg.pass_builder()->AppendPass("fc_mkldnn_pass");
}
std::vector<std::vector<PaddleTensor>> outputs;
......@@ -79,6 +80,7 @@ void compare(bool use_mkldnn = false) {
SetConfig(&cfg);
if (use_mkldnn) {
cfg.EnableMKLDNN();
cfg.pass_builder()->AppendPass("fc_mkldnn_pass");
}
std::vector<std::vector<PaddleTensor>> input_slots_all;
......
......@@ -149,6 +149,7 @@ void SetConfig(AnalysisConfig *cfg, bool use_mkldnn = false) {
}
if (use_mkldnn) {
cfg->EnableMKLDNN();
cfg->pass_builder()->AppendPass("fc_mkldnn_pass");
}
// Enable seqpool_concat_fuse_pass, disabled by default since it takes much
// time
......
......@@ -189,6 +189,7 @@ void profile(bool use_mkldnn = false) {
std::vector<std::vector<PaddleTensor>> outputs;
if (use_mkldnn) {
cfg.EnableMKLDNN();
cfg.pass_builder()->AppendPass("fc_mkldnn_pass");
}
std::vector<std::vector<PaddleTensor>> input_slots_all;
......
......@@ -85,6 +85,7 @@ void profile(bool use_mkldnn = false) {
SetConfig(&cfg);
if (use_mkldnn) {
cfg.EnableMKLDNN();
cfg.pass_builder()->AppendPass("fc_mkldnn_pass");
}
// cfg.pass_builder()->TurnOnDebug();
std::vector<std::vector<PaddleTensor>> outputs;
......
......@@ -12,299 +12,266 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <mkldnn/include/mkldnn_types.h>
#include <memory>
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/fc_op.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/mkldnn_helper.h"
#include "paddle/fluid/platform/variant.h"
namespace paddle {
namespace operators {
using paddle::framework::Tensor;
using paddle::platform::MKLDNNDeviceContext;
using framework::DataLayout;
using framework::Tensor;
using framework::LoDTensor;
using framework::DDim;
using framework::ExecutionContext;
using platform::MKLDNNDeviceContext;
using platform::to_void_cast;
using platform::GetMKLDNNFormat;
using mkldnn::memory;
using mkldnn::inner_product_forward;
using mkldnn::primitive;
using mkldnn::stream;
using mkldnn::prop_kind;
template <typename T>
class MKLDNNMD {
class FCPrimitiveFactory {
public:
explicit MKLDNNMD(const T* in, const T* w, bool bias)
: in(paddle::framework::vectorize2int(in->dims())),
w(paddle::framework::vectorize2int(w->dims())) {
with_bias_ = bias;
}
explicit FCPrimitiveFactory(const mkldnn::engine& engine) : engine_(engine) {}
inner_product_forward CreateFcPrimitive(const LoDTensor* input,
const Tensor* weights,
const Tensor* bias, LoDTensor* output,
const ExecutionContext& ctx) {
RecomputeOutputDims(ctx, input, weights, output);
if (fc_) {
UpdateDataPointers(ctx, output, input);
return *fc_;
}
auto src_desc = CreateMemDescriptor(input, input->format());
input_ = CreateMemory(src_desc, input);
mkldnn::memory::desc dst() const {
return platform::MKLDNNMemDesc({in[0], w[1]},
mkldnn::memory::data_type::f32,
mkldnn::memory::format::nc);
}
weights_ = TransposeWeights(weights);
if (src_desc.data.ndims == 4) {
weights_ = CreateFourDimWeightsMemory(input, weights);
}
auto dst_desc = CreateMemDescriptor(output, memory::format::any);
mkldnn::memory::desc src() const {
return is_spatial()
? platform::MKLDNNMemDesc({in[0], in[1], in[2], in[3]},
mkldnn::memory::data_type::f32,
mkldnn::memory::format::nchw)
: platform::MKLDNNMemDesc({in[0], in[1]},
mkldnn::memory::data_type::f32,
mkldnn::memory::format::nc);
fc_ = CreateFcPrimitive(*input_, *weights_, dst_desc, bias, output, ctx);
return *fc_;
}
mkldnn::memory::desc weights() const {
return is_spatial()
? platform::MKLDNNMemDesc({w[1], in[1], in[2], in[3]},
mkldnn::memory::data_type::f32,
mkldnn::memory::format::oihw)
: platform::MKLDNNMemDesc({w[1], in[1]},
mkldnn::memory::data_type::f32,
mkldnn::memory::format::oi);
private:
void UpdateDataPointers(const ExecutionContext& ctx, Tensor* out,
const Tensor* in) {
input_->set_data_handle(const_cast<T*>(in->data<T>()));
output_->set_data_handle(out->mutable_data<T>(ctx.GetPlace()));
if (out->format() == memory::format::format_undef) {
auto output_format = output_->get_primitive_desc().desc().data.format;
out->set_format((memory::format)output_format);
}
}
mkldnn::memory::desc bias() const {
return with_bias_
? platform::MKLDNNMemDesc({w[1]}, mkldnn::memory::data_type::f32,
mkldnn::memory::format::format_undef)
: platform::MKLDNNMemDesc({}, mkldnn::memory::data_type::f32,
mkldnn::memory::format::format_undef);
memory::format MatchWeightFormat(memory::format fmt) {
using format = memory::format;
switch (fmt) {
case format::nChw16c:
return format::oIhw16i;
case format::nChw8c:
return format::oIhw8i;
case format::nchw:
return format::oihw;
default:
return format::format_undef;
}
}
private:
bool is_spatial() const { return in.size() > 1 && w.size() > 1; }
mkldnn::memory Reorder(const memory::desc& src_desc,
const memory::desc& dst_desc, const void* src_data) {
auto src_mem = memory({src_desc, engine_}, const_cast<void*>(src_data));
auto dst_mem = memory({dst_desc, engine_});
std::vector<int> in;
std::vector<int> w;
bool with_bias_;
bool is_spatial_;
};
auto reorder = mkldnn::reorder(src_mem, dst_mem);
stream(stream::kind::eager).submit({reorder}).wait();
class MKLDNNMemory {
public:
MKLDNNMemory(MKLDNNMD<Tensor>* t, const mkldnn::engine& e)
: md_(t), engine_(e) {}
virtual ~MKLDNNMemory() = default;
template <typename Output>
mkldnn::memory dst(const Output* out) {
return mkldnn::memory({md_->dst(), engine_},
static_cast<void*>(const_cast<float*>(out)));
return dst_mem;
}
template <typename Output>
mkldnn::memory dst(Output* out) {
return mkldnn::memory({md_->dst(), engine_}, out);
static mkldnn::memory::desc CreateMemDescriptor(const std::vector<int>& dims,
memory::format format) {
return platform::MKLDNNMemDesc(dims, platform::MKLDNNGetDataType<T>(),
format);
}
template <typename Input>
mkldnn::memory src(const Input* in) {
return mkldnn::memory({md_->src(), engine_},
static_cast<void*>(const_cast<float*>(in)));
static mkldnn::memory::desc CreateMemDescriptor(const Tensor* tensor,
memory::format format) {
auto dims = framework::vectorize2int(tensor->dims());
return CreateMemDescriptor(dims, format);
}
template <typename Weight>
mkldnn::memory weights(const Weight* w) {
return mkldnn::memory({md_->weights(), engine_},
static_cast<void*>(const_cast<float*>(w)));
mkldnn::memory CreateMemory(const mkldnn::memory::desc& desc,
const Tensor* tensor) {
return CreateMemory(desc, tensor->data<T>());
}
mkldnn::memory bias() {
return mkldnn::memory(mkldnn::memory::primitive_desc(md_->bias(), engine_));
mkldnn::memory CreateMemory(const mkldnn::memory::desc& desc,
const void* data) {
return memory({desc, engine_}, const_cast<void*>(data));
}
private:
MKLDNNMD<Tensor>* md_;
const mkldnn::engine& engine_;
};
template <typename T>
class FCMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
public:
void Compute(const paddle::framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()),
"It must use CPUPlace.");
mkldnn::memory TransposeWeights(const Tensor* weights) {
auto dims = framework::vectorize2int(weights->dims());
std::swap(dims[0], dims[1]); // Correct output dimensions
auto src_desc = CreateMemDescriptor(dims, memory::format::io);
auto dst_desc = CreateMemDescriptor(dims, memory::format::oi);
return Reorder(src_desc, dst_desc, weights->data<T>());
}
auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
const auto& mkldnn_engine = dev_ctx.GetEngine();
inner_product_forward CreateFcPrimitive(const memory& src_memory,
const memory& weights_memory,
const memory::desc& dst_desc,
const Tensor* bias, Tensor* output,
const ExecutionContext& ctx) {
const auto weights_desc = weights_memory.get_primitive_desc().desc();
const auto src_desc = src_memory.get_primitive_desc().desc();
if (bias) {
auto bias_desc = CreateMemDescriptor(bias, bias->format());
bias_ = CreateMemory(bias_desc, bias);
auto fc_prim_desc =
CreateFcPrimDesc(src_desc, weights_desc, bias_desc, dst_desc);
output_ = CreateDstMemory(fc_prim_desc, ctx, output);
return inner_product_forward(fc_prim_desc, src_memory, weights_memory,
*bias_, *output_);
} else {
auto fc_prim_desc = CreateFcPrimDesc(src_desc, weights_desc, dst_desc);
output_ = CreateDstMemory(fc_prim_desc, ctx, output);
return inner_product_forward(fc_prim_desc, src_memory, weights_memory,
*output_);
}
}
auto input = ctx.Input<framework::LoDTensor>("Input");
auto w = ctx.Input<Tensor>("W");
auto bias = ctx.Input<Tensor>("Bias");
mkldnn::inner_product_forward::primitive_desc CreateFcPrimDesc(
const mkldnn::memory::desc& input_desc,
const mkldnn::memory::desc& weights_desc,
const mkldnn::memory::desc& bias_desc,
const mkldnn::memory::desc& dst_desc) {
auto fc_desc =
inner_product_forward::desc(prop_kind::forward_scoring, input_desc,
weights_desc, bias_desc, dst_desc);
PADDLE_ENFORCE(input->dims().size() == 2 || input->dims().size() == 4,
"Input must be with 2 or 4 dimensions, i.e. NCHW");
// TODO(intel friends): the native weight format is io,
// but the mkldnn weight format is oihw, which may need be transposed.
PADDLE_ENFORCE(w->dims().size() == 2 || w->dims().size() == 4,
"Weights must be with 2 or 4 dimensions, i.e. OI or OIHW");
return inner_product_forward::primitive_desc(fc_desc, engine_);
}
bool with_bias = bias != nullptr;
MKLDNNMD<Tensor> md(input, w, with_bias);
mkldnn::inner_product_forward::primitive_desc CreateFcPrimDesc(
const mkldnn::memory::desc& input_desc,
const mkldnn::memory::desc& weights_desc,
const mkldnn::memory::desc& dst_desc) {
auto fc_desc = inner_product_forward::desc(prop_kind::forward, input_desc,
weights_desc, dst_desc);
std::shared_ptr<mkldnn::inner_product_forward::primitive_desc> pd =
FcFwdPrimitiveDesc(md.src(), md.weights(), md.dst(), md.bias(),
with_bias, mkldnn_engine);
return inner_product_forward::primitive_desc(fc_desc, engine_);
}
const std::string key = ctx.op().Output("Out");
const std::string key_fc_pd = key + "@fc_pd";
mkldnn::memory CreateFourDimWeightsMemory(const Tensor* input,
const Tensor* weights) {
auto input_dims = framework::vectorize2int(input->dims());
auto weight_dims = framework::vectorize2int(weights->dims());
auto dims = {weight_dims[1], input_dims[1], input_dims[2], input_dims[3]};
dev_ctx.SetBlob(key_fc_pd, pd);
auto dst_format = MatchWeightFormat(input->format());
auto src_desc = CreateMemDescriptor(dims, memory::format::oihw);
auto dst_desc = CreateMemDescriptor(dims, dst_format);
MKLDNNMemory mem(&md, mkldnn_engine);
return Reorder(src_desc, dst_desc, weights_->get_data_handle());
}
const T* input_data = input->data<T>();
const T* w_data = w->data<T>();
mkldnn::memory CreateDstMemory(
const mkldnn::inner_product_forward::primitive_desc& fc_prim_desc,
const ExecutionContext& ctx, Tensor* output) {
auto dst_prim_desc = fc_prim_desc.dst_primitive_desc();
auto buffer_size = dst_prim_desc.get_size();
T* output_data = output->mutable_data<T>(
ctx.GetPlace(), ::paddle::memory::Allocator::kDefault, buffer_size);
output->set_format((memory::format)dst_prim_desc.desc().data.format);
return memory(dst_prim_desc, to_void_cast<T>(output_data));
}
auto output = ctx.Output<framework::LoDTensor>("Out");
void RecomputeOutputDims(const ExecutionContext& ctx, const LoDTensor* input,
const Tensor* w, LoDTensor* output) {
int in_num_col_dims = ctx.Attr<int>("in_num_col_dims");
std::vector<int64_t> output_dims;
FCOutputSize(input->dims(), w->dims(), output_dims, in_num_col_dims);
output->Resize(framework::make_ddim(output_dims));
output->set_lod(input->lod());
}
T* output_data = output->mutable_data<T>(ctx.GetPlace());
auto dst_memory = mem.dst(output_data);
auto src_memory = mem.src(input_data);
auto weights_memory = mem.weights(w_data);
// TODO(intel friends): bias memory should also be obtain from bias->data()
auto bias_memory = mem.bias();
private:
const mkldnn::engine& engine_;
boost::optional<memory> bias_;
boost::optional<memory> input_;
boost::optional<memory> output_;
boost::optional<memory> weights_;
boost::optional<inner_product_forward> fc_;
};
auto forward = with_bias ? mkldnn::inner_product_forward(
*pd, src_memory, weights_memory, bias_memory,
dst_memory)
: mkldnn::inner_product_forward(
*pd, src_memory, weights_memory, dst_memory);
static std::string GetHash(const Tensor* input, const Tensor* weights,
const std::string& suffix) {
auto dim2str = [](const DDim& operand_dims) {
std::string str = "";
for (size_t i = 0; i < operand_dims.size(); ++i) {
str += std::to_string(operand_dims[i]) + "-";
}
return str;
};
return std::to_string((unsigned)input->format()) + dim2str(weights->dims()) +
suffix;
}
std::vector<mkldnn::primitive> pipeline = {forward};
mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait();
template <typename T>
std::shared_ptr<FCPrimitiveFactory<T>> GetPrimitiveFactory(
const MKLDNNDeviceContext& dev_ctx, const ExecutionContext& ctx,
const Tensor* input, const Tensor* weights,
const mkldnn::engine& mkldnn_engine) {
const std::string key = GetHash(input, weights, ctx.op().Output("Out"));
auto prim_creator =
std::static_pointer_cast<FCPrimitiveFactory<T>>(dev_ctx.GetBlob(key));
if (prim_creator == nullptr) {
prim_creator = std::make_shared<FCPrimitiveFactory<T>>(mkldnn_engine);
dev_ctx.SetBlob(key, prim_creator);
}
private:
std::unique_ptr<mkldnn::inner_product_forward::primitive_desc>
FcFwdPrimitiveDesc(const mkldnn::memory::desc& src,
const mkldnn::memory::desc& weights,
const mkldnn::memory::desc& dst,
const mkldnn::memory::desc& bias, const bool with_bias,
const mkldnn::engine& engine) const {
auto desc = with_bias
? mkldnn::inner_product_forward::desc(
mkldnn::prop_kind::forward, src, weights, bias, dst)
: mkldnn::inner_product_forward::desc(
mkldnn::prop_kind::forward, src, weights, dst);
auto pd = new mkldnn::inner_product_forward::primitive_desc(desc, engine);
return std::unique_ptr<mkldnn::inner_product_forward::primitive_desc>(pd);
}
};
return prim_creator;
}
template <typename T>
class FCMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
class FCMKLDNNOpKernel : public framework::OpKernel<T> {
public:
void Compute(const paddle::framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()),
PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()),
"It must use CPUPlace.");
auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
const auto& mkldnn_engine = dev_ctx.GetEngine();
T* input_grad_data = nullptr;
T* w_grad_data = nullptr;
Tensor* input_grad = ctx.Output<Tensor>(framework::GradVarName("Input"));
Tensor* w_grad = ctx.Output<Tensor>(framework::GradVarName("W"));
const Tensor* input = ctx.Input<Tensor>("Input");
const T* input_data = input->data<T>();
const Tensor* w = ctx.Input<Tensor>("W");
const T* w_data = w->data<T>();
if (input_grad) {
input_grad->Resize(input->dims());
input_grad_data = input_grad->mutable_data<T>(ctx.GetPlace());
}
if (w_grad) {
w_grad->Resize(w->dims());
w_grad_data = w_grad->mutable_data<T>(ctx.GetPlace());
}
const Tensor* out_grad = ctx.Input<Tensor>(framework::GradVarName("Out"));
const T* out_grad_data = out_grad->data<T>();
auto input = ctx.Input<LoDTensor>("Input");
auto w = ctx.Input<Tensor>("W");
auto bias = ctx.Input<Tensor>("Bias");
bool with_bias = bias != nullptr;
MKLDNNMD<Tensor> md(input, w, with_bias);
MKLDNNMemory mem(&md, mkldnn_engine);
auto dst_memory = mem.dst(out_grad_data);
auto src_memory = mem.src(input_data);
auto weights_memory = mem.weights(w_data);
auto bias_memory = mem.bias();
auto output = ctx.Output<LoDTensor>("Out");
const std::string key = ctx.op().Input("Out");
const std::string key_fc_pd = key + "@fc_pd";
auto prim_creator =
GetPrimitiveFactory<T>(dev_ctx, ctx, input, w, mkldnn_engine);
auto fc = prim_creator->CreateFcPrimitive(input, w, bias, output, ctx);
stream(stream::kind::eager).submit({fc}).wait();
auto pd =
std::static_pointer_cast<mkldnn::inner_product_forward::primitive_desc>(
dev_ctx.GetBlob(key_fc_pd));
PADDLE_ENFORCE(pd != nullptr, "Fail to find key_fc_pd in device context");
if (w_grad) {
auto weights_grad_memory = mem.weights(w_grad_data);
mkldnn::inner_product_backward_weights::primitive_desc bwd_weight_pd =
FcBwdWeightsPrimitiveDesc(md.src(), md.weights(), md.dst(), md.bias(),
with_bias, *pd, mkldnn_engine);
auto bwd_weights_prim = mkldnn::inner_product_backward_weights(
bwd_weight_pd, src_memory, dst_memory, weights_grad_memory,
bias_memory);
std::vector<mkldnn::primitive> pipeline{bwd_weights_prim};
mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait();
}
if (input_grad) {
auto src_grad_memory = mem.src(input_grad_data);
mkldnn::inner_product_backward_data::primitive_desc bwd_data_pd =
FcBwdDataPrimitiveDesc(md.src(), md.weights(), md.dst(), *pd,
mkldnn_engine);
auto bwd_data_prim = mkldnn::inner_product_backward_data(
bwd_data_pd, dst_memory, weights_memory, src_grad_memory);
std::vector<mkldnn::primitive> pipeline{bwd_data_prim};
mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait();
}
}
private:
mkldnn::inner_product_backward_weights::primitive_desc
FcBwdWeightsPrimitiveDesc(
const mkldnn::memory::desc& src, const mkldnn::memory::desc& diff_weights,
const mkldnn::memory::desc& diff_dst, const mkldnn::memory::desc& bias,
const bool with_bias,
const mkldnn::inner_product_forward::primitive_desc& pd,
const mkldnn::engine& engine) const {
auto bwd_weight_desc = with_bias
? mkldnn::inner_product_backward_weights::desc(
src, diff_weights, bias, diff_dst)
: mkldnn::inner_product_backward_weights::desc(
src, diff_weights, diff_dst);
return mkldnn::inner_product_backward_weights::primitive_desc(
bwd_weight_desc, engine, pd);
}
mkldnn::inner_product_backward_data::primitive_desc FcBwdDataPrimitiveDesc(
const mkldnn::memory::desc& diff_src, const mkldnn::memory::desc& weights,
const mkldnn::memory::desc& diff_dst,
const mkldnn::inner_product_forward::primitive_desc& pd,
const mkldnn::engine& engine) const {
auto bwd_data_desc =
mkldnn::inner_product_backward_data::desc(diff_src, weights, diff_dst);
return mkldnn::inner_product_backward_data::primitive_desc(bwd_data_desc,
engine, pd);
output->set_layout(DataLayout::kMKLDNN);
}
};
} // namespace operators
......@@ -312,6 +279,3 @@ class FCMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
REGISTER_OP_KERNEL(fc, MKLDNN, ::paddle::platform::CPUPlace,
paddle::operators::FCMKLDNNOpKernel<float>);
REGISTER_OP_KERNEL(fc_grad, MKLDNN, ::paddle::platform::CPUPlace,
paddle::operators::FCMKLDNNGradOpKernel<float>);
......@@ -20,34 +20,30 @@ from paddle.fluid.tests.unittests.op_test import OpTest
def fully_connected_naive(input, weights, bias_data=None):
in_n, in_c, in_h, in_w = input.shape
w_h, w_c = weights.shape
x_data = np.reshape(input, [in_n, in_c * in_h * in_w])
# this transpose should be implemented at C code
w_data = np.transpose(np.reshape(weights, (w_c, in_c * in_h * in_w)))
result = None
if not bias_data:
result = np.dot(x_data, w_data)
result = np.dot(input, weights)
else:
result = np.dot(x_data, w_data) + bias_data
result = np.dot(input, weights) + bias_data
return result
class MatrixGenerate:
def __init__(self, mb, ic, oc, h, w):
self.input = np.random.random((mb, ic, h, w)).astype("float32")
self.input = np.random.random((mb, ic * h * w)).astype("float32")
self.weights = np.random.random((ic * h * w, oc)).astype("float32")
class TestFCMKLDNNOp(OpTest):
def create_data(self):
self.matrix = MatrixGenerate(1, 10, 15, 3, 3)
def setUp(self):
self.op_type = "fc"
self.use_mkldnn = True
self.matrix = MatrixGenerate(1, 10, 15, 3, 3)
self.create_data()
self.inputs = {'Input': self.matrix.input, 'W': self.matrix.weights}
self.attrs = {'use_mkldnn': self.use_mkldnn, }
......@@ -60,37 +56,16 @@ class TestFCMKLDNNOp(OpTest):
self.check_output()
def test_check_grad_normal(self):
self.check_grad(set(['Input', 'W']), 'Out', max_relative_error=0.9)
pass
def test_check_grad_no_weight(self):
self.check_grad(
['Input'], 'Out', max_relative_error=0.5, no_grad_set=set('W'))
pass
class TestFCMKLDNNOp1(TestFCMKLDNNOp):
def init_op_type(self):
def create_data(self):
self.matrix = MatrixGenerate(2, 15, 48, 2, 2)
class TestFCMKLDNNOp2(TestFCMKLDNNOp):
def init_op_type(self):
self.matrix = MatrixGenerate(2, 32, 40, 1, 1)
class TestFCMKLDNNOp3(TestFCMKLDNNOp):
def init_op_type(self):
self.matrix = MatrixGenerate(2, 2, 4, 1, 1)
class TestFCMKLDNNOp4(TestFCMKLDNNOp):
def init_op_type(self):
self.matrix = MatrixGenerate(2, 32, 48, 2, 2)
class TestFCMKLDNNOp4(TestFCMKLDNNOp):
def init_op_type(self):
self.matrix = MatrixGenerate(2, 32, 1000, 6, 6)
if __name__ == "__main__":
unittest.main()
......@@ -76,6 +76,7 @@ class InferenceTranspiler(object):
self._fuse_conv_relu_mkldnn(
program) # ResNet residual block merging
self._fuse_bn_relu_mkldnn(program)
self._fuse_mul_add_mkldnn(program)
self._is_test_pass(program)
......@@ -387,6 +388,62 @@ class InferenceTranspiler(object):
# And a better solution will be considered later.
program = program.clone()
def _fuse_mul_add_mkldnn(self, program):
'''
Transpile the program by fusing Mul+Add layers to FC layer with the MKL-DNN inner product.
The MUL following a Elementwise_add layer can be replaced by the MKL-DNN FC.
The Elementwise add's bias input 'Y' has to be added into the
MKL-DNN-based FC input 'Bias'.
The operator transformation is:
- before:
- MUL->elementwise_add -> any_other_op
- after:
- FC -> any_other_op
The transpile stages are:
1. insert a new MKL-DNN-based FC operator with `Bias` input
taken from the Elementwise add's input 'Y' (bias),
2. fuse the parameters of MUL and Elemenwise add,
3. remove the MUL, elementwise_add operators,
4. make the input of the deleted Elementwise add operator to be the input of the
new FC operator,
5. remove unused variables,
Args:
program (Program): program to transpile
'''
self.block = program.block(0)
self.input_map = {} # store the input names should be adjusted
i = 0
while i < len(self.block.ops):
# find a elementwise add op
if self.block.ops[i].type == 'elementwise_add':
add_op = self.block.ops[i]
add_idx = i
mul_idx = -1
# find the preceding mul op
for j in reversed(range(add_idx)):
if self.block.ops[j].type == 'mul':
mul_out_name = self.block.ops[j].output_arg_names[0]
if self.block.ops[j].output_arg_names[
0] in add_op.input_arg_names:
mul_op = self.block.ops[j]
mul_idx = j
break
if mul_idx < 0:
i += 1
continue
# create and insert a new fc op
fc_op_new = self._insert_fc_op(add_idx + 1, mul_op, add_op)
# remove the old operators
self.block._remove_op(add_idx)
self.block._remove_op(mul_idx)
# restart scanning for elementwise add from the deleted mul's index
i = mul_idx
i += 1
self._adjust_input()
self._remove_unused_var()
program = program.clone()
# ====================== private transpiler functions =====================
def _insert_bias_op(self, index, current_op, bn_op):
'''
......@@ -509,6 +566,42 @@ class InferenceTranspiler(object):
outputs={"Output": out_var},
attrs=attrs)
def _insert_fc_op(self, index, mul_op, add_op):
'''
Construct a new FC operator by copying the old Mul and adding the
'Y' input taken from the Elementwise add's input 'Y'.
:param index: insert location of FC
:type index: Int
:param mul_op: MUL operator to be copied
:type mul_op: Operator
:param add_op: Elementwise add operator taken bias from
:type add_op: Operator
:return: fc_op_new
:type: Operator
'''
def get_op_outputs(op, names):
result = {}
for name in names:
result[name] = self.block.var(op.output(name)[0])
return result
fc_inputs = {}
fc_inputs['Input'] = self.block.var(mul_op.input('X')[0])
fc_inputs['W'] = self.block.var(mul_op.input('Y')[0])
fc_inputs['Bias'] = self.block.var(add_op.input('Y')[0])
fc_outputs = get_op_outputs(add_op, ['Out'])
fc_attrs = {}
fc_attrs['use_mkldnn'] = True
fc_op_new = self.block._insert_op(
index,
type='fc',
inputs=fc_inputs,
outputs=fc_outputs,
attrs=fc_attrs)
return fc_op_new
def _fuse_conv_eltwise(self, index, conv_op, eltwise_op):
'''
fuse the conv op with elementwise_add
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册