未验证 提交 66c7c7eb 编写于 作者: R Ruibiao Chen 提交者: GitHub

Allow manaully set py_reader name in standalone executor (#45898)

* Allow manaully set py_reader name in standalone executor

* Fix CI errors
上级 54e1a7cc
......@@ -43,7 +43,6 @@ namespace interpreter {
using VariableIdMap = std::map<std::string, std::vector<int>>;
constexpr size_t kPrepareWorkQueueIdx = 2;
const char blocking_queue_prefix[] = "lod_tensor_blocking_queue";
const std::vector<WorkQueueOptions> ConstructWorkQueueOptions(
size_t host_num_threads, size_t device_num_threads, EventsWaiter* waiter) {
......@@ -281,11 +280,12 @@ void create_all_ops(const framework::BlockDesc& block,
}
}
std::tuple<VariableValueMap, VariableIdMap> build_variable_map(
std::tuple<VariableValueMap, VariableIdMap> BuildVariableMap(
const VariableNameMap& var_name_map,
VariableScope* var_scope,
Scope* local_scope,
bool enforce_exist = true) {
bool allow_var_not_in_program = false,
bool allow_var_not_in_scope = false) {
VariableValueMap name2var;
VariableIdMap name2id;
for (auto& item : var_name_map) {
......@@ -295,14 +295,10 @@ std::tuple<VariableValueMap, VariableIdMap> build_variable_map(
for (auto& var_name : item.second) {
if (!var_scope->HasVar(var_name)) {
// Hot fix for variables used in dataloader, like
// 'lod_tensor_blocking_queue_0' These variables may be created in
// scope, and it is not existed as variable in program.
if (var_name.find(blocking_queue_prefix) != std::string::npos &&
local_scope->FindVar(var_name)) {
if (allow_var_not_in_program && local_scope->FindVar(var_name)) {
VLOG(3) << "Add " << var_name << " to var_scope";
var_scope->AddVar(var_name, nullptr);
} else if (!enforce_exist) {
// skip the non-exist variable: such as recurrent_grad
} else if (allow_var_not_in_scope) {
VLOG(4) << var_name << " don't exist in variable scope, skip it!";
continue;
}
......@@ -449,35 +445,51 @@ void build_op_func_list(const platform::Place& place,
for (size_t i = 0; i < ops.size(); ++i) {
auto op = ops[i].get();
VLOG(6) << "Build OpFuncNode from : " << op->Type();
const std::string& op_type = op->Type();
VLOG(6) << "Build OpFuncNode from : " << op_type;
// Print new executor log if grad op is used.
// It's only for test and will be removed later.
if (!flag_log_is_printed && op->Type().find("_grad") != std::string::npos) {
if (!flag_log_is_printed && op_type.find("_grad") != std::string::npos) {
VLOG(0) << "Standalone Executor is Used.";
flag_log_is_printed = true;
}
auto inputs_names = op->Inputs();
auto outputs_names = op->Outputs();
// Hot fix for variables used in dataloader, like
// 'lod_tensor_blocking_queue_0'. These variables may be created in scope,
// and it is not existed as variable in program.
const std::set<std::string> ops_with_var_not_in_program = {
"create_py_reader"};
const std::set<std::string> ops_with_var_not_in_scope = {
"conditional_block",
"conditional_block_grad",
"recurrent_grad",
"rnn_memory_helper",
"rnn_memory_helper_grad",
"while",
"while_grad"};
bool allow_var_not_in_program = ops_with_var_not_in_program.count(op_type);
bool allow_var_not_in_scope = ops_with_var_not_in_scope.count(op_type);
framework::VariableNameMap& input_name_map = op->Inputs();
VariableValueMap ins_map;
VariableIdMap ins_name2id;
bool enforce_exist = true;
if (op->Type() == "recurrent_grad" || op->Type() == "rnn_memory_helper" ||
op->Type() == "rnn_memory_helper_grad" ||
op->Type() == "conditional_block" ||
op->Type() == "conditional_block_grad" || op->Type() == "while" ||
op->Type() == "while_grad") {
enforce_exist = false;
}
std::tie(ins_map, ins_name2id) =
build_variable_map(inputs_names, var_scope, local_scope, enforce_exist);
std::tie(ins_map, ins_name2id) = BuildVariableMap(input_name_map,
var_scope,
local_scope,
allow_var_not_in_program,
allow_var_not_in_scope);
framework::VariableNameMap& output_name_map = op->Outputs();
VariableValueMap outs_map;
VariableIdMap outs_name2id;
std::tie(outs_map, outs_name2id) = build_variable_map(
outputs_names, var_scope, local_scope, enforce_exist);
std::tie(outs_map, outs_name2id) =
BuildVariableMap(output_name_map,
var_scope,
local_scope,
/*allow_var_not_in_program=*/false,
allow_var_not_in_scope);
// step 1: build OpFuncNode
OpFuncNode op_func_node;
......@@ -634,7 +646,7 @@ void build_op_func_list(const platform::Place& place,
if (framework::IsComplexType(kernel_type.data_type_)) {
interpreter::HandleComplexGradToRealGrad(op_func_node,
place,
outputs_names,
output_name_map,
&runtime_context.outputs,
var_scope,
vec_func_list,
......@@ -665,17 +677,17 @@ void build_op_func_list(const platform::Place& place,
}
}
} catch (platform::EnforceNotMet& ex) {
framework::InsertCallStackInfo(op->Type(), op->Attrs(), &ex);
framework::InsertCallStackInfo(op_type, op->Attrs(), &ex);
throw std::move(ex);
} catch (platform::EOFException&) {
std::rethrow_exception(std::current_exception());
} catch (std::exception& ex) {
LOG(WARNING) << op->Type() << " raises an exception "
LOG(WARNING) << op_type << " raises an exception "
<< platform::demangle(typeid(ex).name()) << ", "
<< ex.what();
std::rethrow_exception(std::current_exception());
} catch (...) {
LOG(WARNING) << op->Type() << " raises an unknown exception";
LOG(WARNING) << op_type << " raises an unknown exception";
std::rethrow_exception(std::current_exception());
}
......
......@@ -95,11 +95,14 @@ def simple_fc_net(in_size,
py_reader = fluid.layers.create_py_reader_by_data(
capacity=queue_capacity,
use_double_buffer=use_double_buffer,
feed_list=[in_data, label])
feed_list=[in_data, label],
name=unique_name.generate('py_reader_name'))
else:
py_reader = fluid.layers.py_reader(capacity=queue_capacity,
py_reader = fluid.layers.py_reader(
capacity=queue_capacity,
shapes=[in_data.shape, label.shape],
dtypes=['float32', 'int64'],
name=unique_name.generate('py_reader_name'),
use_double_buffer=use_double_buffer)
in_data, label = fluid.layers.read_file(py_reader)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册