提交 9f6f5166 编写于 作者: xiebaiyuan's avatar xiebaiyuan

split op impl

上级 fc280daf
...@@ -21,8 +21,64 @@ limitations under the License. */ ...@@ -21,8 +21,64 @@ limitations under the License. */
namespace paddle_mobile { namespace paddle_mobile {
namespace operators { namespace operators {
// Strided numel memory copy from src to dst by the specified axis
//
// For example, for a tensor dims [4, 20, 100], the strieded numel is
// [8000, 2000, 100]
//
// NOTE: The src and dst tensor should have the same elements
// except the specified axis.
template <typename T>
inline void StridedNumelCopyWithAxis(int64_t axis, T* dst,
const framework::DDim& dst_stride_numel,
const T* src,
const framework::DDim& src_stride_numel,
int64_t size) {
int64_t before = dst_stride_numel[0] / dst_stride_numel[axis];
int64_t src_after = src_stride_numel[axis];
int64_t dst_after = dst_stride_numel[axis];
PADDLE_MOBILE_ENFORCE(src_stride_numel.size() == dst_stride_numel.size(),
"src and dst tensor should have the same dims size.");
for (int64_t i = 0; i < axis; ++i) {
if (i < axis) {
PADDLE_MOBILE_ENFORCE(src_stride_numel[i] / src_stride_numel[axis] ==
dst_stride_numel[i] / dst_stride_numel[axis],
"src and dst should have the same elements "
"except the specified axis.");
} else if (i == axis) {
continue;
} else {
PADDLE_MOBILE_ENFORCE(src_stride_numel[i] == dst_stride_numel[i],
"src and dst should have the same elements "
"except the specified axis.");
}
}
for (int64_t i = 0; i < before; ++i) {
memory::Copy(dst + i * dst_after, src + i * src_after, sizeof(T) * size);
}
}
template <typename P> template <typename P>
void SplitCompute(const SplitParam<CPU>& param) {} void SplitCompute(const SplitParam<CPU>& param) {
auto* in = param.InputX();
auto outs = param.Outs();
auto in_stride = framework::stride_numel(in->dims());
int64_t axis = param.Axis();
size_t input_offset = 0;
for (auto& out : outs) {
out->mutable_data<float>();
auto out_stride = framework::stride_numel(out->dims());
StridedNumelCopyWithAxis<float>(axis, out->data<float>(), out_stride,
in->data<float>() + input_offset, in_stride,
out_stride[axis]);
input_offset += out_stride[axis];
}
}
} // namespace operators } // namespace operators
} // namespace paddle_mobile } // namespace paddle_mobile
......
...@@ -245,6 +245,12 @@ class OpParam { ...@@ -245,6 +245,12 @@ class OpParam {
return GetVarValue<T>("Out", outputs, scope); return GetVarValue<T>("Out", outputs, scope);
} }
template <typename T>
static vector<T *> OutMultiFrom(const VariableNameMap &outputs,
const Scope &scope) {
return GetMultiVarValue<T>("Out", outputs, scope);
}
template <typename T> template <typename T>
static T *OutputYFrom(const VariableNameMap &outputs, const Scope &scope) { static T *OutputYFrom(const VariableNameMap &outputs, const Scope &scope) {
return GetVarValue<T>("Y", outputs, scope); return GetVarValue<T>("Y", outputs, scope);
...@@ -2291,16 +2297,29 @@ class SplitParam : public OpParam { ...@@ -2291,16 +2297,29 @@ class SplitParam : public OpParam {
SplitParam(const VariableNameMap &inputs, const VariableNameMap &outputs, SplitParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
const AttributeMap &attrs, const Scope &scope) { const AttributeMap &attrs, const Scope &scope) {
input_x_ = InputXFrom<GType>(inputs, scope); input_x_ = InputXFrom<GType>(inputs, scope);
out_ = OutFrom<GType>(outputs, scope); outs_ = OutMultiFrom<GType>(outputs, scope);
axis = GetAttr<int>("axis", attrs); axis = GetAttr<int>("axis", attrs);
num = GetAttr<int>("num", attrs);
sections = GetAttr<std::vector<int>>("sections", attrs);
// for (int i = 0; i < outs_.size(); ++i) {
// out_ts_.push_back(*scope.FindVar(outs_[i])->GetMutable());
// }
} }
const RType *InputX() const { return input_x_; } const RType *InputX() const { return input_x_; }
RType *Out() const { return out_; } std::vector<GType *> Outs() const { return outs_; }
int Axis() const { return axis; }
int Num() const { return num; }
std::vector<int> Sections() const { return sections; }
// std::vector<GType> OutTs() const { return out_ts_; }
private: private:
RType *input_x_; RType *input_x_;
RType *out_; std::vector<GType *> outs_;
int axis; int axis;
int num;
std::vector<int> sections;
// std::vector<GType> out_ts_;
}; };
#endif #endif
......
...@@ -18,9 +18,62 @@ limitations under the License. */ ...@@ -18,9 +18,62 @@ limitations under the License. */
namespace paddle_mobile { namespace paddle_mobile {
namespace operators { namespace operators {
template <typename DeviceType, typename T> template <typename DeviceType, typename T>
void SplitOp<DeviceType, T>::InferShape() const { void SplitOp<DeviceType, T>::InferShape() const {
this->param_.Out()->Resize(this->param_.InputX()->dims()); PADDLE_MOBILE_ENFORCE(this->param_.InputX() != nullptr,
"Input(X) of SplitOp should not be null.");
// std::string str;
// str.size()
const auto &outs = this->param_.Outs();
PADDLE_MOBILE_ENFORCE(outs.size() >= 1UL,
"Outputs(Out) of SplitOp should not be empty.");
auto in_dims = this->param_.InputX()->dims();
size_t axis = static_cast<size_t>(this->param_.Axis());
size_t num = static_cast<size_t>(this->param_.Num());
const auto &sections = this->param_.Sections();
const size_t outs_number = outs.size();
std::vector<framework::DDim> outs_dims;
outs_dims.reserve(outs_number);
if (num > 0) {
int64_t in_axis_dim = in_dims[axis];
PADDLE_MOBILE_ENFORCE(in_axis_dim % num == 0,
"tensor split does not result"
" in an equal division");
size_t out_axis_dim = in_axis_dim / num;
for (size_t i = 0; i < outs_number; ++i) {
auto dim = in_dims;
dim[axis] = out_axis_dim;
outs_dims.push_back(dim);
}
} else if (sections.size() > 0) {
PADDLE_MOBILE_ENFORCE(sections.size() == outs_number,
"tensor split sections size"
"should be equal to output size.");
for (size_t i = 0; i < outs_number; ++i) {
auto dim = in_dims;
dim[axis] = sections[i];
outs_dims.push_back(dim);
}
}
PADDLE_MOBILE_ENFORCE(outs_dims.size() == outs.size(),
"length==dims.size() must be true!");
for (int j = 0; j < outs_dims.size(); ++j) {
outs[j]->Resize(outs_dims[j]);
}
// todo lod impl
// if (axis != 0) {
// // Only pass LoD when not spliting along the first dim.
// for (size_t i = 0; i < outs_number; ++i) {
// ctx->ShareLoD("X", "Out", 0, i);
// }
// }
} }
} // namespace operators } // namespace operators
......
...@@ -44,7 +44,6 @@ class SplitOp : public framework::OperatorWithKernel< ...@@ -44,7 +44,6 @@ class SplitOp : public framework::OperatorWithKernel<
operators::SplitKernel<DeviceType, T>>::OperatorWithKernel; operators::SplitKernel<DeviceType, T>>::OperatorWithKernel;
void InferShape() const override; void InferShape() const override;
}; };
} // namespace operators } // namespace operators
} // namespace paddle_mobile } // namespace paddle_mobile
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册