提交 666c3bb9 编写于 作者: W Wojciech Uss 提交者: Tao Luo

handle multi-inputs with empty inputs for mkldnn_concat_op (#21827)

test=develop
上级 a9af87ed
......@@ -59,6 +59,16 @@ static const mkldnn::engine& GetMKLDNNEngine(
return dev_ctx.GetEngine();
}
// From a multi-input, gather only nonempty inputs
static const std::vector<const Tensor*> ReduceMultiInput(
const std::vector<const Tensor*>& inputs) {
std::vector<const Tensor*> reduced(inputs.size());
auto end_it = std::copy_if(inputs.begin(), inputs.end(), reduced.begin(),
[](const Tensor* t) { return t->numel() > 0; });
reduced.resize(std::distance(reduced.begin(), end_it));
return reduced;
}
template <typename T>
class ConcatPrimitiveFactory {
public:
......@@ -120,7 +130,7 @@ template <typename T>
class ConcatMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
public:
void Compute(const paddle::framework::ExecutionContext& ctx) const override {
auto multi_input = ctx.MultiInput<Tensor>("X");
auto multi_input = ReduceMultiInput(ctx.MultiInput<Tensor>("X"));
EnforceLayouts(multi_input);
Tensor* output = ctx.Output<Tensor>("Out");
int concat_axis = ctx.Attr<int>("axis");
......
......@@ -154,8 +154,6 @@ class CPUROIAlignOpKernel : public framework::OpKernel<T> {
int width = in_dims[3];
int rois_num = rois->dims()[0];
if (rois_num == 0) return;
auto in_stride = framework::stride(in_dims);
auto roi_stride = framework::stride(rois->dims());
auto out_stride = framework::stride(out->dims());
......
......@@ -15,7 +15,7 @@
from __future__ import print_function
import unittest
from paddle.fluid.tests.unittests.test_concat_op import TestConcatOp, TestConcatOp2, TestConcatOp3
from paddle.fluid.tests.unittests.test_concat_op import TestConcatOp, TestConcatOp2, TestConcatOp3, TestConcatOp4
class TestMKLDNNConcatOp(TestConcatOp):
......@@ -42,7 +42,7 @@ class TestMKLDNNConcatOp2(TestConcatOp2):
self._cpu_only = True
def test_check_output(self):
# TODO(wangzhongpu): support mkldnn op in dygraph mode
# TODO(wangzhongpu): support mkldnn op in dygraph mode
self.check_output(check_dygraph=(self.attrs["use_mkldnn"] == False))
def test_check_grad(self):
......@@ -59,7 +59,24 @@ class TestMKLDNNConcatOp3(TestConcatOp3):
self._cpu_only = True
def test_check_output(self):
# TODO(wangzhongpu): support mkldnn op in dygraph mode
# TODO(wangzhongpu): support mkldnn op in dygraph mode
self.check_output(check_dygraph=(self.attrs["use_mkldnn"] == False))
def test_check_grad(self):
pass
def init_kernel_type(self):
self.use_mkldnn = True
class TestMKLDNNConcatOp4(TestConcatOp4):
def setUp(self):
super(TestMKLDNNConcatOp4, self).setUp()
self.attrs["use_mkldnn"] = True
self._cpu_only = True
def test_check_output(self):
# TODO(wangzhongpu): support mkldnn op in dygraph mode
self.check_output(check_dygraph=(self.attrs["use_mkldnn"] == False))
def test_check_grad(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册