提交 47409b72 编写于 作者: Z zhupengyang 提交者: Tensor Tang

fix split op and arm unit test

上级 519ef7f7
......@@ -52,10 +52,10 @@ void split_cpy<float>(const float* din, float* dout, int num) {
}
template <>
void split<float>(const float* din, std::vector<lite::Tensor*>* dout,
void split<float>(const float* din, const std::vector<lite::Tensor*>& dout,
const int axis, const std::vector<int>& in_strides) {
int input_offset = 0;
for (auto out : *dout) {
for (auto out : dout) {
auto out_dim = out->dims();
std::vector<int> out_strides(out_dim.size());
out_strides[out_dim.size() - 1] = out_dim[out_dim.size() - 1];
......
......@@ -26,7 +26,7 @@ template <typename T>
void split_cpy(const T* din, T* dout, int num);
template <typename T>
void split(const T* din, std::vector<lite::Tensor*>* dout, const int axis,
void split(const T* din, const std::vector<lite::Tensor*>& dout, const int axis,
const std::vector<int>& in_strides);
} // namespace math
......
......@@ -24,7 +24,7 @@ namespace arm {
void SplitCompute::Run() {
auto& param = Param<operators::SplitParam>();
const float* din = param.x->data<float>();
auto* dout = param.output;
auto& dout = param.output;
auto in_dim = param.x->dims();
std::vector<int> in_strides(in_dim.size());
in_strides[in_dim.size() - 1] = in_dim[in_dim.size() - 1];
......
......@@ -24,20 +24,10 @@ namespace kernels {
namespace arm {
void splite_resize_out(const lite::Tensor* din,
std::vector<lite::Tensor*>* dout, int axis, int num,
const std::vector<int>& sections) {
for (auto out : *dout) delete out;
dout->clear();
const std::vector<lite::Tensor*>& dout, int axis,
int num, const std::vector<int>& sections) {
auto in_dims = din->dims();
int outs_number;
if (num > 0) {
outs_number = num;
} else {
outs_number = sections.size();
}
for (int i = 0; i < outs_number; i++) {
dout->push_back(new lite::Tensor);
}
int outs_number = dout.size();
std::vector<lite::DDimLite> outs_dims;
outs_dims.reserve(outs_number);
......@@ -58,7 +48,7 @@ void splite_resize_out(const lite::Tensor* din,
}
for (int j = 0; j < outs_dims.size(); ++j) {
(*dout)[j]->Resize(outs_dims[j]);
dout[j]->Resize(outs_dims[j]);
}
}
......@@ -75,7 +65,7 @@ void split_compute_ref(const operators::SplitParam& param) {
}
int input_offset = 0;
for (auto out : *dout) {
for (auto out : dout) {
auto out_dim = out->dims();
std::vector<int> out_strides(out_dim.size());
out_strides[out_dim.size() - 1] = out_dim[out_dim.size() - 1];
......@@ -128,16 +118,31 @@ TEST(split_arm, compute) {
for (int i = 0; i < x.dims().production(); i++) {
x_data[i] = i;
}
splite_resize_out(&x, &output, axis, num, sections);
splite_resize_out(&x, &output_ref, axis, num, sections);
for (auto out : output) delete out;
for (auto out : output_ref) delete out;
output.clear();
output_ref.clear();
int outs_number;
if (num > 0) {
outs_number = num;
} else {
outs_number = sections.size();
}
for (int i = 0; i < outs_number; i++) {
output.push_back(new lite::Tensor);
output_ref.push_back(new lite::Tensor);
}
splite_resize_out(&x, output, axis, num, sections);
splite_resize_out(&x, output_ref, axis, num, sections);
param.x = &x;
param.axis = axis;
param.num = num;
param.sections = &sections;
param.output = &output;
param.sections = sections;
param.output = output;
split.SetParam(param);
split.Run();
param.output = &output_ref;
param.output = output_ref;
split_compute_ref<float>(param);
for (int i = 0; i < output.size(); i++) {
float* output_data = output[i]->mutable_data<float>();
......
......@@ -178,10 +178,10 @@ struct DropoutParam {
// For Split op
struct SplitParam {
lite::Tensor* x{};
std::vector<lite::Tensor*>* output{};
std::vector<lite::Tensor*> output{};
int axis{-1};
int num{0};
std::vector<int>* sections;
std::vector<int> sections;
};
/// ----------------------- element wise operators ----------------------
......
......@@ -21,7 +21,7 @@ namespace operators {
bool SplitOp::CheckShape() const {
CHECK_OR_FALSE(param_.x);
CHECK_OR_FALSE(param_.output);
CHECK_GT_OR_FALSE(param_.output.size(), 1UL);
auto x_dims = param_.x->dims();
auto x_rank = x_dims.size();
CHECK_OR_FALSE(param_.axis >= -static_cast<int>(x_rank) &&
......@@ -31,7 +31,7 @@ bool SplitOp::CheckShape() const {
bool SplitOp::InferShape() const {
const auto &outs = param_.output;
auto in_dims = param_.x.dims();
auto in_dims = param_.x->dims();
int axis = param_.axis;
int num = param_.num;
const auto &sections = param_.sections;
......@@ -68,7 +68,7 @@ bool SplitOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) {
param_.sections = opdesc.GetAttr<std::vector<int>>("sections");
param_.x = const_cast<lite::Tensor *>(
&scope->FindVar(opdesc.Input("X").front())->Get<lite::Tensor>());
auto outs = op_desc.Output("Out");
auto outs = opdesc.Output("Out");
for (auto var : outs) {
param_.output.push_back(scope->FindVar(var)->GetMutable<lite::Tensor>());
}
......@@ -79,4 +79,4 @@ bool SplitOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) {
} // namespace lite
} // namespace paddle
REGISTER_LITE_OP(softmax, paddle::lite::operators::SoftmaxOp);
REGISTER_LITE_OP(split, paddle::lite::operators::SplitOp);
......@@ -23,7 +23,7 @@ namespace paddle {
namespace lite {
namespace operators {
class SoftmaxOp : public OpLite {
class SplitOp : public OpLite {
public:
SplitOp() {}
explicit SplitOp(const std::string &op_type) : OpLite(op_type) {}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册