未验证 提交 febfd7d6 编写于 作者: H huzhiqiang 提交者: GitHub

modify reshape2 OP test=dvelop (#1963)

modify reshape2 OP to add shape_tensor input
上级 4a3a45b6
...@@ -66,6 +66,9 @@ REGISTER_LITE_KERNEL(reshape, ...@@ -66,6 +66,9 @@ REGISTER_LITE_KERNEL(reshape,
.BindInput("X", .BindInput("X",
{LiteType::GetTensorTy( {LiteType::GetTensorTy(
TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny), -1)}) TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny), -1)})
.BindInput("ShapeTensor",
{LiteType::GetTensorTy(
TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny), -1)})
.BindInput("Shape", .BindInput("Shape",
{LiteType::GetTensorTy( {LiteType::GetTensorTy(
TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny), -1)}) TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny), -1)})
...@@ -86,6 +89,9 @@ REGISTER_LITE_KERNEL(reshape2, ...@@ -86,6 +89,9 @@ REGISTER_LITE_KERNEL(reshape2,
.BindInput("Shape", .BindInput("Shape",
{LiteType::GetTensorTy( {LiteType::GetTensorTy(
TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny), -1)}) TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny), -1)})
.BindInput("ShapeTensor",
{LiteType::GetTensorTy(
TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny), -1)})
.BindOutput("Out", .BindOutput("Out",
{LiteType::GetTensorTy( {LiteType::GetTensorTy(
TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny), -1)}) TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny), -1)})
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include "lite/operators/reshape_op.h" #include "lite/operators/reshape_op.h"
#include "lite/core/op_registry.h" #include "lite/core/op_registry.h"
#include "lite/core/tensor.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
...@@ -43,24 +44,52 @@ bool ReshapeOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) { ...@@ -43,24 +44,52 @@ bool ReshapeOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) {
param_.x = const_cast<lite::Tensor *>(&(x_var->Get<lite::Tensor>())); param_.x = const_cast<lite::Tensor *>(&(x_var->Get<lite::Tensor>()));
param_.output = output_var->GetMutable<lite::Tensor>(); param_.output = output_var->GetMutable<lite::Tensor>();
std::vector<std::string> input_arg_names = opdesc.InputArgumentNames(); std::vector<std::string> input_arg_names = opdesc.InputArgumentNames();
if (std::find(input_arg_names.begin(), input_arg_names.end(), "Shape") !=
input_arg_names.end()) {
if (opdesc.Input("Shape").size() > 0) {
auto actual_shape_var = scope->FindVar(opdesc.Input("Shape").front());
if (actual_shape_var != nullptr) {
param_.actual_shape = const_cast<lite::Tensor *>(
&(actual_shape_var->Get<lite::Tensor>()));
}
}
}
param_.shape = (opdesc.GetAttr<std::vector<int>>("shape"));
if (opdesc.HasAttr("inplace")) { if (opdesc.HasAttr("inplace")) {
param_.inplace = opdesc.GetAttr<bool>("inplace"); param_.inplace = opdesc.GetAttr<bool>("inplace");
} }
CHECK(param_.x) << "Input(X) of ReshapeOp should not be null."; CHECK(param_.x) << "Input(X) of ReshapeOp should not be null.";
CHECK(param_.output) << "Output(Out) of ReshapeOp should not be null."; CHECK(param_.output) << "Output(Out) of ReshapeOp should not be null.";
CHECK(!param_.shape.empty())
<< "The shape information must be set by Attr(shape)."; if (opdesc.HasInput("ShapeTensor") &&
opdesc.Input("ShapeTensor").size() > 0) {
auto inputs = opdesc.Input("ShapeTensor");
for (auto var : inputs) {
lite::Tensor *datatensor =
scope->FindVar(var)->GetMutable<lite::Tensor>();
param_.shape.push_back(datatensor->mutable_data<int>()[0]);
}
const std::vector<int> shape_vector = param_.shape;
lite::Tensor *shape_tensor = new lite::Tensor;
shape_tensor->Resize(DDim({shape_vector.size()}));
int *data_shape = shape_tensor->mutable_data<int>();
for (int i = 0; i < shape_vector.size(); i++) {
data_shape[i] = shape_vector[i];
}
param_.actual_shape = shape_tensor;
return true;
} else if (opdesc.HasInput("Shape") && opdesc.Input("Shape").size() > 0) {
auto actual_shape_var = scope->FindVar(opdesc.Input("Shape").front());
if (actual_shape_var != nullptr) {
param_.actual_shape =
const_cast<lite::Tensor *>(&(actual_shape_var->Get<lite::Tensor>()));
int length = param_.actual_shape->dims().production();
int *shape_list = actual_shape_var->GetMutable<int>();
param_.shape.assign(shape_list, shape_list + length);
}
return true;
} else {
param_.shape = opdesc.GetAttr<std::vector<int>>("shape");
CHECK(!param_.shape.empty())
<< "The shape information must be set by Attr(shape).";
const std::vector<int> shape_vector = param_.shape;
lite::Tensor *shape_tensor = new lite::Tensor;
shape_tensor->Resize(DDim({shape_vector.size()}));
int *data_shape = shape_tensor->mutable_data<int>();
for (int i = 0; i < shape_vector.size(); i++) {
data_shape[i] = shape_vector[i];
}
param_.actual_shape = shape_tensor;
}
return true; return true;
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册