未验证 提交 0a75feb9 编写于 作者: H huzhiqiang 提交者: GitHub

[cherry-pick]Fix opt、lookup_table op and write_to_array op (#2922)

上级 fb163a65
...@@ -30,6 +30,7 @@ ...@@ -30,6 +30,7 @@
#include "lite/model_parser/compatible_pb.h" #include "lite/model_parser/compatible_pb.h"
#include "lite/model_parser/pb/program_desc.h" #include "lite/model_parser/pb/program_desc.h"
#include "lite/utils/cp_logging.h" #include "lite/utils/cp_logging.h"
#include "lite/utils/io.h"
#include "lite/utils/string.h" #include "lite/utils/string.h"
#include "supported_kernel_op_info.h" // NOLINT #include "supported_kernel_op_info.h" // NOLINT
...@@ -400,6 +401,7 @@ void Main() { ...@@ -400,6 +401,7 @@ void Main() {
return; return;
} }
lite::MkDirRecur(FLAGS_optimize_out);
auto model_dirs = lite::ListDir(FLAGS_model_set_dir, true); auto model_dirs = lite::ListDir(FLAGS_model_set_dir, true);
if (model_dirs.size() == 0) { if (model_dirs.size() == 0) {
LOG(FATAL) << "[" << FLAGS_model_set_dir << "] does not contain any model"; LOG(FATAL) << "[" << FLAGS_model_set_dir << "] does not contain any model";
...@@ -454,7 +456,9 @@ int main(int argc, char** argv) { ...@@ -454,7 +456,9 @@ int main(int argc, char** argv) {
} }
google::ParseCommandLineFlags(&argc, &argv, false); google::ParseCommandLineFlags(&argc, &argv, false);
paddle::lite_api::ParseInputCommand(); paddle::lite_api::ParseInputCommand();
paddle::lite_api::CheckIfModelSupported(); if (FLAGS_model_set_dir == "") {
paddle::lite_api::CheckIfModelSupported();
}
paddle::lite_api::Main(); paddle::lite_api::Main();
return 0; return 0;
} }
...@@ -67,22 +67,22 @@ void LookupTableCompute::Run() { ...@@ -67,22 +67,22 @@ void LookupTableCompute::Run() {
REGISTER_LITE_KERNEL(lookup_table, REGISTER_LITE_KERNEL(lookup_table,
kARM, kARM,
kFloat, kAny,
kNCHW, kNCHW,
paddle::lite::kernels::arm::LookupTableCompute, paddle::lite::kernels::arm::LookupTableCompute,
def) def)
.BindInput("W", {LiteType::GetTensorTy(TARGET(kARM))}) .BindInput("W", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Ids", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))}) .BindInput("Ids", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kAny))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize(); .Finalize();
REGISTER_LITE_KERNEL(lookup_table_v2, REGISTER_LITE_KERNEL(lookup_table_v2,
kARM, kARM,
kFloat, kAny,
kNCHW, kNCHW,
paddle::lite::kernels::arm::LookupTableCompute, paddle::lite::kernels::arm::LookupTableCompute,
def) def)
.BindInput("W", {LiteType::GetTensorTy(TARGET(kARM))}) .BindInput("W", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Ids", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))}) .BindInput("Ids", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kAny))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize(); .Finalize();
...@@ -21,7 +21,7 @@ namespace lite { ...@@ -21,7 +21,7 @@ namespace lite {
namespace kernels { namespace kernels {
namespace arm { namespace arm {
class LookupTableCompute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> { class LookupTableCompute : public KernelLite<TARGET(kARM), PRECISION(kAny)> {
public: public:
using param_t = operators::LookupTableParam; using param_t = operators::LookupTableParam;
......
...@@ -53,7 +53,7 @@ void lookup_table_compute_ref(const operators::LookupTableParam &param) { ...@@ -53,7 +53,7 @@ void lookup_table_compute_ref(const operators::LookupTableParam &param) {
TEST(lookup_table_arm, retrieve_op) { TEST(lookup_table_arm, retrieve_op) {
auto lookup_table = auto lookup_table =
KernelRegistry::Global().Create<TARGET(kARM), PRECISION(kFloat)>( KernelRegistry::Global().Create<TARGET(kARM), PRECISION(kAny)>(
"lookup_table"); "lookup_table");
ASSERT_FALSE(lookup_table.empty()); ASSERT_FALSE(lookup_table.empty());
ASSERT_TRUE(lookup_table.front()); ASSERT_TRUE(lookup_table.front());
...@@ -61,7 +61,7 @@ TEST(lookup_table_arm, retrieve_op) { ...@@ -61,7 +61,7 @@ TEST(lookup_table_arm, retrieve_op) {
TEST(lookup_table_arm, init) { TEST(lookup_table_arm, init) {
LookupTableCompute lookup_table; LookupTableCompute lookup_table;
ASSERT_EQ(lookup_table.precision(), PRECISION(kFloat)); ASSERT_EQ(lookup_table.precision(), PRECISION(kAny));
ASSERT_EQ(lookup_table.target(), TARGET(kARM)); ASSERT_EQ(lookup_table.target(), TARGET(kARM));
} }
...@@ -112,4 +112,4 @@ TEST(lookup_table_arm, compute) { ...@@ -112,4 +112,4 @@ TEST(lookup_table_arm, compute) {
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
USE_LITE_KERNEL(lookup_table, kARM, kFloat, kNCHW, def); USE_LITE_KERNEL(lookup_table, kARM, kAny, kNCHW, def);
...@@ -65,6 +65,6 @@ REGISTER_LITE_KERNEL(write_to_array, ...@@ -65,6 +65,6 @@ REGISTER_LITE_KERNEL(write_to_array,
paddle::lite::kernels::arm::WriteToArrayCompute, paddle::lite::kernels::arm::WriteToArrayCompute,
def) def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kAny))}) .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kAny))})
.BindInput("I", {LiteType::GetTensorTy(TARGET(kARM))}) .BindInput("I", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kAny))})
.BindOutput("Out", {LiteType::GetTensorListTy(TARGET(kARM))}) .BindOutput("Out", {LiteType::GetTensorListTy(TARGET(kARM))})
.Finalize(); .Finalize();
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册