未验证 提交 1d78681d 编写于 作者: W WangZhen 提交者: GitHub

Support 0 shapes input Tensor for MKL slice (#45930)

Support 0 shapes input Tensor for MKL slice kernel
上级 f48b1264
......@@ -171,7 +171,9 @@ void innerTransDataLayoutFromMKLDNN(DataLayout in_layout,
out->set_mem_desc(out_mem_desc);
out->Resize(in.dims());
if ((in.mem_desc() != out->mem_desc()) || always_copy) {
// Note(0x45f): Using initialized() to support slice Tensors
// with shapes like [0, 0, 0].
if (in.initialized() && ((in.mem_desc() != out->mem_desc()) || always_copy)) {
void* in_data = GetDataFromTensor(in, in_type);
platform::ReorderMKLDNNHandler handler(
......
......@@ -72,11 +72,19 @@ class SliceMKLDNNKernel : public framework::OpKernel<T> {
ends[i] = ends[i] < 0 ? x_vec_dims[axes[i]] + ends[i]
: std::min(ends[i], x_vec_dims[axes[i]]);
offsets[axes[i]] = starts[i];
slice_dims[axes[i]] = ends[i] - starts[i];
slice_dims[axes[i]] =
std::max(static_cast<int64_t>(0), ends[i] - starts[i]);
}
out->Resize(phi::make_ddim(slice_dims));
// Note(0x45f): To support slice Tensors with shapes like [0, 0, 0].
if (!x->initialized()) {
out->mutable_data(x->place(), x->dtype());
out->set_layout(experimental::DataLayout::kMKLDNN);
return;
}
dnnl::memory::data_type x_type =
framework::ToMKLDNNDataType(framework::TransToProtoVarType(x->dtype()));
......
......@@ -83,7 +83,9 @@ void innerTransDataLayoutFromOneDNN(DataLayout in_layout,
out->set_mem_desc(out_mem_desc);
out->Resize(in.dims());
if ((in.mem_desc() != out->mem_desc()) || always_copy) {
// Note(0x45f): Using initialized() to support slice Tensors
// with shapes like [0, 0, 0].
if (in.initialized() && ((in.mem_desc() != out->mem_desc()) || always_copy)) {
void* in_data = GetDataFromTensor(in, in_type);
ReorderOneDNNHandler handler(in_tz, in.dtype(), in_type, cpu_engine);
......
......@@ -82,6 +82,33 @@ class TestCase2(TestSliceOp):
self.out = self.input[-3:3, 0:100, :, 2:-1]
class TestSliceZerosShapeTensor(OpTest):
def setUp(self):
self.op_type = "slice"
self.config()
self.inputs = {'Input': self.input}
self.outputs = {'Out': self.out}
self.attrs = {
'axes': self.axes,
'starts': self.starts,
'ends': self.ends,
'infer_flags': self.infer_flags,
'use_mkldnn': True
}
def config(self):
self.input = np.random.random([0, 0, 0]).astype("float32")
self.starts = [1]
self.ends = [2]
self.axes = [0]
self.infer_flags = []
self.out = self.input[1:2]
def test_check_output(self):
self.check_output_with_place(paddle.CPUPlace())
# 1.2 with attr(decrease)
class TestSliceOp_decs_dim(OpTest):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册