提交 aed3fc8e 编写于 作者: Z zhupy 提交者: zhupengyang

fix split op

test=develop
上级 f116bb9e
......@@ -20,7 +20,7 @@ cc_library(fill_constant_op_lite SRCS fill_constant_op.cc DEPS ${op_DEPS})
cc_library(op_params_lite SRCS op_params.cc DEPS ${tensor_lite} any_lite framework_proto_lite)
cc_library(dropout_op_lite SRCS dropout_op.cc DEPS ${op_DEPS})
cc_library(concat_op_lite SRCS concat_op.cc DEPS ${op_DEPS})
# cc_library(split_op_lite SRCS split_op.cc DEPS ${op_DEPS})
cc_library(split_op_lite SRCS split_op.cc DEPS ${op_DEPS})
set(ops_lite
conv_op_lite
......@@ -41,7 +41,7 @@ set(ops_lite
activation_ops_lite
dropout_op_lite
concat_op_lite
#split_op_lite
split_op_lite
PARENT_SCOPE)
lite_cc_test(test_fc_op_lite SRCS fc_op_test.cc
......
......@@ -48,7 +48,7 @@ bool SplitOp::InferShape() const {
outs_dims.push_back(dim);
}
} else if (sections.size() > 0) {
for (size_t i = 0; i < outs_number; ++i) {
for (int i = 0; i < outs_number; ++i) {
auto dim = in_dims;
dim[axis] = sections[i];
outs_dims.push_back(dim);
......@@ -66,9 +66,9 @@ bool SplitOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) {
param_.axis = opdesc.GetAttr<int>("axis");
param_.num = opdesc.GetAttr<int>("num");
param_.sections = opdesc.GetAttr<std::vector<int>>("sections");
param_.x = const_cast<lite::Tensor *>(
&scope->FindVar(opdesc.Input("X").front())->Get<lite::Tensor>());
auto input = opdesc.Input("Input").front();
auto outs = opdesc.Output("Out");
param_.x = scope->FindVar(input)->GetMutable<lite::Tensor>();
for (auto var : outs) {
param_.output.push_back(scope->FindVar(var)->GetMutable<lite::Tensor>());
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册