提交 f8fcc594 编写于 作者: Z zhupengyang

fix split unit test

上级 924fed01
......@@ -27,7 +27,7 @@ void splite_resize_out(const lite::Tensor* din,
const std::vector<lite::Tensor*>& dout, int axis,
int num, const std::vector<int>& sections) {
auto in_dims = din->dims();
int outs_number;
int outs_number = dout.size();
std::vector<lite::DDimLite> outs_dims;
outs_dims.reserve(outs_number);
......@@ -118,11 +118,11 @@ TEST(split_arm, compute) {
for (int i = 0; i < x.dims().production(); i++) {
x_data[i] = i;
}
for (auto out : output) delete out;
for (auto out : output_ref) delete out;
output.clear();
output_ref.clear();
int outs_number;
if (num > 0) {
outs_number = num;
......@@ -133,7 +133,6 @@ TEST(split_arm, compute) {
output.push_back(new lite::Tensor);
output_ref.push_back(new lite::Tensor);
}
splite_resize_out(&x, output, axis, num, sections);
splite_resize_out(&x, output_ref, axis, num, sections);
param.x = &x;
......
......@@ -31,7 +31,7 @@ bool SplitOp::CheckShape() const {
bool SplitOp::InferShape() const {
const auto &outs = param_.output;
auto in_dims = param_.x.dims();
auto in_dims = param_.x->dims();
int axis = param_.axis;
int num = param_.num;
const auto &sections = param_.sections;
......@@ -68,7 +68,7 @@ bool SplitOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) {
param_.sections = opdesc.GetAttr<std::vector<int>>("sections");
param_.x = const_cast<lite::Tensor *>(
&scope->FindVar(opdesc.Input("X").front())->Get<lite::Tensor>());
auto outs = op_desc.Output("Out").front;
auto outs = opdesc.Output("Out");
for (auto var : outs) {
param_.output.push_back(scope->FindVar(var)->GetMutable<lite::Tensor>());
}
......
......@@ -23,7 +23,7 @@ namespace paddle {
namespace lite {
namespace operators {
class SoftmaxOp : public OpLite {
class SplitOp : public OpLite {
public:
SplitOp() {}
explicit SplitOp(const std::string &op_type) : OpLite(op_type) {}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册