From aed3fc8e89b6505fd1277cf834008d11f358a2fc Mon Sep 17 00:00:00 2001 From: zhupy <1165938320@qq.com> Date: Mon, 17 Jun 2019 11:45:20 +0000 Subject: [PATCH] fix split op test=develop --- paddle/fluid/lite/operators/CMakeLists.txt | 4 ++-- paddle/fluid/lite/operators/split_op.cc | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/paddle/fluid/lite/operators/CMakeLists.txt b/paddle/fluid/lite/operators/CMakeLists.txt index 536fcb75e..1a7ac5bd5 100644 --- a/paddle/fluid/lite/operators/CMakeLists.txt +++ b/paddle/fluid/lite/operators/CMakeLists.txt @@ -20,7 +20,7 @@ cc_library(fill_constant_op_lite SRCS fill_constant_op.cc DEPS ${op_DEPS}) cc_library(op_params_lite SRCS op_params.cc DEPS ${tensor_lite} any_lite framework_proto_lite) cc_library(dropout_op_lite SRCS dropout_op.cc DEPS ${op_DEPS}) cc_library(concat_op_lite SRCS concat_op.cc DEPS ${op_DEPS}) -# cc_library(split_op_lite SRCS split_op.cc DEPS ${op_DEPS}) +cc_library(split_op_lite SRCS split_op.cc DEPS ${op_DEPS}) set(ops_lite conv_op_lite @@ -41,7 +41,7 @@ set(ops_lite activation_ops_lite dropout_op_lite concat_op_lite - #split_op_lite + split_op_lite PARENT_SCOPE) lite_cc_test(test_fc_op_lite SRCS fc_op_test.cc diff --git a/paddle/fluid/lite/operators/split_op.cc b/paddle/fluid/lite/operators/split_op.cc index 587682763..1f220819d 100644 --- a/paddle/fluid/lite/operators/split_op.cc +++ b/paddle/fluid/lite/operators/split_op.cc @@ -48,7 +48,7 @@ bool SplitOp::InferShape() const { outs_dims.push_back(dim); } } else if (sections.size() > 0) { - for (size_t i = 0; i < outs_number; ++i) { + for (int i = 0; i < outs_number; ++i) { auto dim = in_dims; dim[axis] = sections[i]; outs_dims.push_back(dim); @@ -66,9 +66,9 @@ bool SplitOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) { param_.axis = opdesc.GetAttr("axis"); param_.num = opdesc.GetAttr("num"); param_.sections = opdesc.GetAttr>("sections"); - param_.x = const_cast( - &scope->FindVar(opdesc.Input("X").front())->Get()); + auto input = opdesc.Input("Input").front(); auto outs = opdesc.Output("Out"); + param_.x = scope->FindVar(input)->GetMutable(); for (auto var : outs) { param_.output.push_back(scope->FindVar(var)->GetMutable()); } -- GitLab