未验证 提交 421c6305 编写于 作者: J juncaipeng 提交者: GitHub

fix bug for reshape op, test=develop (#2141)

* fix bug for reshape op, test=develop
上级 80d35725
...@@ -24,27 +24,9 @@ namespace host { ...@@ -24,27 +24,9 @@ namespace host {
void ReshapeCompute::Run() { void ReshapeCompute::Run() {
auto& param = Param<operators::ReshapeParam>(); auto& param = Param<operators::ReshapeParam>();
auto x = param.x; auto x = param.x;
auto actual_shape = param.actual_shape;
auto output = param.output; auto output = param.output;
bool inplace = param.inplace;
auto x_dims = x->dims();
auto output_dims = output->dims(); auto output_dims = output->dims();
if (actual_shape) { if (param.inplace) {
auto actual_shape_dims = actual_shape->dims();
auto* actual_shape_data = actual_shape->data<int>();
#ifdef LITE_WITH_CUDA
lite::Tensor cpu_actual_shape;
if (actual_shape->target() == TARGET(kCUDA)) {
cpu_actual_shape.CopyDataFrom(*actual_shape);
actual_shape_data = cpu_actual_shape.data<int>();
}
#endif
auto shape = std::vector<int>(
actual_shape_data, actual_shape_data + actual_shape_dims.production());
output_dims = lite::operators::ValidateShape(shape, x_dims);
output->Resize(output_dims);
}
if (inplace) {
output->ShareDataWith(*x); output->ShareDataWith(*x);
} else { } else {
output->CopyDataFrom(*x); output->CopyDataFrom(*x);
......
...@@ -32,40 +32,57 @@ TEST(reshape_host, compute) { ...@@ -32,40 +32,57 @@ TEST(reshape_host, compute) {
ReshapeCompute reshape; ReshapeCompute reshape;
operators::ReshapeParam param; operators::ReshapeParam param;
Tensor x; Tensor input;
Tensor actual_shape;
Tensor output; Tensor output;
input.Resize({1, 2, 4, 6});
x.Resize(DDim(std::vector<int64_t>({1, 2, 4, 6}))); auto* input_data = input.mutable_data<float>();
actual_shape.Resize(DDim(std::vector<int64_t>({2}))); for (int i = 0; i < input.numel(); i++) {
input_data[i] = i;
auto* x_data = x.mutable_data<float>();
auto* actual_shape_data = actual_shape.mutable_data<int>();
for (int i = 0; i < x.dims().production(); i++) {
x_data[i] = i;
} }
actual_shape_data[0] = 6; Tensor shape_tensor;
actual_shape_data[1] = 8; shape_tensor.Resize({2});
auto* shape_tensor_data = shape_tensor.mutable_data<int>();
shape_tensor_data[0] = 6;
shape_tensor_data[1] = 8;
param.x = &x; // set param and run
param.shape = {-1, 0, 3, 2, 1}; param.x = &input;
param.output = &output; param.shape_tensor = &shape_tensor; // use shape_tensor
param.actual_shape = &actual_shape;
param.inplace = false; param.inplace = false;
param.output = &output;
reshape.SetParam(param); reshape.SetParam(param);
reshape.Run(); reshape.Run();
// check output dims // check output dims
CHECK_EQ(actual_shape.dims().production(), output.dims().size()); CHECK_EQ(shape_tensor.numel(), output.numel());
for (int i = 0; i < output.dims().size(); i++) { for (int i = 0; i < output.dims().size(); i++) {
CHECK_EQ(output.dims()[i], actual_shape_data[i]); CHECK_EQ(output.dims()[i], shape_tensor_data[i]);
} }
// check output data // check output data
auto* output_data = output.mutable_data<float>(); auto* output_data = output.mutable_data<float>();
CHECK_NE(output_data, x_data); CHECK_NE(output_data, input_data);
for (int i = 0; i < output.dims().production(); i++) { for (int i = 0; i < output.numel(); i++) {
EXPECT_NEAR(output_data[i], x_data[i], 1e-6); EXPECT_NEAR(output_data[i], input_data[i], 1e-6);
}
// use shape, set param and run
param.shape_tensor = nullptr;
param.shape_vct = {-1, 0, 3, 2, 1};
reshape.SetParam(param);
reshape.Run();
// check output dims
CHECK_EQ(shape_tensor.numel(), output.numel());
for (int i = 0; i < output.dims().size(); i++) {
CHECK_EQ(output.dims()[i], shape_tensor_data[i]);
}
// check output data
output_data = output.mutable_data<float>();
CHECK_NE(output_data, input_data);
for (int i = 0; i < output.numel(); i++) {
EXPECT_NEAR(output_data[i], input_data[i], 1e-6);
} }
// check output data if inplace = true; // check output data if inplace = true;
...@@ -73,7 +90,7 @@ TEST(reshape_host, compute) { ...@@ -73,7 +90,7 @@ TEST(reshape_host, compute) {
reshape.SetParam(param); reshape.SetParam(param);
reshape.Run(); reshape.Run();
output_data = output.mutable_data<float>(); output_data = output.mutable_data<float>();
CHECK_EQ(output_data, x_data); CHECK_EQ(output_data, input_data);
} }
TEST(reshape, retrive_op) { TEST(reshape, retrive_op) {
......
...@@ -27,19 +27,9 @@ namespace kernels { ...@@ -27,19 +27,9 @@ namespace kernels {
namespace x86 { namespace x86 {
template <typename T> template <typename T>
void Compute(const lite::Tensor* in, void Compute(const lite::Tensor* in, lite::Tensor* out) {
const lite::Tensor* actual_shape,
lite::Tensor* out) {
auto out_dims = out->dims(); auto out_dims = out->dims();
auto in_dims = in->dims(); auto in_dims = in->dims();
if (actual_shape) {
auto shape_dims = actual_shape->dims();
const int* shape_data = actual_shape->data<int>();
std::vector<int> shape =
std::vector<int>(shape_data, shape_data + shape_dims.production());
out_dims = lite::operators::ValidateShape(shape, in_dims);
out->Resize(out_dims);
}
out->CopyDataFrom(*in); out->CopyDataFrom(*in);
out->Resize(out_dims); out->Resize(out_dims);
} }
...@@ -51,7 +41,7 @@ class ReshapeCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> { ...@@ -51,7 +41,7 @@ class ReshapeCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
void Run() override { void Run() override {
auto& param = *param_.get_mutable<param_t>(); auto& param = *param_.get_mutable<param_t>();
Compute<T>(param.x, param.actual_shape, param.output); Compute<T>(param.x, param.output);
} }
virtual ~ReshapeCompute() = default; virtual ~ReshapeCompute() = default;
...@@ -67,7 +57,7 @@ class Reshape2Compute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> { ...@@ -67,7 +57,7 @@ class Reshape2Compute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
void Run() override { void Run() override {
auto& param = *param_.get_mutable<param_t>(); auto& param = *param_.get_mutable<param_t>();
Compute<T>(param.x, param.actual_shape, param.output); Compute<T>(param.x, param.output);
} }
virtual ~Reshape2Compute() = default; virtual ~Reshape2Compute() = default;
......
...@@ -44,7 +44,7 @@ TEST(reshape_x86, run_test) { ...@@ -44,7 +44,7 @@ TEST(reshape_x86, run_test) {
lite::Tensor out; lite::Tensor out;
std::vector<int64_t> x_shape({1, 2, 4, 1}); std::vector<int64_t> x_shape({1, 2, 4, 1});
x.Resize(lite::DDim(x_shape)); x.Resize(lite::DDim(x_shape));
actual_shape.Resize(lite::DDim(std::vector<int64_t>({3}))); actual_shape.Resize(lite::DDim(std::vector<int64_t>({4})));
std::vector<int64_t> out_shape({1, 8, 1, 1}); std::vector<int64_t> out_shape({1, 8, 1, 1});
out.Resize(lite::DDim(out_shape)); out.Resize(lite::DDim(out_shape));
...@@ -56,8 +56,9 @@ TEST(reshape_x86, run_test) { ...@@ -56,8 +56,9 @@ TEST(reshape_x86, run_test) {
x_data[i] = static_cast<float>(i); x_data[i] = static_cast<float>(i);
} }
actual_data[0] = 1; actual_data[0] = 1;
actual_data[1] = 4; actual_data[1] = 8;
actual_data[2] = 2; actual_data[2] = 1;
actual_data[1] = 1;
std::vector<int> shape({1, 8, 1, 1}); std::vector<int> shape({1, 8, 1, 1});
...@@ -67,12 +68,12 @@ TEST(reshape_x86, run_test) { ...@@ -67,12 +68,12 @@ TEST(reshape_x86, run_test) {
param.x = &x; param.x = &x;
param.output = &out; param.output = &out;
param.shape = shape; param.shape_vct = shape;
param.actual_shape = &actual_shape; param.shape_tensor = &actual_shape;
std::unique_ptr<KernelContext> ctx(new KernelContext); std::unique_ptr<KernelContext> ctx(new KernelContext);
ctx->As<X86Context>(); ctx->As<X86Context>();
for (int i = 0; i < 2; ++i) { for (int i = 0; i < 2; ++i) {
if (1 == i) param.actual_shape = nullptr; if (1 == i) param.shape_tensor = nullptr;
reshape.SetContext(std::move(ctx)); reshape.SetContext(std::move(ctx));
reshape.SetParam(param); reshape.SetParam(param);
reshape.Run(); reshape.Run();
...@@ -106,7 +107,7 @@ TEST(reshape2_x86, run_test) { ...@@ -106,7 +107,7 @@ TEST(reshape2_x86, run_test) {
actual_shape.Resize(lite::DDim(std::vector<int64_t>({3}))); actual_shape.Resize(lite::DDim(std::vector<int64_t>({3})));
std::vector<int64_t> out_shape({1, 4, 2}); std::vector<int64_t> out_shape({1, 4, 2});
out.Resize(lite::DDim(out_shape)); out.Resize(lite::DDim(out_shape));
std::vector<int64_t> xshape_shape({1, 2, 4}); std::vector<int64_t> xshape_shape({1, 4, 2});
xshape.Resize(lite::DDim(xshape_shape)); xshape.Resize(lite::DDim(xshape_shape));
auto x_data = x.mutable_data<float>(); auto x_data = x.mutable_data<float>();
...@@ -122,7 +123,7 @@ TEST(reshape2_x86, run_test) { ...@@ -122,7 +123,7 @@ TEST(reshape2_x86, run_test) {
actual_data[1] = 4; actual_data[1] = 4;
actual_data[2] = 2; actual_data[2] = 2;
std::vector<int> shape({0, -1, 2}); std::vector<int> shape({1, 4, 2});
// Reshape2Compute reshape2; // Reshape2Compute reshape2;
Reshape2Compute<float> reshape2; Reshape2Compute<float> reshape2;
...@@ -131,12 +132,12 @@ TEST(reshape2_x86, run_test) { ...@@ -131,12 +132,12 @@ TEST(reshape2_x86, run_test) {
param.x = &x; param.x = &x;
param.output = &out; param.output = &out;
param.xshape = &xshape; param.xshape = &xshape;
param.shape = shape; param.shape_vct = shape;
param.actual_shape = &actual_shape; param.shape_tensor = &actual_shape;
std::unique_ptr<KernelContext> ctx(new KernelContext); std::unique_ptr<KernelContext> ctx(new KernelContext);
ctx->As<X86Context>(); ctx->As<X86Context>();
for (int i = 0; i < 2; ++i) { for (int i = 0; i < 2; ++i) {
if (1 == i) param.actual_shape = nullptr; if (1 == i) param.shape_tensor = nullptr;
reshape2.SetContext(std::move(ctx)); reshape2.SetContext(std::move(ctx));
reshape2.SetParam(param); reshape2.SetParam(param);
reshape2.Run(); reshape2.Run();
......
...@@ -190,11 +190,12 @@ struct SoftmaxParam { ...@@ -190,11 +190,12 @@ struct SoftmaxParam {
// For Reshape and Reshape2 Op // For Reshape and Reshape2 Op
struct ReshapeParam { struct ReshapeParam {
const lite::Tensor* x{}; const lite::Tensor* x{};
const lite::Tensor* actual_shape{nullptr}; std::vector<const lite::Tensor*> shape_tensor_vct{};
const lite::Tensor* shape_tensor{};
std::vector<int> shape_vct{};
lite::Tensor* output{}; lite::Tensor* output{};
lite::Tensor* xshape{};
std::vector<int> shape{}; lite::Tensor* xshape{};
bool inplace{false}; bool inplace{false};
}; };
......
...@@ -23,13 +23,32 @@ namespace operators { ...@@ -23,13 +23,32 @@ namespace operators {
bool ReshapeOp::CheckShape() const { bool ReshapeOp::CheckShape() const {
CHECK_OR_FALSE(param_.x); CHECK_OR_FALSE(param_.x);
CHECK_OR_FALSE(param_.output); CHECK_OR_FALSE(param_.output);
CHECK_OR_FALSE(!param_.shape.empty());
return true; return true;
} }
bool ReshapeOp::InferShape() const { bool ReshapeOp::InferShape() const {
auto shape_tensor_vct = param_.shape_tensor_vct;
auto *shape_tensor = param_.shape_tensor;
auto shape_vct = param_.shape_vct;
std::vector<int> final_shape;
if (shape_tensor_vct.size() > 0) {
for (int i = 0; i < shape_tensor_vct.size(); i++) {
final_shape.push_back(shape_tensor_vct[i]->data<int>()[0]);
}
} else if (shape_tensor != nullptr) {
auto *shape_tensor_data = shape_tensor->data<int>();
final_shape = std::vector<int>(shape_tensor_data,
shape_tensor_data + shape_tensor->numel());
} else if (!shape_vct.empty()) {
final_shape = shape_vct;
} else {
LOG(FATAL) << "input shape error";
}
auto x_dims = param_.x->dims(); auto x_dims = param_.x->dims();
auto output_dims = ValidateShape(param_.shape, x_dims); auto output_dims = ValidateShape(final_shape, x_dims);
LOG(INFO) << "output_dims:" << output_dims;
param_.output->Resize(output_dims); param_.output->Resize(output_dims);
auto out_lod = param_.output->mutable_lod(); auto out_lod = param_.output->mutable_lod();
*out_lod = param_.x->lod(); *out_lod = param_.x->lod();
...@@ -37,60 +56,32 @@ bool ReshapeOp::InferShape() const { ...@@ -37,60 +56,32 @@ bool ReshapeOp::InferShape() const {
} }
bool ReshapeOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) { bool ReshapeOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) {
auto x_var = scope->FindVar(opdesc.Input("X").front()); param_.x =
auto output_var = scope->FindVar(opdesc.Output("Out").front()); scope->FindVar(opdesc.Input("X").front())->GetMutable<lite::Tensor>();
CHECK(x_var); param_.output =
CHECK(output_var); scope->FindVar(opdesc.Output("Out").front())->GetMutable<lite::Tensor>();
param_.x = const_cast<lite::Tensor *>(&(x_var->Get<lite::Tensor>()));
param_.output = output_var->GetMutable<lite::Tensor>();
std::vector<std::string> input_arg_names = opdesc.InputArgumentNames();
if (opdesc.HasAttr("inplace")) {
param_.inplace = opdesc.GetAttr<bool>("inplace");
}
CHECK(param_.x) << "Input(X) of ReshapeOp should not be null.";
CHECK(param_.output) << "Output(Out) of ReshapeOp should not be null.";
if (opdesc.HasInput("ShapeTensor") && if (opdesc.HasInput("ShapeTensor") &&
opdesc.Input("ShapeTensor").size() > 0) { opdesc.Input("ShapeTensor").size() > 0) {
auto inputs = opdesc.Input("ShapeTensor"); auto args = opdesc.Input("ShapeTensor");
for (auto var : inputs) { for (auto arg : args) {
lite::Tensor *datatensor = auto *var = scope->FindVar(arg);
scope->FindVar(var)->GetMutable<lite::Tensor>(); if (var != nullptr) {
param_.shape.push_back(datatensor->mutable_data<int>()[0]); param_.shape_tensor_vct.push_back(var->GetMutable<lite::Tensor>());
} }
const std::vector<int> shape_vector = param_.shape;
lite::Tensor *shape_tensor = new lite::Tensor;
shape_tensor->Resize({static_cast<int64_t>(shape_vector.size())});
int *data_shape = shape_tensor->mutable_data<int>();
for (int i = 0; i < shape_vector.size(); i++) {
data_shape[i] = shape_vector[i];
}
param_.actual_shape = shape_tensor;
return true;
} else if (opdesc.HasInput("Shape") && opdesc.Input("Shape").size() > 0) {
auto actual_shape_var = scope->FindVar(opdesc.Input("Shape").front());
if (actual_shape_var != nullptr) {
param_.actual_shape =
const_cast<lite::Tensor *>(&(actual_shape_var->Get<lite::Tensor>()));
int length = param_.actual_shape->dims().production();
int *shape_list = actual_shape_var->GetMutable<int>();
param_.shape.assign(shape_list, shape_list + length);
} }
return true; }
} else { if (opdesc.HasInput("Shape") && opdesc.Input("Shape").size() > 0) {
param_.shape = opdesc.GetAttr<std::vector<int>>("shape"); auto var = scope->FindVar(opdesc.Input("Shape").front());
CHECK(!param_.shape.empty()) if (var != nullptr) {
<< "The shape information must be set by Attr(shape)."; param_.shape_tensor = var->GetMutable<lite::Tensor>();
const std::vector<int> shape_vector = param_.shape;
lite::Tensor *shape_tensor = new lite::Tensor;
shape_tensor->Resize({static_cast<int64_t>(shape_vector.size())});
int *data_shape = shape_tensor->mutable_data<int>();
for (int i = 0; i < shape_vector.size(); i++) {
data_shape[i] = shape_vector[i];
} }
param_.actual_shape = shape_tensor; }
if (opdesc.HasAttr("shape")) {
param_.shape_vct = opdesc.GetAttr<std::vector<int>>("shape");
}
if (opdesc.HasAttr("inplace")) {
param_.inplace = opdesc.GetAttr<bool>("inplace");
} }
return true; return true;
} }
...@@ -104,20 +95,20 @@ bool Reshape2Op::CheckShape() const { ...@@ -104,20 +95,20 @@ bool Reshape2Op::CheckShape() const {
bool Reshape2Op::InferShape() const { bool Reshape2Op::InferShape() const {
ReshapeOp::InferShape(); ReshapeOp::InferShape();
auto x_dims = param_.x->dims(); auto x_dims = param_.x->dims();
std::vector<DDim::value_type> xshape_dims(x_dims.size() + 1, 1); std::vector<DDim::value_type> xshape_dims(x_dims.size() + 1, 0);
for (size_t i = 0; i < x_dims.size(); i++) { for (size_t i = 0; i < x_dims.size(); i++) {
xshape_dims[i + 1] = x_dims[i]; xshape_dims[i + 1] = x_dims[i];
} }
param_.xshape->Resize(xshape_dims); param_.xshape->Resize(xshape_dims);
auto xshape_lod = param_.xshape->mutable_lod();
*xshape_lod = param_.x->lod();
return true; return true;
} }
bool Reshape2Op::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) { bool Reshape2Op::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) {
ReshapeOp::AttachImpl(opdesc, scope); ReshapeOp::AttachImpl(opdesc, scope);
auto xshape_var = scope->FindVar(opdesc.Output("XShape").front()); auto xshape_var = scope->FindVar(opdesc.Output("XShape").front());
CHECK(xshape_var);
param_.xshape = xshape_var->GetMutable<lite::Tensor>(); param_.xshape = xshape_var->GetMutable<lite::Tensor>();
CHECK(param_.xshape) << "Output(XShape) of ReshapeOp should not be null.";
return true; return true;
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册