提交 924fed01 编写于 作者: Z zhupy

fix split unit test

上级 7e714e2b
...@@ -24,20 +24,10 @@ namespace kernels { ...@@ -24,20 +24,10 @@ namespace kernels {
namespace arm { namespace arm {
void splite_resize_out(const lite::Tensor* din, void splite_resize_out(const lite::Tensor* din,
std::vector<lite::Tensor*>* dout, int axis, int num, const std::vector<lite::Tensor*>& dout, int axis,
const std::vector<int>& sections) { int num, const std::vector<int>& sections) {
for (auto out : *dout) delete out;
dout->clear();
auto in_dims = din->dims(); auto in_dims = din->dims();
int outs_number; int outs_number;
if (num > 0) {
outs_number = num;
} else {
outs_number = sections.size();
}
for (int i = 0; i < outs_number; i++) {
dout->push_back(new lite::Tensor);
}
std::vector<lite::DDimLite> outs_dims; std::vector<lite::DDimLite> outs_dims;
outs_dims.reserve(outs_number); outs_dims.reserve(outs_number);
...@@ -58,7 +48,7 @@ void splite_resize_out(const lite::Tensor* din, ...@@ -58,7 +48,7 @@ void splite_resize_out(const lite::Tensor* din,
} }
for (int j = 0; j < outs_dims.size(); ++j) { for (int j = 0; j < outs_dims.size(); ++j) {
(*dout)[j]->Resize(outs_dims[j]); dout[j]->Resize(outs_dims[j]);
} }
} }
...@@ -75,7 +65,7 @@ void split_compute_ref(const operators::SplitParam& param) { ...@@ -75,7 +65,7 @@ void split_compute_ref(const operators::SplitParam& param) {
} }
int input_offset = 0; int input_offset = 0;
for (auto out : *dout) { for (auto out : dout) {
auto out_dim = out->dims(); auto out_dim = out->dims();
std::vector<int> out_strides(out_dim.size()); std::vector<int> out_strides(out_dim.size());
out_strides[out_dim.size() - 1] = out_dim[out_dim.size() - 1]; out_strides[out_dim.size() - 1] = out_dim[out_dim.size() - 1];
...@@ -128,16 +118,32 @@ TEST(split_arm, compute) { ...@@ -128,16 +118,32 @@ TEST(split_arm, compute) {
for (int i = 0; i < x.dims().production(); i++) { for (int i = 0; i < x.dims().production(); i++) {
x_data[i] = i; x_data[i] = i;
} }
splite_resize_out(&x, &output, axis, num, sections);
splite_resize_out(&x, &output_ref, axis, num, sections); for (auto out : output) delete out;
for (auto out : output_ref) delete out;
output.clear();
output_ref.clear();
int outs_number;
if (num > 0) {
outs_number = num;
} else {
outs_number = sections.size();
}
for (int i = 0; i < outs_number; i++) {
output.push_back(new lite::Tensor);
output_ref.push_back(new lite::Tensor);
}
splite_resize_out(&x, output, axis, num, sections);
splite_resize_out(&x, output_ref, axis, num, sections);
param.x = &x; param.x = &x;
param.axis = axis; param.axis = axis;
param.num = num; param.num = num;
param.sections = &sections; param.sections = sections;
param.output = &output; param.output = output;
split.SetParam(param); split.SetParam(param);
split.Run(); split.Run();
param.output = &output_ref; param.output = output_ref;
split_compute_ref<float>(param); split_compute_ref<float>(param);
for (int i = 0; i < output.size(); i++) { for (int i = 0; i < output.size(); i++) {
float* output_data = output[i]->mutable_data<float>(); float* output_data = output[i]->mutable_data<float>();
......
...@@ -21,7 +21,7 @@ namespace operators { ...@@ -21,7 +21,7 @@ namespace operators {
bool SplitOp::CheckShape() const { bool SplitOp::CheckShape() const {
CHECK_OR_FALSE(param_.x); CHECK_OR_FALSE(param_.x);
CHECK_OR_FALSE(param_.output); CHECK_GT_OR_FALSE(param_.output.size(), 1UL);
auto x_dims = param_.x->dims(); auto x_dims = param_.x->dims();
auto x_rank = x_dims.size(); auto x_rank = x_dims.size();
CHECK_OR_FALSE(param_.axis >= -static_cast<int>(x_rank) && CHECK_OR_FALSE(param_.axis >= -static_cast<int>(x_rank) &&
...@@ -68,7 +68,7 @@ bool SplitOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) { ...@@ -68,7 +68,7 @@ bool SplitOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) {
param_.sections = opdesc.GetAttr<std::vector<int>>("sections"); param_.sections = opdesc.GetAttr<std::vector<int>>("sections");
param_.x = const_cast<lite::Tensor *>( param_.x = const_cast<lite::Tensor *>(
&scope->FindVar(opdesc.Input("X").front())->Get<lite::Tensor>()); &scope->FindVar(opdesc.Input("X").front())->Get<lite::Tensor>());
auto outs = op_desc.Output("Out"); auto outs = op_desc.Output("Out").front;
for (auto var : outs) { for (auto var : outs) {
param_.output.push_back(scope->FindVar(var)->GetMutable<lite::Tensor>()); param_.output.push_back(scope->FindVar(var)->GetMutable<lite::Tensor>());
} }
...@@ -79,4 +79,4 @@ bool SplitOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) { ...@@ -79,4 +79,4 @@ bool SplitOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) {
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
REGISTER_LITE_OP(softmax, paddle::lite::operators::SoftmaxOp); REGISTER_LITE_OP(split, paddle::lite::operators::SplitOp);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册