diff --git a/paddle/fluid/lite/kernels/arm/split_compute_test.cc b/paddle/fluid/lite/kernels/arm/split_compute_test.cc index 808a1e2cdb7724042ffcd1324cf0dc2c5e28f2fc..7d2d95205fc58a092be5714720c92382a9883a17 100644 --- a/paddle/fluid/lite/kernels/arm/split_compute_test.cc +++ b/paddle/fluid/lite/kernels/arm/split_compute_test.cc @@ -24,20 +24,10 @@ namespace kernels { namespace arm { void splite_resize_out(const lite::Tensor* din, - std::vector* dout, int axis, int num, - const std::vector& sections) { - for (auto out : *dout) delete out; - dout->clear(); + const std::vector& dout, int axis, + int num, const std::vector& sections) { auto in_dims = din->dims(); int outs_number; - if (num > 0) { - outs_number = num; - } else { - outs_number = sections.size(); - } - for (int i = 0; i < outs_number; i++) { - dout->push_back(new lite::Tensor); - } std::vector outs_dims; outs_dims.reserve(outs_number); @@ -58,7 +48,7 @@ void splite_resize_out(const lite::Tensor* din, } for (int j = 0; j < outs_dims.size(); ++j) { - (*dout)[j]->Resize(outs_dims[j]); + dout[j]->Resize(outs_dims[j]); } } @@ -75,7 +65,7 @@ void split_compute_ref(const operators::SplitParam& param) { } int input_offset = 0; - for (auto out : *dout) { + for (auto out : dout) { auto out_dim = out->dims(); std::vector out_strides(out_dim.size()); out_strides[out_dim.size() - 1] = out_dim[out_dim.size() - 1]; @@ -128,16 +118,32 @@ TEST(split_arm, compute) { for (int i = 0; i < x.dims().production(); i++) { x_data[i] = i; } - splite_resize_out(&x, &output, axis, num, sections); - splite_resize_out(&x, &output_ref, axis, num, sections); + + for (auto out : output) delete out; + for (auto out : output_ref) delete out; + output.clear(); + output_ref.clear(); + int outs_number; + if (num > 0) { + outs_number = num; + } else { + outs_number = sections.size(); + } + for (int i = 0; i < outs_number; i++) { + output.push_back(new lite::Tensor); + output_ref.push_back(new lite::Tensor); + } + + splite_resize_out(&x, output, axis, num, sections); + splite_resize_out(&x, output_ref, axis, num, sections); param.x = &x; param.axis = axis; param.num = num; - param.sections = §ions; - param.output = &output; + param.sections = sections; + param.output = output; split.SetParam(param); split.Run(); - param.output = &output_ref; + param.output = output_ref; split_compute_ref(param); for (int i = 0; i < output.size(); i++) { float* output_data = output[i]->mutable_data(); diff --git a/paddle/fluid/lite/operators/split_op.cc b/paddle/fluid/lite/operators/split_op.cc index c788e9cf9546a8c058398d71fde7aa4295fe8fbc..299be81fc0b216b9b436866ea76ff4ce04051a96 100644 --- a/paddle/fluid/lite/operators/split_op.cc +++ b/paddle/fluid/lite/operators/split_op.cc @@ -21,7 +21,7 @@ namespace operators { bool SplitOp::CheckShape() const { CHECK_OR_FALSE(param_.x); - CHECK_OR_FALSE(param_.output); + CHECK_GT_OR_FALSE(param_.output.size(), 1UL); auto x_dims = param_.x->dims(); auto x_rank = x_dims.size(); CHECK_OR_FALSE(param_.axis >= -static_cast(x_rank) && @@ -68,7 +68,7 @@ bool SplitOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) { param_.sections = opdesc.GetAttr>("sections"); param_.x = const_cast( &scope->FindVar(opdesc.Input("X").front())->Get()); - auto outs = op_desc.Output("Out"); + auto outs = op_desc.Output("Out").front; for (auto var : outs) { param_.output.push_back(scope->FindVar(var)->GetMutable()); } @@ -79,4 +79,4 @@ bool SplitOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) { } // namespace lite } // namespace paddle -REGISTER_LITE_OP(softmax, paddle::lite::operators::SoftmaxOp); +REGISTER_LITE_OP(split, paddle::lite::operators::SplitOp);