未验证 提交 66c7c7eb 编写于 作者: R Ruibiao Chen 提交者: GitHub

Allow manaully set py_reader name in standalone executor (#45898)

* Allow manaully set py_reader name in standalone executor

* Fix CI errors
上级 54e1a7cc
...@@ -43,7 +43,6 @@ namespace interpreter { ...@@ -43,7 +43,6 @@ namespace interpreter {
using VariableIdMap = std::map<std::string, std::vector<int>>; using VariableIdMap = std::map<std::string, std::vector<int>>;
constexpr size_t kPrepareWorkQueueIdx = 2; constexpr size_t kPrepareWorkQueueIdx = 2;
const char blocking_queue_prefix[] = "lod_tensor_blocking_queue";
const std::vector<WorkQueueOptions> ConstructWorkQueueOptions( const std::vector<WorkQueueOptions> ConstructWorkQueueOptions(
size_t host_num_threads, size_t device_num_threads, EventsWaiter* waiter) { size_t host_num_threads, size_t device_num_threads, EventsWaiter* waiter) {
...@@ -281,11 +280,12 @@ void create_all_ops(const framework::BlockDesc& block, ...@@ -281,11 +280,12 @@ void create_all_ops(const framework::BlockDesc& block,
} }
} }
std::tuple<VariableValueMap, VariableIdMap> build_variable_map( std::tuple<VariableValueMap, VariableIdMap> BuildVariableMap(
const VariableNameMap& var_name_map, const VariableNameMap& var_name_map,
VariableScope* var_scope, VariableScope* var_scope,
Scope* local_scope, Scope* local_scope,
bool enforce_exist = true) { bool allow_var_not_in_program = false,
bool allow_var_not_in_scope = false) {
VariableValueMap name2var; VariableValueMap name2var;
VariableIdMap name2id; VariableIdMap name2id;
for (auto& item : var_name_map) { for (auto& item : var_name_map) {
...@@ -295,14 +295,10 @@ std::tuple<VariableValueMap, VariableIdMap> build_variable_map( ...@@ -295,14 +295,10 @@ std::tuple<VariableValueMap, VariableIdMap> build_variable_map(
for (auto& var_name : item.second) { for (auto& var_name : item.second) {
if (!var_scope->HasVar(var_name)) { if (!var_scope->HasVar(var_name)) {
// Hot fix for variables used in dataloader, like if (allow_var_not_in_program && local_scope->FindVar(var_name)) {
// 'lod_tensor_blocking_queue_0' These variables may be created in VLOG(3) << "Add " << var_name << " to var_scope";
// scope, and it is not existed as variable in program.
if (var_name.find(blocking_queue_prefix) != std::string::npos &&
local_scope->FindVar(var_name)) {
var_scope->AddVar(var_name, nullptr); var_scope->AddVar(var_name, nullptr);
} else if (!enforce_exist) { } else if (allow_var_not_in_scope) {
// skip the non-exist variable: such as recurrent_grad
VLOG(4) << var_name << " don't exist in variable scope, skip it!"; VLOG(4) << var_name << " don't exist in variable scope, skip it!";
continue; continue;
} }
...@@ -449,35 +445,51 @@ void build_op_func_list(const platform::Place& place, ...@@ -449,35 +445,51 @@ void build_op_func_list(const platform::Place& place,
for (size_t i = 0; i < ops.size(); ++i) { for (size_t i = 0; i < ops.size(); ++i) {
auto op = ops[i].get(); auto op = ops[i].get();
VLOG(6) << "Build OpFuncNode from : " << op->Type(); const std::string& op_type = op->Type();
VLOG(6) << "Build OpFuncNode from : " << op_type;
// Print new executor log if grad op is used. // Print new executor log if grad op is used.
// It's only for test and will be removed later. // It's only for test and will be removed later.
if (!flag_log_is_printed && op->Type().find("_grad") != std::string::npos) { if (!flag_log_is_printed && op_type.find("_grad") != std::string::npos) {
VLOG(0) << "Standalone Executor is Used."; VLOG(0) << "Standalone Executor is Used.";
flag_log_is_printed = true; flag_log_is_printed = true;
} }
auto inputs_names = op->Inputs(); // Hot fix for variables used in dataloader, like
auto outputs_names = op->Outputs(); // 'lod_tensor_blocking_queue_0'. These variables may be created in scope,
// and it is not existed as variable in program.
const std::set<std::string> ops_with_var_not_in_program = {
"create_py_reader"};
const std::set<std::string> ops_with_var_not_in_scope = {
"conditional_block",
"conditional_block_grad",
"recurrent_grad",
"rnn_memory_helper",
"rnn_memory_helper_grad",
"while",
"while_grad"};
bool allow_var_not_in_program = ops_with_var_not_in_program.count(op_type);
bool allow_var_not_in_scope = ops_with_var_not_in_scope.count(op_type);
framework::VariableNameMap& input_name_map = op->Inputs();
VariableValueMap ins_map; VariableValueMap ins_map;
VariableIdMap ins_name2id; VariableIdMap ins_name2id;
bool enforce_exist = true; std::tie(ins_map, ins_name2id) = BuildVariableMap(input_name_map,
if (op->Type() == "recurrent_grad" || op->Type() == "rnn_memory_helper" || var_scope,
op->Type() == "rnn_memory_helper_grad" || local_scope,
op->Type() == "conditional_block" || allow_var_not_in_program,
op->Type() == "conditional_block_grad" || op->Type() == "while" || allow_var_not_in_scope);
op->Type() == "while_grad") {
enforce_exist = false;
}
std::tie(ins_map, ins_name2id) =
build_variable_map(inputs_names, var_scope, local_scope, enforce_exist);
framework::VariableNameMap& output_name_map = op->Outputs();
VariableValueMap outs_map; VariableValueMap outs_map;
VariableIdMap outs_name2id; VariableIdMap outs_name2id;
std::tie(outs_map, outs_name2id) = build_variable_map( std::tie(outs_map, outs_name2id) =
outputs_names, var_scope, local_scope, enforce_exist); BuildVariableMap(output_name_map,
var_scope,
local_scope,
/*allow_var_not_in_program=*/false,
allow_var_not_in_scope);
// step 1: build OpFuncNode // step 1: build OpFuncNode
OpFuncNode op_func_node; OpFuncNode op_func_node;
...@@ -634,7 +646,7 @@ void build_op_func_list(const platform::Place& place, ...@@ -634,7 +646,7 @@ void build_op_func_list(const platform::Place& place,
if (framework::IsComplexType(kernel_type.data_type_)) { if (framework::IsComplexType(kernel_type.data_type_)) {
interpreter::HandleComplexGradToRealGrad(op_func_node, interpreter::HandleComplexGradToRealGrad(op_func_node,
place, place,
outputs_names, output_name_map,
&runtime_context.outputs, &runtime_context.outputs,
var_scope, var_scope,
vec_func_list, vec_func_list,
...@@ -665,17 +677,17 @@ void build_op_func_list(const platform::Place& place, ...@@ -665,17 +677,17 @@ void build_op_func_list(const platform::Place& place,
} }
} }
} catch (platform::EnforceNotMet& ex) { } catch (platform::EnforceNotMet& ex) {
framework::InsertCallStackInfo(op->Type(), op->Attrs(), &ex); framework::InsertCallStackInfo(op_type, op->Attrs(), &ex);
throw std::move(ex); throw std::move(ex);
} catch (platform::EOFException&) { } catch (platform::EOFException&) {
std::rethrow_exception(std::current_exception()); std::rethrow_exception(std::current_exception());
} catch (std::exception& ex) { } catch (std::exception& ex) {
LOG(WARNING) << op->Type() << " raises an exception " LOG(WARNING) << op_type << " raises an exception "
<< platform::demangle(typeid(ex).name()) << ", " << platform::demangle(typeid(ex).name()) << ", "
<< ex.what(); << ex.what();
std::rethrow_exception(std::current_exception()); std::rethrow_exception(std::current_exception());
} catch (...) { } catch (...) {
LOG(WARNING) << op->Type() << " raises an unknown exception"; LOG(WARNING) << op_type << " raises an unknown exception";
std::rethrow_exception(std::current_exception()); std::rethrow_exception(std::current_exception());
} }
......
...@@ -95,12 +95,15 @@ def simple_fc_net(in_size, ...@@ -95,12 +95,15 @@ def simple_fc_net(in_size,
py_reader = fluid.layers.create_py_reader_by_data( py_reader = fluid.layers.create_py_reader_by_data(
capacity=queue_capacity, capacity=queue_capacity,
use_double_buffer=use_double_buffer, use_double_buffer=use_double_buffer,
feed_list=[in_data, label]) feed_list=[in_data, label],
name=unique_name.generate('py_reader_name'))
else: else:
py_reader = fluid.layers.py_reader(capacity=queue_capacity, py_reader = fluid.layers.py_reader(
shapes=[in_data.shape, label.shape], capacity=queue_capacity,
dtypes=['float32', 'int64'], shapes=[in_data.shape, label.shape],
use_double_buffer=use_double_buffer) dtypes=['float32', 'int64'],
name=unique_name.generate('py_reader_name'),
use_double_buffer=use_double_buffer)
in_data, label = fluid.layers.read_file(py_reader) in_data, label = fluid.layers.read_file(py_reader)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册