diff --git a/paddle/fluid/lite/api/CMakeLists.txt b/paddle/fluid/lite/api/CMakeLists.txt index 56262c92129f318ba94cb9b2d7b7889d8e69c5b5..04587cce2fc742fd64f8b9236f9175a7e5f11e89 100644 --- a/paddle/fluid/lite/api/CMakeLists.txt +++ b/paddle/fluid/lite/api/CMakeLists.txt @@ -1,5 +1,5 @@ if(LITE_WITH_CUDA) - cc_library(cxx_api_lite_cuda SRCS cxx_api.cc DEPS scope_lite op_executor_lite host_kernels ops_lite optimizer_lite target_wrapper_host target_wrapper_cuda kernels_cuda) + cc_library(cxx_api_lite_cuda SRCS cxx_api.cc DEPS scope_lite host_kernels ops_lite optimizer_lite target_wrapper_host target_wrapper_cuda kernels_cuda) nv_test(test_cxx_api_lite_cuda SRCS cxx_api_test.cc DEPS cxx_api_lite_cuda model_parser_lite) else() cc_library(cxx_api_lite SRCS cxx_api.cc DEPS scope_lite op_executor_lite host_kernels ops_lite optimizer_lite target_wrapper_host ) diff --git a/paddle/fluid/lite/core/mir/io_complement_pass.cc b/paddle/fluid/lite/core/mir/io_complement_pass.cc index 4adba67673e0d8697d8979e86e1d1c53bcc6aa44..d01d08faefa431cd90bb22b1e8452a8018d7ed59 100644 --- a/paddle/fluid/lite/core/mir/io_complement_pass.cc +++ b/paddle/fluid/lite/core/mir/io_complement_pass.cc @@ -13,6 +13,9 @@ // limitations under the License. #include "paddle/fluid/lite/core/mir/io_complement_pass.h" +#include +#include +#include #include "paddle/fluid/lite/core/mir/graph_visualize_pass.h" #include "paddle/fluid/lite/core/mir/pass_registry.h" @@ -55,8 +58,8 @@ void IoComplementPass::ComplementInputs(SSAGraph* graph, Node* inst_node, CHECK(inst.op_info()->GetInputArgname(in_arg_name, &tmp)); auto decl_arg_type = inst.picked_kernel().GetInputDeclType(tmp); CHECK(in->AsArgument().type); - if (!TypeCompatibleTo(*in->AsArgument().type, *decl_arg_type)) { - LOG(INFO) << "found IO unmatched tensor: " << in->AsArgument().name + if (!TargetCompatibleTo(*in->AsArgument().type, *decl_arg_type)) { + LOG(INFO) << "found Target unmatched tensor: " << in->AsArgument().name << " for kernel " << inst.op->DebugString() << " " << *in->AsArgument().type << " -> " << *decl_arg_type; // Add an IoCopy instruction to make the input compatible with other dist. diff --git a/paddle/fluid/lite/core/mir/runtime_context_assign_pass.cc b/paddle/fluid/lite/core/mir/runtime_context_assign_pass.cc index 46275ed4d99a0d55fd363061ab082e7d60e449d4..e3157f1f6ebdd1bdde80d3d91f756226e2072ae7 100644 --- a/paddle/fluid/lite/core/mir/runtime_context_assign_pass.cc +++ b/paddle/fluid/lite/core/mir/runtime_context_assign_pass.cc @@ -38,9 +38,11 @@ class RuntimeContextAssignPass : public InstructionPass { case TARGET(kX86): inst.picked_kernel().SetContext(NewHostContext()); break; +#ifdef LITE_WITH_CUDA case TARGET(kCUDA): inst.picked_kernel().SetContext(NewCudaContext()); break; +#endif default: LOG(FATAL) << "unsupported target " << TargetToStr(inst.picked_kernel().target()); @@ -55,6 +57,7 @@ class RuntimeContextAssignPass : public InstructionPass { return ctx; } +#ifdef LITE_WITH_CUDA std::unique_ptr NewCudaContext() { std::unique_ptr ctx(new KernelContext); auto& cuda = ctx->AsCudaContext(); @@ -63,6 +66,7 @@ class RuntimeContextAssignPass : public InstructionPass { cuda.blas_fp32 = cublas_fp32_; return ctx; } +#endif void InitCudaBlas() { cublas_fp32_ = std::make_shared>(); diff --git a/paddle/fluid/lite/core/type_system.h b/paddle/fluid/lite/core/type_system.h index 597712665a3c4f836480d4e7c59ebfb0b18540de..a141f0afbbba899b54e719662c4f990e0a0477a8 100644 --- a/paddle/fluid/lite/core/type_system.h +++ b/paddle/fluid/lite/core/type_system.h @@ -183,9 +183,18 @@ class Type : public DataTypeBase { // -------------------------------- compatible check --------------------------- static bool TargetCompatibleTo(const Type& a, const Type& b) { - return a.IsVoid() || // - (a.IsTensor() && b.IsTensor() && (a.target() == b.target() || // - b.target() == TARGET(kAny))); + auto is_host = [](TargetType x) { + return x == TARGET(kHost) || x == TARGET(kX86); + }; + if (a.IsVoid() || b.IsVoid()) return true; + if (a.IsTensor() || b.IsTensor()) { + if (a.IsTensor() && b.IsTensor()) { + return is_host(a.target()) ? is_host(b.target()) + : a.target() == b.target(); + } + return false; + } + return true; } static bool DataLayoutCompatibleTo(const Type& a, const Type& b) { @@ -224,32 +233,32 @@ class UnsupportedTy : public Type { }; class TensorAnyTy : public Type { public: - TensorAnyTy(TargetType target) + explicit TensorAnyTy(TargetType target) : Type(ID::Tensor_Any, "TensorAny", true, target, PRECISION(kAny), DATALAYOUT(kAny)) {} }; // A list of tensor, and no assumption on the data layout or data type. class TensorListAnyTy : public Type { public: - TensorListAnyTy(TargetType target) + explicit TensorListAnyTy(TargetType target) : Type(ID::TensorList_Any, "TensorList_Any", false, target, PRECISION(kAny), DATALAYOUT(kAny)) {} }; class TensorFp32NCHWTy : public Type { public: - TensorFp32NCHWTy(TargetType target) + explicit TensorFp32NCHWTy(TargetType target) : Type(ID::Tensor_Fp32_NCHW, "TensorFp32NCHW", true /*is_tensor*/, target, PrecisionType::kFloat, DataLayoutType::kNCHW) {} }; class TensorInt8NCHWTy : public Type { public: - TensorInt8NCHWTy(TargetType target) + explicit TensorInt8NCHWTy(TargetType target) : Type(ID::Tensor_Int8_NCHW, "TensorInt8NCHW", true /*is_tensor*/, target, PrecisionType::kInt8, DataLayoutType::kNCHW) {} }; class TensorInt64NCHWTy : public Type { public: - TensorInt64NCHWTy(TargetType target) + explicit TensorInt64NCHWTy(TargetType target) : Type(ID::Tensor_Int64_NCHW, "TensorInt64NCHW", true /*is_tensor*/, target, PrecisionType::kInt8, DataLayoutType::kNCHW) {} }; @@ -270,12 +279,14 @@ struct ParamType { Place tensor_place{}; const Type* type; - explicit ParamType() = default; + ParamType() = default; explicit ParamType(size_t element_type_hash) : element_type_hash(element_type_hash) {} ParamType(size_t element_type_hash, const Place& place) : element_type_hash(element_type_hash), tensor_place(place) {} - ParamType(const Type* type) : type(type) { tensor_place = type->place(); } + explicit ParamType(const Type* type) : type(type) { + tensor_place = type->place(); + } std::string DebugString() const { return tensor_place.DebugString(); } }; diff --git a/paddle/fluid/lite/operators/CMakeLists.txt b/paddle/fluid/lite/operators/CMakeLists.txt index 004178f2f6a497087164fb9fdd55c239ecef9aff..8b0c1b236f2bedd317ba043e190e9ade88ef96d1 100644 --- a/paddle/fluid/lite/operators/CMakeLists.txt +++ b/paddle/fluid/lite/operators/CMakeLists.txt @@ -1,4 +1,4 @@ -cc_library(fc_op_lite SRCS fc_op.cc DEPS op_lite op_params_lite tensor_lite proto_desc) +cc_library(fc_op_lite SRCS fc_op.cc DEPS op_lite op_params_lite tensor_lite) cc_library(relu_op_lite SRCS relu_op.cc DEPS op_lite) cc_library(mul_op_lite SRCS mul_op.cc DEPS op_lite) cc_library(scale_op_lite SRCS scale_op.cc DEPS op_lite)