diff --git a/paddle/fluid/lite/core/executor.h b/paddle/fluid/lite/core/executor.h index 19c9e2767c0fac9f3e5a569abe384720e58cc920..d53eb2b90c65e49aaa66133cd706ddef46b9d2ac 100644 --- a/paddle/fluid/lite/core/executor.h +++ b/paddle/fluid/lite/core/executor.h @@ -52,7 +52,7 @@ class Executor { ops_.emplace_back(LiteOpRegistry::Global().Create(op_type)); // pick initial kernel ops_.back()->PickKernel(valid_places_); - ops_.back()->Attach(*op_desc, exec_scope_); + ops_.back()->AttachImpl(*op_desc, exec_scope_); } } diff --git a/paddle/fluid/lite/core/kernel.cc b/paddle/fluid/lite/core/kernel.cc index 557ed2162103af44fd478c90f33a4ae042156698..34e0198296072c9740b2170c9b95f3f6f0d4b44c 100644 --- a/paddle/fluid/lite/core/kernel.cc +++ b/paddle/fluid/lite/core/kernel.cc @@ -17,11 +17,6 @@ namespace paddle { namespace lite { -bool operator==(const Place &a, const Place &b) { - return a.target == b.target && a.precision == b.precision && - a.layout == b.layout; -} - bool operator<(const Place &a, const Place &b) { if (a.target != b.target) return a.target < b.target; diff --git a/paddle/fluid/lite/core/kernel.h b/paddle/fluid/lite/core/kernel.h index b4d230061e9f55af487fd8bef3c90d1a8515b4b9..ee6890aea963132220d863316fcfb0492c9a7bf1 100644 --- a/paddle/fluid/lite/core/kernel.h +++ b/paddle/fluid/lite/core/kernel.h @@ -19,6 +19,7 @@ #include "paddle/fluid/framework/op_desc.h" #include "paddle/fluid/lite/core/context.h" #include "paddle/fluid/lite/core/target_wrapper.h" +#include "paddle/fluid/lite/core/type_system.h" #include "paddle/fluid/lite/core/types.h" #include "paddle/fluid/lite/operators/op_params.h" #include "paddle/fluid/lite/utils/all.h" @@ -51,6 +52,7 @@ class KernelBase { virtual TargetType target() const = 0; virtual PrecisionType precision() const = 0; + virtual DataLayoutType layout() const = 0; virtual ~KernelBase() = default; @@ -66,17 +68,21 @@ class KernelBase { * registered in the `TypeSystem`. */ struct ParamType { + // For unsupported types. size_t element_type_hash{}; Place tensor_place{}; + const Type* type_; ParamType() = default; ParamType(size_t element_type_hash) : element_type_hash(element_type_hash) {} ParamType(size_t element_type_hash, const Place& place) : element_type_hash(element_type_hash), tensor_place(place) {} + ParamType(const Type* type) : type_(type) {} }; /* - * The data types of kernel parameters. + * The data types of kernel parameters. It is used to track the type of kernel's + * inputs and outputs. */ struct ParamTypes { std::vector> inputs; @@ -115,6 +121,8 @@ struct ParamTypes { */ class ParamTypeRegistry { public: + enum class IO : int { kInput = 0, kOutput }; + template /* @@ -130,7 +138,12 @@ class ParamTypeRegistry { NewInstance(const std::string& kernel_type) : kernel_type_(kernel_type) {} NewInstance& BindInput(int offset, const ParamType& ptype) { - ParamTypeRegistry::Global().Register( + ParamTypeRegistry::Global().Register( + kernel_type_, Place{target, precision, layout}, offset, ptype); + return *this; + } + NewInstance& BindOutput(int offset, const ParamType& ptype) { + ParamTypeRegistry::Global().Register( kernel_type_, Place{target, precision, layout}, offset, ptype); return *this; } @@ -141,8 +154,12 @@ class ParamTypeRegistry { std::string kernel_type_; }; + template void Register(const std::string& kernel_type, const Place& place, int offset, - ParamType data_type) {} + ParamType data_type) { + KernelIdTy key{kernel_type, place, io, offset}; + types_[key] = data_type; + } ParamType Retrive(const Place& place, int offset); @@ -155,16 +172,15 @@ class ParamTypeRegistry { ParamTypeRegistry() = default; public: - enum class IO : int { kInput = 0, kOutput }; // Identification for a Kernel. - struct KernelIdT { + struct KernelIdTy { std::string kernel_type; Place place; IO io; int offset; }; - using key_t = KernelIdT; + using key_t = KernelIdTy; struct KeyCmp { bool operator()(const key_t& a, const key_t& b) const; }; @@ -188,6 +204,7 @@ class OpKernel : public KernelBase { TargetType target() const override { return Target; } PrecisionType precision() const override { return Precision; } + DataLayoutType layout() const override { return DataLayout; } void Touch() {} diff --git a/paddle/fluid/lite/core/mir/CMakeLists.txt b/paddle/fluid/lite/core/mir/CMakeLists.txt index 55b75b73ac85fd4ef7b8acf5f3167aaa88a3fcc5..8573a239b7b1e689019c903f1051228053a86ef0 100644 --- a/paddle/fluid/lite/core/mir/CMakeLists.txt +++ b/paddle/fluid/lite/core/mir/CMakeLists.txt @@ -1,3 +1,8 @@ -cc_library(mir_pass SRCS pass.cc) cc_library(mir_node SRCS node.cc) -cc_library(mir_ssa_graph SRCS ssa_graph.cc) \ No newline at end of file +cc_library(mir_ssa_graph SRCS ssa_graph.cc DEPS mir_node) +cc_library(mir_pass SRCS pass.cc DEPS mir_ssa_graph) +cc_library(mir_pass_manager SRCS pass_manager.cc DEPS mir_pass mir_ssa_graph) +cc_library(mir_pass_registry SRCS pass_registry.cc DEPS mir_pass_manager) +cc_library(mir_demo_pass SRCS demo_pass.cc DEPS mir_pass) + +cc_test(test_mir_pass_manager SRCS pass_manager_test.cc DEPS mir_pass_manager mir_demo_pass) diff --git a/paddle/fluid/lite/core/mir/demo_pass.cc b/paddle/fluid/lite/core/mir/demo_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..63dd44af9240d97505c81b8799c053fae35cc0ed --- /dev/null +++ b/paddle/fluid/lite/core/mir/demo_pass.cc @@ -0,0 +1,33 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/core/mir/pass.h" +#include "paddle/fluid/lite/core/mir/pass_registry.h" + +namespace paddle { +namespace lite { +namespace mir { + +class DemoPass : public mir::Pass { + public: + void Apply(std::unique_ptr& graph) override {} +}; + +bool RegisterDemoPass() { + return PassManager::Global().AddNewPass("demo", new DemoPass); +} + +} // namespace mir +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/core/mir/node.cc b/paddle/fluid/lite/core/mir/node.cc index e31f70bfef0de2ef2b098a06fd0f97b17dbd6cec..711ff508f23c7d5218a7d788e90b3fe58f154018 100644 --- a/paddle/fluid/lite/core/mir/node.cc +++ b/paddle/fluid/lite/core/mir/node.cc @@ -1 +1,15 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + #include "paddle/fluid/lite/core/mir/node.h" diff --git a/paddle/fluid/lite/core/mir/node.h b/paddle/fluid/lite/core/mir/node.h index daef285870cd3605c555480fab13056cfdc2fcd0..c1c24bce2fa7b0b7bf2bc4bf1d64dd5836c2fc23 100644 --- a/paddle/fluid/lite/core/mir/node.h +++ b/paddle/fluid/lite/core/mir/node.h @@ -1,15 +1,87 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include "paddle/fluid/lite/core/kernel.h" + namespace paddle { namespace lite { namespace mir { +// Node in a MIR graph. class Node { public: - // Tell is instruction. - bool IsInstruct() const; - // Tell is an argument. - bool IsArgument() const; -}; + std::list inlinks; + std::list outlinks; + + Node() = default; + + enum class Role { + kUnk = -1, + kArgument, + kInstruct, + kNumRoles /*should be last*/ + }; + struct Instruct { + std::string op_type; + Place place; + // The kernel instances this Instruct contains. + std::vector> valid_kernels; + }; + + struct Argument { + std::string name; + Place place; + }; + + // Set roles. + Argument& AsArgument() { + if (role_ != Role::kUnk) { + CHECK(role_ == Role::kArgument); + return *argument_; + } + role_ = Role::kArgument; + argument_.reset(new Argument); + return *argument_; + } + Instruct& AsInstruct() { + if (role_ != Role::kUnk) { + CHECK(role_ == Role::kInstruct); + return *instruct_; + } + role_ = Role::kInstruct; + instruct_.reset(new Instruct); + return *instruct_; + } + // Check roles. + bool IsRoleSet() const { return role_ == Role::kUnk; } + bool IsInstruct() const { return role_ == Role::kInstruct; } + bool IsArgument() const { return role_ == Role::kArgument; } + + private: + // Either instruct_ or argument_ is used. + std::unique_ptr instruct_; + std::unique_ptr argument_; + + Role role_{Role::kUnk}; +}; } // namespace mir } // namespace lite -} // namespace paddle \ No newline at end of file +} // namespace paddle diff --git a/paddle/fluid/lite/core/mir/pass.cc b/paddle/fluid/lite/core/mir/pass.cc index 31b105ac0cc2208087162ec8e1dc6365fd50d0f6..0c2f03a25641ad9f790094b50e45d166016ee4f7 100644 --- a/paddle/fluid/lite/core/mir/pass.cc +++ b/paddle/fluid/lite/core/mir/pass.cc @@ -1 +1,15 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + #include "paddle/fluid/lite/core/mir/pass.h" diff --git a/paddle/fluid/lite/core/mir/pass.h b/paddle/fluid/lite/core/mir/pass.h index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..4264f8937b160cc8e9d24c77a6b3764633e944fb 100644 --- a/paddle/fluid/lite/core/mir/pass.h +++ b/paddle/fluid/lite/core/mir/pass.h @@ -0,0 +1,37 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "paddle/fluid/lite/core/mir/node.h" +#include "paddle/fluid/lite/core/mir/ssa_graph.h" + +namespace paddle { +namespace lite { +namespace mir { + +class Pass { + public: + virtual void Apply(std::unique_ptr& graph) = 0; + + const std::string& name() const { return name_; } + + virtual ~Pass() = default; + + private: + std::string name_; +}; + +} // namespace mir +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/core/mir/pass_manager.cc b/paddle/fluid/lite/core/mir/pass_manager.cc new file mode 100644 index 0000000000000000000000000000000000000000..7d4bb685ac9708371e43605ddf7abb100a8db758 --- /dev/null +++ b/paddle/fluid/lite/core/mir/pass_manager.cc @@ -0,0 +1,30 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/core/mir/pass_manager.h" +#include "paddle/fluid/lite/core/mir/pass_registry.h" + +namespace paddle { +namespace lite { +namespace mir { + +PassManager::PassManager() {} + +// Manually register here. +extern bool RegisterDemoPass(); +static bool xx __attribute__((unused)) = RegisterDemoPass(); + +} // namespace mir +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/core/mir/pass_manager.h b/paddle/fluid/lite/core/mir/pass_manager.h new file mode 100644 index 0000000000000000000000000000000000000000..e2ff6549bd53fbc33e67011845355cd5104841d7 --- /dev/null +++ b/paddle/fluid/lite/core/mir/pass_manager.h @@ -0,0 +1,80 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include + +#include "paddle/fluid/lite/core/mir/pass.h" + +namespace paddle { +namespace lite { +namespace mir { + +class PassManager { + public: + static PassManager& Global() { + static PassManager x; + return x; + } + + PassManager(); + + void Run() { + for (auto& pass : passes_) { + LOG(INFO) << "Running MIR pass " << pass->name(); + pass->Apply(graph_); + } + } + + bool AddNewPass(const std::string& name, Pass* pass) { + passes_.emplace_back(pass); + pass_map_.emplace(name, passes_.back().get()); + return true; + } + + // Clear all the passes. + void Clear() { passes_.clear(); } + + std::list>::iterator passes_begin() { + return passes_.begin(); + } + std::list>::iterator passes_end() { + return passes_.end(); + } + std::list>::const_iterator passes_const_begin() + const { + return passes_.begin(); + } + std::list>::const_iterator passes_const_end() + const { + return passes_.end(); + } + + Pass* LookUp(const std::string& key) { + auto it = pass_map_.find(key); + CHECK(it != pass_map_.end()); + return it->second; + } + + private: + std::unique_ptr graph_; + std::list> passes_; + std::map pass_map_; +}; + +} // namespace mir +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/core/mir/pass_manager_test.cc b/paddle/fluid/lite/core/mir/pass_manager_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..74cf90a49179c3a6c6dc1998223fed535c5e22cc --- /dev/null +++ b/paddle/fluid/lite/core/mir/pass_manager_test.cc @@ -0,0 +1,30 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/core/mir/pass_manager.h" +#include + +namespace paddle { +namespace lite { +namespace mir { + +TEST(PassManager, test) { + auto* pass = PassManager::Global().LookUp("demo"); + LOG(INFO) << "pass: " << pass; + ASSERT_TRUE(pass != nullptr); +} + +} // namespace mir +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/core/mir/pass_registry.cc b/paddle/fluid/lite/core/mir/pass_registry.cc new file mode 100644 index 0000000000000000000000000000000000000000..8b8100638fb540a78899863ca04d64b66f2e865b --- /dev/null +++ b/paddle/fluid/lite/core/mir/pass_registry.cc @@ -0,0 +1,21 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/core/mir/pass_registry.h" + +namespace paddle { +namespace lite { +namespace mir {} // namespace mir +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/core/mir/pass_registry.h b/paddle/fluid/lite/core/mir/pass_registry.h new file mode 100644 index 0000000000000000000000000000000000000000..fc743740aee98c93a25fe1f2709c3fcd464ef28c --- /dev/null +++ b/paddle/fluid/lite/core/mir/pass_registry.h @@ -0,0 +1,37 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/fluid/lite/core/mir/pass.h" +#include "paddle/fluid/lite/core/mir/pass_manager.h" + +namespace paddle { +namespace lite { +namespace mir { + +class PassRegistry { + public: + PassRegistry(const std::string& name, mir::Pass* pass) { + LOG(INFO) << "Registry add MIR pass " << name; + PassManager::Global().AddNewPass(name, pass); + } + + bool Touch() const { return true; } +}; + +} // namespace mir +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/core/mir/ssa_graph.cc b/paddle/fluid/lite/core/mir/ssa_graph.cc index 56a60041957c3f41aed3450b2aa2dfbbf3643ac2..54d570cab4b69331a7888463536eab331f2864b4 100644 --- a/paddle/fluid/lite/core/mir/ssa_graph.cc +++ b/paddle/fluid/lite/core/mir/ssa_graph.cc @@ -1 +1,15 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + #include "paddle/fluid/lite/core/mir/ssa_graph.h" diff --git a/paddle/fluid/lite/core/mir/ssa_graph.h b/paddle/fluid/lite/core/mir/ssa_graph.h index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..d7524c1cb63997ba9464e819bdeaa707b17db241 100644 --- a/paddle/fluid/lite/core/mir/ssa_graph.h +++ b/paddle/fluid/lite/core/mir/ssa_graph.h @@ -0,0 +1,74 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include "paddle/fluid/lite/core/mir/node.h" +#include "paddle/fluid/lite/core/op_lite.h" + +namespace paddle { +namespace lite { +namespace mir { + +// A program is used to represent a code program, in Paddle, a code program +// contains: +// - main block, which is a list of OpLite +// - scope: which contains all the weights +struct Program { + std::list> ops; + lite::Scope *scope; +}; + +// An Graph for MIR. It is built from a list of Op and a scope. +class GraphBase {}; + +class SSAGraph : GraphBase { + public: + // @param program: the op program + // @param valid_places: the valid places user set for the system. + void Build(const Program &program, const std::vector &valid_places) { + for (auto &op : program.ops) { + node_storage_.emplace_back(); + // TODO(Superjomn) remove one valid_places here. + op->SetValidPlaces(valid_places); + auto &new_node = node_storage_.back(); + auto &new_kernel = node_storage_.back().AsInstruct(); + new_kernel.valid_kernels = op->CreateKernels(valid_places); + + CHECK(new_node.inlinks.empty()) << "duplicate Build found"; + CHECK(new_node.outlinks.empty()) << "duplicate Build found"; + // collect inputs and outputs + for (const std::string &name : op->input_names()) { + new_node.inlinks.push_back(arguments_.at(name)); + } + for (const std::string &name : op->output_names()) { + new_node.outlinks.push_back(arguments_.at(name)); + } + } + } + + std::vector TopoloticalOrder() const; + + private: + std::list node_storage_; + std::map arguments_; +}; + +} // namespace mir +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/core/op_lite.cc b/paddle/fluid/lite/core/op_lite.cc index 4e491cd8cd6bb70292a2954577a4903ae6e7ea1c..a57ee119cc8136cb10faaca017e5ed71eca5f5d6 100644 --- a/paddle/fluid/lite/core/op_lite.cc +++ b/paddle/fluid/lite/core/op_lite.cc @@ -44,5 +44,15 @@ void OpLite::PickKernel(const std::vector &valid_places, } } +bool OpLite::Run() { + CHECK(kernel_); + SyncInputEvents(); + + kernel_->Run(); + + RecordOutputEvents(); + return true; +} + } // namespace lite } // namespace paddle diff --git a/paddle/fluid/lite/core/op_lite.h b/paddle/fluid/lite/core/op_lite.h index 45985e8e575c1d0ed3d56e7e443e77b9e9866e14..0ab0550be08f46f545d8ed177b642311b2e2c97e 100644 --- a/paddle/fluid/lite/core/op_lite.h +++ b/paddle/fluid/lite/core/op_lite.h @@ -36,6 +36,11 @@ struct Registry { void Touch() {} }; +namespace mir { +class Node; +class SSAGraph; +} + /** * The base class of an light-weight operators, currently just used in inference * to eliminate overhead of some operations in current framework. @@ -71,19 +76,13 @@ class OpLite : public Registry { // Inference the outputs' shape. virtual bool InferShape() const { return true; } // Run this operator. - virtual bool Run() { - CHECK(kernel_); - SyncInputEvents(); - - kernel_->Run(); + virtual bool Run(); - RecordOutputEvents(); - return true; + bool Attach(const framework::OpDesc &opdesc, lite::Scope *scope) { + ExtractInputsAndOutputs(opdesc); + return AttachImpl(opdesc, scope); } - // Attach it with the runtime environment. - virtual bool Attach(const framework::OpDesc &opdesc, lite::Scope *scope) = 0; - // Human-readable information. virtual std::string DebugString() const = 0; @@ -92,9 +91,29 @@ class OpLite : public Registry { void PickKernel(const std::vector &valid_places, KernelStrategy kernel_strategy = KernelStrategy::kStatic); + const std::list &input_names() const { return input_names_; } + const std::list &output_names() const { return output_names_; } + virtual ~OpLite() = default; protected: + // Attach it with the runtime environment. + virtual bool AttachImpl(const framework::OpDesc &opdesc, + lite::Scope *scope) = 0; + + void ExtractInputsAndOutputs(const framework::OpDesc &opdesc) { + for (const auto &item : opdesc.Inputs()) { + for (const auto &x : item.second) { + input_names_.push_back(x); + } + } + for (const auto &item : opdesc.Outputs()) { + for (const auto &x : item.second) { + output_names_.push_back(x); + } + } + } + // Specify the kernel to run by default. This will specify the value of // `kernel_place_`. virtual void StaticPickKernel(const std::vector &valid_targets) { @@ -113,12 +132,17 @@ class OpLite : public Registry { std::vector> CreateKernels( const std::vector &places, const std::string &kernel_type = ""); + friend class mir::Node; + friend class mir::SSAGraph; + protected: std::unique_ptr op_context_; std::unique_ptr kernel_; std::string op_type_; std::vector valid_places_; Place kernel_place_{TARGET(kHost), PRECISION(kFloat)}; + std::list input_names_; + std::list output_names_; }; } // namespace lite diff --git a/paddle/fluid/lite/core/type_system.cc b/paddle/fluid/lite/core/type_system.cc index ed8cb29ad806c9bf55e04db3f9ad387cae690b24..a558383f7d232f5e0fa1c24b7e5c4971cd5f50a1 100644 --- a/paddle/fluid/lite/core/type_system.cc +++ b/paddle/fluid/lite/core/type_system.cc @@ -13,3 +13,46 @@ // limitations under the License. #include "paddle/fluid/lite/core/type_system.h" + +namespace paddle { +namespace lite { + +// ------------------------- GetType specification ---------------------------- +template <> +const Type* +Type::Get() { + static UnsupportedTy x; + return &x; +} + +template <> +const Type* +Type::Get() { + static TensorFp32NCHWTy x(TargetType::kX86); + return &x; +} + +template <> +const Type* Type::Get(TargetType target, int device) { + return Get(); +} + +template <> +const Type* Type::Get(TargetType target) { + switch (target) { + case TargetType::kX86: + return Get(); + default: + LOG(FATAL) << "unsupported target " << TargetToStr(target); + return nullptr; + } +} + +// ------------------------- end GetType specification ------------------------ + +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/core/type_system.h b/paddle/fluid/lite/core/type_system.h index aabd50bb9243d6461534a8c2e5d6d04d55c445cf..d9107afff8d21714137d12110e302412be7cf42b 100644 --- a/paddle/fluid/lite/core/type_system.h +++ b/paddle/fluid/lite/core/type_system.h @@ -82,7 +82,7 @@ class DataTypeBase { * Datatype with device info considered. * NOTE A Type with different device is treated as different DeviceDataType. */ -class DeviceDataType : public DataTypeBase { +class Type : public DataTypeBase { public: TargetType target() const { return place_.target; } PrecisionType precision() const { return place_.precision; } @@ -90,23 +90,31 @@ class DeviceDataType : public DataTypeBase { const Place& place() const { return place_; } const std::string& name() const { return name_; } - bool operator==(const DeviceDataType& other) { + bool operator==(const Type& other) { return id_ == other.id() && place_ == other.place(); } // Can cast to another type. This is heavily used in MIR, by determine whether // is is possible to add a instruction to transform a type to another. - virtual bool TypeCastable(const DeviceDataType& type) const { - return id_ == type.id(); - } + virtual bool TypeCastable(const Type& type) const { return id_ == type.id(); } + + template + // Get a type. + static const Type* Get(); + + template + static const Type* Get(TargetType target = TargetType::kHost); - virtual ~DeviceDataType() = default; + virtual ~Type() = default; protected: - DeviceDataType(ID id, const std::string& name, bool is_tensor, - TargetType target = TargetType::kHost, - PrecisionType precision = PrecisionType::kFloat, - DataLayoutType layout = DataLayoutType::kNCHW) + Type(ID id, const std::string& name, bool is_tensor, + TargetType target = TargetType::kHost, + PrecisionType precision = PrecisionType::kFloat, + DataLayoutType layout = DataLayoutType::kNCHW) : DataTypeBase(id, is_tensor), place_{target, precision, layout}, name_(name) {} @@ -117,30 +125,33 @@ class DeviceDataType : public DataTypeBase { }; // -------------------------------- predefined types --------------------------- -class Void : public DeviceDataType { +// TODO(Superjomn) make all the Types' constructs protected to make sure there +// is only one instance across the system. +class VoidTy : public Type { + public: + VoidTy() : Type(ID::Void, "Void", false /*is_tensor*/) {} +}; +class UnsupportedTy : public Type { public: - Void() : DeviceDataType(ID::Void, "Void", false /*is_tensor*/) {} + UnsupportedTy() : Type(ID::Unsupported, "Unsupported", false /*is_tensor*/) {} }; -class TensorFp32NCHW : public DeviceDataType { +class TensorFp32NCHWTy : public Type { public: - TensorFp32NCHW(TargetType target) - : DeviceDataType(ID::Tensor_Fp32_NCHW, "TensorFp32NCHW", - true /*is_tensor*/, target, PrecisionType::kFloat, - DataLayoutType::kNCHW) {} + TensorFp32NCHWTy(TargetType target) + : Type(ID::Tensor_Fp32_NCHW, "TensorFp32NCHW", true /*is_tensor*/, target, + PrecisionType::kFloat, DataLayoutType::kNCHW) {} }; -class TensorInt8NCHW : public DeviceDataType { +class TensorInt8NCHWTy : public Type { public: - TensorInt8NCHW(TargetType target) - : DeviceDataType(ID::Tensor_Int8_NCHW, "TensorInt8NCHW", - true /*is_tensor*/, target, PrecisionType::kInt8, - DataLayoutType::kNCHW) {} + TensorInt8NCHWTy(TargetType target) + : Type(ID::Tensor_Int8_NCHW, "TensorInt8NCHW", true /*is_tensor*/, target, + PrecisionType::kInt8, DataLayoutType::kNCHW) {} }; -class TensorInt64NCHW : public DeviceDataType { +class TensorInt64NCHWTy : public Type { public: - TensorInt64NCHW(TargetType target) - : DeviceDataType(ID::Tensor_Int64_NCHW, "TensorInt64NCHW", - true /*is_tensor*/, target, PrecisionType::kInt8, - DataLayoutType::kNCHW) {} + TensorInt64NCHWTy(TargetType target) + : Type(ID::Tensor_Int64_NCHW, "TensorInt64NCHW", true /*is_tensor*/, + target, PrecisionType::kInt8, DataLayoutType::kNCHW) {} }; // ------------------------- end predefined types --------------------------- diff --git a/paddle/fluid/lite/kernels/host/fc_compute.cc b/paddle/fluid/lite/kernels/host/fc_compute.cc index 42756bd23e3722c511b91e24658032b7f19de756..bbbc13e30ab08d441f9c30e850d72e7b1d67095f 100644 --- a/paddle/fluid/lite/kernels/host/fc_compute.cc +++ b/paddle/fluid/lite/kernels/host/fc_compute.cc @@ -15,6 +15,7 @@ #include "paddle/fluid/lite/kernels/host/fc_compute.h" #include #include "paddle/fluid/lite/core/op_registry.h" +#include "paddle/fluid/lite/core/type_system.h" namespace paddle { namespace lite { @@ -51,6 +52,8 @@ PrecisionType FcCompute::precision() const { return PRECISION(kFloat); } } // namespace paddle REGISTER_LITE_KERNEL(fc, kHost, kFloat, paddle::lite::kernels::host::FcCompute) - .BindInput(0, {typeid(paddle::lite::Tensor).hash_code(), - paddle::lite::Place{TARGET(kHost), PRECISION(kFloat)}}) + .BindInput(0, {paddle::lite::Type::Get( + TARGET(kX86))}) + .BindOutput(0, {paddle::lite::Type::Get( + TARGET(kX86))}) .Finalize(); diff --git a/paddle/fluid/lite/operators/fc_op.h b/paddle/fluid/lite/operators/fc_op.h index 71cd815184012be3072c37acb88b970120f72c60..cd8e5064639e0d8dfa69af8c148158833e0bab6c 100644 --- a/paddle/fluid/lite/operators/fc_op.h +++ b/paddle/fluid/lite/operators/fc_op.h @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +#pragma once + #include #include #include "paddle/fluid/lite/core/kernel.h" @@ -44,7 +46,8 @@ class FcOpLite : public OpLite { */ // TODO(Superjomn) replace framework::OpDesc with a lite one. - bool Attach(const framework::OpDesc &op_desc, lite::Scope *scope) override { + bool AttachImpl(const framework::OpDesc &op_desc, + lite::Scope *scope) override { auto input = op_desc.Input("Input").front(); auto W = op_desc.Input("W").front(); auto bias = op_desc.Input("Bias").front(); diff --git a/paddle/fluid/lite/operators/fc_op_test.cc b/paddle/fluid/lite/operators/fc_op_test.cc index b191469ff3c2f9784fb3c047f33fa949f56f8885..54914b5ab1918992515bbc3d23b64e4df1d1c496 100644 --- a/paddle/fluid/lite/operators/fc_op_test.cc +++ b/paddle/fluid/lite/operators/fc_op_test.cc @@ -61,7 +61,7 @@ TEST(fc_op_lite, test) { fc.SetValidPlaces({Place{TARGET(kHost), PRECISION(kFloat)}}); fc.PickKernel({Place{TARGET(kHost), PRECISION(kFloat)}}); - fc.Attach(desc, &scope); + fc.AttachImpl(desc, &scope); fc.Run(); for (int i = 0; i < 10 * 20; i++) { diff --git a/paddle/fluid/lite/operators/mul_op.h b/paddle/fluid/lite/operators/mul_op.h index 392d3e72beca3fa0a18e1a387e36a987b2fd194f..effab7b8cbde9f044081259114ed4ed78f7e6d3b 100644 --- a/paddle/fluid/lite/operators/mul_op.h +++ b/paddle/fluid/lite/operators/mul_op.h @@ -37,7 +37,8 @@ class MulOpLite : public OpLite { bool InferShape() const override; // TODO(Superjomn) replace framework::OpDesc with a lite one. - bool Attach(const framework::OpDesc &op_desc, lite::Scope *scope) override { + bool AttachImpl(const framework::OpDesc &op_desc, + lite::Scope *scope) override { auto input = op_desc.Input("X").front(); auto W = op_desc.Input("Y").front(); auto out = op_desc.Output("Out").front(); diff --git a/paddle/fluid/lite/operators/relu_op.cc b/paddle/fluid/lite/operators/relu_op.cc index c46ccd9edea2afb8ad40dadd1fc3409eb4008c51..ea3dea6585de523e6e6f81ad04a236b9108c81fa 100644 --- a/paddle/fluid/lite/operators/relu_op.cc +++ b/paddle/fluid/lite/operators/relu_op.cc @@ -31,7 +31,7 @@ bool ReluOp::InferShape() const { return true; } -bool ReluOp::Attach(const framework::OpDesc &opdesc, lite::Scope *scope) { +bool ReluOp::AttachImpl(const framework::OpDesc &opdesc, lite::Scope *scope) { param_.input = const_cast( &scope->FindVar(opdesc.Input("Input").front())->Get()); param_.output = diff --git a/paddle/fluid/lite/operators/relu_op.h b/paddle/fluid/lite/operators/relu_op.h index da8553000b5d59b5445d67dc0f4fb497de556f32..3a73b45a273341d090b7879de1bed7d2e60b4569 100644 --- a/paddle/fluid/lite/operators/relu_op.h +++ b/paddle/fluid/lite/operators/relu_op.h @@ -32,7 +32,7 @@ class ReluOp : public OpLite { bool InferShape() const override; - bool Attach(const framework::OpDesc &opdesc, lite::Scope *scope) override; + bool AttachImpl(const framework::OpDesc &opdesc, lite::Scope *scope) override; std::string DebugString() const override { return "tanh"; } diff --git a/paddle/fluid/lite/operators/scale_op.cc b/paddle/fluid/lite/operators/scale_op.cc index 1371c01277e47e31ce7d1302114d70850640c36e..eb99fec3f46b3171f64e391cf722fb42ac993edd 100644 --- a/paddle/fluid/lite/operators/scale_op.cc +++ b/paddle/fluid/lite/operators/scale_op.cc @@ -44,7 +44,8 @@ class ScaleOp : public OpLite { } // TODO(Superjomn) replace framework::OpDesc with a lite one. - bool Attach(const framework::OpDesc &op_desc, lite::Scope *scope) override { + bool AttachImpl(const framework::OpDesc &op_desc, + lite::Scope *scope) override { auto x = op_desc.Input("X").front(); auto out = op_desc.Output("Out").front();