提交 7e714e2b 编写于 作者: Z zhupy

fix split arm kernel

上级 02029900
......@@ -52,10 +52,10 @@ void split_cpy<float>(const float* din, float* dout, int num) {
}
template <>
void split<float>(const float* din, std::vector<lite::Tensor*>* dout,
void split<float>(const float* din, const std::vector<lite::Tensor*>& dout,
const int axis, const std::vector<int>& in_strides) {
int input_offset = 0;
for (auto out : *dout) {
for (auto out : dout) {
auto out_dim = out->dims();
std::vector<int> out_strides(out_dim.size());
out_strides[out_dim.size() - 1] = out_dim[out_dim.size() - 1];
......
......@@ -26,7 +26,7 @@ template <typename T>
void split_cpy(const T* din, T* dout, int num);
template <typename T>
void split(const T* din, std::vector<lite::Tensor*>* dout, const int axis,
void split(const T* din, const std::vector<lite::Tensor*>& dout, const int axis,
const std::vector<int>& in_strides);
} // namespace math
......
......@@ -24,7 +24,7 @@ namespace arm {
void SplitCompute::Run() {
auto& param = Param<operators::SplitParam>();
const float* din = param.x->data<float>();
auto* dout = param.output;
auto& dout = param.output;
auto in_dim = param.x->dims();
std::vector<int> in_strides(in_dim.size());
in_strides[in_dim.size() - 1] = in_dim[in_dim.size() - 1];
......
......@@ -177,10 +177,10 @@ struct DropoutParam {
// For Split op
struct SplitParam {
lite::Tensor* x{};
std::vector<lite::Tensor*>* output{};
std::vector<lite::Tensor*> output{};
int axis{-1};
int num{0};
std::vector<int>* sections;
std::vector<int> sections;
};
/// ----------------------- element wise operators ----------------------
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册