提交 68242bc3 编写于 作者: M Mihai Maruseac

Prevent segfault in `GetSessionHandle{,V2}`.

In eager mode, session state is null.

PiperOrigin-RevId: 332548597
Change-Id: If094812c2e094044220b9ba28f7d7601be042f38
上级 4a436e15
......@@ -16,6 +16,7 @@ limitations under the License.
// See docs in ../ops/data_flow_ops.cc.
#include <limits.h>
#include <vector>
#include "tensorflow/core/common_runtime/device.h"
......@@ -27,6 +28,7 @@ limitations under the License.
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/mutex.h"
......@@ -42,7 +44,11 @@ class GetSessionHandleOp : public OpKernel {
void Compute(OpKernelContext* ctx) override {
const Tensor& val = ctx->input(0);
int64 id = ctx->session_state()->GetNewId();
auto session_state = ctx->session_state();
OP_REQUIRES(ctx, session_state != nullptr,
errors::FailedPrecondition(
"GetSessionHandle called on null session state"));
int64 id = session_state->GetNewId();
TensorStore::TensorAndKey tk{val, id, requested_device()};
OP_REQUIRES_OK(ctx, ctx->tensor_store()->AddTensor(name(), tk));
......
......@@ -25,6 +25,7 @@ from tensorflow.python.framework import constant_op
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import gen_data_flow_ops
from tensorflow.python.ops import gen_math_ops
from tensorflow.python.ops import gen_string_ops
from tensorflow.python.platform import test
......@@ -79,6 +80,13 @@ class RawOpsTest(test.TestCase, parameterized.TestCase):
pad_width=0,
preserve_short_sequences=False))
def testGetSessionHandle(self):
if context.executing_eagerly():
with self.assertRaisesRegex(
errors.FailedPreconditionError,
"GetSessionHandle called on null session state"):
gen_data_flow_ops.GetSessionHandle(value=[1])
if __name__ == "__main__":
ops.enable_eager_execution()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册