提交 75018c79 编写于 作者: D Dong Lin 提交者: TensorFlower Gardener

Allow user to pass custom threadpool via Session::Run()

PiperOrigin-RevId: 286228455
Change-Id: Id85aef40d98edfce4a93b4e9ab2eda304f54b865
上级 3883b177
......@@ -233,6 +233,7 @@ cc_library_with_android_deps(
deps = [
"//tensorflow/core:core_cpu",
"//tensorflow/core:lib",
"//tensorflow/core:lib_experimental",
"//tensorflow/core:protos_all_cc",
],
)
......
......@@ -127,6 +127,33 @@ Status ClientSession::Run(const RunOptions& run_options, const FeedType& inputs,
target_node_names, outputs, run_metadata);
}
Status ClientSession::Run(
const RunOptions& run_options, const FeedType& inputs,
const std::vector<Output>& fetch_outputs,
const std::vector<Operation>& run_outputs, std::vector<Tensor>* outputs,
RunMetadata* run_metadata,
const thread::ThreadPoolOptions& threadpool_options) const {
std::vector<std::pair<string, Tensor>> feeds;
for (auto const& feed : inputs) {
TF_RETURN_IF_ERROR(feed.second.status);
feeds.emplace_back(feed.first.name(), feed.second.tensor);
}
std::vector<string> output_tensor_names;
output_tensor_names.reserve(fetch_outputs.size());
for (auto const& output : fetch_outputs) {
output_tensor_names.push_back(output.name());
}
std::vector<string> target_node_names;
target_node_names.reserve(run_outputs.size());
for (auto const& output : run_outputs) {
target_node_names.push_back(output.node()->name());
}
TF_RETURN_IF_ERROR(impl()->MaybeExtendGraph());
return impl()->session_->Run(run_options, feeds, output_tensor_names,
target_node_names, outputs, run_metadata,
threadpool_options);
}
Status ClientSession::MakeCallable(const CallableOptions& callable_options,
CallableHandle* out_handle) {
TF_RETURN_IF_ERROR(impl()->MaybeExtendGraph());
......
......@@ -93,6 +93,14 @@ class ClientSession {
const std::vector<Operation>& run_outputs,
std::vector<Tensor>* outputs, RunMetadata* run_metadata) const;
/// Same as above. Additionally allows user to provide custom threadpool
/// implementation via ThreadPoolOptions.
Status Run(const RunOptions& run_options, const FeedType& inputs,
const std::vector<Output>& fetch_outputs,
const std::vector<Operation>& run_outputs,
std::vector<Tensor>* outputs, RunMetadata* run_metadata,
const thread::ThreadPoolOptions& threadpool_options) const;
/// \brief A handle to a subgraph, created with
/// `ClientSession::MakeCallable()`.
typedef int64 CallableHandle;
......
......@@ -112,7 +112,7 @@ TEST(ClientSessionTest, Extend) {
test::ExpectTensorEqual<int>(outputs[0], test::AsTensor<int>({31, 42}, {2}));
}
TEST(ClientSessionTest, MultiThreaded) {
TEST(ClientSessionTest, MultiThreadedWithDefaultThreadpool) {
Scope root = Scope::NewRootScope();
auto a = Add(root, {1, 2}, {3, 4});
auto b = Mul(root, {1, 2}, {3, 4});
......@@ -138,6 +138,49 @@ TEST(ClientSessionTest, MultiThreaded) {
test::ExpectTensorEqual<int>(outputs[0], test::AsTensor<int>({-1, 2}, {2}));
}
TEST(ClientSessionTest, MultiThreadedWithCustomThreadpool) {
Scope root = Scope::NewRootScope();
int num_threads = 3;
auto a = Add(root, {1, 2}, {3, 4});
auto b = Mul(root, {1, 2}, {3, 4});
ClientSession session(root);
auto inter_op_threadpool =
absl::make_unique<CustomThreadPoolImpl>(num_threads);
ASSERT_EQ(inter_op_threadpool->GetNumScheduleCalled(), 0);
auto intra_op_threadpool =
absl::make_unique<CustomThreadPoolImpl>(num_threads);
ASSERT_EQ(intra_op_threadpool->GetNumScheduleCalled(), 0);
tensorflow::thread::ThreadPoolOptions threadPoolOptions;
threadPoolOptions.inter_op_threadpool = inter_op_threadpool.get();
threadPoolOptions.intra_op_threadpool = intra_op_threadpool.get();
{
thread::ThreadPool thread_pool(Env::Default(), "pool", 2);
thread_pool.Schedule([&session, a]() {
std::vector<Tensor> outputs;
TF_EXPECT_OK(session.Run(RunOptions(), ClientSession::FeedType{}, {a}, {},
&outputs, nullptr, thread::ThreadPoolOptions()));
test::ExpectTensorEqual<int>(outputs[0],
test::AsTensor<int>({4, 6}, {2}));
});
thread_pool.Schedule([&session, b]() {
std::vector<Tensor> outputs;
TF_EXPECT_OK(session.Run(RunOptions(), ClientSession::FeedType{}, {b}, {},
&outputs, nullptr, thread::ThreadPoolOptions()));
test::ExpectTensorEqual<int>(outputs[0],
test::AsTensor<int>({3, 8}, {2}));
});
}
auto c = Sub(root, b, a);
std::vector<Tensor> outputs;
TF_EXPECT_OK(session.Run(RunOptions(), ClientSession::FeedType{}, {c}, {},
&outputs, nullptr, thread::ThreadPoolOptions()));
test::ExpectTensorEqual<int>(outputs[0], test::AsTensor<int>({-1, 2}, {2}));
}
TEST(ClientSessionTest, CallableWithDefaultThreadPool) {
Scope root = Scope::NewRootScope();
auto a = Placeholder(root, DT_INT32);
......
......@@ -793,6 +793,17 @@ Status DirectSession::Run(const RunOptions& run_options,
const std::vector<string>& target_nodes,
std::vector<Tensor>* outputs,
RunMetadata* run_metadata) {
return Run(run_options, inputs, output_names, target_nodes, outputs,
run_metadata, thread::ThreadPoolOptions());
}
Status DirectSession::Run(const RunOptions& run_options,
const NamedTensorList& inputs,
const std::vector<string>& output_names,
const std::vector<string>& target_nodes,
std::vector<Tensor>* outputs,
RunMetadata* run_metadata,
const thread::ThreadPoolOptions& threadpool_options) {
TF_RETURN_IF_ERROR(CheckNotClosed());
TF_RETURN_IF_ERROR(CheckGraphCreated("Run()"));
direct_session_runs->GetCell()->IncrementBy(1);
......@@ -852,7 +863,7 @@ Status DirectSession::Run(const RunOptions& run_options,
TF_RETURN_IF_ERROR(RunInternal(step_id, run_options, &call_frame,
executors_and_keys, run_metadata,
thread::ThreadPoolOptions()));
threadpool_options));
// Receive outputs.
if (outputs) {
......
......@@ -84,6 +84,14 @@ class DirectSession : public Session {
std::vector<Tensor>* outputs,
RunMetadata* run_metadata) override;
// NOTE: Experimental and subject to change.
::tensorflow::Status Run(
const ::tensorflow::RunOptions& run_options,
const NamedTensorList& inputs, const std::vector<string>& output_names,
const std::vector<string>& target_nodes, std::vector<Tensor>* outputs,
RunMetadata* run_metadata,
const thread::ThreadPoolOptions& threadpool_options) override;
// NOTE: PRunSetup and PRun are added to support partial execution. This
// feature is experimental and subject to change.
::tensorflow::Status PRunSetup(const std::vector<string>& input_names,
......
......@@ -174,6 +174,19 @@ class Session {
const std::vector<string>& target_node_names,
std::vector<Tensor>* outputs, RunMetadata* run_metadata);
/// \brief Like `Run` with `RunOptions` proto, but allows user to provide
/// custom threadpool implementation via ThreadPoolOptions.
/// NOTE: This API is still experimental and may change.
virtual Status Run(const RunOptions& run_options,
const std::vector<std::pair<string, Tensor> >& inputs,
const std::vector<string>& output_tensor_names,
const std::vector<string>& target_node_names,
std::vector<Tensor>* outputs, RunMetadata* run_metadata,
const thread::ThreadPoolOptions& threadpool_options) {
return errors::Unimplemented(
"Run with threadpool is not supported for this session.");
}
/// \brief Sets up a graph for partial execution. All future feeds and
/// fetches are specified by `input_names` and `output_names`. Returns
/// `handle` that can be used to perform a sequence of partial feeds and
......@@ -245,7 +258,8 @@ class Session {
}
/// \brief Invokes the subgraph named by `handle` with the given options and
/// input tensors.
/// input tensors. User can provide custom threadpool implementation via
/// threadpool_options.
///
/// The order of tensors in `feed_tensors` must and `fetch_tensors` will
/// match the order of names in `CallableOptions::feed()` and
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册