未验证 提交 a72907bb 编写于 作者: J jerrywgz 提交者: GitHub

Enhance concat op to support empty input. (#17015)

* enhance_concat, test=develop
上级 83c4f772
...@@ -37,6 +37,9 @@ class ConcatKernel : public framework::OpKernel<T> { ...@@ -37,6 +37,9 @@ class ConcatKernel : public framework::OpKernel<T> {
if (axis == 0 && ins.size() < 10) { if (axis == 0 && ins.size() < 10) {
size_t output_offset = 0; size_t output_offset = 0;
for (auto* in : ins) { for (auto* in : ins) {
if (!in || in->numel() == 0UL) {
continue;
}
auto in_stride = framework::stride_numel(in->dims()); auto in_stride = framework::stride_numel(in->dims());
auto out_stride = framework::stride_numel(out->dims()); auto out_stride = framework::stride_numel(out->dims());
StridedNumelCopyWithAxis<T>(ctx.device_context(), axis, StridedNumelCopyWithAxis<T>(ctx.device_context(), axis,
...@@ -45,9 +48,13 @@ class ConcatKernel : public framework::OpKernel<T> { ...@@ -45,9 +48,13 @@ class ConcatKernel : public framework::OpKernel<T> {
output_offset += in_stride[axis]; output_offset += in_stride[axis];
} }
} else { } else {
std::vector<framework::Tensor> inputs(ins.size()); std::vector<framework::Tensor> inputs;
for (size_t j = 0; j < ins.size(); ++j) { for (size_t j = 0; j < ins.size(); ++j) {
inputs[j] = *ins[j]; if (ins[j] && ins[j]->numel() > 0) {
inputs.push_back(*ins[j]);
} else {
continue;
}
} }
auto& dev_ctx = ctx.template device_context<DeviceContext>(); auto& dev_ctx = ctx.template device_context<DeviceContext>();
paddle::operators::math::ConcatFunctor<DeviceContext, T> concat_functor; paddle::operators::math::ConcatFunctor<DeviceContext, T> concat_functor;
...@@ -82,7 +89,8 @@ class ConcatGradKernel : public framework::OpKernel<T> { ...@@ -82,7 +89,8 @@ class ConcatGradKernel : public framework::OpKernel<T> {
// get output tensor that the name is not kEmptyVarName // get output tensor that the name is not kEmptyVarName
std::vector<framework::Tensor*> outputs; std::vector<framework::Tensor*> outputs;
for (size_t j = 0; j < outs.size(); ++j) { for (size_t j = 0; j < outs.size(); ++j) {
if (out_var_names[j] != framework::kEmptyVarName) { if (out_var_names[j] != framework::kEmptyVarName &&
outs[j]->numel() != 0UL) {
outs[j]->mutable_data<T>(ctx.GetPlace()); outs[j]->mutable_data<T>(ctx.GetPlace());
outputs.push_back(outs[j]); outputs.push_back(outs[j]);
} else { } else {
......
...@@ -64,5 +64,16 @@ class TestConcatOp3(TestConcatOp): ...@@ -64,5 +64,16 @@ class TestConcatOp3(TestConcatOp):
pass pass
class TestConcatOp4(TestConcatOp):
def init_test_data(self):
self.x0 = np.random.random((2, 3, 4, 5)).astype('float32')
self.x1 = np.random.random((2, 3, 4, 5)).astype('float32')
self.x2 = np.random.random((0, 3, 4, 5)).astype('float32')
self.axis = 0
def test_check_grad(self):
pass
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册