未验证 提交 9d996cdd 编写于 作者: J jakpiase 提交者: GitHub

Fix for slice OneDNN kernel in solov2 and ppyolo models (#35706)

* fixed slice error

* added handling of StartsTensor+List and EndsTensor+List

* fix for ppyolo model
上级 e80acff3
......@@ -12,8 +12,29 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/utils.h"
#include "paddle/fluid/platform/mkldnn_reuse.h"
static mkldnn::memory::format_tag get_plain_format_tag(
const paddle::framework::Tensor* tensor) {
auto tensor_dims_size = tensor->dims().size();
switch (tensor_dims_size) {
case 1:
return mkldnn::memory::format_tag::a;
case 2:
return mkldnn::memory::format_tag::ab;
case 3:
return mkldnn::memory::format_tag::abc;
case 4:
return mkldnn::memory::format_tag::abcd;
case 5:
return mkldnn::memory::format_tag::abcde;
}
return mkldnn::memory::format_tag::abcdef;
}
namespace paddle {
namespace operators {
......@@ -35,7 +56,6 @@ class SliceMKLDNNKernel : public framework::OpKernel<T> {
auto* out = ctx.Output<Tensor>("Out");
auto x_vec_dims = framework::vectorize(x->dims());
auto out_vec_dims = framework::vectorize(out->dims());
auto axes_int = ctx.Attr<std::vector<int>>("axes");
auto starts_int = ctx.Attr<std::vector<int>>("starts");
......@@ -48,8 +68,22 @@ class SliceMKLDNNKernel : public framework::OpKernel<T> {
std::vector<int64_t> ends(ctx.Attr<std::vector<int>>("ends").begin(),
ctx.Attr<std::vector<int>>("ends").end());
auto starts_tensor_list = ctx.MultiInput<Tensor>("StartsTensorList");
if (ctx.HasInput("StartsTensor")) {
starts = GetDataFromTensor<int64_t>(ctx.Input<Tensor>("StartsTensor"));
} else if (starts_tensor_list.size() > 0) {
starts = GetDataFromTensorList<int64_t>(starts_tensor_list);
}
auto decrease_axis = ctx.Attr<std::vector<int>>("decrease_axis");
auto ends_tensor_list = ctx.MultiInput<Tensor>("EndsTensorList");
if (ctx.HasInput("EndsTensor")) {
ends = GetDataFromTensor<int64_t>(ctx.Input<Tensor>("EndsTensor"));
} else if (ends_tensor_list.size() > 0) {
ends = GetDataFromTensorList<int64_t>(ends_tensor_list);
}
std::vector<int64_t> offsets(x_vec_dims.size(), 0);
std::vector<int64_t> slice_dims(x_vec_dims);
......@@ -61,6 +95,8 @@ class SliceMKLDNNKernel : public framework::OpKernel<T> {
slice_dims[axes[i]] = ends[i] - starts[i];
}
out->Resize(framework::make_ddim(slice_dims));
mkldnn::memory::data_type x_type = framework::ToMKLDNNDataType(x->type());
auto key = platform::CreateKey(dev_ctx, x_vec_dims, axes, starts, ends,
x->format(), x_type);
......@@ -73,20 +109,35 @@ class SliceMKLDNNKernel : public framework::OpKernel<T> {
auto slice_mem_p = reorder_handler.AcquireSubmemory(slice_dims, offsets,
reorder_src_memory_p);
auto reorder_dst_memory_p = reorder_handler.AcquireDstMemory(
out, slice_dims, 0, x->format(), ctx.GetPlace());
out, slice_dims, 0, get_plain_format_tag(x), ctx.GetPlace());
auto reorder_p =
reorder_handler.AcquireReorder(reorder_dst_memory_p, slice_mem_p);
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
reorder_p->execute(astream, *slice_mem_p, *reorder_dst_memory_p);
astream.wait();
std::vector<int64_t> new_out_dims(slice_dims.size() - decrease_axis.size());
if (new_out_dims.size() == 0) {
new_out_dims.emplace_back(1);
} else {
for (const auto& axis : decrease_axis) {
slice_dims[axis] = 0;
}
int i = 0;
for (const auto& slice_dim : slice_dims) {
if (slice_dim != 0) new_out_dims[i++] = slice_dim;
}
}
astream.wait();
out->Resize(framework::make_ddim(new_out_dims));
out->set_layout(framework::DataLayout::kMKLDNN);
out->set_format(platform::GetMKLDNNFormat(
reorder_dst_memory_p->get_desc().reshape(out_vec_dims)));
reorder_dst_memory_p->get_desc().reshape(new_out_dims)));
}
};
template <typename T>
class SliceGradMKLDNNKernel : public framework::OpKernel<T> {
public:
......@@ -116,6 +167,20 @@ class SliceGradMKLDNNKernel : public framework::OpKernel<T> {
std::vector<int64_t> ends(ctx.Attr<std::vector<int>>("ends").begin(),
ctx.Attr<std::vector<int>>("ends").end());
auto starts_tensor_list = ctx.MultiInput<Tensor>("StartsTensorList");
if (ctx.HasInput("StartsTensor")) {
starts = GetDataFromTensor<int64_t>(ctx.Input<Tensor>("StartsTensor"));
} else if (starts_tensor_list.size() > 0) {
starts = GetDataFromTensorList<int64_t>(starts_tensor_list);
}
auto ends_tensor_list = ctx.MultiInput<Tensor>("EndsTensorList");
if (ctx.HasInput("EndsTensor")) {
ends = GetDataFromTensor<int64_t>(ctx.Input<Tensor>("EndsTensor"));
} else if (ends_tensor_list.size() > 0) {
ends = GetDataFromTensorList<int64_t>(ends_tensor_list);
}
auto decrease_axis = ctx.Attr<std::vector<int>>("decrease_axis");
std::vector<int64_t> offsets(dx_vec_dims.size(), 0);
......
......@@ -150,6 +150,46 @@ class TestSlice3DOneDNNOp(TestSliceDecrease1AxisOneDNNOp):
self.out = self.input[:, :, -1]
class TestSliceOneDNNOp_decs_dim_starts_ListTensor(
TestSliceDecrease1AxisOneDNNOp):
def set_inputs(self):
starts_tensor = []
for index, ele in enumerate(self.starts):
starts_tensor.append(("x1", np.ones((1)).astype('int32') * 2))
self.inputs = {'Input': self.input, 'StartsTensorList': starts_tensor}
def config(self):
self.input = np.random.random([5, 4, 5]).astype("float32")
self.starts = [1]
self.ends = [3]
self.axes = [2]
self.decrease_axis = []
self.infer_flags = [1, 1, 1]
self.out = self.input[:, :, 2:3]
class TestSlice4DInferDimsOneDNNOp(TestSliceDecrease1AxisOneDNNOp):
def config(self):
self.input = np.random.random([1, 1, 10, 10]).astype("float32")
self.starts = [1, 2]
self.ends = [9, 9]
self.axes = [2, 3]
self.decrease_axis = [1]
self.infer_flags = [-1, -1]
self.out = self.input[:, :, 1:9, 2:9]
class TestSlice4DInferDimsOneDNNOp2(TestSliceDecrease1AxisOneDNNOp):
def config(self):
self.input = np.random.random([1, 1, 10, 10]).astype("float32")
self.starts = [4, 2]
self.ends = [7, 8]
self.axes = [2, 3]
self.decrease_axis = [0, 1]
self.infer_flags = [-1, -1]
self.out = self.input[:, :, 4:7, 2:8]
# BF16 TESTS
def create_bf16_test_class(parent):
@OpTestTool.skip_if_not_cpu_bf16()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册