未验证 提交 febfd7d6 编写于 作者: H huzhiqiang 提交者: GitHub

modify reshape2 OP test=dvelop (#1963)

modify reshape2 OP to add shape_tensor input
上级 4a3a45b6
......@@ -66,6 +66,9 @@ REGISTER_LITE_KERNEL(reshape,
.BindInput("X",
{LiteType::GetTensorTy(
TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny), -1)})
.BindInput("ShapeTensor",
{LiteType::GetTensorTy(
TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny), -1)})
.BindInput("Shape",
{LiteType::GetTensorTy(
TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny), -1)})
......@@ -86,6 +89,9 @@ REGISTER_LITE_KERNEL(reshape2,
.BindInput("Shape",
{LiteType::GetTensorTy(
TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny), -1)})
.BindInput("ShapeTensor",
{LiteType::GetTensorTy(
TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny), -1)})
.BindOutput("Out",
{LiteType::GetTensorTy(
TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny), -1)})
......
......@@ -14,6 +14,7 @@
#include "lite/operators/reshape_op.h"
#include "lite/core/op_registry.h"
#include "lite/core/tensor.h"
namespace paddle {
namespace lite {
......@@ -43,24 +44,52 @@ bool ReshapeOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) {
param_.x = const_cast<lite::Tensor *>(&(x_var->Get<lite::Tensor>()));
param_.output = output_var->GetMutable<lite::Tensor>();
std::vector<std::string> input_arg_names = opdesc.InputArgumentNames();
if (std::find(input_arg_names.begin(), input_arg_names.end(), "Shape") !=
input_arg_names.end()) {
if (opdesc.Input("Shape").size() > 0) {
auto actual_shape_var = scope->FindVar(opdesc.Input("Shape").front());
if (actual_shape_var != nullptr) {
param_.actual_shape = const_cast<lite::Tensor *>(
&(actual_shape_var->Get<lite::Tensor>()));
}
}
}
param_.shape = (opdesc.GetAttr<std::vector<int>>("shape"));
if (opdesc.HasAttr("inplace")) {
param_.inplace = opdesc.GetAttr<bool>("inplace");
}
CHECK(param_.x) << "Input(X) of ReshapeOp should not be null.";
CHECK(param_.output) << "Output(Out) of ReshapeOp should not be null.";
CHECK(!param_.shape.empty())
<< "The shape information must be set by Attr(shape).";
if (opdesc.HasInput("ShapeTensor") &&
opdesc.Input("ShapeTensor").size() > 0) {
auto inputs = opdesc.Input("ShapeTensor");
for (auto var : inputs) {
lite::Tensor *datatensor =
scope->FindVar(var)->GetMutable<lite::Tensor>();
param_.shape.push_back(datatensor->mutable_data<int>()[0]);
}
const std::vector<int> shape_vector = param_.shape;
lite::Tensor *shape_tensor = new lite::Tensor;
shape_tensor->Resize(DDim({shape_vector.size()}));
int *data_shape = shape_tensor->mutable_data<int>();
for (int i = 0; i < shape_vector.size(); i++) {
data_shape[i] = shape_vector[i];
}
param_.actual_shape = shape_tensor;
return true;
} else if (opdesc.HasInput("Shape") && opdesc.Input("Shape").size() > 0) {
auto actual_shape_var = scope->FindVar(opdesc.Input("Shape").front());
if (actual_shape_var != nullptr) {
param_.actual_shape =
const_cast<lite::Tensor *>(&(actual_shape_var->Get<lite::Tensor>()));
int length = param_.actual_shape->dims().production();
int *shape_list = actual_shape_var->GetMutable<int>();
param_.shape.assign(shape_list, shape_list + length);
}
return true;
} else {
param_.shape = opdesc.GetAttr<std::vector<int>>("shape");
CHECK(!param_.shape.empty())
<< "The shape information must be set by Attr(shape).";
const std::vector<int> shape_vector = param_.shape;
lite::Tensor *shape_tensor = new lite::Tensor;
shape_tensor->Resize(DDim({shape_vector.size()}));
int *data_shape = shape_tensor->mutable_data<int>();
for (int i = 0; i < shape_vector.size(); i++) {
data_shape[i] = shape_vector[i];
}
param_.actual_shape = shape_tensor;
}
return true;
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册