diff --git a/paddle/fluid/lite/arm/math/split.cc b/paddle/fluid/lite/arm/math/split.cc index 6dd6de6242e806947dfc630fd8f2a4dd03c89335..bf8d50590ff89c451347e33a289391b8d929e5b6 100644 --- a/paddle/fluid/lite/arm/math/split.cc +++ b/paddle/fluid/lite/arm/math/split.cc @@ -52,10 +52,10 @@ void split_cpy(const float* din, float* dout, int num) { } template <> -void split(const float* din, std::vector* dout, +void split(const float* din, const std::vector& dout, const int axis, const std::vector& in_strides) { int input_offset = 0; - for (auto out : *dout) { + for (auto out : dout) { auto out_dim = out->dims(); std::vector out_strides(out_dim.size()); out_strides[out_dim.size() - 1] = out_dim[out_dim.size() - 1]; diff --git a/paddle/fluid/lite/arm/math/split.h b/paddle/fluid/lite/arm/math/split.h index 9b5651d81ffa75362fcc39db82157c56548917c0..643214e174c3ede02f430ee4ded7cee097ba0afc 100644 --- a/paddle/fluid/lite/arm/math/split.h +++ b/paddle/fluid/lite/arm/math/split.h @@ -26,7 +26,7 @@ template void split_cpy(const T* din, T* dout, int num); template -void split(const T* din, std::vector* dout, const int axis, +void split(const T* din, const std::vector& dout, const int axis, const std::vector& in_strides); } // namespace math diff --git a/paddle/fluid/lite/kernels/arm/split_compute.cc b/paddle/fluid/lite/kernels/arm/split_compute.cc index 9da69894592e146c9191eb9da38d8d481cf287a7..3c2416bd6907199e6e83baf65c428b675462f271 100644 --- a/paddle/fluid/lite/kernels/arm/split_compute.cc +++ b/paddle/fluid/lite/kernels/arm/split_compute.cc @@ -24,7 +24,7 @@ namespace arm { void SplitCompute::Run() { auto& param = Param(); const float* din = param.x->data(); - auto* dout = param.output; + auto& dout = param.output; auto in_dim = param.x->dims(); std::vector in_strides(in_dim.size()); in_strides[in_dim.size() - 1] = in_dim[in_dim.size() - 1]; diff --git a/paddle/fluid/lite/kernels/arm/split_compute_test.cc b/paddle/fluid/lite/kernels/arm/split_compute_test.cc index 808a1e2cdb7724042ffcd1324cf0dc2c5e28f2fc..39632bee8decfe875f0adb3c2717d58e593c400b 100644 --- a/paddle/fluid/lite/kernels/arm/split_compute_test.cc +++ b/paddle/fluid/lite/kernels/arm/split_compute_test.cc @@ -24,20 +24,10 @@ namespace kernels { namespace arm { void splite_resize_out(const lite::Tensor* din, - std::vector* dout, int axis, int num, - const std::vector& sections) { - for (auto out : *dout) delete out; - dout->clear(); + const std::vector& dout, int axis, + int num, const std::vector& sections) { auto in_dims = din->dims(); - int outs_number; - if (num > 0) { - outs_number = num; - } else { - outs_number = sections.size(); - } - for (int i = 0; i < outs_number; i++) { - dout->push_back(new lite::Tensor); - } + int outs_number = dout.size(); std::vector outs_dims; outs_dims.reserve(outs_number); @@ -58,7 +48,7 @@ void splite_resize_out(const lite::Tensor* din, } for (int j = 0; j < outs_dims.size(); ++j) { - (*dout)[j]->Resize(outs_dims[j]); + dout[j]->Resize(outs_dims[j]); } } @@ -75,7 +65,7 @@ void split_compute_ref(const operators::SplitParam& param) { } int input_offset = 0; - for (auto out : *dout) { + for (auto out : dout) { auto out_dim = out->dims(); std::vector out_strides(out_dim.size()); out_strides[out_dim.size() - 1] = out_dim[out_dim.size() - 1]; @@ -128,16 +118,31 @@ TEST(split_arm, compute) { for (int i = 0; i < x.dims().production(); i++) { x_data[i] = i; } - splite_resize_out(&x, &output, axis, num, sections); - splite_resize_out(&x, &output_ref, axis, num, sections); + for (auto out : output) delete out; + for (auto out : output_ref) delete out; + output.clear(); + output_ref.clear(); + + int outs_number; + if (num > 0) { + outs_number = num; + } else { + outs_number = sections.size(); + } + for (int i = 0; i < outs_number; i++) { + output.push_back(new lite::Tensor); + output_ref.push_back(new lite::Tensor); + } + splite_resize_out(&x, output, axis, num, sections); + splite_resize_out(&x, output_ref, axis, num, sections); param.x = &x; param.axis = axis; param.num = num; - param.sections = §ions; - param.output = &output; + param.sections = sections; + param.output = output; split.SetParam(param); split.Run(); - param.output = &output_ref; + param.output = output_ref; split_compute_ref(param); for (int i = 0; i < output.size(); i++) { float* output_data = output[i]->mutable_data(); diff --git a/paddle/fluid/lite/operators/op_params.h b/paddle/fluid/lite/operators/op_params.h index 5bc925f35a8b95c23033eafaff82fdcf44ac3e9a..a81e590b5a34db70c0b90759b4bd18b7d8d27cad 100644 --- a/paddle/fluid/lite/operators/op_params.h +++ b/paddle/fluid/lite/operators/op_params.h @@ -178,10 +178,10 @@ struct DropoutParam { // For Split op struct SplitParam { lite::Tensor* x{}; - std::vector* output{}; + std::vector output{}; int axis{-1}; int num{0}; - std::vector* sections; + std::vector sections; }; /// ----------------------- element wise operators ---------------------- diff --git a/paddle/fluid/lite/operators/split_op.cc b/paddle/fluid/lite/operators/split_op.cc index c788e9cf9546a8c058398d71fde7aa4295fe8fbc..9b4b7662ab7ba7228ee215bf051601150e2b6bb7 100644 --- a/paddle/fluid/lite/operators/split_op.cc +++ b/paddle/fluid/lite/operators/split_op.cc @@ -21,7 +21,7 @@ namespace operators { bool SplitOp::CheckShape() const { CHECK_OR_FALSE(param_.x); - CHECK_OR_FALSE(param_.output); + CHECK_GT_OR_FALSE(param_.output.size(), 1UL); auto x_dims = param_.x->dims(); auto x_rank = x_dims.size(); CHECK_OR_FALSE(param_.axis >= -static_cast(x_rank) && @@ -31,7 +31,7 @@ bool SplitOp::CheckShape() const { bool SplitOp::InferShape() const { const auto &outs = param_.output; - auto in_dims = param_.x.dims(); + auto in_dims = param_.x->dims(); int axis = param_.axis; int num = param_.num; const auto §ions = param_.sections; @@ -68,7 +68,7 @@ bool SplitOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) { param_.sections = opdesc.GetAttr>("sections"); param_.x = const_cast( &scope->FindVar(opdesc.Input("X").front())->Get()); - auto outs = op_desc.Output("Out"); + auto outs = opdesc.Output("Out"); for (auto var : outs) { param_.output.push_back(scope->FindVar(var)->GetMutable()); } @@ -79,4 +79,4 @@ bool SplitOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) { } // namespace lite } // namespace paddle -REGISTER_LITE_OP(softmax, paddle::lite::operators::SoftmaxOp); +REGISTER_LITE_OP(split, paddle::lite::operators::SplitOp); diff --git a/paddle/fluid/lite/operators/split_op.h b/paddle/fluid/lite/operators/split_op.h index 177c44171e6e67214f820f04e801be6c01df01cc..20dc4b1028c27f4efab558694285e44d46182ef8 100644 --- a/paddle/fluid/lite/operators/split_op.h +++ b/paddle/fluid/lite/operators/split_op.h @@ -23,7 +23,7 @@ namespace paddle { namespace lite { namespace operators { -class SoftmaxOp : public OpLite { +class SplitOp : public OpLite { public: SplitOp() {} explicit SplitOp(const std::string &op_type) : OpLite(op_type) {}