提交 72b734e4 编写于 作者: S superjomn

make lite without cuda works

上级 70540d1b
if(LITE_WITH_CUDA)
cc_library(cxx_api_lite_cuda SRCS cxx_api.cc DEPS scope_lite op_executor_lite host_kernels ops_lite optimizer_lite target_wrapper_host target_wrapper_cuda kernels_cuda)
cc_library(cxx_api_lite_cuda SRCS cxx_api.cc DEPS scope_lite host_kernels ops_lite optimizer_lite target_wrapper_host target_wrapper_cuda kernels_cuda)
nv_test(test_cxx_api_lite_cuda SRCS cxx_api_test.cc DEPS cxx_api_lite_cuda model_parser_lite)
else()
cc_library(cxx_api_lite SRCS cxx_api.cc DEPS scope_lite op_executor_lite host_kernels ops_lite optimizer_lite target_wrapper_host )
......
......@@ -13,6 +13,9 @@
// limitations under the License.
#include "paddle/fluid/lite/core/mir/io_complement_pass.h"
#include <list>
#include <string>
#include <vector>
#include "paddle/fluid/lite/core/mir/graph_visualize_pass.h"
#include "paddle/fluid/lite/core/mir/pass_registry.h"
......@@ -55,8 +58,8 @@ void IoComplementPass::ComplementInputs(SSAGraph* graph, Node* inst_node,
CHECK(inst.op_info()->GetInputArgname(in_arg_name, &tmp));
auto decl_arg_type = inst.picked_kernel().GetInputDeclType(tmp);
CHECK(in->AsArgument().type);
if (!TypeCompatibleTo(*in->AsArgument().type, *decl_arg_type)) {
LOG(INFO) << "found IO unmatched tensor: " << in->AsArgument().name
if (!TargetCompatibleTo(*in->AsArgument().type, *decl_arg_type)) {
LOG(INFO) << "found Target unmatched tensor: " << in->AsArgument().name
<< " for kernel " << inst.op->DebugString() << " "
<< *in->AsArgument().type << " -> " << *decl_arg_type;
// Add an IoCopy instruction to make the input compatible with other dist.
......
......@@ -38,9 +38,11 @@ class RuntimeContextAssignPass : public InstructionPass {
case TARGET(kX86):
inst.picked_kernel().SetContext(NewHostContext());
break;
#ifdef LITE_WITH_CUDA
case TARGET(kCUDA):
inst.picked_kernel().SetContext(NewCudaContext());
break;
#endif
default:
LOG(FATAL) << "unsupported target "
<< TargetToStr(inst.picked_kernel().target());
......@@ -55,6 +57,7 @@ class RuntimeContextAssignPass : public InstructionPass {
return ctx;
}
#ifdef LITE_WITH_CUDA
std::unique_ptr<KernelContext> NewCudaContext() {
std::unique_ptr<KernelContext> ctx(new KernelContext);
auto& cuda = ctx->AsCudaContext();
......@@ -63,6 +66,7 @@ class RuntimeContextAssignPass : public InstructionPass {
cuda.blas_fp32 = cublas_fp32_;
return ctx;
}
#endif
void InitCudaBlas() {
cublas_fp32_ = std::make_shared<lite::cuda::Blas<float>>();
......
......@@ -183,9 +183,18 @@ class Type : public DataTypeBase {
// -------------------------------- compatible check ---------------------------
static bool TargetCompatibleTo(const Type& a, const Type& b) {
return a.IsVoid() || //
(a.IsTensor() && b.IsTensor() && (a.target() == b.target() || //
b.target() == TARGET(kAny)));
auto is_host = [](TargetType x) {
return x == TARGET(kHost) || x == TARGET(kX86);
};
if (a.IsVoid() || b.IsVoid()) return true;
if (a.IsTensor() || b.IsTensor()) {
if (a.IsTensor() && b.IsTensor()) {
return is_host(a.target()) ? is_host(b.target())
: a.target() == b.target();
}
return false;
}
return true;
}
static bool DataLayoutCompatibleTo(const Type& a, const Type& b) {
......@@ -224,32 +233,32 @@ class UnsupportedTy : public Type {
};
class TensorAnyTy : public Type {
public:
TensorAnyTy(TargetType target)
explicit TensorAnyTy(TargetType target)
: Type(ID::Tensor_Any, "TensorAny", true, target, PRECISION(kAny),
DATALAYOUT(kAny)) {}
};
// A list of tensor, and no assumption on the data layout or data type.
class TensorListAnyTy : public Type {
public:
TensorListAnyTy(TargetType target)
explicit TensorListAnyTy(TargetType target)
: Type(ID::TensorList_Any, "TensorList_Any", false, target,
PRECISION(kAny), DATALAYOUT(kAny)) {}
};
class TensorFp32NCHWTy : public Type {
public:
TensorFp32NCHWTy(TargetType target)
explicit TensorFp32NCHWTy(TargetType target)
: Type(ID::Tensor_Fp32_NCHW, "TensorFp32NCHW", true /*is_tensor*/, target,
PrecisionType::kFloat, DataLayoutType::kNCHW) {}
};
class TensorInt8NCHWTy : public Type {
public:
TensorInt8NCHWTy(TargetType target)
explicit TensorInt8NCHWTy(TargetType target)
: Type(ID::Tensor_Int8_NCHW, "TensorInt8NCHW", true /*is_tensor*/, target,
PrecisionType::kInt8, DataLayoutType::kNCHW) {}
};
class TensorInt64NCHWTy : public Type {
public:
TensorInt64NCHWTy(TargetType target)
explicit TensorInt64NCHWTy(TargetType target)
: Type(ID::Tensor_Int64_NCHW, "TensorInt64NCHW", true /*is_tensor*/,
target, PrecisionType::kInt8, DataLayoutType::kNCHW) {}
};
......@@ -270,12 +279,14 @@ struct ParamType {
Place tensor_place{};
const Type* type;
explicit ParamType() = default;
ParamType() = default;
explicit ParamType(size_t element_type_hash)
: element_type_hash(element_type_hash) {}
ParamType(size_t element_type_hash, const Place& place)
: element_type_hash(element_type_hash), tensor_place(place) {}
ParamType(const Type* type) : type(type) { tensor_place = type->place(); }
explicit ParamType(const Type* type) : type(type) {
tensor_place = type->place();
}
std::string DebugString() const { return tensor_place.DebugString(); }
};
......
cc_library(fc_op_lite SRCS fc_op.cc DEPS op_lite op_params_lite tensor_lite proto_desc)
cc_library(fc_op_lite SRCS fc_op.cc DEPS op_lite op_params_lite tensor_lite)
cc_library(relu_op_lite SRCS relu_op.cc DEPS op_lite)
cc_library(mul_op_lite SRCS mul_op.cc DEPS op_lite)
cc_library(scale_op_lite SRCS scale_op.cc DEPS op_lite)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册