提交 0a20a618 编写于 作者: S superjomn

fix ARM compile error

test=develop
上级 9174652b
...@@ -80,9 +80,10 @@ option(WITH_FAST_MATH "Make use of fast math library, might affect the precisi ...@@ -80,9 +80,10 @@ option(WITH_FAST_MATH "Make use of fast math library, might affect the precisi
option(WITH_DGC "Use DGC(Deep Gradient Compression) or not" ON) option(WITH_DGC "Use DGC(Deep Gradient Compression) or not" ON)
# for lite, both server and mobile framework. # for lite, both server and mobile framework.
option(LITE_WITH_CUDA "Enable CUDA in lite mode" ON) option(WITH_LITE "Enable lite framework" ON)
option(LITE_WITH_CUDA "Enable CUDA in lite mode" OFF)
option(LITE_WITH_X86 "Enable X86 in lite mode" ON) option(LITE_WITH_X86 "Enable X86 in lite mode" ON)
option(LITE_WITH_LIGHT_WEIGHT_FRAMEWORK "Enable light-weight framework" ON) option(LITE_WITH_LIGHT_WEIGHT_FRAMEWORK "Enable light-weight framework" OFF)
set(THIRD_PARTY_PATH "${CMAKE_BINARY_DIR}/third_party" CACHE STRING set(THIRD_PARTY_PATH "${CMAKE_BINARY_DIR}/third_party" CACHE STRING
......
...@@ -124,7 +124,7 @@ cc_library(shape_inference SRCS shape_inference.cc DEPS ddim attribute device_co ...@@ -124,7 +124,7 @@ cc_library(shape_inference SRCS shape_inference.cc DEPS ddim attribute device_co
cc_library(transfer_scope_cache SRCS transfer_scope_cache.cc DEPS scope framework_proto device_context) cc_library(transfer_scope_cache SRCS transfer_scope_cache.cc DEPS scope framework_proto device_context)
cc_library(op_kernel_type SRCS op_kernel_type.cc DEPS device_context place) cc_library(op_kernel_type SRCS op_kernel_type.cc DEPS device_context place)
cc_library(operator SRCS operator.cc DEPS op_info device_context tensor scope glog cc_library(operator SRCS operator.cc DEPS op_info device_context tensor scope glog
shape_inference data_transform lod_tensor profiler transfer_scope_cache op_kernel_type) shape_inference data_transform lod_tensor profiler transfer_scope_cache op_kernel_type data_feed_proto)
cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry device_context) cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry device_context)
......
...@@ -2,11 +2,34 @@ if (NOT WITH_LITE) ...@@ -2,11 +2,34 @@ if (NOT WITH_LITE)
return() return()
endif() endif()
message(WARNING "Enable Lite") message(WARNING "Lite enabled!")
message(STATUS "LIGHT_FRAMEWORK: ${LITE_WITH_LIGHT_WEIGHT_FRAMEWORK}") message(STATUS "LIGHT_FRAMEWORK: ${LITE_WITH_LIGHT_WEIGHT_FRAMEWORK}")
message(STATUS "LITE_WITH_CUDA: ${LITE_WITH_CUDA}") message(STATUS "LITE_WITH_CUDA: ${LITE_WITH_CUDA}")
message(STATUS "LITE_WITH_X86: ${LITE_WITH_X86}") message(STATUS "LITE_WITH_X86: ${LITE_WITH_X86}")
set(LITE_MODEL_DIR "${THIRD_PARTY_PATH}/install")
function(lite_download_and_uncompress INSTALL_DIR URL FILENAME)
message(STATUS "Download inference test stuff from ${URL}/${FILENAME}")
string(REGEX REPLACE "[-%.]" "_" FILENAME_EX ${FILENAME})
set(EXTERNAL_PROJECT_NAME "extern_lite_download_${FILENAME_EX}")
set(UNPACK_DIR "${INSTALL_DIR}/src/${EXTERNAL_PROJECT_NAME}")
ExternalProject_Add(
${EXTERNAL_PROJECT_NAME}
${EXTERNAL_PROJECT_LOG_ARGS}
PREFIX ${INSTALL_DIR}
DOWNLOAD_COMMAND wget --no-check-certificate -q -O ${INSTALL_DIR}/${FILENAME} ${URL}/${FILENAME} &&
${CMAKE_COMMAND} -E tar xzf ${INSTALL_DIR}/${FILENAME}
DOWNLOAD_DIR ${INSTALL_DIR}
DOWNLOAD_NO_PROGRESS 1
CONFIGURE_COMMAND ""
BUILD_COMMAND ""
UPDATE_COMMAND ""
INSTALL_COMMAND ""
)
endfunction()
add_subdirectory(core) add_subdirectory(core)
add_subdirectory(x86) add_subdirectory(x86)
add_subdirectory(host) add_subdirectory(host)
......
...@@ -18,8 +18,22 @@ cc_library(light_api_lite SRCS light_api.cc DEPS ${light_api_deps} ${ops_lite} $ ...@@ -18,8 +18,22 @@ cc_library(light_api_lite SRCS light_api.cc DEPS ${light_api_deps} ${ops_lite} $
message(STATUS "get ops ${ops_lite}") message(STATUS "get ops ${ops_lite}")
message(STATUS "get kernels ${host_kernels}") message(STATUS "get kernels ${host_kernels}")
lite_cc_test(test_cxx_api_lite SRCS cxx_api_test.cc DEPS cxx_api_lite model_parser_lite target_wrapper_host
${ops_lite} ${host_kernels}) include(ExternalProject)
lite_cc_test(test_light_api SRCS light_api_test.cc DEPS light_api_lite) set(LITE_URL "http://paddle-inference-dist.bj.bcebos.com" CACHE STRING "inference download url")
set(LITE_DEMO_INSTALL_DIR "${THIRD_PARTY_PATH}/inference_demo" CACHE STRING
"A path setting inference demo download directories.")
lite_cc_test(test_cxx_api_lite SRCS cxx_api_test.cc
DEPS cxx_api_lite model_parser_lite target_wrapper_host
${ops_lite} ${host_kernels} ARGS --model_dir=${LITE_MODEL_DIR}/lite_naive_model
--optimized_model=${LITE_MODEL_DIR}/lite_naive_model_opt SERIAL)
if(WITH_TESTING)
lite_download_and_uncompress(${LITE_MODEL_DIR} ${LITE_URL} "lite_naive_model.tar.gz")
add_dependencies(test_cxx_api_lite extern_lite_download_lite_naive_model_tar_gz)
endif(WITH_TESTING)
lite_cc_test(test_light_api SRCS light_api_test.cc DEPS light_api_lite ARGS --optimized_model=${LITE_MODEL_DIR}/lite_naive_model_opt SERIAL)
cc_binary(cxx_api_lite_bin SRCS cxx_api_bin.cc DEPS cxx_api_lite model_parser_lite target_wrapper_host ${ops_lite} ${host_kernels}) cc_binary(cxx_api_lite_bin SRCS cxx_api_bin.cc DEPS cxx_api_lite model_parser_lite target_wrapper_host ${ops_lite} ${host_kernels})
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
#include "paddle/fluid/lite/core/op_registry.h" #include "paddle/fluid/lite/core/op_registry.h"
DEFINE_string(model_dir, "", ""); DEFINE_string(model_dir, "", "");
DEFINE_string(optimized_model, "", "");
namespace paddle { namespace paddle {
namespace lite { namespace lite {
...@@ -48,7 +49,7 @@ TEST(CXXApi, test) { ...@@ -48,7 +49,7 @@ TEST(CXXApi, test) {
data[i] = i; data[i] = i;
} }
LOG(INFO) << "input " << *input_tensor; // LOG(INFO) << "input " << *input_tensor;
predictor.Run(); predictor.Run();
...@@ -57,7 +58,7 @@ TEST(CXXApi, test) { ...@@ -57,7 +58,7 @@ TEST(CXXApi, test) {
LOG(INFO) << "out " << out->data<float>()[0]; LOG(INFO) << "out " << out->data<float>()[0];
LOG(INFO) << "out " << out->data<float>()[1]; LOG(INFO) << "out " << out->data<float>()[1];
LOG(INFO) << "dims " << out->dims(); LOG(INFO) << "dims " << out->dims();
LOG(INFO) << "out " << *out; // LOG(INFO) << "out " << *out;
} }
#ifndef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK #ifndef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
...@@ -67,7 +68,7 @@ TEST(CXXApi, save_model) { ...@@ -67,7 +68,7 @@ TEST(CXXApi, save_model) {
predictor.Build(FLAGS_model_dir, Place{TARGET(kCUDA), PRECISION(kFloat)}, predictor.Build(FLAGS_model_dir, Place{TARGET(kCUDA), PRECISION(kFloat)},
valid_places); valid_places);
predictor.SaveModel("./optimized_model"); predictor.SaveModel(FLAGS_optimized_model);
} }
#endif #endif
......
...@@ -41,21 +41,20 @@ class LightPredictor { ...@@ -41,21 +41,20 @@ class LightPredictor {
void Run() { program_->Run(); } void Run() { program_->Run(); }
// Get offset-th col of feed. // Get offset-th col of feed.
TensorBase* GetInput(size_t offset) { Tensor* GetInput(size_t offset) {
auto* _feed_list = program_->exec_scope()->FindVar("feed"); auto* _feed_list = program_->exec_scope()->FindVar("feed");
CHECK(_feed_list) << "no feed variable in exec_scope"; CHECK(_feed_list) << "no feed variable in exec_scope";
auto* feed_list = _feed_list->GetMutable<std::vector<TensorBase>>(); auto* feed_list = _feed_list->GetMutable<std::vector<Tensor>>();
if (offset >= feed_list->size()) { if (offset >= feed_list->size()) {
feed_list->resize(offset + 1); feed_list->resize(offset + 1);
} }
return &feed_list->at(offset); return &feed_list->at(offset);
} }
const TensorBase* GetOutput(size_t offset) { const Tensor* GetOutput(size_t offset) {
auto* _fetch_list = program_->exec_scope()->FindVar("fetch"); auto* _fetch_list = program_->exec_scope()->FindVar("fetch");
CHECK(_fetch_list) << "no fatch variable in exec_scope"; CHECK(_fetch_list) << "no fatch variable in exec_scope";
auto& fetch_list = auto& fetch_list = *_fetch_list->GetMutable<std::vector<lite::Tensor>>();
*_fetch_list->GetMutable<std::vector<lite::TensorBase>>();
CHECK_LT(offset, fetch_list.size()) << "offset " << offset << " overflow"; CHECK_LT(offset, fetch_list.size()) << "offset " << offset << " overflow";
return &fetch_list.at(offset); return &fetch_list.at(offset);
} }
......
...@@ -13,21 +13,20 @@ ...@@ -13,21 +13,20 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/lite/api/light_api.h" #include "paddle/fluid/lite/api/light_api.h"
#include <gflags/gflags.h>
#include <gtest/gtest.h> #include <gtest/gtest.h>
DEFINE_string(optimized_model, "", "");
namespace paddle { namespace paddle {
namespace lite { namespace lite {
const std::string model_dir =
"/home/chunwei/project/Paddle/cmake-build-relwithdebinfo/paddle/fluid/lite/"
"api/optimized_model";
TEST(LightAPI, load) { TEST(LightAPI, load) {
LightPredictor predictor; LightPredictor predictor;
predictor.Build(model_dir); predictor.Build(FLAGS_optimized_model);
auto* input_tensor = predictor.GetInput(0); auto* input_tensor = predictor.GetInput(0);
input_tensor->Resize({100, 100}); input_tensor->Resize(DDimLite(std::vector<int64_t>({100, 100})));
auto* data = input_tensor->mutable_data<float>(); auto* data = input_tensor->mutable_data<float>();
for (int i = 0; i < 100 * 100; i++) { for (int i = 0; i < 100 * 100; i++) {
data[i] = i; data[i] = i;
...@@ -40,13 +39,13 @@ TEST(LightAPI, load) { ...@@ -40,13 +39,13 @@ TEST(LightAPI, load) {
} // namespace paddle } // namespace paddle
USE_LITE_OP(mul); USE_LITE_OP(mul);
USE_LITE_OP(fc); // USE_LITE_OP(fc);
USE_LITE_OP(scale); // USE_LITE_OP(scale);
USE_LITE_OP(feed); USE_LITE_OP(feed);
USE_LITE_OP(fetch); USE_LITE_OP(fetch);
USE_LITE_OP(io_copy); USE_LITE_OP(io_copy);
USE_LITE_KERNEL(fc, kHost, kFloat, kNCHW, def); // USE_LITE_KERNEL(fc, kHost, kFloat, kNCHW, def);
USE_LITE_KERNEL(mul, kHost, kFloat, kNCHW, def); // USE_LITE_KERNEL(mul, kHost, kFloat, kNCHW, def);
USE_LITE_KERNEL(scale, kHost, kFloat, kNCHW, def); // USE_LITE_KERNEL(scale, kHost, kFloat, kNCHW, def);
USE_LITE_KERNEL(feed, kHost, kAny, kAny, def); USE_LITE_KERNEL(feed, kHost, kAny, kAny, def);
USE_LITE_KERNEL(fetch, kHost, kAny, kAny, def); USE_LITE_KERNEL(fetch, kHost, kAny, kAny, def);
cc_library(lite_gtest_main SRCS lite_gtest_main.cc) cc_library(lite_gtest_main SRCS lite_gtest_main.cc)
cc_library(memory_lite SRCS memory.cc) cc_library(memory_lite SRCS memory.cc DEPS target_wrapper_lite target_wrapper_host)
cc_library(target_wrapper_lite SRCS target_wrapper.cc) cc_library(target_wrapper_lite SRCS target_wrapper.cc)
cc_library(lite_tensor SRCS lite_tensor.cc DEPS memory_lite target_wrapper_lite) cc_library(lite_tensor SRCS lite_tensor.cc DEPS memory_lite target_wrapper_lite)
...@@ -24,7 +24,7 @@ cc_library(op_lite SRCS op_lite.cc DEPS scope_lite op_registry_lite compatible_p ...@@ -24,7 +24,7 @@ cc_library(op_lite SRCS op_lite.cc DEPS scope_lite op_registry_lite compatible_p
cc_library(types_lite SRCS types.cc) cc_library(types_lite SRCS types.cc)
cc_library(type_system SRCS type_system.cc DEPS ${tensor_lite} target_wrapper_lite) cc_library(type_system SRCS type_system.cc DEPS ${tensor_lite} target_wrapper_lite)
cc_library(kernel_executor_lite SRCS kernel_executor.cc DEPS mir_ssa_graph kernel_lite) #cc_library(kernel_executor_lite SRCS kernel_executor.cc DEPS mir_ssa_graph kernel_lite)
cc_library(program_lite SRCS program.cc DEPS op_lite kernel_lite) cc_library(program_lite SRCS program.cc DEPS op_lite kernel_lite)
cc_library(optimizer_lite SRCS optimizer.cc DEPS mir_pass_manager model_parser_lite program_lite) cc_library(optimizer_lite SRCS optimizer.cc DEPS mir_pass_manager model_parser_lite program_lite)
...@@ -46,5 +46,5 @@ lite_cc_test(test_kernel_lite SRCS kernel_test.cc DEPS kernel_lite target_wrappe ...@@ -46,5 +46,5 @@ lite_cc_test(test_kernel_lite SRCS kernel_test.cc DEPS kernel_lite target_wrappe
lite_cc_test(test_op_lite SRCS op_lite_test.cc DEPS op_lite) lite_cc_test(test_op_lite SRCS op_lite_test.cc DEPS op_lite)
lite_cc_test(test_tensor_lite SRCS lite_tensor_test.cc DEPS lite_tensor) lite_cc_test(test_tensor_lite SRCS lite_tensor_test.cc DEPS lite_tensor)
lite_cc_test(test_type_system SRCS type_system_test.cc DEPS type_system utils_lite) lite_cc_test(test_type_system SRCS type_system_test.cc DEPS type_system utils_lite)
lite_cc_test(test_optimizer_lite SRCS optimizer_test.cc DEPS mir_pass_manager program_fake_utils mir_passes) #lite_cc_test(test_optimizer_lite SRCS optimizer_test.cc DEPS mir_pass_manager program_fake_utils mir_passes optimizer_lite fc_op_lite)
lite_cc_test(test_types_lite SRCS types_test.cc DEPS types_lite) lite_cc_test(test_types_lite SRCS types_test.cc DEPS types_lite)
...@@ -22,9 +22,11 @@ static void* TargetMalloc(TargetType target, size_t size) { ...@@ -22,9 +22,11 @@ static void* TargetMalloc(TargetType target, size_t size) {
void* data{nullptr}; void* data{nullptr};
switch (target) { switch (target) {
case TargetType::kHost: case TargetType::kHost:
#ifdef LITE_WITH_X86
case TargetType::kX86: case TargetType::kX86:
data = TargetWrapper<TARGET(kHost)>::Malloc(size); data = TargetWrapper<TARGET(kHost)>::Malloc(size);
break; break;
#endif
#ifdef LITE_WITH_CUDA #ifdef LITE_WITH_CUDA
case TargetType::kCUDA: case TargetType::kCUDA:
data = data =
......
...@@ -22,14 +22,17 @@ endif() ...@@ -22,14 +22,17 @@ endif()
cc_test(test_mir_pass_manager SRCS pass_manager_test.cc DEPS mir_pass_manager mir_passes) cc_test(test_mir_pass_manager SRCS pass_manager_test.cc DEPS mir_pass_manager mir_passes)
cc_test(test_ssa_graph SRCS ssa_graph_test.cc DEPS cc_test(test_ssa_graph SRCS ssa_graph_test.cc DEPS
mir_ssa_graph scope_lite op_lite mir_ssa_graph scope_lite op_lite
${ops_lite} fc_op_lite
${host_kernels} ${host_kernels}
mir_passes mir_passes
mir_pass_manager mir_pass_manager
program_fake_utils program_fake_utils
) )
set(test_variable_place_infrence_pass_DEPS set(test_variable_place_infrence_pass_DEPS
${ops_lite} mul_op_lite
feed_op_lite
fetch_op_lite
io_copy_op_lite
${host_kernels} ${host_kernels}
mir_passes mir_passes
mir_pass_manager mir_pass_manager
......
...@@ -35,10 +35,12 @@ void BuildFc(framework::ProgramDesc* desc, const std::string& x, ...@@ -35,10 +35,12 @@ void BuildFc(framework::ProgramDesc* desc, const std::string& x,
} }
TEST(SSAGraph, test) { TEST(SSAGraph, test) {
auto program = ProgramFaker(); auto program_faker = ProgramFaker();
SSAGraph graph; SSAGraph graph;
std::vector<Place> places{{TARGET(kHost), PRECISION(kFloat)}}; std::vector<Place> places{{TARGET(kHost), PRECISION(kFloat)}};
auto scope = std::make_shared<lite::Scope>();
lite::Program program(*program_faker.program()->Proto(), scope, places);
graph.Build(program, places); graph.Build(program, places);
Visualize(&graph); Visualize(&graph);
...@@ -49,4 +51,4 @@ TEST(SSAGraph, test) { ...@@ -49,4 +51,4 @@ TEST(SSAGraph, test) {
} // namespace paddle } // namespace paddle
USE_LITE_OP(fc); USE_LITE_OP(fc);
USE_LITE_KERNEL(fc, kHost, kFloat); USE_LITE_KERNEL(fc, kHost, kFloat, kNCHW, def);
...@@ -38,12 +38,6 @@ class VariablePlaceInferencePass : public DebugPass { ...@@ -38,12 +38,6 @@ class VariablePlaceInferencePass : public DebugPass {
LOG(INFO) << "found kernel in inputs " << v->AsStmt().op_type; LOG(INFO) << "found kernel in inputs " << v->AsStmt().op_type;
continue; continue;
} }
// auto& arg = v->AsArgument();
// LOG(INFO) << "get graph input " << arg.name << " " << *arg.type;
// arg.type.target = argument_default_target_;
// the other place description can't be determined yet, until their first
// usage by some kernel.
} }
} }
......
...@@ -44,7 +44,7 @@ TEST(variable_place_inference_pass, test) { ...@@ -44,7 +44,7 @@ TEST(variable_place_inference_pass, test) {
}, },
}); });
Program program(*desc, scope, places); Program program(*desc->Proto(), scope, places);
core::KernelPickFactor factor; core::KernelPickFactor factor;
factor.ConsiderTarget(); factor.ConsiderTarget();
......
...@@ -93,7 +93,8 @@ class KernelRegistry final { ...@@ -93,7 +93,8 @@ class KernelRegistry final {
std::move(creator)); std::move(creator));
} }
template <TargetType Target, PrecisionType Precision, DataLayoutType Layout> template <TargetType Target, PrecisionType Precision = PRECISION(kFloat),
DataLayoutType Layout = DATALAYOUT(kNCHW)>
std::list<std::unique_ptr<KernelBase>> Create(const std::string &op_type) { std::list<std::unique_ptr<KernelBase>> Create(const std::string &op_type) {
using kernel_registor_t = using kernel_registor_t =
KernelRegistryForTarget<Target, Precision, Layout>; KernelRegistryForTarget<Target, Precision, Layout>;
......
...@@ -25,12 +25,18 @@ namespace lite { ...@@ -25,12 +25,18 @@ namespace lite {
TEST(Optimizer, test) { TEST(Optimizer, test) {
Optimizer optimizer; Optimizer optimizer;
auto program = ProgramFaker(); auto program_faker = ProgramFaker();
program_faker.AddFeed("X", 0);
program_faker.AddFetch("X", 0);
std::vector<Place> places({Place{TARGET(kHost), PRECISION(kFloat)}}); std::vector<Place> places({Place{TARGET(kHost), PRECISION(kFloat)}});
core::KernelPickFactor factor; core::KernelPickFactor factor;
factor.ConsiderTarget(); factor.ConsiderTarget();
auto scope = std::make_shared<lite::Scope>();
auto program_proto = *program_faker.program()->Proto();
Program program(program_proto, scope, places);
optimizer.Run(std::move(program), places, factor); optimizer.Run(std::move(program), places, factor);
auto runtime_program = optimizer.GenRuntimeProgram(); auto runtime_program = optimizer.GenRuntimeProgram();
LOG(INFO) << "num statements " << runtime_program->num_instructions(); LOG(INFO) << "num statements " << runtime_program->num_instructions();
......
...@@ -70,7 +70,7 @@ struct Program { ...@@ -70,7 +70,7 @@ struct Program {
VLOG(4) << "create Op [" << op_type << "]"; VLOG(4) << "create Op [" << op_type << "]";
auto op = LiteOpRegistry::Global().Create(op_type); auto op = LiteOpRegistry::Global().Create(op_type);
CHECK(op) << "no Op found for " << op_type; CHECK(op) << "no Op found for " << op_type;
ops.emplace_back(op); ops.emplace_back(std::move(op));
ops.back()->Attach(op_desc, exec_scope); ops.back()->Attach(op_desc, exec_scope);
} }
} }
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#pragma once #pragma once
#include <string> #include <string>
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/lite/core/mir/ssa_graph.h" #include "paddle/fluid/lite/core/mir/ssa_graph.h"
#include "paddle/fluid/lite/core/op_registry.h" #include "paddle/fluid/lite/core/op_registry.h"
...@@ -28,9 +29,9 @@ Program FakeProgram() { ...@@ -28,9 +29,9 @@ Program FakeProgram() {
std::string w1 = "w" + std::to_string(id); std::string w1 = "w" + std::to_string(id);
std::string b1 = "b" + std::to_string(id); std::string b1 = "b" + std::to_string(id);
std::string out1 = "out" + std::to_string(id); std::string out1 = "out" + std::to_string(id);
auto w1v = program.scope->Var(w1)->GetMutable<TensorBase>(); auto w1v = program.scope->Var(w1)->GetMutable<lite::Tensor>();
auto b1v = program.scope->Var(b1)->GetMutable<TensorBase>(); auto b1v = program.scope->Var(b1)->GetMutable<lite::Tensor>();
auto out1v = program.scope->Var(out1)->GetMutable<TensorBase>(); auto out1v = program.scope->Var(out1)->GetMutable<lite::Tensor>();
lite::OpDesc desc; lite::OpDesc desc;
desc.SetInput("Input", {x}); desc.SetInput("Input", {x});
...@@ -38,7 +39,7 @@ Program FakeProgram() { ...@@ -38,7 +39,7 @@ Program FakeProgram() {
desc.SetInput("Bias", {b1}); desc.SetInput("Bias", {b1});
desc.SetOutput("Out", {out1}); desc.SetOutput("Out", {out1});
desc.SetType("fc"); desc.SetType("fc");
desc.SetAttr<int>("in_num_col_dims", 1); desc.SetAttr("in_num_col_dims", 1);
// add to input // add to input
program.tmp_vars.push_back(w1); program.tmp_vars.push_back(w1);
...@@ -48,9 +49,9 @@ Program FakeProgram() { ...@@ -48,9 +49,9 @@ Program FakeProgram() {
fc_op->Attach(desc, program.scope.get()); fc_op->Attach(desc, program.scope.get());
program.ops.emplace_back(std::move(fc_op)); program.ops.emplace_back(std::move(fc_op));
w1v->Resize({100, 100}); w1v->Resize(DDimHvy(std::vector<int64_t>({100, 100})));
b1v->Resize({100, 1}); b1v->Resize(DDimHvy(std::vector<int64_t>({100, 1})));
out1v->Resize({100, 100}); out1v->Resize(DDimHvy(std::vector<int64_t>({100, 100})));
return out1; return out1;
}; };
...@@ -60,8 +61,8 @@ Program FakeProgram() { ...@@ -60,8 +61,8 @@ Program FakeProgram() {
std::string x = "x"; std::string x = "x";
program.tmp_vars.push_back(x); program.tmp_vars.push_back(x);
auto* xv = program.scope->Var(x)->GetMutable<TensorBase>(); auto* xv = program.scope->Var(x)->GetMutable<lite::Tensor>();
xv->Resize({100, 100}); xv->Resize(DDimHvy(std::vector<int64_t>({100, 100})));
for (int i = 0; i < 3; i++) { for (int i = 0; i < 3; i++) {
x = add_fc(i, x); x = add_fc(i, x);
...@@ -81,7 +82,7 @@ class ProgramFaker { ...@@ -81,7 +82,7 @@ class ProgramFaker {
void CreateVars(lite::Scope* scope) { void CreateVars(lite::Scope* scope) {
for (auto& var : tmp_vars_) { for (auto& var : tmp_vars_) {
auto* x = scope->Var(var); auto* x = scope->Var(var);
x->GetMutable<lite::TensorBase>(); x->GetMutable<lite::Tensor>();
} }
for (auto& x : tmp_vars_) { for (auto& x : tmp_vars_) {
......
message(STATUS "compile with lite host kernels") message(STATUS "compile with lite host kernels")
cc_library(fc_compute_host SRCS fc_compute.cc DEPS ${lite_kernel_deps}) cc_library(fc_compute_host SRCS fc_compute.cc DEPS ${lite_kernel_deps} eigen3)
cc_library(relu_compute_host SRCS relu_compute.cc DEPS ${lite_kernel_deps}) cc_library(relu_compute_host SRCS relu_compute.cc DEPS ${lite_kernel_deps})
cc_library(mul_compute_host SRCS mul_compute.cc DEPS ${lite_kernel_deps}) cc_library(mul_compute_host SRCS mul_compute.cc DEPS ${lite_kernel_deps} eigen3)
cc_library(scale_compute_host SRCS scale_compute.cc DEPS ${lite_kernel_deps}) cc_library(scale_compute_host SRCS scale_compute.cc DEPS ${lite_kernel_deps} eigen3)
cc_library(feed_compute_host SRCS feed_compute.cc DEPS ${lite_kernel_deps}) cc_library(feed_compute_host SRCS feed_compute.cc DEPS ${lite_kernel_deps})
cc_library(fetch_compute_host SRCS fetch_compute.cc DEPS ${lite_kernel_deps}) cc_library(fetch_compute_host SRCS fetch_compute.cc DEPS ${lite_kernel_deps})
...@@ -16,5 +17,3 @@ set(host_kernels ...@@ -16,5 +17,3 @@ set(host_kernels
) )
set(host_kernels "${host_kernels}" CACHE INTERNAL "host kernels") set(host_kernels "${host_kernels}" CACHE INTERNAL "host kernels")
lite_cc_test(test_fc_compute SRCS fc_compute_test.cc DEPS fc_compute_host fc_op_lite)
...@@ -23,7 +23,7 @@ namespace kernels { ...@@ -23,7 +23,7 @@ namespace kernels {
namespace host { namespace host {
TEST(fc_compute_naive, test) { TEST(fc_compute_naive, test) {
TensorBase x, w, b, out, out1; lite::Tensor x, w, b, out, out1;
const int batch_size = 2; const int batch_size = 2;
x.Resize({batch_size, 3}); x.Resize({batch_size, 3});
w.Resize({4, 3}); w.Resize({4, 3});
...@@ -79,15 +79,15 @@ TEST(fc_host, compute) { ...@@ -79,15 +79,15 @@ TEST(fc_host, compute) {
FcCompute fc; FcCompute fc;
operators::FcParam param; operators::FcParam param;
TensorBase x; lite::Tensor x;
TensorBase w; lite::Tensor w;
TensorBase bias; lite::Tensor bias;
TensorBase output; lite::Tensor output;
x.Resize({1, 10, 20}); x.Resize(DDim(std::vector<int64_t>({1, 10, 20})));
w.Resize({20, 20}); w.Resize(DDim(std::vector<int64_t>({20, 20})));
bias.Resize({1, 10}); bias.Resize(DDim(std::vector<int64_t>({1, 10})));
output.Resize({10, 20}); output.Resize(DDim(std::vector<int64_t>({10, 20})));
auto* x_data = x.mutable_data<float>(); auto* x_data = x.mutable_data<float>();
auto* w_data = w.mutable_data<float>(); auto* w_data = w.mutable_data<float>();
...@@ -119,7 +119,7 @@ TEST(fc_host, compute) { ...@@ -119,7 +119,7 @@ TEST(fc_host, compute) {
TEST(fc, retrive_op) { TEST(fc, retrive_op) {
auto fc = auto fc =
KernelRegistry::Global().Create<TARGET(kHost), PRECISION(kFloat)>("fc"); KernelRegistry::Global().Create<TARGET(kHost), PRECISION(kFloat)>("fc");
ASSERT_TRUE(fc.get()); ASSERT_TRUE(fc);
} }
} // namespace host } // namespace host
...@@ -127,4 +127,4 @@ TEST(fc, retrive_op) { ...@@ -127,4 +127,4 @@ TEST(fc, retrive_op) {
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
USE_LITE_KERNEL(fc, kHost, kFloat); USE_LITE_KERNEL(fc, kHost, kFloat, kNCHW, def);
cc_library(runtime_lite SRCS runtime.cc) #cc_library(runtime_lite SRCS runtime.cc)
lite_cc_test(test_model_parser_lite SRCS model_parser_test.cc DEPS model_parser_lite)
lite_cc_test(test_model_parser_lite SRCS model_parser_test.cc
DEPS model_parser_lite framework_proto_lite
ARGS --model_dir=${LITE_MODEL_DIR}/lite_naive_model)
if(WITH_TESTING)
add_dependencies(test_model_parser_lite extern_lite_download_lite_naive_model_tar_gz)
endif(WITH_TESTING)
if(LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) if(LITE_WITH_LIGHT_WEIGHT_FRAMEWORK)
cc_library(compatible_pb_lite SRCS compatible_pb.cc DEPS op_desc_lite var_desc_lite) cc_library(compatible_pb_lite SRCS compatible_pb.cc DEPS op_desc_lite var_desc_lite)
else() else()
......
...@@ -20,8 +20,8 @@ ...@@ -20,8 +20,8 @@
* lite::pb::XXDesc. * lite::pb::XXDesc.
*/ */
#include "paddle/fluid/framework/framework.pb.h"
#ifdef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK #ifdef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
#include "paddle/fluid/lite/core/framework.pb.h"
#include "paddle/fluid/lite/model_parser/pb/op_desc.h" #include "paddle/fluid/lite/model_parser/pb/op_desc.h"
#include "paddle/fluid/lite/model_parser/pb/var_desc.h" #include "paddle/fluid/lite/model_parser/pb/var_desc.h"
#else #else
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
#include <memory> #include <memory>
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/fluid/framework/framework.pb.h" #include "paddle/fluid/lite/core/framework.pb.h"
#include "paddle/fluid/lite/core/scope.h" #include "paddle/fluid/lite/core/scope.h"
#include "paddle/fluid/lite/core/variable.h" #include "paddle/fluid/lite/core/variable.h"
......
...@@ -13,22 +13,25 @@ ...@@ -13,22 +13,25 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/lite/model_parser/model_parser.h" #include "paddle/fluid/lite/model_parser/model_parser.h"
#include <gflags/gflags.h>
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include "paddle/fluid/lite/core/scope.h" #include "paddle/fluid/lite/core/scope.h"
DEFINE_string(model_dir, "", "");
namespace paddle { namespace paddle {
namespace lite { namespace lite {
TEST(ModelParser, LoadProgram) { TEST(ModelParser, LoadProgram) {
auto program = LoadProgram( CHECK(!FLAGS_model_dir.empty());
"/home/chunwei/project2/models/fc/fluid_checkpoint/__model__"); auto program = LoadProgram(FLAGS_model_dir + "/__model__");
} }
TEST(ModelParser, LoadParam) { TEST(ModelParser, LoadParam) {
Scope scope; Scope scope;
auto* v = scope.Var("xxx"); auto* v = scope.Var("xxx");
LoadParam("/home/chunwei/project2/models/fc/fluid_checkpoint/b1", v); LoadParam(FLAGS_model_dir + "/fc_0.b_0", v);
const auto& t = v->Get<TensorBase>(); const auto& t = v->Get<Tensor>();
LOG(INFO) << "loaded\n"; LOG(INFO) << "loaded\n";
LOG(INFO) << t; LOG(INFO) << t;
} }
...@@ -36,7 +39,7 @@ TEST(ModelParser, LoadParam) { ...@@ -36,7 +39,7 @@ TEST(ModelParser, LoadParam) {
TEST(ModelParser, LoadModel) { TEST(ModelParser, LoadModel) {
Scope scope; Scope scope;
framework::proto::ProgramDesc prog; framework::proto::ProgramDesc prog;
LoadModel("/home/chunwei/project2/models/fc/fluid_checkpoint", &scope, &prog); LoadModel(FLAGS_model_dir, &scope, &prog);
} }
} // namespace lite } // namespace lite
......
...@@ -24,7 +24,7 @@ ...@@ -24,7 +24,7 @@
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
#include "paddle/fluid/framework/framework.pb.h" #include "paddle/fluid/lite/core/framework.pb.h"
#include "paddle/fluid/lite/utils/all.h" #include "paddle/fluid/lite/utils/all.h"
namespace paddle { namespace paddle {
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
#include <algorithm> #include <algorithm>
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/fluid/framework/framework.pb.h" #include "paddle/fluid/lite/core/framework.pb.h"
#include "paddle/fluid/lite/utils/cp_logging.h" #include "paddle/fluid/lite/utils/cp_logging.h"
namespace paddle { namespace paddle {
......
...@@ -59,9 +59,6 @@ class FcOpLite : public OpLite { ...@@ -59,9 +59,6 @@ class FcOpLite : public OpLite {
param_.output = scope->FindVar(out)->GetMutable<lite::Tensor>(); param_.output = scope->FindVar(out)->GetMutable<lite::Tensor>();
param_.in_num_col_dims = GetAttr<int>(op_desc.GetAttr("in_num_col_dims")); param_.in_num_col_dims = GetAttr<int>(op_desc.GetAttr("in_num_col_dims"));
CHECK(kernel_);
kernel_->SetParam(param_);
return true; return true;
} }
......
...@@ -21,17 +21,16 @@ namespace lite { ...@@ -21,17 +21,16 @@ namespace lite {
namespace operators { namespace operators {
TEST(fc_op_lite, test) { TEST(fc_op_lite, test) {
LOG(INFO) << "\n" << KernelRegistry::Global().DebugString();
// prepare variables // prepare variables
Scope scope; Scope scope;
auto* x = scope.Var("x")->GetMutable<TensorBase>(); auto* x = scope.Var("x")->GetMutable<Tensor>();
auto* w = scope.Var("w")->GetMutable<TensorBase>(); auto* w = scope.Var("w")->GetMutable<Tensor>();
auto* bias = scope.Var("bias")->GetMutable<TensorBase>(); auto* bias = scope.Var("bias")->GetMutable<Tensor>();
auto* output = scope.Var("output")->GetMutable<TensorBase>(); auto* output = scope.Var("output")->GetMutable<Tensor>();
x->Resize({1, 10, 20}); x->Resize(DDim(std::vector<int64_t>({1, 10, 20})));
w->Resize({20, 20}); w->Resize(DDim(std::vector<int64_t>{20, 20}));
bias->Resize({1, 10}); bias->Resize(DDim(std::vector<int64_t>{1, 10}));
output->Resize({10, 20}); output->Resize(DDim(std::vector<int64_t>{10, 20}));
// set data // set data
for (int i = 0; i < 10 * 20; i++) { for (int i = 0; i < 10 * 20; i++) {
...@@ -59,18 +58,13 @@ TEST(fc_op_lite, test) { ...@@ -59,18 +58,13 @@ TEST(fc_op_lite, test) {
FcOpLite fc("fc"); FcOpLite fc("fc");
fc.SetValidPlaces({Place{TARGET(kHost), PRECISION(kFloat)}}); fc.SetValidPlaces({Place{TARGET(kHost), PRECISION(kFloat)}});
fc.PickKernel({Place{TARGET(kHost), PRECISION(kFloat)}}); fc.Attach(desc, &scope);
auto kernels = fc.CreateKernels({Place{TARGET(kHost), PRECISION(kFloat)}});
fc.AttachImpl(desc, &scope); ASSERT_FALSE(kernels.empty());
fc.Run();
for (int i = 0; i < 10 * 20; i++) {
LOG(INFO) << output->data<float>()[i];
}
} }
} // namespace operators } // namespace operators
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
USE_LITE_KERNEL(fc, kHost, kFloat); USE_LITE_KERNEL(fc, kHost, kFloat, kNCHW, def);
if(LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) if(LITE_WITH_LIGHT_WEIGHT_FRAMEWORK)
set(utils_DEPS) set(utils_DEPS)
lite_cc_test(test_logging_lite SRCS logging_test.cc)
else() else()
set(utils_DEPS glog) set(utils_DEPS glog)
endif() endif()
lite_cc_test(test_logging_lite SRCS logging_test.cc)
lite_cc_test(test_varient SRCS varient_test.cc DEPS utils_lite) lite_cc_test(test_varient SRCS varient_test.cc DEPS utils_lite)
cc_library(any_lite SRCS any.cc) cc_library(any_lite SRCS any.cc)
cc_library(utils_lite SRCS cp_logging.cc DEPS ${utils_DEPS} any_lite) cc_library(utils_lite SRCS cp_logging.cc DEPS ${utils_DEPS} any_lite)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册