提交 8864f459 编写于 作者: N nhzlx

1. delete unused file

2. refine after mrege from github
上级 0c3191f4
...@@ -43,8 +43,6 @@ endif() ...@@ -43,8 +43,6 @@ endif()
# lite_cc_test(test_light_api SRCS light_api_test.cc DEPS light_api_lite ARGS --optimized_model=${LITE_MODEL_DIR}/lite_naive_model_opt SERIAL) # lite_cc_test(test_light_api SRCS light_api_test.cc DEPS light_api_lite ARGS --optimized_model=${LITE_MODEL_DIR}/lite_naive_model_opt SERIAL)
# endif() # endif()
message(STATUS "!!!" ${ops_lite})
message(STATUS "!!!" ${arm_kernels})
lite_cc_binary(cxx_api_lite_bin SRCS cxx_api_bin.cc lite_cc_binary(cxx_api_lite_bin SRCS cxx_api_bin.cc
DEPS DEPS
cxx_api_lite cxx_api_lite
......
...@@ -14,5 +14,9 @@ set(mir_fusers ...@@ -14,5 +14,9 @@ set(mir_fusers
fuse_conv_bn fuse_conv_bn
CACHE INTERNAL "fusers") CACHE INTERNAL "fusers")
if (LITE_WITH_LIGHT_WEIGHT_FRAMEWORK)
return()
endif()
lite_cc_test(test_lite_conv_bn_fuse SRCS conv_bn_fuse_pass_test.cc lite_cc_test(test_lite_conv_bn_fuse SRCS conv_bn_fuse_pass_test.cc
DEPS elementwise_ops_lite batch_norm_op_lite conv_op_lite proto_desc compatible_pb_lite program_lite mir_pass mir_pass_manager pattern_matcher_high_api) DEPS elementwise_ops_lite batch_norm_op_lite conv_op_lite proto_desc compatible_pb_lite program_lite mir_pass mir_pass_manager pattern_matcher_high_api)
...@@ -84,9 +84,9 @@ void ConvBNFuser::InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) { ...@@ -84,9 +84,9 @@ void ConvBNFuser::InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) {
->GetMutable<lite::Tensor>(); ->GetMutable<lite::Tensor>();
size_t bias_size = bn_scale_t->data_size(); size_t bias_size = bn_scale_t->data_size();
auto bn_scale_d = bn_scale_t->mutable_data<float>(); auto bn_scale_d = bn_scale_t->mutable_data<float>();
PADDLE_ENFORCE(bias_size == conv_weight_dims[0], CHECK(bias_size == conv_weight_dims[0])
"The BN bias's size should be equal to the size of the first " << "The BN bias's size should be equal to the size of the first "
"dim size of the conv weights"); << "dim size of the conv weights";
auto bn_mean_t = scope->FindVar(matched.at("bn_mean")->arg()->name) auto bn_mean_t = scope->FindVar(matched.at("bn_mean")->arg()->name)
->GetMutable<lite::Tensor>(); ->GetMutable<lite::Tensor>();
......
...@@ -409,8 +409,8 @@ PMNode *PMNode::assert_is_op_output(const std::string &op_type) { ...@@ -409,8 +409,8 @@ PMNode *PMNode::assert_is_op_output(const std::string &op_type) {
bool IsNthOutput(const Node *var, const Node *op, const std::string &argument, bool IsNthOutput(const Node *var, const Node *op, const std::string &argument,
size_t nth) { size_t nth) {
PADDLE_ENFORCE(var->IsArg()); CHECK(var->IsArg());
PADDLE_ENFORCE(op->IsStmt()); CHECK(op->IsStmt());
auto op_info = op->stmt()->op_info(); auto op_info = op->stmt()->op_info();
if (op_info->Output(argument).size() <= nth) return false; if (op_info->Output(argument).size() <= nth) return false;
return var->arg()->name == op_info->Output(argument)[nth]; return var->arg()->name == op_info->Output(argument)[nth];
...@@ -418,8 +418,8 @@ bool IsNthOutput(const Node *var, const Node *op, const std::string &argument, ...@@ -418,8 +418,8 @@ bool IsNthOutput(const Node *var, const Node *op, const std::string &argument,
bool IsNthInput(const Node *var, const Node *op, const std::string &argument, bool IsNthInput(const Node *var, const Node *op, const std::string &argument,
size_t nth) { size_t nth) {
PADDLE_ENFORCE(var->IsArg()); CHECK(var->IsArg());
PADDLE_ENFORCE(op->IsStmt()); CHECK(op->IsStmt());
auto op_info = op->stmt()->op_info(); auto op_info = op->stmt()->op_info();
if (op_info->Input(argument).size() <= nth) return false; if (op_info->Input(argument).size() <= nth) return false;
return var->arg()->name == op_info->Input(argument)[nth]; return var->arg()->name == op_info->Input(argument)[nth];
......
...@@ -26,8 +26,6 @@ namespace mir { ...@@ -26,8 +26,6 @@ namespace mir {
bool SSAGraph::CheckBidirectionalConnection() { bool SSAGraph::CheckBidirectionalConnection() {
LOG(INFO) << "node count " << node_storage_.size(); LOG(INFO) << "node count " << node_storage_.size();
for (auto &node : node_storage_) { for (auto &node : node_storage_) {
if (node.IsStmt()) LOG(INFO) << node.AsStmt().op_info()->Type();
if (node.IsArg()) LOG(INFO) << node.AsArg().name << " " << node.AsArg().id;
for (auto *in : node.inlinks) { for (auto *in : node.inlinks) {
CHECK(in->outlinks.end() != CHECK(in->outlinks.end() !=
std::find(in->outlinks.begin(), in->outlinks.end(), &node)); std::find(in->outlinks.begin(), in->outlinks.end(), &node));
...@@ -123,11 +121,8 @@ void SSAGraph::Build(const Program &program, ...@@ -123,11 +121,8 @@ void SSAGraph::Build(const Program &program,
std::unordered_map<std::string, mir::Node *> arg_update_node_map_; std::unordered_map<std::string, mir::Node *> arg_update_node_map_;
for (auto &op : program.ops()) { for (auto &op : program.ops()) {
LOG(INFO) << op->op_info()->Type();
auto *op_node = GraphCreateInstructNode(op, valid_places); auto *op_node = GraphCreateInstructNode(op, valid_places);
LOG(INFO) << "input:";
for (const std::string &name : op->op_info()->input_names()) { for (const std::string &name : op->op_info()->input_names()) {
LOG(INFO) << name;
mir::Node *arg_node = nullptr; mir::Node *arg_node = nullptr;
if (arg_update_node_map_.count(name)) { if (arg_update_node_map_.count(name)) {
arg_node = arg_update_node_map_.at(name); arg_node = arg_update_node_map_.at(name);
...@@ -141,9 +136,7 @@ void SSAGraph::Build(const Program &program, ...@@ -141,9 +136,7 @@ void SSAGraph::Build(const Program &program,
CHECK(arg_node->IsRoleSet()); CHECK(arg_node->IsRoleSet());
DirectedLink(arg_node, op_node); DirectedLink(arg_node, op_node);
} }
LOG(INFO) << "output:";
for (const std::string &name : op->op_info()->output_names()) { for (const std::string &name : op->op_info()->output_names()) {
LOG(INFO) << name;
node_storage_.emplace_back(); node_storage_.emplace_back();
auto *arg_node = &node_storage_.back(); auto *arg_node = &node_storage_.back();
arg_node->AsArg(name, node_storage_.size() - 1); arg_node->AsArg(name, node_storage_.size() - 1);
......
...@@ -46,27 +46,33 @@ class Optimizer { ...@@ -46,27 +46,33 @@ class Optimizer {
SpecifyKernelPickTactic(kernel_pick_factor); SpecifyKernelPickTactic(kernel_pick_factor);
InitTargetTypeTransformPass(); InitTargetTypeTransformPass();
#ifndef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK // #ifndef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
if (passes.empty()) { if (passes.empty()) {
RunPasses(std::vector<std::string>{{ RunPasses(std::vector<std::string>{{
"graph_visualze", //
"lite_conv_bn_fuse_pass", // "lite_conv_bn_fuse_pass", //
"graph_visualze", //
"lite_conv_elementwise_add_act_fuse_pass", // "lite_conv_elementwise_add_act_fuse_pass", //
"graph_visualze", //
"lite_fc_fuse_pass", // "lite_fc_fuse_pass", //
"static_kernel_pick_pass", // "static_kernel_pick_pass", //
/*
"variable_place_inference_pass", // "variable_place_inference_pass", //
"argument_type_display_pass", // "argument_type_display_pass", //
"type_target_transform_pass", // "type_target_transform_pass", //
"argument_type_display_pass", // // "argument_type_display_pass", //
"variable_place_inference_pass", // "variable_place_inference_pass", //
"argument_type_display_pass", // "argument_type_display_pass", //
"io_copy_kernel_pick_pass", // "io_copy_kernel_pick_pass", //
"variable_place_inference_pass", // "variable_place_inference_pass", //
"graph_visualze", //
*/
"runtime_context_assign_pass", // "runtime_context_assign_pass", //
}}); }});
} else { } else {
RunPasses(passes); RunPasses(passes);
} }
#endif // #endif
exec_scope_ = program.exec_scope(); exec_scope_ = program.exec_scope();
} }
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/operators/batch_norm.h"
#include <vector>
#include "paddle/fluid/lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace operators {
bool BatchNormOpLite::CheckShape() const { return true; }
bool BatchNormOpLite::InferShape() const { return true; }
} // namespace operators
} // namespace lite
} // namespace paddle
REGISTER_LITE_OP(batch_norm, paddle::lite::operators::BatchNormOpLite);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/lite/core/kernel.h"
#include "paddle/fluid/lite/core/op_lite.h"
#include "paddle/fluid/lite/core/scope.h"
#include "paddle/fluid/lite/operators/op_params.h"
#include "paddle/fluid/lite/utils/all.h"
namespace paddle {
namespace lite {
namespace operators {
class BatchNormOpLite : public OpLite {
public:
BatchNormOpLite() {}
explicit BatchNormOpLite(const std::string &type) : OpLite(type) {}
bool CheckShape() const override;
bool InferShape() const override;
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
// TODO(Superjomn) replace framework::OpDesc with a lite one.
bool AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) override {
auto x = op_desc.Input("X").front();
auto bias = op_desc.Input("Bias").front();
auto mean = op_desc.Input("Mean").front();
auto scale = op_desc.Input("Scale").front();
auto variance = op_desc.Input("Variance").front();
auto out = op_desc.Output("Y").front();
auto mean_out = op_desc.Output("MeanOut").front();
auto var_out = op_desc.Output("VarianceOut").front();
auto saved_mean = op_desc.Output("SavedMean").front();
auto saved_var = op_desc.Output("SavedVariance").front();
auto *var = scope->FindVar(x);
param_.x = var->GetMutable<Tensor>();
var = scope->FindVar(bias);
param_.bias = var->GetMutable<Tensor>();
var = scope->FindVar(mean);
param_.mean = var->GetMutable<Tensor>();
var = scope->FindVar(scale);
param_.scale = var->GetMutable<Tensor>();
var = scope->FindVar(variance);
param_.var = var->GetMutable<Tensor>();
var = scope->FindVar(out);
param_.out = var->GetMutable<Tensor>();
var = scope->FindVar(mean_out);
param_.mean_out = var->GetMutable<Tensor>();
var = scope->FindVar(var_out);
param_.var_out = var->GetMutable<Tensor>();
var = scope->FindVar(saved_mean);
param_.saved_mean = var->GetMutable<Tensor>();
var = scope->FindVar(saved_var);
param_.saved_var = var->GetMutable<Tensor>();
param_.eps = op_desc.GetAttr<float>("epsilon");
return true;
}
std::string DebugString() const override { return "batch_norm"; }
private:
mutable BatchNormParam param_;
};
} // namespace operators
} // namespace lite
} // namespace paddle
...@@ -82,7 +82,7 @@ bool BatchNormOp::AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) { ...@@ -82,7 +82,7 @@ bool BatchNormOp::AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) {
param_.variance = param_.variance =
scope->FindVar(op_desc.Input("Variance").front())->GetMutable<Tensor>(); scope->FindVar(op_desc.Input("Variance").front())->GetMutable<Tensor>();
param_.y = scope->FindVar(op_desc.Output("Y").front())->GetMutable<Tensor>(); param_.y = scope->FindVar(op_desc.Output("Y").front())->GetMutable<Tensor>();
param_.is_test = op_desc.GetAttr<bool>("is_test"); param_.is_test = op_desc.GetAttr<int>("is_test");
param_.use_global_stats = op_desc.GetAttr<bool>("use_global_stats"); param_.use_global_stats = op_desc.GetAttr<bool>("use_global_stats");
if (!param_.is_test) { if (!param_.is_test) {
param_.mean_out = param_.mean_out =
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册