提交 1a67061f 编写于 作者: X Xin Pan

graph to program pass

fix a few other things
上级 1f09bc32
...@@ -107,11 +107,11 @@ cc_library(lod_rank_table SRCS lod_rank_table.cc DEPS lod_tensor) ...@@ -107,11 +107,11 @@ cc_library(lod_rank_table SRCS lod_rank_table.cc DEPS lod_tensor)
cc_library(feed_fetch_method SRCS feed_fetch_method.cc DEPS lod_tensor scope glog) cc_library(feed_fetch_method SRCS feed_fetch_method.cc DEPS lod_tensor scope glog)
if(WITH_DISTRIBUTE) if(WITH_DISTRIBUTE)
cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method sendrecvop_grpc cares grpc++_unsecure grpc_unsecure gpr) cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method sendrecvop_grpc cares grpc++_unsecure grpc_unsecure gpr graph_to_program_pass)
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor") set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
set_source_files_properties(executor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(executor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
else() else()
cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method) cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method graph_to_program_pass)
endif() endif()
if (NOT WIN32) if (NOT WIN32)
......
...@@ -3,6 +3,7 @@ cc_library(graph SRCS graph.cc DEPS node) ...@@ -3,6 +3,7 @@ cc_library(graph SRCS graph.cc DEPS node)
cc_library(graph_helper SRCS graph_helper.cc DEPS graph) cc_library(graph_helper SRCS graph_helper.cc DEPS graph)
cc_library(pass SRCS pass.cc DEPS graph node graph_helper) cc_library(pass SRCS pass.cc DEPS graph node graph_helper)
cc_library(graph_viz_pass SRCS graph_viz_pass.cc DEPS graph pass graph_helper) cc_library(graph_viz_pass SRCS graph_viz_pass.cc DEPS graph pass graph_helper)
cc_library(graph_to_program_pass SRCS graph_to_program_pass.cc DEPS graph pass graph_helper)
cc_library(graph_traits SRCS graph_traits.cc DEPS graph) cc_library(graph_traits SRCS graph_traits.cc DEPS graph)
cc_library(graph_pattern_detecter SRCS graph_pattern_detecter.cc DEPS graph graph_helper graph_traits) cc_library(graph_pattern_detecter SRCS graph_pattern_detecter.cc DEPS graph graph_helper graph_traits)
cc_library(fc_fuse_pass SRCS fc_fuse_pass.cc DEPS graph graph_pattern_detecter) cc_library(fc_fuse_pass SRCS fc_fuse_pass.cc DEPS graph graph_pattern_detecter)
...@@ -12,5 +13,6 @@ cc_library(infer_clean_graph_pass SRCS infer_clean_graph_pass.cc DEPS graph pass ...@@ -12,5 +13,6 @@ cc_library(infer_clean_graph_pass SRCS infer_clean_graph_pass.cc DEPS graph pass
cc_test(pass_test SRCS pass_test.cc DEPS graph pass graph_helper) cc_test(pass_test SRCS pass_test.cc DEPS graph pass graph_helper)
cc_test(graph_test SRCS graph_test.cc DEPS graph graph_helper op_registry) cc_test(graph_test SRCS graph_test.cc DEPS graph graph_helper op_registry)
cc_test(graph_helper_test SRCS graph_helper_test.cc DEPS graph graph_helper op_registry) cc_test(graph_helper_test SRCS graph_helper_test.cc DEPS graph graph_helper op_registry)
cc_test(graph_to_program_pass_test SRCS graph_to_program_pass_test.cc DEPS graph_to_program_pass)
cc_test(test_graph_pattern_detecter SRCS graph_pattern_detecter_tester.cc DEPS graph_pattern_detecter) cc_test(test_graph_pattern_detecter SRCS graph_pattern_detecter_tester.cc DEPS graph_pattern_detecter)
cc_test(test_fc_fuse_pass SRCS fc_fuse_pass_tester.cc DEPS fc_fuse_pass graph_pattern_detecter graph pass graph_traits framework_proto) cc_test(test_fc_fuse_pass SRCS fc_fuse_pass_tester.cc DEPS fc_fuse_pass graph_pattern_detecter graph pass graph_traits framework_proto)
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/ir/graph_to_program_pass.h"
#include <map>
#include <string>
#include <vector>
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/program_desc.h"
namespace paddle {
namespace framework {
namespace ir {
std::unique_ptr<Graph> GraphToProgramPass::ApplyImpl(
std::unique_ptr<Graph> graph) const {
ProgramDesc& program = Get<ProgramDesc>("program");
std::unique_ptr<proto::ProgramDesc> program_pb(
new proto::ProgramDesc(*program.Proto()));
auto block = program_pb->mutable_blocks(kRootBlockIndex);
block->clear_vars();
std::unordered_set<std::string> visited_vars;
for (ir::Node* n : graph->Nodes()) {
if (n->NodeType() == ir::Node::Type::kVariable) {
if (n->Var() && visited_vars.count(n->Var()->Name()) == 0) {
visited_vars.insert(n->Var()->Name());
block->add_vars()->MergeFrom(*n->Var()->Proto());
}
}
}
block->clear_ops();
std::vector<ir::Node*> nodes = TopologySortOperations(*graph);
for (ir::Node* n : nodes) {
if (!n->Op()) {
continue;
}
block->add_ops()->MergeFrom(*n->Op()->Proto());
}
program.CopyFrom(*program_pb);
return graph;
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(graph_to_program_pass, paddle::framework::ir::GraphToProgramPass);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/ir/pass.h"
namespace paddle {
namespace framework {
namespace ir {
class GraphToProgramPass : public Pass {
protected:
std::unique_ptr<Graph> ApplyImpl(std::unique_ptr<Graph> graph) const override;
};
} // namespace ir
} // namespace framework
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/ir/graph_to_program_pass.h"
namespace paddle {
namespace framework {
namespace ir {} // namespace ir
} // namespace framework
} // namespace paddle
...@@ -132,7 +132,9 @@ OpDesc::OpDesc(const proto::OpDesc &desc, BlockDesc *block) ...@@ -132,7 +132,9 @@ OpDesc::OpDesc(const proto::OpDesc &desc, BlockDesc *block)
std::string attr_name = attr.name(); std::string attr_name = attr.name();
// The sub_block referred to by the BLOCK attr hasn't been added // The sub_block referred to by the BLOCK attr hasn't been added
// to ProgramDesc class yet, we skip setting BLOCK attr here. // to ProgramDesc class yet, we skip setting BLOCK attr here.
if (attr.type() != proto::AttrType::BLOCK) { // TODO(paddle-dev): Need copy fix this to copy Block as well.
if (attr.type() != proto::AttrType::BLOCK &&
attr.type() != proto::AttrType::BLOCKS) {
attrs_[attr_name] = GetAttrValue(attr); attrs_[attr_name] = GetAttrValue(attr);
} }
} }
......
...@@ -80,6 +80,12 @@ ProgramDesc::ProgramDesc(const proto::ProgramDesc &desc) { ...@@ -80,6 +80,12 @@ ProgramDesc::ProgramDesc(const proto::ProgramDesc &desc) {
InitFromProto(); InitFromProto();
} }
void ProgramDesc::CopyFrom(const proto::ProgramDesc &desc) {
blocks_.clear();
desc_ = desc;
InitFromProto();
}
ProgramDesc::ProgramDesc(const std::string &binary_str) { ProgramDesc::ProgramDesc(const std::string &binary_str) {
PADDLE_ENFORCE(desc_.ParseFromString(binary_str), PADDLE_ENFORCE(desc_.ParseFromString(binary_str),
"Fail to parse program_desc from binary string."); "Fail to parse program_desc from binary string.");
......
...@@ -53,6 +53,8 @@ class ProgramDesc { ...@@ -53,6 +53,8 @@ class ProgramDesc {
void Flush(); void Flush();
void CopyFrom(const proto::ProgramDesc &desc);
proto::ProgramDesc *Proto(); proto::ProgramDesc *Proto();
// The output variable of feed_op is referenced as feed_target. // The output variable of feed_op is referenced as feed_target.
......
...@@ -10,7 +10,7 @@ set(FLUID_CORE_MODULES proto_desc memory lod_tensor executor) ...@@ -10,7 +10,7 @@ set(FLUID_CORE_MODULES proto_desc memory lod_tensor executor)
# TODO(panyx0718): Should this be called paddle_fluid_inference_api_internal? # TODO(panyx0718): Should this be called paddle_fluid_inference_api_internal?
cc_library(paddle_fluid_api cc_library(paddle_fluid_api
SRCS io.cc SRCS io.cc
DEPS ${FLUID_CORE_MODULES} ${GLOB_OP_LIB}) DEPS ${FLUID_CORE_MODULES} ${GLOB_OP_LIB} graph_to_program_pass)
get_property(fluid_modules GLOBAL PROPERTY FLUID_MODULES) get_property(fluid_modules GLOBAL PROPERTY FLUID_MODULES)
......
...@@ -138,7 +138,6 @@ std::unique_ptr<framework::ProgramDesc> Load( ...@@ -138,7 +138,6 @@ std::unique_ptr<framework::ProgramDesc> Load(
std::unique_ptr<framework::ProgramDesc> main_program( std::unique_ptr<framework::ProgramDesc> main_program(
new framework::ProgramDesc(program_desc_str)); new framework::ProgramDesc(program_desc_str));
LoadPersistables(executor, scope, *main_program, "", param_filename); LoadPersistables(executor, scope, *main_program, "", param_filename);
return main_program; return main_program;
} }
......
...@@ -18,6 +18,7 @@ limitations under the License. */ ...@@ -18,6 +18,7 @@ limitations under the License. */
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/fluid/framework/ir/graph_to_program_pass.h"
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/inference/io.h" #include "paddle/fluid/inference/io.h"
#include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler.h"
...@@ -135,6 +136,15 @@ std::vector<std::vector<int64_t>> GetFeedTargetShapes( ...@@ -135,6 +136,15 @@ std::vector<std::vector<int64_t>> GetFeedTargetShapes(
return feed_target_shapes; return feed_target_shapes;
} }
void Compile(paddle::framework::ProgramDesc* program) {
std::unique_ptr<paddle::framework::ir::Graph> g(
new paddle::framework::ir::Graph(*program));
auto pass = paddle::framework::ir::PassRegistry::Instance().Get(
"graph_to_program_pass");
pass->SetNotOwned<paddle::framework::ProgramDesc>("program", program);
pass->Apply(std::move(g));
}
template <typename Place, bool CreateVars = true, bool PrepareContext = false> template <typename Place, bool CreateVars = true, bool PrepareContext = false>
void TestInference(const std::string& dirname, void TestInference(const std::string& dirname,
const std::vector<paddle::framework::LoDTensor*>& cpu_feeds, const std::vector<paddle::framework::LoDTensor*>& cpu_feeds,
...@@ -172,6 +182,8 @@ void TestInference(const std::string& dirname, ...@@ -172,6 +182,8 @@ void TestInference(const std::string& dirname,
paddle::platform::DeviceContextPool::Instance().Get(place)); paddle::platform::DeviceContextPool::Instance().Get(place));
inference_program = InitProgram(&executor, scope, dirname, is_combined); inference_program = InitProgram(&executor, scope, dirname, is_combined);
} }
Compile(inference_program.get());
// Disable the profiler and print the timing information // Disable the profiler and print the timing information
paddle::platform::DisableProfiler(paddle::platform::EventSortingKey::kDefault, paddle::platform::DisableProfiler(paddle::platform::EventSortingKey::kDefault,
"load_program_profiler"); "load_program_profiler");
...@@ -249,3 +261,5 @@ void TestInference(const std::string& dirname, ...@@ -249,3 +261,5 @@ void TestInference(const std::string& dirname,
delete scope; delete scope;
} }
USE_PASS(graph_to_program_pass);
...@@ -355,6 +355,7 @@ class ParallelDoGradOpDescMaker : public framework::SingleGradOpDescMaker { ...@@ -355,6 +355,7 @@ class ParallelDoGradOpDescMaker : public framework::SingleGradOpDescMaker {
grad->SetInput(framework::GradVarName(output_param), og_names); grad->SetInput(framework::GradVarName(output_param), og_names);
} }
} }
grad->SetInput("Communicator", {"nccl_com__do_not_change_"});
grad->SetAttrMap(this->Attrs()); grad->SetAttrMap(this->Attrs());
grad->SetBlockAttr(kParallelBlock, grad_block_[0]); grad->SetBlockAttr(kParallelBlock, grad_block_[0]);
......
...@@ -125,8 +125,8 @@ opts = optimizer.minimize(avg_cost) ...@@ -125,8 +125,8 @@ opts = optimizer.minimize(avg_cost)
batch_size = fluid.layers.create_tensor(dtype='int64') batch_size = fluid.layers.create_tensor(dtype='int64')
batch_acc = fluid.layers.accuracy(input=predict, label=label, total=batch_size) batch_acc = fluid.layers.accuracy(input=predict, label=label, total=batch_size)
# fluid.memory_optimize(fluid.default_main_program(), level=0) fluid.memory_optimize(fluid.default_main_program(), level=0)
fluid.release_memory(fluid.default_main_program()) # fluid.release_memory(fluid.default_main_program())
BATCH_SIZE = 16 BATCH_SIZE = 16
PASS_NUM = 1 PASS_NUM = 1
......
...@@ -92,8 +92,8 @@ def main(): ...@@ -92,8 +92,8 @@ def main():
optimizer = fluid.optimizer.Adagrad(learning_rate=1e-4) optimizer = fluid.optimizer.Adagrad(learning_rate=1e-4)
optimizer.minimize(avg_cost) optimizer.minimize(avg_cost)
# fluid.memory_optimize(fluid.default_main_program()) fluid.memory_optimize(fluid.default_main_program())
fluid.release_memory(fluid.default_main_program()) # fluid.release_memory(fluid.default_main_program())
# fix the order of training data # fix the order of training data
train_data = paddle.batch( train_data = paddle.batch(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册