提交 36c2a9af 编写于 作者: X Xin Pan

pass builder allow cutomize pass in python.

上级 eb1aeb17
......@@ -150,11 +150,10 @@ else()
endif()
if (NOT WIN32)
cc_library(parallel_executor SRCS parallel_executor.cc DEPS
threaded_ssa_graph_executor scope_buffered_ssa_graph_executor
graph graph_viz_pass multi_devices_graph_pass
multi_devices_graph_print_pass multi_devices_graph_check_pass
fast_threaded_ssa_graph_executor fuse_elewise_add_act_pass)
cc_library(parallel_executor SRCS parallel_executor.cc DEPS
threaded_ssa_graph_executor scope_buffered_ssa_graph_executor
graph build_strategy
fast_threaded_ssa_graph_executor)
endif() # NOT WIN32
cc_library(prune SRCS prune.cc DEPS framework_proto)
......
......@@ -54,3 +54,8 @@ cc_library(scope_buffered_ssa_graph_executor SRCS scope_buffered_ssa_graph_execu
# device_context reduce_op_handle )
cc_library(fast_threaded_ssa_graph_executor SRCS fast_threaded_ssa_graph_executor.cc
DEPS fetch_op_handle ssa_graph_executor scope simple_threadpool device_context)
cc_library(build_strategy SRCS build_strategy.cc DEPS
graph_viz_pass multi_devices_graph_pass
multi_devices_graph_print_pass multi_devices_graph_check_pass,
fuse_elewise_add_act_pass)
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/details/build_strategy.h"
#include <string>
#include <tuple>
#include "paddle/fluid/framework/details/multi_devices_graph_check_pass.h"
#include "paddle/fluid/framework/details/multi_devices_graph_print_pass.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_viz_pass.h"
namespace paddle {
namespace framework {
namespace details {
class ParallelExecutorPassBuilder : public ir::PassBuilder {
public:
explicit ParallelExecutorPassBuilder(const BuildStrategy &strategy)
: ir::PassBuilder(), strategy_(strategy) {
// Apply a graph viz pass to record a graph.
if (!strategy_.debug_graphviz_path_.empty()) {
auto viz_pass = AppendPass("graph_viz_pass");
const std::string graph_path = string::Sprintf(
"%s%s", strategy_.debug_graphviz_path_.c_str(), "_original_graph");
viz_pass->Set<std::string>("graph_viz_path", new std::string(graph_path));
}
// Apply op fusion.
if (strategy.fuse_elewise_add_act_ops_) {
auto fuse_elewise_add_act_pass =
ir::PassRegistry::Instance().Get("fuse_elewise_add_act_pass");
graph = fuse_elewise_add_act_pass->Apply(std::move(graph));
// Apply a graph viz pass to record a graph.
if (!strategy.debug_graphviz_path_.empty()) {
auto viz_pass = ir::PassRegistry::Instance().Get("graph_viz_pass");
const std::string graph_path = string::Sprintf(
"%s%s", strategy.debug_graphviz_path_.c_str(), "_fused_graph");
viz_pass->Set<std::string>("graph_viz_path", new std::string(graph_path));
graph = viz_pass->Apply(std::move(graph));
}
}
// Convert graph to run on multi-devices.
auto multi_devices_pass = AppendPass("multi_devices_pass");
multi_devices_pass->SetNotOwned<const BuildStrategy>("strategy",
&strategy_);
// Apply a graph print pass to record a graph with device info.
if (!strategy_.debug_graphviz_path_.empty()) {
auto multi_devices_print_pass = AppendPass("multi_devices_print_pass");
multi_devices_print_pass->SetNotOwned<const std::string>(
"debug_graphviz_path", &strategy_.debug_graphviz_path_);
multi_devices_print_pass->Set<details::GraphvizSSAGraphPrinter>(
"graph_printer", new details::GraphvizSSAGraphPrinter);
}
// Verify that the graph is correct for multi-device executor.
AppendPass("multi_devices_check_pass");
}
std::unique_ptr<ir::Graph> Build(
const ProgramDesc &main_program,
const std::vector<platform::Place> &places,
const std::string &loss_var_name,
const std::unordered_set<std::string> &param_names,
const std::vector<Scope *> &local_scopes,
#ifdef PADDLE_WITH_CUDA
const bool use_cuda, platform::NCCLContextMap *nccl_ctxs) const {
#else
const bool use_cuda) const {
#endif
// Convert the program to graph.
std::unique_ptr<ir::Graph> graph(new ir::Graph(main_program));
for (std::shared_ptr<ir::Pass> &pass : AllPasses()) {
if (pass->Type() == "multi_devices_pass") {
pass->SetNotOwned<const std::vector<platform::Place>>("places",
&places);
pass->SetNotOwned<const std::string>("loss_var_name", &loss_var_name);
pass->SetNotOwned<const std::unordered_set<std::string>>("params",
&param_names);
pass->SetNotOwned<const std::vector<Scope *>>("local_scopes",
&local_scopes);
#ifdef PADDLE_WITH_CUDA
platform::NCCLContextMap *nctx = use_cuda ? nccl_ctxs : nullptr;
pass->SetNotOwned<platform::NCCLContextMap>("nccl_ctxs", nctx);
#endif
}
graph = pass->Apply(std::move(graph));
}
return graph;
}
private:
BuildStrategy strategy_;
};
ir::PassBuilder *BuildStrategy::CreatePassBuilder() const {
pass_builder_.reset(new ParallelExecutorPassBuilder(*this));
return pass_builder_.get();
}
std::unique_ptr<ir::Graph> BuildStrategy::Apply(
const ProgramDesc &main_program, const std::vector<platform::Place> &places,
const std::string &loss_var_name,
const std::unordered_set<std::string> &param_names,
const std::vector<Scope *> &local_scopes,
#ifdef PADDLE_WITH_CUDA
const bool use_cuda, platform::NCCLContextMap *nccl_ctxs) const {
#else
const bool use_cuda) const {
#endif
if (!pass_builder_) {
CreatePassBuilder();
}
// std::unique_ptr<ir::Graph> graph;
ParallelExecutorPassBuilder *builder =
reinterpret_cast<ParallelExecutorPassBuilder *>(pass_builder_.get());
#ifdef PADDLE_WITH_CUDA
std::unique_ptr<ir::Graph> graph =
builder->Build(main_program, places, loss_var_name, param_names,
local_scopes, use_cuda, nccl_ctxs);
#else
std::unique_ptr<ir::Graph> graph = builder->Build(
main_program, places, loss_var_name, param_names, local_scopes, use_cuda);
#endif
return graph;
}
} // namespace details
} // namespace framework
} // namespace paddle
USE_PASS(fuse_elewise_add_act_pass);
USE_PASS(graph_viz_pass);
USE_PASS(multi_devices_pass);
USE_PASS(multi_devices_check_pass);
USE_PASS(multi_devices_print_pass);
......@@ -15,11 +15,25 @@
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/framework/ir/pass_builder.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/enforce.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/nccl_helper.h"
#endif
namespace paddle {
namespace framework {
namespace details {
class ParallelExecutorPassBuilder;
struct BuildStrategy;
struct BuildStrategy {
// ParallelExecutor supports two modes of ReduceStrategy, kAllReduce and
// kReduce, for CPU and GPU. If you use kAllReduce, different threads
......@@ -57,6 +71,24 @@ struct BuildStrategy {
bool fuse_elewise_add_act_ops_{false};
bool enable_data_balance_{false};
ir::PassBuilder *CreatePassBuilder() const;
std::unique_ptr<ir::Graph> Apply(
const ProgramDesc &main_program,
const std::vector<platform::Place> &places,
const std::string &loss_var_name,
const std::unordered_set<std::string> &param_names,
const std::vector<Scope *> &local_scopes,
#ifdef PADDLE_WITH_CUDA
const bool use_cuda, platform::NCCLContextMap *nccl_ctxs) const;
#else
const bool use_cuda) const;
#endif
private:
// TODO(panyx0718): This should probably be unique_ptr.
mutable std::shared_ptr<ir::PassBuilder> pass_builder_;
};
} // namespace details
......
......@@ -41,6 +41,8 @@ cc_library(fuse_elewise_add_act_pass SRCS fuse_elewise_add_act_pass.cc DEPS pass
set(GLOB_PASS_LIB ${PASS_LIBRARY} CACHE INTERNAL "Global PASS library")
cc_library(pass_builder SRCS pass_builder.cc DEPS pass)
cc_test(pass_test SRCS pass_test.cc DEPS graph pass graph_helper)
cc_test(graph_test SRCS graph_test.cc DEPS graph graph_helper op_registry)
cc_test(graph_helper_test SRCS graph_helper_test.cc DEPS graph graph_helper op_registry)
......
......@@ -19,7 +19,6 @@ namespace paddle {
namespace framework {
namespace ir {
std::unique_ptr<Graph> Pass::Apply(std::unique_ptr<Graph> graph) const {
PADDLE_ENFORCE(!applied_, "Pass can only Apply() once.");
PADDLE_ENFORCE(graph.get(), "graph passed to Pass::Apply() cannot be empty.");
for (const std::string& attr : required_pass_attrs_) {
PADDLE_ENFORCE(attrs_.find(attr) != attrs_.end(),
......
......@@ -42,6 +42,8 @@ class Pass {
attr_dels_.clear();
}
std::string Type() const { return type_; }
std::unique_ptr<Graph> Apply(std::unique_ptr<Graph> graph) const;
// Get a reference to the attributed previously set.
......@@ -68,13 +70,13 @@ class Pass {
// should delete the attribute.
template <typename AttrType>
void SetNotOwned(const std::string &attr_name, AttrType *attr) {
PADDLE_ENFORCE(attrs_.count(attr_name) == 0);
attrs_[attr_name] = attr;
}
protected:
virtual std::unique_ptr<Graph> ApplyImpl(
std::unique_ptr<Graph> graph) const = 0;
virtual std::unique_ptr<Graph> ApplyImpl(std::unique_ptr<Graph> graph) const {
LOG(FATAL) << "Calling virtual Pass not implemented.";
}
private:
template <typename PassType>
......@@ -89,7 +91,10 @@ class Pass {
required_graph_attrs_.insert(attrs.begin(), attrs.end());
}
void RegisterType(const std::string &type) { type_ = type; }
mutable bool applied_{false};
std::string type_;
std::unordered_set<std::string> required_pass_attrs_;
std::unordered_set<std::string> required_graph_attrs_;
std::map<std::string, boost::any> attrs_;
......@@ -143,10 +148,11 @@ struct PassRegistrar : public Registrar {
PADDLE_ENFORCE(!PassRegistry::Instance().Has(pass_type),
"'%s' is registered more than once.", pass_type);
PassRegistry::Instance().Insert(
pass_type, [this]() -> std::unique_ptr<Pass> {
pass_type, [this, pass_type]() -> std::unique_ptr<Pass> {
std::unique_ptr<Pass> pass(new PassType());
pass->RegisterRequiredPassAttrs(this->required_pass_attrs_);
pass->RegisterRequiredGraphAttrs(this->required_graph_attrs_);
pass->RegisterType(pass_type);
return pass;
});
}
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/ir/pass_builder.h"
namespace paddle {
namespace framework {
namespace ir {
std::shared_ptr<Pass> PassBuilder::AppendPass(const std::string& pass_type) {
auto pass = ir::PassRegistry::Instance().Get(pass_type);
passes_.emplace_back(pass.release());
return passes_.back();
}
void PassBuilder::RemovePass(size_t idx) {
PADDLE_ENFORCE(passes_.size() > idx);
passes_.erase(passes_.begin() + idx);
}
std::shared_ptr<Pass> PassBuilder::InsertPass(size_t idx,
const std::string& pass_type) {
PADDLE_ENFORCE(passes_.size() >= idx);
std::shared_ptr<Pass> pass(
ir::PassRegistry::Instance().Get(pass_type).release());
passes_.insert(passes_.begin() + idx, std::move(pass));
return passes_[idx];
}
} // namespace ir
} // namespace framework
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/framework/ir/pass.h"
namespace paddle {
namespace framework {
namespace ir {
class PassBuilder {
public:
PassBuilder() {}
virtual ~PassBuilder() {}
std::shared_ptr<Pass> AppendPass(const std::string& pass_type);
std::shared_ptr<Pass> InsertPass(size_t idx, const std::string& pass_type);
void RemovePass(size_t idx);
std::vector<std::shared_ptr<Pass>> AllPasses() const { return passes_; }
protected:
std::vector<std::shared_ptr<Pass>> passes_;
};
} // namespace ir
} // namespace framework
} // namespace paddle
......@@ -19,15 +19,13 @@ limitations under the License. */
#include <vector>
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_viz_pass.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/nccl_helper.h"
#endif
#include "paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.h"
#include "paddle/fluid/framework/details/multi_devices_graph_check_pass.h"
#include "paddle/fluid/framework/details/multi_devices_graph_print_pass.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h"
#include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h"
#include "paddle/fluid/platform/profiler.h"
......@@ -35,80 +33,6 @@ limitations under the License. */
namespace paddle {
namespace framework {
std::unique_ptr<ir::Graph> ApplyParallelExecutorPass(
const ProgramDesc &main_program, const std::vector<platform::Place> &places,
const std::string &loss_var_name,
const std::unordered_set<std::string> &param_names,
const std::vector<Scope *> &local_scopes, const bool use_cuda,
#ifdef PADDLE_WITH_CUDA
const BuildStrategy &strategy, platform::NCCLContextMap *nccl_ctxs) {
#else
const BuildStrategy &strategy) {
#endif
// Convert the program to graph.
std::unique_ptr<ir::Graph> graph(new ir::Graph(main_program));
// Apply a graph viz pass to record a graph.
if (!strategy.debug_graphviz_path_.empty()) {
auto viz_pass = ir::PassRegistry::Instance().Get("graph_viz_pass");
const std::string graph_path = string::Sprintf(
"%s%s", strategy.debug_graphviz_path_.c_str(), "_original_graph");
viz_pass->Set<std::string>("graph_viz_path", new std::string(graph_path));
graph = viz_pass->Apply(std::move(graph));
}
// Apply op fusion.
if (strategy.fuse_elewise_add_act_ops_) {
auto fuse_elewise_add_act_pass =
ir::PassRegistry::Instance().Get("fuse_elewise_add_act_pass");
graph = fuse_elewise_add_act_pass->Apply(std::move(graph));
// Apply a graph viz pass to record a graph.
if (!strategy.debug_graphviz_path_.empty()) {
auto viz_pass = ir::PassRegistry::Instance().Get("graph_viz_pass");
const std::string graph_path = string::Sprintf(
"%s%s", strategy.debug_graphviz_path_.c_str(), "_fused_graph");
viz_pass->Set<std::string>("graph_viz_path", new std::string(graph_path));
graph = viz_pass->Apply(std::move(graph));
}
}
// Convert graph to run on multi-devices.
auto multi_devices_pass =
ir::PassRegistry::Instance().Get("multi_devices_pass");
multi_devices_pass->SetNotOwned<const std::vector<platform::Place>>("places",
&places);
multi_devices_pass->SetNotOwned<const std::string>("loss_var_name",
&loss_var_name);
multi_devices_pass->SetNotOwned<const std::unordered_set<std::string>>(
"params", &param_names);
multi_devices_pass->SetNotOwned<const std::vector<Scope *>>("local_scopes",
&local_scopes);
multi_devices_pass->SetNotOwned<const BuildStrategy>("strategy", &strategy);
#ifdef PADDLE_WITH_CUDA
platform::NCCLContextMap *nctx = use_cuda ? nccl_ctxs : nullptr;
multi_devices_pass->SetNotOwned<platform::NCCLContextMap>("nccl_ctxs", nctx);
#endif
graph = multi_devices_pass->Apply(std::move(graph));
// Apply a graph print pass to record a graph with device info.
if (!strategy.debug_graphviz_path_.empty()) {
auto multi_devices_print_pass =
ir::PassRegistry::Instance().Get("multi_devices_print_pass");
multi_devices_print_pass->SetNotOwned<const std::string>(
"debug_graphviz_path", &strategy.debug_graphviz_path_);
multi_devices_print_pass->Set<details::GraphvizSSAGraphPrinter>(
"graph_printer", new details::GraphvizSSAGraphPrinter);
graph = multi_devices_print_pass->Apply(std::move(graph));
}
// Verify that the graph is correct for multi-device executor.
auto multi_devices_check_pass =
ir::PassRegistry::Instance().Get("multi_devices_check_pass");
graph = multi_devices_check_pass->Apply(std::move(graph));
return graph;
}
class ParallelExecutorPrivate {
public:
explicit ParallelExecutorPrivate(const std::vector<platform::Place> &places)
......@@ -199,10 +123,9 @@ ParallelExecutor::ParallelExecutor(
// Step 3. Convert main_program to SSA form and dependency graph. Also, insert
// ncclOp
#ifdef PADDLE_WITH_CUDA
std::unique_ptr<ir::Graph> graph = ApplyParallelExecutorPass(
std::unique_ptr<ir::Graph> graph = build_strategy.Apply(
main_program, member_->places_, loss_var_name, params,
member_->local_scopes_, member_->use_cuda_, build_strategy,
member_->nccl_ctxs_.get());
member_->local_scopes_, member_->use_cuda_, member_->nccl_ctxs_.get());
auto max_memory_size = GetEagerDeletionThreshold();
if (max_memory_size >= 0) {
......@@ -228,9 +151,9 @@ ParallelExecutor::ParallelExecutor(
}
}
#else
std::unique_ptr<ir::Graph> graph = ApplyParallelExecutorPass(
main_program, member_->places_, loss_var_name, params,
member_->local_scopes_, member_->use_cuda_, build_strategy);
std::unique_ptr<ir::Graph> graph =
build_strategy.Apply(main_program, member_->places_, loss_var_name,
params, member_->local_scopes_, member_->use_cuda_);
#endif
if (exec_strategy.type_ == ExecutionStrategy::kDefault) {
......@@ -373,12 +296,6 @@ ParallelExecutor::~ParallelExecutor() {
} // namespace framework
} // namespace paddle
USE_PASS(fuse_elewise_add_act_pass);
USE_PASS(graph_viz_pass);
USE_PASS(multi_devices_pass);
USE_PASS(multi_devices_check_pass);
USE_PASS(multi_devices_print_pass);
#ifdef PADDLE_WITH_CUDA
USE_PASS(reference_count_pass);
#endif
......@@ -14,14 +14,14 @@ limitations under the License. */
#pragma once
#include <paddle/fluid/framework/details/build_strategy.h>
#include <atomic>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/details/build_strategy.h"
#include "paddle/fluid/framework/details/execution_strategy.h"
#include "paddle/fluid/framework/details/multi_devices_graph_pass.h"
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/program_desc.h"
......
set(PYBIND_DEPS pybind python proto_desc memory executor prune feed_fetch_method)
set(PYBIND_DEPS pybind python proto_desc memory executor prune feed_fetch_method pass_builder)
set(PYBIND_SRCS pybind.cc exception.cc protobuf.cc const_value.cc)
if(NOT WIN32)
list(APPEND PYBIND_DEPS parallel_executor profiler)
......
......@@ -25,6 +25,7 @@ limitations under the License. */
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/feed_fetch_method.h"
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/ir/pass_builder.h"
#include "paddle/fluid/framework/lod_rank_table.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/lod_tensor_array.h"
......@@ -595,6 +596,28 @@ All parameter, weight, gradient are variables in Paddle.
m.def("is_profiler_enabled", platform::IsProfileEnabled);
m.def("reset_profiler", platform::ResetProfiler);
py::class_<ir::Pass, std::shared_ptr<ir::Pass>> pass(m, "Pass");
pass.def(py::init())
.def("set_str", [](ir::Pass &self, const std::string &name,
const std::string &attr) {
self.Set<std::string>(name, new std::string(attr));
});
py::class_<ir::PassBuilder> pb(m, "PassBuilder");
pb.def(py::init())
.def("append_pass",
[](ir::PassBuilder &self,
const std::string &pass_type) -> std::shared_ptr<ir::Pass> {
return self.AppendPass(pass_type);
})
.def("all_passes", [](ir::PassBuilder &self) { return self.AllPasses(); })
.def("insert_pass",
[](ir::PassBuilder &self, size_t idx, const std::string &pass_type) {
return self.InsertPass(idx, pass_type);
})
.def("remove_pass",
[](ir::PassBuilder &self, size_t idx) { self.RemovePass(idx); });
// -- python binds for parallel executor.
py::class_<ParallelExecutor> pe(m, "ParallelExecutor");
py::class_<ExecutionStrategy> exec_strategy(pe, "ExecutionStrategy");
......@@ -677,7 +700,10 @@ All parameter, weight, gradient are variables in Paddle.
},
[](BuildStrategy &self, bool b) {
self.fuse_elewise_add_act_ops_ = b;
});
})
.def("create_pass_builder",
[](BuildStrategy &self) { return *self.CreatePassBuilder(); },
py::return_value_policy::reference);
pe.def(py::init<const std::vector<platform::Place> &,
const std::unordered_set<std::string> &,
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import paddle.fluid as fluid
import paddle.fluid.core as core
import numpy as np
import unittest
import os
import sys
import math
def simple_fc_net():
img = fluid.layers.data(name='image', shape=[784], dtype='float32')
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
hidden = img
for _ in range(4):
hidden = fluid.layers.fc(
hidden,
size=200,
act='tanh',
bias_attr=fluid.ParamAttr(
initializer=fluid.initializer.Constant(value=1.0)))
prediction = fluid.layers.fc(hidden, size=10, act='softmax')
loss = fluid.layers.cross_entropy(input=prediction, label=label)
loss = fluid.layers.mean(loss)
return loss
class TestPassBuilder(unittest.TestCase):
def check_network_convergence(self, use_cuda, build_strategy=None):
os.environ['CPU_NUM'] = str(4)
main = fluid.Program()
startup = fluid.Program()
with fluid.program_guard(main, startup):
loss = simple_fc_net()
test_program = main.clone(for_test=True)
opt = fluid.optimizer.SGD(learning_rate=0.001)
opt.minimize(loss)
batch_size = 32
image = np.random.normal(size=(batch_size, 784)).astype('float32')
label = np.random.randint(0, 10, (batch_size, 1), dtype="int64")
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(startup)
feed_dict = {'image': image, 'label': label}
train_exe = fluid.ParallelExecutor(
use_cuda=use_cuda,
loss_name=loss.name,
main_program=main,
build_strategy=build_strategy)
test_exe = fluid.ParallelExecutor(
use_cuda=use_cuda,
main_program=test_program,
share_vars_from=train_exe,
build_strategy=build_strategy)
for i in range(5):
test_loss, = test_exe.run([loss.name], feed=feed_dict)
train_loss, = train_exe.run([loss.name], feed=feed_dict)
avg_test_loss_val = np.array(test_loss).mean()
if math.isnan(float(avg_test_loss_val)):
sys.exit("got NaN loss, testing failed.")
avg_train_loss_val = np.array(train_loss).mean()
if math.isnan(float(avg_train_loss_val)):
sys.exit("got NaN loss, training failed.")
self.assertTrue(
np.allclose(
train_loss, test_loss, atol=1e-8),
"Train loss: " + str(train_loss) + "\n Test loss:" +
str(test_loss))
def test_parallel_testing_with_new_strategy(self):
build_strategy = fluid.BuildStrategy()
pass_builder = build_strategy.create_pass_builder()
viz_pass = pass_builder.append_pass("graph_viz_pass")
all_passes = pass_builder.all_passes()
pass_builder.insert_pass(len(all_passes), "graph_viz_pass")
pass_builder.remove_pass(len(pass_builder.all_passes()) - 1)
viz_pass.set_str("graph_viz_path", "/tmp/viz_pass")
self.check_network_convergence(
use_cuda=core.is_compiled_with_cuda(),
build_strategy=build_strategy)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册