From febfd7d602f569cf8abe408a50e10b8c12bff065 Mon Sep 17 00:00:00 2001 From: huzhiqiang <912790387@qq.com> Date: Fri, 6 Sep 2019 17:08:10 +0800 Subject: [PATCH] modify reshape2 OP test=dvelop (#1963) modify reshape2 OP to add shape_tensor input --- lite/kernels/host/reshape_compute.cc | 6 +++ lite/operators/reshape_op.cc | 55 +++++++++++++++++++++------- 2 files changed, 48 insertions(+), 13 deletions(-) diff --git a/lite/kernels/host/reshape_compute.cc b/lite/kernels/host/reshape_compute.cc index a5934999cd..72aa02782b 100644 --- a/lite/kernels/host/reshape_compute.cc +++ b/lite/kernels/host/reshape_compute.cc @@ -66,6 +66,9 @@ REGISTER_LITE_KERNEL(reshape, .BindInput("X", {LiteType::GetTensorTy( TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny), -1)}) + .BindInput("ShapeTensor", + {LiteType::GetTensorTy( + TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny), -1)}) .BindInput("Shape", {LiteType::GetTensorTy( TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny), -1)}) @@ -86,6 +89,9 @@ REGISTER_LITE_KERNEL(reshape2, .BindInput("Shape", {LiteType::GetTensorTy( TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny), -1)}) + .BindInput("ShapeTensor", + {LiteType::GetTensorTy( + TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny), -1)}) .BindOutput("Out", {LiteType::GetTensorTy( TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny), -1)}) diff --git a/lite/operators/reshape_op.cc b/lite/operators/reshape_op.cc index 0e7059d66d..2aa8217e92 100644 --- a/lite/operators/reshape_op.cc +++ b/lite/operators/reshape_op.cc @@ -14,6 +14,7 @@ #include "lite/operators/reshape_op.h" #include "lite/core/op_registry.h" +#include "lite/core/tensor.h" namespace paddle { namespace lite { @@ -43,24 +44,52 @@ bool ReshapeOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) { param_.x = const_cast(&(x_var->Get())); param_.output = output_var->GetMutable(); std::vector input_arg_names = opdesc.InputArgumentNames(); - if (std::find(input_arg_names.begin(), input_arg_names.end(), "Shape") != - input_arg_names.end()) { - if (opdesc.Input("Shape").size() > 0) { - auto actual_shape_var = scope->FindVar(opdesc.Input("Shape").front()); - if (actual_shape_var != nullptr) { - param_.actual_shape = const_cast( - &(actual_shape_var->Get())); - } - } - } - param_.shape = (opdesc.GetAttr>("shape")); if (opdesc.HasAttr("inplace")) { param_.inplace = opdesc.GetAttr("inplace"); } CHECK(param_.x) << "Input(X) of ReshapeOp should not be null."; CHECK(param_.output) << "Output(Out) of ReshapeOp should not be null."; - CHECK(!param_.shape.empty()) - << "The shape information must be set by Attr(shape)."; + + if (opdesc.HasInput("ShapeTensor") && + opdesc.Input("ShapeTensor").size() > 0) { + auto inputs = opdesc.Input("ShapeTensor"); + for (auto var : inputs) { + lite::Tensor *datatensor = + scope->FindVar(var)->GetMutable(); + param_.shape.push_back(datatensor->mutable_data()[0]); + } + const std::vector shape_vector = param_.shape; + lite::Tensor *shape_tensor = new lite::Tensor; + shape_tensor->Resize(DDim({shape_vector.size()})); + int *data_shape = shape_tensor->mutable_data(); + for (int i = 0; i < shape_vector.size(); i++) { + data_shape[i] = shape_vector[i]; + } + param_.actual_shape = shape_tensor; + return true; + } else if (opdesc.HasInput("Shape") && opdesc.Input("Shape").size() > 0) { + auto actual_shape_var = scope->FindVar(opdesc.Input("Shape").front()); + if (actual_shape_var != nullptr) { + param_.actual_shape = + const_cast(&(actual_shape_var->Get())); + int length = param_.actual_shape->dims().production(); + int *shape_list = actual_shape_var->GetMutable(); + param_.shape.assign(shape_list, shape_list + length); + } + return true; + } else { + param_.shape = opdesc.GetAttr>("shape"); + CHECK(!param_.shape.empty()) + << "The shape information must be set by Attr(shape)."; + const std::vector shape_vector = param_.shape; + lite::Tensor *shape_tensor = new lite::Tensor; + shape_tensor->Resize(DDim({shape_vector.size()})); + int *data_shape = shape_tensor->mutable_data(); + for (int i = 0; i < shape_vector.size(); i++) { + data_shape[i] = shape_vector[i]; + } + param_.actual_shape = shape_tensor; + } return true; } -- GitLab