提交 b43e49fa 编写于 作者: X Xin Pan

fix

上级 afc603c1
...@@ -63,7 +63,7 @@ paddle.fluid.ExecutionStrategy.__init__ __init__(self: paddle.fluid.core.Executi ...@@ -63,7 +63,7 @@ paddle.fluid.ExecutionStrategy.__init__ __init__(self: paddle.fluid.core.Executi
paddle.fluid.BuildStrategy.GradientScaleStrategy.__init__ __init__(self: paddle.fluid.core.GradientScaleStrategy, arg0: int) -> None paddle.fluid.BuildStrategy.GradientScaleStrategy.__init__ __init__(self: paddle.fluid.core.GradientScaleStrategy, arg0: int) -> None
paddle.fluid.BuildStrategy.ReduceStrategy.__init__ __init__(self: paddle.fluid.core.ReduceStrategy, arg0: int) -> None paddle.fluid.BuildStrategy.ReduceStrategy.__init__ __init__(self: paddle.fluid.core.ReduceStrategy, arg0: int) -> None
paddle.fluid.BuildStrategy.__init__ __init__(self: paddle.fluid.core.BuildStrategy) -> None paddle.fluid.BuildStrategy.__init__ __init__(self: paddle.fluid.core.BuildStrategy) -> None
paddle.fluid.BuildStrategy.create_pass_builder create_pass_builder(self: paddle.fluid.core.BuildStrategy) -> paddle.fluid.core.PassBuilder paddle.fluid.BuildStrategy.create_passes_from_srategy create_passes_from_srategy(self: paddle.fluid.core.BuildStrategy) -> paddle.fluid.core.PassBuilder
paddle.fluid.create_lod_tensor ArgSpec(args=['data', 'recursive_seq_lens', 'place'], varargs=None, keywords=None, defaults=None) paddle.fluid.create_lod_tensor ArgSpec(args=['data', 'recursive_seq_lens', 'place'], varargs=None, keywords=None, defaults=None)
paddle.fluid.create_random_int_lodtensor ArgSpec(args=['recursive_seq_lens', 'base_shape', 'place', 'low', 'high'], varargs=None, keywords=None, defaults=None) paddle.fluid.create_random_int_lodtensor ArgSpec(args=['recursive_seq_lens', 'base_shape', 'place', 'low', 'high'], varargs=None, keywords=None, defaults=None)
paddle.fluid.io.save_vars ArgSpec(args=['executor', 'dirname', 'main_program', 'vars', 'predicate', 'filename'], varargs=None, keywords=None, defaults=(None, None, None, None)) paddle.fluid.io.save_vars ArgSpec(args=['executor', 'dirname', 'main_program', 'vars', 'predicate', 'filename'], varargs=None, keywords=None, defaults=(None, None, None, None))
......
...@@ -14,9 +14,6 @@ limitations under the License. */ ...@@ -14,9 +14,6 @@ limitations under the License. */
#include "paddle/fluid/framework/details/build_strategy.h" #include "paddle/fluid/framework/details/build_strategy.h"
#include <string>
#include <tuple>
#include "paddle/fluid/framework/details/multi_devices_graph_check_pass.h" #include "paddle/fluid/framework/details/multi_devices_graph_check_pass.h"
#include "paddle/fluid/framework/details/multi_devices_graph_print_pass.h" #include "paddle/fluid/framework/details/multi_devices_graph_print_pass.h"
#include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph.h"
...@@ -71,46 +68,14 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder { ...@@ -71,46 +68,14 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
AppendPass("multi_devices_check_pass"); AppendPass("multi_devices_check_pass");
} }
std::unique_ptr<ir::Graph> Build(
const ProgramDesc &main_program,
const std::vector<platform::Place> &places,
const std::string &loss_var_name,
const std::unordered_set<std::string> &param_names,
const std::vector<Scope *> &local_scopes,
#ifdef PADDLE_WITH_CUDA
const bool use_cuda, platform::NCCLContextMap *nccl_ctxs) const {
#else
const bool use_cuda) const {
#endif
// Convert the program to graph.
std::unique_ptr<ir::Graph> graph(new ir::Graph(main_program));
for (std::shared_ptr<ir::Pass> &pass : AllPasses()) {
if (pass->Type() == "multi_devices_pass") {
pass->SetNotOwned<const std::vector<platform::Place>>("places",
&places);
pass->SetNotOwned<const std::string>("loss_var_name", &loss_var_name);
pass->SetNotOwned<const std::unordered_set<std::string>>("params",
&param_names);
pass->SetNotOwned<const std::vector<Scope *>>("local_scopes",
&local_scopes);
#ifdef PADDLE_WITH_CUDA
platform::NCCLContextMap *nctx = use_cuda ? nccl_ctxs : nullptr;
pass->SetNotOwned<platform::NCCLContextMap>("nccl_ctxs", nctx);
#endif
}
graph = pass->Apply(std::move(graph));
}
return graph;
}
private: private:
BuildStrategy strategy_; BuildStrategy strategy_;
}; };
ir::PassBuilder *BuildStrategy::CreatePassBuilder() const { std::shared_ptr<ir::PassBuilder> BuildStrategy::CreatePassesFromStrategy()
const {
pass_builder_.reset(new ParallelExecutorPassBuilder(*this)); pass_builder_.reset(new ParallelExecutorPassBuilder(*this));
return pass_builder_.get(); return pass_builder_;
} }
std::unique_ptr<ir::Graph> BuildStrategy::Apply( std::unique_ptr<ir::Graph> BuildStrategy::Apply(
...@@ -123,20 +88,33 @@ std::unique_ptr<ir::Graph> BuildStrategy::Apply( ...@@ -123,20 +88,33 @@ std::unique_ptr<ir::Graph> BuildStrategy::Apply(
#else #else
const bool use_cuda) const { const bool use_cuda) const {
#endif #endif
// Create a default one if not intialized by user.
if (!pass_builder_) { if (!pass_builder_) {
CreatePassBuilder(); CreatePassesFromStrategy();
} }
// std::unique_ptr<ir::Graph> graph;
ParallelExecutorPassBuilder *builder = std::unique_ptr<ir::Graph> graph(new ir::Graph(main_program));
reinterpret_cast<ParallelExecutorPassBuilder *>(pass_builder_.get());
for (std::shared_ptr<ir::Pass> &pass : pass_builder_->AllPasses()) {
if (pass->Type() == "multi_devices_pass") {
pass->Erase("places");
pass->SetNotOwned<const std::vector<platform::Place>>("places", &places);
pass->Erase("loss_var_name");
pass->SetNotOwned<const std::string>("loss_var_name", &loss_var_name);
pass->Erase("params");
pass->SetNotOwned<const std::unordered_set<std::string>>("params",
&param_names);
pass->Erase("local_scopes");
pass->SetNotOwned<const std::vector<Scope *>>("local_scopes",
&local_scopes);
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
std::unique_ptr<ir::Graph> graph = platform::NCCLContextMap *nctx = use_cuda ? nccl_ctxs : nullptr;
builder->Build(main_program, places, loss_var_name, param_names, pass->Erase("nccl_ctxs");
local_scopes, use_cuda, nccl_ctxs); pass->SetNotOwned<platform::NCCLContextMap>("nccl_ctxs", nctx);
#else
std::unique_ptr<ir::Graph> graph = builder->Build(
main_program, places, loss_var_name, param_names, local_scopes, use_cuda);
#endif #endif
}
graph = pass->Apply(std::move(graph));
}
return graph; return graph;
} }
} // namespace details } // namespace details
......
...@@ -31,9 +31,6 @@ namespace paddle { ...@@ -31,9 +31,6 @@ namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
class ParallelExecutorPassBuilder;
struct BuildStrategy;
struct BuildStrategy { struct BuildStrategy {
// ParallelExecutor supports two modes of ReduceStrategy, kAllReduce and // ParallelExecutor supports two modes of ReduceStrategy, kAllReduce and
// kReduce, for CPU and GPU. If you use kAllReduce, different threads // kReduce, for CPU and GPU. If you use kAllReduce, different threads
...@@ -72,7 +69,7 @@ struct BuildStrategy { ...@@ -72,7 +69,7 @@ struct BuildStrategy {
bool enable_data_balance_{false}; bool enable_data_balance_{false};
ir::PassBuilder *CreatePassBuilder() const; std::shared_ptr<ir::PassBuilder> CreatePassesFromStrategy() const;
std::unique_ptr<ir::Graph> Apply( std::unique_ptr<ir::Graph> Apply(
const ProgramDesc &main_program, const ProgramDesc &main_program,
...@@ -87,7 +84,6 @@ struct BuildStrategy { ...@@ -87,7 +84,6 @@ struct BuildStrategy {
#endif #endif
private: private:
// TODO(panyx0718): This should probably be unique_ptr.
mutable std::shared_ptr<ir::PassBuilder> pass_builder_; mutable std::shared_ptr<ir::PassBuilder> pass_builder_;
}; };
......
...@@ -54,6 +54,21 @@ class Pass { ...@@ -54,6 +54,21 @@ class Pass {
return *boost::any_cast<AttrType *>(attrs_.at(attr_name)); return *boost::any_cast<AttrType *>(attrs_.at(attr_name));
} }
bool Has(const std::string &attr_name) const {
return attrs_.find(attr_name) != attrs_.end();
}
void Erase(const std::string &attr_name) {
if (!Has(attr_name)) {
return;
}
if (attr_dels_.find(attr_name) != attr_dels_.end()) {
attr_dels_[attr_name]();
attr_dels_.erase(attr_name);
}
attrs_.erase(attr_name);
}
// Set a pointer to the attribute. Pass takes ownership of the attribute. // Set a pointer to the attribute. Pass takes ownership of the attribute.
template <typename AttrType> template <typename AttrType>
void Set(const std::string &attr_name, AttrType *attr) { void Set(const std::string &attr_name, AttrType *attr) {
...@@ -70,6 +85,8 @@ class Pass { ...@@ -70,6 +85,8 @@ class Pass {
// should delete the attribute. // should delete the attribute.
template <typename AttrType> template <typename AttrType>
void SetNotOwned(const std::string &attr_name, AttrType *attr) { void SetNotOwned(const std::string &attr_name, AttrType *attr) {
PADDLE_ENFORCE(attrs_.count(attr_name) == 0, "%s already set in the pass",
attr_name);
attrs_[attr_name] = attr; attrs_[attr_name] = attr;
} }
......
...@@ -82,12 +82,10 @@ TEST(PassTest, TestPassAttrCheck) { ...@@ -82,12 +82,10 @@ TEST(PassTest, TestPassAttrCheck) {
ASSERT_EQ(graph->Get<int>("copy_test_pass_attr"), 2); ASSERT_EQ(graph->Get<int>("copy_test_pass_attr"), 2);
ASSERT_EQ(graph->Get<int>("copy_test_graph_attr"), 2); ASSERT_EQ(graph->Get<int>("copy_test_graph_attr"), 2);
try { // Allow apply more than once.
graph = pass->Apply(std::move(graph)); graph.reset(new Graph(prog));
} catch (paddle::platform::EnforceNotMet e) { graph->Set<int>("test_graph_attr", new int);
exception = std::string(e.what()); graph = pass->Apply(std::move(graph));
}
ASSERT_TRUE(exception.find("Pass can only Apply() once") != exception.npos);
pass = PassRegistry::Instance().Get("test_pass"); pass = PassRegistry::Instance().Get("test_pass");
pass->SetNotOwned<int>("test_pass_attr", &val); pass->SetNotOwned<int>("test_pass_attr", &val);
......
...@@ -603,7 +603,8 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -603,7 +603,8 @@ All parameter, weight, gradient are variables in Paddle.
self.Set<std::string>(name, new std::string(attr)); self.Set<std::string>(name, new std::string(attr));
}); });
py::class_<ir::PassBuilder> pb(m, "PassBuilder"); py::class_<ir::PassBuilder, std::shared_ptr<ir::PassBuilder>> pb(
m, "PassBuilder");
pb.def(py::init()) pb.def(py::init())
.def("append_pass", .def("append_pass",
[](ir::PassBuilder &self, [](ir::PassBuilder &self,
...@@ -701,9 +702,10 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -701,9 +702,10 @@ All parameter, weight, gradient are variables in Paddle.
[](BuildStrategy &self, bool b) { [](BuildStrategy &self, bool b) {
self.fuse_elewise_add_act_ops_ = b; self.fuse_elewise_add_act_ops_ = b;
}) })
.def("create_pass_builder", .def("create_passes_from_srategy",
[](BuildStrategy &self) { return *self.CreatePassBuilder(); }, [](BuildStrategy &self) -> std::shared_ptr<ir::PassBuilder> {
py::return_value_policy::reference); return self.CreatePassesFromStrategy();
});
pe.def(py::init<const std::vector<platform::Place> &, pe.def(py::init<const std::vector<platform::Place> &,
const std::unordered_set<std::string> &, const std::unordered_set<std::string> &,
......
...@@ -94,16 +94,27 @@ class TestPassBuilder(unittest.TestCase): ...@@ -94,16 +94,27 @@ class TestPassBuilder(unittest.TestCase):
def test_parallel_testing_with_new_strategy(self): def test_parallel_testing_with_new_strategy(self):
build_strategy = fluid.BuildStrategy() build_strategy = fluid.BuildStrategy()
pass_builder = build_strategy.create_pass_builder() pass_builder = build_strategy.create_passes_from_srategy()
origin_len = len(pass_builder.all_passes())
viz_pass = pass_builder.append_pass("graph_viz_pass") viz_pass = pass_builder.append_pass("graph_viz_pass")
all_passes = pass_builder.all_passes() self.assertEqual(origin_len + 1, len(pass_builder.all_passes()))
pass_builder.insert_pass(len(all_passes), "graph_viz_pass")
pass_builder.insert_pass(
len(pass_builder.all_passes()), "graph_viz_pass")
self.assertEqual(origin_len + 2, len(pass_builder.all_passes()))
pass_builder.remove_pass(len(pass_builder.all_passes()) - 1) pass_builder.remove_pass(len(pass_builder.all_passes()) - 1)
viz_pass.set_str("graph_viz_path", "/tmp/viz_pass") self.assertEqual(origin_len + 1, len(pass_builder.all_passes()))
viz_pass.set_str("graph_viz_path", "/tmp/test_viz_pass")
self.check_network_convergence( self.check_network_convergence(
use_cuda=core.is_compiled_with_cuda(), use_cuda=core.is_compiled_with_cuda(),
build_strategy=build_strategy) build_strategy=build_strategy)
try:
os.stat("/tmp/test_viz_pass")
except os.error:
self.assertFalse(True)
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册