未验证 提交 9d996cdd 编写于 作者: J jakpiase 提交者: GitHub

Fix for slice OneDNN kernel in solov2 and ppyolo models (#35706)

* fixed slice error

* added handling of StartsTensor+List and EndsTensor+List

* fix for ppyolo model
上级 e80acff3
...@@ -12,8 +12,29 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,8 +12,29 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/utils.h"
#include "paddle/fluid/platform/mkldnn_reuse.h" #include "paddle/fluid/platform/mkldnn_reuse.h"
static mkldnn::memory::format_tag get_plain_format_tag(
const paddle::framework::Tensor* tensor) {
auto tensor_dims_size = tensor->dims().size();
switch (tensor_dims_size) {
case 1:
return mkldnn::memory::format_tag::a;
case 2:
return mkldnn::memory::format_tag::ab;
case 3:
return mkldnn::memory::format_tag::abc;
case 4:
return mkldnn::memory::format_tag::abcd;
case 5:
return mkldnn::memory::format_tag::abcde;
}
return mkldnn::memory::format_tag::abcdef;
}
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -35,7 +56,6 @@ class SliceMKLDNNKernel : public framework::OpKernel<T> { ...@@ -35,7 +56,6 @@ class SliceMKLDNNKernel : public framework::OpKernel<T> {
auto* out = ctx.Output<Tensor>("Out"); auto* out = ctx.Output<Tensor>("Out");
auto x_vec_dims = framework::vectorize(x->dims()); auto x_vec_dims = framework::vectorize(x->dims());
auto out_vec_dims = framework::vectorize(out->dims());
auto axes_int = ctx.Attr<std::vector<int>>("axes"); auto axes_int = ctx.Attr<std::vector<int>>("axes");
auto starts_int = ctx.Attr<std::vector<int>>("starts"); auto starts_int = ctx.Attr<std::vector<int>>("starts");
...@@ -48,8 +68,22 @@ class SliceMKLDNNKernel : public framework::OpKernel<T> { ...@@ -48,8 +68,22 @@ class SliceMKLDNNKernel : public framework::OpKernel<T> {
std::vector<int64_t> ends(ctx.Attr<std::vector<int>>("ends").begin(), std::vector<int64_t> ends(ctx.Attr<std::vector<int>>("ends").begin(),
ctx.Attr<std::vector<int>>("ends").end()); ctx.Attr<std::vector<int>>("ends").end());
auto starts_tensor_list = ctx.MultiInput<Tensor>("StartsTensorList");
if (ctx.HasInput("StartsTensor")) {
starts = GetDataFromTensor<int64_t>(ctx.Input<Tensor>("StartsTensor"));
} else if (starts_tensor_list.size() > 0) {
starts = GetDataFromTensorList<int64_t>(starts_tensor_list);
}
auto decrease_axis = ctx.Attr<std::vector<int>>("decrease_axis"); auto decrease_axis = ctx.Attr<std::vector<int>>("decrease_axis");
auto ends_tensor_list = ctx.MultiInput<Tensor>("EndsTensorList");
if (ctx.HasInput("EndsTensor")) {
ends = GetDataFromTensor<int64_t>(ctx.Input<Tensor>("EndsTensor"));
} else if (ends_tensor_list.size() > 0) {
ends = GetDataFromTensorList<int64_t>(ends_tensor_list);
}
std::vector<int64_t> offsets(x_vec_dims.size(), 0); std::vector<int64_t> offsets(x_vec_dims.size(), 0);
std::vector<int64_t> slice_dims(x_vec_dims); std::vector<int64_t> slice_dims(x_vec_dims);
...@@ -61,6 +95,8 @@ class SliceMKLDNNKernel : public framework::OpKernel<T> { ...@@ -61,6 +95,8 @@ class SliceMKLDNNKernel : public framework::OpKernel<T> {
slice_dims[axes[i]] = ends[i] - starts[i]; slice_dims[axes[i]] = ends[i] - starts[i];
} }
out->Resize(framework::make_ddim(slice_dims));
mkldnn::memory::data_type x_type = framework::ToMKLDNNDataType(x->type()); mkldnn::memory::data_type x_type = framework::ToMKLDNNDataType(x->type());
auto key = platform::CreateKey(dev_ctx, x_vec_dims, axes, starts, ends, auto key = platform::CreateKey(dev_ctx, x_vec_dims, axes, starts, ends,
x->format(), x_type); x->format(), x_type);
...@@ -73,20 +109,35 @@ class SliceMKLDNNKernel : public framework::OpKernel<T> { ...@@ -73,20 +109,35 @@ class SliceMKLDNNKernel : public framework::OpKernel<T> {
auto slice_mem_p = reorder_handler.AcquireSubmemory(slice_dims, offsets, auto slice_mem_p = reorder_handler.AcquireSubmemory(slice_dims, offsets,
reorder_src_memory_p); reorder_src_memory_p);
auto reorder_dst_memory_p = reorder_handler.AcquireDstMemory( auto reorder_dst_memory_p = reorder_handler.AcquireDstMemory(
out, slice_dims, 0, x->format(), ctx.GetPlace()); out, slice_dims, 0, get_plain_format_tag(x), ctx.GetPlace());
auto reorder_p = auto reorder_p =
reorder_handler.AcquireReorder(reorder_dst_memory_p, slice_mem_p); reorder_handler.AcquireReorder(reorder_dst_memory_p, slice_mem_p);
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream(); auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
reorder_p->execute(astream, *slice_mem_p, *reorder_dst_memory_p); reorder_p->execute(astream, *slice_mem_p, *reorder_dst_memory_p);
astream.wait();
std::vector<int64_t> new_out_dims(slice_dims.size() - decrease_axis.size());
if (new_out_dims.size() == 0) {
new_out_dims.emplace_back(1);
} else {
for (const auto& axis : decrease_axis) {
slice_dims[axis] = 0;
}
int i = 0;
for (const auto& slice_dim : slice_dims) {
if (slice_dim != 0) new_out_dims[i++] = slice_dim;
}
}
astream.wait();
out->Resize(framework::make_ddim(new_out_dims));
out->set_layout(framework::DataLayout::kMKLDNN); out->set_layout(framework::DataLayout::kMKLDNN);
out->set_format(platform::GetMKLDNNFormat( out->set_format(platform::GetMKLDNNFormat(
reorder_dst_memory_p->get_desc().reshape(out_vec_dims))); reorder_dst_memory_p->get_desc().reshape(new_out_dims)));
} }
}; };
template <typename T> template <typename T>
class SliceGradMKLDNNKernel : public framework::OpKernel<T> { class SliceGradMKLDNNKernel : public framework::OpKernel<T> {
public: public:
...@@ -116,6 +167,20 @@ class SliceGradMKLDNNKernel : public framework::OpKernel<T> { ...@@ -116,6 +167,20 @@ class SliceGradMKLDNNKernel : public framework::OpKernel<T> {
std::vector<int64_t> ends(ctx.Attr<std::vector<int>>("ends").begin(), std::vector<int64_t> ends(ctx.Attr<std::vector<int>>("ends").begin(),
ctx.Attr<std::vector<int>>("ends").end()); ctx.Attr<std::vector<int>>("ends").end());
auto starts_tensor_list = ctx.MultiInput<Tensor>("StartsTensorList");
if (ctx.HasInput("StartsTensor")) {
starts = GetDataFromTensor<int64_t>(ctx.Input<Tensor>("StartsTensor"));
} else if (starts_tensor_list.size() > 0) {
starts = GetDataFromTensorList<int64_t>(starts_tensor_list);
}
auto ends_tensor_list = ctx.MultiInput<Tensor>("EndsTensorList");
if (ctx.HasInput("EndsTensor")) {
ends = GetDataFromTensor<int64_t>(ctx.Input<Tensor>("EndsTensor"));
} else if (ends_tensor_list.size() > 0) {
ends = GetDataFromTensorList<int64_t>(ends_tensor_list);
}
auto decrease_axis = ctx.Attr<std::vector<int>>("decrease_axis"); auto decrease_axis = ctx.Attr<std::vector<int>>("decrease_axis");
std::vector<int64_t> offsets(dx_vec_dims.size(), 0); std::vector<int64_t> offsets(dx_vec_dims.size(), 0);
...@@ -172,4 +237,4 @@ REGISTER_OP_KERNEL(slice, MKLDNN, paddle::platform::CPUPlace, ...@@ -172,4 +237,4 @@ REGISTER_OP_KERNEL(slice, MKLDNN, paddle::platform::CPUPlace,
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_KERNEL(slice_grad, MKLDNN, paddle::platform::CPUPlace, REGISTER_OP_KERNEL(slice_grad, MKLDNN, paddle::platform::CPUPlace,
ops::SliceGradMKLDNNKernel<float>, ops::SliceGradMKLDNNKernel<float>,
ops::SliceGradMKLDNNKernel<paddle::platform::bfloat16>); ops::SliceGradMKLDNNKernel<paddle::platform::bfloat16>);
\ No newline at end of file
...@@ -150,6 +150,46 @@ class TestSlice3DOneDNNOp(TestSliceDecrease1AxisOneDNNOp): ...@@ -150,6 +150,46 @@ class TestSlice3DOneDNNOp(TestSliceDecrease1AxisOneDNNOp):
self.out = self.input[:, :, -1] self.out = self.input[:, :, -1]
class TestSliceOneDNNOp_decs_dim_starts_ListTensor(
TestSliceDecrease1AxisOneDNNOp):
def set_inputs(self):
starts_tensor = []
for index, ele in enumerate(self.starts):
starts_tensor.append(("x1", np.ones((1)).astype('int32') * 2))
self.inputs = {'Input': self.input, 'StartsTensorList': starts_tensor}
def config(self):
self.input = np.random.random([5, 4, 5]).astype("float32")
self.starts = [1]
self.ends = [3]
self.axes = [2]
self.decrease_axis = []
self.infer_flags = [1, 1, 1]
self.out = self.input[:, :, 2:3]
class TestSlice4DInferDimsOneDNNOp(TestSliceDecrease1AxisOneDNNOp):
def config(self):
self.input = np.random.random([1, 1, 10, 10]).astype("float32")
self.starts = [1, 2]
self.ends = [9, 9]
self.axes = [2, 3]
self.decrease_axis = [1]
self.infer_flags = [-1, -1]
self.out = self.input[:, :, 1:9, 2:9]
class TestSlice4DInferDimsOneDNNOp2(TestSliceDecrease1AxisOneDNNOp):
def config(self):
self.input = np.random.random([1, 1, 10, 10]).astype("float32")
self.starts = [4, 2]
self.ends = [7, 8]
self.axes = [2, 3]
self.decrease_axis = [0, 1]
self.infer_flags = [-1, -1]
self.out = self.input[:, :, 4:7, 2:8]
# BF16 TESTS # BF16 TESTS
def create_bf16_test_class(parent): def create_bf16_test_class(parent):
@OpTestTool.skip_if_not_cpu_bf16() @OpTestTool.skip_if_not_cpu_bf16()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册