diff --git a/paddle/fluid/lite/core/mir/CMakeLists.txt b/paddle/fluid/lite/core/mir/CMakeLists.txt index ba6c363b86287b4ccdaaf1faccfa12031d338a57..322981c5827a26c1a35c9574a7c0d17492cb4f8d 100644 --- a/paddle/fluid/lite/core/mir/CMakeLists.txt +++ b/paddle/fluid/lite/core/mir/CMakeLists.txt @@ -3,10 +3,12 @@ cc_library(mir_ssa_graph SRCS ssa_graph.cc DEPS mir_node) cc_library(mir_pass SRCS pass.cc DEPS mir_ssa_graph) cc_library(mir_pass_manager SRCS pass_manager.cc DEPS mir_pass mir_ssa_graph mir_passes) cc_library(mir_pass_registry SRCS pass_registry.cc DEPS mir_pass_manager) + add_subdirectory(fusion) cc_library(mir_passes SRCS fc_fuse_pass.cc conv_elementwise_add_relu_fuse_pass.cc + conv_bn_fuse_pass.cc static_kernel_pick_pass.cc variable_place_inference_pass.cc type_target_transform_pass.cc @@ -55,6 +57,7 @@ lite_cc_test(test_pattern_matcher_lite SRCS pattern_matcher_test.cc DEPS pattern lite_cc_library(pattern_matcher_high_api SRCS pattern_matcher_high_api.cc DEPS pattern_matcher_lite) + # TODO(wz) replace framework/proto to lite proto. if (NOT LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) # it depends on the fluid/framework/proto, that is too heavy for mobile execution. diff --git a/paddle/fluid/lite/core/mir/conv_bn_fuse_pass.cc b/paddle/fluid/lite/core/mir/conv_bn_fuse_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..562ec7f45073a13f37c7f44ebcae0fb13fbb8b42 --- /dev/null +++ b/paddle/fluid/lite/core/mir/conv_bn_fuse_pass.cc @@ -0,0 +1,37 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/core/mir/conv_bn_fuse_pass.h" +#include +#include +#include "paddle/fluid/lite/core/mir/fusion/conv_bn_fuser.h" +#include "paddle/fluid/lite/core/mir/pass_registry.h" + +namespace paddle { +namespace lite { +namespace mir { + +void ConvBNFusePass::Apply(const std::unique_ptr& graph) { + fusion::ConvBNFuser fuser("conv2d"); + fuser(graph.get()); + + fusion::ConvBNFuser fuser2("depthwise_conv2d"); + fuser2(graph.get()); +} + +} // namespace mir +} // namespace lite +} // namespace paddle + +REGISTER_MIR_PASS(lite_conv_bn_fuse_pass, paddle::lite::mir::ConvBNFusePass); diff --git a/paddle/fluid/lite/core/mir/conv_bn_fuse_pass.h b/paddle/fluid/lite/core/mir/conv_bn_fuse_pass.h new file mode 100644 index 0000000000000000000000000000000000000000..d5164c906525a55f04d83a7cb22f1a75b3a20c5d --- /dev/null +++ b/paddle/fluid/lite/core/mir/conv_bn_fuse_pass.h @@ -0,0 +1,32 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "paddle/fluid/lite/core/mir/pass.h" + +namespace paddle { +namespace lite { +namespace mir { + +class ConvBNFusePass : public ProgramPass { + public: + void Apply(const std::unique_ptr& graph) override; +}; + +} // namespace mir +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/core/mir/conv_elementwise_add_relu_fuse_pass_test.cc b/paddle/fluid/lite/core/mir/conv_elementwise_add_relu_fuse_pass_test.cc index 5ada0a2c60dabf36c5d1081dcb5023baac3b173a..2cde3d25a6910a87083dd16d96d1c19f00d24ddf 100644 --- a/paddle/fluid/lite/core/mir/conv_elementwise_add_relu_fuse_pass_test.cc +++ b/paddle/fluid/lite/core/mir/conv_elementwise_add_relu_fuse_pass_test.cc @@ -40,8 +40,8 @@ std::unique_ptr BuildGraph(framework::ProgramDesc* program_desc, auto* conv2d_1 = main_block->AppendOp(); auto* conv2d_2 = main_block->AppendOp(); auto* add_1 = main_block->AppendOp(); - auto* add_2 = main_block->AppendOp(); auto* relu_1 = main_block->AppendOp(); + auto* add_2 = main_block->AppendOp(); auto* relu_2 = main_block->AppendOp(); main_block->Var("input_1"); @@ -123,8 +123,8 @@ TEST(conv_elementwise_add_relu_fuse_pass, graph_test) { auto scope = std::make_shared(); auto graph = BuildGraph(&program_desc, scope, places); - ASSERT_EQ(graph->nodes().size(), - 11UL /*vars*/ + 6UL /*ops*/ + 2UL /*feed op + fetch op*/); + Visualize(graph.get()); + ASSERT_EQ(graph->nodes().size(), 11UL /*vars*/ + 6UL /*ops*/); Visualize(graph.get()); } diff --git a/paddle/fluid/lite/core/mir/fusion/CMakeLists.txt b/paddle/fluid/lite/core/mir/fusion/CMakeLists.txt index e0816c1be56665602bff7dce685f21084d82a0a1..1aecfdaed02d6f82e3829d076126adfddf686763 100644 --- a/paddle/fluid/lite/core/mir/fusion/CMakeLists.txt +++ b/paddle/fluid/lite/core/mir/fusion/CMakeLists.txt @@ -1,11 +1,18 @@ cc_library(fuse_fc SRCS fc_fuser.cc DEPS pattern_matcher_high_api) -cc_library(conv_elementwise_add_relu +cc_library(fuse_conv_elementwise_add_relu SRCS conv_elementwise_add_relu_fuser.cc DEPS pattern_matcher_high_api) +cc_library(fuse_conv_bn + SRCS conv_bn_fuser.cc + DEPS pattern_matcher_high_api) set(mir_fusers fuse_fc - conv_elementwise_add_relu + fuse_conv_elementwise_add_relu + fuse_conv_bn CACHE INTERNAL "fusers") + +lite_cc_test(test_lite_conv_bn_fuse SRCS conv_bn_fuse_pass_test.cc + DEPS elementwise_ops_lite batch_norm_op_lite conv_op_lite proto_desc compatible_pb_lite program_lite mir_pass mir_pass_manager pattern_matcher_high_api) diff --git a/paddle/fluid/lite/core/mir/fusion/conv_bn_fuse_pass_test.cc b/paddle/fluid/lite/core/mir/fusion/conv_bn_fuse_pass_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..7ce20c4d6e28d8368397510ea912ede647224226 --- /dev/null +++ b/paddle/fluid/lite/core/mir/fusion/conv_bn_fuse_pass_test.cc @@ -0,0 +1,135 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/core/mir/conv_bn_fuse_pass.h" +#include +#include +#include +#include "paddle/fluid/framework/program_desc.h" +#include "paddle/fluid/lite/core/compatible_tensor.h" +#include "paddle/fluid/lite/core/mir/graph_visualize_pass.h" +#include "paddle/fluid/lite/core/program.h" + +namespace paddle { +namespace lite { +namespace mir { +namespace fusion { + +std::unique_ptr BuildGraph(framework::ProgramDesc* program_desc, + const std::shared_ptr& scope, + const std::vector& valid_places) { + auto* main_block = program_desc->MutableBlock(0); + auto* conv_op = main_block->AppendOp(); + auto* bn_op = main_block->AppendOp(); + main_block->Var("conv_i"); + main_block->Var("conv_param"); + main_block->Var("conv_out"); + + main_block->Var("bn_scale"); + main_block->Var("bn_bias"); + main_block->Var("bn_mean"); + main_block->Var("bn_var"); + main_block->Var("bn_out"); + main_block->Var("bn_mean_out"); + main_block->Var("bn_var_out"); + main_block->Var("bn_saved_mean"); + main_block->Var("bn_saved_var"); + + scope->Var("conv_i")->GetMutable(); + auto conv_param_t = scope->Var("conv_param")->GetMutable(); + std::vector conv_param_shape = {3, 1, 2, 2}; + conv_param_t->Resize(lite::DDim(conv_param_shape)); + conv_param_t->mutable_data(); + scope->Var("conv_out")->GetMutable(); + auto bn_scale_t = scope->Var("bn_scale")->GetMutable(); + std::vector bn_scale_shape = {3}; + bn_scale_t->Resize(lite::DDim(bn_scale_shape)); + bn_scale_t->mutable_data(); + + auto bn_bias_t = scope->Var("bn_bias")->GetMutable(); + std::vector bn_bias_shape = {3}; + bn_bias_t->Resize(lite::DDim(bn_bias_shape)); + bn_bias_t->mutable_data(); + + auto bn_mean_t = scope->Var("bn_mean")->GetMutable(); + bn_mean_t->Resize(lite::DDim(bn_bias_shape)); + bn_mean_t->mutable_data(); + + auto bn_var_t = scope->Var("bn_var")->GetMutable(); + bn_var_t->Resize(lite::DDim(bn_bias_shape)); + bn_var_t->mutable_data(); + + scope->Var("bn_out")->GetMutable(); + scope->Var("bn_mean_out")->GetMutable(); + scope->Var("bn_var_out")->GetMutable(); + scope->Var("bn_saved_mean")->GetMutable(); + scope->Var("bn_saved_var")->GetMutable(); + + conv_op->SetType("conv2d"); + conv_op->SetInput("Input", {"conv_i"}); + conv_op->SetInput("Filter", {"conv_param"}); + conv_op->SetOutput("Output", {"conv_out"}); + const std::vector strides({1, 1}); + const std::vector paddings({1, 1}); + const std::vector dilations({1, 1}); + const int groups = 1; + conv_op->SetAttr("strides", strides); + conv_op->SetAttr("paddings", paddings); + conv_op->SetAttr("dilations", dilations); + conv_op->SetAttr("groups", groups); + + bn_op->SetType("batch_norm"); + bn_op->SetInput("X", {"conv_out"}); + bn_op->SetInput("Bias", {"bn_bias"}); + bn_op->SetInput("Mean", {"bn_mean"}); + bn_op->SetInput("Scale", {"bn_scale"}); + bn_op->SetInput("Variance", {"bn_var"}); + + bn_op->SetOutput("Y", {"bn_out"}); + bn_op->SetOutput("MeanOut", {"bn_mean_out"}); + bn_op->SetOutput("VarianceOut", {"bn_var_out"}); + bn_op->SetOutput("SavedMean", {"bn_saved_mean"}); + bn_op->SetOutput("SavedVariance", {"bn_saved_var"}); + float eps = 1e-5; + bn_op->SetAttr("epsilon", eps); + + program_desc->Flush(); + + lite::Program program(*program_desc->Proto(), scope, valid_places); + auto graph = std::unique_ptr(new SSAGraph()); + graph->Build(program, valid_places); + + return graph; +} + +TEST(pattern_matcher2, test) { + framework::ProgramDesc program_desc; + std::vector places{{TARGET(kHost), PRECISION(kFloat)}}; + auto scope = std::make_shared(); + auto graph = BuildGraph(&program_desc, scope, places); + const int num_nodes = graph->nodes().size(); + auto* fuser = new ConvBNFusePass; + fuser->Apply(graph); + ASSERT_EQ(graph->nodes().size(), + num_nodes - 8UL /*nodes removed */ + 1UL /* eltwise_add node*/); +} + +} // namespace fusion +} // namespace mir +} // namespace lite +} // namespace paddle + +USE_LITE_OP(conv2d); +USE_LITE_OP(batch_norm); +USE_LITE_OP(elementwise_add); diff --git a/paddle/fluid/lite/core/mir/fusion/conv_bn_fuser.cc b/paddle/fluid/lite/core/mir/fusion/conv_bn_fuser.cc new file mode 100644 index 0000000000000000000000000000000000000000..e753f8a858dbfe9cbe7a5f29e473524ac9196f70 --- /dev/null +++ b/paddle/fluid/lite/core/mir/fusion/conv_bn_fuser.cc @@ -0,0 +1,128 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/core/mir/fusion/conv_bn_fuser.h" +#include +#include + +namespace paddle { +namespace lite { +namespace mir { +namespace fusion { + +void ConvBNFuser::BuildPattern() { + auto* conv_input = + VarNode("conv_input")->assert_is_op_input(conv_type_, "Input")->AsInput(); + auto* conv_weight = VarNode("conv_weight") + ->assert_is_op_input(conv_type_, "Filter") + ->AsInput(); + auto* conv = OpNode("conv2d", conv_type_)->assert_is_op(conv_type_); + auto* conv_out = VarNode("conv_out") + ->assert_is_op_output(conv_type_, "Output") + ->assert_is_op_input("batch_norm", "X"); + + auto* bn_scale = VarNode("bn_scale") + ->assert_is_op_input("batch_norm", "Scale") + ->AsIntermediate(); + auto* bn_bias = + VarNode("bn_bias")->assert_is_op_input("batch_norm", "Bias")->AsInput(); + auto* bn_mean = VarNode("bn_mean") + ->assert_is_op_input("batch_norm", "Mean") + ->AsIntermediate(); + auto* bn_var = VarNode("bn_variance") + ->assert_is_op_input("batch_norm", "Variance") + ->AsIntermediate(); + auto* bn = + OpNode("bn", "batch_norm")->assert_is_op("batch_norm")->AsIntermediate(); + + auto* bn_out = + VarNode("bn_out")->assert_is_op_output("batch_norm", "Y")->AsOutput(); + auto* bn_mean_out = VarNode("bn_mean_out") + ->assert_is_op_output("batch_norm", "MeanOut") + ->AsIntermediate(); + auto* bn_var_out = VarNode("bn_var_out") + ->assert_is_op_output("batch_norm", "VarianceOut") + ->AsIntermediate(); + auto* bn_saved_mean = VarNode("bn_saved_mean") + ->assert_is_op_output("batch_norm", "SavedMean") + ->AsIntermediate(); + auto* bn_saved_var = VarNode("bn_saved_var") + ->assert_is_op_output("batch_norm", "SavedVariance") + ->AsIntermediate(); + + conv->LinksFrom({conv_input, conv_weight}).LinksTo({conv_out}); + + bn->LinksFrom({conv_out, bn_scale, bn_bias, bn_mean, bn_var}) + .LinksTo({bn_out, bn_mean_out, bn_saved_mean, bn_saved_var, bn_var_out}); +} + +void ConvBNFuser::InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) { + auto op_desc = GenOpDesc(matched); + auto eltwise_op = LiteOpRegistry::Global().Create("elementwise_add"); + auto conv = matched.at("conv2d")->stmt()->op; + auto* scope = conv->scope(); + auto& valid_places = conv->valid_places(); + + auto conv_weight_t = scope->FindVar(matched.at("conv_weight")->arg()->name) + ->GetMutable(); + auto conv_weight_d = conv_weight_t->mutable_data(); + auto conv_weight_dims = conv_weight_t->dims(); + size_t weight_num = conv_weight_t->data_size(); + + auto bn_scale_t = scope->FindVar(matched.at("bn_scale")->arg()->name) + ->GetMutable(); + size_t bias_size = bn_scale_t->data_size(); + auto bn_scale_d = bn_scale_t->mutable_data(); + PADDLE_ENFORCE(bias_size == conv_weight_dims[0], + "The BN bias's size should be equal to the size of the first " + "dim size of the conv weights"); + + auto bn_mean_t = scope->FindVar(matched.at("bn_mean")->arg()->name) + ->GetMutable(); + auto bn_mean_d = bn_mean_t->mutable_data(); + + auto bn_var_t = scope->FindVar(matched.at("bn_variance")->arg()->name) + ->GetMutable(); + auto bn_var_d = bn_var_t->mutable_data(); + + auto bn_bias_t = scope->FindVar(matched.at("bn_bias")->arg()->name) + ->GetMutable(); + auto bn_bias_d = bn_bias_t->mutable_data(); + auto eps = matched.at("bn")->stmt()->op_info()->GetAttr("epsilon"); + + ComputeFusedWeight(bn_scale_d, bn_mean_d, bn_var_d, bn_bias_d, conv_weight_d, + eps, bias_size, weight_num / bias_size); + + eltwise_op->Attach(op_desc, scope); + auto* new_op_node = graph->GraphCreateInstructNode(eltwise_op, valid_places); + + IR_NODE_LINK_TO(matched.at("conv_out"), new_op_node); + IR_NODE_LINK_TO(matched.at("bn_bias"), new_op_node); + IR_NODE_LINK_TO(new_op_node, matched.at("bn_out")); +} + +cpp::OpDesc ConvBNFuser::GenOpDesc(const key2nodes_t& matched) { + cpp::OpDesc op_desc; + op_desc.SetType("elementwise_add"); + op_desc.SetInput("X", {matched.at("conv_out")->arg()->name}); + op_desc.SetInput("Y", {matched.at("bn_bias")->arg()->name}); + op_desc.SetOutput("Out", {matched.at("bn_out")->arg()->name}); + op_desc.SetAttr("axis", 1); + return op_desc; +} + +} // namespace fusion +} // namespace mir +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/core/mir/fusion/conv_bn_fuser.h b/paddle/fluid/lite/core/mir/fusion/conv_bn_fuser.h new file mode 100644 index 0000000000000000000000000000000000000000..a591d20717e2b18771f27b709580d6a07d32bca2 --- /dev/null +++ b/paddle/fluid/lite/core/mir/fusion/conv_bn_fuser.h @@ -0,0 +1,57 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "paddle/fluid/lite/core/mir/pattern_matcher_high_api.h" + +namespace paddle { +namespace lite { +namespace mir { +namespace fusion { + +class ConvBNFuser : public FuseBase { + public: + explicit ConvBNFuser(const std::string& conv_type) : conv_type_(conv_type) {} + void BuildPattern() override; + void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override; + + private: + cpp::OpDesc GenOpDesc(const key2nodes_t& matched) override; + void ComputeFusedWeight(float* scale_d, float* mean_d, float* var_d, + float* bias_d, float* conv_weight_d, float eps, int h, + int w) { + for (int i = 0; i < h; i++) { + var_d[i] = scale_d[i] / std::sqrt(var_d[i] + eps); + } + for (int i = 0; i < h; i++) { + bias_d[i] += (-mean_d[i]) * var_d[i]; + } + for (int i = 0; i < h; i++) { + for (int j = 0; j < w; j++) { + conv_weight_d[i * w + j] *= var_d[i]; + } + } + } + + private: + std::string conv_type_{"conv2d"}; +}; + +} // namespace fusion +} // namespace mir +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/core/mir/node.h b/paddle/fluid/lite/core/mir/node.h index 67ee47a9e12fde139a81e5b21759645a87e6b098..a5fd90dac482d434afb624216aad875e12350c36 100644 --- a/paddle/fluid/lite/core/mir/node.h +++ b/paddle/fluid/lite/core/mir/node.h @@ -71,12 +71,20 @@ class Node { struct Arg { std::string name; + int id{0}; const Type* type{}; // Weight is a special kind of argument, it is marked as weight explicitly // so that some weight related optimization can take place. bool is_weight{false}; }; + Arg& AsArg(const std::string& name, int id) { + auto& x = AsArg(); + x.name = name; + x.id = id; + return x; + } + Arg& AsArg(const std::string& name) { auto& x = AsArg(); x.name = name; diff --git a/paddle/fluid/lite/core/mir/passes.h b/paddle/fluid/lite/core/mir/passes.h index e0110a8e3b27fd11c03fa15732b634e3d7978682..b65e1d53d0796854091ca9bcea1da234d0dd9419 100644 --- a/paddle/fluid/lite/core/mir/passes.h +++ b/paddle/fluid/lite/core/mir/passes.h @@ -31,3 +31,4 @@ USE_MIR_PASS(generate_program_pass); USE_MIR_PASS(io_copy_kernel_pick_pass); USE_MIR_PASS(argument_type_display_pass); USE_MIR_PASS(runtime_context_assign_pass); +USE_MIR_PASS(lite_conv_bn_fuse_pass); diff --git a/paddle/fluid/lite/core/mir/pattern_matcher.cc b/paddle/fluid/lite/core/mir/pattern_matcher.cc index 8a83bd242bd0b9430cd668f3ff7efd559c1b1186..7524312db8b88055fe27344bb860906ee9c0d63f 100644 --- a/paddle/fluid/lite/core/mir/pattern_matcher.cc +++ b/paddle/fluid/lite/core/mir/pattern_matcher.cc @@ -407,20 +407,22 @@ PMNode *PMNode::assert_is_op_output(const std::string &op_type) { return this; } -PMNode *PMNode::assert_is_op_input(const std::string &op_type) { - assert_is_var(); - asserts_.emplace_back([=](const Node *x) { - for (auto *op : x->outlinks) { - if (op && op->IsStmt()) { - auto *op_info = op->stmt()->op_info(); - if (op_info->Type() == op_type) { - return true; - } - } - } - return false; - }); - return this; +bool IsNthOutput(const Node *var, const Node *op, const std::string &argument, + size_t nth) { + PADDLE_ENFORCE(var->IsArg()); + PADDLE_ENFORCE(op->IsStmt()); + auto op_info = op->stmt()->op_info(); + if (op_info->Output(argument).size() <= nth) return false; + return var->arg()->name == op_info->Output(argument)[nth]; +} + +bool IsNthInput(const Node *var, const Node *op, const std::string &argument, + size_t nth) { + PADDLE_ENFORCE(var->IsArg()); + PADDLE_ENFORCE(op->IsStmt()); + auto op_info = op->stmt()->op_info(); + if (op_info->Input(argument).size() <= nth) return false; + return var->arg()->name == op_info->Input(argument)[nth]; } PMNode *PMNode::assert_is_op_input(const std::string &op_type, @@ -436,8 +438,8 @@ PMNode *PMNode::assert_is_op_nth_input(const std::string &op_type, assert_is_op_input(op_type); asserts_.emplace_back([=](const Node *x) { for (auto *op : x->outlinks) { - if (op->IsStmt() && op->stmt()->op_info()->Type() == op_type && - IsNthInput(*x, *op, argument, nth)) + if (op && op->IsStmt() && op->stmt()->op_info()->Type() == op_type && + IsNthInput(x, op, argument, nth)) return true; } return false; @@ -445,14 +447,41 @@ PMNode *PMNode::assert_is_op_nth_input(const std::string &op_type, return this; } -bool IsNthInput(const Node &var, const Node &op, const std::string &argument, - int nth) { - CHECK(var.IsArg()); - CHECK(op.IsStmt()); - if (!HasInput(op, argument) || - static_cast(op.stmt()->op_info()->Input(argument).size()) <= nth) +PMNode *PMNode::assert_is_op_output(const std::string &op_type, + const std::string &argument) { + assert_is_var(); + assert_is_op_nth_output(op_type, argument, 0); + return this; +} + +PMNode *PMNode::assert_is_op_nth_output(const std::string &op_type, + const std::string &argument, int nth) { + assert_is_var(); + asserts_.emplace_back([=](const Node *x) { + for (auto *op : x->inlinks) { + if (op && op->IsStmt() && op->stmt()->op_info()->Type() == op_type && + IsNthOutput(x, op, argument, nth)) + return true; + } + return false; + }); + return this; +} + +PMNode *PMNode::assert_is_op_input(const std::string &op_type) { + assert_is_var(); + asserts_.emplace_back([=](const Node *x) { + for (auto *op : x->outlinks) { + if (op && op->IsStmt()) { + auto *op_info = op->stmt()->op_info(); + if (op_info->Type() == op_type) { + return true; + } + } + } return false; - return var.arg()->name == op.stmt()->op_info()->Input(argument)[nth]; + }); + return this; } bool HasInput(const Node &op, const std::string &argument) { diff --git a/paddle/fluid/lite/core/mir/pattern_matcher.h b/paddle/fluid/lite/core/mir/pattern_matcher.h index f2862a229e3eea5c621227c77011f26f1dc1ed29..ff9fbce35ddf3f601a441bb6105dc658505cbe0e 100644 --- a/paddle/fluid/lite/core/mir/pattern_matcher.h +++ b/paddle/fluid/lite/core/mir/pattern_matcher.h @@ -129,8 +129,13 @@ struct PMNode { PMNode* assert_is_op_input(const std::string& op_type); PMNode* assert_is_op_input(const std::string& op_type, const std::string& argument); + PMNode* assert_is_op_output(const std::string& op_type, + const std::string& argument); + PMNode* assert_is_op_nth_input(const std::string& op_type, const std::string& argument, int nth); + PMNode* assert_is_op_nth_output(const std::string& op_type, + const std::string& argument, int nth); template PMNode* assert_op_attr(const std::string& attr_name, const T& attr) { diff --git a/paddle/fluid/lite/core/mir/pattern_matcher_high_api_test.cc b/paddle/fluid/lite/core/mir/pattern_matcher_high_api_test.cc index beee4d32acb733007c088f3101fec02ccf94e8a4..7a46bb9a93d95b9379c961d8044fbdfcd04e7ab4 100644 --- a/paddle/fluid/lite/core/mir/pattern_matcher_high_api_test.cc +++ b/paddle/fluid/lite/core/mir/pattern_matcher_high_api_test.cc @@ -124,8 +124,7 @@ TEST(pattern_matcher_high_api, graph_test) { auto scope = std::make_shared(); auto graph = BuildGraph(&program_desc, scope, places); - ASSERT_EQ(graph->nodes().size(), - 7UL /*real nodes*/ + 2UL /*feed op + fetch op*/); + ASSERT_EQ(graph->nodes().size(), 7UL /*real nodes*/); Visualize(graph.get()); } diff --git a/paddle/fluid/lite/core/mir/ssa_graph.cc b/paddle/fluid/lite/core/mir/ssa_graph.cc index 82507067c4726b271013cf4a69e95c5045b091a8..ba99a681f79db0406ce1ddd0bb53c0c4ad19a0bc 100644 --- a/paddle/fluid/lite/core/mir/ssa_graph.cc +++ b/paddle/fluid/lite/core/mir/ssa_graph.cc @@ -16,6 +16,7 @@ #include #include #include +#include #include namespace paddle { @@ -25,6 +26,8 @@ namespace mir { bool SSAGraph::CheckBidirectionalConnection() { LOG(INFO) << "node count " << node_storage_.size(); for (auto &node : node_storage_) { + if (node.IsStmt()) LOG(INFO) << node.AsStmt().op_info()->Type(); + if (node.IsArg()) LOG(INFO) << node.AsArg().name << " " << node.AsArg().id; for (auto *in : node.inlinks) { CHECK(in->outlinks.end() != std::find(in->outlinks.begin(), in->outlinks.end(), &node)); @@ -93,31 +96,6 @@ std::vector SSAGraph::StmtTopologicalOrder() { return res; } -void SSAGraph::GraphCreateTmpVarNodes(const Program &program) { - for (const auto &name : program.tmp_vars()) { - CHECK(!arguments_.count(name)) << "duplicate creating temp variable: " - << name; - VLOG(5) << "create arg node " << name; - node_storage_.emplace_back(); - auto &new_node = node_storage_.back(); - new_node.AsArg(name); - arguments_[name] = &new_node; - } -} - -void SSAGraph::GraphCreateWeightVarNodes(const Program &program) { - // create weight nodes. - for (const auto &name : program.weights()) { - CHECK(!arguments_.count(name)) << "duplicate creating weight variable: " - << name; - VLOG(5) << "create arg node " << name; - node_storage_.emplace_back(); - auto &new_node = node_storage_.back(); - new_node.AsArg(name); - arguments_[name] = &new_node; - } -} - Node *SSAGraph::GraphCreateInstructNode( const std::shared_ptr &op, const std::vector &valid_places) { node_storage_.emplace_back(); @@ -135,29 +113,50 @@ Node *SSAGraph::GraphCreateInstructNode( void SSAGraph::Build(const Program &program, const std::vector &valid_places) { CHECK(node_storage_.empty()); - GraphCreateTmpVarNodes(program); - GraphCreateWeightVarNodes(program); - CHECK(CheckNodesRoleSet()); + auto weights_name = program.weights(); + auto is_weights = [&](const std::string &name) -> bool { + auto it = std::find(weights_name.begin(), weights_name.end(), name); + if (it == weights_name.end()) return false; + return true; + }; + + std::unordered_map arg_update_node_map_; for (auto &op : program.ops()) { + LOG(INFO) << op->op_info()->Type(); auto *op_node = GraphCreateInstructNode(op, valid_places); + LOG(INFO) << "input:"; for (const std::string &name : op->op_info()->input_names()) { - auto *arg = Argument(name); - CHECK(arg->IsRoleSet()); - DirectedLink(arg, op_node); + LOG(INFO) << name; + mir::Node *arg_node = nullptr; + if (arg_update_node_map_.count(name)) { + arg_node = arg_update_node_map_.at(name); + } else { + node_storage_.emplace_back(); + arg_node = &node_storage_.back(); + arg_node->AsArg(name, node_storage_.size() - 1); + arg_update_node_map_[name] = arg_node; + } + if (is_weights(name)) arg_node->AsArg().is_weight = true; + CHECK(arg_node->IsRoleSet()); + DirectedLink(arg_node, op_node); } + LOG(INFO) << "output:"; for (const std::string &name : op->op_info()->output_names()) { - if (!arguments_.count(name)) { - NewArgumentNode(name); - } - auto *arg = arguments_.at(name); - CHECK(arg->IsRoleSet()); - DirectedLink(op_node, arg); + LOG(INFO) << name; + node_storage_.emplace_back(); + auto *arg_node = &node_storage_.back(); + arg_node->AsArg(name, node_storage_.size() - 1); + arg_update_node_map_[name] = arg_node; + + if (is_weights(name)) arg_node->AsArg().is_weight = true; + CHECK(arg_node->IsRoleSet()); + DirectedLink(op_node, arg_node); } CHECK(CheckLinksRoleSet()); } - MarkArgumentWeights(program); + CHECK(CheckNodesRoleSet()); CheckValid(); } @@ -227,10 +226,9 @@ bool SSAGraph::CheckLinksRoleSet() { Node *SSAGraph::NewArgumentNode(const std::string &name) { node_storage_.emplace_back(); - CHECK(!arguments_.count(name)) << "duplicate argument called " << name; - arguments_[name] = &node_storage_.back(); - node_storage_.back().AsArg(name); - return &node_storage_.back(); + auto &arg_node = node_storage_.back(); + arg_node.AsArg(name, node_storage_.size() - 1); + return &arg_node; } Node *SSAGraph::NewInstructNode() { diff --git a/paddle/fluid/lite/core/mir/ssa_graph.h b/paddle/fluid/lite/core/mir/ssa_graph.h index 5cad1478c225a6551fcd653ca4e79b58360e3724..7c0e6cef498c5c555c1cee6ab334e6be556a9897 100644 --- a/paddle/fluid/lite/core/mir/ssa_graph.h +++ b/paddle/fluid/lite/core/mir/ssa_graph.h @@ -40,8 +40,6 @@ class SSAGraph : GraphBase { void Build(const Program &program, const std::vector &valid_places); void RemoveNode(const mir::Node *node); - mir::Node *Argument(const std::string &name); - std::vector StmtTopologicalOrder(); // The inputs of the graph. @@ -68,9 +66,7 @@ class SSAGraph : GraphBase { const std::vector &valid_places); private: - void GraphCreateTmpVarNodes(const Program &program); - void GraphCreateWeightVarNodes(const Program &program); - + mir::Node *Argument(const std::string &name); // Check the bidirectional connection. bool CheckBidirectionalConnection(); bool CheckNodesRoleSet(); diff --git a/paddle/fluid/lite/core/mir/type_target_transform_pass.cc b/paddle/fluid/lite/core/mir/type_target_transform_pass.cc index 25789d34dca2fa90dbb8c7a415da651c44cc6d12..12dd2dcff0607bea46f41e7f5698ad2fb7e12404 100644 --- a/paddle/fluid/lite/core/mir/type_target_transform_pass.cc +++ b/paddle/fluid/lite/core/mir/type_target_transform_pass.cc @@ -65,20 +65,22 @@ void TypeTargetTransformPass::ComplementInputs(SSAGraph* graph, Node* inst_node, << " for kernel " << inst.op->DebugString() << " " << *in->AsArg().type << " -> " << *decl_arg_type; // Add an IoCopy instruction to make the input compatible with other dist. - AddIoCopyInst(*in->AsArg().type, *decl_arg_type, in->AsArg().name, graph, - inst_node, valid_places_); + AddIoCopyInst(*in->AsArg().type, *decl_arg_type, in, graph, inst_node, + valid_places_); } } void TypeTargetTransformPass::AddIoCopyInst( - const Type& from, const Type& to, const std::string& var, SSAGraph* graph, + const Type& from, const Type& to, Node* in, SSAGraph* graph, Node* inst_node, const std::vector& valid_places) { CHECK(!valid_places.empty()) << "valid_place should be set"; // var -> new_transform_op -> new_var -> inst // So there will be a new Argument node and a new IoCopy Statement Node. + CHECK(in->IsArg()); auto node_id = [&] { return graph->nodes().size(); }; - auto io_copy_output_name = var + "/trans/" + std::to_string(node_id()); + auto io_copy_output_name = + in->AsArg().name + "/trans/" + std::to_string(node_id()); auto* io_copy_output_arg = graph->NewArgumentNode(io_copy_output_name); auto* io_copy_inst = graph->NewInstructNode(); @@ -92,7 +94,7 @@ void TypeTargetTransformPass::AddIoCopyInst( // Create IoCopy Instruction. cpp::OpDesc op_desc; op_desc.SetType("io_copy"); - op_desc.SetInput("Input", {var}); + op_desc.SetInput("Input", {in->AsArg().name}); op_desc.SetOutput("Out", {io_copy_output_name}); io_copy_op->Attach(op_desc, inst_node->AsStmt().op->scope()); @@ -100,18 +102,18 @@ void TypeTargetTransformPass::AddIoCopyInst( io_copy_inst->AsStmt("io_copy", std::move(kernels), io_copy_op); // Remove the old link - RemoveDirectedLink(graph->Argument(var), inst_node); + RemoveDirectedLink(in, inst_node); // Update the original instruction OpDesc. // Update its input to the io_copy_output_name // Add new link, var -> new_inst, new_inst->newarg, newarg->inst - DirectedLink(graph->Argument(var), io_copy_inst); + DirectedLink(in, io_copy_inst); DirectedLink(io_copy_inst, io_copy_output_arg); DirectedLink(io_copy_output_arg, inst_node); // reset opdesc and update kernel information - UpdateInputTo(inst_node->AsStmt().op->mutable_op_info(), var, + UpdateInputTo(inst_node->AsStmt().op->mutable_op_info(), in->AsArg().name, io_copy_output_name); inst_node->AsStmt().op->Attach(*inst_node->AsStmt().op->op_info(), diff --git a/paddle/fluid/lite/core/mir/type_target_transform_pass.h b/paddle/fluid/lite/core/mir/type_target_transform_pass.h index 838c0bcdabc92717d4b62bda25b77df1bad6dc5d..052e3297abbe806c24f89eb7469cb1fe69246ff3 100644 --- a/paddle/fluid/lite/core/mir/type_target_transform_pass.h +++ b/paddle/fluid/lite/core/mir/type_target_transform_pass.h @@ -45,7 +45,7 @@ class TypeTargetTransformPass : public ProgramPass { void ComplementInputs(SSAGraph* graph, Node* inst_node, Node* in); - void AddIoCopyInst(const Type& from, const Type& to, const std::string& var, + void AddIoCopyInst(const Type& from, const Type& to, Node* in, SSAGraph* graph, Node* inst_node, const std::vector& valid_places); diff --git a/paddle/fluid/lite/core/mir/variable_place_inference_pass.h b/paddle/fluid/lite/core/mir/variable_place_inference_pass.h index 4d555d638a91e17796a68ed3397c22d138084e5a..2128c6d2014bf8879743ebf7190b3a95a3bc4186 100644 --- a/paddle/fluid/lite/core/mir/variable_place_inference_pass.h +++ b/paddle/fluid/lite/core/mir/variable_place_inference_pass.h @@ -13,7 +13,10 @@ // limitations under the License. #pragma once +#include #include +#include +#include #include "paddle/fluid/lite/core/mir/pass.h" #include "paddle/fluid/lite/core/target_wrapper.h" @@ -60,40 +63,44 @@ class VariablePlaceInferencePass : public DebugPass { // LOG(INFO) << "- inferencing type " << // deal with inputs VLOG(4) << "inferencing op " << inst.op_type; - for (auto& arg_name : inst.op_info()->input_argnames()) { + // TODO(zhaolong): Add check if the node's name in op's arguments. + + auto get_argname = [&]( + const std::string& node_name, + const std::map>& argname_map) + -> std::string { + for (auto& ele : argname_map) { + auto it = + std::find(ele.second.begin(), ele.second.end(), node_name); + if (it != ele.second.end()) return ele.first; + } + return ""; + }; + + for (auto* x_in : x->inlinks) { + std::string node_name = x_in->AsArg().name; + std::string arg_name = get_argname(node_name, inst.op_info()->inputs()); + CHECK(arg_name.size() > 0) << "can not found op arguments for node " + << node_name; VLOG(3) << "-- input arg_name " << arg_name; - // check if inputs's place is set, if not set, update them with the - // kernel's declaration. auto type = inst.picked_kernel().GetInputDeclType(arg_name); - auto arg_names = inst.op_info()->inputs().at(arg_name); - - for (auto& arg_name : arg_names) { - VLOG(3) << "--- var " << arg_name; - auto* node = graph->RetrieveArgument(arg_name); - CHECK(node) << "argument " << arg_name << " not exists in the graph"; - auto& arg_node = node->AsArg(); - if (!arg_node.type) { - VLOG(4) << "set type " << *type << " " << node; - arg_node.type = type; - } + if (!x_in->AsArg().type) { + VLOG(4) << "set type " << *type << " " << x_in; + x_in->AsArg().type = type; } } - for (auto& arg_name : inst.op_info()->output_argnames()) { + for (auto* x_out : x->outlinks) { + std::string node_name = x_out->AsArg().name; + std::string arg_name = + get_argname(node_name, inst.op_info()->outputs()); + CHECK(arg_name.size() > 0) << "can not found op arguments for node " + << node_name; VLOG(3) << "-- output arg_name " << arg_name; auto type = inst.picked_kernel().GetOutputDeclType(arg_name); - auto arg_names = inst.op_info()->outputs().at(arg_name); - // check if outputs's place is set, if not set, update them with the - // kernel's declaration. - for (auto& arg_name : arg_names) { - VLOG(3) << "--- var " << arg_name; - auto* node = graph->RetrieveArgument(arg_name); - CHECK(node) << "argument " << arg_name << " not exists in the graph"; - auto& arg_node = node->AsArg(); - if (!arg_node.type) { - node->AsArg().type = type; - VLOG(3) << "set type " << *type; - } + if (!x_out->AsArg().type) { + VLOG(4) << "set type " << *type << " " << x_out; + x_out->AsArg().type = type; } } } diff --git a/paddle/fluid/lite/core/optimizer.h b/paddle/fluid/lite/core/optimizer.h index b78408a67401453ec7c674e13443febac551d94c..651cd981c76d0d88f4c7294c0d19b1b0acbc76d4 100644 --- a/paddle/fluid/lite/core/optimizer.h +++ b/paddle/fluid/lite/core/optimizer.h @@ -49,17 +49,19 @@ class Optimizer { #ifndef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK if (passes.empty()) { RunPasses(std::vector{{ - "lite_fc_fuse_pass", // - "static_kernel_pick_pass", // - "variable_place_inference_pass", // - "argument_type_display_pass", // - "type_target_transform_pass", // - "argument_type_display_pass", // - "variable_place_inference_pass", // - "argument_type_display_pass", // - "io_copy_kernel_pick_pass", // - "variable_place_inference_pass", // - "runtime_context_assign_pass", // + "lite_conv_bn_fuse_pass", // + "lite_conv_elementwise_add_act_fuse_pass", // + "lite_fc_fuse_pass", // + "static_kernel_pick_pass", // + "variable_place_inference_pass", // + "argument_type_display_pass", // + "type_target_transform_pass", // + "argument_type_display_pass", // + "variable_place_inference_pass", // + "argument_type_display_pass", // + "io_copy_kernel_pick_pass", // + "variable_place_inference_pass", // + "runtime_context_assign_pass", // }}); } else { RunPasses(passes); diff --git a/paddle/fluid/lite/kernels/x86/conv_compute.cc b/paddle/fluid/lite/kernels/x86/conv_compute.cc index 9d2de5be452c7e4f2f66086a62283ef802157af8..35d0de82de4c1b896f3f65726d74f50b17a9f227 100644 --- a/paddle/fluid/lite/kernels/x86/conv_compute.cc +++ b/paddle/fluid/lite/kernels/x86/conv_compute.cc @@ -74,6 +74,7 @@ class Conv2dCompute : public KernelLite { lite::Tensor col_matrix; if (is_expand) { col.Resize(col_shape); + col.mutable_data(); col_matrix.ShareDataWith(col); col_matrix.Resize(col_matrix_shape); } @@ -155,7 +156,6 @@ REGISTER_LITE_KERNEL(conv2d, kX86, kFloat, kNCHW, .BindInput("Input", {LiteType::GetTensorTy(TARGET(kX86))}) .BindInput("Filter", {LiteType::GetTensorTy(TARGET(kX86))}) .BindInput("Bias", {LiteType::GetTensorTy(TARGET(kX86))}) - .BindInput("Input", {LiteType::GetTensorTy(TARGET(kX86))}) .BindOutput("Output", {LiteType::GetTensorTy(TARGET(kX86))}) .Finalize(); @@ -164,6 +164,5 @@ REGISTER_LITE_KERNEL(depthwise_conv2d, kX86, kFloat, kNCHW, .BindInput("Input", {LiteType::GetTensorTy(TARGET(kX86))}) .BindInput("Filter", {LiteType::GetTensorTy(TARGET(kX86))}) .BindInput("Bias", {LiteType::GetTensorTy(TARGET(kX86))}) - .BindInput("Input", {LiteType::GetTensorTy(TARGET(kX86))}) .BindOutput("Output", {LiteType::GetTensorTy(TARGET(kX86))}) .Finalize(); diff --git a/paddle/fluid/lite/operators/CMakeLists.txt b/paddle/fluid/lite/operators/CMakeLists.txt index 691ff743b173d269f09eebfcc6db91b5b32423af..09c05ecb6f5ec5d4ef2523fcbf3994aeea7cbaea 100644 --- a/paddle/fluid/lite/operators/CMakeLists.txt +++ b/paddle/fluid/lite/operators/CMakeLists.txt @@ -19,6 +19,7 @@ cc_library(dropout_op_lite SRCS dropout_op.cc DEPS ${op_DEPS}) cc_library(concat_op_lite SRCS concat_op.cc DEPS ${op_DEPS}) cc_library(conv_op_lite SRCS conv_op.cc DEPS ${op_DEPS}) cc_library(pool_op_lite SRCS pool_op.cc DEPS ${op_DEPS}) +cc_library(batch_norm_op_lite SRCS batch_norm.cc DEPS ${op_DEPS}) set(ops_lite fc_op_lite @@ -38,6 +39,7 @@ set(ops_lite concat_op_lite conv_op_lite pool_op_lite + batch_norm_op_lite CACHE INTERNAL "ops lite") lite_cc_test(test_fc_op_lite SRCS fc_op_test.cc diff --git a/paddle/fluid/lite/operators/batch_norm.cc b/paddle/fluid/lite/operators/batch_norm.cc new file mode 100644 index 0000000000000000000000000000000000000000..80388e13050eaaaccf145ea3784c0e1e34886d81 --- /dev/null +++ b/paddle/fluid/lite/operators/batch_norm.cc @@ -0,0 +1,31 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/operators/batch_norm.h" +#include +#include "paddle/fluid/lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace operators { + +bool BatchNormOpLite::CheckShape() const { return true; } + +bool BatchNormOpLite::InferShape() const { return true; } + +} // namespace operators +} // namespace lite +} // namespace paddle + +REGISTER_LITE_OP(batch_norm, paddle::lite::operators::BatchNormOpLite); diff --git a/paddle/fluid/lite/operators/batch_norm.h b/paddle/fluid/lite/operators/batch_norm.h new file mode 100644 index 0000000000000000000000000000000000000000..90815768e6bd60275b6096900e6e86be080a3a42 --- /dev/null +++ b/paddle/fluid/lite/operators/batch_norm.h @@ -0,0 +1,87 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include "paddle/fluid/lite/core/kernel.h" +#include "paddle/fluid/lite/core/op_lite.h" +#include "paddle/fluid/lite/core/scope.h" +#include "paddle/fluid/lite/operators/op_params.h" +#include "paddle/fluid/lite/utils/all.h" + +namespace paddle { +namespace lite { +namespace operators { + +class BatchNormOpLite : public OpLite { + public: + BatchNormOpLite() {} + + explicit BatchNormOpLite(const std::string &type) : OpLite(type) {} + + bool CheckShape() const override; + + bool InferShape() const override; + + void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } + // TODO(Superjomn) replace framework::OpDesc with a lite one. + bool AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) override { + auto x = op_desc.Input("X").front(); + auto bias = op_desc.Input("Bias").front(); + auto mean = op_desc.Input("Mean").front(); + auto scale = op_desc.Input("Scale").front(); + auto variance = op_desc.Input("Variance").front(); + + auto out = op_desc.Output("Y").front(); + auto mean_out = op_desc.Output("MeanOut").front(); + auto var_out = op_desc.Output("VarianceOut").front(); + auto saved_mean = op_desc.Output("SavedMean").front(); + auto saved_var = op_desc.Output("SavedVariance").front(); + + auto *var = scope->FindVar(x); + param_.x = var->GetMutable(); + var = scope->FindVar(bias); + param_.bias = var->GetMutable(); + var = scope->FindVar(mean); + param_.mean = var->GetMutable(); + var = scope->FindVar(scale); + param_.scale = var->GetMutable(); + var = scope->FindVar(variance); + param_.var = var->GetMutable(); + var = scope->FindVar(out); + param_.out = var->GetMutable(); + var = scope->FindVar(mean_out); + param_.mean_out = var->GetMutable(); + var = scope->FindVar(var_out); + param_.var_out = var->GetMutable(); + var = scope->FindVar(saved_mean); + param_.saved_mean = var->GetMutable(); + var = scope->FindVar(saved_var); + param_.saved_var = var->GetMutable(); + + param_.eps = op_desc.GetAttr("epsilon"); + + return true; + } + + std::string DebugString() const override { return "batch_norm"; } + + private: + mutable BatchNormParam param_; +}; + +} // namespace operators +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/operators/op_params.h b/paddle/fluid/lite/operators/op_params.h index 23b21cb276442d4e1da8b83557007a132c9de3fb..87b4c8dd5fbe63baf25a2a6df8e13b2c3db5fb53 100644 --- a/paddle/fluid/lite/operators/op_params.h +++ b/paddle/fluid/lite/operators/op_params.h @@ -237,6 +237,22 @@ struct SGDParam { lite::Tensor* ParamOut{}; }; +// +struct BatchNormParam { + lite::Tensor* x{}; + lite::Tensor* bias{}; + lite::Tensor* mean{}; + lite::Tensor* scale{}; + lite::Tensor* var{}; + lite::Tensor* out{}; + lite::Tensor* mean_out{}; + lite::Tensor* var_out{}; + lite::Tensor* saved_mean{}; + lite::Tensor* saved_var{}; + + float eps{1e-5}; +}; + } // namespace operators } // namespace lite } // namespace paddle