diff --git a/paddle/fluid/lite/core/mir/CMakeLists.txt b/paddle/fluid/lite/core/mir/CMakeLists.txt index 7b7bc73ef1e1c11954895682177a973a490d5cc2..1a2a0d3697df6f0a9d33a6761688b80a5e9cca08 100644 --- a/paddle/fluid/lite/core/mir/CMakeLists.txt +++ b/paddle/fluid/lite/core/mir/CMakeLists.txt @@ -6,7 +6,7 @@ cc_library(mir_pass_registry SRCS pass_registry.cc DEPS mir_pass_manager) cc_library(mir_passes SRCS static_kernel_pick_pass.cc variable_place_inference_pass.cc - io_complement_pass.cc + type_target_transform_pass.cc io_copy_kernel_pick_pass.cc graph_visualize_pass.cc generate_program_pass.cc diff --git a/paddle/fluid/lite/core/mir/passes.h b/paddle/fluid/lite/core/mir/passes.h index d81cdd7d01e9972a51a37d4f6a918451ee03e144..60e53257ba01006e71095faa62b083d47e894c60 100644 --- a/paddle/fluid/lite/core/mir/passes.h +++ b/paddle/fluid/lite/core/mir/passes.h @@ -24,7 +24,7 @@ namespace mir {} // namespace mir USE_MIR_PASS(demo); USE_MIR_PASS(static_kernel_pick_pass); USE_MIR_PASS(variable_place_inference_pass); -USE_MIR_PASS(io_complement_pass); +USE_MIR_PASS(type_target_transform_pass); USE_MIR_PASS(generate_program_pass); USE_MIR_PASS(io_copy_kernel_pick_pass); USE_MIR_PASS(argument_type_display_pass); diff --git a/paddle/fluid/lite/core/mir/io_complement_pass.cc b/paddle/fluid/lite/core/mir/type_target_transform_pass.cc similarity index 86% rename from paddle/fluid/lite/core/mir/io_complement_pass.cc rename to paddle/fluid/lite/core/mir/type_target_transform_pass.cc index d01d08faefa431cd90bb22b1e8452a8018d7ed59..34762cf40c52b1a69ea00036f9120da4224b1513 100644 --- a/paddle/fluid/lite/core/mir/io_complement_pass.cc +++ b/paddle/fluid/lite/core/mir/type_target_transform_pass.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/lite/core/mir/io_complement_pass.h" +#include "paddle/fluid/lite/core/mir/type_target_transform_pass.h" #include #include #include @@ -23,7 +23,7 @@ namespace paddle { namespace lite { namespace mir { -void IoComplementPass::Apply(std::unique_ptr& graph) { +void TypeTargetTransformPass::Apply(std::unique_ptr& graph) { // Start from inputs of the graph, those should have place set. std::list nodes; for (auto& node : graph->mutable_nodes()) { @@ -42,8 +42,8 @@ void IoComplementPass::Apply(std::unique_ptr& graph) { VLOG(3) << "\n" << Visualize(graph.get()); } -void IoComplementPass::ComplementInputs(SSAGraph* graph, Node* inst_node, - Node* in) { +void TypeTargetTransformPass::ComplementInputs(SSAGraph* graph, Node* inst_node, + Node* in) { // If this input is out of date. if (inst_node->inlinks.end() == std::find(inst_node->inlinks.begin(), inst_node->inlinks.end(), in)) @@ -68,10 +68,9 @@ void IoComplementPass::ComplementInputs(SSAGraph* graph, Node* inst_node, } } -void IoComplementPass::AddIoCopyInst(const Type& from, const Type& to, - const std::string& var, SSAGraph* graph, - Node* inst_node, - const std::vector& valid_places) { +void TypeTargetTransformPass::AddIoCopyInst( + const Type& from, const Type& to, const std::string& var, SSAGraph* graph, + Node* inst_node, const std::vector& valid_places) { CHECK(!valid_places.empty()) << "valid_place should be set"; // var -> new_transform_op -> new_var -> inst // So there will be a new Argument node and a new IoCopy Instruct Node. @@ -131,7 +130,8 @@ void IoComplementPass::AddIoCopyInst(const Type& from, const Type& to, graph->CheckValid(); } -void IoComplementPass::SetValidPlaces(const std::vector& valid_places) { +void TypeTargetTransformPass::SetValidPlaces( + const std::vector& valid_places) { CHECK(!valid_places.empty()); valid_places_ = valid_places; } @@ -140,4 +140,5 @@ void IoComplementPass::SetValidPlaces(const std::vector& valid_places) { } // namespace lite } // namespace paddle -REGISTER_MIR_PASS(io_complement_pass, paddle::lite::mir::IoComplementPass); +REGISTER_MIR_PASS(type_target_transform_pass, + paddle::lite::mir::TypeTargetTransformPass); diff --git a/paddle/fluid/lite/core/mir/io_complement_pass.h b/paddle/fluid/lite/core/mir/type_target_transform_pass.h similarity index 94% rename from paddle/fluid/lite/core/mir/io_complement_pass.h rename to paddle/fluid/lite/core/mir/type_target_transform_pass.h index b1ae1846263d7016cbd9b08080175667943c8407..4f3f0c1c148f1dbe486600856968553cbc1af439 100644 --- a/paddle/fluid/lite/core/mir/io_complement_pass.h +++ b/paddle/fluid/lite/core/mir/type_target_transform_pass.h @@ -14,6 +14,8 @@ #pragma once +#include +#include #include "paddle/fluid/lite/core/mir/pass.h" #include "paddle/fluid/lite/core/op_registry.h" @@ -36,7 +38,7 @@ static void UpdateInputTo(framework::proto::OpDesc* desc, * IoComplementPass complement the necessary instruction to make data * transferring or transformation between different places. */ -class IoComplementPass : public ProgramPass { +class TypeTargetTransformPass : public ProgramPass { public: void Apply(std::unique_ptr& graph) override; @@ -48,7 +50,7 @@ class IoComplementPass : public ProgramPass { void SetValidPlaces(const std::vector& valid_places); - const std::vector& valid_places() const { return valid_places_; }; + const std::vector& valid_places() const { return valid_places_; } private: std::vector valid_places_; diff --git a/paddle/fluid/lite/core/mir/variable_place_inference_pass_test.cc b/paddle/fluid/lite/core/mir/variable_place_inference_pass_test.cc index 76d8f8491729bfa8ce7ba7e4eaf270b0cf18807b..8a71bbd2a3959dfe07d18f48410187b6e31cc358 100644 --- a/paddle/fluid/lite/core/mir/variable_place_inference_pass_test.cc +++ b/paddle/fluid/lite/core/mir/variable_place_inference_pass_test.cc @@ -54,7 +54,7 @@ TEST(variable_place_inference_pass, test) { "argument_type_display_pass", // "variable_place_inference_pass", // "argument_type_display_pass", // - "io_complement_pass", // + "type_target_transform_pass", // }); Place prefered_place{ diff --git a/paddle/fluid/lite/core/optimizer.cc b/paddle/fluid/lite/core/optimizer.cc index 96f3a05352cdb443d0426a246cf2a2a758310897..bb9fb5fe06760f2f7078b893157f6f1f65d058a8 100644 --- a/paddle/fluid/lite/core/optimizer.cc +++ b/paddle/fluid/lite/core/optimizer.cc @@ -13,8 +13,8 @@ // limitations under the License. #include "paddle/fluid/lite/core/optimizer.h" -#include "paddle/fluid/lite/core/mir/io_complement_pass.h" #include "paddle/fluid/lite/core/mir/static_kernel_pick_pass.h" +#include "paddle/fluid/lite/core/mir/type_target_transform_pass.h" namespace paddle { namespace lite { diff --git a/paddle/fluid/lite/core/optimizer.h b/paddle/fluid/lite/core/optimizer.h index c72bd740547a2037c4d2eb645f723ba2bc0bb7ed..4da5331fda8644c81224c672f86b03903105e1e6 100644 --- a/paddle/fluid/lite/core/optimizer.h +++ b/paddle/fluid/lite/core/optimizer.h @@ -16,10 +16,10 @@ #include #include #include "paddle/fluid/lite/core/mir/generate_program_pass.h" -#include "paddle/fluid/lite/core/mir/io_complement_pass.h" #include "paddle/fluid/lite/core/mir/pass_manager.h" #include "paddle/fluid/lite/core/mir/ssa_graph.h" #include "paddle/fluid/lite/core/mir/static_kernel_pick_pass.h" +#include "paddle/fluid/lite/core/mir/type_target_transform_pass.h" #include "paddle/fluid/lite/core/program.h" #include "paddle/fluid/lite/core/types.h" @@ -48,7 +48,7 @@ class Optimizer { "static_kernel_pick_pass", // "variable_place_inference_pass", // "argument_type_display_pass", // - "io_complement_pass", // + "type_target_transform_pass", // "argument_type_display_pass", // "variable_place_inference_pass", // "argument_type_display_pass", // @@ -83,8 +83,9 @@ class Optimizer { } void InitIoComplement() { - auto* pass = mir::PassManager::Global().LookUp( - "io_complement_pass"); + auto* pass = + mir::PassManager::Global().LookUp( + "type_target_transform_pass"); CHECK(pass); CHECK(!valid_places_.empty()); LOG(INFO) << "valid_places.size " << valid_places_.size();