提交 f8fcc594 编写于 作者: Z zhupengyang

fix split unit test

上级 924fed01
...@@ -27,7 +27,7 @@ void splite_resize_out(const lite::Tensor* din, ...@@ -27,7 +27,7 @@ void splite_resize_out(const lite::Tensor* din,
const std::vector<lite::Tensor*>& dout, int axis, const std::vector<lite::Tensor*>& dout, int axis,
int num, const std::vector<int>& sections) { int num, const std::vector<int>& sections) {
auto in_dims = din->dims(); auto in_dims = din->dims();
int outs_number; int outs_number = dout.size();
std::vector<lite::DDimLite> outs_dims; std::vector<lite::DDimLite> outs_dims;
outs_dims.reserve(outs_number); outs_dims.reserve(outs_number);
...@@ -118,11 +118,11 @@ TEST(split_arm, compute) { ...@@ -118,11 +118,11 @@ TEST(split_arm, compute) {
for (int i = 0; i < x.dims().production(); i++) { for (int i = 0; i < x.dims().production(); i++) {
x_data[i] = i; x_data[i] = i;
} }
for (auto out : output) delete out; for (auto out : output) delete out;
for (auto out : output_ref) delete out; for (auto out : output_ref) delete out;
output.clear(); output.clear();
output_ref.clear(); output_ref.clear();
int outs_number; int outs_number;
if (num > 0) { if (num > 0) {
outs_number = num; outs_number = num;
...@@ -133,7 +133,6 @@ TEST(split_arm, compute) { ...@@ -133,7 +133,6 @@ TEST(split_arm, compute) {
output.push_back(new lite::Tensor); output.push_back(new lite::Tensor);
output_ref.push_back(new lite::Tensor); output_ref.push_back(new lite::Tensor);
} }
splite_resize_out(&x, output, axis, num, sections); splite_resize_out(&x, output, axis, num, sections);
splite_resize_out(&x, output_ref, axis, num, sections); splite_resize_out(&x, output_ref, axis, num, sections);
param.x = &x; param.x = &x;
......
...@@ -31,7 +31,7 @@ bool SplitOp::CheckShape() const { ...@@ -31,7 +31,7 @@ bool SplitOp::CheckShape() const {
bool SplitOp::InferShape() const { bool SplitOp::InferShape() const {
const auto &outs = param_.output; const auto &outs = param_.output;
auto in_dims = param_.x.dims(); auto in_dims = param_.x->dims();
int axis = param_.axis; int axis = param_.axis;
int num = param_.num; int num = param_.num;
const auto &sections = param_.sections; const auto &sections = param_.sections;
...@@ -68,7 +68,7 @@ bool SplitOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) { ...@@ -68,7 +68,7 @@ bool SplitOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) {
param_.sections = opdesc.GetAttr<std::vector<int>>("sections"); param_.sections = opdesc.GetAttr<std::vector<int>>("sections");
param_.x = const_cast<lite::Tensor *>( param_.x = const_cast<lite::Tensor *>(
&scope->FindVar(opdesc.Input("X").front())->Get<lite::Tensor>()); &scope->FindVar(opdesc.Input("X").front())->Get<lite::Tensor>());
auto outs = op_desc.Output("Out").front; auto outs = opdesc.Output("Out");
for (auto var : outs) { for (auto var : outs) {
param_.output.push_back(scope->FindVar(var)->GetMutable<lite::Tensor>()); param_.output.push_back(scope->FindVar(var)->GetMutable<lite::Tensor>());
} }
......
...@@ -23,7 +23,7 @@ namespace paddle { ...@@ -23,7 +23,7 @@ namespace paddle {
namespace lite { namespace lite {
namespace operators { namespace operators {
class SoftmaxOp : public OpLite { class SplitOp : public OpLite {
public: public:
SplitOp() {} SplitOp() {}
explicit SplitOp(const std::string &op_type) : OpLite(op_type) {} explicit SplitOp(const std::string &op_type) : OpLite(op_type) {}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册