提交 0a20a618 编写于 作者: S superjomn

fix ARM compile error

test=develop
上级 9174652b
......@@ -80,9 +80,10 @@ option(WITH_FAST_MATH "Make use of fast math library, might affect the precisi
option(WITH_DGC "Use DGC(Deep Gradient Compression) or not" ON)
# for lite, both server and mobile framework.
option(LITE_WITH_CUDA "Enable CUDA in lite mode" ON)
option(WITH_LITE "Enable lite framework" ON)
option(LITE_WITH_CUDA "Enable CUDA in lite mode" OFF)
option(LITE_WITH_X86 "Enable X86 in lite mode" ON)
option(LITE_WITH_LIGHT_WEIGHT_FRAMEWORK "Enable light-weight framework" ON)
option(LITE_WITH_LIGHT_WEIGHT_FRAMEWORK "Enable light-weight framework" OFF)
set(THIRD_PARTY_PATH "${CMAKE_BINARY_DIR}/third_party" CACHE STRING
......
......@@ -124,7 +124,7 @@ cc_library(shape_inference SRCS shape_inference.cc DEPS ddim attribute device_co
cc_library(transfer_scope_cache SRCS transfer_scope_cache.cc DEPS scope framework_proto device_context)
cc_library(op_kernel_type SRCS op_kernel_type.cc DEPS device_context place)
cc_library(operator SRCS operator.cc DEPS op_info device_context tensor scope glog
shape_inference data_transform lod_tensor profiler transfer_scope_cache op_kernel_type)
shape_inference data_transform lod_tensor profiler transfer_scope_cache op_kernel_type data_feed_proto)
cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry device_context)
......
......@@ -2,11 +2,34 @@ if (NOT WITH_LITE)
return()
endif()
message(WARNING "Enable Lite")
message(WARNING "Lite enabled!")
message(STATUS "LIGHT_FRAMEWORK: ${LITE_WITH_LIGHT_WEIGHT_FRAMEWORK}")
message(STATUS "LITE_WITH_CUDA: ${LITE_WITH_CUDA}")
message(STATUS "LITE_WITH_X86: ${LITE_WITH_X86}")
set(LITE_MODEL_DIR "${THIRD_PARTY_PATH}/install")
function(lite_download_and_uncompress INSTALL_DIR URL FILENAME)
message(STATUS "Download inference test stuff from ${URL}/${FILENAME}")
string(REGEX REPLACE "[-%.]" "_" FILENAME_EX ${FILENAME})
set(EXTERNAL_PROJECT_NAME "extern_lite_download_${FILENAME_EX}")
set(UNPACK_DIR "${INSTALL_DIR}/src/${EXTERNAL_PROJECT_NAME}")
ExternalProject_Add(
${EXTERNAL_PROJECT_NAME}
${EXTERNAL_PROJECT_LOG_ARGS}
PREFIX ${INSTALL_DIR}
DOWNLOAD_COMMAND wget --no-check-certificate -q -O ${INSTALL_DIR}/${FILENAME} ${URL}/${FILENAME} &&
${CMAKE_COMMAND} -E tar xzf ${INSTALL_DIR}/${FILENAME}
DOWNLOAD_DIR ${INSTALL_DIR}
DOWNLOAD_NO_PROGRESS 1
CONFIGURE_COMMAND ""
BUILD_COMMAND ""
UPDATE_COMMAND ""
INSTALL_COMMAND ""
)
endfunction()
add_subdirectory(core)
add_subdirectory(x86)
add_subdirectory(host)
......
......@@ -18,8 +18,22 @@ cc_library(light_api_lite SRCS light_api.cc DEPS ${light_api_deps} ${ops_lite} $
message(STATUS "get ops ${ops_lite}")
message(STATUS "get kernels ${host_kernels}")
lite_cc_test(test_cxx_api_lite SRCS cxx_api_test.cc DEPS cxx_api_lite model_parser_lite target_wrapper_host
${ops_lite} ${host_kernels})
lite_cc_test(test_light_api SRCS light_api_test.cc DEPS light_api_lite)
include(ExternalProject)
set(LITE_URL "http://paddle-inference-dist.bj.bcebos.com" CACHE STRING "inference download url")
set(LITE_DEMO_INSTALL_DIR "${THIRD_PARTY_PATH}/inference_demo" CACHE STRING
"A path setting inference demo download directories.")
lite_cc_test(test_cxx_api_lite SRCS cxx_api_test.cc
DEPS cxx_api_lite model_parser_lite target_wrapper_host
${ops_lite} ${host_kernels} ARGS --model_dir=${LITE_MODEL_DIR}/lite_naive_model
--optimized_model=${LITE_MODEL_DIR}/lite_naive_model_opt SERIAL)
if(WITH_TESTING)
lite_download_and_uncompress(${LITE_MODEL_DIR} ${LITE_URL} "lite_naive_model.tar.gz")
add_dependencies(test_cxx_api_lite extern_lite_download_lite_naive_model_tar_gz)
endif(WITH_TESTING)
lite_cc_test(test_light_api SRCS light_api_test.cc DEPS light_api_lite ARGS --optimized_model=${LITE_MODEL_DIR}/lite_naive_model_opt SERIAL)
cc_binary(cxx_api_lite_bin SRCS cxx_api_bin.cc DEPS cxx_api_lite model_parser_lite target_wrapper_host ${ops_lite} ${host_kernels})
......@@ -19,6 +19,7 @@
#include "paddle/fluid/lite/core/op_registry.h"
DEFINE_string(model_dir, "", "");
DEFINE_string(optimized_model, "", "");
namespace paddle {
namespace lite {
......@@ -48,7 +49,7 @@ TEST(CXXApi, test) {
data[i] = i;
}
LOG(INFO) << "input " << *input_tensor;
// LOG(INFO) << "input " << *input_tensor;
predictor.Run();
......@@ -57,7 +58,7 @@ TEST(CXXApi, test) {
LOG(INFO) << "out " << out->data<float>()[0];
LOG(INFO) << "out " << out->data<float>()[1];
LOG(INFO) << "dims " << out->dims();
LOG(INFO) << "out " << *out;
// LOG(INFO) << "out " << *out;
}
#ifndef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
......@@ -67,7 +68,7 @@ TEST(CXXApi, save_model) {
predictor.Build(FLAGS_model_dir, Place{TARGET(kCUDA), PRECISION(kFloat)},
valid_places);
predictor.SaveModel("./optimized_model");
predictor.SaveModel(FLAGS_optimized_model);
}
#endif
......
......@@ -41,21 +41,20 @@ class LightPredictor {
void Run() { program_->Run(); }
// Get offset-th col of feed.
TensorBase* GetInput(size_t offset) {
Tensor* GetInput(size_t offset) {
auto* _feed_list = program_->exec_scope()->FindVar("feed");
CHECK(_feed_list) << "no feed variable in exec_scope";
auto* feed_list = _feed_list->GetMutable<std::vector<TensorBase>>();
auto* feed_list = _feed_list->GetMutable<std::vector<Tensor>>();
if (offset >= feed_list->size()) {
feed_list->resize(offset + 1);
}
return &feed_list->at(offset);
}
const TensorBase* GetOutput(size_t offset) {
const Tensor* GetOutput(size_t offset) {
auto* _fetch_list = program_->exec_scope()->FindVar("fetch");
CHECK(_fetch_list) << "no fatch variable in exec_scope";
auto& fetch_list =
*_fetch_list->GetMutable<std::vector<lite::TensorBase>>();
auto& fetch_list = *_fetch_list->GetMutable<std::vector<lite::Tensor>>();
CHECK_LT(offset, fetch_list.size()) << "offset " << offset << " overflow";
return &fetch_list.at(offset);
}
......
......@@ -13,21 +13,20 @@
// limitations under the License.
#include "paddle/fluid/lite/api/light_api.h"
#include <gflags/gflags.h>
#include <gtest/gtest.h>
DEFINE_string(optimized_model, "", "");
namespace paddle {
namespace lite {
const std::string model_dir =
"/home/chunwei/project/Paddle/cmake-build-relwithdebinfo/paddle/fluid/lite/"
"api/optimized_model";
TEST(LightAPI, load) {
LightPredictor predictor;
predictor.Build(model_dir);
predictor.Build(FLAGS_optimized_model);
auto* input_tensor = predictor.GetInput(0);
input_tensor->Resize({100, 100});
input_tensor->Resize(DDimLite(std::vector<int64_t>({100, 100})));
auto* data = input_tensor->mutable_data<float>();
for (int i = 0; i < 100 * 100; i++) {
data[i] = i;
......@@ -40,13 +39,13 @@ TEST(LightAPI, load) {
} // namespace paddle
USE_LITE_OP(mul);
USE_LITE_OP(fc);
USE_LITE_OP(scale);
// USE_LITE_OP(fc);
// USE_LITE_OP(scale);
USE_LITE_OP(feed);
USE_LITE_OP(fetch);
USE_LITE_OP(io_copy);
USE_LITE_KERNEL(fc, kHost, kFloat, kNCHW, def);
USE_LITE_KERNEL(mul, kHost, kFloat, kNCHW, def);
USE_LITE_KERNEL(scale, kHost, kFloat, kNCHW, def);
// USE_LITE_KERNEL(fc, kHost, kFloat, kNCHW, def);
// USE_LITE_KERNEL(mul, kHost, kFloat, kNCHW, def);
// USE_LITE_KERNEL(scale, kHost, kFloat, kNCHW, def);
USE_LITE_KERNEL(feed, kHost, kAny, kAny, def);
USE_LITE_KERNEL(fetch, kHost, kAny, kAny, def);
cc_library(lite_gtest_main SRCS lite_gtest_main.cc)
cc_library(memory_lite SRCS memory.cc)
cc_library(memory_lite SRCS memory.cc DEPS target_wrapper_lite target_wrapper_host)
cc_library(target_wrapper_lite SRCS target_wrapper.cc)
cc_library(lite_tensor SRCS lite_tensor.cc DEPS memory_lite target_wrapper_lite)
......@@ -24,7 +24,7 @@ cc_library(op_lite SRCS op_lite.cc DEPS scope_lite op_registry_lite compatible_p
cc_library(types_lite SRCS types.cc)
cc_library(type_system SRCS type_system.cc DEPS ${tensor_lite} target_wrapper_lite)
cc_library(kernel_executor_lite SRCS kernel_executor.cc DEPS mir_ssa_graph kernel_lite)
#cc_library(kernel_executor_lite SRCS kernel_executor.cc DEPS mir_ssa_graph kernel_lite)
cc_library(program_lite SRCS program.cc DEPS op_lite kernel_lite)
cc_library(optimizer_lite SRCS optimizer.cc DEPS mir_pass_manager model_parser_lite program_lite)
......@@ -46,5 +46,5 @@ lite_cc_test(test_kernel_lite SRCS kernel_test.cc DEPS kernel_lite target_wrappe
lite_cc_test(test_op_lite SRCS op_lite_test.cc DEPS op_lite)
lite_cc_test(test_tensor_lite SRCS lite_tensor_test.cc DEPS lite_tensor)
lite_cc_test(test_type_system SRCS type_system_test.cc DEPS type_system utils_lite)
lite_cc_test(test_optimizer_lite SRCS optimizer_test.cc DEPS mir_pass_manager program_fake_utils mir_passes)
#lite_cc_test(test_optimizer_lite SRCS optimizer_test.cc DEPS mir_pass_manager program_fake_utils mir_passes optimizer_lite fc_op_lite)
lite_cc_test(test_types_lite SRCS types_test.cc DEPS types_lite)
......@@ -22,9 +22,11 @@ static void* TargetMalloc(TargetType target, size_t size) {
void* data{nullptr};
switch (target) {
case TargetType::kHost:
#ifdef LITE_WITH_X86
case TargetType::kX86:
data = TargetWrapper<TARGET(kHost)>::Malloc(size);
break;
#endif
#ifdef LITE_WITH_CUDA
case TargetType::kCUDA:
data =
......
......@@ -22,14 +22,17 @@ endif()
cc_test(test_mir_pass_manager SRCS pass_manager_test.cc DEPS mir_pass_manager mir_passes)
cc_test(test_ssa_graph SRCS ssa_graph_test.cc DEPS
mir_ssa_graph scope_lite op_lite
${ops_lite}
fc_op_lite
${host_kernels}
mir_passes
mir_pass_manager
program_fake_utils
)
set(test_variable_place_infrence_pass_DEPS
${ops_lite}
mul_op_lite
feed_op_lite
fetch_op_lite
io_copy_op_lite
${host_kernels}
mir_passes
mir_pass_manager
......
......@@ -35,10 +35,12 @@ void BuildFc(framework::ProgramDesc* desc, const std::string& x,
}
TEST(SSAGraph, test) {
auto program = ProgramFaker();
auto program_faker = ProgramFaker();
SSAGraph graph;
std::vector<Place> places{{TARGET(kHost), PRECISION(kFloat)}};
auto scope = std::make_shared<lite::Scope>();
lite::Program program(*program_faker.program()->Proto(), scope, places);
graph.Build(program, places);
Visualize(&graph);
......@@ -49,4 +51,4 @@ TEST(SSAGraph, test) {
} // namespace paddle
USE_LITE_OP(fc);
USE_LITE_KERNEL(fc, kHost, kFloat);
USE_LITE_KERNEL(fc, kHost, kFloat, kNCHW, def);
......@@ -38,12 +38,6 @@ class VariablePlaceInferencePass : public DebugPass {
LOG(INFO) << "found kernel in inputs " << v->AsStmt().op_type;
continue;
}
// auto& arg = v->AsArgument();
// LOG(INFO) << "get graph input " << arg.name << " " << *arg.type;
// arg.type.target = argument_default_target_;
// the other place description can't be determined yet, until their first
// usage by some kernel.
}
}
......
......@@ -44,7 +44,7 @@ TEST(variable_place_inference_pass, test) {
},
});
Program program(*desc, scope, places);
Program program(*desc->Proto(), scope, places);
core::KernelPickFactor factor;
factor.ConsiderTarget();
......
......@@ -93,7 +93,8 @@ class KernelRegistry final {
std::move(creator));
}
template <TargetType Target, PrecisionType Precision, DataLayoutType Layout>
template <TargetType Target, PrecisionType Precision = PRECISION(kFloat),
DataLayoutType Layout = DATALAYOUT(kNCHW)>
std::list<std::unique_ptr<KernelBase>> Create(const std::string &op_type) {
using kernel_registor_t =
KernelRegistryForTarget<Target, Precision, Layout>;
......
......@@ -25,12 +25,18 @@ namespace lite {
TEST(Optimizer, test) {
Optimizer optimizer;
auto program = ProgramFaker();
auto program_faker = ProgramFaker();
program_faker.AddFeed("X", 0);
program_faker.AddFetch("X", 0);
std::vector<Place> places({Place{TARGET(kHost), PRECISION(kFloat)}});
core::KernelPickFactor factor;
factor.ConsiderTarget();
auto scope = std::make_shared<lite::Scope>();
auto program_proto = *program_faker.program()->Proto();
Program program(program_proto, scope, places);
optimizer.Run(std::move(program), places, factor);
auto runtime_program = optimizer.GenRuntimeProgram();
LOG(INFO) << "num statements " << runtime_program->num_instructions();
......
......@@ -70,7 +70,7 @@ struct Program {
VLOG(4) << "create Op [" << op_type << "]";
auto op = LiteOpRegistry::Global().Create(op_type);
CHECK(op) << "no Op found for " << op_type;
ops.emplace_back(op);
ops.emplace_back(std::move(op));
ops.back()->Attach(op_desc, exec_scope);
}
}
......
......@@ -14,6 +14,7 @@
#pragma once
#include <string>
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/lite/core/mir/ssa_graph.h"
#include "paddle/fluid/lite/core/op_registry.h"
......@@ -28,9 +29,9 @@ Program FakeProgram() {
std::string w1 = "w" + std::to_string(id);
std::string b1 = "b" + std::to_string(id);
std::string out1 = "out" + std::to_string(id);
auto w1v = program.scope->Var(w1)->GetMutable<TensorBase>();
auto b1v = program.scope->Var(b1)->GetMutable<TensorBase>();
auto out1v = program.scope->Var(out1)->GetMutable<TensorBase>();
auto w1v = program.scope->Var(w1)->GetMutable<lite::Tensor>();
auto b1v = program.scope->Var(b1)->GetMutable<lite::Tensor>();
auto out1v = program.scope->Var(out1)->GetMutable<lite::Tensor>();
lite::OpDesc desc;
desc.SetInput("Input", {x});
......@@ -38,7 +39,7 @@ Program FakeProgram() {
desc.SetInput("Bias", {b1});
desc.SetOutput("Out", {out1});
desc.SetType("fc");
desc.SetAttr<int>("in_num_col_dims", 1);
desc.SetAttr("in_num_col_dims", 1);
// add to input
program.tmp_vars.push_back(w1);
......@@ -48,9 +49,9 @@ Program FakeProgram() {
fc_op->Attach(desc, program.scope.get());
program.ops.emplace_back(std::move(fc_op));
w1v->Resize({100, 100});
b1v->Resize({100, 1});
out1v->Resize({100, 100});
w1v->Resize(DDimHvy(std::vector<int64_t>({100, 100})));
b1v->Resize(DDimHvy(std::vector<int64_t>({100, 1})));
out1v->Resize(DDimHvy(std::vector<int64_t>({100, 100})));
return out1;
};
......@@ -60,8 +61,8 @@ Program FakeProgram() {
std::string x = "x";
program.tmp_vars.push_back(x);
auto* xv = program.scope->Var(x)->GetMutable<TensorBase>();
xv->Resize({100, 100});
auto* xv = program.scope->Var(x)->GetMutable<lite::Tensor>();
xv->Resize(DDimHvy(std::vector<int64_t>({100, 100})));
for (int i = 0; i < 3; i++) {
x = add_fc(i, x);
......@@ -81,7 +82,7 @@ class ProgramFaker {
void CreateVars(lite::Scope* scope) {
for (auto& var : tmp_vars_) {
auto* x = scope->Var(var);
x->GetMutable<lite::TensorBase>();
x->GetMutable<lite::Tensor>();
}
for (auto& x : tmp_vars_) {
......
message(STATUS "compile with lite host kernels")
cc_library(fc_compute_host SRCS fc_compute.cc DEPS ${lite_kernel_deps})
cc_library(fc_compute_host SRCS fc_compute.cc DEPS ${lite_kernel_deps} eigen3)
cc_library(relu_compute_host SRCS relu_compute.cc DEPS ${lite_kernel_deps})
cc_library(mul_compute_host SRCS mul_compute.cc DEPS ${lite_kernel_deps})
cc_library(scale_compute_host SRCS scale_compute.cc DEPS ${lite_kernel_deps})
cc_library(mul_compute_host SRCS mul_compute.cc DEPS ${lite_kernel_deps} eigen3)
cc_library(scale_compute_host SRCS scale_compute.cc DEPS ${lite_kernel_deps} eigen3)
cc_library(feed_compute_host SRCS feed_compute.cc DEPS ${lite_kernel_deps})
cc_library(fetch_compute_host SRCS fetch_compute.cc DEPS ${lite_kernel_deps})
......@@ -16,5 +17,3 @@ set(host_kernels
)
set(host_kernels "${host_kernels}" CACHE INTERNAL "host kernels")
lite_cc_test(test_fc_compute SRCS fc_compute_test.cc DEPS fc_compute_host fc_op_lite)
......@@ -23,7 +23,7 @@ namespace kernels {
namespace host {
TEST(fc_compute_naive, test) {
TensorBase x, w, b, out, out1;
lite::Tensor x, w, b, out, out1;
const int batch_size = 2;
x.Resize({batch_size, 3});
w.Resize({4, 3});
......@@ -79,15 +79,15 @@ TEST(fc_host, compute) {
FcCompute fc;
operators::FcParam param;
TensorBase x;
TensorBase w;
TensorBase bias;
TensorBase output;
lite::Tensor x;
lite::Tensor w;
lite::Tensor bias;
lite::Tensor output;
x.Resize({1, 10, 20});
w.Resize({20, 20});
bias.Resize({1, 10});
output.Resize({10, 20});
x.Resize(DDim(std::vector<int64_t>({1, 10, 20})));
w.Resize(DDim(std::vector<int64_t>({20, 20})));
bias.Resize(DDim(std::vector<int64_t>({1, 10})));
output.Resize(DDim(std::vector<int64_t>({10, 20})));
auto* x_data = x.mutable_data<float>();
auto* w_data = w.mutable_data<float>();
......@@ -119,7 +119,7 @@ TEST(fc_host, compute) {
TEST(fc, retrive_op) {
auto fc =
KernelRegistry::Global().Create<TARGET(kHost), PRECISION(kFloat)>("fc");
ASSERT_TRUE(fc.get());
ASSERT_TRUE(fc);
}
} // namespace host
......@@ -127,4 +127,4 @@ TEST(fc, retrive_op) {
} // namespace lite
} // namespace paddle
USE_LITE_KERNEL(fc, kHost, kFloat);
USE_LITE_KERNEL(fc, kHost, kFloat, kNCHW, def);
cc_library(runtime_lite SRCS runtime.cc)
lite_cc_test(test_model_parser_lite SRCS model_parser_test.cc DEPS model_parser_lite)
#cc_library(runtime_lite SRCS runtime.cc)
lite_cc_test(test_model_parser_lite SRCS model_parser_test.cc
DEPS model_parser_lite framework_proto_lite
ARGS --model_dir=${LITE_MODEL_DIR}/lite_naive_model)
if(WITH_TESTING)
add_dependencies(test_model_parser_lite extern_lite_download_lite_naive_model_tar_gz)
endif(WITH_TESTING)
if(LITE_WITH_LIGHT_WEIGHT_FRAMEWORK)
cc_library(compatible_pb_lite SRCS compatible_pb.cc DEPS op_desc_lite var_desc_lite)
else()
......
......@@ -20,8 +20,8 @@
* lite::pb::XXDesc.
*/
#include "paddle/fluid/framework/framework.pb.h"
#ifdef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
#include "paddle/fluid/lite/core/framework.pb.h"
#include "paddle/fluid/lite/model_parser/pb/op_desc.h"
#include "paddle/fluid/lite/model_parser/pb/var_desc.h"
#else
......
......@@ -18,7 +18,7 @@
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/lite/core/framework.pb.h"
#include "paddle/fluid/lite/core/scope.h"
#include "paddle/fluid/lite/core/variable.h"
......
......@@ -13,22 +13,25 @@
// limitations under the License.
#include "paddle/fluid/lite/model_parser/model_parser.h"
#include <gflags/gflags.h>
#include <gtest/gtest.h>
#include "paddle/fluid/lite/core/scope.h"
DEFINE_string(model_dir, "", "");
namespace paddle {
namespace lite {
TEST(ModelParser, LoadProgram) {
auto program = LoadProgram(
"/home/chunwei/project2/models/fc/fluid_checkpoint/__model__");
CHECK(!FLAGS_model_dir.empty());
auto program = LoadProgram(FLAGS_model_dir + "/__model__");
}
TEST(ModelParser, LoadParam) {
Scope scope;
auto* v = scope.Var("xxx");
LoadParam("/home/chunwei/project2/models/fc/fluid_checkpoint/b1", v);
const auto& t = v->Get<TensorBase>();
LoadParam(FLAGS_model_dir + "/fc_0.b_0", v);
const auto& t = v->Get<Tensor>();
LOG(INFO) << "loaded\n";
LOG(INFO) << t;
}
......@@ -36,7 +39,7 @@ TEST(ModelParser, LoadParam) {
TEST(ModelParser, LoadModel) {
Scope scope;
framework::proto::ProgramDesc prog;
LoadModel("/home/chunwei/project2/models/fc/fluid_checkpoint", &scope, &prog);
LoadModel(FLAGS_model_dir, &scope, &prog);
}
} // namespace lite
......
......@@ -24,7 +24,7 @@
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/lite/core/framework.pb.h"
#include "paddle/fluid/lite/utils/all.h"
namespace paddle {
......
......@@ -17,7 +17,7 @@
#include <algorithm>
#include <string>
#include <vector>
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/lite/core/framework.pb.h"
#include "paddle/fluid/lite/utils/cp_logging.h"
namespace paddle {
......
......@@ -59,9 +59,6 @@ class FcOpLite : public OpLite {
param_.output = scope->FindVar(out)->GetMutable<lite::Tensor>();
param_.in_num_col_dims = GetAttr<int>(op_desc.GetAttr("in_num_col_dims"));
CHECK(kernel_);
kernel_->SetParam(param_);
return true;
}
......
......@@ -21,17 +21,16 @@ namespace lite {
namespace operators {
TEST(fc_op_lite, test) {
LOG(INFO) << "\n" << KernelRegistry::Global().DebugString();
// prepare variables
Scope scope;
auto* x = scope.Var("x")->GetMutable<TensorBase>();
auto* w = scope.Var("w")->GetMutable<TensorBase>();
auto* bias = scope.Var("bias")->GetMutable<TensorBase>();
auto* output = scope.Var("output")->GetMutable<TensorBase>();
x->Resize({1, 10, 20});
w->Resize({20, 20});
bias->Resize({1, 10});
output->Resize({10, 20});
auto* x = scope.Var("x")->GetMutable<Tensor>();
auto* w = scope.Var("w")->GetMutable<Tensor>();
auto* bias = scope.Var("bias")->GetMutable<Tensor>();
auto* output = scope.Var("output")->GetMutable<Tensor>();
x->Resize(DDim(std::vector<int64_t>({1, 10, 20})));
w->Resize(DDim(std::vector<int64_t>{20, 20}));
bias->Resize(DDim(std::vector<int64_t>{1, 10}));
output->Resize(DDim(std::vector<int64_t>{10, 20}));
// set data
for (int i = 0; i < 10 * 20; i++) {
......@@ -59,18 +58,13 @@ TEST(fc_op_lite, test) {
FcOpLite fc("fc");
fc.SetValidPlaces({Place{TARGET(kHost), PRECISION(kFloat)}});
fc.PickKernel({Place{TARGET(kHost), PRECISION(kFloat)}});
fc.AttachImpl(desc, &scope);
fc.Run();
for (int i = 0; i < 10 * 20; i++) {
LOG(INFO) << output->data<float>()[i];
}
fc.Attach(desc, &scope);
auto kernels = fc.CreateKernels({Place{TARGET(kHost), PRECISION(kFloat)}});
ASSERT_FALSE(kernels.empty());
}
} // namespace operators
} // namespace lite
} // namespace paddle
USE_LITE_KERNEL(fc, kHost, kFloat);
USE_LITE_KERNEL(fc, kHost, kFloat, kNCHW, def);
if(LITE_WITH_LIGHT_WEIGHT_FRAMEWORK)
set(utils_DEPS)
lite_cc_test(test_logging_lite SRCS logging_test.cc)
else()
set(utils_DEPS glog)
endif()
lite_cc_test(test_logging_lite SRCS logging_test.cc)
lite_cc_test(test_varient SRCS varient_test.cc DEPS utils_lite)
cc_library(any_lite SRCS any.cc)
cc_library(utils_lite SRCS cp_logging.cc DEPS ${utils_DEPS} any_lite)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册