提交 b7eec216 编写于 作者: Y YashasSamaga

generalize axis for concat fusion

上级 dc57707e
......@@ -2525,8 +2525,7 @@ struct Net::Impl : public detail::NetImplBase
// (and so we eliminate the concatenation layer, because the channels
// are concatenated implicitly).
Ptr<ConcatLayer> concatLayer = ld.layerInstance.dynamicCast<ConcatLayer>();
if( !concatLayer.empty() && concatLayer->axis == 1 && !concatLayer->padding &&
ld.outputBlobs.size() == 1 )
if( !concatLayer.empty() && !concatLayer->padding && ld.outputBlobs.size() == 1 )
{
Mat& output = ld.outputBlobs[0];
UMat umat_output;
......@@ -2563,7 +2562,8 @@ struct Net::Impl : public detail::NetImplBase
// the concatenation optimization is applied with batch_size > 1.
// so, for now, we only apply this optimization in the most popular
// case batch_size == 1.
if( output.dims == 4 && output.size[0] == 1 )
int axis = clamp(concatLayer->axis, output.dims);
if( output.total(0, axis) == 1 )
{
size_t i, ninputs = ld.inputBlobsId.size();
std::vector<LayerPin> realinputs(ninputs);
......@@ -2602,14 +2602,14 @@ struct Net::Impl : public detail::NetImplBase
OpenCLBackendWrapper::update(ld.outputBlobsWrappers, umats);
}
#endif
Range chrange[] = { Range::all(), Range::all(), Range::all(), Range::all() };
std::vector<Range> chrange(output.dims, Range::all());
int ofs = 0;
for( i = 0; i < ninputs; i++ )
{
LayerPin pin = realinputs[i];
LayerData* inp_i_data = &layers[pin.lid];
int channels_i = ld.inputBlobs[i]->size[1];
chrange[1] = Range(ofs, ofs + channels_i);
int channels_i = ld.inputBlobs[i]->size[axis];
chrange[axis] = Range(ofs, ofs + channels_i);
printf_(("\toutput %s(%d) to channels (%d, %d)\n", inp_i_data->layerInstance->name.c_str(),
pin.oid, ofs, ofs + channels_i));
ofs += channels_i;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册