diff --git a/src/operators/kernel/central-arm-func/split_arm_func.h b/src/operators/kernel/central-arm-func/split_arm_func.h index 77c1c9ed55e042de91ed41ccbde1b3a01eebe212..24ab2f83a4f3be8b29cb9e33347d639c52f9eea1 100644 --- a/src/operators/kernel/central-arm-func/split_arm_func.h +++ b/src/operators/kernel/central-arm-func/split_arm_func.h @@ -21,8 +21,64 @@ limitations under the License. */ namespace paddle_mobile { namespace operators { +// Strided numel memory copy from src to dst by the specified axis +// +// For example, for a tensor dims [4, 20, 100], the strieded numel is +// [8000, 2000, 100] +// +// NOTE: The src and dst tensor should have the same elements +// except the specified axis. +template +inline void StridedNumelCopyWithAxis(int64_t axis, T* dst, + const framework::DDim& dst_stride_numel, + const T* src, + const framework::DDim& src_stride_numel, + int64_t size) { + int64_t before = dst_stride_numel[0] / dst_stride_numel[axis]; + int64_t src_after = src_stride_numel[axis]; + int64_t dst_after = dst_stride_numel[axis]; + + PADDLE_MOBILE_ENFORCE(src_stride_numel.size() == dst_stride_numel.size(), + "src and dst tensor should have the same dims size."); + + for (int64_t i = 0; i < axis; ++i) { + if (i < axis) { + PADDLE_MOBILE_ENFORCE(src_stride_numel[i] / src_stride_numel[axis] == + dst_stride_numel[i] / dst_stride_numel[axis], + "src and dst should have the same elements " + "except the specified axis."); + } else if (i == axis) { + continue; + } else { + PADDLE_MOBILE_ENFORCE(src_stride_numel[i] == dst_stride_numel[i], + "src and dst should have the same elements " + "except the specified axis."); + } + } + + for (int64_t i = 0; i < before; ++i) { + memory::Copy(dst + i * dst_after, src + i * src_after, sizeof(T) * size); + } +} + template -void SplitCompute(const SplitParam& param) {} +void SplitCompute(const SplitParam& param) { + auto* in = param.InputX(); + auto outs = param.Outs(); + auto in_stride = framework::stride_numel(in->dims()); + int64_t axis = param.Axis(); + + size_t input_offset = 0; + for (auto& out : outs) { + out->mutable_data(); + auto out_stride = framework::stride_numel(out->dims()); + + StridedNumelCopyWithAxis(axis, out->data(), out_stride, + in->data() + input_offset, in_stride, + out_stride[axis]); + input_offset += out_stride[axis]; + } +} } // namespace operators } // namespace paddle_mobile diff --git a/src/operators/op_param.h b/src/operators/op_param.h index b52160a5da311e60ec5f0472fe396335d9c25be3..614f1536a12434945e23b71aa9724be638732370 100644 --- a/src/operators/op_param.h +++ b/src/operators/op_param.h @@ -245,6 +245,12 @@ class OpParam { return GetVarValue("Out", outputs, scope); } + template + static vector OutMultiFrom(const VariableNameMap &outputs, + const Scope &scope) { + return GetMultiVarValue("Out", outputs, scope); + } + template static T *OutputYFrom(const VariableNameMap &outputs, const Scope &scope) { return GetVarValue("Y", outputs, scope); @@ -2291,16 +2297,29 @@ class SplitParam : public OpParam { SplitParam(const VariableNameMap &inputs, const VariableNameMap &outputs, const AttributeMap &attrs, const Scope &scope) { input_x_ = InputXFrom(inputs, scope); - out_ = OutFrom(outputs, scope); + outs_ = OutMultiFrom(outputs, scope); axis = GetAttr("axis", attrs); + num = GetAttr("num", attrs); + sections = GetAttr>("sections", attrs); + + // for (int i = 0; i < outs_.size(); ++i) { + // out_ts_.push_back(*scope.FindVar(outs_[i])->GetMutable()); + // } } const RType *InputX() const { return input_x_; } - RType *Out() const { return out_; } + std::vector Outs() const { return outs_; } + int Axis() const { return axis; } + int Num() const { return num; } + std::vector Sections() const { return sections; } + // std::vector OutTs() const { return out_ts_; } private: RType *input_x_; - RType *out_; + std::vector outs_; int axis; + int num; + std::vector sections; + // std::vector out_ts_; }; #endif diff --git a/src/operators/split_op.cpp b/src/operators/split_op.cpp index df4c2b276cfd90fcf6fe3d29d4f80ae667ee17e1..4f33122976beb214f588f8647637166a6c4e84cd 100644 --- a/src/operators/split_op.cpp +++ b/src/operators/split_op.cpp @@ -18,9 +18,62 @@ limitations under the License. */ namespace paddle_mobile { namespace operators { + template void SplitOp::InferShape() const { - this->param_.Out()->Resize(this->param_.InputX()->dims()); + PADDLE_MOBILE_ENFORCE(this->param_.InputX() != nullptr, + "Input(X) of SplitOp should not be null."); + // std::string str; + // str.size() + const auto &outs = this->param_.Outs(); + PADDLE_MOBILE_ENFORCE(outs.size() >= 1UL, + "Outputs(Out) of SplitOp should not be empty."); + + auto in_dims = this->param_.InputX()->dims(); + size_t axis = static_cast(this->param_.Axis()); + size_t num = static_cast(this->param_.Num()); + + const auto §ions = this->param_.Sections(); + + const size_t outs_number = outs.size(); + std::vector outs_dims; + outs_dims.reserve(outs_number); + + if (num > 0) { + int64_t in_axis_dim = in_dims[axis]; + PADDLE_MOBILE_ENFORCE(in_axis_dim % num == 0, + "tensor split does not result" + " in an equal division"); + size_t out_axis_dim = in_axis_dim / num; + for (size_t i = 0; i < outs_number; ++i) { + auto dim = in_dims; + dim[axis] = out_axis_dim; + outs_dims.push_back(dim); + } + } else if (sections.size() > 0) { + PADDLE_MOBILE_ENFORCE(sections.size() == outs_number, + "tensor split sections size" + "should be equal to output size."); + for (size_t i = 0; i < outs_number; ++i) { + auto dim = in_dims; + dim[axis] = sections[i]; + outs_dims.push_back(dim); + } + } + + PADDLE_MOBILE_ENFORCE(outs_dims.size() == outs.size(), + "length==dims.size() must be true!"); + for (int j = 0; j < outs_dims.size(); ++j) { + outs[j]->Resize(outs_dims[j]); + } + + // todo lod impl + // if (axis != 0) { + // // Only pass LoD when not spliting along the first dim. + // for (size_t i = 0; i < outs_number; ++i) { + // ctx->ShareLoD("X", "Out", 0, i); + // } + // } } } // namespace operators diff --git a/src/operators/split_op.h b/src/operators/split_op.h index 278b66d76cb5ed6cc02588602c64897260190bd1..f7d60b37441e77c5d47ac6040404535a841bcf8e 100644 --- a/src/operators/split_op.h +++ b/src/operators/split_op.h @@ -44,7 +44,6 @@ class SplitOp : public framework::OperatorWithKernel< operators::SplitKernel>::OperatorWithKernel; void InferShape() const override; }; - } // namespace operators } // namespace paddle_mobile