提交 5f0d7166 编写于 作者: T Tensor Tang

Merge branch 'develop-split' into 'develop'

fix split op and arm unit test

See merge request paddlepaddle/paddlelitearmbackend!15
...@@ -52,10 +52,10 @@ void split_cpy<float>(const float* din, float* dout, int num) { ...@@ -52,10 +52,10 @@ void split_cpy<float>(const float* din, float* dout, int num) {
} }
template <> template <>
void split<float>(const float* din, std::vector<lite::Tensor*>* dout, void split<float>(const float* din, const std::vector<lite::Tensor*>& dout,
const int axis, const std::vector<int>& in_strides) { const int axis, const std::vector<int>& in_strides) {
int input_offset = 0; int input_offset = 0;
for (auto out : *dout) { for (auto out : dout) {
auto out_dim = out->dims(); auto out_dim = out->dims();
std::vector<int> out_strides(out_dim.size()); std::vector<int> out_strides(out_dim.size());
out_strides[out_dim.size() - 1] = out_dim[out_dim.size() - 1]; out_strides[out_dim.size() - 1] = out_dim[out_dim.size() - 1];
......
...@@ -26,7 +26,7 @@ template <typename T> ...@@ -26,7 +26,7 @@ template <typename T>
void split_cpy(const T* din, T* dout, int num); void split_cpy(const T* din, T* dout, int num);
template <typename T> template <typename T>
void split(const T* din, std::vector<lite::Tensor*>* dout, const int axis, void split(const T* din, const std::vector<lite::Tensor*>& dout, const int axis,
const std::vector<int>& in_strides); const std::vector<int>& in_strides);
} // namespace math } // namespace math
......
...@@ -24,7 +24,7 @@ namespace arm { ...@@ -24,7 +24,7 @@ namespace arm {
void SplitCompute::Run() { void SplitCompute::Run() {
auto& param = Param<operators::SplitParam>(); auto& param = Param<operators::SplitParam>();
const float* din = param.x->data<float>(); const float* din = param.x->data<float>();
auto* dout = param.output; auto& dout = param.output;
auto in_dim = param.x->dims(); auto in_dim = param.x->dims();
std::vector<int> in_strides(in_dim.size()); std::vector<int> in_strides(in_dim.size());
in_strides[in_dim.size() - 1] = in_dim[in_dim.size() - 1]; in_strides[in_dim.size() - 1] = in_dim[in_dim.size() - 1];
......
...@@ -24,20 +24,10 @@ namespace kernels { ...@@ -24,20 +24,10 @@ namespace kernels {
namespace arm { namespace arm {
void splite_resize_out(const lite::Tensor* din, void splite_resize_out(const lite::Tensor* din,
std::vector<lite::Tensor*>* dout, int axis, int num, const std::vector<lite::Tensor*>& dout, int axis,
const std::vector<int>& sections) { int num, const std::vector<int>& sections) {
for (auto out : *dout) delete out;
dout->clear();
auto in_dims = din->dims(); auto in_dims = din->dims();
int outs_number; int outs_number = dout.size();
if (num > 0) {
outs_number = num;
} else {
outs_number = sections.size();
}
for (int i = 0; i < outs_number; i++) {
dout->push_back(new lite::Tensor);
}
std::vector<lite::DDimLite> outs_dims; std::vector<lite::DDimLite> outs_dims;
outs_dims.reserve(outs_number); outs_dims.reserve(outs_number);
...@@ -58,7 +48,7 @@ void splite_resize_out(const lite::Tensor* din, ...@@ -58,7 +48,7 @@ void splite_resize_out(const lite::Tensor* din,
} }
for (int j = 0; j < outs_dims.size(); ++j) { for (int j = 0; j < outs_dims.size(); ++j) {
(*dout)[j]->Resize(outs_dims[j]); dout[j]->Resize(outs_dims[j]);
} }
} }
...@@ -75,7 +65,7 @@ void split_compute_ref(const operators::SplitParam& param) { ...@@ -75,7 +65,7 @@ void split_compute_ref(const operators::SplitParam& param) {
} }
int input_offset = 0; int input_offset = 0;
for (auto out : *dout) { for (auto out : dout) {
auto out_dim = out->dims(); auto out_dim = out->dims();
std::vector<int> out_strides(out_dim.size()); std::vector<int> out_strides(out_dim.size());
out_strides[out_dim.size() - 1] = out_dim[out_dim.size() - 1]; out_strides[out_dim.size() - 1] = out_dim[out_dim.size() - 1];
...@@ -128,16 +118,31 @@ TEST(split_arm, compute) { ...@@ -128,16 +118,31 @@ TEST(split_arm, compute) {
for (int i = 0; i < x.dims().production(); i++) { for (int i = 0; i < x.dims().production(); i++) {
x_data[i] = i; x_data[i] = i;
} }
splite_resize_out(&x, &output, axis, num, sections); for (auto out : output) delete out;
splite_resize_out(&x, &output_ref, axis, num, sections); for (auto out : output_ref) delete out;
output.clear();
output_ref.clear();
int outs_number;
if (num > 0) {
outs_number = num;
} else {
outs_number = sections.size();
}
for (int i = 0; i < outs_number; i++) {
output.push_back(new lite::Tensor);
output_ref.push_back(new lite::Tensor);
}
splite_resize_out(&x, output, axis, num, sections);
splite_resize_out(&x, output_ref, axis, num, sections);
param.x = &x; param.x = &x;
param.axis = axis; param.axis = axis;
param.num = num; param.num = num;
param.sections = &sections; param.sections = sections;
param.output = &output; param.output = output;
split.SetParam(param); split.SetParam(param);
split.Run(); split.Run();
param.output = &output_ref; param.output = output_ref;
split_compute_ref<float>(param); split_compute_ref<float>(param);
for (int i = 0; i < output.size(); i++) { for (int i = 0; i < output.size(); i++) {
float* output_data = output[i]->mutable_data<float>(); float* output_data = output[i]->mutable_data<float>();
......
...@@ -178,10 +178,10 @@ struct DropoutParam { ...@@ -178,10 +178,10 @@ struct DropoutParam {
// For Split op // For Split op
struct SplitParam { struct SplitParam {
lite::Tensor* x{}; lite::Tensor* x{};
std::vector<lite::Tensor*>* output{}; std::vector<lite::Tensor*> output{};
int axis{-1}; int axis{-1};
int num{0}; int num{0};
std::vector<int>* sections; std::vector<int> sections;
}; };
/// ----------------------- element wise operators ---------------------- /// ----------------------- element wise operators ----------------------
......
...@@ -21,7 +21,7 @@ namespace operators { ...@@ -21,7 +21,7 @@ namespace operators {
bool SplitOp::CheckShape() const { bool SplitOp::CheckShape() const {
CHECK_OR_FALSE(param_.x); CHECK_OR_FALSE(param_.x);
CHECK_OR_FALSE(param_.output); CHECK_GT_OR_FALSE(param_.output.size(), 1UL);
auto x_dims = param_.x->dims(); auto x_dims = param_.x->dims();
auto x_rank = x_dims.size(); auto x_rank = x_dims.size();
CHECK_OR_FALSE(param_.axis >= -static_cast<int>(x_rank) && CHECK_OR_FALSE(param_.axis >= -static_cast<int>(x_rank) &&
...@@ -31,7 +31,7 @@ bool SplitOp::CheckShape() const { ...@@ -31,7 +31,7 @@ bool SplitOp::CheckShape() const {
bool SplitOp::InferShape() const { bool SplitOp::InferShape() const {
const auto &outs = param_.output; const auto &outs = param_.output;
auto in_dims = param_.x.dims(); auto in_dims = param_.x->dims();
int axis = param_.axis; int axis = param_.axis;
int num = param_.num; int num = param_.num;
const auto &sections = param_.sections; const auto &sections = param_.sections;
...@@ -68,7 +68,7 @@ bool SplitOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) { ...@@ -68,7 +68,7 @@ bool SplitOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) {
param_.sections = opdesc.GetAttr<std::vector<int>>("sections"); param_.sections = opdesc.GetAttr<std::vector<int>>("sections");
param_.x = const_cast<lite::Tensor *>( param_.x = const_cast<lite::Tensor *>(
&scope->FindVar(opdesc.Input("X").front())->Get<lite::Tensor>()); &scope->FindVar(opdesc.Input("X").front())->Get<lite::Tensor>());
auto outs = op_desc.Output("Out"); auto outs = opdesc.Output("Out");
for (auto var : outs) { for (auto var : outs) {
param_.output.push_back(scope->FindVar(var)->GetMutable<lite::Tensor>()); param_.output.push_back(scope->FindVar(var)->GetMutable<lite::Tensor>());
} }
...@@ -79,4 +79,4 @@ bool SplitOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) { ...@@ -79,4 +79,4 @@ bool SplitOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) {
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
REGISTER_LITE_OP(softmax, paddle::lite::operators::SoftmaxOp); REGISTER_LITE_OP(split, paddle::lite::operators::SplitOp);
...@@ -23,7 +23,7 @@ namespace paddle { ...@@ -23,7 +23,7 @@ namespace paddle {
namespace lite { namespace lite {
namespace operators { namespace operators {
class SoftmaxOp : public OpLite { class SplitOp : public OpLite {
public: public:
SplitOp() {} SplitOp() {}
explicit SplitOp(const std::string &op_type) : OpLite(op_type) {} explicit SplitOp(const std::string &op_type) : OpLite(op_type) {}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册