提交 af8d4fda 编写于 作者: X Xiao Yu 提交者: TensorFlower Gardener

Ensure we can propagate error properly when setting...

Ensure we can propagate error properly when setting "TF_ENABLE_EAGER_CLIENT_STREAMING_ENQUEUE=false".

PiperOrigin-RevId: 262737847
上级 b4bf76a4
......@@ -35,7 +35,7 @@ class DestroyTensorHandleNode : public tensorflow::EagerNode {
Status Run() override {
EnqueueResponse* response = new EnqueueResponse;
eager_client_->StreamingEnqueueAsync(
return eager_client_->StreamingEnqueueAsync(
request_.get(), response, [response](const tensorflow::Status& s) {
if (!s.ok()) {
LOG(WARNING) << "Ignoring an error encountered when deleting "
......@@ -44,7 +44,6 @@ class DestroyTensorHandleNode : public tensorflow::EagerNode {
}
delete response;
});
return Status::OK();
}
void Abort(Status status) override {}
......
......@@ -54,9 +54,9 @@ class EagerClient {
// is invoked and keeps it open until some error condition.
// Similarly to the methods above, the request can be deleted as soon as
// StreamingEnqueueAsync returns.
virtual void StreamingEnqueueAsync(const EnqueueRequest* request,
EnqueueResponse* response,
StatusCallback done) = 0;
virtual Status StreamingEnqueueAsync(const EnqueueRequest* request,
EnqueueResponse* response,
StatusCallback done) = 0;
};
// Simple wrapper class that can be used to retrieve EagerClients.
......
......@@ -157,7 +157,7 @@ Status RemoteCopyNode::StartSend() {
EnqueueResponse* response = new EnqueueResponse;
// If StartRecv fails very quickly, `this` can be destroyed before the
// callback below is executed. So, we can't capture `this`.
eager_client->StreamingEnqueueAsync(
return eager_client->StreamingEnqueueAsync(
&request, response, [response, captured_state](const Status& s) {
captured_state->SetSendStatus(s);
if (!s.ok()) {
......@@ -165,7 +165,6 @@ Status RemoteCopyNode::StartSend() {
}
delete response;
});
return Status::OK();
}
}
......@@ -210,7 +209,7 @@ Status RemoteCopyNode::RunRemoteRecv(EagerOperation* op) {
EnqueueResponse* response = new EnqueueResponse;
const std::shared_ptr<CapturedSharedState>& captured_state = captured_state_;
Device* recv_device = recv_device_;
eager_client->StreamingEnqueueAsync(
return eager_client->StreamingEnqueueAsync(
&request, response,
[captured_state, response, recv_device](const Status& s) {
if (s.ok()) {
......@@ -228,8 +227,6 @@ Status RemoteCopyNode::RunRemoteRecv(EagerOperation* op) {
}
delete response;
});
return Status::OK();
}
Status RemoteCopyNode::StartRecv() {
......
......@@ -44,7 +44,7 @@ Status RemoteExecuteNode::Run() {
}
VLOG(3) << "Issuing: " << rpc_description;
eager_client_->StreamingEnqueueAsync(
return eager_client_->StreamingEnqueueAsync(
request_.get(), response,
[inputs, retvals, response, device,
rpc_description](const Status& status) {
......@@ -75,7 +75,6 @@ Status RemoteExecuteNode::Run() {
}
delete response;
});
return Status::OK();
}
} // namespace eager
......
......@@ -83,9 +83,9 @@ class GrpcEagerClient : public EagerClient {
}
}
void StreamingEnqueueAsync(const EnqueueRequest* request,
EnqueueResponse* response,
StatusCallback done) override {
Status StreamingEnqueueAsync(const EnqueueRequest* request,
EnqueueResponse* response,
StatusCallback done) override {
if (EnableStreaming()) {
tf_shared_lock l(mu_);
auto it = enqueue_dispatchers_.find(request->context_id());
......@@ -100,6 +100,7 @@ class GrpcEagerClient : public EagerClient {
it = it_and_bool.first;
}
it->second.SendNextRequest(*request, response, std::move(done));
return Status::OK();
} else {
Notification n;
Status status;
......@@ -109,6 +110,7 @@ class GrpcEagerClient : public EagerClient {
});
n.WaitForNotification();
done(status);
return status;
}
}
......
......@@ -84,14 +84,13 @@ class SingleWorkerTest(test.TestCase):
cm.exception.message)
def testMultiDeviceFunctionAmbiguousDevice(self):
self.skipTest('b/139212497')
@def_function.function
def ambiguous_device(i):
with ops.device('cpu:0'):
return i + constant_op.constant([2])
with self.assertRaises(ValueError) as cm:
with self.assertRaises(errors.InvalidArgumentError) as cm:
with ops.device('/job:worker/replica:0/task:0/cpu:0'):
ambiguous_device(constant_op.constant([2])).numpy()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册