提交 ba06a75f 编写于 作者: J Jay Shi 提交者: TensorFlower Gardener

[tf.data] Add unit test to test the environment variable settings in...

[tf.data] Add unit test to test the environment variable settings in `ObtainOptimizations` function.

PiperOrigin-RevId: 325903283
Change-Id: I45742135cf4afe7c45c29afd9585b0445c07dd0a
上级 1e336d3e
......@@ -906,13 +906,38 @@ bool MatchesAnyVersionRE(StringPiece op_prefix, StringPiece op_to_match) {
}
std::vector<tstring> SelectOptimizations(
const string& job_name, const string& opt_ins_raw,
const string& opt_outs_raw,
const string& job_name,
const absl::flat_hash_map<string, uint64>& live_experiments,
const std::vector<tstring>& optimizations_enabled,
const std::vector<tstring>& optimizations_disabled,
const std::vector<tstring>& optimizations_default,
std::function<uint64(const string&)> hash_func) {
std::vector<tstring> optimizations;
if (job_name.empty()) {
// If `job_name` is empty, apply the enabled and default optimizations
// directly.
optimizations.insert(optimizations.end(), optimizations_enabled.begin(),
optimizations_enabled.end());
optimizations.insert(optimizations.end(), optimizations_default.begin(),
optimizations_default.end());
return optimizations;
}
// If `job_name` is non-empty, we determine which optimizations to apply to
// this job based on the enable/disable settings from tf.data.Options, the
// opt in/out settings from environment variables, and rollout condition from
// `live_experiments`.
const char* opt_ins_raw_cs = std::getenv("TF_DATA_EXPERIMENT_OPT_IN");
const char* opt_outs_raw_cs = std::getenv("TF_DATA_EXPERIMENT_OPT_OUT");
string opt_ins_raw;
if (opt_ins_raw_cs != nullptr) {
opt_ins_raw = string(opt_ins_raw_cs);
}
string opt_outs_raw;
if (opt_outs_raw_cs != nullptr) {
opt_outs_raw = string(opt_outs_raw_cs);
}
// Creates a set of optimizations.
absl::flat_hash_set<tstring> optimizations_set;
......@@ -1018,7 +1043,6 @@ std::vector<tstring> SelectOptimizations(
}
}
std::vector<tstring> optimizations;
optimizations.insert(optimizations.end(), optimizations_set.begin(),
optimizations_set.end());
return optimizations;
......
......@@ -304,12 +304,11 @@ class DummyResourceOp : public OpKernel {
// MatchesAnyVersionRE("PaddedBatchDataset", "BatchDataset") == false
bool MatchesAnyVersionRE(StringPiece op_prefix, StringPiece op_to_match);
// Based on `optimizations_enabled`, `optimizations_disabled`, and
// `optimizations_disabled`, returns the list of optimizations that will be
// Based on `job_name`, `optimizations_enabled`, `optimizations_disabled` and
// `optimizations_default`, returns the list of optimizations that will be
// applied.
std::vector<tstring> SelectOptimizations(
const string& job_name, const string& opt_ins_raw,
const string& opt_outs_raw,
const string& job_name,
const absl::flat_hash_map<string, uint64>& live_experiments,
const std::vector<tstring>& optimizations_enabled,
const std::vector<tstring>& optimizations_disabled,
......
......@@ -1138,18 +1138,15 @@ class SelectOptimizationsHashTest : public ::testing::TestWithParam<uint64> {};
TEST_P(SelectOptimizationsHashTest, DatasetUtils) {
const uint64 hash_result = GetParam();
string job_name = "job";
const string opt_ins_raw = "";
const string opt_outs_raw = "";
auto hash_func = [hash_result](const string& str) { return hash_result; };
absl::flat_hash_map<string, uint64> live_experiments = {
{"exp1", 0}, {"exp2", 20}, {"exp3", 33}, {"exp4", 45},
{"exp5", 67}, {"exp6", 88}, {"exp7", 100}};
std::vector<tstring> optimizations_enabled, optimizations_disabled,
optimizations_default;
std::vector<tstring> optimizations =
SelectOptimizations(job_name, opt_ins_raw, opt_outs_raw, live_experiments,
optimizations_enabled, optimizations_disabled,
optimizations_default, hash_func);
std::vector<tstring> optimizations = SelectOptimizations(
job_name, live_experiments, optimizations_enabled, optimizations_disabled,
optimizations_default, hash_func);
int tested_times = 0;
switch (hash_result) {
......@@ -1182,48 +1179,60 @@ class SelectOptimizationsOptTest
: public ::testing::TestWithParam<std::tuple<string, string>> {};
TEST_P(SelectOptimizationsOptTest, DatasetUtils) {
const string opt_ins = std::get<0>(GetParam());
const string opt_outs = std::get<1>(GetParam());
if (!opt_ins.empty()) {
setenv("TF_DATA_EXPERIMENT_OPT_IN", opt_ins.c_str(), 1);
}
if (!opt_outs.empty()) {
setenv("TF_DATA_EXPERIMENT_OPT_OUT", opt_outs.c_str(), 1);
}
string job_name = "job";
const string opt_ins_raw = std::get<0>(GetParam());
const string opt_outs_raw = std::get<1>(GetParam());
auto hash_func = [](const string& str) { return 50; };
absl::flat_hash_map<string, uint64> live_experiments = {
{"exp1", 0}, {"exp2", 25}, {"exp3", 50}, {"exp4", 75}, {"exp5", 100}};
std::vector<tstring> optimizations_enabled, optimizations_disabled,
optimizations_default;
std::vector<tstring> optimizations =
SelectOptimizations(job_name, opt_ins_raw, opt_outs_raw, live_experiments,
optimizations_enabled, optimizations_disabled,
optimizations_default, hash_func);
std::vector<tstring> optimizations = SelectOptimizations(
job_name, live_experiments, optimizations_enabled, optimizations_disabled,
optimizations_default, hash_func);
int tested_times = 0;
if (opt_outs_raw == "all") {
if (opt_outs == "all") {
EXPECT_THAT(optimizations, UnorderedElementsAre());
tested_times++;
} else if (opt_outs_raw.empty()) {
if (opt_ins_raw == "all") {
} else if (opt_outs.empty()) {
if (opt_ins == "all") {
EXPECT_THAT(optimizations,
UnorderedElementsAre("exp1", "exp2", "exp3", "exp4", "exp5"));
tested_times++;
} else if (opt_ins_raw.empty()) {
} else if (opt_ins.empty()) {
EXPECT_THAT(optimizations, UnorderedElementsAre("exp4", "exp5"));
tested_times++;
} else if (opt_ins_raw == "exp2,exp4") {
} else if (opt_ins == "exp2,exp4") {
EXPECT_THAT(optimizations, UnorderedElementsAre("exp2", "exp4", "exp5"));
tested_times++;
}
} else if (opt_outs_raw == "exp1,exp5") {
if (opt_ins_raw == "all") {
} else if (opt_outs == "exp1,exp5") {
if (opt_ins == "all") {
EXPECT_THAT(optimizations, UnorderedElementsAre("exp2", "exp3", "exp4"));
tested_times++;
} else if (opt_ins_raw.empty()) {
} else if (opt_ins.empty()) {
EXPECT_THAT(optimizations, UnorderedElementsAre("exp4"));
tested_times++;
} else if (opt_ins_raw == "exp2,exp4") {
} else if (opt_ins == "exp2,exp4") {
EXPECT_THAT(optimizations, UnorderedElementsAre("exp2", "exp4"));
tested_times++;
}
}
EXPECT_EQ(tested_times, 1);
if (!opt_ins.empty()) {
unsetenv("TF_DATA_EXPERIMENT_OPT_IN");
}
if (!opt_outs.empty()) {
unsetenv("TF_DATA_EXPERIMENT_OPT_OUT");
}
}
INSTANTIATE_TEST_SUITE_P(
......@@ -1235,10 +1244,16 @@ class SelectOptimizationsConflictTest
: public ::testing::TestWithParam<std::tuple<string, string, uint64>> {};
TEST_P(SelectOptimizationsConflictTest, DatasetUtils) {
string job_name = "job";
const string opt_ins_raw = std::get<0>(GetParam());
const string opt_outs_raw = std::get<1>(GetParam());
const string opt_ins = std::get<0>(GetParam());
const string opt_outs = std::get<1>(GetParam());
const uint64 hash_result = std::get<2>(GetParam());
if (!opt_ins.empty()) {
setenv("TF_DATA_EXPERIMENT_OPT_IN", opt_ins.c_str(), 1);
}
if (!opt_outs.empty()) {
setenv("TF_DATA_EXPERIMENT_OPT_OUT", opt_outs.c_str(), 1);
}
string job_name = "job";
auto hash_func = [hash_result](const string& str) { return hash_result; };
absl::flat_hash_map<string, uint64> live_experiments = {
{"exp1", 20}, {"exp2", 30}, {"exp3", 40},
......@@ -1246,21 +1261,27 @@ TEST_P(SelectOptimizationsConflictTest, DatasetUtils) {
std::vector<tstring> optimizations_enabled = {"exp1", "exp4"},
optimizations_disabled = {"exp2", "exp5"},
optimizations_default = {"exp3", "exp6"};
std::vector<tstring> optimizations =
SelectOptimizations(job_name, opt_ins_raw, opt_outs_raw, live_experiments,
optimizations_enabled, optimizations_disabled,
optimizations_default, hash_func);
std::vector<tstring> optimizations = SelectOptimizations(
job_name, live_experiments, optimizations_enabled, optimizations_disabled,
optimizations_default, hash_func);
int tested_times = 0;
if (opt_outs_raw.empty()) {
if (opt_outs.empty()) {
EXPECT_THAT(optimizations,
UnorderedElementsAre("exp1", "exp3", "exp4", "exp6"));
tested_times++;
} else if (opt_outs_raw == "exp1,exp3") {
} else if (opt_outs == "exp1,exp3") {
EXPECT_THAT(optimizations, UnorderedElementsAre("exp1", "exp4", "exp6"));
tested_times++;
}
EXPECT_EQ(tested_times, 1);
if (!opt_ins.empty()) {
unsetenv("TF_DATA_EXPERIMENT_OPT_IN");
}
if (!opt_outs.empty()) {
unsetenv("TF_DATA_EXPERIMENT_OPT_OUT");
}
}
INSTANTIATE_TEST_SUITE_P(Test, SelectOptimizationsConflictTest,
......@@ -1268,6 +1289,66 @@ INSTANTIATE_TEST_SUITE_P(Test, SelectOptimizationsConflictTest,
::testing::Values("", "exp1,exp3"),
::testing::Values(10, 50, 90)));
class SelectOptimizationsJobTest
: public ::testing::TestWithParam<std::tuple<string, string, string>> {};
TEST_P(SelectOptimizationsJobTest, DatasetUtils) {
const string job_name = std::get<0>(GetParam());
const string opt_ins = std::get<1>(GetParam());
const string opt_outs = std::get<2>(GetParam());
if (!opt_ins.empty()) {
setenv("TF_DATA_EXPERIMENT_OPT_IN", opt_ins.c_str(), 1);
}
if (!opt_outs.empty()) {
setenv("TF_DATA_EXPERIMENT_OPT_OUT", opt_outs.c_str(), 1);
}
std::vector<tstring> optimizations_enabled = {"exp4"}, optimizations_disabled,
optimizations_default = {"exp2"};
absl::flat_hash_map<string, uint64> live_experiments = {
{"exp1", 0}, {"exp2", 100}, {"exp3", 100}};
auto hash_func = [](const string& str) { return Hash64(str); };
std::vector<tstring> optimizations = SelectOptimizations(
job_name, live_experiments, optimizations_enabled, optimizations_disabled,
optimizations_default, hash_func);
int tested_times = 0;
if (job_name.empty()) {
EXPECT_THAT(optimizations, UnorderedElementsAre("exp2", "exp4"));
tested_times++;
} else if (opt_ins.empty()) {
if (opt_outs.empty()) {
EXPECT_THAT(optimizations, UnorderedElementsAre("exp2", "exp3", "exp4"));
tested_times++;
} else if (opt_outs == "exp2,exp3") {
EXPECT_THAT(optimizations, UnorderedElementsAre("exp4"));
tested_times++;
}
} else if (opt_ins == "exp1") {
if (opt_outs.empty()) {
EXPECT_THAT(optimizations,
UnorderedElementsAre("exp1", "exp2", "exp3", "exp4"));
tested_times++;
} else if (opt_outs == "exp2,exp3") {
EXPECT_THAT(optimizations, UnorderedElementsAre("exp1", "exp4"));
tested_times++;
}
}
EXPECT_EQ(tested_times, 1);
if (!opt_ins.empty()) {
unsetenv("TF_DATA_EXPERIMENT_OPT_IN");
}
if (!opt_outs.empty()) {
unsetenv("TF_DATA_EXPERIMENT_OPT_OUT");
}
}
INSTANTIATE_TEST_SUITE_P(Test, SelectOptimizationsJobTest,
::testing::Combine(::testing::Values("", "job"),
::testing::Values("", "exp1"),
::testing::Values("",
"exp2,exp3")));
} // namespace
} // namespace data
} // namespace tensorflow
......@@ -80,48 +80,27 @@ void OptimizeDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input,
&optimizations_default));
string job_name = port::JobName();
if (job_name.empty()) {
// If `job_name` is empty, apply the enabled and default optimizations
// directly.
optimizations.insert(optimizations.end(), optimizations_enabled.begin(),
optimizations_enabled.end());
optimizations.insert(optimizations.end(), optimizations_default.begin(),
optimizations_default.end());
} else {
// The map that stores the experiment names and for how much percentage
// of the jobs, the experiments will be randomly turned on.
//
// This is currently empty; we have no live experiments yet.
absl::flat_hash_map<string, uint64> live_experiments;
const char* opt_ins_raw_cs = std::getenv("TF_DATA_EXPERIMENT_OPT_IN");
const char* opt_outs_raw_cs = std::getenv("TF_DATA_EXPERIMENT_OPT_OUT");
string opt_ins_raw;
if (opt_ins_raw_cs != nullptr) {
opt_ins_raw = string(opt_ins_raw_cs);
}
string opt_outs_raw;
if (opt_outs_raw_cs != nullptr) {
opt_outs_raw = string(opt_outs_raw_cs);
}
auto hash_func = [](const string& str) { return Hash64(str); };
optimizations = SelectOptimizations(
job_name, opt_ins_raw, opt_outs_raw, live_experiments,
optimizations_enabled, optimizations_disabled, optimizations_default,
hash_func);
// Log the experiments that will be applied.
if (!live_experiments.empty() && VLOG_IS_ON(1)) {
VLOG(1) << "The input pipeline is subject to tf.data experiment. "
"Please see `go/tf-data-experiments` for more details.";
for (auto& pair : live_experiments) {
string experiment = pair.first;
if (std::find(optimizations.begin(), optimizations.end(),
experiment) != optimizations.end()) {
VLOG(1) << "The experiment \"" << experiment << "\" is applied.";
metrics::RecordTFDataExperiment(experiment);
}
// The map that stores the experiment names and for how much percentage
// of the jobs, the experiments will be randomly turned on.
//
// This is currently empty; we have no live experiments yet.
absl::flat_hash_map<string, uint64> live_experiments;
auto hash_func = [](const string& str) { return Hash64(str); };
optimizations = SelectOptimizations(
job_name, live_experiments, optimizations_enabled,
optimizations_disabled, optimizations_default, hash_func);
// Log and record the experiments that will be applied.
if (!job_name.empty() && !live_experiments.empty()) {
VLOG(1) << "The input pipeline is subject to tf.data experiment. "
"Please see `go/tf-data-experiments` for more details.";
for (auto& pair : live_experiments) {
string experiment = pair.first;
if (std::find(optimizations.begin(), optimizations.end(), experiment) !=
optimizations.end()) {
VLOG(1) << "The experiment \"" << experiment << "\" is applied.";
metrics::RecordTFDataExperiment(experiment);
}
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册