未验证 提交 421c6305 编写于 作者: J juncaipeng 提交者: GitHub

fix bug for reshape op, test=develop (#2141)

* fix bug for reshape op, test=develop
上级 80d35725
......@@ -24,27 +24,9 @@ namespace host {
void ReshapeCompute::Run() {
auto& param = Param<operators::ReshapeParam>();
auto x = param.x;
auto actual_shape = param.actual_shape;
auto output = param.output;
bool inplace = param.inplace;
auto x_dims = x->dims();
auto output_dims = output->dims();
if (actual_shape) {
auto actual_shape_dims = actual_shape->dims();
auto* actual_shape_data = actual_shape->data<int>();
#ifdef LITE_WITH_CUDA
lite::Tensor cpu_actual_shape;
if (actual_shape->target() == TARGET(kCUDA)) {
cpu_actual_shape.CopyDataFrom(*actual_shape);
actual_shape_data = cpu_actual_shape.data<int>();
}
#endif
auto shape = std::vector<int>(
actual_shape_data, actual_shape_data + actual_shape_dims.production());
output_dims = lite::operators::ValidateShape(shape, x_dims);
output->Resize(output_dims);
}
if (inplace) {
if (param.inplace) {
output->ShareDataWith(*x);
} else {
output->CopyDataFrom(*x);
......
......@@ -32,40 +32,57 @@ TEST(reshape_host, compute) {
ReshapeCompute reshape;
operators::ReshapeParam param;
Tensor x;
Tensor actual_shape;
Tensor input;
Tensor output;
x.Resize(DDim(std::vector<int64_t>({1, 2, 4, 6})));
actual_shape.Resize(DDim(std::vector<int64_t>({2})));
auto* x_data = x.mutable_data<float>();
auto* actual_shape_data = actual_shape.mutable_data<int>();
for (int i = 0; i < x.dims().production(); i++) {
x_data[i] = i;
input.Resize({1, 2, 4, 6});
auto* input_data = input.mutable_data<float>();
for (int i = 0; i < input.numel(); i++) {
input_data[i] = i;
}
actual_shape_data[0] = 6;
actual_shape_data[1] = 8;
Tensor shape_tensor;
shape_tensor.Resize({2});
auto* shape_tensor_data = shape_tensor.mutable_data<int>();
shape_tensor_data[0] = 6;
shape_tensor_data[1] = 8;
param.x = &x;
param.shape = {-1, 0, 3, 2, 1};
param.output = &output;
param.actual_shape = &actual_shape;
// set param and run
param.x = &input;
param.shape_tensor = &shape_tensor; // use shape_tensor
param.inplace = false;
param.output = &output;
reshape.SetParam(param);
reshape.Run();
// check output dims
CHECK_EQ(actual_shape.dims().production(), output.dims().size());
CHECK_EQ(shape_tensor.numel(), output.numel());
for (int i = 0; i < output.dims().size(); i++) {
CHECK_EQ(output.dims()[i], actual_shape_data[i]);
CHECK_EQ(output.dims()[i], shape_tensor_data[i]);
}
// check output data
auto* output_data = output.mutable_data<float>();
CHECK_NE(output_data, x_data);
for (int i = 0; i < output.dims().production(); i++) {
EXPECT_NEAR(output_data[i], x_data[i], 1e-6);
CHECK_NE(output_data, input_data);
for (int i = 0; i < output.numel(); i++) {
EXPECT_NEAR(output_data[i], input_data[i], 1e-6);
}
// use shape, set param and run
param.shape_tensor = nullptr;
param.shape_vct = {-1, 0, 3, 2, 1};
reshape.SetParam(param);
reshape.Run();
// check output dims
CHECK_EQ(shape_tensor.numel(), output.numel());
for (int i = 0; i < output.dims().size(); i++) {
CHECK_EQ(output.dims()[i], shape_tensor_data[i]);
}
// check output data
output_data = output.mutable_data<float>();
CHECK_NE(output_data, input_data);
for (int i = 0; i < output.numel(); i++) {
EXPECT_NEAR(output_data[i], input_data[i], 1e-6);
}
// check output data if inplace = true;
......@@ -73,7 +90,7 @@ TEST(reshape_host, compute) {
reshape.SetParam(param);
reshape.Run();
output_data = output.mutable_data<float>();
CHECK_EQ(output_data, x_data);
CHECK_EQ(output_data, input_data);
}
TEST(reshape, retrive_op) {
......
......@@ -27,19 +27,9 @@ namespace kernels {
namespace x86 {
template <typename T>
void Compute(const lite::Tensor* in,
const lite::Tensor* actual_shape,
lite::Tensor* out) {
void Compute(const lite::Tensor* in, lite::Tensor* out) {
auto out_dims = out->dims();
auto in_dims = in->dims();
if (actual_shape) {
auto shape_dims = actual_shape->dims();
const int* shape_data = actual_shape->data<int>();
std::vector<int> shape =
std::vector<int>(shape_data, shape_data + shape_dims.production());
out_dims = lite::operators::ValidateShape(shape, in_dims);
out->Resize(out_dims);
}
out->CopyDataFrom(*in);
out->Resize(out_dims);
}
......@@ -51,7 +41,7 @@ class ReshapeCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
void Run() override {
auto& param = *param_.get_mutable<param_t>();
Compute<T>(param.x, param.actual_shape, param.output);
Compute<T>(param.x, param.output);
}
virtual ~ReshapeCompute() = default;
......@@ -67,7 +57,7 @@ class Reshape2Compute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
void Run() override {
auto& param = *param_.get_mutable<param_t>();
Compute<T>(param.x, param.actual_shape, param.output);
Compute<T>(param.x, param.output);
}
virtual ~Reshape2Compute() = default;
......
......@@ -44,7 +44,7 @@ TEST(reshape_x86, run_test) {
lite::Tensor out;
std::vector<int64_t> x_shape({1, 2, 4, 1});
x.Resize(lite::DDim(x_shape));
actual_shape.Resize(lite::DDim(std::vector<int64_t>({3})));
actual_shape.Resize(lite::DDim(std::vector<int64_t>({4})));
std::vector<int64_t> out_shape({1, 8, 1, 1});
out.Resize(lite::DDim(out_shape));
......@@ -56,8 +56,9 @@ TEST(reshape_x86, run_test) {
x_data[i] = static_cast<float>(i);
}
actual_data[0] = 1;
actual_data[1] = 4;
actual_data[2] = 2;
actual_data[1] = 8;
actual_data[2] = 1;
actual_data[1] = 1;
std::vector<int> shape({1, 8, 1, 1});
......@@ -67,12 +68,12 @@ TEST(reshape_x86, run_test) {
param.x = &x;
param.output = &out;
param.shape = shape;
param.actual_shape = &actual_shape;
param.shape_vct = shape;
param.shape_tensor = &actual_shape;
std::unique_ptr<KernelContext> ctx(new KernelContext);
ctx->As<X86Context>();
for (int i = 0; i < 2; ++i) {
if (1 == i) param.actual_shape = nullptr;
if (1 == i) param.shape_tensor = nullptr;
reshape.SetContext(std::move(ctx));
reshape.SetParam(param);
reshape.Run();
......@@ -106,7 +107,7 @@ TEST(reshape2_x86, run_test) {
actual_shape.Resize(lite::DDim(std::vector<int64_t>({3})));
std::vector<int64_t> out_shape({1, 4, 2});
out.Resize(lite::DDim(out_shape));
std::vector<int64_t> xshape_shape({1, 2, 4});
std::vector<int64_t> xshape_shape({1, 4, 2});
xshape.Resize(lite::DDim(xshape_shape));
auto x_data = x.mutable_data<float>();
......@@ -122,7 +123,7 @@ TEST(reshape2_x86, run_test) {
actual_data[1] = 4;
actual_data[2] = 2;
std::vector<int> shape({0, -1, 2});
std::vector<int> shape({1, 4, 2});
// Reshape2Compute reshape2;
Reshape2Compute<float> reshape2;
......@@ -131,12 +132,12 @@ TEST(reshape2_x86, run_test) {
param.x = &x;
param.output = &out;
param.xshape = &xshape;
param.shape = shape;
param.actual_shape = &actual_shape;
param.shape_vct = shape;
param.shape_tensor = &actual_shape;
std::unique_ptr<KernelContext> ctx(new KernelContext);
ctx->As<X86Context>();
for (int i = 0; i < 2; ++i) {
if (1 == i) param.actual_shape = nullptr;
if (1 == i) param.shape_tensor = nullptr;
reshape2.SetContext(std::move(ctx));
reshape2.SetParam(param);
reshape2.Run();
......
......@@ -190,11 +190,12 @@ struct SoftmaxParam {
// For Reshape and Reshape2 Op
struct ReshapeParam {
const lite::Tensor* x{};
const lite::Tensor* actual_shape{nullptr};
std::vector<const lite::Tensor*> shape_tensor_vct{};
const lite::Tensor* shape_tensor{};
std::vector<int> shape_vct{};
lite::Tensor* output{};
lite::Tensor* xshape{};
std::vector<int> shape{};
lite::Tensor* xshape{};
bool inplace{false};
};
......
......@@ -23,13 +23,32 @@ namespace operators {
bool ReshapeOp::CheckShape() const {
CHECK_OR_FALSE(param_.x);
CHECK_OR_FALSE(param_.output);
CHECK_OR_FALSE(!param_.shape.empty());
return true;
}
bool ReshapeOp::InferShape() const {
auto shape_tensor_vct = param_.shape_tensor_vct;
auto *shape_tensor = param_.shape_tensor;
auto shape_vct = param_.shape_vct;
std::vector<int> final_shape;
if (shape_tensor_vct.size() > 0) {
for (int i = 0; i < shape_tensor_vct.size(); i++) {
final_shape.push_back(shape_tensor_vct[i]->data<int>()[0]);
}
} else if (shape_tensor != nullptr) {
auto *shape_tensor_data = shape_tensor->data<int>();
final_shape = std::vector<int>(shape_tensor_data,
shape_tensor_data + shape_tensor->numel());
} else if (!shape_vct.empty()) {
final_shape = shape_vct;
} else {
LOG(FATAL) << "input shape error";
}
auto x_dims = param_.x->dims();
auto output_dims = ValidateShape(param_.shape, x_dims);
auto output_dims = ValidateShape(final_shape, x_dims);
LOG(INFO) << "output_dims:" << output_dims;
param_.output->Resize(output_dims);
auto out_lod = param_.output->mutable_lod();
*out_lod = param_.x->lod();
......@@ -37,60 +56,32 @@ bool ReshapeOp::InferShape() const {
}
bool ReshapeOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) {
auto x_var = scope->FindVar(opdesc.Input("X").front());
auto output_var = scope->FindVar(opdesc.Output("Out").front());
CHECK(x_var);
CHECK(output_var);
param_.x = const_cast<lite::Tensor *>(&(x_var->Get<lite::Tensor>()));
param_.output = output_var->GetMutable<lite::Tensor>();
std::vector<std::string> input_arg_names = opdesc.InputArgumentNames();
if (opdesc.HasAttr("inplace")) {
param_.inplace = opdesc.GetAttr<bool>("inplace");
}
CHECK(param_.x) << "Input(X) of ReshapeOp should not be null.";
CHECK(param_.output) << "Output(Out) of ReshapeOp should not be null.";
param_.x =
scope->FindVar(opdesc.Input("X").front())->GetMutable<lite::Tensor>();
param_.output =
scope->FindVar(opdesc.Output("Out").front())->GetMutable<lite::Tensor>();
if (opdesc.HasInput("ShapeTensor") &&
opdesc.Input("ShapeTensor").size() > 0) {
auto inputs = opdesc.Input("ShapeTensor");
for (auto var : inputs) {
lite::Tensor *datatensor =
scope->FindVar(var)->GetMutable<lite::Tensor>();
param_.shape.push_back(datatensor->mutable_data<int>()[0]);
}
const std::vector<int> shape_vector = param_.shape;
lite::Tensor *shape_tensor = new lite::Tensor;
shape_tensor->Resize({static_cast<int64_t>(shape_vector.size())});
int *data_shape = shape_tensor->mutable_data<int>();
for (int i = 0; i < shape_vector.size(); i++) {
data_shape[i] = shape_vector[i];
}
param_.actual_shape = shape_tensor;
return true;
} else if (opdesc.HasInput("Shape") && opdesc.Input("Shape").size() > 0) {
auto actual_shape_var = scope->FindVar(opdesc.Input("Shape").front());
if (actual_shape_var != nullptr) {
param_.actual_shape =
const_cast<lite::Tensor *>(&(actual_shape_var->Get<lite::Tensor>()));
int length = param_.actual_shape->dims().production();
int *shape_list = actual_shape_var->GetMutable<int>();
param_.shape.assign(shape_list, shape_list + length);
auto args = opdesc.Input("ShapeTensor");
for (auto arg : args) {
auto *var = scope->FindVar(arg);
if (var != nullptr) {
param_.shape_tensor_vct.push_back(var->GetMutable<lite::Tensor>());
}
}
return true;
} else {
param_.shape = opdesc.GetAttr<std::vector<int>>("shape");
CHECK(!param_.shape.empty())
<< "The shape information must be set by Attr(shape).";
const std::vector<int> shape_vector = param_.shape;
lite::Tensor *shape_tensor = new lite::Tensor;
shape_tensor->Resize({static_cast<int64_t>(shape_vector.size())});
int *data_shape = shape_tensor->mutable_data<int>();
for (int i = 0; i < shape_vector.size(); i++) {
data_shape[i] = shape_vector[i];
}
if (opdesc.HasInput("Shape") && opdesc.Input("Shape").size() > 0) {
auto var = scope->FindVar(opdesc.Input("Shape").front());
if (var != nullptr) {
param_.shape_tensor = var->GetMutable<lite::Tensor>();
}
param_.actual_shape = shape_tensor;
}
if (opdesc.HasAttr("shape")) {
param_.shape_vct = opdesc.GetAttr<std::vector<int>>("shape");
}
if (opdesc.HasAttr("inplace")) {
param_.inplace = opdesc.GetAttr<bool>("inplace");
}
return true;
}
......@@ -104,20 +95,20 @@ bool Reshape2Op::CheckShape() const {
bool Reshape2Op::InferShape() const {
ReshapeOp::InferShape();
auto x_dims = param_.x->dims();
std::vector<DDim::value_type> xshape_dims(x_dims.size() + 1, 1);
std::vector<DDim::value_type> xshape_dims(x_dims.size() + 1, 0);
for (size_t i = 0; i < x_dims.size(); i++) {
xshape_dims[i + 1] = x_dims[i];
}
param_.xshape->Resize(xshape_dims);
auto xshape_lod = param_.xshape->mutable_lod();
*xshape_lod = param_.x->lod();
return true;
}
bool Reshape2Op::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) {
ReshapeOp::AttachImpl(opdesc, scope);
auto xshape_var = scope->FindVar(opdesc.Output("XShape").front());
CHECK(xshape_var);
param_.xshape = xshape_var->GetMutable<lite::Tensor>();
CHECK(param_.xshape) << "Output(XShape) of ReshapeOp should not be null.";
return true;
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册