提交 fc280daf 编写于 作者: xiebaiyuan's avatar xiebaiyuan

flatten op impl

上级 4c191812
......@@ -18,10 +18,48 @@ limitations under the License. */
namespace paddle_mobile {
namespace operators {
static std::vector<int32_t> GetOutputShape(const int axis,
const framework::DDim &in_dims) {
int64_t outer = 1, inner = 1;
for (int i = 0; i < in_dims.size(); ++i) {
if (i < axis) {
outer *= in_dims[i];
} else {
inner *= in_dims[i];
}
}
std::vector<int32_t> out_shape(2);
out_shape[0] = static_cast<int>(outer);
out_shape[1] = static_cast<int>(inner);
return out_shape;
}
template <typename DeviceType, typename T>
void FlattenOp<DeviceType, T>::InferShape() const {
// todo check
this->param_.Out()->Resize(this->param_.InputX()->dims());
PADDLE_MOBILE_ENFORCE(this->param_.InputX() != nullptr,
"Input (X) of Flatten op should not be null.");
PADDLE_MOBILE_ENFORCE(this->param_.Out() != nullptr,
"Output (Output) of Flatten op should not be null.");
auto &axis = this->param_.Axis();
PADDLE_MOBILE_ENFORCE(axis >= 0,
"The axis should be greater than or equal to 0.");
auto &in_dims = this->param_.InputX()->dims();
// const auto &in_dims = ctx->GetInputDim("X");
PADDLE_MOBILE_ENFORCE(
axis <= in_dims.size(),
"The axis should be less than or equal to input tensor's rank.");
const auto &out_dims = GetOutputShape(axis, in_dims);
this->param_.Out()->Resize(in_dims);
// todo supprot lodtensor
// if (in_dims[0] == out_dims[0]) {
// // Only pass LoD when the first dimension of output and Input(X)
// // are the same.
// ctx->ShareLoD("X", "Out");
// }
}
} // namespace operators
......
......@@ -22,7 +22,11 @@ namespace paddle_mobile {
namespace operators {
template <typename P>
void FlattenCompute(const FlattenParam<CPU>& param) {}
void FlattenCompute(const FlattenParam<CPU> &param) {
param.Out()->mutable_data<float>();
framework::TensorCopy(*param.InputX(), param.Out());
param.Out()->Resize(param.Out()->dims());
}
} // namespace operators
} // namespace paddle_mobile
......
......@@ -2268,13 +2268,16 @@ class FlattenParam : public OpParam {
const AttributeMap &attrs, const Scope &scope) {
input_x_ = InputXFrom<GType>(inputs, scope);
out_ = OutFrom<GType>(outputs, scope);
axis = GetAttr<int>("axis", attrs);
}
const RType *InputX() const { return input_x_; }
RType *Out() const { return out_; }
const int &Axis() const { return axis; }
private:
RType *input_x_;
RType *out_;
int axis;
};
#endif
......@@ -2289,6 +2292,7 @@ class SplitParam : public OpParam {
const AttributeMap &attrs, const Scope &scope) {
input_x_ = InputXFrom<GType>(inputs, scope);
out_ = OutFrom<GType>(outputs, scope);
axis = GetAttr<int>("axis", attrs);
}
const RType *InputX() const { return input_x_; }
RType *Out() const { return out_; }
......@@ -2296,6 +2300,7 @@ class SplitParam : public OpParam {
private:
RType *input_x_;
RType *out_;
int axis;
};
#endif
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册