提交 56a81579 编写于 作者: L lichao18

Optimize reshape op

上级 0894c8e9
......@@ -23,16 +23,12 @@ template <DeviceType D, class T>
class ReshapeOp : public Operation {
public:
explicit ReshapeOp(OpConstructContext *context)
: Operation(context) {}
: Operation(context),
has_df_(Operation::GetOptionalArg<int>("has_data_format", 0)) {}
MaceStatus Run(OpContext *context) override {
MACE_UNUSED(context);
const Tensor *input = this->Input(INPUT);
const std::vector<index_t> &input_shape = input->shape();
int axis = Operation::GetOptionalArg<int>("reshape_axis", 0);
int num_axes = Operation::GetOptionalArg<int>("num_axes", -1);
MACE_CHECK(axis == 0 && num_axes == -1,
"Only support axis = 0 and num_axes = -1");
const Tensor *shape = this->Input(SHAPE);
const index_t num_dims = shape->dim_size() == 0 ? 0 : shape->dim(0);
Tensor::MappingGuard shape_guard(shape);
......@@ -40,20 +36,16 @@ class ReshapeOp : public Operation {
int unknown_idx = -1;
index_t product = 1;
std::vector<index_t> out_shape;
std::vector<index_t> out_shape(num_dims);
index_t n = 0;
for (int i = 0; i < num_dims; ++i) {
if (shape_data[i] == -1) {
MACE_CHECK(unknown_idx == -1, "Only one input size may be -1");
unknown_idx = i;
out_shape.push_back(1);
} else if (shape_data[i] == 0) {
MACE_CHECK(shape_data[i] == 0, "Shape should be 0");
out_shape.push_back(input_shape[i]);
product *= input_shape[i];
out_shape[i] = 1;
} else {
MACE_CHECK(shape_data[i] > 0, "Shape must be non-negative: ",
MACE_CHECK(shape_data[i] >= 0, "Shape must be non-negative: ",
shape_data[i]);
if (shape_data[i] == 0) {
MACE_CHECK(i < input->dim_size(),
......@@ -62,7 +54,7 @@ class ReshapeOp : public Operation {
} else {
n = shape_data[i];
}
out_shape.push_back(n);
out_shape[i] = n;
product *= n;
}
}
......@@ -77,14 +69,13 @@ class ReshapeOp : public Operation {
}
Tensor *output = this->Output(OUTPUT);
// NHWC -> NCHW
auto has_df = Operation::GetOptionalArg<int>(
"has_data_format", 0);
if (has_df && D == DeviceType::CPU
if (has_df_ && D == DeviceType::CPU
&& out_shape.size() == 4 && shape->is_weight()) {
std::vector<int> dst_dims = {0, 3, 1, 2};
std::vector<index_t> out_shape_gpu = TransposeShape<index_t, index_t>(
std::vector<index_t> trans_shape = TransposeShape<index_t, index_t>(
out_shape, dst_dims);
out_shape = out_shape_gpu;
out_shape = trans_shape;
}
output->ReuseTensorBuffer(*input);
......@@ -93,6 +84,9 @@ class ReshapeOp : public Operation {
return MaceStatus::MACE_SUCCESS;
}
private:
bool has_df_;
private:
MACE_OP_INPUT_TAGS(INPUT, SHAPE);
MACE_OP_OUTPUT_TAGS(OUTPUT);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册