diff --git a/paddle/fluid/framework/custom_operator.cc b/paddle/fluid/framework/custom_operator.cc index 31243bad3013bf104aa91e5fe819e4ca0cda6655..d3a4909071e8172b7adec29d1de531b6ef57e4bc 100644 --- a/paddle/fluid/framework/custom_operator.cc +++ b/paddle/fluid/framework/custom_operator.cc @@ -249,7 +249,8 @@ static void RunKernelFunc(const framework::ExecutionContext& ctx, true_out_meta->dims = calc_out->dims(); true_out_meta->dtype = calc_out->dtype(); true_out_meta->layout = calc_out->layout(); - // lod and offset no need to be reset + true_out_meta->offset = calc_out->offset(); + // lod no need to be reset // reset holder if needed if (true_out->Holder() != calc_out->Holder()) { true_out->ResetHolder(calc_out->Holder()); diff --git a/paddle/pten/core/dense_tensor_impl.cc b/paddle/pten/core/dense_tensor_impl.cc index a798ebe98b521a0bb11ff1f6127f21584bd39f79..18702c85b08dabb8dd4d2c39e1835695504f1102 100644 --- a/paddle/pten/core/dense_tensor_impl.cc +++ b/paddle/pten/core/dense_tensor_impl.cc @@ -76,13 +76,8 @@ void DenseTensor::set_layout(const paddle::framework::DataLayout layout) { meta_.layout = layout; } +// Note: When you reset holder, you need to ensure the offset is correct void DenseTensor::ResetHolder(const std::shared_ptr& holder) { - PADDLE_ENFORCE_EQ( - meta_.offset, - 0, - paddle::platform::errors::Fatal( - "Only the offset is supported to zero when the holder is reset.")); - if (holder_) { // TODO(zyfncg): The change of static_cast<> in check will recover back // when SetAllocationForOutputTenosr is deleted. @@ -90,7 +85,7 @@ void DenseTensor::ResetHolder(const std::shared_ptr& holder) { // compare with a data with unsigned long type, this will make checking // failed, so it's a temporary solution to deal with this problem. PADDLE_ENFORCE_LE( - numel() * static_cast(SizeOf(dtype())), + numel() * static_cast(SizeOf(dtype())) + meta_.offset, static_cast(holder->size()), paddle::platform::errors::InvalidArgument( "The size of Holder is not enough to store the Tensor.")); diff --git a/python/paddle/fluid/tests/custom_op/CMakeLists.txt b/python/paddle/fluid/tests/custom_op/CMakeLists.txt index 42aed28074c4e2b7b8c2e0313c5c2c11fc7c51f0..82364d36922572e9a60bfade0d0cf39eed4b8c81 100644 --- a/python/paddle/fluid/tests/custom_op/CMakeLists.txt +++ b/python/paddle/fluid/tests/custom_op/CMakeLists.txt @@ -20,6 +20,7 @@ py_test(test_custom_attrs_jit SRCS test_custom_attrs_jit.py) py_test(test_custom_concat SRCS test_custom_concat.py) py_test(test_custom_conj SRCS test_custom_conj.py) py_test(test_custom_linear SRCS test_custom_linear.py) +py_test(test_custom_simple_slice SRCS test_custom_simple_slice.py) # other tests py_test(test_sysconfig SRCS test_sysconfig.py)