diff --git a/paddle/fluid/framework/executor.cc b/paddle/fluid/framework/executor.cc index b18f426883402c9f2ec17bb58ad985b41302709a..68eca6e328da9510552f77760aea915c24292a49 100644 --- a/paddle/fluid/framework/executor.cc +++ b/paddle/fluid/framework/executor.cc @@ -81,8 +81,7 @@ Executor::Executor(const platform::Place& place) : place_(place) {} Executor::~Executor() { #ifdef PADDLE_WITH_MKLDNN - // Clear mkl-dnn cache, unless explicitly - // (as set in constructor) marked not to do so + // Clear mkl-dnn cache, // this is needed to have mkl-dnn unit tests working if (platform::is_cpu_place(place_)) { platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index 5e6da1b33497d9dda1279616ee16b2c73a13c845..2d7af7df59bd5eb8d7c37125e36a4994527f8af3 100644 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -146,6 +146,11 @@ if (WITH_MKLDNN) cc_test(test_conv_activation_mkldnn_fuse_pass SRCS mkldnn/conv_activation_mkldnn_fuse_pass_tester.cc DEPS conv_activation_mkldnn_fuse_pass) cc_test(test_conv_concat_relu_mkldnn_fuse_pass SRCS mkldnn/conv_concat_relu_mkldnn_fuse_pass_tester.cc DEPS conv_concat_relu_mkldnn_fuse_pass) cc_test(test_conv_elementwise_add_mkldnn_fuse_pass SRCS mkldnn/conv_elementwise_add_mkldnn_fuse_pass_tester.cc DEPS conv_elementwise_add_mkldnn_fuse_pass) + set(TEST_CONV_BN_PASS_DEPS conv_bn_fuse_pass graph_to_program_pass conv_op conv_transpose_op math_function im2col vol2col batch_norm_op gelu_op activation_op elementwise_add_op concat_and_split naive_executor device_context) +if (WITH_GPU) + set(TEST_CONV_BN_PASS_DEPS ${TEST_CONV_BN_PASS_DEPS} depthwise_conv) +endif() + cc_test(test_conv_batch_norm_mkldnn_fuse_pass SRCS mkldnn/mkldnn_conv_bn_fuse_pass_tester.cc DEPS ${TEST_CONV_BN_PASS_DEPS}) cc_test(test_scale_matmul_fuse_pass SRCS mkldnn/scale_matmul_fuse_pass_tester.cc DEPS scale_matmul_fuse_pass) cc_test(test_mkldnn_placement_pass SRCS mkldnn/mkldnn_placement_pass_tester.cc DEPS mkldnn_placement_pass) cc_test(test_mkldnn_inplace_pass SRCS mkldnn/mkldnn_inplace_pass_tester.cc DEPS mkldnn_inplace_pass) diff --git a/paddle/fluid/framework/ir/conv_bn_fuse_pass.cc b/paddle/fluid/framework/ir/conv_bn_fuse_pass.cc index 9474ca23a6431158bb945bbb1be96c80364c2d2d..7313ef2cc35dd7c386c11252def211db34d665ad 100644 --- a/paddle/fluid/framework/ir/conv_bn_fuse_pass.cc +++ b/paddle/fluid/framework/ir/conv_bn_fuse_pass.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "paddle/fluid/framework/ir/conv_bn_fuse_pass.h" +#include #include #include #include @@ -278,9 +279,48 @@ void ConvEltwiseAddBNFusePass::ApplyImpl(ir::Graph* graph) const { // update weights and biases float epsilon = BOOST_GET_CONST(float, batch_norm->Op()->GetAttr("epsilon")); - recompute_bias_and_weights(scope, conv_weight, *bn_scale, *bn_bias_tensor, - *bn_mean, *bn_variance, eltwise_y_in_tensor, - epsilon, conv_type()); + + // if bias is an input to other ops as well then we cannot overwrite it + // so we create separate elementwise Y in nodes + if (eltwise_y_in->outputs.size() > 1) { + // Make a copy of eltwise Y input tensor + // Create eltwise_y (conv bias) variable + VarDesc eltwise_y_in_desc(patterns::PDNodeName( + name_scope_, "eltwise_y_in" + std::to_string(found_conv_bn_count))); + eltwise_y_in_desc.SetShape( + framework::vectorize(eltwise_y_in_tensor->dims())); + eltwise_y_in_desc.SetDataType(eltwise_y_in_tensor->type()); + eltwise_y_in_desc.SetLoDLevel(eltwise_y_in->Var()->GetLoDLevel()); + eltwise_y_in_desc.SetPersistable(true); + auto* eltwise_y_in_node = g->CreateVarNode(&eltwise_y_in_desc); + auto* eltwise_y_in_tensor_ex = + scope->Var(eltwise_y_in_node->Name())->GetMutable(); + + // Initialize eltwise_y + TensorCopy(*eltwise_y_in_tensor, platform::CPUPlace(), + eltwise_y_in_tensor_ex); + + recompute_bias_and_weights(scope, conv_weight, *bn_scale, *bn_bias_tensor, + *bn_mean, *bn_variance, eltwise_y_in_tensor_ex, + epsilon, conv_type()); + // Set new var + eltwise->Op()->RenameInput(eltwise_y_in->Name(), + eltwise_y_in_node->Name()); + // Link new bias node to eltwise + IR_NODE_LINK_TO(eltwise_y_in_node, eltwise); + // unlink original bias from eltwise_op + eltwise_y_in->outputs.erase( + std::remove_if(eltwise_y_in->outputs.begin(), + eltwise_y_in->outputs.end(), + [&](Node*& n) { + return n->id() == eltwise->id() ? true : false; + }), + eltwise_y_in->outputs.end()); + } else { + recompute_bias_and_weights(scope, conv_weight, *bn_scale, *bn_bias_tensor, + *bn_mean, *bn_variance, eltwise_y_in_tensor, + epsilon, conv_type()); + } // Update the elementwise_add node eltwise->Op()->SetAttr("axis", 1); diff --git a/paddle/fluid/framework/ir/mkldnn/mkldnn_conv_bn_fuse_pass_tester.cc b/paddle/fluid/framework/ir/mkldnn/mkldnn_conv_bn_fuse_pass_tester.cc new file mode 100644 index 0000000000000000000000000000000000000000..5f819ddbfaf8b88732b35119014c34644a1c402b --- /dev/null +++ b/paddle/fluid/framework/ir/mkldnn/mkldnn_conv_bn_fuse_pass_tester.cc @@ -0,0 +1,303 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include +#include +#include +#include +#include "paddle/fluid/framework/ir/graph_traits.h" +#include "paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.h" +#include "paddle/fluid/framework/ir/pass_tester_helper.h" +#include "paddle/fluid/framework/naive_executor.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/platform/place.h" + +USE_OP(batch_norm); +USE_OP_DEVICE_KERNEL(batch_norm, MKLDNN); +USE_OP(conv2d_transpose); +USE_OP_DEVICE_KERNEL(conv2d_transpose, MKLDNN); +USE_OP(elementwise_add); +USE_OP_DEVICE_KERNEL(elementwise_add, MKLDNN); +USE_OP(gelu); +USE_OP_DEVICE_KERNEL(gelu, MKLDNN); + +namespace paddle { +namespace framework { +namespace ir { + +class MKLDNNConvBatchNormPassTest { + private: + void SetOp(ProgramDesc* prog, const std::string& type, + const std::string& name, const std::vector& inputs, + const std::vector& outputs, + boost::tribool use_mkldnn) { + auto* op = prog->MutableBlock(0)->AppendOp(); + + op->SetType(type); + + if (!boost::indeterminate(use_mkldnn)) + op->SetAttr("use_mkldnn", use_mkldnn); + + if (type == "conv2d_transpose") { + op->SetAttr("name", name); + op->SetInput("Input", {inputs[0]}); + op->SetInput("Filter", {inputs[1]}); + op->SetOutput("Output", {outputs[0]}); + op->SetAttr("is_test", true); + op->SetAttr("strides", std::vector(2, 2)); + } else if (std::unordered_set{"gelu", "leaky_relu", "relu", + "tanh"} + .count(type)) { + op->SetInput("X", inputs); + op->SetOutput("Out", {outputs[0]}); + } else if (type == "elementwise_add") { + op->SetAttr("axis", static_cast(1)); + op->SetInput("X", {inputs[0]}); + op->SetInput("Y", {inputs[1]}); + op->SetOutput("Out", {outputs[0]}); + } else if (type == "batch_norm") { + op->SetAttr("is_test", true); + op->SetAttr("epsilon", static_cast(1e-5)); + op->SetInput("X", {inputs[0]}); + op->SetInput("Scale", {inputs[1]}); + op->SetInput("Bias", {inputs[2]}); + op->SetInput("Mean", {inputs[3]}); + op->SetInput("Variance", {inputs[4]}); + op->SetOutput("Y", {outputs[0]}); + op->SetOutput("MeanOut", {outputs[1]}); + op->SetOutput("VarianceOut", {outputs[2]}); + op->SetOutput("SavedMean", {outputs[3]}); + op->SetOutput("SavedVariance", {outputs[4]}); + } else { + FAIL() << "Unexpected operator type."; + } + } + + ProgramDesc BuildProgramDesc(bool is_elementwise_add) { + ProgramDesc prog; + + // params + for (auto& v : std::vector( + {"weights", "weights2", "bias_bn", "scale", "mean", "variance", + "saved_mean", "saved_variance", "bias_bn2", "scale2", "mean2", + "variance2", "saved_mean2", "saved_variance2"})) { + auto* var = prog.MutableBlock(0)->Var(v); + var->SetType(proto::VarType::LOD_TENSOR); + var->SetPersistable(true); + } + + // inputs and non-persistant holders + for (auto& v : std::vector( + {"a", "b", "e", "f", "g", "h", "i", "j", "k", "l", "m"})) { + auto* var = prog.MutableBlock(0)->Var(v); + var->SetType(proto::VarType::LOD_TENSOR); + } + + SetOp(&prog, "conv2d_transpose", "conv1", + std::vector({"a", "weights"}), + std::vector({"f"}), true); + if (is_elementwise_add == true) { + SetOp(&prog, "conv2d_transpose", "conv2", + std::vector({"b", "weights2"}), + std::vector({"e"}), true); + SetOp(&prog, "elementwise_add", "elementwise_add1", + std::vector({"f", "g"}), + std::vector({"h"}), true); + SetOp(&prog, "elementwise_add", "elementwise_add2", + std::vector({"e", "g"}), + std::vector({"j"}), true); + SetOp(&prog, "batch_norm", "batch_norm1", + std::vector( + {"h", "scale", "bias_bn", "mean", "variance"}), + std::vector( + {"i", "mean", "variance", "saved_mean", "saved_variance"}), + true); + SetOp(&prog, "batch_norm", "batch_norm2", + std::vector( + {"j", "scale2", "bias_bn2", "mean2", "variance2"}), + std::vector( + {"k", "mean2", "variance2", "saved_mean2", "saved_variance2"}), + true); + SetOp(&prog, "elementwise_add", "elementwise_add3", + std::vector({"i", "k"}), + std::vector({"l"}), true); + } else { + SetOp(&prog, "batch_norm", "batch_norm1", + std::vector( + {"f", "scale", "bias_bn", "mean", "variance"}), + std::vector( + {"l", "mean", "variance", "saved_mean", "saved_variance"}), + true); + } + SetOp(&prog, "gelu", "gelu1", std::vector({"l"}), + std::vector({"m"}), true); + + return prog; + } + + void FillTensorWithRandomData(Tensor* tnsr, float lowb, float upb, + platform::CPUPlace place) { + float* ptr = tnsr->mutable_data(place); + // Initialize input data + std::uniform_real_distribution dist(static_cast(lowb), + static_cast(upb)); + std::mt19937 engine; + for (int i = 0; i < tnsr->numel(); ++i) { + ptr[i] = dist(engine); + } + } + + void CompareTensors(Tensor* tensor1, Tensor* tensor2) { + // check dims + for (int i = 0; i < tensor1->numel(); ++i) { + EXPECT_NEAR(tensor1->data()[i], tensor2->data()[i], 1e-3); + } + } + + public: + void MainTest(bool is_elementwise_add) { + auto base_prog = BuildProgramDesc(is_elementwise_add); + + std::unique_ptr graph(new ir::Graph(base_prog)); + Scope scope; + auto place = paddle::platform::CPUPlace(); + NaiveExecutor exe{place}; + + auto pass = PassRegistry::Instance().Get( + is_elementwise_add ? "conv_transpose_eltwiseadd_bn_fuse_pass" + : "conv_transpose_bn_fuse_pass"); + graph->SetNotOwned(kParamScopeAttr, &scope); + + auto& prog = graph->OriginProgram(); + + exe.CreateVariables(prog, 0, true, &scope); + exe.CreateVariables(prog, 0, false, &scope); + + exe.Prepare(&scope, prog, 0, false); + + std::cout << GenScopeTreeDebugInfo(&scope); + + auto* a_tensor = exe.FindTensor("a"); + auto* b_tensor = exe.FindTensor("b"); + auto* weights_tensor = exe.FindTensor("weights"); + auto* weights2_tensor = exe.FindTensor("weights2"); + auto* g_tensor = exe.FindTensor("g"); + + // Batch Norm + auto* bias_bn_tensor = exe.FindTensor("bias_bn"); // shift + auto* scale_tensor = exe.FindTensor("scale"); + auto* mean_tensor = exe.FindTensor("mean"); + auto* variance_tensor = exe.FindTensor("variance"); + auto* bias_bn2_tensor = exe.FindTensor("bias_bn2"); // shift + auto* scale2_tensor = exe.FindTensor("scale2"); + auto* mean2_tensor = exe.FindTensor("mean2"); + auto* variance2_tensor = exe.FindTensor("variance2"); + + int ic, oc, iw, ih, n, fw, fh; + + n = 1; + fw = fh = 2; + oc = ic = 24; + iw = ih = 160; + + // mb1_ic24oc24_ih8oh16kh2sh2dh0ph0_iw80ow160kw2sw2dw0pw0 deconv + a_tensor->Resize({n, ic, ih, iw}); + weights_tensor->Resize({oc, ic, fh, fw}); + g_tensor->Resize({oc}); + + bias_bn_tensor->Resize({oc}); + scale_tensor->Resize({oc}); + mean_tensor->Resize({oc}); + variance_tensor->Resize({oc}); + if (is_elementwise_add) { + b_tensor->Resize({n, ic, ih, iw}); + weights2_tensor->Resize({oc, ic, fh, fw}); + bias_bn2_tensor->Resize({oc}); + scale2_tensor->Resize({oc}); + mean2_tensor->Resize({oc}); + variance2_tensor->Resize({oc}); + } + + // Input and conv transpose + FillTensorWithRandomData(a_tensor, 1.0f, 2.0f, place); + FillTensorWithRandomData(g_tensor, 1.0f, 2.0f, place); + FillTensorWithRandomData(weights_tensor, 1.0f, 2.0f, place); + if (is_elementwise_add) { + FillTensorWithRandomData(b_tensor, 1.0f, 2.0f, place); + FillTensorWithRandomData(weights2_tensor, 1.0f, 2.0f, place); + } + + // First Batch_Norm + FillTensorWithRandomData(bias_bn_tensor, 1.0f, 2.0f, place); + FillTensorWithRandomData(scale_tensor, 1.0f, 2.0f, place); + FillTensorWithRandomData(mean_tensor, 1.0f, 2.0f, place); + FillTensorWithRandomData(variance_tensor, 1.0f, 2.0f, place); + + // Second Batch Norm (exists only when elementwise_add is present) + if (is_elementwise_add) { + FillTensorWithRandomData(bias_bn2_tensor, 1.0f, 2.0f, place); + FillTensorWithRandomData(scale2_tensor, 1.0f, 2.0f, place); + FillTensorWithRandomData(mean2_tensor, 1.0f, 2.0f, place); + FillTensorWithRandomData(variance2_tensor, 1.0f, 2.0f, place); + } + + exe.Run(); + + // Get result without IR passes applied + // Need to copy result over as the same scope is used in both executors + // so first result will be overwritten by second + auto* m_tensor = exe.FindTensor("m"); + Tensor no_ir_result; + TensorCopy(*m_tensor, place, &no_ir_result); + + graph.reset(pass->Apply(graph.release())); + + // Get Program from graph + ProgramDesc optimized_prog; + auto graph2program_pass = + paddle::framework::ir::PassRegistry::Instance().Get( + "graph_to_program_pass"); + graph2program_pass->SetNotOwned( + "program", &optimized_prog); + graph2program_pass->Apply(graph.release()); + + exe.Prepare(&scope, optimized_prog, 0, false); + exe.Run(); + + auto* ir_result = exe.FindTensor("m"); + + // Two graphs. Execute both and compare results + CompareTensors(&no_ir_result, ir_result); + + VLOG(3) << DebugString(graph); + } +}; + +TEST(MKLDNNConvBatchNormPassTest, conv_batch_norm) { + MKLDNNConvBatchNormPassTest().MainTest(false); +} + +TEST(MKLDNNConvBatchNormPassTest, conv_elementwise_add_batch_norm) { + MKLDNNConvBatchNormPassTest().MainTest(true); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +USE_PASS(conv_transpose_bn_fuse_pass); +USE_PASS(conv_transpose_eltwiseadd_bn_fuse_pass); +USE_PASS(graph_to_program_pass); diff --git a/paddle/fluid/framework/ir/pass.cc b/paddle/fluid/framework/ir/pass.cc index 3bf7d5293475a49f9c751e2750218d7917fbc5f4..78e8b1612648404743e6ba6725777e55d688e662 100644 --- a/paddle/fluid/framework/ir/pass.cc +++ b/paddle/fluid/framework/ir/pass.cc @@ -18,6 +18,7 @@ limitations under the License. */ #include #include "paddle/fluid/framework/ir/graph_helper.h" +#include "paddle/fluid/platform/device_context.h" namespace paddle { namespace framework { @@ -49,6 +50,14 @@ Graph* Pass::Apply(Graph* graph) const { graph->Set(kPassRecorder, new PassRecorder); } graph->Get(kPassRecorder).insert(Type()); +#ifdef PADDLE_WITH_MKLDNN + // Clear mkl-dnn cache, + // Passes can change params, tensors, so caching need to be discarded + platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); + platform::MKLDNNDeviceContext* dev_ctx = + (platform::MKLDNNDeviceContext*)pool.Get(paddle::platform::CPUPlace()); + dev_ctx->ResetBlobMap(); +#endif return graph; } diff --git a/paddle/fluid/framework/naive_executor.cc b/paddle/fluid/framework/naive_executor.cc index 9c1638f407d388521c216cca84cca0bcb60acc8b..a5de53e9d07d562c32885b1495981757f45cb5f9 100644 --- a/paddle/fluid/framework/naive_executor.cc +++ b/paddle/fluid/framework/naive_executor.cc @@ -118,5 +118,20 @@ void NaiveExecutor::CleanFeedFetchOps() { ops_.swap(ops); } +NaiveExecutor::~NaiveExecutor() { +#ifdef PADDLE_WITH_MKLDNN + // Clear mkl-dnn cache, + // this is needed to have mkl-dnn unit tests working + if (platform::is_cpu_place(place_)) { + platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); + platform::MKLDNNDeviceContext *dev_ctx = + (platform::MKLDNNDeviceContext *)pool.Get(place_); + dev_ctx->ResetBlobMap(); + platform::MKLDNNDeviceContext::tls().set_cur_paddle_data_layout( + paddle::framework::DataLayout::kNCHW); + } +#endif +} + } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/naive_executor.h b/paddle/fluid/framework/naive_executor.h index 5e673f68574c4ddaa4c9260367d09e9f62f6b751..81402a650a3e334e273c18b279c241282ac5bf1f 100644 --- a/paddle/fluid/framework/naive_executor.h +++ b/paddle/fluid/framework/naive_executor.h @@ -32,6 +32,8 @@ class NaiveExecutor { public: explicit NaiveExecutor(const platform::Place& place) : place_(place) {} + ~NaiveExecutor(); + // Create child scope. // Create variables. // @with_feed_fetch_ops: whether to work with the feed and fetch operators.