提交 c353397d 编写于 作者: T tensor-tang

fix and merge from github b8572aa3

上级 d1904d11
......@@ -18,6 +18,7 @@ limitations under the License. */
#include <mutex> // NOLINT
#include <string>
#include <unordered_map>
#include <utility>
#include "glog/logging.h"
#include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/op_proto_maker.h"
......
......@@ -34,7 +34,7 @@ endfunction()
function (lite_deps TARGET)
set(options "")
set(oneValueArgs "")
set(multiValueArgs DEPS X86_DEPS CUDA_DEPS ARM_DEPS PROFILE_DEPS ARGS)
set(multiValueArgs DEPS X86_DEPS CUDA_DEPS ARM_DEPS PROFILE_DEPS LIGHT_DEPS HVY_DEPS ARGS)
cmake_parse_arguments(lite_deps "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
set(deps ${lite_deps_DEPS})
......@@ -63,14 +63,39 @@ function (lite_deps TARGET)
endforeach(var)
endif()
if(LITE_WITH_LIGHT_WEIGHT_FRAMEWORK)
foreach(var ${lite_deps_LIGHT_DEPS})
set(deps ${deps} ${var})
endforeach(var)
endif()
if (NOT LITE_WITH_LIGHT_WEIGHT_FRAMEWORK)
foreach(var ${lite_deps_HVY_DEPS})
set(deps ${deps} ${var})
endforeach(var)
endif()
set(${TARGET} ${deps} PARENT_SCOPE)
endfunction()
# Add names for lite libraries for latter compile. We use this name list to avoid compiling
# the whole fluid project to accelerate the compile speed.
set(offline_lib_registry_file "${CMAKE_BINARY_DIR}/lite_libs.txt")
file(WRITE ${offline_lib_registry_file} "") # clean
# cc_library with branch support.
# The branches:
# X86_DEPS: works only when LITE_WITH_X86 is ON.
# CUDA_DEPS: LITE_WITH_CUDA
# ARM_DEPS: LITE_WITH_ARM
# PROFILE_DEPS: LITE_WITH_PROFILE
# LIGHT_DEPS: LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
# HVY_DEPS: NOT LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
function(lite_cc_library TARGET)
set(options "")
set(oneValueArgs "")
set(multiValueArgs SRCS DEPS X86_DEPS CUDA_DEPS ARM_DEPS PROFILE_DEPS ARGS)
set(multiValueArgs SRCS DEPS X86_DEPS CUDA_DEPS ARM_DEPS PROFILE_DEPS LIGHT_DEPS
HVY_DEPS ARGS)
cmake_parse_arguments(args "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
set(deps "")
......@@ -79,15 +104,22 @@ function(lite_cc_library TARGET)
X86_DEPS ${args_X86_DEPS}
CUDA_DEPS ${args_CUDA_DEPS}
ARM_DEPS ${args_ARM_DEPS}
PROFILE_DEPS ${args_PROFILE_DEPS})
PROFILE_DEPS ${args_PROFILE_DEPS}
LIGHT_DEPS ${args_LIGHT_DEPS}
HVY_DEPS ${args_HVY_DEPS}
)
cc_library(${TARGET} SRCS ${args_SRCS} DEPS ${deps} ${args_DEPS})
# register a library name.
file(APPEND ${offline_lib_registry_file} "${TARGET}\n")
endfunction()
function(lite_cc_binary TARGET)
set(options "")
set(oneValueArgs "")
set(multiValueArgs SRCS DEPS X86_DEPS CUDA_DEPS ARM_DEPS PROFILE_DEPS ARGS)
set(multiValueArgs SRCS DEPS X86_DEPS CUDA_DEPS ARM_DEPS PROFILE_DEPS
LIGHT_DEPS HVY_DEPS ARGS)
cmake_parse_arguments(args "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
set(deps "")
......@@ -97,6 +129,8 @@ function(lite_cc_binary TARGET)
CUDA_DEPS ${args_CUDA_DEPS}
ARM_DEPS ${args_ARM_DEPS}
PROFILE_DEPS ${args_PROFILE_DEPS}
LIGHT_DEPS ${args_LIGHT_DEPS}
HVY_DEPS ${args_HVY_DEPS}
)
cc_binary(${TARGET} SRCS ${args_SRCS} DEPS ${deps} ${args_DEPS})
endfunction()
......@@ -104,15 +138,13 @@ endfunction()
# Add a unit-test name to file for latter offline manual test.
set(offline_test_registry_file "${CMAKE_BINARY_DIR}/lite_tests.txt")
file(WRITE ${offline_test_registry_file} "") # clean
function (register_test_offline TARGET)
file(APPEND ${offline_test_registry_file} "${TARGET}\n")
endfunction()
# Test lite modules.
function(lite_cc_test TARGET)
set(options "")
set(oneValueArgs "")
set(multiValueArgs SRCS DEPS X86_DEPS CUDA_DEPS ARM_DEPS PROFILE_DEPS ARGS)
set(multiValueArgs SRCS DEPS X86_DEPS CUDA_DEPS ARM_DEPS PROFILE_DEPS
LIGHT_DEPS HVY_DEPS
ARGS)
cmake_parse_arguments(args "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
set(deps "")
......@@ -122,9 +154,11 @@ function(lite_cc_test TARGET)
CUDA_DEPS ${args_CUDA_DEPS}
ARM_DEPS ${args_ARM_DEPS}
PROFILE_DEPS ${args_PROFILE_DEPS}
LIGHT_DEPS ${args_LIGHT_DEPS}
HVY_DEPS ${args_HVY_DEPS}
)
_lite_cc_test(${TARGET} SRCS ${args_SRCS} DEPS ${deps} ARGS ${args_ARGS})
register_test_offline("${TARGET}")
file(APPEND ${offline_test_registry_file} "${TARGET}\n")
endfunction()
add_subdirectory(core)
......@@ -137,4 +171,4 @@ add_subdirectory(kernels)
add_subdirectory(model_parser)
add_subdirectory(utils)
add_subdirectory(api)
add_subdirectory(gen_code)
......@@ -76,6 +76,7 @@ TEST(CXXApi, save_model) {
predictor.Build(FLAGS_model_dir, Place{TARGET(kCUDA), PRECISION(kFloat)},
valid_places);
LOG(INFO) << "Save optimized model to " << FLAGS_optimized_model;
predictor.SaveModel(FLAGS_optimized_model);
}
#endif // LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
......@@ -130,6 +131,9 @@ USE_LITE_OP(square)
USE_LITE_OP(softmax)
USE_LITE_OP(dropout)
USE_LITE_OP(concat)
USE_LITE_OP(conv2d)
USE_LITE_OP(depthwise_conv2d)
USE_LITE_OP(pool2d)
USE_LITE_KERNEL(feed, kHost, kAny, kAny, def);
USE_LITE_KERNEL(fetch, kHost, kAny, kAny, def);
......@@ -144,6 +148,9 @@ USE_LITE_KERNEL(elementwise_add, kX86, kFloat, kNCHW, def);
USE_LITE_KERNEL(softmax, kX86, kFloat, kNCHW, def);
USE_LITE_KERNEL(dropout, kX86, kFloat, kNCHW, def);
USE_LITE_KERNEL(concat, kX86, kFloat, kNCHW, def);
USE_LITE_KERNEL(conv2d, kX86, kFloat, kNCHW, def);
USE_LITE_KERNEL(depthwise_conv2d, kX86, kFloat, kNCHW, def);
USE_LITE_KERNEL(pool2d, kX86, kFloat, kNCHW, def);
#endif
#ifdef LITE_WITH_CUDA
......
......@@ -64,7 +64,7 @@ class LightPredictor {
private:
void BuildRuntimeProgram(const framework::proto::ProgramDesc& prog) {
std::vector<Instruct> insts;
std::vector<Instruction> insts;
// 1. Create op first
Program program(prog, scope_, {});
......@@ -72,7 +72,7 @@ class LightPredictor {
// Create the kernels of the target places, and filter out the specific
// kernel with the target alias.
for (auto& op : program.ops) {
for (auto& op : program.ops_) {
lite::pb::OpDesc desc(op->op_info()->desc());
auto kernel_type = desc.GetAttr(kKernelTypeAttr).get<std::string>();
std::string op_type, alias;
......@@ -89,8 +89,8 @@ class LightPredictor {
insts.emplace_back(op, std::move(*it));
}
program_.reset(new RuntimeProgram(std::move(insts)));
CHECK(program.exec_scope);
program_->set_exec_scope(program.exec_scope);
CHECK(program.exec_scope_);
program_->set_exec_scope(program.exec_scope_);
}
private:
......
if(NOT (LITE_WITH_LIGHT_WEIGHT_FRAMEWORK AND LITE_WITH_ARM))
return()
endif()
if(NOT (LITE_WITH_LIGHT_WEIGHT_FRAMEWORK AND LITE_WITH_ARM))
return()
......
......@@ -8,7 +8,7 @@ lite_cc_library(target_wrapper_lite SRCS target_wrapper.cc
lite_cc_library(memory_lite SRCS memory.cc DEPS target_wrapper_lite)
lite_cc_library(lite_tensor SRCS lite_tensor.cc DEPS memory_lite target_wrapper_lite)
if (NOT LITE_WITH_LIGHT_WEIGHT_FRAMEWORK)
cc_library(hvy_tensor SRCS hvy_tensor.cc DEPS lod_tensor)
lite_cc_library(hvy_tensor SRCS hvy_tensor.cc DEPS lod_tensor HVY_DEPS framework_proto)
endif()
if(LITE_WITH_LIGHT_WEIGHT_FRAMEWORK)
......@@ -19,19 +19,18 @@ endif()
proto_library(framework_proto_lite SRCS framework.proto)
cc_library(kernel_lite SRCS kernel.cc DEPS type_system target_wrapper_lite any_lite op_params_lite framework_proto_lite)
cc_library(kernel_lite SRCS kernel.cc DEPS type_system target_wrapper_lite any_lite op_params_lite framework_proto_lite ${tensor_lite})
cc_library(variable_lite SRCS variable.cc)
cc_library(op_registry_lite SRCS op_registry.cc DEPS framework_proto_lite)
cc_library(scope_lite SRCS scope.cc DEPS ${tensor_lite})
cc_library(cpu_info_lite SRCS cpu_info.cc)
cc_library(context_lite SRCS context.cc DEPS ${tensor_lite} any_lite cpu_info_lite)
cc_library(op_lite SRCS op_lite.cc DEPS scope_lite op_registry_lite target_wrapper_lite
cpp_op_desc_lite
${tensor_lite})
cpp_op_desc_lite ${tensor_lite})
cc_library(types_lite SRCS types.cc)
cc_library(type_system SRCS type_system.cc DEPS ${tensor_lite} target_wrapper_lite)
cc_library(program_lite SRCS program.cc DEPS op_lite kernel_lite)
lite_cc_library(program_lite SRCS program.cc DEPS op_lite kernel_lite compatible_pb_lite model_parser_lite HVY_DEPS framework_proto)
cc_library(optimizer_lite SRCS optimizer.cc DEPS mir_pass_manager model_parser_lite program_lite)
add_subdirectory(mir)
......@@ -57,4 +56,3 @@ lite_cc_test(test_type_system SRCS type_system_test.cc DEPS type_system utils_li
lite_cc_test(test_types_lite SRCS types_test.cc DEPS types_lite)
lite_cc_test(test_memory_lite SRCS memory_test.cc DEPS memory_lite)
lite_cc_test(test_context_lite SRCS context_test.cc DEPS context_lite X86_DEPS operator)
......@@ -173,6 +173,11 @@ class Context<TargetType::kX86> {
new ::paddle::framework::ExecutionContext(*x86_device_context_));
}
Context(Context&& ctx) {
x86_device_context_ = std::move(ctx.x86_device_context_);
x86_execution_context_ = std::move(ctx.x86_execution_context_);
}
// NOTE: InitOnce should only be used by ContextScheduler
void InitOnce() {}
......
......@@ -21,6 +21,7 @@
#pragma once
#include <vector>
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/lite/core/target_wrapper.h"
#include "paddle/fluid/lite/core/tensor.h"
namespace paddle {
......@@ -65,6 +66,14 @@ class TensorHvy : public TensorBase<TensorHvy> {
using DDimT = DDimHvy;
using LoDT = framework::LoD;
template <typename DType, typename DimT, TargetType Target>
void Assign(DType* data, const DimT& dim) {
Resize(dim);
auto* dst = mutable_data<DType>(Target);
CopySync<Target>(dst, data, dim.production() * sizeof(DType),
IoDirection::HtoD);
}
TargetType target() const {
if (platform::is_gpu_place(data_.place())) {
return TARGET(kCUDA);
......@@ -95,13 +104,15 @@ class TensorHvy : public TensorBase<TensorHvy> {
const void* raw_data() const { return data_.raw_data(); }
void Resize(const DDimHvy& dims) {
LOG(INFO) << "dims.size " << dims.size();
data_.Resize(framework::make_ddim(dims.Vectorize()));
}
void ShareDataWith(const TensorHvy& other) {
data_.ShareDataWith(other.data_);
}
void ShareDataWith(const framework::Tensor& other) {
data_.ShareDataWith(other);
}
void CopyDataFrom(const TensorHvy& other) {
data_.mutable_data(other.data_.place(), other.data_.type());
TensorCopySync(other.data_, data_.place(), &data_);
......
......@@ -150,7 +150,7 @@ class KernelBase {
void Torch() {}
protected:
std::unique_ptr<KernelContext> ctx_;
std::unique_ptr<KernelContext> ctx_{nullptr};
mutable operators::param_t param_;
// The corresponding op type.
std::string op_type_{};
......
......@@ -61,6 +61,14 @@ class TensorLite : public TensorBase<TensorLite> {
TensorLite() : buffer_(std::make_shared<Buffer>()) {}
template <typename DType, typename DimT, TargetType Target>
void Assign(DType *data, const DimT &dim) {
Resize(dim);
auto *dst = mutable_data<DType>(Target);
CopySync<Target>(dst, data, dim.product() * sizeof(DType),
IoDirection::HtoD);
}
template <typename T>
const T *data() const {
return static_cast<const T *>(buffer_->data());
......
......@@ -28,28 +28,34 @@ cc_test(test_ssa_graph SRCS ssa_graph_test.cc DEPS
mir_pass_manager
program_fake_utils
)
set(test_variable_place_infrence_pass_DEPS
mul_op_lite
feed_op_lite
fetch_op_lite
io_copy_op_lite
${host_kernels}
mir_passes
mir_pass_manager
optimizer_lite
program_fake_utils
target_wrapper_host
)
if (LITE_WITH_CUDA)
set(test_variable_place_infrence_pass_DEPS
${test_variable_place_infrence_pass_DEPS} target_wrapper_cuda
kernels_cuda
)
endif()
cc_test(test_variable_place_infrence_pass SRCS variable_place_inference_pass_test.cc DEPS
${test_variable_place_infrence_pass_DEPS})
# lite_cc_test(test_variable_place_infrence_pass SRCS variable_place_inference_pass_test.cc
# DEPS
# mul_op_lite
# feed_op_lite
# fetch_op_lite
# io_copy_op_lite
# ${host_kernels}
# mir_passes
# mir_pass_manager
# optimizer_lite
# program_fake_utils
# target_wrapper_host
# PROFILE_DEPS basic_profiler_lite
# CUDA_DEPS target_wrapper_cuda kernels_cuda
# ARM_DEPS mul_compute_arm
# X86_DEPS mul_compute_x86
# )
cc_library(pattern_matcher_lite SRCS pattern_matcher.cc DEPS mir_node mir_ssa_graph op_lite)
cc_test(test_pattern_matcher_lite SRCS pattern_matcher_tester.cc DEPS pattern_matcher_lite)
lite_cc_library(pattern_matcher_lite SRCS pattern_matcher.cc DEPS mir_node mir_ssa_graph op_lite)
lite_cc_test(test_pattern_matcher_lite SRCS pattern_matcher_test.cc DEPS pattern_matcher_lite)
lite_cc_library(pattern_matcher_high_api SRCS pattern_matcher_high_api.cc DEPS pattern_matcher_lite)
# TODO(wz) replace framework/proto to lite proto.
if (LITE_WITH_LIGHT_WEIGHT_FRAMEWORK)
# it depends on the fluid/framework/proto, that is too heavy for mobile execution.
lite_cc_test(test_pattern_matcher_high_api SRCS pattern_matcher_high_api_test.cc DEPS
pattern_matcher_high_api proto_desc mir_pass_manager fc_op_lite mul_op_lite elementwise_ops_lite
mir_passes compatible_pb_lite program_lite ${ops_lite})
endif()
......@@ -41,7 +41,7 @@ class GenerateProgramPass : public ProgramPass {
}
private:
std::vector<Instruct> insts_;
std::vector<Instruction> insts_;
};
} // namespace mir
......
......@@ -16,10 +16,6 @@
namespace paddle {
namespace lite {
namespace mir {
PassManager::PassManager() {}
} // namespace mir
namespace mir {} // namespace mir
} // namespace lite
} // namespace paddle
......@@ -30,7 +30,7 @@ class PassManager {
return x;
}
PassManager();
PassManager() {}
void Run(const std::unique_ptr<SSAGraph>& graph) {
for (auto& pass : passes_) {
......
......@@ -27,6 +27,30 @@ namespace mir {
size_t PMPattern::id_ = 0UL;
PMNode &PMNode::operator>>(PMNode &right) {
pattern_->AddEdge(this, &right);
// automatically add out op link relation.
if (right.IsOp()) {
CHECK(!right.op_type_.empty());
this->assert_is_op_input(right.op_type_);
}
return right;
}
PMNode &PMNode::operator>>(std::vector<PMNode *> &nodes) {
for (auto *node : nodes) {
*this >> *node;
}
return *this;
}
void operator>>(std::vector<PMNode *> &others, PMNode &me) {
for (auto *o : others) {
*o >> me;
}
}
PMNode *PMPattern::NewNode(const std::string &name) {
if (!name.empty()) {
CHECK_EQ(node_map_.count(name), 0UL)
......@@ -122,9 +146,7 @@ void PatternMatcher::ValidateByNodeRole(
// Collect the inlinks and outlinks.
std::unordered_set<Node *> ios;
for (auto &item : subgraph) {
if (!item.first->IsIntermediate()) {
ios.insert(item.second);
}
ios.insert(item.second);
}
for (auto &item : subgraph) {
if (item.first->IsIntermediate()) {
......@@ -400,6 +422,30 @@ PMNode *PMNode::assert_is_op_input(const std::string &op_type) {
return this;
}
void GraphSafeRemoveNodes(SSAGraph *graph,
const std::unordered_set<const Node *> &nodes) {
for (auto *node : nodes) {
graph->RemoveNode(node);
}
for (auto &node : graph->mutable_nodes()) {
for (auto it = node.inlinks.begin(); it != node.inlinks.end();) {
if (nodes.count(*it)) {
it = node.inlinks.erase(it);
} else {
it++;
}
}
for (auto it = node.outlinks.begin(); it != node.outlinks.end();) {
if (nodes.count(*it)) {
it = node.outlinks.erase(it);
} else {
it++;
}
}
}
}
} // namespace mir
} // namespace lite
} // namespace paddle
......@@ -58,6 +58,15 @@ struct PMNode {
PMNode& LinksTo(const std::vector<PMNode*>& others);
PMNode& LinksFrom(const std::vector<PMNode*>& others);
// Link this to another node.
PMNode& operator>>(PMNode& right);
// Link many nodes to this node.
friend void operator>>(std::vector<PMNode*>& others, PMNode& me);
// Link this to many other nodes.
PMNode& operator>>(std::vector<PMNode*>& nodes);
bool Tell(const Node* node) const {
if (teller_) return teller_(node);
......@@ -92,6 +101,20 @@ struct PMNode {
return this;
}
PMNode* AsVar() {
type_ = Type::kVar;
assert_is_var();
return this;
}
PMNode* AsOp(const std::string& op_type) {
type_ = Type::kOp;
assert_is_op(op_type);
return this;
}
void set_op_type(const std::string& op_type) { op_type_ = op_type; }
bool IsIntermediate() const { return role_ == Role::kIntermediate; }
bool IsInput() const { return role_ == Role::kInput; }
bool IsOutput() const { return role_ == Role::kOutput; }
......@@ -141,6 +164,7 @@ struct PMNode {
std::vector<teller_t> asserts_;
PMPattern* pattern_;
std::string name_;
std::string op_type_;
Type type_;
Role role_{Role::kUnknown};
};
......@@ -273,6 +297,10 @@ class PatternMatcher {
std::unordered_map<const PMNode*, std::unordered_set<Node*>> pmnodes2nodes_;
};
// Graph safely remove some nodes, will automatically clean up the edges.
void GraphSafeRemoveNodes(SSAGraph* graph,
const std::unordered_set<const Node*>& nodes);
// Some pre-defined patterns those can be reused in multiple passes.
// The related Fluid Layer or Op should be one pattern here for better re-usage
// across different fusion.
......
......@@ -94,7 +94,7 @@ std::vector<mir::Node *> SSAGraph::StmtTopologicalOrder() {
}
void SSAGraph::GraphCreateTmpVarNodes(const Program &program) {
for (const auto &name : program.tmp_vars) {
for (const auto &name : program.tmp_vars()) {
CHECK(!arguments_.count(name)) << "duplicate creating temp variable: "
<< name;
VLOG(5) << "create arg node " << name;
......@@ -107,7 +107,7 @@ void SSAGraph::GraphCreateTmpVarNodes(const Program &program) {
void SSAGraph::GraphCreateWeightVarNodes(const Program &program) {
// create weight nodes.
for (const auto &name : program.weights) {
for (const auto &name : program.weights()) {
CHECK(!arguments_.count(name)) << "duplicate creating weight variable: "
<< name;
VLOG(5) << "create arg node " << name;
......@@ -119,8 +119,7 @@ void SSAGraph::GraphCreateWeightVarNodes(const Program &program) {
}
Node *SSAGraph::GraphCreateInstructNode(
const Program &program, const std::shared_ptr<OpLite> &op,
const std::vector<Place> &valid_places) {
const std::shared_ptr<OpLite> &op, const std::vector<Place> &valid_places) {
node_storage_.emplace_back();
// TODO(Superjomn) remove one valid_places here.
op->SetValidPlaces(valid_places);
......@@ -140,8 +139,8 @@ void SSAGraph::Build(const Program &program,
GraphCreateWeightVarNodes(program);
CHECK(CheckNodesRoleSet());
for (auto &op : program.ops) {
auto *op_node = GraphCreateInstructNode(program, op, valid_places);
for (auto &op : program.ops()) {
auto *op_node = GraphCreateInstructNode(op, valid_places);
for (const std::string &name : op->op_info()->input_names()) {
auto *arg = Argument(name);
CHECK(arg->IsRoleSet());
......@@ -162,6 +161,13 @@ void SSAGraph::Build(const Program &program,
CheckValid();
}
void SSAGraph::RemoveNode(const mir::Node *node) {
auto pos = std::find_if(node_storage_.begin(), node_storage_.end(),
[&node](mir::Node &n) { return &n == node; });
CHECK(pos != node_storage_.end());
node_storage_.erase(pos);
}
mir::Node *SSAGraph::Argument(const std::string &name) {
auto it = arguments_.find(name);
CHECK(it != arguments_.end()) << "no argument called " << name;
......
......@@ -38,6 +38,7 @@ class SSAGraph : GraphBase {
// @param program: the op program
// @param valid_places: the valid places user set for the system.
void Build(const Program &program, const std::vector<Place> &valid_places);
void RemoveNode(const mir::Node *node);
mir::Node *Argument(const std::string &name);
......@@ -63,12 +64,12 @@ class SSAGraph : GraphBase {
CHECK(CheckLinksRoleSet());
}
Node *GraphCreateInstructNode(const std::shared_ptr<OpLite> &op,
const std::vector<Place> &valid_places);
private:
void GraphCreateTmpVarNodes(const Program &program);
void GraphCreateWeightVarNodes(const Program &program);
Node *GraphCreateInstructNode(const Program &program,
const std::shared_ptr<OpLite> &op,
const std::vector<Place> &valid_places);
// Check the bidirectional connection.
bool CheckBidirectionalConnection();
......@@ -77,7 +78,7 @@ class SSAGraph : GraphBase {
bool CheckLinksRoleSet();
void MarkArgumentWeights(const Program &program) {
for (const auto &name : program.weights) {
for (const auto &name : program.weights()) {
arguments_[name]->AsArg().is_weight = true;
}
}
......
......@@ -37,6 +37,8 @@ void StaticKernelPickPass::Apply(const std::unique_ptr<SSAGraph>& graph) {
if (!node.IsStmt()) continue;
auto& instruct = node.AsStmt();
std::vector<std::pair<size_t, std::unique_ptr<KernelBase>>> scored;
CHECK(!instruct.valid_kernels.empty()) << "No kernels found for "
<< instruct.op_type;
for (auto&& kernel : instruct.valid_kernels) {
size_t score = KernelGrade(*kernel);
scored.emplace_back(score, std::move(kernel));
......
......@@ -42,6 +42,12 @@ TEST(variable_place_inference_pass, test) {
Place{
TARGET(kCUDA), PRECISION(kFloat), DATALAYOUT(kNCHW),
},
Place{
TARGET(kX86), PRECISION(kFloat), DATALAYOUT(kNCHW),
},
Place{
TARGET(kX86), PRECISION(kAny), DATALAYOUT(kAny),
},
});
Program program(*desc->Proto(), scope, places);
......@@ -58,7 +64,15 @@ TEST(variable_place_inference_pass, test) {
});
Place prefered_place{
#ifdef PADDLE_WITH_CUDA
TARGET(kCUDA), PRECISION(kFloat), DATALAYOUT(kNCHW),
#else
#ifdef PADDLE_WITH_ARM
TARGET(kARM), PRECISION(kFloat), DATALAYOUT(kNCHW),
#else // X86
TARGET(kX86), PRECISION(kFloat), DATALAYOUT(kNCHW),
#endif // ARM
#endif
};
optimizer.KernelPickPreferPlace(prefered_place);
optimizer.Run(std::move(program), places, factor, passes);
......@@ -72,3 +86,16 @@ USE_LITE_OP(mul);
USE_LITE_OP(feed);
USE_LITE_OP(fetch);
USE_LITE_OP(io_copy);
#ifdef LITE_WITH_X86
USE_LITE_KERNEL(mul, kX86, kFloat, kNCHW, def);
#endif
#ifdef LITE_WITH_ARM
USE_LITE_KERNEL(mul, kARM, kFloat, kNCHW, def);
#endif
#ifdef LITE_WITH_CUDA
USE_LITE_KERNEL(io_copy, kCUDA, kAny, kAny, host_to_device);
USE_LITE_KERNEL(io_copy, kCUDA, kAny, kAny, device_to_host);
#endif
......@@ -28,15 +28,23 @@ std::vector<std::unique_ptr<KernelBase>> OpLite::CreateKernels(
CHECK(!op_type_.empty()) << "op_type_ should be set first";
auto pick_kernel = [&](const Place &place) {
auto ks = KernelRegistry::Global().Create(
(kernel_type.empty() ? op_type_ : kernel_type), place.target,
place.precision, place.layout);
auto ks = KernelRegistry::Global().Create(op_type_, place.target,
place.precision, place.layout);
for (auto &&it : ks) {
AttachKernel(it.get());
kernels.emplace_back(std::move(it));
}
};
if (!kernel_type.empty()) {
Place place;
std::string op_type, alias;
KernelBase::ParseKernelType(kernel_type, &op_type, &alias, &place);
pick_kernel(place);
CHECK(!kernels.empty()) << "no kernel for kernel type " << kernel_type;
return kernels;
}
std::set<Place> place_set;
for (auto place : places) {
place_set.insert(place);
......@@ -53,7 +61,7 @@ std::vector<std::unique_ptr<KernelBase>> OpLite::CreateKernels(
targets.insert(place.target);
}
CHECK(!kernels.empty()) << "No kernel found for Op " << op_type_;
// CHECK(!kernels.empty()) << "No kernel found for Op " << op_type_;
VLOG(2) << "op " << op_type_ << " get " << kernels.size() << " kernels";
return kernels;
}
......
......@@ -147,7 +147,7 @@ class OpLite : public Registry {
class OpInfo : public cpp::OpDesc {
public:
OpInfo(const OpInfo &) = default;
OpInfo(const cpp::OpDesc &other) : cpp::OpDesc(other) {}
explicit OpInfo(const cpp::OpDesc &other) : cpp::OpDesc(other) {}
// Collect all the input variable's name.
std::vector<std::string> input_names() const {
......
......@@ -64,7 +64,7 @@ class Optimizer {
RunPasses(passes);
}
#endif
exec_scope_ = program.exec_scope;
exec_scope_ = program.exec_scope();
}
void KernelPickPreferPlace(const Place& place) {
......
......@@ -4,4 +4,3 @@ endif()
lite_cc_library(basic_profiler_lite SRCS basic_profiler.cc)
lite_cc_test(test_basic_profiler SRCS basic_profiler_test.cc DEPS basic_profiler_lite)
......@@ -62,5 +62,45 @@ void RuntimeProgram::SaveParams(const std::string &dir,
}
}
void Program::Build(const framework::proto::ProgramDesc &program) {
CHECK(ops_.empty()) << "Executor duplicate Build found";
// Create operators.
for (const auto &proto_op_desc : program.blocks(0).ops()) {
lite::OpDesc op_desc_dummy(proto_op_desc);
cpp::OpDesc op_desc;
TransformOpDescPbToCpp(op_desc_dummy, &op_desc);
auto op_type = op_desc.Type();
// if (op_type == "feed" || op_type == "fetch") continue;
VLOG(4) << "create Op [" << op_type << "]";
LOG(INFO) << "create Op [" << op_type << "]";
auto op = LiteOpRegistry::Global().Create(op_type);
CHECK(op) << "no Op found for " << op_type;
ops_.emplace_back(std::move(op));
ops_.back()->Attach(op_desc, exec_scope_);
}
}
void Program::PrepareWorkspace(const framework::proto::ProgramDesc &program) {
CHECK(!exec_scope_) << "Duplicate PrepareWorkspace found";
exec_scope_ = &scope_->NewScope();
// Create Feed and Fetch var.
scope_->Var("feed")->GetMutable<std::vector<lite::Tensor>>();
scope_->Var("fetch")->GetMutable<std::vector<lite::Tensor>>();
tmp_vars_.push_back("feed");
tmp_vars_.push_back("fetch");
CHECK(!program.blocks().empty());
for (auto proto_var_desc : program.blocks(0).vars()) {
lite::VarDesc var_desc(proto_var_desc);
if (!var_desc.Persistable()) {
tmp_vars_.push_back(var_desc.Name());
exec_scope_->Var(var_desc.Name());
} else {
if (var_desc.Name() == "feed" || var_desc.Name() == "fetch") continue;
weights_.push_back(var_desc.Name());
}
}
}
} // namespace lite
} // namespace paddle
......@@ -37,79 +37,54 @@ static const char kKernelTypeAttr[] = "__@kernel_type_attr@__";
// - main block, which is a list of OpLite
// - scope: which contains all the weights
struct Program {
std::list<std::string> tmp_vars;
std::list<std::string> weights;
std::list<std::shared_ptr<OpLite>> ops;
// the scope to run the kernels, NOTE this is the execution scope.
std::shared_ptr<lite::Scope> scope;
std::vector<Place> valid_places;
// Runtime scope.
lite::Scope* exec_scope{};
const framework::proto::ProgramDesc desc;
explicit Program(const std::shared_ptr<Scope>& root) { scope = root; }
public:
explicit Program(const std::shared_ptr<Scope>& root) { scope_ = root; }
Program(const framework::proto::ProgramDesc& desc,
const std::shared_ptr<Scope>& root,
const std::vector<Place>& valid_places)
: scope(root), valid_places(valid_places), desc(desc) {
CHECK(scope) << "scope should be init first";
: scope_(root), valid_places_(valid_places), desc_(desc) {
CHECK(scope_) << "scope should be init first";
PrepareWorkspace(desc);
Build(desc);
}
std::unique_ptr<Program> Clone() const {
std::unique_ptr<Program> res(new Program(desc, scope, valid_places));
std::unique_ptr<Program> res(new Program(desc_, scope_, valid_places_));
return res;
}
const std::list<std::string>& weights() const { return weights_; }
const std::list<std::string>& tmp_vars() const { return tmp_vars_; }
std::list<std::string>* mutable_weights() { return &weights_; }
std::list<std::string>* mutable_tmp_vars() { return &tmp_vars_; }
const std::list<std::shared_ptr<OpLite>>& ops() const { return ops_; }
std::list<std::shared_ptr<OpLite>>* mutable_ops() { return &ops_; }
lite::Scope* exec_scope() { return exec_scope_; }
lite::Scope* scope() { return scope_.get(); }
private:
// Build from a program and scope.
void Build(const framework::proto::ProgramDesc& program) {
CHECK(ops.empty()) << "Executor duplicate Build found";
// Create operators.
for (const auto& proto_op_desc : program.blocks(0).ops()) {
pb::OpDesc op_desc(proto_op_desc);
auto op_type = op_desc.Type();
// if (op_type == "feed" || op_type == "fetch") continue;
VLOG(4) << "create Op [" << op_type << "]";
LOG(INFO) << "create Op [" << op_type << "]";
auto op = LiteOpRegistry::Global().Create(op_type);
CHECK(op) << "no Op found for " << op_type;
ops.emplace_back(std::move(op));
cpp::OpDesc cpp_op_desc;
TransformOpDescPbToCpp(op_desc, &cpp_op_desc);
ops.back()->Attach(cpp_op_desc, exec_scope);
}
}
void Build(const framework::proto::ProgramDesc& program);
// Create temporary variables.
void PrepareWorkspace(const framework::proto::ProgramDesc& program) {
CHECK(!exec_scope) << "Duplicate PrepareWorkspace found";
exec_scope = &scope->NewScope();
// Create Feed and Fetch var.
scope->Var("feed")->GetMutable<std::vector<lite::Tensor>>();
scope->Var("fetch")->GetMutable<std::vector<lite::Tensor>>();
tmp_vars.push_back("feed");
tmp_vars.push_back("fetch");
CHECK(!program.blocks().empty());
for (auto proto_var_desc : program.blocks(0).vars()) {
lite::VarDesc var_desc(proto_var_desc);
if (!var_desc.Persistable()) {
tmp_vars.push_back(var_desc.Name());
exec_scope->Var(var_desc.Name());
} else {
if (var_desc.Name() == "feed" || var_desc.Name() == "fetch") continue;
weights.push_back(var_desc.Name());
}
}
}
void PrepareWorkspace(const framework::proto::ProgramDesc& program);
private:
std::list<std::string> tmp_vars_;
std::list<std::string> weights_;
std::list<std::shared_ptr<OpLite>> ops_;
// the scope to run the kernels, NOTE this is the execution scope.
std::shared_ptr<lite::Scope> scope_;
std::vector<Place> valid_places_;
// Runtime scope.
lite::Scope* exec_scope_{};
const framework::proto::ProgramDesc desc_;
};
struct Instruct {
Instruct(const std::shared_ptr<OpLite>& op,
std::unique_ptr<KernelBase>&& kernel)
struct Instruction {
Instruction(const std::shared_ptr<OpLite>& op,
std::unique_ptr<KernelBase>&& kernel)
: op_(op), kernel_(std::move(kernel)) {
#ifdef LITE_WITH_PROFILE
profile_id_ = profile::BasicProfiler<profile::BasicTimer>::Global()
......@@ -132,7 +107,7 @@ struct Instruct {
kernel_->Launch();
}
friend std::ostream& operator<<(std::ostream& os, const Instruct& other) {
friend std::ostream& operator<<(std::ostream& os, const Instruction& other) {
os << other.kernel_->summary() << "\t(" << other.kernel_->doc() << ")";
return os;
}
......@@ -156,7 +131,7 @@ struct Instruct {
*/
class RuntimeProgram {
public:
explicit RuntimeProgram(std::vector<Instruct>&& insts)
explicit RuntimeProgram(std::vector<Instruction>&& insts)
: instructions_(std::move(insts)) {
if (instructions_.empty()) {
LOG(FATAL) << "no instructions";
......@@ -186,7 +161,7 @@ class RuntimeProgram {
private:
RuntimeProgram(const RuntimeProgram&) = delete;
std::vector<Instruct> instructions_;
std::vector<Instruction> instructions_;
lite::Scope* exec_scope_{};
};
......
......@@ -33,11 +33,11 @@ Program FakeProgram() {
std::string w1 = "w" + std::to_string(id);
std::string b1 = "b" + std::to_string(id);
std::string out1 = "out" + std::to_string(id);
auto w1v = program.scope->Var(w1)->GetMutable<lite::Tensor>();
auto b1v = program.scope->Var(b1)->GetMutable<lite::Tensor>();
auto out1v = program.scope->Var(out1)->GetMutable<lite::Tensor>();
auto w1v = program.scope()->Var(w1)->GetMutable<lite::Tensor>();
auto b1v = program.scope()->Var(b1)->GetMutable<lite::Tensor>();
auto out1v = program.scope()->Var(out1)->GetMutable<lite::Tensor>();
lite::OpDesc desc;
cpp::OpDesc desc;
desc.SetInput("Input", {x});
desc.SetInput("W", {w1});
desc.SetInput("Bias", {b1});
......@@ -46,12 +46,12 @@ Program FakeProgram() {
desc.SetAttr("in_num_col_dims", 1);
// add to input
program.tmp_vars.push_back(w1);
program.tmp_vars.push_back(b1);
program.mutable_tmp_vars()->push_back(w1);
program.mutable_tmp_vars()->push_back(b1);
auto fc_op = LiteOpRegistry::Global().Create("fc");
fc_op->Attach(desc, program.scope.get());
program.ops.emplace_back(std::move(fc_op));
fc_op->Attach(desc, program.scope());
program.mutable_ops()->emplace_back(std::move(fc_op));
w1v->Resize(DDimHvy(std::vector<int64_t>({100, 100})));
b1v->Resize(DDimHvy(std::vector<int64_t>({100, 1})));
......@@ -64,8 +64,8 @@ Program FakeProgram() {
// out1, w2, b2 -fc-> out2
std::string x = "x";
program.tmp_vars.push_back(x);
auto* xv = program.scope->Var(x)->GetMutable<lite::Tensor>();
program.mutable_tmp_vars()->push_back(x);
auto* xv = program.scope()->Var(x)->GetMutable<lite::Tensor>();
xv->Resize(DDimHvy(std::vector<int64_t>({100, 100})));
for (int i = 0; i < 3; i++) {
......
......@@ -17,7 +17,13 @@
namespace paddle {
namespace lite {
Scope::~Scope() {}
Scope::~Scope() {
for (auto *x : kids_) {
if (x) {
delete x;
}
}
}
Scope &Scope::NewScope() const {
kids_.push_back(new Scope);
......
......@@ -63,7 +63,8 @@ static const std::string& TargetToStr(TargetType target) {
}
static const std::string& PrecisionToStr(PrecisionType precision) {
static const std::string precision2string[] = {"unk", "float", "int8", "any"};
static const std::string precision2string[] = {"unk", "float", "int8_t",
"any"};
auto x = static_cast<int>(precision);
CHECK_LT(x, static_cast<int>(PRECISION(NUM)));
return precision2string[x];
......@@ -76,6 +77,29 @@ static const std::string& DataLayoutToStr(DataLayoutType layout) {
return datalayout2string[x];
}
static const std::string& TargetRepr(TargetType target) {
static const std::string target2string[] = {"kUnk", "kHost", "kX86", "kCUDA",
"kAny"};
auto x = static_cast<int>(target);
CHECK_LT(x, static_cast<int>(TARGET(NUM)));
return target2string[x];
}
static const std::string& PrecisionRepr(PrecisionType precision) {
static const std::string precision2string[] = {"kUnk", "kFloat", "kInt8",
"kAny"};
auto x = static_cast<int>(precision);
CHECK_LT(x, static_cast<int>(PRECISION(NUM)));
return precision2string[x];
}
static const std::string& DataLayoutRepr(DataLayoutType layout) {
static const std::string datalayout2string[] = {"kUnk", "kNCHW", "kAny"};
auto x = static_cast<int>(layout);
CHECK_LT(x, static_cast<int>(DATALAYOUT(NUM)));
return datalayout2string[x];
}
/*
* Place specifies the execution context of a Kernel or input/output for a
* kernel. It is used to make the analysis of the MIR more clear and accurate.
......@@ -228,5 +252,20 @@ class TargetWrapper<TARGET(kCUDA), cudaStream_t, cudaEvent_t> {
};
#endif // LITE_WITH_CUDA
template <TargetType Target>
void CopySync(void* dst, void* src, size_t size, IoDirection dir) {
switch (Target) {
case TARGET(kX86):
case TARGET(kHost):
case TARGET(kARM):
TargetWrapperX86::MemcpySync(dst, src, size, IoDirection::HtoH);
break;
#ifdef LITE_WITH_CUDA
case TARGET(kCUDA):
TargetWrapperCuda::MemcpySync(dst, src, size, dir);
#endif
}
}
} // namespace lite
} // namespace paddle
......@@ -21,6 +21,7 @@
* looks the same.
*/
#include <string>
#include <vector>
#include "paddle/fluid/lite/core/target_wrapper.h"
......@@ -47,7 +48,8 @@ class DDimBase {
DDimBase() = default;
explicit DDimBase(const std::vector<int64_t> &x) { self()->ConstructFrom(x); }
value_type operator[](int offset) const { return (*self())[offset]; }
value_type operator[](int offset) const { return (*const_self())[offset]; }
value_type &operator[](int offset) { return (*self())[offset]; }
std::vector<int64_t> Vectorize() const { return self()->Vectorize(); }
size_t size() const { return const_self()->size(); }
bool empty() const { return const_self()->empty(); }
......@@ -73,18 +75,19 @@ class DDimBase {
{Slice(0, col).production(), Slice(col, size()).production()}));
}
friend std::ostream &operator<<(std::ostream &os, const DDimT &dims) {
if (dims.empty()) {
os << "[]";
return os;
std::string repr() const {
std::stringstream ss;
ss << "{";
for (size_t i = 0; i < this->size() - 1; i++) {
ss << (*this)[i] << ",";
}
if (!this->empty()) ss << (*this)[size() - 1];
ss << "}";
return ss.str();
}
os << "[";
for (size_t i = 0; i < dims.size() - 1; i++) {
os << dims[i] << " ";
}
if (!dims.empty()) os << dims[dims.size() - 1];
os << "]";
friend std::ostream &operator<<(std::ostream &os, const DDimT &dims) {
os << dims.repr();
return os;
}
......@@ -102,6 +105,12 @@ template <typename TensorT>
class TensorBase {
public:
TensorBase() = default;
template <typename T, typename DimT>
void Assign(T *data, const DimT &dim) {
self()->Assign(data, dim);
}
TargetType target() const { return self()->target(); }
template <typename T>
......
......@@ -24,7 +24,7 @@ namespace lite {
class Variable {
public:
template <typename T>
const T& Get() {
const T& Get() const {
return blob_.get<T>();
}
......
......@@ -4,4 +4,3 @@ endif()
nv_library(target_wrapper_cuda SRCS target_wrapper.cc)
nv_library(cuda_blas_lite SRCS blas.cc)
cc_library(target_wrapper_host SRCS target_wrapper.cc)
......@@ -5,4 +5,3 @@ add_subdirectory(arm)
add_subdirectory(cuda)
add_subdirectory(x86)
......@@ -46,7 +46,6 @@ void ConvCompute::PrepareForRun() {
const auto* b_data = param.bias ? param.bias->data<float>() : nullptr;
auto* o_data = param.output->mutable_data<float>();
// TODO(xxx): create should be somewhere better!
bool kps_equal = (param.paddings[0] == param.paddings[1]) &&
(param.strides[0] == param.strides[1]) && (kw == kh);
bool no_dilation = (param.dilations[0] == 1) && (param.dilations[1] == 1);
......@@ -60,26 +59,26 @@ void ConvCompute::PrepareForRun() {
if (param.groups == ic && ic == oc && kps_equal && no_dilation && flag_dw) {
// dw conv impl
impl_ = new lite::arm::math::DepthwiseConv<PRECISION(kFloat)>;
// LOG(INFO) << "invoking dw conv";
VLOG(3) << "invoking dw conv";
} else if (param.groups == 1 && kw == 3 && stride == 1 && kps_equal &&
no_dilation) {
if (ic >= 32 && oc >= 32 && oh > 16 && ow > 16) {
// winograd conv impl
impl_ = new lite::arm::math::WinogradConv<PRECISION(kFloat)>;
// LOG(INFO) << "invoking winograd conv";
VLOG(3) << "invoking winograd conv";
} else {
// direct conv impl
impl_ = new lite::arm::math::DirectConv<PRECISION(kFloat)>;
// LOG(INFO) << "invoking direct conv";
VLOG(3) << "invoking direct conv";
}
} else if (param.groups == 1 && kw == 3 && stride == 2 && kps_equal &&
no_dilation) {
// direct conv impl
impl_ = new lite::arm::math::DirectConv<PRECISION(kFloat)>;
// LOG(INFO) << "invoking direct conv";
VLOG(3) << "invoking direct conv";
} else {
impl_ = new lite::arm::math::GemmLikeConv<PRECISION(kFloat)>;
// LOG(INFO) << "invoking gemm like conv";
VLOG(3) << "invoking gemm like conv";
}
CHECK(this->impl_->create(param, &ctx));
}
......
......@@ -56,7 +56,7 @@ void FcCompute::Run() {
} else {
// use sgemmv
// sgemv((const float*)weights, (const float*)din, (float*)dout,
// false, n, x_w, param_->_flag_bias, (float*)bias, false);
// false, n, x_w, _param->_flag_bias, (float*)bias, false);
}
}
......
......@@ -9,4 +9,3 @@ cc_library(io_copy_compute_cuda SRCS io_copy_compute.cc DEPS ${tensor_lite})
nv_library(kernels_cuda DEPS mul_compute_cuda io_copy_compute_cuda cuda_blas_lite)
......@@ -12,5 +12,4 @@ set(host_kernels
reshape_compute_host
)
set(host_kernels "${host_kernels}" CACHE INTERNAL "host kernels")
set(host_kernels "${host_kernels}" CACHE GLOBAL "host kernels")
......@@ -15,6 +15,8 @@ cc_library(elementwise_compute_x86 SRCS elementwise_compute.cc DEPS ${lite_kerne
cc_library(softmax_compute_x86 SRCS softmax_compute.cc DEPS ${lite_kernel_deps} softmax)
cc_library(dropout_compute_x86 SRCS dropout_compute.cc DEPS ${lite_kernel_deps} )
cc_library(concat_compute_x86 SRCS concat_compute.cc DEPS ${lite_kernel_deps} )
cc_library(conv_compute_x86 SRCS conv_compute.cc DEPS ${lite_kernel_deps} blas im2col vol2col)
cc_library(pool_compute_x86 SRCS pool_compute.cc DEPS ${lite_kernel_deps} pooling)
set(x86_kernels
activation_compute_x86
......@@ -25,10 +27,11 @@ set(x86_kernels
relu_compute_x86
fc_compute_x86
scale_compute_x86
softmax_compute_x86
softmax_compute_x86
dropout_compute_x86
concat_compute_x86
conv_compute_x86
pool_compute_x86
)
set(x86_kernels "${x86_kernels}" CACHE INTERNAL "x86 kernels")
......@@ -13,7 +13,8 @@
// limitations under the License.
#include "paddle/fluid/lite/model_parser/compatible_pb.h"
#include "compatible_pb.h"
#include <string>
#include <vector>
namespace paddle {
namespace lite {
......
......@@ -14,6 +14,7 @@
#include "paddle/fluid/lite/model_parser/cpp/op_desc.h"
#include <set>
#include <utility>
namespace paddle {
namespace lite {
......@@ -44,12 +45,13 @@ FindAttr(const cpp::OpDesc& desc, const std::string& name) {
return std::make_pair(it, attr_it);
}
#define GET_IMPL_ONE(T, repr__) \
template <> \
T OpDesc::GetAttr<T>(const std::string& name) const { \
auto pair = FindAttr(*this, name); \
CHECK(pair.second->second == AttrType::repr__); \
return pair.first->second.get<T>(); \
#define GET_IMPL_ONE(T, repr__) \
template <> \
T OpDesc::GetAttr<T>(const std::string& name) const { \
auto pair = FindAttr(*this, name); \
CHECK(pair.second->second == AttrType::repr__) \
<< "required type is " << #repr__ << " not match the true type"; \
return pair.first->second.get<T>(); \
}
GET_IMPL_ONE(int32_t, INT);
......
......@@ -44,7 +44,7 @@ FindAttr(framework::proto::OpDesc *desc, const std::string &name) {
}
SET_IMPL_ONE(int, INT, i);
SET_IMPL_ONE(float, FLOAT, f);
SET_IMPL_ONE(bool, FLOAT, f);
SET_IMPL_ONE(bool, BOOLEAN, b);
template <>
void OpDesc::SetAttr<std::vector<int>>(const std::string &name,
......
set(op_DEPS ${tensor_lite} op_lite op_params_lite)
cc_library(conv_op_lite SRCS conv_op.cc DEPS ${op_DEPS})
cc_library(pool_op_lite SRCS pool_op.cc DEPS ${op_DEPS})
cc_library(fc_op_lite SRCS fc_op.cc DEPS ${op_DEPS})
cc_library(relu_op_lite SRCS relu_op.cc DEPS ${op_DEPS})
cc_library(mul_op_lite SRCS mul_op.cc DEPS ${op_DEPS})
......@@ -18,10 +19,10 @@ cc_library(fill_constant_op_lite SRCS fill_constant_op.cc DEPS ${op_DEPS})
cc_library(op_params_lite SRCS op_params.cc DEPS ${tensor_lite} any_lite framework_proto_lite)
cc_library(dropout_op_lite SRCS dropout_op.cc DEPS ${op_DEPS})
cc_library(concat_op_lite SRCS concat_op.cc DEPS ${op_DEPS})
cc_library(pool_op_lite SRCS pool_op.cc DEPS ${op_DEPS})
set(ops_lite
conv_op_lite
pool_op_lite
fc_op_lite
relu_op_lite
mul_op_lite
......@@ -42,11 +43,11 @@ set(ops_lite
lite_cc_test(test_fc_op_lite SRCS fc_op_test.cc
DEPS fc_op_lite memory_lite
X86_DEPS fc_compute_x86
ARM_DEPS fc_compute_arm)
ARM_DEPS fc_compute_arm)
lite_cc_test(test_pool_op_lite SRCS pool_op_test.cc
DEPS pool_op_lite memory_lite
ARM_DEPS pool_compute_arm)
lite_cc_test(test_scale_op_lite SRCS scale_op_test.cc DEPS scale_op_lite memory_lite)
lite_cc_test(test_softmax_op_lite SRCS softmax_op_test.cc DEPS softmax_op_lite memory_lite)
lite_cc_test(test_reshape_op_lite SRCS reshape_op_test.cc DEPS reshape_op_lite memory_lite)
lite_cc_test(test_concat_op_lite SRCS concat_op_test.cc DEPS concat_op_lite memory_lite)
lite_cc_test(test_pool_op_lite SRCS pool_op_test.cc
DEPS pool_op_lite memory_lite
ARM_DEPS pool_compute_arm)
......@@ -13,6 +13,7 @@
// limitations under the License.
#include "paddle/fluid/lite/operators/conv_op.h"
#include <vector>
#include "paddle/fluid/lite/core/op_registry.h"
namespace paddle {
......@@ -74,4 +75,4 @@ bool ConvOpLite::InferShape() const {
} // namespace paddle
REGISTER_LITE_OP(conv2d, paddle::lite::operators::ConvOpLite);
REGISTER_LITE_OP(depthwise_conv2d, paddle::lite::operators::ConvOpLite);
\ No newline at end of file
REGISTER_LITE_OP(depthwise_conv2d, paddle::lite::operators::ConvOpLite);
......@@ -13,7 +13,6 @@
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/lite/core/compatible_tensor.h"
......@@ -60,12 +59,13 @@ class ConvOpLite : public OpLite {
const_cast<lite::Tensor *>(&(bias_var->Get<lite::Tensor>()));
}
}
if (std::find(input_arg_names.begin(), input_arg_names.end(), "ResidualData") !=
input_arg_names.end()) {
auto residual_data_var = scope->FindVar(op_desc.Input("ResidualData").front());
if (std::find(input_arg_names.begin(), input_arg_names.end(),
"ResidualData") != input_arg_names.end()) {
auto residual_data_var =
scope->FindVar(op_desc.Input("ResidualData").front());
if (residual_data_var != nullptr) {
param_.residualData =
const_cast<lite::Tensor *>(&(residual_data_var->Get<lite::Tensor>()));
param_.residualData = const_cast<lite::Tensor *>(
&(residual_data_var->Get<lite::Tensor>()));
}
}
return true;
......
......@@ -38,8 +38,8 @@ class FeedOp : public OpLite {
auto feed_var_name = opdesc.Input("X").front();
auto* feed_var = scope->FindVar(feed_var_name);
CHECK(feed_var);
auto& feed_tensor_list = feed_var->Get<std::vector<lite::Tensor>>();
param_.feed_list = &feed_tensor_list;
auto* feed_tensor_list = feed_var->GetMutable<std::vector<lite::Tensor>>();
param_.feed_list = feed_tensor_list;
auto out_name = opdesc.Output("Out").front();
auto* out_var = scope->FindVar(out_name);
......
......@@ -45,10 +45,11 @@ class MulOpLite : public OpLite {
CHECK(var);
param_.x = var->GetMutable<Tensor>();
var = scope->FindVar(W);
CHECK(var);
CHECK(var) << "no var called " << W;
param_.y = var->GetMutable<Tensor>();
CHECK(scope->FindVar(out));
param_.output = scope->FindVar(out)->GetMutable<Tensor>();
var = scope->FindVar(out);
CHECK(var) << "no var called " << out;
param_.output = var->GetMutable<Tensor>();
param_.x_num_col_dims = op_desc.GetAttr<int>("x_num_col_dims");
param_.y_num_col_dims = op_desc.GetAttr<int>("y_num_col_dims");
......
......@@ -85,4 +85,4 @@ bool PoolOpLite::InferShape() const {
} // namespace lite
} // namespace paddle
REGISTER_LITE_OP(pool, paddle::lite::operators::PoolOpLite);
REGISTER_LITE_OP(pool2d, paddle::lite::operators::PoolOpLite);
......@@ -37,14 +37,6 @@ class PoolOpLite : public OpLite {
bool InferShape() const override;
/*
bool Run() override {
CHECK(kernel_);
kernel_->Run();
return true;
}
*/
// TODO(Superjomn) replace framework::OpDesc with a lite one.
bool AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) override {
auto x = op_desc.Input("X").front();
......
......@@ -88,4 +88,4 @@ RUN pip install -i https://pypi.tuna.tsinghua.edu.cn/simple wheel
RUN pip install -i https://pypi.tuna.tsinghua.edu.cn/simple pre-commit
RUN apt-get autoremove -y && apt-get clean
RUN rm -rf /sdk-tools-linux-4333796.zip /tmp/android-ndk-r17c-linux-x86_64.zip /cmake-3.10.3-Linux-x86_64.tar.gz
\ No newline at end of file
......@@ -2,17 +2,29 @@
set -ex
TESTS_FILE="./lite_tests.txt"
LIBS_FILE="./lite_libs.txt"
readonly common_flags="-DWITH_LITE=ON -DLITE_WITH_LIGHT_WEIGHT_FRAMEWORK=OFF -DWITH_PYTHON=OFF -DWITH_TESTING=ON -DLITE_WITH_ARM=OFF"
# for code gen, a source file is generated after a test, but is dependended by some targets in cmake.
# here we fake an empty file to make cmake works.
function prepare_for_codegen {
# in build directory
mkdir -p ./paddle/fluid/lite/gen_code
touch ./paddle/fluid/lite/gen_code/__generated_code__.cc
}
function cmake_x86 {
prepare_for_codegen
cmake .. -DWITH_GPU=OFF -DWITH_MKLDNN=OFF -DLITE_WITH_X86=ON ${common_flags}
}
function cmake_x86_for_CI {
prepare_for_codegen
cmake .. -DWITH_GPU=OFF -DWITH_MKLDNN=OFF -DLITE_WITH_X86=ON ${common_flags} -DLITE_WITH_PROFILE=ON
}
function cmake_gpu {
prepare_for_codegen
cmake .. " -DWITH_GPU=ON {common_flags} -DLITE_WITH_GPU=ON"
}
......@@ -34,7 +46,7 @@ function cmake_arm {
function build {
file=$1
for _test in $(cat $file); do
make $_test -j$(expr $(nproc))
make $_test -j$(expr $(nproc) - 2)
done
}
......@@ -42,7 +54,11 @@ function build {
function test_lite {
local file=$1
echo "file: ${file}"
for _test in $(cat $file); do
# We move the build phase here to make the 'gen_code' test compiles after the
# corresponding test is executed and the C++ code generates.
make $_test -j$(expr $(nproc) - 2)
ctest -R $_test -V
done
}
......@@ -86,8 +102,10 @@ function build_test_server {
cd ./build
export LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/paddle/build/third_party/install/mklml/lib"
cmake_x86_for_CI
build $TESTS_FILE
# compile the tests and execute them.
test_lite $TESTS_FILE
# build the remaining libraries to check compiling error.
build $LIBS_FILE
}
# Build the code and run lite server tests. This is executed in the CI system.
......@@ -117,7 +135,6 @@ function build_test_arm {
build_dir=build.lite.${os}.${abi}
mkdir -p $build_dir
cd $build_dir
cmake_arm ${os} ${abi}
build $TESTS_FILE
......@@ -167,6 +184,7 @@ function main {
;;
build)
build $TESTS_FILE
build $LIBS_FILE
shift
;;
cmake_x86)
......
......@@ -8,5 +8,4 @@ set(utils_DEPS glog)
lite_cc_test(test_varient SRCS varient_test.cc DEPS utils_lite)
cc_library(any_lite SRCS any.cc)
cc_library(utils_lite SRCS cp_logging.cc DEPS ${utils_DEPS} any_lite)
cc_library(utils_lite SRCS cp_logging.cc string.cc DEPS ${utils_DEPS} any_lite)
......@@ -4,4 +4,3 @@ endif()
cc_library(target_wrapper_x86 SRCS target_wrapper.cc)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册