“680935cc2f6855e0ca36f1e6a7a0f94cc84cb5eb”上不存在“mobile/src/operators/gru_unit_op.h”
提交 cbd381b0 编写于 作者: 李寅

Merge branch 'master' into 'master'

Optimize reshape op

See merge request !1018
...@@ -23,16 +23,12 @@ template <DeviceType D, class T> ...@@ -23,16 +23,12 @@ template <DeviceType D, class T>
class ReshapeOp : public Operation { class ReshapeOp : public Operation {
public: public:
explicit ReshapeOp(OpConstructContext *context) explicit ReshapeOp(OpConstructContext *context)
: Operation(context) {} : Operation(context),
has_df_(Operation::GetOptionalArg<int>("has_data_format", 0)) {}
MaceStatus Run(OpContext *context) override { MaceStatus Run(OpContext *context) override {
MACE_UNUSED(context); MACE_UNUSED(context);
const Tensor *input = this->Input(INPUT); const Tensor *input = this->Input(INPUT);
const std::vector<index_t> &input_shape = input->shape();
int axis = Operation::GetOptionalArg<int>("reshape_axis", 0);
int num_axes = Operation::GetOptionalArg<int>("num_axes", -1);
MACE_CHECK(axis == 0 && num_axes == -1,
"Only support axis = 0 and num_axes = -1");
const Tensor *shape = this->Input(SHAPE); const Tensor *shape = this->Input(SHAPE);
const index_t num_dims = shape->dim_size() == 0 ? 0 : shape->dim(0); const index_t num_dims = shape->dim_size() == 0 ? 0 : shape->dim(0);
Tensor::MappingGuard shape_guard(shape); Tensor::MappingGuard shape_guard(shape);
...@@ -40,20 +36,16 @@ class ReshapeOp : public Operation { ...@@ -40,20 +36,16 @@ class ReshapeOp : public Operation {
int unknown_idx = -1; int unknown_idx = -1;
index_t product = 1; index_t product = 1;
std::vector<index_t> out_shape; std::vector<index_t> out_shape(num_dims);
index_t n = 0; index_t n = 0;
for (int i = 0; i < num_dims; ++i) { for (int i = 0; i < num_dims; ++i) {
if (shape_data[i] == -1) { if (shape_data[i] == -1) {
MACE_CHECK(unknown_idx == -1, "Only one input size may be -1"); MACE_CHECK(unknown_idx == -1, "Only one input size may be -1");
unknown_idx = i; unknown_idx = i;
out_shape.push_back(1); out_shape[i] = 1;
} else if (shape_data[i] == 0) {
MACE_CHECK(shape_data[i] == 0, "Shape should be 0");
out_shape.push_back(input_shape[i]);
product *= input_shape[i];
} else { } else {
MACE_CHECK(shape_data[i] > 0, "Shape must be non-negative: ", MACE_CHECK(shape_data[i] >= 0, "Shape must be non-negative: ",
shape_data[i]); shape_data[i]);
if (shape_data[i] == 0) { if (shape_data[i] == 0) {
MACE_CHECK(i < input->dim_size(), MACE_CHECK(i < input->dim_size(),
...@@ -62,7 +54,7 @@ class ReshapeOp : public Operation { ...@@ -62,7 +54,7 @@ class ReshapeOp : public Operation {
} else { } else {
n = shape_data[i]; n = shape_data[i];
} }
out_shape.push_back(n); out_shape[i] = n;
product *= n; product *= n;
} }
} }
...@@ -77,14 +69,13 @@ class ReshapeOp : public Operation { ...@@ -77,14 +69,13 @@ class ReshapeOp : public Operation {
} }
Tensor *output = this->Output(OUTPUT); Tensor *output = this->Output(OUTPUT);
// NHWC -> NCHW // NHWC -> NCHW
auto has_df = Operation::GetOptionalArg<int>(
"has_data_format", 0); if (has_df_ && D == DeviceType::CPU
if (has_df && D == DeviceType::CPU
&& out_shape.size() == 4 && shape->is_weight()) { && out_shape.size() == 4 && shape->is_weight()) {
std::vector<int> dst_dims = {0, 3, 1, 2}; std::vector<int> dst_dims = {0, 3, 1, 2};
std::vector<index_t> out_shape_gpu = TransposeShape<index_t, index_t>( std::vector<index_t> trans_shape = TransposeShape<index_t, index_t>(
out_shape, dst_dims); out_shape, dst_dims);
out_shape = out_shape_gpu; out_shape = trans_shape;
} }
output->ReuseTensorBuffer(*input); output->ReuseTensorBuffer(*input);
...@@ -93,6 +84,9 @@ class ReshapeOp : public Operation { ...@@ -93,6 +84,9 @@ class ReshapeOp : public Operation {
return MaceStatus::MACE_SUCCESS; return MaceStatus::MACE_SUCCESS;
} }
private:
bool has_df_;
private: private:
MACE_OP_INPUT_TAGS(INPUT, SHAPE); MACE_OP_INPUT_TAGS(INPUT, SHAPE);
MACE_OP_OUTPUT_TAGS(OUTPUT); MACE_OP_OUTPUT_TAGS(OUTPUT);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册