提交 0fb566b7 编写于 作者: S superjomn

rename io_complement_pass to type_target_transform_pass

上级 72b734e4
...@@ -6,7 +6,7 @@ cc_library(mir_pass_registry SRCS pass_registry.cc DEPS mir_pass_manager) ...@@ -6,7 +6,7 @@ cc_library(mir_pass_registry SRCS pass_registry.cc DEPS mir_pass_manager)
cc_library(mir_passes cc_library(mir_passes
SRCS static_kernel_pick_pass.cc SRCS static_kernel_pick_pass.cc
variable_place_inference_pass.cc variable_place_inference_pass.cc
io_complement_pass.cc type_target_transform_pass.cc
io_copy_kernel_pick_pass.cc io_copy_kernel_pick_pass.cc
graph_visualize_pass.cc graph_visualize_pass.cc
generate_program_pass.cc generate_program_pass.cc
......
...@@ -24,7 +24,7 @@ namespace mir {} // namespace mir ...@@ -24,7 +24,7 @@ namespace mir {} // namespace mir
USE_MIR_PASS(demo); USE_MIR_PASS(demo);
USE_MIR_PASS(static_kernel_pick_pass); USE_MIR_PASS(static_kernel_pick_pass);
USE_MIR_PASS(variable_place_inference_pass); USE_MIR_PASS(variable_place_inference_pass);
USE_MIR_PASS(io_complement_pass); USE_MIR_PASS(type_target_transform_pass);
USE_MIR_PASS(generate_program_pass); USE_MIR_PASS(generate_program_pass);
USE_MIR_PASS(io_copy_kernel_pick_pass); USE_MIR_PASS(io_copy_kernel_pick_pass);
USE_MIR_PASS(argument_type_display_pass); USE_MIR_PASS(argument_type_display_pass);
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/lite/core/mir/io_complement_pass.h" #include "paddle/fluid/lite/core/mir/type_target_transform_pass.h"
#include <list> #include <list>
#include <string> #include <string>
#include <vector> #include <vector>
...@@ -23,7 +23,7 @@ namespace paddle { ...@@ -23,7 +23,7 @@ namespace paddle {
namespace lite { namespace lite {
namespace mir { namespace mir {
void IoComplementPass::Apply(std::unique_ptr<mir::SSAGraph>& graph) { void TypeTargetTransformPass::Apply(std::unique_ptr<mir::SSAGraph>& graph) {
// Start from inputs of the graph, those should have place set. // Start from inputs of the graph, those should have place set.
std::list<Node*> nodes; std::list<Node*> nodes;
for (auto& node : graph->mutable_nodes()) { for (auto& node : graph->mutable_nodes()) {
...@@ -42,8 +42,8 @@ void IoComplementPass::Apply(std::unique_ptr<mir::SSAGraph>& graph) { ...@@ -42,8 +42,8 @@ void IoComplementPass::Apply(std::unique_ptr<mir::SSAGraph>& graph) {
VLOG(3) << "\n" << Visualize(graph.get()); VLOG(3) << "\n" << Visualize(graph.get());
} }
void IoComplementPass::ComplementInputs(SSAGraph* graph, Node* inst_node, void TypeTargetTransformPass::ComplementInputs(SSAGraph* graph, Node* inst_node,
Node* in) { Node* in) {
// If this input is out of date. // If this input is out of date.
if (inst_node->inlinks.end() == if (inst_node->inlinks.end() ==
std::find(inst_node->inlinks.begin(), inst_node->inlinks.end(), in)) std::find(inst_node->inlinks.begin(), inst_node->inlinks.end(), in))
...@@ -68,10 +68,9 @@ void IoComplementPass::ComplementInputs(SSAGraph* graph, Node* inst_node, ...@@ -68,10 +68,9 @@ void IoComplementPass::ComplementInputs(SSAGraph* graph, Node* inst_node,
} }
} }
void IoComplementPass::AddIoCopyInst(const Type& from, const Type& to, void TypeTargetTransformPass::AddIoCopyInst(
const std::string& var, SSAGraph* graph, const Type& from, const Type& to, const std::string& var, SSAGraph* graph,
Node* inst_node, Node* inst_node, const std::vector<Place>& valid_places) {
const std::vector<Place>& valid_places) {
CHECK(!valid_places.empty()) << "valid_place should be set"; CHECK(!valid_places.empty()) << "valid_place should be set";
// var -> new_transform_op -> new_var -> inst // var -> new_transform_op -> new_var -> inst
// So there will be a new Argument node and a new IoCopy Instruct Node. // So there will be a new Argument node and a new IoCopy Instruct Node.
...@@ -131,7 +130,8 @@ void IoComplementPass::AddIoCopyInst(const Type& from, const Type& to, ...@@ -131,7 +130,8 @@ void IoComplementPass::AddIoCopyInst(const Type& from, const Type& to,
graph->CheckValid(); graph->CheckValid();
} }
void IoComplementPass::SetValidPlaces(const std::vector<Place>& valid_places) { void TypeTargetTransformPass::SetValidPlaces(
const std::vector<Place>& valid_places) {
CHECK(!valid_places.empty()); CHECK(!valid_places.empty());
valid_places_ = valid_places; valid_places_ = valid_places;
} }
...@@ -140,4 +140,5 @@ void IoComplementPass::SetValidPlaces(const std::vector<Place>& valid_places) { ...@@ -140,4 +140,5 @@ void IoComplementPass::SetValidPlaces(const std::vector<Place>& valid_places) {
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
REGISTER_MIR_PASS(io_complement_pass, paddle::lite::mir::IoComplementPass); REGISTER_MIR_PASS(type_target_transform_pass,
paddle::lite::mir::TypeTargetTransformPass);
...@@ -14,6 +14,8 @@ ...@@ -14,6 +14,8 @@
#pragma once #pragma once
#include <string>
#include <vector>
#include "paddle/fluid/lite/core/mir/pass.h" #include "paddle/fluid/lite/core/mir/pass.h"
#include "paddle/fluid/lite/core/op_registry.h" #include "paddle/fluid/lite/core/op_registry.h"
...@@ -36,7 +38,7 @@ static void UpdateInputTo(framework::proto::OpDesc* desc, ...@@ -36,7 +38,7 @@ static void UpdateInputTo(framework::proto::OpDesc* desc,
* IoComplementPass complement the necessary instruction to make data * IoComplementPass complement the necessary instruction to make data
* transferring or transformation between different places. * transferring or transformation between different places.
*/ */
class IoComplementPass : public ProgramPass { class TypeTargetTransformPass : public ProgramPass {
public: public:
void Apply(std::unique_ptr<mir::SSAGraph>& graph) override; void Apply(std::unique_ptr<mir::SSAGraph>& graph) override;
...@@ -48,7 +50,7 @@ class IoComplementPass : public ProgramPass { ...@@ -48,7 +50,7 @@ class IoComplementPass : public ProgramPass {
void SetValidPlaces(const std::vector<Place>& valid_places); void SetValidPlaces(const std::vector<Place>& valid_places);
const std::vector<Place>& valid_places() const { return valid_places_; }; const std::vector<Place>& valid_places() const { return valid_places_; }
private: private:
std::vector<Place> valid_places_; std::vector<Place> valid_places_;
......
...@@ -54,7 +54,7 @@ TEST(variable_place_inference_pass, test) { ...@@ -54,7 +54,7 @@ TEST(variable_place_inference_pass, test) {
"argument_type_display_pass", // "argument_type_display_pass", //
"variable_place_inference_pass", // "variable_place_inference_pass", //
"argument_type_display_pass", // "argument_type_display_pass", //
"io_complement_pass", // "type_target_transform_pass", //
}); });
Place prefered_place{ Place prefered_place{
......
...@@ -13,8 +13,8 @@ ...@@ -13,8 +13,8 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/lite/core/optimizer.h" #include "paddle/fluid/lite/core/optimizer.h"
#include "paddle/fluid/lite/core/mir/io_complement_pass.h"
#include "paddle/fluid/lite/core/mir/static_kernel_pick_pass.h" #include "paddle/fluid/lite/core/mir/static_kernel_pick_pass.h"
#include "paddle/fluid/lite/core/mir/type_target_transform_pass.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
......
...@@ -16,10 +16,10 @@ ...@@ -16,10 +16,10 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/fluid/lite/core/mir/generate_program_pass.h" #include "paddle/fluid/lite/core/mir/generate_program_pass.h"
#include "paddle/fluid/lite/core/mir/io_complement_pass.h"
#include "paddle/fluid/lite/core/mir/pass_manager.h" #include "paddle/fluid/lite/core/mir/pass_manager.h"
#include "paddle/fluid/lite/core/mir/ssa_graph.h" #include "paddle/fluid/lite/core/mir/ssa_graph.h"
#include "paddle/fluid/lite/core/mir/static_kernel_pick_pass.h" #include "paddle/fluid/lite/core/mir/static_kernel_pick_pass.h"
#include "paddle/fluid/lite/core/mir/type_target_transform_pass.h"
#include "paddle/fluid/lite/core/program.h" #include "paddle/fluid/lite/core/program.h"
#include "paddle/fluid/lite/core/types.h" #include "paddle/fluid/lite/core/types.h"
...@@ -48,7 +48,7 @@ class Optimizer { ...@@ -48,7 +48,7 @@ class Optimizer {
"static_kernel_pick_pass", // "static_kernel_pick_pass", //
"variable_place_inference_pass", // "variable_place_inference_pass", //
"argument_type_display_pass", // "argument_type_display_pass", //
"io_complement_pass", // "type_target_transform_pass", //
"argument_type_display_pass", // "argument_type_display_pass", //
"variable_place_inference_pass", // "variable_place_inference_pass", //
"argument_type_display_pass", // "argument_type_display_pass", //
...@@ -83,8 +83,9 @@ class Optimizer { ...@@ -83,8 +83,9 @@ class Optimizer {
} }
void InitIoComplement() { void InitIoComplement() {
auto* pass = mir::PassManager::Global().LookUp<mir::IoComplementPass>( auto* pass =
"io_complement_pass"); mir::PassManager::Global().LookUp<mir::TypeTargetTransformPass>(
"type_target_transform_pass");
CHECK(pass); CHECK(pass);
CHECK(!valid_places_.empty()); CHECK(!valid_places_.empty());
LOG(INFO) << "valid_places.size " << valid_places_.size(); LOG(INFO) << "valid_places.size " << valid_places_.size();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册