未验证 提交 91b074a2 编写于 作者: C Chen Weihang 提交者: GitHub

[CustomOp] Fix slice bug of custom op (#39393)

* fix slice bug of cusstom op

* add offset in check
上级 f810d755
...@@ -249,7 +249,8 @@ static void RunKernelFunc(const framework::ExecutionContext& ctx, ...@@ -249,7 +249,8 @@ static void RunKernelFunc(const framework::ExecutionContext& ctx,
true_out_meta->dims = calc_out->dims(); true_out_meta->dims = calc_out->dims();
true_out_meta->dtype = calc_out->dtype(); true_out_meta->dtype = calc_out->dtype();
true_out_meta->layout = calc_out->layout(); true_out_meta->layout = calc_out->layout();
// lod and offset no need to be reset true_out_meta->offset = calc_out->offset();
// lod no need to be reset
// reset holder if needed // reset holder if needed
if (true_out->Holder() != calc_out->Holder()) { if (true_out->Holder() != calc_out->Holder()) {
true_out->ResetHolder(calc_out->Holder()); true_out->ResetHolder(calc_out->Holder());
......
...@@ -76,13 +76,8 @@ void DenseTensor::set_layout(const paddle::framework::DataLayout layout) { ...@@ -76,13 +76,8 @@ void DenseTensor::set_layout(const paddle::framework::DataLayout layout) {
meta_.layout = layout; meta_.layout = layout;
} }
// Note: When you reset holder, you need to ensure the offset is correct
void DenseTensor::ResetHolder(const std::shared_ptr<pten::Allocation>& holder) { void DenseTensor::ResetHolder(const std::shared_ptr<pten::Allocation>& holder) {
PADDLE_ENFORCE_EQ(
meta_.offset,
0,
paddle::platform::errors::Fatal(
"Only the offset is supported to zero when the holder is reset."));
if (holder_) { if (holder_) {
// TODO(zyfncg): The change of static_cast<> in check will recover back // TODO(zyfncg): The change of static_cast<> in check will recover back
// when SetAllocationForOutputTenosr is deleted. // when SetAllocationForOutputTenosr is deleted.
...@@ -90,7 +85,7 @@ void DenseTensor::ResetHolder(const std::shared_ptr<pten::Allocation>& holder) { ...@@ -90,7 +85,7 @@ void DenseTensor::ResetHolder(const std::shared_ptr<pten::Allocation>& holder) {
// compare with a data with unsigned long type, this will make checking // compare with a data with unsigned long type, this will make checking
// failed, so it's a temporary solution to deal with this problem. // failed, so it's a temporary solution to deal with this problem.
PADDLE_ENFORCE_LE( PADDLE_ENFORCE_LE(
numel() * static_cast<int64_t>(SizeOf(dtype())), numel() * static_cast<int64_t>(SizeOf(dtype())) + meta_.offset,
static_cast<int64_t>(holder->size()), static_cast<int64_t>(holder->size()),
paddle::platform::errors::InvalidArgument( paddle::platform::errors::InvalidArgument(
"The size of Holder is not enough to store the Tensor.")); "The size of Holder is not enough to store the Tensor."));
......
...@@ -20,6 +20,7 @@ py_test(test_custom_attrs_jit SRCS test_custom_attrs_jit.py) ...@@ -20,6 +20,7 @@ py_test(test_custom_attrs_jit SRCS test_custom_attrs_jit.py)
py_test(test_custom_concat SRCS test_custom_concat.py) py_test(test_custom_concat SRCS test_custom_concat.py)
py_test(test_custom_conj SRCS test_custom_conj.py) py_test(test_custom_conj SRCS test_custom_conj.py)
py_test(test_custom_linear SRCS test_custom_linear.py) py_test(test_custom_linear SRCS test_custom_linear.py)
py_test(test_custom_simple_slice SRCS test_custom_simple_slice.py)
# other tests # other tests
py_test(test_sysconfig SRCS test_sysconfig.py) py_test(test_sysconfig SRCS test_sysconfig.py)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册