提交 09b35192 编写于 作者: Z zhupengyang

Merge branch 'zhupy/fix-split-op' into 'incubate/lite'

fix split op

See merge request inference/paddlelite!22
...@@ -20,7 +20,7 @@ cc_library(fill_constant_op_lite SRCS fill_constant_op.cc DEPS ${op_DEPS}) ...@@ -20,7 +20,7 @@ cc_library(fill_constant_op_lite SRCS fill_constant_op.cc DEPS ${op_DEPS})
cc_library(op_params_lite SRCS op_params.cc DEPS ${tensor_lite} any_lite framework_proto_lite) cc_library(op_params_lite SRCS op_params.cc DEPS ${tensor_lite} any_lite framework_proto_lite)
cc_library(dropout_op_lite SRCS dropout_op.cc DEPS ${op_DEPS}) cc_library(dropout_op_lite SRCS dropout_op.cc DEPS ${op_DEPS})
cc_library(concat_op_lite SRCS concat_op.cc DEPS ${op_DEPS}) cc_library(concat_op_lite SRCS concat_op.cc DEPS ${op_DEPS})
# cc_library(split_op_lite SRCS split_op.cc DEPS ${op_DEPS}) cc_library(split_op_lite SRCS split_op.cc DEPS ${op_DEPS})
set(ops_lite set(ops_lite
conv_op_lite conv_op_lite
...@@ -41,7 +41,7 @@ set(ops_lite ...@@ -41,7 +41,7 @@ set(ops_lite
activation_ops_lite activation_ops_lite
dropout_op_lite dropout_op_lite
concat_op_lite concat_op_lite
#split_op_lite split_op_lite
PARENT_SCOPE) PARENT_SCOPE)
lite_cc_test(test_fc_op_lite SRCS fc_op_test.cc lite_cc_test(test_fc_op_lite SRCS fc_op_test.cc
......
...@@ -48,7 +48,7 @@ bool SplitOp::InferShape() const { ...@@ -48,7 +48,7 @@ bool SplitOp::InferShape() const {
outs_dims.push_back(dim); outs_dims.push_back(dim);
} }
} else if (sections.size() > 0) { } else if (sections.size() > 0) {
for (size_t i = 0; i < outs_number; ++i) { for (int i = 0; i < outs_number; ++i) {
auto dim = in_dims; auto dim = in_dims;
dim[axis] = sections[i]; dim[axis] = sections[i];
outs_dims.push_back(dim); outs_dims.push_back(dim);
...@@ -66,9 +66,9 @@ bool SplitOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) { ...@@ -66,9 +66,9 @@ bool SplitOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) {
param_.axis = opdesc.GetAttr<int>("axis"); param_.axis = opdesc.GetAttr<int>("axis");
param_.num = opdesc.GetAttr<int>("num"); param_.num = opdesc.GetAttr<int>("num");
param_.sections = opdesc.GetAttr<std::vector<int>>("sections"); param_.sections = opdesc.GetAttr<std::vector<int>>("sections");
param_.x = const_cast<lite::Tensor *>( auto input = opdesc.Input("Input").front();
&scope->FindVar(opdesc.Input("X").front())->Get<lite::Tensor>());
auto outs = opdesc.Output("Out"); auto outs = opdesc.Output("Out");
param_.x = scope->FindVar(input)->GetMutable<lite::Tensor>();
for (auto var : outs) { for (auto var : outs) {
param_.output.push_back(scope->FindVar(var)->GetMutable<lite::Tensor>()); param_.output.push_back(scope->FindVar(var)->GetMutable<lite::Tensor>());
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册