未验证 提交 1a675130 编写于 作者: Z zhupengyang 提交者: GitHub

[NPU] enhance shape check in subgraph compute (#2924)

上级 c890d4f5
...@@ -88,6 +88,7 @@ int ReshapeConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -88,6 +88,7 @@ int ReshapeConverter(void* ctx, OpLite* op, KernelBase* kernel) {
} else { } else {
auto shape = op_info->GetAttr<std::vector<int>>("shape"); auto shape = op_info->GetAttr<std::vector<int>>("shape");
auto out_shape = lite::operators::ValidateShape(shape, x_dims); auto out_shape = lite::operators::ValidateShape(shape, x_dims);
out_shape = CvtShape(out_shape);
if (out_shape.size() > 4) { if (out_shape.size() > 4) {
LOG(WARNING) << "[NPU] HiAI DDK only supports less than 4 dimensions, " LOG(WARNING) << "[NPU] HiAI DDK only supports less than 4 dimensions, "
"but shape has " "but shape has "
......
...@@ -151,6 +151,15 @@ int CvtActMode(std::string act_type) { ...@@ -151,6 +151,15 @@ int CvtActMode(std::string act_type) {
return act_mode; return act_mode;
} }
bool CheckShape(DDim origin_dims, hiai::TensorDimension device_dims) {
auto origin_shape = CvtShape(origin_dims);
CHECK_EQ(origin_shape.size(), 4);
return origin_shape[0] == device_dims.GetNumber() &&
origin_shape[1] == device_dims.GetChannel() &&
origin_shape[2] == device_dims.GetHeight() &&
origin_shape[3] == device_dims.GetWidth();
}
} // namespace npu } // namespace npu
} // namespace subgraph } // namespace subgraph
} // namespace lite } // namespace lite
......
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
#include "HiAiModelManagerService.h"
#include "graph/buffer.h" #include "graph/buffer.h"
#include "graph/graph.h" #include "graph/graph.h"
#include "graph/model.h" #include "graph/model.h"
...@@ -145,6 +146,8 @@ ge::TensorPtr CvtTensor(const Tensor& in_tensor, ...@@ -145,6 +146,8 @@ ge::TensorPtr CvtTensor(const Tensor& in_tensor,
int CvtActMode(std::string act_type); int CvtActMode(std::string act_type);
bool CheckShape(DDim origin_dims, hiai::TensorDimension device_dims);
} // namespace npu } // namespace npu
} // namespace subgraph } // namespace subgraph
} // namespace lite } // namespace lite
......
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
#include "lite/core/op_registry.h" #include "lite/core/op_registry.h"
#include "lite/kernels/npu/bridges/graph.h" #include "lite/kernels/npu/bridges/graph.h"
#include "lite/kernels/npu/bridges/paddle_use_bridges.h" #include "lite/kernels/npu/bridges/paddle_use_bridges.h"
#include "lite/kernels/npu/bridges/utility.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
...@@ -123,9 +124,19 @@ int SubgraphEngine::BuildDeviceProgram() { ...@@ -123,9 +124,19 @@ int SubgraphEngine::BuildDeviceProgram() {
<< device_idims[i].GetHeight() << "," << device_idims[i].GetWidth() << device_idims[i].GetHeight() << "," << device_idims[i].GetWidth()
<< "}"; << "}";
// Prepare the device input tensors // Prepare the device input tensors
CHECK_EQ(origin_idims_[i].production(), if (!subgraph::npu::CheckShape(origin_idims_[i], device_idims[i])) {
device_idims[i].GetNumber() * device_idims[i].GetChannel() * LOG(WARNING) << "origin and device input's dims are mismatched.";
device_idims[i].GetHeight() * device_idims[i].GetWidth()); for (int j = 0; j < origin_idims_[i].size(); j++) {
LOG(WARNING) << "origin_idims_[" << i << "][" << j
<< "]: " << origin_idims_[i][j];
}
LOG(WARNING) << "device_idims[" << i << "]: {"
<< device_idims[i].GetNumber() << ", "
<< device_idims[i].GetChannel() << ", "
<< device_idims[i].GetHeight() << ", "
<< device_idims[i].GetWidth() << "}";
return subgraph::FAILED;
}
device_itensors_[i].reset(new hiai::AiTensor); device_itensors_[i].reset(new hiai::AiTensor);
device_itensors_[i]->Init(&(device_idims[i])); device_itensors_[i]->Init(&(device_idims[i]));
} }
...@@ -166,9 +177,19 @@ int SubgraphEngine::BuildDeviceProgram() { ...@@ -166,9 +177,19 @@ int SubgraphEngine::BuildDeviceProgram() {
<< PrecisionToStr(precision); << PrecisionToStr(precision);
break; break;
} }
CHECK_EQ(origin_odims_[i].production(), if (!subgraph::npu::CheckShape(origin_odims_[i], device_odims[i])) {
device_odims[i].GetNumber() * device_odims[i].GetChannel() * LOG(WARNING) << "origin and device output's dims are mismatched.";
device_odims[i].GetHeight() * device_odims[i].GetWidth()); for (int j = 0; j < origin_odims_[i].size(); j++) {
LOG(WARNING) << "origin_odims_[" << i << "][" << j
<< "]: " << origin_odims_[i][j];
}
LOG(WARNING) << "device_odims[" << i << "]: {"
<< device_odims[i].GetNumber() << ", "
<< device_odims[i].GetChannel() << ", "
<< device_odims[i].GetHeight() << ", "
<< device_odims[i].GetWidth() << "}";
return subgraph::FAILED;
}
device_otensors_[i].reset(new hiai::AiTensor); device_otensors_[i].reset(new hiai::AiTensor);
device_otensors_[i]->Init(&(device_odims[i])); device_otensors_[i]->Init(&(device_odims[i]));
} }
......
...@@ -155,19 +155,7 @@ class ReshapeComputeTester : public arena::TestCase { ...@@ -155,19 +155,7 @@ class ReshapeComputeTester : public arena::TestCase {
} }
}; };
TEST(Reshape, precision) { void TestReshape4D(Place place, float abs_error) {
LOG(INFO) << "test Reshape op";
float abs_error = 2e-5;
Place place;
#if defined(LITE_WITH_NPU)
place = TARGET(kNPU);
abs_error = 1e-2; // Using fp16 in NPU
#elif defined(LITE_WITH_XPU)
place = TARGET(kXPU);
#else
return;
#endif
DDim dims{{2, 3, 4, 5}}; DDim dims{{2, 3, 4, 5}};
std::vector<std::vector<int>> shapes{{5, 4, 3, 2}, std::vector<std::vector<int>> shapes{{5, 4, 3, 2},
{2, 3, 20}, {2, 3, 20},
...@@ -177,9 +165,6 @@ TEST(Reshape, precision) { ...@@ -177,9 +165,6 @@ TEST(Reshape, precision) {
{0, 0, 20}, {0, 0, 20},
{0, 0, -1}}; {0, 0, -1}};
for (auto shape : shapes) { for (auto shape : shapes) {
#ifdef LITE_WITH_NPU
if (dims.size() > 4 || shape.size() > 4) continue;
#endif
std::unique_ptr<arena::TestCase> tester( std::unique_ptr<arena::TestCase> tester(
new ReshapeComputeTester(place, "def", dims, shape)); new ReshapeComputeTester(place, "def", dims, shape));
arena::Arena arena(std::move(tester), place, abs_error); arena::Arena arena(std::move(tester), place, abs_error);
...@@ -187,5 +172,47 @@ TEST(Reshape, precision) { ...@@ -187,5 +172,47 @@ TEST(Reshape, precision) {
} }
} }
void TestReshape3D(Place place, float abs_error) {
DDim dims{{2, 3, 20}};
std::vector<std::vector<int>> shapes{
{5, 4, 3, 2}, {2, 3, 20}, {2, 60}, {120}, {2, 3, -1}, {0, 60}, {0, -1}};
for (auto shape : shapes) {
std::unique_ptr<arena::TestCase> tester(
new ReshapeComputeTester(place, "def", dims, shape));
arena::Arena arena(std::move(tester), place, abs_error);
arena.TestPrecision({"xshape"});
}
}
void TestReshape2D(Place place, float abs_error) {
DDim dims{{6, 20}};
std::vector<std::vector<int>> shapes{
{5, 4, 3, 2}, {2, 3, 20}, {2, 60}, {120}, {-1}};
for (auto shape : shapes) {
std::unique_ptr<arena::TestCase> tester(
new ReshapeComputeTester(place, "def", dims, shape));
arena::Arena arena(std::move(tester), place, abs_error);
arena.TestPrecision({"xshape"});
}
}
TEST(Reshape, precision) {
LOG(INFO) << "test Reshape op";
float abs_error = 2e-5;
Place place;
#if defined(LITE_WITH_NPU)
place = TARGET(kNPU);
abs_error = 1e-2; // Using fp16 in NPU
#elif defined(LITE_WITH_XPU)
place = TARGET(kXPU);
#else
return;
#endif
TestReshape4D(place, abs_error);
TestReshape3D(place, abs_error);
TestReshape2D(place, abs_error);
}
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册