未验证 提交 d8ba9626 编写于 作者: Z Zhaolong Xing 提交者: GitHub

incubate/lite: batch norm pattern detect (#18021)

* bn pattern

* 1. add conv bn fusion pass(conv2d, depthwise_conv2d)
2. modify the graph to ssagraph
test=develop
上级 e71bf9f7
...@@ -3,10 +3,12 @@ cc_library(mir_ssa_graph SRCS ssa_graph.cc DEPS mir_node) ...@@ -3,10 +3,12 @@ cc_library(mir_ssa_graph SRCS ssa_graph.cc DEPS mir_node)
cc_library(mir_pass SRCS pass.cc DEPS mir_ssa_graph) cc_library(mir_pass SRCS pass.cc DEPS mir_ssa_graph)
cc_library(mir_pass_manager SRCS pass_manager.cc DEPS mir_pass mir_ssa_graph mir_passes) cc_library(mir_pass_manager SRCS pass_manager.cc DEPS mir_pass mir_ssa_graph mir_passes)
cc_library(mir_pass_registry SRCS pass_registry.cc DEPS mir_pass_manager) cc_library(mir_pass_registry SRCS pass_registry.cc DEPS mir_pass_manager)
add_subdirectory(fusion) add_subdirectory(fusion)
cc_library(mir_passes cc_library(mir_passes
SRCS fc_fuse_pass.cc SRCS fc_fuse_pass.cc
conv_elementwise_add_relu_fuse_pass.cc conv_elementwise_add_relu_fuse_pass.cc
conv_bn_fuse_pass.cc
static_kernel_pick_pass.cc static_kernel_pick_pass.cc
variable_place_inference_pass.cc variable_place_inference_pass.cc
type_target_transform_pass.cc type_target_transform_pass.cc
...@@ -55,6 +57,7 @@ lite_cc_test(test_pattern_matcher_lite SRCS pattern_matcher_test.cc DEPS pattern ...@@ -55,6 +57,7 @@ lite_cc_test(test_pattern_matcher_lite SRCS pattern_matcher_test.cc DEPS pattern
lite_cc_library(pattern_matcher_high_api SRCS pattern_matcher_high_api.cc DEPS pattern_matcher_lite) lite_cc_library(pattern_matcher_high_api SRCS pattern_matcher_high_api.cc DEPS pattern_matcher_lite)
# TODO(wz) replace framework/proto to lite proto. # TODO(wz) replace framework/proto to lite proto.
if (NOT LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) if (NOT LITE_WITH_LIGHT_WEIGHT_FRAMEWORK)
# it depends on the fluid/framework/proto, that is too heavy for mobile execution. # it depends on the fluid/framework/proto, that is too heavy for mobile execution.
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/mir/conv_bn_fuse_pass.h"
#include <memory>
#include <vector>
#include "paddle/fluid/lite/core/mir/fusion/conv_bn_fuser.h"
#include "paddle/fluid/lite/core/mir/pass_registry.h"
namespace paddle {
namespace lite {
namespace mir {
void ConvBNFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
fusion::ConvBNFuser fuser("conv2d");
fuser(graph.get());
fusion::ConvBNFuser fuser2("depthwise_conv2d");
fuser2(graph.get());
}
} // namespace mir
} // namespace lite
} // namespace paddle
REGISTER_MIR_PASS(lite_conv_bn_fuse_pass, paddle::lite::mir::ConvBNFusePass);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include "paddle/fluid/lite/core/mir/pass.h"
namespace paddle {
namespace lite {
namespace mir {
class ConvBNFusePass : public ProgramPass {
public:
void Apply(const std::unique_ptr<SSAGraph>& graph) override;
};
} // namespace mir
} // namespace lite
} // namespace paddle
...@@ -40,8 +40,8 @@ std::unique_ptr<SSAGraph> BuildGraph(framework::ProgramDesc* program_desc, ...@@ -40,8 +40,8 @@ std::unique_ptr<SSAGraph> BuildGraph(framework::ProgramDesc* program_desc,
auto* conv2d_1 = main_block->AppendOp(); auto* conv2d_1 = main_block->AppendOp();
auto* conv2d_2 = main_block->AppendOp(); auto* conv2d_2 = main_block->AppendOp();
auto* add_1 = main_block->AppendOp(); auto* add_1 = main_block->AppendOp();
auto* add_2 = main_block->AppendOp();
auto* relu_1 = main_block->AppendOp(); auto* relu_1 = main_block->AppendOp();
auto* add_2 = main_block->AppendOp();
auto* relu_2 = main_block->AppendOp(); auto* relu_2 = main_block->AppendOp();
main_block->Var("input_1"); main_block->Var("input_1");
...@@ -123,8 +123,8 @@ TEST(conv_elementwise_add_relu_fuse_pass, graph_test) { ...@@ -123,8 +123,8 @@ TEST(conv_elementwise_add_relu_fuse_pass, graph_test) {
auto scope = std::make_shared<Scope>(); auto scope = std::make_shared<Scope>();
auto graph = BuildGraph(&program_desc, scope, places); auto graph = BuildGraph(&program_desc, scope, places);
ASSERT_EQ(graph->nodes().size(), Visualize(graph.get());
11UL /*vars*/ + 6UL /*ops*/ + 2UL /*feed op + fetch op*/); ASSERT_EQ(graph->nodes().size(), 11UL /*vars*/ + 6UL /*ops*/);
Visualize(graph.get()); Visualize(graph.get());
} }
......
cc_library(fuse_fc cc_library(fuse_fc
SRCS fc_fuser.cc SRCS fc_fuser.cc
DEPS pattern_matcher_high_api) DEPS pattern_matcher_high_api)
cc_library(conv_elementwise_add_relu cc_library(fuse_conv_elementwise_add_relu
SRCS conv_elementwise_add_relu_fuser.cc SRCS conv_elementwise_add_relu_fuser.cc
DEPS pattern_matcher_high_api) DEPS pattern_matcher_high_api)
cc_library(fuse_conv_bn
SRCS conv_bn_fuser.cc
DEPS pattern_matcher_high_api)
set(mir_fusers set(mir_fusers
fuse_fc fuse_fc
conv_elementwise_add_relu fuse_conv_elementwise_add_relu
fuse_conv_bn
CACHE INTERNAL "fusers") CACHE INTERNAL "fusers")
lite_cc_test(test_lite_conv_bn_fuse SRCS conv_bn_fuse_pass_test.cc
DEPS elementwise_ops_lite batch_norm_op_lite conv_op_lite proto_desc compatible_pb_lite program_lite mir_pass mir_pass_manager pattern_matcher_high_api)
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/mir/conv_bn_fuse_pass.h"
#include <gflags/gflags.h>
#include <gtest/gtest.h>
#include <vector>
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/lite/core/compatible_tensor.h"
#include "paddle/fluid/lite/core/mir/graph_visualize_pass.h"
#include "paddle/fluid/lite/core/program.h"
namespace paddle {
namespace lite {
namespace mir {
namespace fusion {
std::unique_ptr<SSAGraph> BuildGraph(framework::ProgramDesc* program_desc,
const std::shared_ptr<Scope>& scope,
const std::vector<Place>& valid_places) {
auto* main_block = program_desc->MutableBlock(0);
auto* conv_op = main_block->AppendOp();
auto* bn_op = main_block->AppendOp();
main_block->Var("conv_i");
main_block->Var("conv_param");
main_block->Var("conv_out");
main_block->Var("bn_scale");
main_block->Var("bn_bias");
main_block->Var("bn_mean");
main_block->Var("bn_var");
main_block->Var("bn_out");
main_block->Var("bn_mean_out");
main_block->Var("bn_var_out");
main_block->Var("bn_saved_mean");
main_block->Var("bn_saved_var");
scope->Var("conv_i")->GetMutable<lite::Tensor>();
auto conv_param_t = scope->Var("conv_param")->GetMutable<lite::Tensor>();
std::vector<int64_t> conv_param_shape = {3, 1, 2, 2};
conv_param_t->Resize(lite::DDim(conv_param_shape));
conv_param_t->mutable_data<float>();
scope->Var("conv_out")->GetMutable<lite::Tensor>();
auto bn_scale_t = scope->Var("bn_scale")->GetMutable<lite::Tensor>();
std::vector<int64_t> bn_scale_shape = {3};
bn_scale_t->Resize(lite::DDim(bn_scale_shape));
bn_scale_t->mutable_data<float>();
auto bn_bias_t = scope->Var("bn_bias")->GetMutable<lite::Tensor>();
std::vector<int64_t> bn_bias_shape = {3};
bn_bias_t->Resize(lite::DDim(bn_bias_shape));
bn_bias_t->mutable_data<float>();
auto bn_mean_t = scope->Var("bn_mean")->GetMutable<lite::Tensor>();
bn_mean_t->Resize(lite::DDim(bn_bias_shape));
bn_mean_t->mutable_data<float>();
auto bn_var_t = scope->Var("bn_var")->GetMutable<lite::Tensor>();
bn_var_t->Resize(lite::DDim(bn_bias_shape));
bn_var_t->mutable_data<float>();
scope->Var("bn_out")->GetMutable<lite::Tensor>();
scope->Var("bn_mean_out")->GetMutable<lite::Tensor>();
scope->Var("bn_var_out")->GetMutable<lite::Tensor>();
scope->Var("bn_saved_mean")->GetMutable<lite::Tensor>();
scope->Var("bn_saved_var")->GetMutable<lite::Tensor>();
conv_op->SetType("conv2d");
conv_op->SetInput("Input", {"conv_i"});
conv_op->SetInput("Filter", {"conv_param"});
conv_op->SetOutput("Output", {"conv_out"});
const std::vector<int> strides({1, 1});
const std::vector<int> paddings({1, 1});
const std::vector<int> dilations({1, 1});
const int groups = 1;
conv_op->SetAttr("strides", strides);
conv_op->SetAttr("paddings", paddings);
conv_op->SetAttr("dilations", dilations);
conv_op->SetAttr("groups", groups);
bn_op->SetType("batch_norm");
bn_op->SetInput("X", {"conv_out"});
bn_op->SetInput("Bias", {"bn_bias"});
bn_op->SetInput("Mean", {"bn_mean"});
bn_op->SetInput("Scale", {"bn_scale"});
bn_op->SetInput("Variance", {"bn_var"});
bn_op->SetOutput("Y", {"bn_out"});
bn_op->SetOutput("MeanOut", {"bn_mean_out"});
bn_op->SetOutput("VarianceOut", {"bn_var_out"});
bn_op->SetOutput("SavedMean", {"bn_saved_mean"});
bn_op->SetOutput("SavedVariance", {"bn_saved_var"});
float eps = 1e-5;
bn_op->SetAttr("epsilon", eps);
program_desc->Flush();
lite::Program program(*program_desc->Proto(), scope, valid_places);
auto graph = std::unique_ptr<SSAGraph>(new SSAGraph());
graph->Build(program, valid_places);
return graph;
}
TEST(pattern_matcher2, test) {
framework::ProgramDesc program_desc;
std::vector<Place> places{{TARGET(kHost), PRECISION(kFloat)}};
auto scope = std::make_shared<Scope>();
auto graph = BuildGraph(&program_desc, scope, places);
const int num_nodes = graph->nodes().size();
auto* fuser = new ConvBNFusePass;
fuser->Apply(graph);
ASSERT_EQ(graph->nodes().size(),
num_nodes - 8UL /*nodes removed */ + 1UL /* eltwise_add node*/);
}
} // namespace fusion
} // namespace mir
} // namespace lite
} // namespace paddle
USE_LITE_OP(conv2d);
USE_LITE_OP(batch_norm);
USE_LITE_OP(elementwise_add);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/mir/fusion/conv_bn_fuser.h"
#include <memory>
#include <vector>
namespace paddle {
namespace lite {
namespace mir {
namespace fusion {
void ConvBNFuser::BuildPattern() {
auto* conv_input =
VarNode("conv_input")->assert_is_op_input(conv_type_, "Input")->AsInput();
auto* conv_weight = VarNode("conv_weight")
->assert_is_op_input(conv_type_, "Filter")
->AsInput();
auto* conv = OpNode("conv2d", conv_type_)->assert_is_op(conv_type_);
auto* conv_out = VarNode("conv_out")
->assert_is_op_output(conv_type_, "Output")
->assert_is_op_input("batch_norm", "X");
auto* bn_scale = VarNode("bn_scale")
->assert_is_op_input("batch_norm", "Scale")
->AsIntermediate();
auto* bn_bias =
VarNode("bn_bias")->assert_is_op_input("batch_norm", "Bias")->AsInput();
auto* bn_mean = VarNode("bn_mean")
->assert_is_op_input("batch_norm", "Mean")
->AsIntermediate();
auto* bn_var = VarNode("bn_variance")
->assert_is_op_input("batch_norm", "Variance")
->AsIntermediate();
auto* bn =
OpNode("bn", "batch_norm")->assert_is_op("batch_norm")->AsIntermediate();
auto* bn_out =
VarNode("bn_out")->assert_is_op_output("batch_norm", "Y")->AsOutput();
auto* bn_mean_out = VarNode("bn_mean_out")
->assert_is_op_output("batch_norm", "MeanOut")
->AsIntermediate();
auto* bn_var_out = VarNode("bn_var_out")
->assert_is_op_output("batch_norm", "VarianceOut")
->AsIntermediate();
auto* bn_saved_mean = VarNode("bn_saved_mean")
->assert_is_op_output("batch_norm", "SavedMean")
->AsIntermediate();
auto* bn_saved_var = VarNode("bn_saved_var")
->assert_is_op_output("batch_norm", "SavedVariance")
->AsIntermediate();
conv->LinksFrom({conv_input, conv_weight}).LinksTo({conv_out});
bn->LinksFrom({conv_out, bn_scale, bn_bias, bn_mean, bn_var})
.LinksTo({bn_out, bn_mean_out, bn_saved_mean, bn_saved_var, bn_var_out});
}
void ConvBNFuser::InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) {
auto op_desc = GenOpDesc(matched);
auto eltwise_op = LiteOpRegistry::Global().Create("elementwise_add");
auto conv = matched.at("conv2d")->stmt()->op;
auto* scope = conv->scope();
auto& valid_places = conv->valid_places();
auto conv_weight_t = scope->FindVar(matched.at("conv_weight")->arg()->name)
->GetMutable<lite::Tensor>();
auto conv_weight_d = conv_weight_t->mutable_data<float>();
auto conv_weight_dims = conv_weight_t->dims();
size_t weight_num = conv_weight_t->data_size();
auto bn_scale_t = scope->FindVar(matched.at("bn_scale")->arg()->name)
->GetMutable<lite::Tensor>();
size_t bias_size = bn_scale_t->data_size();
auto bn_scale_d = bn_scale_t->mutable_data<float>();
PADDLE_ENFORCE(bias_size == conv_weight_dims[0],
"The BN bias's size should be equal to the size of the first "
"dim size of the conv weights");
auto bn_mean_t = scope->FindVar(matched.at("bn_mean")->arg()->name)
->GetMutable<lite::Tensor>();
auto bn_mean_d = bn_mean_t->mutable_data<float>();
auto bn_var_t = scope->FindVar(matched.at("bn_variance")->arg()->name)
->GetMutable<lite::Tensor>();
auto bn_var_d = bn_var_t->mutable_data<float>();
auto bn_bias_t = scope->FindVar(matched.at("bn_bias")->arg()->name)
->GetMutable<lite::Tensor>();
auto bn_bias_d = bn_bias_t->mutable_data<float>();
auto eps = matched.at("bn")->stmt()->op_info()->GetAttr<float>("epsilon");
ComputeFusedWeight(bn_scale_d, bn_mean_d, bn_var_d, bn_bias_d, conv_weight_d,
eps, bias_size, weight_num / bias_size);
eltwise_op->Attach(op_desc, scope);
auto* new_op_node = graph->GraphCreateInstructNode(eltwise_op, valid_places);
IR_NODE_LINK_TO(matched.at("conv_out"), new_op_node);
IR_NODE_LINK_TO(matched.at("bn_bias"), new_op_node);
IR_NODE_LINK_TO(new_op_node, matched.at("bn_out"));
}
cpp::OpDesc ConvBNFuser::GenOpDesc(const key2nodes_t& matched) {
cpp::OpDesc op_desc;
op_desc.SetType("elementwise_add");
op_desc.SetInput("X", {matched.at("conv_out")->arg()->name});
op_desc.SetInput("Y", {matched.at("bn_bias")->arg()->name});
op_desc.SetOutput("Out", {matched.at("bn_out")->arg()->name});
op_desc.SetAttr("axis", 1);
return op_desc;
}
} // namespace fusion
} // namespace mir
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include "paddle/fluid/lite/core/mir/pattern_matcher_high_api.h"
namespace paddle {
namespace lite {
namespace mir {
namespace fusion {
class ConvBNFuser : public FuseBase {
public:
explicit ConvBNFuser(const std::string& conv_type) : conv_type_(conv_type) {}
void BuildPattern() override;
void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override;
private:
cpp::OpDesc GenOpDesc(const key2nodes_t& matched) override;
void ComputeFusedWeight(float* scale_d, float* mean_d, float* var_d,
float* bias_d, float* conv_weight_d, float eps, int h,
int w) {
for (int i = 0; i < h; i++) {
var_d[i] = scale_d[i] / std::sqrt(var_d[i] + eps);
}
for (int i = 0; i < h; i++) {
bias_d[i] += (-mean_d[i]) * var_d[i];
}
for (int i = 0; i < h; i++) {
for (int j = 0; j < w; j++) {
conv_weight_d[i * w + j] *= var_d[i];
}
}
}
private:
std::string conv_type_{"conv2d"};
};
} // namespace fusion
} // namespace mir
} // namespace lite
} // namespace paddle
...@@ -71,12 +71,20 @@ class Node { ...@@ -71,12 +71,20 @@ class Node {
struct Arg { struct Arg {
std::string name; std::string name;
int id{0};
const Type* type{}; const Type* type{};
// Weight is a special kind of argument, it is marked as weight explicitly // Weight is a special kind of argument, it is marked as weight explicitly
// so that some weight related optimization can take place. // so that some weight related optimization can take place.
bool is_weight{false}; bool is_weight{false};
}; };
Arg& AsArg(const std::string& name, int id) {
auto& x = AsArg();
x.name = name;
x.id = id;
return x;
}
Arg& AsArg(const std::string& name) { Arg& AsArg(const std::string& name) {
auto& x = AsArg(); auto& x = AsArg();
x.name = name; x.name = name;
......
...@@ -31,3 +31,4 @@ USE_MIR_PASS(generate_program_pass); ...@@ -31,3 +31,4 @@ USE_MIR_PASS(generate_program_pass);
USE_MIR_PASS(io_copy_kernel_pick_pass); USE_MIR_PASS(io_copy_kernel_pick_pass);
USE_MIR_PASS(argument_type_display_pass); USE_MIR_PASS(argument_type_display_pass);
USE_MIR_PASS(runtime_context_assign_pass); USE_MIR_PASS(runtime_context_assign_pass);
USE_MIR_PASS(lite_conv_bn_fuse_pass);
...@@ -407,37 +407,60 @@ PMNode *PMNode::assert_is_op_output(const std::string &op_type) { ...@@ -407,37 +407,60 @@ PMNode *PMNode::assert_is_op_output(const std::string &op_type) {
return this; return this;
} }
PMNode *PMNode::assert_is_op_input(const std::string &op_type) { bool IsNthOutput(const Node *var, const Node *op, const std::string &argument,
size_t nth) {
PADDLE_ENFORCE(var->IsArg());
PADDLE_ENFORCE(op->IsStmt());
auto op_info = op->stmt()->op_info();
if (op_info->Output(argument).size() <= nth) return false;
return var->arg()->name == op_info->Output(argument)[nth];
}
bool IsNthInput(const Node *var, const Node *op, const std::string &argument,
size_t nth) {
PADDLE_ENFORCE(var->IsArg());
PADDLE_ENFORCE(op->IsStmt());
auto op_info = op->stmt()->op_info();
if (op_info->Input(argument).size() <= nth) return false;
return var->arg()->name == op_info->Input(argument)[nth];
}
PMNode *PMNode::assert_is_op_input(const std::string &op_type,
const std::string &argument) {
assert_is_var();
assert_is_op_nth_input(op_type, argument, 0);
return this;
}
PMNode *PMNode::assert_is_op_nth_input(const std::string &op_type,
const std::string &argument, int nth) {
assert_is_var(); assert_is_var();
assert_is_op_input(op_type);
asserts_.emplace_back([=](const Node *x) { asserts_.emplace_back([=](const Node *x) {
for (auto *op : x->outlinks) { for (auto *op : x->outlinks) {
if (op && op->IsStmt()) { if (op && op->IsStmt() && op->stmt()->op_info()->Type() == op_type &&
auto *op_info = op->stmt()->op_info(); IsNthInput(x, op, argument, nth))
if (op_info->Type() == op_type) {
return true; return true;
} }
}
}
return false; return false;
}); });
return this; return this;
} }
PMNode *PMNode::assert_is_op_input(const std::string &op_type, PMNode *PMNode::assert_is_op_output(const std::string &op_type,
const std::string &argument) { const std::string &argument) {
assert_is_var(); assert_is_var();
assert_is_op_nth_input(op_type, argument, 0); assert_is_op_nth_output(op_type, argument, 0);
return this; return this;
} }
PMNode *PMNode::assert_is_op_nth_input(const std::string &op_type, PMNode *PMNode::assert_is_op_nth_output(const std::string &op_type,
const std::string &argument, int nth) { const std::string &argument, int nth) {
assert_is_var(); assert_is_var();
assert_is_op_input(op_type);
asserts_.emplace_back([=](const Node *x) { asserts_.emplace_back([=](const Node *x) {
for (auto *op : x->outlinks) { for (auto *op : x->inlinks) {
if (op->IsStmt() && op->stmt()->op_info()->Type() == op_type && if (op && op->IsStmt() && op->stmt()->op_info()->Type() == op_type &&
IsNthInput(*x, *op, argument, nth)) IsNthOutput(x, op, argument, nth))
return true; return true;
} }
return false; return false;
...@@ -445,14 +468,20 @@ PMNode *PMNode::assert_is_op_nth_input(const std::string &op_type, ...@@ -445,14 +468,20 @@ PMNode *PMNode::assert_is_op_nth_input(const std::string &op_type,
return this; return this;
} }
bool IsNthInput(const Node &var, const Node &op, const std::string &argument, PMNode *PMNode::assert_is_op_input(const std::string &op_type) {
int nth) { assert_is_var();
CHECK(var.IsArg()); asserts_.emplace_back([=](const Node *x) {
CHECK(op.IsStmt()); for (auto *op : x->outlinks) {
if (!HasInput(op, argument) || if (op && op->IsStmt()) {
static_cast<int>(op.stmt()->op_info()->Input(argument).size()) <= nth) auto *op_info = op->stmt()->op_info();
if (op_info->Type() == op_type) {
return true;
}
}
}
return false; return false;
return var.arg()->name == op.stmt()->op_info()->Input(argument)[nth]; });
return this;
} }
bool HasInput(const Node &op, const std::string &argument) { bool HasInput(const Node &op, const std::string &argument) {
......
...@@ -129,8 +129,13 @@ struct PMNode { ...@@ -129,8 +129,13 @@ struct PMNode {
PMNode* assert_is_op_input(const std::string& op_type); PMNode* assert_is_op_input(const std::string& op_type);
PMNode* assert_is_op_input(const std::string& op_type, PMNode* assert_is_op_input(const std::string& op_type,
const std::string& argument); const std::string& argument);
PMNode* assert_is_op_output(const std::string& op_type,
const std::string& argument);
PMNode* assert_is_op_nth_input(const std::string& op_type, PMNode* assert_is_op_nth_input(const std::string& op_type,
const std::string& argument, int nth); const std::string& argument, int nth);
PMNode* assert_is_op_nth_output(const std::string& op_type,
const std::string& argument, int nth);
template <typename T> template <typename T>
PMNode* assert_op_attr(const std::string& attr_name, const T& attr) { PMNode* assert_op_attr(const std::string& attr_name, const T& attr) {
......
...@@ -124,8 +124,7 @@ TEST(pattern_matcher_high_api, graph_test) { ...@@ -124,8 +124,7 @@ TEST(pattern_matcher_high_api, graph_test) {
auto scope = std::make_shared<Scope>(); auto scope = std::make_shared<Scope>();
auto graph = BuildGraph(&program_desc, scope, places); auto graph = BuildGraph(&program_desc, scope, places);
ASSERT_EQ(graph->nodes().size(), ASSERT_EQ(graph->nodes().size(), 7UL /*real nodes*/);
7UL /*real nodes*/ + 2UL /*feed op + fetch op*/);
Visualize(graph.get()); Visualize(graph.get());
} }
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include <algorithm> #include <algorithm>
#include <memory> #include <memory>
#include <set> #include <set>
#include <unordered_map>
#include <utility> #include <utility>
namespace paddle { namespace paddle {
...@@ -25,6 +26,8 @@ namespace mir { ...@@ -25,6 +26,8 @@ namespace mir {
bool SSAGraph::CheckBidirectionalConnection() { bool SSAGraph::CheckBidirectionalConnection() {
LOG(INFO) << "node count " << node_storage_.size(); LOG(INFO) << "node count " << node_storage_.size();
for (auto &node : node_storage_) { for (auto &node : node_storage_) {
if (node.IsStmt()) LOG(INFO) << node.AsStmt().op_info()->Type();
if (node.IsArg()) LOG(INFO) << node.AsArg().name << " " << node.AsArg().id;
for (auto *in : node.inlinks) { for (auto *in : node.inlinks) {
CHECK(in->outlinks.end() != CHECK(in->outlinks.end() !=
std::find(in->outlinks.begin(), in->outlinks.end(), &node)); std::find(in->outlinks.begin(), in->outlinks.end(), &node));
...@@ -93,31 +96,6 @@ std::vector<mir::Node *> SSAGraph::StmtTopologicalOrder() { ...@@ -93,31 +96,6 @@ std::vector<mir::Node *> SSAGraph::StmtTopologicalOrder() {
return res; return res;
} }
void SSAGraph::GraphCreateTmpVarNodes(const Program &program) {
for (const auto &name : program.tmp_vars()) {
CHECK(!arguments_.count(name)) << "duplicate creating temp variable: "
<< name;
VLOG(5) << "create arg node " << name;
node_storage_.emplace_back();
auto &new_node = node_storage_.back();
new_node.AsArg(name);
arguments_[name] = &new_node;
}
}
void SSAGraph::GraphCreateWeightVarNodes(const Program &program) {
// create weight nodes.
for (const auto &name : program.weights()) {
CHECK(!arguments_.count(name)) << "duplicate creating weight variable: "
<< name;
VLOG(5) << "create arg node " << name;
node_storage_.emplace_back();
auto &new_node = node_storage_.back();
new_node.AsArg(name);
arguments_[name] = &new_node;
}
}
Node *SSAGraph::GraphCreateInstructNode( Node *SSAGraph::GraphCreateInstructNode(
const std::shared_ptr<OpLite> &op, const std::vector<Place> &valid_places) { const std::shared_ptr<OpLite> &op, const std::vector<Place> &valid_places) {
node_storage_.emplace_back(); node_storage_.emplace_back();
...@@ -135,29 +113,50 @@ Node *SSAGraph::GraphCreateInstructNode( ...@@ -135,29 +113,50 @@ Node *SSAGraph::GraphCreateInstructNode(
void SSAGraph::Build(const Program &program, void SSAGraph::Build(const Program &program,
const std::vector<Place> &valid_places) { const std::vector<Place> &valid_places) {
CHECK(node_storage_.empty()); CHECK(node_storage_.empty());
GraphCreateTmpVarNodes(program);
GraphCreateWeightVarNodes(program);
CHECK(CheckNodesRoleSet());
auto weights_name = program.weights();
auto is_weights = [&](const std::string &name) -> bool {
auto it = std::find(weights_name.begin(), weights_name.end(), name);
if (it == weights_name.end()) return false;
return true;
};
std::unordered_map<std::string, mir::Node *> arg_update_node_map_;
for (auto &op : program.ops()) { for (auto &op : program.ops()) {
LOG(INFO) << op->op_info()->Type();
auto *op_node = GraphCreateInstructNode(op, valid_places); auto *op_node = GraphCreateInstructNode(op, valid_places);
LOG(INFO) << "input:";
for (const std::string &name : op->op_info()->input_names()) { for (const std::string &name : op->op_info()->input_names()) {
auto *arg = Argument(name); LOG(INFO) << name;
CHECK(arg->IsRoleSet()); mir::Node *arg_node = nullptr;
DirectedLink(arg, op_node); if (arg_update_node_map_.count(name)) {
arg_node = arg_update_node_map_.at(name);
} else {
node_storage_.emplace_back();
arg_node = &node_storage_.back();
arg_node->AsArg(name, node_storage_.size() - 1);
arg_update_node_map_[name] = arg_node;
} }
for (const std::string &name : op->op_info()->output_names()) { if (is_weights(name)) arg_node->AsArg().is_weight = true;
if (!arguments_.count(name)) { CHECK(arg_node->IsRoleSet());
NewArgumentNode(name); DirectedLink(arg_node, op_node);
} }
auto *arg = arguments_.at(name); LOG(INFO) << "output:";
CHECK(arg->IsRoleSet()); for (const std::string &name : op->op_info()->output_names()) {
DirectedLink(op_node, arg); LOG(INFO) << name;
node_storage_.emplace_back();
auto *arg_node = &node_storage_.back();
arg_node->AsArg(name, node_storage_.size() - 1);
arg_update_node_map_[name] = arg_node;
if (is_weights(name)) arg_node->AsArg().is_weight = true;
CHECK(arg_node->IsRoleSet());
DirectedLink(op_node, arg_node);
} }
CHECK(CheckLinksRoleSet()); CHECK(CheckLinksRoleSet());
} }
MarkArgumentWeights(program); CHECK(CheckNodesRoleSet());
CheckValid(); CheckValid();
} }
...@@ -227,10 +226,9 @@ bool SSAGraph::CheckLinksRoleSet() { ...@@ -227,10 +226,9 @@ bool SSAGraph::CheckLinksRoleSet() {
Node *SSAGraph::NewArgumentNode(const std::string &name) { Node *SSAGraph::NewArgumentNode(const std::string &name) {
node_storage_.emplace_back(); node_storage_.emplace_back();
CHECK(!arguments_.count(name)) << "duplicate argument called " << name; auto &arg_node = node_storage_.back();
arguments_[name] = &node_storage_.back(); arg_node.AsArg(name, node_storage_.size() - 1);
node_storage_.back().AsArg(name); return &arg_node;
return &node_storage_.back();
} }
Node *SSAGraph::NewInstructNode() { Node *SSAGraph::NewInstructNode() {
......
...@@ -40,8 +40,6 @@ class SSAGraph : GraphBase { ...@@ -40,8 +40,6 @@ class SSAGraph : GraphBase {
void Build(const Program &program, const std::vector<Place> &valid_places); void Build(const Program &program, const std::vector<Place> &valid_places);
void RemoveNode(const mir::Node *node); void RemoveNode(const mir::Node *node);
mir::Node *Argument(const std::string &name);
std::vector<mir::Node *> StmtTopologicalOrder(); std::vector<mir::Node *> StmtTopologicalOrder();
// The inputs of the graph. // The inputs of the graph.
...@@ -68,9 +66,7 @@ class SSAGraph : GraphBase { ...@@ -68,9 +66,7 @@ class SSAGraph : GraphBase {
const std::vector<Place> &valid_places); const std::vector<Place> &valid_places);
private: private:
void GraphCreateTmpVarNodes(const Program &program); mir::Node *Argument(const std::string &name);
void GraphCreateWeightVarNodes(const Program &program);
// Check the bidirectional connection. // Check the bidirectional connection.
bool CheckBidirectionalConnection(); bool CheckBidirectionalConnection();
bool CheckNodesRoleSet(); bool CheckNodesRoleSet();
......
...@@ -65,20 +65,22 @@ void TypeTargetTransformPass::ComplementInputs(SSAGraph* graph, Node* inst_node, ...@@ -65,20 +65,22 @@ void TypeTargetTransformPass::ComplementInputs(SSAGraph* graph, Node* inst_node,
<< " for kernel " << inst.op->DebugString() << " " << " for kernel " << inst.op->DebugString() << " "
<< *in->AsArg().type << " -> " << *decl_arg_type; << *in->AsArg().type << " -> " << *decl_arg_type;
// Add an IoCopy instruction to make the input compatible with other dist. // Add an IoCopy instruction to make the input compatible with other dist.
AddIoCopyInst(*in->AsArg().type, *decl_arg_type, in->AsArg().name, graph, AddIoCopyInst(*in->AsArg().type, *decl_arg_type, in, graph, inst_node,
inst_node, valid_places_); valid_places_);
} }
} }
void TypeTargetTransformPass::AddIoCopyInst( void TypeTargetTransformPass::AddIoCopyInst(
const Type& from, const Type& to, const std::string& var, SSAGraph* graph, const Type& from, const Type& to, Node* in, SSAGraph* graph,
Node* inst_node, const std::vector<Place>& valid_places) { Node* inst_node, const std::vector<Place>& valid_places) {
CHECK(!valid_places.empty()) << "valid_place should be set"; CHECK(!valid_places.empty()) << "valid_place should be set";
// var -> new_transform_op -> new_var -> inst // var -> new_transform_op -> new_var -> inst
// So there will be a new Argument node and a new IoCopy Statement Node. // So there will be a new Argument node and a new IoCopy Statement Node.
CHECK(in->IsArg());
auto node_id = [&] { return graph->nodes().size(); }; auto node_id = [&] { return graph->nodes().size(); };
auto io_copy_output_name = var + "/trans/" + std::to_string(node_id()); auto io_copy_output_name =
in->AsArg().name + "/trans/" + std::to_string(node_id());
auto* io_copy_output_arg = graph->NewArgumentNode(io_copy_output_name); auto* io_copy_output_arg = graph->NewArgumentNode(io_copy_output_name);
auto* io_copy_inst = graph->NewInstructNode(); auto* io_copy_inst = graph->NewInstructNode();
...@@ -92,7 +94,7 @@ void TypeTargetTransformPass::AddIoCopyInst( ...@@ -92,7 +94,7 @@ void TypeTargetTransformPass::AddIoCopyInst(
// Create IoCopy Instruction. // Create IoCopy Instruction.
cpp::OpDesc op_desc; cpp::OpDesc op_desc;
op_desc.SetType("io_copy"); op_desc.SetType("io_copy");
op_desc.SetInput("Input", {var}); op_desc.SetInput("Input", {in->AsArg().name});
op_desc.SetOutput("Out", {io_copy_output_name}); op_desc.SetOutput("Out", {io_copy_output_name});
io_copy_op->Attach(op_desc, inst_node->AsStmt().op->scope()); io_copy_op->Attach(op_desc, inst_node->AsStmt().op->scope());
...@@ -100,18 +102,18 @@ void TypeTargetTransformPass::AddIoCopyInst( ...@@ -100,18 +102,18 @@ void TypeTargetTransformPass::AddIoCopyInst(
io_copy_inst->AsStmt("io_copy", std::move(kernels), io_copy_op); io_copy_inst->AsStmt("io_copy", std::move(kernels), io_copy_op);
// Remove the old link // Remove the old link
RemoveDirectedLink(graph->Argument(var), inst_node); RemoveDirectedLink(in, inst_node);
// Update the original instruction OpDesc. // Update the original instruction OpDesc.
// Update its input to the io_copy_output_name // Update its input to the io_copy_output_name
// Add new link, var -> new_inst, new_inst->newarg, newarg->inst // Add new link, var -> new_inst, new_inst->newarg, newarg->inst
DirectedLink(graph->Argument(var), io_copy_inst); DirectedLink(in, io_copy_inst);
DirectedLink(io_copy_inst, io_copy_output_arg); DirectedLink(io_copy_inst, io_copy_output_arg);
DirectedLink(io_copy_output_arg, inst_node); DirectedLink(io_copy_output_arg, inst_node);
// reset opdesc and update kernel information // reset opdesc and update kernel information
UpdateInputTo(inst_node->AsStmt().op->mutable_op_info(), var, UpdateInputTo(inst_node->AsStmt().op->mutable_op_info(), in->AsArg().name,
io_copy_output_name); io_copy_output_name);
inst_node->AsStmt().op->Attach(*inst_node->AsStmt().op->op_info(), inst_node->AsStmt().op->Attach(*inst_node->AsStmt().op->op_info(),
......
...@@ -45,7 +45,7 @@ class TypeTargetTransformPass : public ProgramPass { ...@@ -45,7 +45,7 @@ class TypeTargetTransformPass : public ProgramPass {
void ComplementInputs(SSAGraph* graph, Node* inst_node, Node* in); void ComplementInputs(SSAGraph* graph, Node* inst_node, Node* in);
void AddIoCopyInst(const Type& from, const Type& to, const std::string& var, void AddIoCopyInst(const Type& from, const Type& to, Node* in,
SSAGraph* graph, Node* inst_node, SSAGraph* graph, Node* inst_node,
const std::vector<Place>& valid_places); const std::vector<Place>& valid_places);
......
...@@ -13,7 +13,10 @@ ...@@ -13,7 +13,10 @@
// limitations under the License. // limitations under the License.
#pragma once #pragma once
#include <map>
#include <memory> #include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/lite/core/mir/pass.h" #include "paddle/fluid/lite/core/mir/pass.h"
#include "paddle/fluid/lite/core/target_wrapper.h" #include "paddle/fluid/lite/core/target_wrapper.h"
...@@ -60,40 +63,44 @@ class VariablePlaceInferencePass : public DebugPass { ...@@ -60,40 +63,44 @@ class VariablePlaceInferencePass : public DebugPass {
// LOG(INFO) << "- inferencing type " << // LOG(INFO) << "- inferencing type " <<
// deal with inputs // deal with inputs
VLOG(4) << "inferencing op " << inst.op_type; VLOG(4) << "inferencing op " << inst.op_type;
for (auto& arg_name : inst.op_info()->input_argnames()) { // TODO(zhaolong): Add check if the node's name in op's arguments.
auto get_argname = [&](
const std::string& node_name,
const std::map<std::string, std::vector<std::string>>& argname_map)
-> std::string {
for (auto& ele : argname_map) {
auto it =
std::find(ele.second.begin(), ele.second.end(), node_name);
if (it != ele.second.end()) return ele.first;
}
return "";
};
for (auto* x_in : x->inlinks) {
std::string node_name = x_in->AsArg().name;
std::string arg_name = get_argname(node_name, inst.op_info()->inputs());
CHECK(arg_name.size() > 0) << "can not found op arguments for node "
<< node_name;
VLOG(3) << "-- input arg_name " << arg_name; VLOG(3) << "-- input arg_name " << arg_name;
// check if inputs's place is set, if not set, update them with the
// kernel's declaration.
auto type = inst.picked_kernel().GetInputDeclType(arg_name); auto type = inst.picked_kernel().GetInputDeclType(arg_name);
auto arg_names = inst.op_info()->inputs().at(arg_name); if (!x_in->AsArg().type) {
VLOG(4) << "set type " << *type << " " << x_in;
for (auto& arg_name : arg_names) { x_in->AsArg().type = type;
VLOG(3) << "--- var " << arg_name;
auto* node = graph->RetrieveArgument(arg_name);
CHECK(node) << "argument " << arg_name << " not exists in the graph";
auto& arg_node = node->AsArg();
if (!arg_node.type) {
VLOG(4) << "set type " << *type << " " << node;
arg_node.type = type;
}
} }
} }
for (auto& arg_name : inst.op_info()->output_argnames()) { for (auto* x_out : x->outlinks) {
std::string node_name = x_out->AsArg().name;
std::string arg_name =
get_argname(node_name, inst.op_info()->outputs());
CHECK(arg_name.size() > 0) << "can not found op arguments for node "
<< node_name;
VLOG(3) << "-- output arg_name " << arg_name; VLOG(3) << "-- output arg_name " << arg_name;
auto type = inst.picked_kernel().GetOutputDeclType(arg_name); auto type = inst.picked_kernel().GetOutputDeclType(arg_name);
auto arg_names = inst.op_info()->outputs().at(arg_name); if (!x_out->AsArg().type) {
// check if outputs's place is set, if not set, update them with the VLOG(4) << "set type " << *type << " " << x_out;
// kernel's declaration. x_out->AsArg().type = type;
for (auto& arg_name : arg_names) {
VLOG(3) << "--- var " << arg_name;
auto* node = graph->RetrieveArgument(arg_name);
CHECK(node) << "argument " << arg_name << " not exists in the graph";
auto& arg_node = node->AsArg();
if (!arg_node.type) {
node->AsArg().type = type;
VLOG(3) << "set type " << *type;
}
} }
} }
} }
......
...@@ -49,6 +49,8 @@ class Optimizer { ...@@ -49,6 +49,8 @@ class Optimizer {
#ifndef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK #ifndef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
if (passes.empty()) { if (passes.empty()) {
RunPasses(std::vector<std::string>{{ RunPasses(std::vector<std::string>{{
"lite_conv_bn_fuse_pass", //
"lite_conv_elementwise_add_act_fuse_pass", //
"lite_fc_fuse_pass", // "lite_fc_fuse_pass", //
"static_kernel_pick_pass", // "static_kernel_pick_pass", //
"variable_place_inference_pass", // "variable_place_inference_pass", //
......
...@@ -74,6 +74,7 @@ class Conv2dCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> { ...@@ -74,6 +74,7 @@ class Conv2dCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
lite::Tensor col_matrix; lite::Tensor col_matrix;
if (is_expand) { if (is_expand) {
col.Resize(col_shape); col.Resize(col_shape);
col.mutable_data<T>();
col_matrix.ShareDataWith(col); col_matrix.ShareDataWith(col);
col_matrix.Resize(col_matrix_shape); col_matrix.Resize(col_matrix_shape);
} }
...@@ -155,7 +156,6 @@ REGISTER_LITE_KERNEL(conv2d, kX86, kFloat, kNCHW, ...@@ -155,7 +156,6 @@ REGISTER_LITE_KERNEL(conv2d, kX86, kFloat, kNCHW,
.BindInput("Input", {LiteType::GetTensorTy(TARGET(kX86))}) .BindInput("Input", {LiteType::GetTensorTy(TARGET(kX86))})
.BindInput("Filter", {LiteType::GetTensorTy(TARGET(kX86))}) .BindInput("Filter", {LiteType::GetTensorTy(TARGET(kX86))})
.BindInput("Bias", {LiteType::GetTensorTy(TARGET(kX86))}) .BindInput("Bias", {LiteType::GetTensorTy(TARGET(kX86))})
.BindInput("Input", {LiteType::GetTensorTy(TARGET(kX86))})
.BindOutput("Output", {LiteType::GetTensorTy(TARGET(kX86))}) .BindOutput("Output", {LiteType::GetTensorTy(TARGET(kX86))})
.Finalize(); .Finalize();
...@@ -164,6 +164,5 @@ REGISTER_LITE_KERNEL(depthwise_conv2d, kX86, kFloat, kNCHW, ...@@ -164,6 +164,5 @@ REGISTER_LITE_KERNEL(depthwise_conv2d, kX86, kFloat, kNCHW,
.BindInput("Input", {LiteType::GetTensorTy(TARGET(kX86))}) .BindInput("Input", {LiteType::GetTensorTy(TARGET(kX86))})
.BindInput("Filter", {LiteType::GetTensorTy(TARGET(kX86))}) .BindInput("Filter", {LiteType::GetTensorTy(TARGET(kX86))})
.BindInput("Bias", {LiteType::GetTensorTy(TARGET(kX86))}) .BindInput("Bias", {LiteType::GetTensorTy(TARGET(kX86))})
.BindInput("Input", {LiteType::GetTensorTy(TARGET(kX86))})
.BindOutput("Output", {LiteType::GetTensorTy(TARGET(kX86))}) .BindOutput("Output", {LiteType::GetTensorTy(TARGET(kX86))})
.Finalize(); .Finalize();
...@@ -19,6 +19,7 @@ cc_library(dropout_op_lite SRCS dropout_op.cc DEPS ${op_DEPS}) ...@@ -19,6 +19,7 @@ cc_library(dropout_op_lite SRCS dropout_op.cc DEPS ${op_DEPS})
cc_library(concat_op_lite SRCS concat_op.cc DEPS ${op_DEPS}) cc_library(concat_op_lite SRCS concat_op.cc DEPS ${op_DEPS})
cc_library(conv_op_lite SRCS conv_op.cc DEPS ${op_DEPS}) cc_library(conv_op_lite SRCS conv_op.cc DEPS ${op_DEPS})
cc_library(pool_op_lite SRCS pool_op.cc DEPS ${op_DEPS}) cc_library(pool_op_lite SRCS pool_op.cc DEPS ${op_DEPS})
cc_library(batch_norm_op_lite SRCS batch_norm.cc DEPS ${op_DEPS})
set(ops_lite set(ops_lite
fc_op_lite fc_op_lite
...@@ -38,6 +39,7 @@ set(ops_lite ...@@ -38,6 +39,7 @@ set(ops_lite
concat_op_lite concat_op_lite
conv_op_lite conv_op_lite
pool_op_lite pool_op_lite
batch_norm_op_lite
CACHE INTERNAL "ops lite") CACHE INTERNAL "ops lite")
lite_cc_test(test_fc_op_lite SRCS fc_op_test.cc lite_cc_test(test_fc_op_lite SRCS fc_op_test.cc
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/operators/batch_norm.h"
#include <vector>
#include "paddle/fluid/lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace operators {
bool BatchNormOpLite::CheckShape() const { return true; }
bool BatchNormOpLite::InferShape() const { return true; }
} // namespace operators
} // namespace lite
} // namespace paddle
REGISTER_LITE_OP(batch_norm, paddle::lite::operators::BatchNormOpLite);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/lite/core/kernel.h"
#include "paddle/fluid/lite/core/op_lite.h"
#include "paddle/fluid/lite/core/scope.h"
#include "paddle/fluid/lite/operators/op_params.h"
#include "paddle/fluid/lite/utils/all.h"
namespace paddle {
namespace lite {
namespace operators {
class BatchNormOpLite : public OpLite {
public:
BatchNormOpLite() {}
explicit BatchNormOpLite(const std::string &type) : OpLite(type) {}
bool CheckShape() const override;
bool InferShape() const override;
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
// TODO(Superjomn) replace framework::OpDesc with a lite one.
bool AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) override {
auto x = op_desc.Input("X").front();
auto bias = op_desc.Input("Bias").front();
auto mean = op_desc.Input("Mean").front();
auto scale = op_desc.Input("Scale").front();
auto variance = op_desc.Input("Variance").front();
auto out = op_desc.Output("Y").front();
auto mean_out = op_desc.Output("MeanOut").front();
auto var_out = op_desc.Output("VarianceOut").front();
auto saved_mean = op_desc.Output("SavedMean").front();
auto saved_var = op_desc.Output("SavedVariance").front();
auto *var = scope->FindVar(x);
param_.x = var->GetMutable<Tensor>();
var = scope->FindVar(bias);
param_.bias = var->GetMutable<Tensor>();
var = scope->FindVar(mean);
param_.mean = var->GetMutable<Tensor>();
var = scope->FindVar(scale);
param_.scale = var->GetMutable<Tensor>();
var = scope->FindVar(variance);
param_.var = var->GetMutable<Tensor>();
var = scope->FindVar(out);
param_.out = var->GetMutable<Tensor>();
var = scope->FindVar(mean_out);
param_.mean_out = var->GetMutable<Tensor>();
var = scope->FindVar(var_out);
param_.var_out = var->GetMutable<Tensor>();
var = scope->FindVar(saved_mean);
param_.saved_mean = var->GetMutable<Tensor>();
var = scope->FindVar(saved_var);
param_.saved_var = var->GetMutable<Tensor>();
param_.eps = op_desc.GetAttr<float>("epsilon");
return true;
}
std::string DebugString() const override { return "batch_norm"; }
private:
mutable BatchNormParam param_;
};
} // namespace operators
} // namespace lite
} // namespace paddle
...@@ -237,6 +237,22 @@ struct SGDParam { ...@@ -237,6 +237,22 @@ struct SGDParam {
lite::Tensor* ParamOut{}; lite::Tensor* ParamOut{};
}; };
//
struct BatchNormParam {
lite::Tensor* x{};
lite::Tensor* bias{};
lite::Tensor* mean{};
lite::Tensor* scale{};
lite::Tensor* var{};
lite::Tensor* out{};
lite::Tensor* mean_out{};
lite::Tensor* var_out{};
lite::Tensor* saved_mean{};
lite::Tensor* saved_var{};
float eps{1e-5};
};
} // namespace operators } // namespace operators
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册