未验证 提交 79e758c6 编写于 作者: Z Zeng Jinle 提交者: GitHub

add fix op run order pass (#34427)

* add fix op run order pass

* add ut for fix_op_run_order

* fix ci error

* improve coverage

* improve coverge again and fix cpu test case

* follow some comments
上级 9d985ca1
......@@ -134,7 +134,8 @@ set(IR_PASS_DEPS graph_viz_pass multi_devices_graph_pass
modify_op_lock_and_record_event_pass
coalesce_grad_tensor_pass fuse_all_reduce_op_pass backward_optimizer_op_deps_pass
fuse_adam_op_pass fuse_sgd_op_pass fuse_momentum_op_pass
sync_batch_norm_pass runtime_context_cache_pass graph_to_program_pass)
sync_batch_norm_pass runtime_context_cache_pass graph_to_program_pass
fix_op_run_order_pass)
if(NOT APPLE AND NOT WIN32 AND (WITH_GPU OR WITH_ROCM))
set(IR_PASS_DEPS ${IR_PASS_DEPS} fusion_group_pass)
endif()
......
......@@ -100,6 +100,9 @@ struct BuildStrategy {
// while running.
bool cache_runtime_context_{false};
// Fix the op run order.
bool fix_op_run_order_{false};
// Operator fusion
// TODO(dev-paddle): fuse_elewise_add_act_ops may cause some models have
// cycle.
......
......@@ -19,6 +19,7 @@
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
#include "paddle/fluid/platform/cuda_device_guard.h"
#endif
#include <algorithm>
namespace paddle {
namespace framework {
......@@ -177,6 +178,16 @@ void EagerDeletionOpHandle::ClearGarbages(
#endif
}
std::vector<std::string> EagerDeletionOpHandle::VarsToDelete() const {
std::vector<std::string> var_names;
var_names.reserve(var_infos_.size());
for (auto &info : var_infos_) {
var_names.emplace_back(info->Name());
}
std::sort(var_names.begin(), var_names.end());
return var_names;
}
} // namespace details
} // namespace framework
} // namespace paddle
......@@ -64,6 +64,8 @@ class EagerDeletionOpHandle : public OpHandleBase {
size_t GetScopeIdx() const { return scope_idx_; }
std::vector<std::string> VarsToDelete() const;
protected:
void RunImpl() override;
......
......@@ -40,9 +40,14 @@ FastThreadedSSAGraphExecutor::FastThreadedSSAGraphExecutor(
places_(places),
graph_(graph),
fetch_ctxs_(places),
pool_(strategy.num_threads_),
// add one more thread for generate op_deps
prepare_pool_(1) {
if (ir::IsTopologySortOperationsUnique(*graph_)) {
VLOG(10)
<< "Change thread number to 1 because the toposort order is unique";
strategy_.num_threads_ = 1;
}
pool_.reset(new ::ThreadPool(strategy.num_threads_));
for (auto &op : ir::FilterByNodeWrapper<OpHandleBase>(*graph_)) {
int dep = static_cast<int>(op->NotReadyInputSize());
op_deps_.emplace(op, dep);
......@@ -223,7 +228,7 @@ void FastThreadedSSAGraphExecutor::RunOpAsync(
OpHandleBase *op,
const std::shared_ptr<BlockingQueue<size_t>> &complete_q) {
++remaining_;
this->pool_.enqueue([=] {
this->pool_->enqueue([=] {
std::deque<OpHandleBase *> op_queue;
op_queue.push_front(op);
......
......@@ -60,7 +60,7 @@ class FastThreadedSSAGraphExecutor : public SSAGraphExecutor {
atomic_op_deps_;
ExceptionHolder exception_;
::ThreadPool pool_;
std::unique_ptr<::ThreadPool> pool_;
::ThreadPool prepare_pool_;
std::vector<OpHandleBase *> traced_ops_;
......
......@@ -111,6 +111,7 @@ message BuildStrategy {
optional bool fuse_bn_add_act_ops = 10 [ default = true ];
optional bool enable_auto_fusion = 11 [ default = false ];
optional bool enable_addto = 12 [ default = false ];
optional bool fix_op_run_order = 13 [ default = false ];
}
message ExecutionStrategy {
......
......@@ -13,6 +13,7 @@
// limitations under the License.
#include "paddle/fluid/framework/ir/coalesce_grad_tensor_pass.h"
#include <algorithm>
#include <string>
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
......@@ -254,8 +255,15 @@ class CoalesceGradTensorPass : public ir::Pass {
const std::unordered_map<std::string, std::vector<ir::Node *>> &vars_info,
const details::ParamsAndGrads &params_grads,
details::GroupParamsAndGrads *group_params_grads) const {
if (GetFuseParameterMemorySize() == 0) {
group_params_grads->resize(1);
auto &result_param_grads = (*group_params_grads)[0];
result_param_grads = params_grads;
std::sort(result_param_grads.begin(), result_param_grads.end());
} else {
SetGroupAccordingToLayers(vars_info, params_grads, group_params_grads);
SetGroupAccordingToMemorySize(vars_info, group_params_grads);
}
if (!IsUnifiedDtype(params_grads, vars_info)) {
ReGroupByDtype(vars_info, group_params_grads);
}
......
......@@ -143,6 +143,32 @@ std::vector<ir::Node *> TopologySortOperations(const Graph &graph) {
return ret;
}
bool IsTopologySortOperationsUnique(const Graph &graph) {
auto nodes = TopologySortOperations(graph);
size_t n = nodes.size();
for (size_t i = 1; i < n; ++i) {
auto *prev_op = nodes[i - 1];
auto *cur_op = nodes[i];
std::unordered_set<Node *> prev_op_outputs;
for (auto *output : prev_op->outputs) {
prev_op_outputs.insert(output);
}
bool found = false;
for (auto *input : cur_op->inputs) {
if (prev_op_outputs.count(input) > 0) {
found = true;
break;
}
}
if (!found) {
return false;
}
}
return true;
}
// Build operator inlink edge table.
std::map<ir::Node *, std::set<ir::Node *, ir::NodeComp>, ir::NodeComp>
BuildOperationAdjList(const Graph &graph) {
......
......@@ -57,6 +57,9 @@ size_t GraphNum(const Graph &graph);
// `graph` cannot contain circle.
std::vector<ir::Node *> TopologySortOperations(const Graph &graph);
// Check whether the topological order of graph ops is unique
bool IsTopologySortOperationsUnique(const Graph &graph);
// Topological sort, but try to DFS.
std::vector<ir::Node *> TopologyDfsSortOperations(const Graph &graph);
......
......@@ -18,3 +18,4 @@ cc_library(fuse_all_reduce_op_pass SRCS fuse_all_reduce_op_pass.cc DEPS graph gr
cc_library(all_reduce_deps_pass SRCS all_reduce_deps_pass.cc DEPS all_reduce_op_handle graph graph_helper pass)
cc_library(backward_optimizer_op_deps_pass SRCS backward_optimizer_op_deps_pass.cc DEPS graph graph_helper pass)
cc_library(add_reader_dependency_pass SRCS add_reader_dependency_pass.cc DEPS graph graph_helper pass)
cc_library(fix_op_run_order_pass SRCS fix_op_run_order_pass DEPS graph graph_helper multi_devices_helper pass op_handle_base eager_deletion_op_handle)
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <algorithm>
#include "paddle/fluid/framework/details/eager_deletion_op_handle.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/details/op_handle_base.h"
#include "paddle/fluid/framework/details/var_handle.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/memory_optimize_pass/op_graph_view.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/string/string_helper.h"
namespace paddle {
namespace framework {
namespace ir {
static std::string kSep(1, static_cast<char>(1)); // NOLINT
// NOTE: VariableNameMap is sorted!
static std::string VarNameMapToString(const VariableNameMap &var_map) {
std::vector<std::string> tmp_strs;
tmp_strs.reserve(var_map.size());
for (auto &pair : var_map) {
auto str = pair.first + kSep + string::join_strings(pair.second, kSep);
tmp_strs.emplace_back(std::move(str));
}
return string::join_strings(tmp_strs, kSep);
}
static std::string OpDescToString(const OpDesc &op) {
return "OpDesc" + kSep + op.Type() + kSep + VarNameMapToString(op.Inputs()) +
kSep + VarNameMapToString(op.Outputs());
}
static std::string VarHandleListToString(
const std::vector<details::VarHandleBase *> &vars) {
std::vector<std::string> valid_vars;
valid_vars.reserve(vars.size());
for (auto *v : vars) {
auto *valid_var = dynamic_cast<details::VarHandle *>(v);
if (valid_var != nullptr) {
valid_vars.emplace_back(valid_var->Name());
}
}
std::sort(valid_vars.begin(), valid_vars.end());
return string::join_strings(valid_vars, kSep);
}
static std::string EagerDeletionOpHandleToString(
const details::EagerDeletionOpHandle &op);
static std::string OpHandleToString(const details::OpHandleBase &op);
static std::string EagerDeletionOpHandleToString(
const details::EagerDeletionOpHandle &op) {
auto vars_to_delete = op.VarsToDelete();
std::unordered_set<details::OpHandleBase *> prev_ops;
std::vector<std::string> prev_op_strs;
prev_op_strs.reserve(op.Inputs().size());
for (auto *var : op.Inputs()) {
auto *prev_op = var->GeneratedOp();
if (prev_op == nullptr) continue;
prev_op_strs.push_back(OpHandleToString(*prev_op));
}
std::sort(prev_op_strs.begin(), prev_op_strs.end());
// NOTE: gc op does not have any valid input/output vars
return "OpHandleBase" + kSep + op.Name() + kSep +
string::join_strings(vars_to_delete, kSep) + kSep +
string::join_strings(prev_op_strs, kSep);
}
static std::string OpHandleToString(const details::OpHandleBase &op) {
// NOTE: gc op does not have any valid input/output vars
auto gc_op = dynamic_cast<const details::EagerDeletionOpHandle *>(&op);
if (gc_op) {
return EagerDeletionOpHandleToString(*gc_op);
}
return "OpHandleBase" + kSep + op.Name() + kSep +
VarHandleListToString(op.Inputs()) + kSep +
VarHandleListToString(op.Outputs());
}
static void AddSequentialDepsForSortedOps(
Graph *graph, const std::vector<details::OpHandleBase *> &sorted_ops) {
size_t n = sorted_ops.size();
for (size_t i = 1; i < n; ++i) {
auto *prev_op = sorted_ops[i - 1];
auto *cur_op = sorted_ops[i];
auto *dep_var = new details::DummyVarHandle(graph->CreateControlDepVar());
graph->Get<details::GraphDepVars>(details::kGraphDepVars).emplace(dep_var);
prev_op->AddOutput(dep_var);
cur_op->AddInput(dep_var);
}
}
class FixOpRunOrderPass : public Pass {
protected:
void ApplyImpl(Graph *graph) const override {
const auto &program = graph->OriginProgram();
std::unordered_map<std::string, size_t> op_to_idx;
size_t i = 0;
for (auto *op_desc : program.Block(0).AllOps()) {
auto op_desc_str = OpDescToString(*op_desc);
PADDLE_ENFORCE_EQ(
op_to_idx.emplace(op_desc_str, i).second, true,
platform::errors::PermissionDenied(
"FixOpRunOrderPass cannot handle OpDesc with same "
"type, inputs and outputs yet, error string repr: %s",
op_desc_str));
++i;
}
// a map to record: "Node" -> "Node Index"
std::unordered_map<Node *, size_t> node_to_idx;
// a map to record found "Node Index"
std::unordered_set<size_t> found_node_indices;
// a map to record the new OpDesc created by other Passes. These ops does
// not exist in the origin program
std::map<std::string, Node *> new_op_desc_nodes;
// a map to record the new OpHandle created by other Passes. These ops does
// not have OpDesc and does not exist in the origin program
std::map<std::string, Node *> new_op_handle_nodes;
// Step 1: handle the unchanged OpDesc, and record new OpDesc/OpHandle
auto op_handles = FilterByNodeWrapper<details::OpHandleBase>(*graph);
for (auto *op_handle : op_handles) {
auto *node = op_handle->Node();
if (node->Op() == nullptr) {
auto node_str = OpHandleToString(*op_handle);
PADDLE_ENFORCE_EQ(new_op_handle_nodes.emplace(node_str, node).second,
true,
platform::errors::PermissionDenied(
"FixOpRunOrderPass cannot OpHandle with same "
"inputs and outputs yet, error repr: %s",
node_str));
continue;
}
auto node_str = OpDescToString(*(node->Op()));
auto iter = op_to_idx.find(node_str);
if (iter != op_to_idx.end()) {
size_t idx = iter->second;
PADDLE_ENFORCE_EQ(
found_node_indices.count(idx), 0,
platform::errors::PermissionDenied(
"FixOpRunOrderPass cannot handle OpDesc with same "
"type, inputs and outputs yet, error repr: %s",
node_str));
found_node_indices.insert(idx);
node_to_idx[node] = idx;
} else {
PADDLE_ENFORCE_EQ(
new_op_desc_nodes.emplace(node_str, node).second, true,
platform::errors::PermissionDenied(
"FixOpRunOrderPass cannot handle OpDesc with same "
"type, inputs and outputs yet, error repr: %s",
node_str));
}
}
VLOG(10) << "Found unchanged OpDesc " << node_to_idx.size()
<< ", new OpDesc " << new_op_desc_nodes.size() << ", new OpHandle "
<< new_op_handle_nodes.size();
// Step 2: assign node index to new OpDesc
size_t node_id_offset = op_to_idx.size();
for (auto &pair : new_op_desc_nodes) {
node_to_idx[pair.second] = node_id_offset;
++node_id_offset;
}
// Step 3: assign node index to new OpHandle
for (auto &pair : new_op_handle_nodes) {
node_to_idx[pair.second] = node_id_offset;
++node_id_offset;
}
// Step 4: sort unchanged OpDesc/new OpDesc/new OpHandle by topological
// order and node index
OpGraphView graph_view(op_handles);
auto comp = [&node_to_idx](details::OpHandleBase *op1,
details::OpHandleBase *op2) {
auto priority1 = static_cast<int>(op1->GetPriority());
auto priority2 = static_cast<int>(op2->GetPriority());
if (priority1 != priority2) {
return priority1 < priority2;
}
return node_to_idx.at(op1->Node()) < node_to_idx.at(op2->Node());
};
std::vector<details::OpHandleBase *> sorted_ops;
sorted_ops.reserve(op_handles.size());
std::queue<details::OpHandleBase *> q;
std::vector<details::OpHandleBase *> tmp_ops;
auto op_deps = graph_view.GetPrecedingDepNum();
// Get ready ops first
for (auto iter = op_deps.begin(); iter != op_deps.end();) {
if (iter->second != 0) {
++iter;
continue;
}
tmp_ops.push_back(iter->first);
op_deps.erase(iter++);
}
// Sort ready ops by node index
std::sort(tmp_ops.begin(), tmp_ops.end(), comp);
for (auto *op : tmp_ops) {
q.push(op);
}
while (!q.empty()) {
auto *cur_op = q.front();
q.pop();
sorted_ops.push_back(cur_op);
auto &pending_ops = graph_view.PendingOps(cur_op);
tmp_ops.clear();
for (auto *pending_op : pending_ops) {
if (--op_deps.at(pending_op) == 0) {
op_deps.erase(pending_op);
tmp_ops.push_back(pending_op);
}
}
// sort next ready ops by node index
std::sort(tmp_ops.begin(), tmp_ops.end(), comp);
for (auto *op : tmp_ops) {
q.push(op);
}
}
PADDLE_ENFORCE_EQ(
sorted_ops.size(), op_handles.size(),
platform::errors::PermissionDenied("There are unvisited ops"));
if (VLOG_IS_ON(10)) {
// print op order to debug
std::vector<size_t> sorted_ops_indices;
sorted_ops_indices.reserve(sorted_ops.size());
for (auto *op : sorted_ops) {
sorted_ops_indices.push_back(node_to_idx.at(op->Node()));
}
VLOG(10) << "Fix op order: "
<< string::join_strings(sorted_ops_indices, ',');
}
// Step 5: add sequential deps for ops to guarantee there is only one
// toposort order
AddSequentialDepsForSortedOps(graph, sorted_ops);
PADDLE_ENFORCE_EQ(IsTopologySortOperationsUnique(*graph), true,
platform::errors::PermissionDenied(
"The topological order must be unique "
"after FixOpRunOrderPass is applied"));
}
};
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(fix_op_run_order_pass, paddle::framework::ir::FixOpRunOrderPass);
......@@ -104,6 +104,13 @@ class ParallelExecutorPrivate {
inline bool HasGarbageCollectors() const { return !gcs_.empty(); }
void ApplyFixOpRunOrderPass(ir::Graph *graph) {
if (build_strategy_.fix_op_run_order_) {
auto pass = ir::PassRegistry::Instance().Get("fix_op_run_order_pass");
pass->Apply(graph);
}
}
/**
* NOTE(zengjinle): the fed variables of users should not be reused,
* because users may feed them into another network. Changing the fed
......@@ -1462,6 +1469,10 @@ std::vector<ir::Graph *> ParallelExecutor::CreateSSAGraphExecutor(
auto possible_inference_graphs =
details::TrySeparateToMultipleSingleDeviceGraphs(graph);
if (!possible_inference_graphs.empty()) {
for (auto &g : possible_inference_graphs) {
member_->ApplyFixOpRunOrderPass(g.get());
}
VLOG(5) << "Use ParallelSSAGraphExecutor in inference phase";
auto *pg_exe = new details::ParallelSSAGraphExecutor(
exec_strategy, member_->local_scopes_, member_->local_exec_scopes_,
......@@ -1474,6 +1485,9 @@ std::vector<ir::Graph *> ParallelExecutor::CreateSSAGraphExecutor(
member_->executor_.reset(pg_exe);
member_->inference_executor_ = pg_exe;
} else {
if (member_->places_.size() == 1) {
member_->ApplyFixOpRunOrderPass(graph);
}
LOG_IF(WARNING, details::HasKeepLastReadOp(*graph))
<< "drop_last=False for DataLoader is not supported in training "
"network. It is automatically turned to drop_last=True.";
......@@ -1560,3 +1574,4 @@ USE_PASS(eager_deletion_pass);
USE_PASS(buffer_shared_inplace_pass);
USE_PASS(buffer_shared_cross_op_memory_reuse_pass);
USE_PASS(inplace_addto_op_pass);
USE_PASS(fix_op_run_order_pass);
......@@ -29,6 +29,10 @@ NCCL_RAND_ROUTINE_EACH(DEFINE_WRAP);
NCCL_RAND_ROUTINE_EACH_AFTER_2212(DEFINE_WRAP)
#endif
#if NCCL_VERSION_CODE >= 2304
NCCL_RAND_ROUTINE_EACH_AFTER_2304(DEFINE_WRAP)
#endif
#if NCCL_VERSION_CODE >= 2703
NCCL_RAND_ROUTINE_EACH_AFTER_2703(DEFINE_WRAP)
#endif
......
......@@ -64,6 +64,11 @@ NCCL_RAND_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_NCCL_WRAP)
NCCL_RAND_ROUTINE_EACH_AFTER_2212(DECLARE_DYNAMIC_LOAD_NCCL_WRAP)
#endif
#if NCCL_VERSION_CODE >= 2304
#define NCCL_RAND_ROUTINE_EACH_AFTER_2304(__macro) __macro(ncclGetVersion);
NCCL_RAND_ROUTINE_EACH_AFTER_2304(DECLARE_DYNAMIC_LOAD_NCCL_WRAP)
#endif
#if NCCL_VERSION_CODE >= 2703
#define NCCL_RAND_ROUTINE_EACH_AFTER_2703(__macro) \
__macro(ncclSend); \
......
......@@ -467,6 +467,19 @@ static void AssertStaticGraphAndDygraphGradMakerNoDiff() {
string::join_strings(ops, ',')));
}
#ifdef PADDLE_WITH_NCCL
static int GetNCCLVersion() {
#if NCCL_VERSION_CODE >= 2304
int ver;
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclGetVersion(&ver));
return ver;
#else
PADDLE_THROW(platform::errors::External(
"Cannot get NCCL version successfully when nccl version < 2.3.4"));
#endif
}
#endif
#ifdef PADDLE_WITH_AVX
PYBIND11_MODULE(core_avx, m) {
#else
......@@ -496,6 +509,14 @@ PYBIND11_MODULE(core_noavx, m) {
m.def("cudnn_version", &platform::CudnnVersion);
#endif
#ifdef PADDLE_WITH_NCCL
m.def("nccl_version", &GetNCCLVersion);
#endif
m.def("wait_device", [](const platform::Place &place) {
platform::DeviceContextPool::Instance().Get(place)->Wait();
});
m.def("from_dlpack", [](py::capsule *dltensor) {
DLManagedTensor *dmt = reinterpret_cast<DLManagedTensor *>(
PyCapsule_GetPointer(dltensor->ptr(), "dltensor"));
......@@ -1796,17 +1817,17 @@ All parameter, weight, gradient are variables in Paddle.
.def("__str__", string::to_string<const platform::Place &>);
py::class_<OperatorBase>(m, "Operator")
.def_static(
"create",
.def_static("create",
[](py::bytes protobin) {
proto::OpDesc desc;
PADDLE_ENFORCE_EQ(desc.ParsePartialFromString(protobin), true,
PADDLE_ENFORCE_EQ(desc.ParsePartialFromString(protobin),
true,
platform::errors::InvalidArgument(
"Cannot parse user input to OpDesc"));
PADDLE_ENFORCE_EQ(
desc.IsInitialized(), true,
PADDLE_ENFORCE_EQ(desc.IsInitialized(), true,
platform::errors::InvalidArgument(
"The provided OpDesc is not initialized, the reason is: %s",
"The provided OpDesc is not "
"initialized, the reason is: %s",
desc.InitializationErrorString()));
return OpRegistry::CreateOp(desc);
})
......@@ -2928,8 +2949,8 @@ All parameter, weight, gradient are variables in Paddle.
self.memory_optimize_ = (py_obj == Py_True);
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"BuildStrategy.memory_optimize must be set to None, False or "
"True"));
"BuildStrategy.memory_optimize must be set to None, False "
"or True"));
}
},
R"DOC((bool, optional): memory opitimize aims to save total memory
......@@ -3003,6 +3024,12 @@ All parameter, weight, gradient are variables in Paddle.
const std::unordered_set<std::string> &mkldnn_enabled_op_types) {
self.mkldnn_enabled_op_types_ = mkldnn_enabled_op_types;
})
.def_property(
"fix_op_run_order",
[](const BuildStrategy &self) { return self.fix_op_run_order_; },
[](BuildStrategy &self, bool fix_op_run_order) {
self.fix_op_run_order_ = fix_op_run_order;
})
.def("_finalize_strategy_and_create_passes",
[](BuildStrategy &self) -> std::shared_ptr<ir::PassBuilder> {
return self.CreatePassesFromStrategy(true);
......
......@@ -38,7 +38,8 @@ void format_string_append(std::string& str, const char* fmt, // NOLINT
CHECK_GE(len, 0);
size_t oldlen = str.length();
str.resize(oldlen + len + 1);
CHECK(snprintf(&str[oldlen], (size_t)len + 1, fmt, args...) == len);
CHECK(snprintf(&str[oldlen], (size_t)len + 1, fmt, args...) == // NOLINT
len);
str.resize(oldlen + len);
}
......@@ -127,7 +128,24 @@ template <class Container>
std::string join_strings(const Container& strs, char delim) {
std::string str;
int i = 0;
size_t i = 0;
for (auto& elem : strs) {
if (i > 0) {
str += delim;
}
str += boost::lexical_cast<std::string>(elem);
++i;
}
return str;
}
template <class Container>
std::string join_strings(const Container& strs, const std::string& delim) {
std::string str;
size_t i = 0;
for (auto& elem : strs) {
if (i > 0) {
str += delim;
......
......@@ -688,6 +688,7 @@ add_subdirectory(ir)
if (WITH_TESTING)
set_property(TEST test_parallel_executor_mnist PROPERTY ENVIRONMENT GLOG_vmodule=all_reduce_deps_pass=10)
set_property(TEST test_parallel_executor_fix_op_run_order PROPERTY ENVIRONMENT GLOG_vmodule=fix_op_run_order_pass=10)
endif()
set_tests_properties(test_parallel_executor_test_while_train test_parallel_executor_mnist
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.fluid as fluid
import unittest
import numpy as np
from paddle.vision.models import resnet50
from paddle.nn import CrossEntropyLoss
class TestFixOpRunOrder(unittest.TestCase):
def setUp(self):
paddle.enable_static()
paddle.seed(1)
paddle.framework.random._manual_program_seed(1)
if paddle.is_compiled_with_cuda():
fluid.set_flags({'FLAGS_cudnn_deterministic': 1})
def get_place(self):
return paddle.CUDAPlace(0) if paddle.is_compiled_with_cuda(
) else paddle.CPUPlace()
def get_feed(self):
batch_size = 32
image = np.random.random([batch_size, 3, 224, 224]).astype('float32')
label = np.random.randint(0, 1000, [batch_size, 1]).astype('int64')
return {"image": image, "label": label}
def create_model(self, fix_op_run_order):
main_prog = paddle.static.Program()
startup_prog = paddle.static.Program()
scope = paddle.static.Scope()
with paddle.static.program_guard(main_prog, startup_prog):
image = paddle.static.data(
name="image", shape=[None, 3, 224, 224], dtype="float32")
label = paddle.static.data(
name="label", shape=[None, 1], dtype="int64")
model = resnet50()
pred = model(image)
loss_fn = CrossEntropyLoss()
loss = loss_fn(pred, label)
optimizer = paddle.optimizer.SGD(learning_rate=1e-3)
optimizer.minimize(loss)
build_strategy = paddle.static.BuildStrategy()
build_strategy.fix_op_run_order = fix_op_run_order
build_strategy.fuse_bn_act_ops = True
build_strategy.fuse_bn_add_act_ops = True
main_prog = paddle.static.CompiledProgram(main_prog).with_data_parallel(
loss_name=loss.name,
build_strategy=build_strategy,
places=[self.get_place()])
exe = paddle.static.Executor(self.get_place())
with paddle.static.scope_guard(scope):
exe.run(startup_prog)
return main_prog, scope, loss
def run_and_fetch_loss(self, main_prog, scope, loss, feed):
with paddle.static.scope_guard(scope):
exe = paddle.static.Executor(self.get_place())
loss_value = exe.run(main_prog, feed=feed, fetch_list=[loss])[0]
return loss_value
def test_main(self):
if not paddle.is_compiled_with_cuda():
return
main1, scope1, loss1 = self.create_model(True)
main2, scope2, loss2 = self.create_model(False)
for i in range(10):
feed = self.get_feed()
loss_val1 = self.run_and_fetch_loss(main1, scope1, loss1, feed)
loss_val2 = self.run_and_fetch_loss(main2, scope2, loss2, feed)
self.assertEqual(loss_val1, loss_val2)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册