提交 ec38041b 编写于 作者: C Chunwei

Merge branch 'chunwei/clear-scale' into 'incubate/lite'

add scale eliminate pass

See merge request inference/paddlelite!18
......@@ -82,7 +82,7 @@ TEST(CXXApi_LightApi, save_and_load_model) {
ASSERT_TRUE(TensorCompareWith(*cxx_out, *light_out));
std::vector<std::string> tensors_with_order({
"a", "fc_0.w_0", "fc_0.tmp_0", "scale_0.tmp_0",
"a", "fc_0.w_0", "scale_0.tmp_0",
});
for (const auto& tensor_name : tensors_with_order) {
......
......@@ -13,7 +13,7 @@
// limitations under the License.
#include "paddle/fluid/lite/api/cxx_api.h"
#include <chrono>
#include <chrono> // NOLINT
#include "paddle/fluid/lite/core/mir/use_passes.h"
#include "paddle/fluid/lite/core/op_registry.h"
......
......@@ -13,6 +13,7 @@
// limitations under the License.
#include "paddle/fluid/lite/api/lite_api_test_helper.h"
#include <vector>
DEFINE_string(model_dir, "", "");
DEFINE_string(optimized_model, "", "");
......
......@@ -59,4 +59,3 @@ lite_cc_test(test_type_system SRCS type_system_test.cc DEPS type_system utils_li
lite_cc_test(test_types_lite SRCS types_test.cc DEPS types_lite)
lite_cc_test(test_memory_lite SRCS memory_test.cc DEPS memory_lite)
lite_cc_test(test_context_lite SRCS context_test.cc DEPS context_lite X86_DEPS operator)
......@@ -5,21 +5,25 @@ cc_library(mir_pass_manager SRCS pass_manager.cc DEPS mir_pass mir_ssa_graph mir
cc_library(mir_pass_registry SRCS pass_registry.cc DEPS mir_pass_manager)
add_subdirectory(fusion)
add_subdirectory(elimination)
cc_library(mir_passes
SRCS fc_fuse_pass.cc
conv_elementwise_add_activation_fuse_pass.cc
elementwise_add_activation_fuse_pass.cc
conv_bn_fuse_pass.cc
static_kernel_pick_pass.cc
variable_place_inference_pass.cc
type_target_transform_pass.cc
io_copy_kernel_pick_pass.cc
graph_visualize_pass.cc
generate_program_pass.cc
argument_type_display_pass.cc
demo_pass.cc
runtime_context_assign_pass.cc
DEPS mir_pass types_lite context_lite ${mir_fusers})
SRCS
fusion/fc_fuse_pass.cc
fusion/conv_elementwise_add_activation_fuse_pass.cc
fusion/conv_bn_fuse_pass.cc
fusion/elementwise_add_activation_fuse_pass.cc
elimination/identity_scale_eliminate_pass.cc
static_kernel_pick_pass.cc
variable_place_inference_pass.cc
type_target_transform_pass.cc
io_copy_kernel_pick_pass.cc
graph_visualize_pass.cc
generate_program_pass.cc
argument_type_display_pass.cc
demo_pass.cc
runtime_context_assign_pass.cc
DEPS mir_pass types_lite context_lite ${mir_fusers})
#cc_test(test_ssa_graph SRCS ssa_graph_test.cc DEPS
#mir_ssa_graph scope_lite op_lite
......@@ -73,7 +77,7 @@ message(STATUS "----> Ops lite: ${ops_lite}")
message(STATUS "----> Host kernels: ${host_kernels}")
message(STATUS "----> X86 kernels: ${x86_kernels}")
lite_cc_test(test_lite_fc_fuse SRCS fc_fuse_pass_test.cc
lite_cc_test(test_lite_fc_fuse SRCS fusion/fc_fuse_pass_test.cc
DEPS cxx_api_lite mir_passes
${ops_lite} ${host_kernels} ${x86_kernels} ${arm_kernels}
ARGS --model_dir=${LITE_MODEL_DIR}/lite_fc_model
......@@ -83,11 +87,11 @@ lite_download_and_uncompress(${LITE_MODEL_DIR} ${LITE_URL} "lite_fc_model.tar.gz
add_dependencies(test_lite_fc_fuse extern_lite_download_lite_fc_model_tar_gz)
lite_cc_test(test_lite_conv_elementwise_add_activation_fuse
SRCS conv_elementwise_add_activation_fuse_pass_test.cc
lite_cc_test(test_lite_conv_elementwise_add_activation_fuse
SRCS fusion/conv_elementwise_add_activation_fuse_pass_test.cc
DEPS cxx_api_lite mir_passes
${ops_lite} ${host_kernels} ${x86_kernels})
lite_cc_test(test_lite_elementwise_add_activation_fuse
SRCS elementwise_add_activation_fuse_pass_test.cc
lite_cc_test(test_lite_elementwise_add_activation_fuse
SRCS fusion/elementwise_add_activation_fuse_pass_test.cc
DEPS cxx_api_lite mir_passes
${ops_lite} ${host_kernels} ${x86_kernels})
if (NOT LITE_WITH_LIGHT_WEIGHT_FRAMEWORK)
lite_cc_test(test_identity_scale_eliminate_pass_lite
SRCS identity_scale_eliminate_pass_test.cc
DEPS mir_passes program_lite proto_desc cpp_op_desc_lite
${ops_lite}
)
endif()
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/mir/pass.h"
#include "paddle/fluid/lite/core/mir/pass_registry.h"
#include "paddle/fluid/lite/core/mir/pattern_matcher_high_api.h"
namespace paddle {
namespace lite {
namespace mir {
namespace {
class Eliminator : public FuseBase {
public:
void BuildPattern() override {
auto* pre_op = OpNode("preop"); // the previous op's output need update
// TODO(Superjomn) check has only one output
auto* x = VarNode("x")->assert_is_op_input("scale", "X");
auto* scale_op = OpNode("scale", "scale")
->assert_op_attr<float>("scale", 1.)
->assert_op_attr<float>("bias", 0.);
auto* out = VarNode("out")->assert_is_op_output("scale", "Out");
*pre_op >> *x >> *scale_op >> *out;
// The pre_op will be eliminated, and a new output-updated op will insert.
x->AsIntermediate(); // x is pre_op's output, need to update
}
private:
void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override {
auto& pre_op = matched.at("preop")->AsStmt();
auto op_info = *pre_op.op_info();
op_info.UpdateAllOutputs(matched.at("x")->AsArg().name,
matched.at("out")->AsArg().name);
pre_op.ResetOp(op_info, graph->valid_places());
GraphSafeRemoveNodes(graph, {matched.at("scale")});
IR_NODE_LINK_TO(matched.at("preop"), matched.at("out"));
}
};
} // namespace
class IdentityScaleEliminatePass : public ProgramPass {
public:
void Apply(const std::unique_ptr<SSAGraph>& graph) override {
Eliminator eliminator;
eliminator(graph.get());
}
};
} // namespace mir
} // namespace lite
} // namespace paddle
REGISTER_MIR_PASS(identity_scale_eliminate_pass,
paddle::lite::mir::IdentityScaleEliminatePass);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/lite/core/mir/graph_visualize_pass.h"
#include "paddle/fluid/lite/core/mir/pass_registry.h"
#include "paddle/fluid/lite/core/mir/ssa_graph.h"
namespace paddle {
namespace lite {
namespace mir {
std::unique_ptr<SSAGraph> BuildGraph(framework::ProgramDesc* program_desc,
const std::shared_ptr<Scope>& scope,
const std::vector<Place>& valid_places) {
// Op list:
// (x)->feed -> (feed) -> scale -> (scale_out) -> fetch->(fetch)
// After pass
// (x)->feed->(scale_out)->fetch->(fetch)
auto* main_block = program_desc->MutableBlock(0);
auto* feed_op = main_block->AppendOp();
auto* scale_op = main_block->AppendOp();
auto* fetch_op = main_block->AppendOp();
main_block->Var("x");
main_block->Var("feed");
main_block->Var("scale_out");
main_block->Var("fetch_out");
scope->Var("x")->GetMutable<lite::Tensor>();
scope->Var("feed")->GetMutable<lite::Tensor>();
scope->Var("scale_out")->GetMutable<lite::Tensor>();
scope->Var("fetch_out")->GetMutable<lite::Tensor>();
feed_op->SetType("feed");
feed_op->SetInput("X", {"x"});
feed_op->SetAttr("col", 1);
feed_op->SetOutput("Out", {"feed"});
scale_op->SetType("scale");
scale_op->SetInput("X", {"feed"});
scale_op->SetOutput("Out", {"scale_out"});
scale_op->SetAttr("scale", 1.f);
scale_op->SetAttr("bias", 0.f);
scale_op->SetAttr("bias_after_scale", true);
fetch_op->SetType("fetch");
fetch_op->SetInput("X", {"scale_out"});
fetch_op->SetOutput("Out", {"fetch"});
fetch_op->SetAttr("col", 1);
program_desc->Flush();
lite::Program program(*program_desc->Proto(), scope, valid_places);
auto graph = std::unique_ptr<SSAGraph>(new SSAGraph());
graph->Build(program, valid_places);
LOG(INFO) << Visualize(graph.get());
return graph;
}
TEST(identity_test, test) {
framework::ProgramDesc program_desc;
std::vector<Place> places{{TARGET(kHost), PRECISION(kFloat)}};
auto scope = std::make_shared<Scope>();
auto graph = BuildGraph(&program_desc, scope, places);
const int num_nodes = graph->nodes().size();
auto pass = PassManager::Global().LookUp("identity_scale_eliminate_pass");
ASSERT_TRUE(pass);
pass->Apply(graph);
ASSERT_EQ(graph->nodes().size(), num_nodes - 2UL);
}
} // namespace mir
} // namespace lite
} // namespace paddle
USE_LITE_OP(feed)
USE_LITE_OP(fetch)
USE_LITE_OP(scale)
USE_MIR_PASS(identity_scale_eliminate_pass)
......@@ -11,7 +11,7 @@ cc_library(fuse_elementwise_add_activation
SRCS elementwise_add_activation_fuser.cc
DEPS pattern_matcher_high_api)
set(mir_fusers
set(mir_fusers
fuse_fc
fuse_conv_elementwise_add_activation
fuse_conv_bn
......
......@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/mir/conv_bn_fuse_pass.h"
#include "paddle/fluid/lite/core/mir/fusion/conv_bn_fuse_pass.h"
#include <memory>
#include <vector>
#include "paddle/fluid/lite/core/mir/fusion/conv_bn_fuser.h"
......
......@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/mir/conv_bn_fuse_pass.h"
#include "paddle/fluid/lite/core/mir/fusion/conv_bn_fuse_pass.h"
#include <gflags/gflags.h>
#include <gtest/gtest.h>
#include <vector>
......
......@@ -70,7 +70,7 @@ void ConvBNFuser::BuildPattern() {
void ConvBNFuser::InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) {
auto op_desc = GenOpDesc(matched);
auto eltwise_op = LiteOpRegistry::Global().Create("elementwise_add");
auto conv = matched.at("conv2d")->stmt()->op;
auto conv = matched.at("conv2d")->stmt()->op();
auto* scope = conv->scope();
auto& valid_places = conv->valid_places();
......
......@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/mir/conv_elementwise_add_activation_fuse_pass.h"
#include "paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_activation_fuse_pass.h"
#include <memory>
#include <vector>
#include "paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_activation_fuser.h"
......
......@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/mir/conv_elementwise_add_activation_fuse_pass.h"
#include "paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_activation_fuse_pass.h"
#include <gflags/gflags.h>
#include <gtest/gtest.h>
#include <vector>
......
......@@ -65,7 +65,7 @@ void ConvElementwiseAddActivationFuser::InsertNewNode(
SSAGraph* graph, const key2nodes_t& matched) {
auto op_desc = GenOpDesc(matched);
auto conv_op = LiteOpRegistry::Global().Create(conv_type_);
auto conv_old = matched.at("conv2d")->stmt()->op;
auto conv_old = matched.at("conv2d")->stmt()->op();
auto* scope = conv_old->scope();
auto& valid_places = conv_old->valid_places();
conv_op->Attach(op_desc, scope);
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "conv_elementwise_add_relu_fuse_pass.h"
#include <memory>
#include <vector>
#include "paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_relu_fuser.h"
#include "paddle/fluid/lite/core/mir/pass_registry.h"
namespace paddle {
namespace lite {
namespace mir {
void ConvElementwiseAddReLUFusePass::Apply(
const std::unique_ptr<SSAGraph>& graph) {
fusion::ConvElementwiseAddReLUFuser fuser("conv2d");
fuser(graph.get());
fusion::ConvElementwiseAddReLUFuser depthwise_fuser("depthwise_conv2d");
depthwise_fuser(graph.get());
}
} // namespace mir
} // namespace lite
} // namespace paddle
REGISTER_MIR_PASS(lite_conv_elementwise_add_act_fuse_pass,
paddle::lite::mir::ConvElementwiseAddReLUFusePass);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include "paddle/fluid/lite/core/mir/pass.h"
namespace paddle {
namespace lite {
namespace mir {
class ConvElementwiseAddReLUFusePass : public ProgramPass {
public:
void Apply(const std::unique_ptr<SSAGraph>& graph) override;
};
} // namespace mir
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "conv_elementwise_add_relu_fuse_pass.h"
#include <gflags/gflags.h>
#include <gtest/gtest.h>
#include <vector>
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/lite/api/cxx_api.h"
#include "paddle/fluid/lite/core/compatible_tensor.h"
#include "paddle/fluid/lite/core/mir/graph_visualize_pass.h"
#include "paddle/fluid/lite/core/mir/passes.h"
#include "paddle/fluid/lite/core/op_registry.h"
#include "paddle/fluid/lite/core/program.h"
DEFINE_string(model_dir, "", "");
DEFINE_string(optimized_model, "", "");
namespace paddle {
namespace lite {
namespace mir {
namespace fusion {
std::unique_ptr<SSAGraph> BuildGraph(framework::ProgramDesc* program_desc,
const std::shared_ptr<Scope>& scope,
const std::vector<Place>& valid_places) {
auto* main_block = program_desc->MutableBlock(0);
auto* conv2d_1 = main_block->AppendOp();
auto* conv2d_2 = main_block->AppendOp();
auto* add_1 = main_block->AppendOp();
auto* relu_1 = main_block->AppendOp();
auto* add_2 = main_block->AppendOp();
auto* relu_2 = main_block->AppendOp();
main_block->Var("input_1");
main_block->Var("input_2");
main_block->Var("filter_1");
main_block->Var("filter_2");
main_block->Var("conv2d_1_out");
main_block->Var("conv2d_2_out");
main_block->Var("bias_1");
main_block->Var("add_1_out");
main_block->Var("add_2_out");
main_block->Var("relu_1_out");
main_block->Var("out");
scope->Var("input_1")->GetMutable<lite::Tensor>();
scope->Var("input_2")->GetMutable<lite::Tensor>();
scope->Var("filter_1")->GetMutable<lite::Tensor>();
scope->Var("filter_2")->GetMutable<lite::Tensor>();
scope->Var("conv2d_1_out")->GetMutable<lite::Tensor>();
scope->Var("conv2d_2_out")->GetMutable<lite::Tensor>();
scope->Var("bias_1")->GetMutable<lite::Tensor>();
scope->Var("add_1_out")->GetMutable<lite::Tensor>();
scope->Var("add_2_out")->GetMutable<lite::Tensor>();
scope->Var("relu_1_out")->GetMutable<lite::Tensor>();
scope->Var("out")->GetMutable<lite::Tensor>();
conv2d_1->SetType("conv2d");
conv2d_1->SetInput("Input", {"input_1"});
conv2d_1->SetInput("Filter", {"filter_1"});
conv2d_1->SetOutput("Output", {"conv2d_1_out"});
conv2d_1->SetAttr("strides", std::vector<int>({1, 1}));
conv2d_1->SetAttr("paddings", std::vector<int>({0, 0}));
conv2d_1->SetAttr("groups", 1);
conv2d_1->SetAttr("dilations", std::vector<int>({1, 1}));
conv2d_1->SetAttr("fuse_relu", false);
add_1->SetType("elementwise_add");
add_1->SetInput("X", {"conv2d_1_out"});
add_1->SetInput("Y", {"bias_1"});
add_1->SetOutput("Out", {"add_1_out"});
add_1->SetAttr("axis", 1);
relu_1->SetType("relu");
relu_1->SetInput("X", {"add_1_out"});
relu_1->SetOutput("Out", {"relu_1_out"});
conv2d_2->SetType("conv2d");
conv2d_2->SetInput("Input", {"input_2"});
conv2d_2->SetInput("Filter", {"filter_2"});
conv2d_2->SetOutput("Output", {"conv2d_2_out"});
conv2d_2->SetAttr("strides", std::vector<int>({1, 1}));
conv2d_2->SetAttr("paddings", std::vector<int>({0, 0}));
conv2d_2->SetAttr("groups", 1);
conv2d_2->SetAttr("dilations", std::vector<int>({1, 1}));
conv2d_2->SetAttr("fuse_relu", false);
add_2->SetType("elementwise_add");
add_2->SetInput("X", {"conv2d_2_out"});
add_2->SetInput("Y", {"relu_1_out"});
add_2->SetOutput("Out", {"add_2_out"});
add_2->SetAttr("axis", 1);
relu_2->SetType("relu");
relu_2->SetInput("X", {"add_2_out"});
relu_2->SetOutput("Out", {"out"});
program_desc->Flush();
lite::Program program(*program_desc->Proto(), scope, valid_places);
auto graph = std::unique_ptr<SSAGraph>(new SSAGraph());
graph->Build(program, valid_places);
return graph;
}
TEST(conv_elementwise_add_relu_fuse_pass, graph_test) {
framework::ProgramDesc program_desc;
std::vector<Place> places{{TARGET(kHost), PRECISION(kFloat)}};
auto scope = std::make_shared<Scope>();
auto graph = BuildGraph(&program_desc, scope, places);
Visualize(graph.get());
ASSERT_EQ(graph->nodes().size(), 11UL /*vars*/ + 6UL /*ops*/);
Visualize(graph.get());
}
TEST(conv_elementwise_add_relu_fuse_pass, fuse_test_op) {
framework::ProgramDesc program_desc;
std::vector<Place> places{{TARGET(kHost), PRECISION(kFloat)}};
auto scope = std::make_shared<Scope>();
auto graph = BuildGraph(&program_desc, scope, places);
Visualize(graph.get());
const int num_nodes = graph->nodes().size();
auto* fuser = new ConvElementwiseAddReLUFusePass;
fuser->Apply(graph);
Visualize(graph.get());
ASSERT_EQ(graph->nodes().size(), num_nodes - 5UL * 2 /*nodes removed */ +
1UL * 2 /* fused fc node*/);
}
} // namespace fusion
} // namespace mir
} // namespace lite
} // namespace paddle
USE_LITE_OP(elementwise_add);
USE_LITE_OP(conv2d);
USE_LITE_OP(depthwise_conv2d);
USE_LITE_OP(relu);
......@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/mir/elementwise_add_activation_fuse_pass.h"
#include "paddle/fluid/lite/core/mir/fusion/elementwise_add_activation_fuse_pass.h"
#include <memory>
#include <vector>
#include "paddle/fluid/lite/core/mir/fusion/elementwise_add_activation_fuser.h"
......
......@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/mir/elementwise_add_activation_fuse_pass.h"
#include "paddle/fluid/lite/core/mir/fusion/elementwise_add_activation_fuse_pass.h"
#include <gflags/gflags.h>
#include <gtest/gtest.h>
#include <vector>
......
......@@ -54,7 +54,7 @@ void ElementwiseAddActivationFuser::InsertNewNode(SSAGraph* graph,
auto op_desc = GenOpDesc(matched);
auto op =
LiteOpRegistry::Global().Create("fusion_elementwise_add_activation");
auto old_op = matched.at("add")->stmt()->op;
auto old_op = matched.at("add")->stmt()->op();
auto* scope = old_op->scope();
auto& valid_places = old_op->valid_places();
op->Attach(op_desc, scope);
......
......@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/mir/fc_fuse_pass.h"
#include "paddle/fluid/lite/core/mir/fusion/fc_fuse_pass.h"
#include <memory>
#include <vector>
#include "paddle/fluid/lite/core/mir/fusion/fc_fuser.h"
......
......@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/mir/fc_fuse_pass.h"
#include "fc_fuse_pass.h"
#include <gflags/gflags.h>
#include <gtest/gtest.h>
#include <vector>
......
......@@ -46,7 +46,7 @@ void FcFuser::BuildPattern() {
void FcFuser::InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) {
auto op_desc = GenOpDesc(matched);
auto fc_op = LiteOpRegistry::Global().Create("fc");
auto mul = matched.at("mul")->stmt()->op;
auto mul = matched.at("mul")->stmt()->op();
auto* scope = mul->scope();
auto& valid_places = mul->valid_places();
fc_op->Attach(op_desc, scope);
......
......@@ -29,7 +29,7 @@ void GenerateProgramPass::Apply(const std::unique_ptr<SSAGraph>& graph) {
if (item->IsStmt()) {
auto& stmt = item->AsStmt();
VLOG(4) << stmt;
insts_.emplace_back(stmt.op, std::move(stmt.valid_kernels.front()));
insts_.emplace_back(stmt.op(), std::move(stmt.kernels().front()));
}
}
}
......
......@@ -39,7 +39,7 @@ std::string Visualize(mir::SSAGraph* graph) {
if (node.IsArg()) {
key = node.AsArg().name;
} else {
key = node.AsStmt().op_type + std::to_string(id++);
key = node.AsStmt().op_type() + std::to_string(id++);
}
if (node.IsStmt()) {
......
......@@ -25,11 +25,11 @@ class IoCopyKernelPickPass : public StmtPass {
for (auto& node : graph->mutable_nodes()) {
if (!node.IsStmt()) continue;
auto& inst = node.AsStmt();
if (inst.op_type != "io_copy") continue;
if (inst.op_type() != "io_copy") continue;
LOG(INFO) << "....> picking a IO COPY kernel";
auto& kernels = node.AsStmt().valid_kernels;
auto& kernels = node.AsStmt().kernels();
CHECK(!kernels.empty()) << "No valid kernels found for IoCopy Op";
const auto* inty = node.inlinks.front()->AsArg().type;
const auto* outy = node.outlinks.front()->AsArg().type;
......
......@@ -13,3 +13,62 @@
// limitations under the License.
#include "paddle/fluid/lite/core/mir/node.h"
#include "paddle/fluid/lite/core/op_registry.h"
namespace paddle {
namespace lite {
const OpInfo *mir::Node::Stmt::op_info() const {
CHECK(op_);
return op_->op_info();
}
Place mir::Node::Stmt::place() const {
CHECK(!valid_kernels_.empty());
return valid_kernels_.front()->place();
}
KernelBase &mir::Node::Stmt::picked_kernel() {
CHECK(!valid_kernels_.empty()) << "no kernel for " << op_type();
return *valid_kernels_.front();
}
OpInfo *mir::Node::Stmt::mutable_op_info() {
CHECK(op_);
return op_->mutable_op_info();
}
void mir::Node::Stmt::ResetOp(const cpp::OpDesc &op_desc,
const std::vector<Place> &valid_places,
lite::Scope *scope) {
CHECK((op_ && op_->scope()) || scope) << "Either scope should be set";
lite::Scope *the_scope = scope ? scope : op_->scope();
op_->Attach(op_desc, the_scope);
// Recreate the kernels with the latest OpInfo.
valid_kernels_.clear();
if (!op_ || op_->op_info()->Type() != op_desc.Type()) {
op_ = LiteOpRegistry::Global().Create(op_desc.Type());
CHECK(op_) << "No op found for " << op_desc.Type();
}
valid_kernels_ = op_->CreateKernels(valid_places);
}
std::ostream &mir::operator<<(std::ostream &os, const mir::Node::Stmt &other) {
os << "Statement " << other.op_type() << " " << other.place();
return os;
}
mir::Node::Arg &mir::Node::AsArg(const std::string &name, int id) {
auto &x = AsArg();
x.name = name;
x.id = id;
return x;
}
mir::Node::Arg &mir::Node::AsArg(const std::string &name) {
auto &x = AsArg();
x.name = name;
return x;
}
} // namespace lite
} // namespace paddle
......@@ -41,32 +41,40 @@ class Node {
kUnk,
};
struct Stmt {
std::string op_type;
class Stmt {
// The kernel instances this Statement contains.
std::vector<std::unique_ptr<KernelBase>> valid_kernels;
std::vector<std::unique_ptr<KernelBase>> valid_kernels_;
// TODO(Superjomn) make this a shared_ptr for resource safety.
std::shared_ptr<OpLite> op; // we hold op to run InferShape
std::shared_ptr<OpLite> op_; // we hold op to run InferShape
const OpInfo* op_info() {
CHECK(op);
return op->op_info();
}
public:
// Refresh the operator and kernels with the latest OpInfo.
void ResetOp(const cpp::OpDesc& op_desc,
const std::vector<Place>& valid_places,
lite::Scope* scope = nullptr);
Place place() const {
CHECK(!valid_kernels.empty());
return valid_kernels.front()->place();
}
std::string op_type() const { return op_info()->Type(); }
const OpInfo* op_info() const;
OpInfo* mutable_op_info();
KernelBase& picked_kernel() {
CHECK(!valid_kernels.empty()) << "no kernel for " << op_type;
return *valid_kernels.front();
void SetKernels(std::vector<std::unique_ptr<KernelBase>>&& kernels) {
valid_kernels_ = std::move(kernels);
}
friend std::ostream& operator<<(std::ostream& os, const Stmt& other) {
os << "Statement " << other.op_type << " " << other.place();
return os;
std::vector<std::unique_ptr<KernelBase>>& kernels() {
return valid_kernels_;
}
void SetOp(const std::shared_ptr<OpLite>& op) { op_ = op; }
const std::shared_ptr<OpLite> op() const { return op_; }
Place place() const;
KernelBase& picked_kernel();
friend std::ostream& operator<<(std::ostream& os, const Stmt& other);
// Description.
std::string desc;
};
struct Arg {
......@@ -78,26 +86,16 @@ class Node {
bool is_weight{false};
};
Arg& AsArg(const std::string& name, int id) {
auto& x = AsArg();
x.name = name;
x.id = id;
return x;
}
Arg& AsArg(const std::string& name, int id);
Arg& AsArg(const std::string& name) {
auto& x = AsArg();
x.name = name;
return x;
}
Arg& AsArg(const std::string& name);
Stmt& AsStmt(const std::string& op_type,
std::vector<std::unique_ptr<KernelBase>>&& kernels,
const std::shared_ptr<OpLite>& op) {
auto& x = AsStmt();
x.op_type = op_type;
x.op = op;
x.valid_kernels = std::move(kernels);
x.SetOp(op);
x.SetKernels(std::move(kernels));
return x;
}
......@@ -142,7 +140,7 @@ class Node {
}
if (other.IsStmt()) {
auto& arg = other.AsStmt();
os << "Statement " << arg.op_type;
os << "Statement " << arg.op_type();
}
return os;
}
......
......@@ -139,14 +139,13 @@ struct PMNode {
template <typename T>
PMNode* assert_op_attr(const std::string& attr_name, const T& attr) {
asserts_.emplace_back([=](Node* x) {
asserts_.push_back([=](const Node* x) {
if (x && x->IsStmt()) {
auto* op_info = x->stmt()->op_info();
return op_info->HasAttr(attr_name) &&
op_info->GetAttr<T>(attr_name) == attr;
} else {
return false;
}
return false;
});
return this;
}
......
......@@ -41,6 +41,7 @@ void FuseBase::DeleteInterNodes(SSAGraph *graph) {
}
}
LOG(INFO) << "keys: " << key2nodes_.size();
std::unordered_set<const Node *> nodes2rm;
for (auto &matched : key2nodes_) {
for (const auto &key : keys) {
......
......@@ -49,7 +49,13 @@ class FuseBase {
virtual void BuildPattern() = 0;
// Generate an operator desc with a matched subgraph.
virtual cpp::OpDesc GenOpDesc(const key2nodes_t& matched) = 0;
virtual cpp::OpDesc GenOpDesc(const key2nodes_t& matched) {
return cpp::OpDesc();
}
PMNode* OpNode(const std::string& key) {
return GetOrCreateNode(key)->assert_is_op();
}
PMNode* OpNode(const std::string& key, const std::string& op_type);
......
......@@ -52,7 +52,7 @@ class FcFuser : public FuseBase {
void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override {
auto op_desc = GenOpDesc(matched);
auto fc_op = LiteOpRegistry::Global().Create("fc");
auto mul = matched.at("mul")->stmt()->op;
auto mul = matched.at("mul")->stmt()->op();
auto* scope = mul->scope();
auto& valid_places = mul->valid_places();
fc_op->Attach(op_desc, scope);
......@@ -90,7 +90,7 @@ std::unique_ptr<SSAGraph> BuildGraph(framework::ProgramDesc* program_desc,
main_block->Var("w");
main_block->Var("out");
scope->Var("w")->GetMutable<lite::Tensor>();
scope->Var("x")->GetMutable<lite::Tensor>();
scope->Var("b")->GetMutable<lite::Tensor>();
scope->Var("mul_out")->GetMutable<lite::Tensor>();
scope->Var("w")->GetMutable<lite::Tensor>();
......
......@@ -23,19 +23,19 @@ namespace mir {
void BuildGraph(SSAGraph* g) {
g->mutable_nodes().emplace_back();
Node& o1 = g->mutable_nodes().back();
o1.AsStmt().op_type = "op1";
o1.AsStmt().desc = "op1";
g->mutable_nodes().emplace_back();
Node& o2 = g->mutable_nodes().back();
o2.AsStmt().op_type = "op2";
o2.AsStmt().desc = "op2";
g->mutable_nodes().emplace_back();
Node& o3 = g->mutable_nodes().back();
o3.AsStmt().op_type = "op3";
o3.AsStmt().desc = "op3";
g->mutable_nodes().emplace_back();
Node& o4 = g->mutable_nodes().back();
o4.AsStmt().op_type = "op4";
o4.AsStmt().desc = "op4";
g->mutable_nodes().emplace_back();
Node& o5 = g->mutable_nodes().back();
o5.AsStmt().op_type = "op5";
o5.AsStmt().desc = "op5";
g->mutable_nodes().emplace_back();
Node& v1 = g->mutable_nodes().back();
v1.AsArg("var1");
......@@ -108,11 +108,11 @@ TEST(PatternMatcher, MarkPMNodesInGraph) {
// v2 -> o3(a node named o3)
auto* o2 = x.pattern_.NewNode([](const Node* node) {
// The teller can be any condition, such as op type, or variable's shape.
return node && node->IsStmt() && node->stmt()->op_type == "op2";
return node && node->IsStmt() && node->stmt()->desc == "op2";
});
auto* o3 = x.pattern_.NewNode([](const Node* node) {
// The teller can be any condition, such as op type, or variable's shape.
return node && node->IsStmt() && node->stmt()->op_type == "op3";
return node && node->IsStmt() && node->stmt()->desc == "op3";
});
auto* v2 = x.pattern_.NewNode([](const Node* node) {
// The teller can be any condition, such as op type, or variable's shape.
......@@ -153,8 +153,8 @@ TEST(PatternMatcher, MultiSubgraph) {
// op -> var
auto* any_op = x.mutable_pattern()->NewNode(
[](const Node* node) {
return node->IsStmt() && (node->stmt()->op_type == "op2" ||
node->stmt()->op_type == "op3");
return node->IsStmt() &&
(node->stmt()->desc == "op2" || node->stmt()->desc == "op3");
},
"OP0");
auto* any_var =
......@@ -170,9 +170,9 @@ TEST(PatternMatcher, MultiSubgraph) {
int count = 0;
PatternMatcher::handle_t handle = [&](const PatternMatcher::subgraph_t& s,
SSAGraph* g) {
LOG(INFO) << "Detect " << s.at(any_op)->stmt()->op_type << " -> "
LOG(INFO) << "Detect " << s.at(any_op)->stmt()->desc << " -> "
<< s.at(any_var)->arg()->name << " -> "
<< s.at(any_op1)->stmt()->op_type;
<< s.at(any_op1)->stmt()->desc;
count++;
};
......@@ -197,12 +197,12 @@ TEST(PatternMatcher, IntermediateCheck) {
PatternMatcher matcher;
auto* op2 = matcher.mutable_pattern()->NewNode(
[](const Node* x) {
return x && x->IsStmt() && x->stmt()->op_type == "op2";
return x && x->IsStmt() && x->stmt()->desc == "op2";
},
"op2");
auto* op3 = matcher.mutable_pattern()->NewNode(
[](const Node* x) {
return x && x->IsStmt() && x->stmt()->op_type == "op3";
return x && x->IsStmt() && x->stmt()->desc == "op3";
},
"op3");
auto* v2 = matcher.mutable_pattern()
......
......@@ -65,6 +65,10 @@ class SSAGraph : GraphBase {
Node *GraphCreateInstructNode(const std::shared_ptr<OpLite> &op,
const std::vector<Place> &valid_places);
// Device related attributes
const std::vector<Place> &valid_places() const { return valid_places_; }
void SetValidPlaces(const std::vector<Place> &x) { valid_places_ = x; }
private:
mir::Node *Argument(const std::string &name);
// Check the bidirectional connection.
......@@ -89,6 +93,7 @@ class SSAGraph : GraphBase {
private:
std::list<mir::Node> node_storage_;
std::map<std::string, mir::Node *> arguments_;
std::vector<Place> valid_places_;
};
// Remove the link between a -> b.
......
......@@ -37,9 +37,9 @@ void StaticKernelPickPass::Apply(const std::unique_ptr<SSAGraph>& graph) {
if (!node.IsStmt()) continue;
auto& instruct = node.AsStmt();
std::vector<std::pair<size_t, std::unique_ptr<KernelBase>>> scored;
CHECK(!instruct.valid_kernels.empty()) << "No kernels found for "
<< instruct.op_type;
for (auto&& kernel : instruct.valid_kernels) {
CHECK(!instruct.kernels().empty()) << "No kernels found for "
<< instruct.op_type();
for (auto&& kernel : instruct.kernels()) {
size_t score = KernelGrade(*kernel);
scored.emplace_back(score, std::move(kernel));
}
......@@ -49,9 +49,9 @@ void StaticKernelPickPass::Apply(const std::unique_ptr<SSAGraph>& graph) {
// Move kernel back
// Just keep a single best kernel.
// TODO(Superjomn) reconsider this.
instruct.valid_kernels.clear();
instruct.valid_kernels.emplace_back(std::move(scored.front().second));
VLOG(2) << "pick " << instruct.valid_kernels.front()->name();
instruct.kernels().clear();
instruct.kernels().emplace_back(std::move(scored.front().second));
VLOG(2) << "pick " << instruct.kernels().front()->name();
}
}
......
......@@ -62,7 +62,7 @@ void TypeTargetTransformPass::ComplementInputs(SSAGraph* graph, Node* inst_node,
CHECK(in->AsArg().type);
if (!TargetCompatibleTo(*in->AsArg().type, *decl_arg_type)) {
LOG(INFO) << "found Target unmatched tensor: " << in->AsArg().name
<< " for kernel " << inst.op->DebugString() << " "
<< " for kernel " << inst.op()->DebugString() << " "
<< *in->AsArg().type << " -> " << *decl_arg_type;
// Add an IoCopy instruction to make the input compatible with other dist.
AddIoCopyInst(*in->AsArg().type, *decl_arg_type, in, graph, inst_node,
......@@ -89,7 +89,7 @@ void TypeTargetTransformPass::AddIoCopyInst(
CHECK(io_copy_op) << "create op [" << io_copy_op << "] failed";
// CHECK(io_copy_op);
// Create the new var manually.
inst_node->AsStmt().op->scope()->Var(io_copy_output_name);
inst_node->AsStmt().op()->scope()->Var(io_copy_output_name);
// Create IoCopy Instruction.
cpp::OpDesc op_desc;
......@@ -97,7 +97,7 @@ void TypeTargetTransformPass::AddIoCopyInst(
op_desc.SetInput("Input", {in->AsArg().name});
op_desc.SetOutput("Out", {io_copy_output_name});
io_copy_op->Attach(op_desc, inst_node->AsStmt().op->scope());
io_copy_op->Attach(op_desc, inst_node->AsStmt().op()->scope());
auto kernels = io_copy_op->CreateKernels(valid_places);
io_copy_inst->AsStmt("io_copy", std::move(kernels), io_copy_op);
......@@ -113,19 +113,19 @@ void TypeTargetTransformPass::AddIoCopyInst(
DirectedLink(io_copy_output_arg, inst_node);
// reset opdesc and update kernel information
UpdateInputTo(inst_node->AsStmt().op->mutable_op_info(), in->AsArg().name,
UpdateInputTo(inst_node->AsStmt().op()->mutable_op_info(), in->AsArg().name,
io_copy_output_name);
inst_node->AsStmt().op->Attach(*inst_node->AsStmt().op->op_info(),
inst_node->AsStmt().op->scope());
inst_node->AsStmt().ResetOp(*inst_node->AsStmt().op_info(),
graph->valid_places());
std::string tmp;
if (inst_node->AsStmt().op_info()->GetInputArgname("a", &tmp)) {
CHECK(false) << "get old a " << tmp;
}
for (auto& kernel : inst_node->AsStmt().valid_kernels) {
inst_node->AsStmt().op->AttachKernel(kernel.get());
for (auto& kernel : inst_node->AsStmt().kernels()) {
inst_node->AsStmt().op()->AttachKernel(kernel.get());
}
graph->CheckValid();
......
......@@ -24,9 +24,11 @@ USE_MIR_PASS(generate_program_pass);
USE_MIR_PASS(io_copy_kernel_pick_pass);
USE_MIR_PASS(argument_type_display_pass);
#endif
USE_MIR_PASS(runtime_context_assign_pass);
USE_MIR_PASS(lite_conv_bn_fuse_pass);
USE_MIR_PASS(graph_visualze);
USE_MIR_PASS(lite_fc_fuse_pass);
USE_MIR_PASS(identity_scale_eliminate_pass);
USE_MIR_PASS(lite_conv_elementwise_add_activation_fuse_pass);
USE_MIR_PASS(lite_elementwise_add_activation_fuse_pass);
......@@ -39,7 +39,7 @@ class VariablePlaceInferencePass : public DebugPass {
for (const auto& v : graph->inputs()) {
// the feed op might in the inputs
if (v->IsStmt()) {
LOG(INFO) << "found kernel in inputs " << v->AsStmt().op_type;
LOG(INFO) << "found kernel in inputs " << v->AsStmt().op_type();
continue;
}
}
......@@ -59,10 +59,10 @@ class VariablePlaceInferencePass : public DebugPass {
for (auto& x : graph->StmtTopologicalOrder()) {
auto& inst = x->AsStmt();
// The IoCopyOp is a tool operator, it won't support the type inference.
if (inst.op_type == "io_copy") continue;
if (inst.op_type() == "io_copy") continue;
// LOG(INFO) << "- inferencing type " <<
// deal with inputs
VLOG(4) << "inferencing op " << inst.op_type;
VLOG(4) << "Infering op " << inst.op_info()->Repr();
// TODO(zhaolong): Add check if the node's name in op's arguments.
auto get_argname = [&](
......@@ -90,12 +90,14 @@ class VariablePlaceInferencePass : public DebugPass {
}
}
VLOG(3) << "inst " << inst.op_info()->Repr();
for (auto* x_out : x->outlinks) {
std::string node_name = x_out->AsArg().name;
std::string arg_name =
get_argname(node_name, inst.op_info()->outputs());
CHECK(arg_name.size() > 0) << "can not found op arguments for node "
<< node_name;
<< node_name << " in Inst "
<< inst.op_type();
VLOG(3) << "-- output arg_name " << arg_name;
auto type = inst.picked_kernel().GetOutputDeclType(arg_name);
if (!x_out->AsArg().type) {
......
......@@ -61,7 +61,6 @@ std::vector<std::unique_ptr<KernelBase>> OpLite::CreateKernels(
targets.insert(place.target);
}
// CHECK(!kernels.empty()) << "No kernel found for Op " << op_type_;
VLOG(2) << "op " << op_type_ << " get " << kernels.size() << " kernels";
return kernels;
}
......@@ -83,7 +82,7 @@ bool OpLite::Attach(const cpp::OpDesc &opdesc, lite::Scope *scope) {
scope_ = scope;
op_info_.reset(
new OpInfo(opdesc)); // Force clean the out-of-date infomation.
return AttachImpl(opdesc, scope);
return AttachImpl(*op_info(), scope);
}
const Tensor *OpLite::GetTensor(lite::Scope *scope,
......
......@@ -197,6 +197,22 @@ class OpInfo : public cpp::OpDesc {
}
return false;
}
void UpdateAllInputs(const std::string &from, const std::string &to) {
for (auto &item : inputs_) {
for (auto &var : item.second) {
if (var == from) var = to;
}
}
}
void UpdateAllOutputs(const std::string &from, const std::string &to) {
for (auto &item : outputs_) {
for (auto &var : item.second) {
if (var == from) var = to;
}
}
}
};
} // namespace lite
......
......@@ -43,6 +43,8 @@ class Optimizer {
CHECK(!graph_) << "duplicate optimize found";
graph_.reset(new mir::SSAGraph);
graph_->Build(program, valid_places);
graph_->SetValidPlaces(valid_places);
SpecifyKernelPickTactic(kernel_pick_factor);
InitTargetTypeTransformPass();
......@@ -51,6 +53,7 @@ class Optimizer {
"lite_conv_bn_fuse_pass", //
"lite_conv_elementwise_add_activation_fuse_pass", //
"lite_fc_fuse_pass", //
"identity_scale_eliminate_pass", //
#ifdef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
"lite_elementwise_add_activation_fuse_pass", //
#endif
......
......@@ -140,7 +140,7 @@ class RuntimeProgram {
void Run() {
for (auto& inst : instructions_) {
VLOG(4) << ">> Running kernel: " << inst;
VLOG(3) << ">> Running kernel: " << inst.op()->op_info()->Repr();
inst.Run();
}
}
......
......@@ -191,7 +191,6 @@ class TensorBase {
template <typename TensorT>
bool TensorCompareWith(const TensorT &a, const TensorT &b) {
if (a.dims() != b.dims()) return false;
LOG(INFO) << "data_size: " << a.data_size();
if (memcmp(a.raw_data(), b.raw_data(), a.data_size()) != 0) return false;
return true;
}
......
......@@ -15,6 +15,8 @@
#include "paddle/fluid/lite/kernels/x86/elementwise_compute.h"
#include <gtest/gtest.h>
#include <iostream>
#include <memory>
#include <utility>
#include <vector>
#include "paddle/fluid/lite/core/op_registry.h"
......
......@@ -40,12 +40,20 @@ class MulCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
auto* x = &param.x->raw_tensor();
auto* y = &param.y->raw_tensor();
const Tensor x_matrix = x->dims().size() > 2 ? framework::ReshapeToMatrix(
*x, param.x_num_col_dims)
: *x;
const Tensor y_matrix = y->dims().size() > 2 ? framework::ReshapeToMatrix(
*y, param.y_num_col_dims)
: *y;
Tensor x_matrix, y_matrix;
if (x->dims().size() > 2) {
x_matrix = framework::ReshapeToMatrix(*x, param.x_num_col_dims);
} else {
x_matrix = *x;
}
if (y->dims().size() > 2) {
y_matrix = framework::ReshapeToMatrix(*y, param.y_num_col_dims);
} else {
y_matrix = *y;
}
auto* z = &param.output->raw_tensor();
auto z_dim = z->dims();
......@@ -75,12 +83,22 @@ class MulGradCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
auto* x = &param.x->raw_tensor();
auto* y = &param.y->raw_tensor();
auto x_matrix = x->dims().size() > 2
? framework::ReshapeToMatrix(*x, param.x_num_col_dims)
: static_cast<const Tensor&>(*x);
auto y_matrix = y->dims().size() > 2
? framework::ReshapeToMatrix(*y, param.y_num_col_dims)
: static_cast<const Tensor&>(*y);
Tensor x_matrix, y_matrix;
if (x->dims().size() > 2) {
x_matrix = framework::ReshapeToMatrix(*x, param.x_num_col_dims);
} else {
x_matrix = *x;
}
if (y->dims().size() > 2) {
y_matrix = framework::ReshapeToMatrix(*y, param.y_num_col_dims);
} else {
y_matrix = *y;
}
auto* dout = &param.output_grad->raw_tensor();
Tensor dout_mat;
......
......@@ -15,6 +15,8 @@
#include "paddle/fluid/lite/kernels/x86/mul_compute.h"
#include <gtest/gtest.h>
#include <iostream>
#include <memory>
#include <utility>
#include <vector>
#include "paddle/fluid/lite/core/op_registry.h"
......
......@@ -11,7 +11,8 @@ if(NOT LITE_ON_MOBILE)
endif()
cc_library(compatible_pb_lite SRCS compatible_pb.cc DEPS op_desc_lite framework_proto_lite var_desc_lite)
cc_library(compatible_pb_lite SRCS compatible_pb.cc
DEPS op_desc_lite framework_proto_lite var_desc_lite cpp_op_desc_lite)
lite_cc_library(model_parser_lite SRCS model_parser.cc DEPS
variable_lite scope_lite ${tensor_lite} scope_lite
......
cc_library(cpp_op_desc_lite SRCS op_desc.cc DEPS any_lite)
......@@ -14,6 +14,7 @@
#pragma once
#include <map>
#include <sstream>
#include <string>
#include <vector>
......@@ -79,6 +80,27 @@ class OpDescAPI {
/// Get an attribute.
template <typename T>
T GetAttr(const std::string& name) const;
std::string Repr() const {
std::stringstream ss;
ss << Type();
ss << "(";
for (auto& arg : InputArgumentNames()) {
ss << arg << ":";
for (auto val : Input(arg)) {
ss << val << " ";
}
}
ss << ") -> (";
for (auto& arg : OutputArgumentNames()) {
ss << arg << ":";
for (auto val : Output(arg)) {
ss << val << " ";
}
}
ss << ")";
return ss.str();
}
};
} // namespace lite
......
......@@ -141,6 +141,8 @@ class OpDesc : public OpDescAPI {
template <typename T>
T GetAttr(const std::string &name) const;
std::string DebugString() const { return desc_.DebugString(); }
private:
std::vector<std::string> GetArguments(
const google::protobuf::RepeatedPtrField<framework::proto::OpDesc_Var>
......
......@@ -38,15 +38,19 @@ class MulOpLite : public OpLite {
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
// TODO(Superjomn) replace framework::OpDesc with a lite one.
bool AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) override {
CHECK(!op_desc.Input("X").empty());
CHECK(!op_desc.Input("Y").empty());
CHECK(!op_desc.Output("Out").empty());
auto input = op_desc.Input("X").front();
auto W = op_desc.Input("Y").front();
auto out = op_desc.Output("Out").front();
auto *var = scope->FindVar(input);
CHECK(var);
param_.x = var->GetMutable<Tensor>();
param_.x = &var->Get<Tensor>();
var = scope->FindVar(W);
CHECK(var) << "no var called " << W;
param_.y = var->GetMutable<Tensor>();
param_.y = &var->Get<Tensor>();
var = scope->FindVar(out);
CHECK(var) << "no var called " << out;
param_.output = var->GetMutable<Tensor>();
......
......@@ -67,8 +67,8 @@ struct ReluParam {
// For Mul Op
struct MulParam {
lite::Tensor* x{};
lite::Tensor* y{};
const lite::Tensor* x{};
const lite::Tensor* y{};
lite::Tensor* output{};
int x_num_col_dims{1};
......
......@@ -20,6 +20,7 @@
#include <typeinfo>
#include <utility>
#include "paddle/fluid/lite/utils/cp_logging.h"
#include "paddle/fluid/lite/utils/string.h"
// This is an equivalent implementation of boost::any. We implement this to
// avoid including the whole boost library and keep the inference library small.
......@@ -116,9 +117,9 @@ struct variant {
if (type_id == typeid(T).hash_code())
return *reinterpret_cast<const T*>(&data);
else
throw std::invalid_argument("unmatched type");
// LOG(FATAL) << "unmatched type get, should be " << type_id << " but get "
// << typeid(T).name();
throw std::invalid_argument(
string_format("unmatched type, store as %d, but want to get %s",
type_id, typeid(T).name()));
return *reinterpret_cast<const T*>(&data);
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册