提交 666c3bb9 编写于 作者: W Wojciech Uss 提交者: Tao Luo

handle multi-inputs with empty inputs for mkldnn_concat_op (#21827)

test=develop
上级 a9af87ed
...@@ -59,6 +59,16 @@ static const mkldnn::engine& GetMKLDNNEngine( ...@@ -59,6 +59,16 @@ static const mkldnn::engine& GetMKLDNNEngine(
return dev_ctx.GetEngine(); return dev_ctx.GetEngine();
} }
// From a multi-input, gather only nonempty inputs
static const std::vector<const Tensor*> ReduceMultiInput(
const std::vector<const Tensor*>& inputs) {
std::vector<const Tensor*> reduced(inputs.size());
auto end_it = std::copy_if(inputs.begin(), inputs.end(), reduced.begin(),
[](const Tensor* t) { return t->numel() > 0; });
reduced.resize(std::distance(reduced.begin(), end_it));
return reduced;
}
template <typename T> template <typename T>
class ConcatPrimitiveFactory { class ConcatPrimitiveFactory {
public: public:
...@@ -120,7 +130,7 @@ template <typename T> ...@@ -120,7 +130,7 @@ template <typename T>
class ConcatMKLDNNOpKernel : public paddle::framework::OpKernel<T> { class ConcatMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
public: public:
void Compute(const paddle::framework::ExecutionContext& ctx) const override { void Compute(const paddle::framework::ExecutionContext& ctx) const override {
auto multi_input = ctx.MultiInput<Tensor>("X"); auto multi_input = ReduceMultiInput(ctx.MultiInput<Tensor>("X"));
EnforceLayouts(multi_input); EnforceLayouts(multi_input);
Tensor* output = ctx.Output<Tensor>("Out"); Tensor* output = ctx.Output<Tensor>("Out");
int concat_axis = ctx.Attr<int>("axis"); int concat_axis = ctx.Attr<int>("axis");
......
...@@ -154,8 +154,6 @@ class CPUROIAlignOpKernel : public framework::OpKernel<T> { ...@@ -154,8 +154,6 @@ class CPUROIAlignOpKernel : public framework::OpKernel<T> {
int width = in_dims[3]; int width = in_dims[3];
int rois_num = rois->dims()[0]; int rois_num = rois->dims()[0];
if (rois_num == 0) return;
auto in_stride = framework::stride(in_dims); auto in_stride = framework::stride(in_dims);
auto roi_stride = framework::stride(rois->dims()); auto roi_stride = framework::stride(rois->dims());
auto out_stride = framework::stride(out->dims()); auto out_stride = framework::stride(out->dims());
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
from __future__ import print_function from __future__ import print_function
import unittest import unittest
from paddle.fluid.tests.unittests.test_concat_op import TestConcatOp, TestConcatOp2, TestConcatOp3 from paddle.fluid.tests.unittests.test_concat_op import TestConcatOp, TestConcatOp2, TestConcatOp3, TestConcatOp4
class TestMKLDNNConcatOp(TestConcatOp): class TestMKLDNNConcatOp(TestConcatOp):
...@@ -69,5 +69,22 @@ class TestMKLDNNConcatOp3(TestConcatOp3): ...@@ -69,5 +69,22 @@ class TestMKLDNNConcatOp3(TestConcatOp3):
self.use_mkldnn = True self.use_mkldnn = True
class TestMKLDNNConcatOp4(TestConcatOp4):
def setUp(self):
super(TestMKLDNNConcatOp4, self).setUp()
self.attrs["use_mkldnn"] = True
self._cpu_only = True
def test_check_output(self):
# TODO(wangzhongpu): support mkldnn op in dygraph mode
self.check_output(check_dygraph=(self.attrs["use_mkldnn"] == False))
def test_check_grad(self):
pass
def init_kernel_type(self):
self.use_mkldnn = True
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册