From e88d6418af8d59717c988aafa35c81d3b3c0831b Mon Sep 17 00:00:00 2001 From: superjomn Date: Mon, 6 May 2019 14:07:28 +0800 Subject: [PATCH] add compatible for server x86 and GPU tensor. --- paddle/fluid/framework/op_desc.h | 2 +- paddle/fluid/lite/CMakeLists.txt | 8 +- paddle/fluid/lite/api/cxx_api_test.cc | 3 +- paddle/fluid/lite/core/CMakeLists.txt | 7 +- paddle/fluid/lite/core/compatible_tensor.cc | 15 +++ paddle/fluid/lite/core/compatible_tensor.h | 96 +++++++++++++++++++ .../lite/core/{tensor.cc => lite_tensor.cc} | 2 +- .../lite/core/{tensor.h => lite_tensor.h} | 56 ++++------- paddle/fluid/lite/core/tensor_test.cc | 2 +- paddle/fluid/lite/core/type_system.h | 2 +- paddle/fluid/lite/core/variable.h | 2 +- paddle/fluid/lite/cuda/CMakeLists.txt | 4 + paddle/fluid/lite/kernels/CMakeLists.txt | 4 +- paddle/fluid/lite/kernels/cuda/CMakeLists.txt | 4 + .../lite/kernels/cuda/io_copy_compute.cc | 14 +-- paddle/fluid/lite/kernels/cuda/mul_compute.h | 3 +- paddle/fluid/lite/kernels/host/fc_compute.cc | 21 ++-- .../fluid/lite/kernels/host/feed_compute.cc | 2 +- paddle/fluid/lite/kernels/host/mul_compute.cc | 17 ++-- paddle/fluid/lite/kernels/host/relu_compute.h | 9 +- .../fluid/lite/kernels/host/scale_compute.cc | 4 +- paddle/fluid/lite/model_parser/CMakeLists.txt | 2 +- .../fluid/lite/model_parser/compatible_pb.h | 13 +++ .../fluid/lite/model_parser/model_parser.cc | 19 ++-- paddle/fluid/lite/model_parser/model_parser.h | 1 - paddle/fluid/lite/operators/fc_op.cc | 2 +- paddle/fluid/lite/operators/fc_op.h | 4 +- paddle/fluid/lite/operators/feed_op.cc | 2 +- paddle/fluid/lite/operators/fetch_op.cc | 2 +- paddle/fluid/lite/operators/mul_op.cc | 2 +- paddle/fluid/lite/operators/mul_op.h | 5 +- paddle/fluid/lite/operators/op_params.h | 2 +- paddle/fluid/lite/operators/relu_op.h | 1 - paddle/fluid/lite/operators/scale_op.cc | 8 +- paddle/fluid/lite/x86/CMakeLists.txt | 4 + 35 files changed, 237 insertions(+), 107 deletions(-) create mode 100644 paddle/fluid/lite/core/compatible_tensor.cc create mode 100644 paddle/fluid/lite/core/compatible_tensor.h rename paddle/fluid/lite/core/{tensor.cc => lite_tensor.cc} (97%) rename paddle/fluid/lite/core/{tensor.h => lite_tensor.h} (55%) diff --git a/paddle/fluid/framework/op_desc.h b/paddle/fluid/framework/op_desc.h index 03ebc9ac0ac..0b7162af24e 100644 --- a/paddle/fluid/framework/op_desc.h +++ b/paddle/fluid/framework/op_desc.h @@ -33,7 +33,7 @@ class OpDesc { OpDesc(const std::string &type, const VariableNameMap &inputs, const VariableNameMap &outputs, const AttributeMap &attrs); - OpDesc(const proto::OpDesc &desc, BlockDesc *block); + OpDesc(const proto::OpDesc &desc, BlockDesc *block = nullptr); explicit OpDesc(BlockDesc *block) : block_(block) {} diff --git a/paddle/fluid/lite/CMakeLists.txt b/paddle/fluid/lite/CMakeLists.txt index 75595cc17e1..5c09261e4dc 100644 --- a/paddle/fluid/lite/CMakeLists.txt +++ b/paddle/fluid/lite/CMakeLists.txt @@ -1,9 +1,11 @@ +if (NOT WITH_LITE) + return() +endif() + add_subdirectory(core) add_subdirectory(x86) add_subdirectory(host) -if(LITE_WITH_CUDA) - add_subdirectory(cuda) -endif() +add_subdirectory(cuda) add_subdirectory(operators) add_subdirectory(kernels) add_subdirectory(model_parser) diff --git a/paddle/fluid/lite/api/cxx_api_test.cc b/paddle/fluid/lite/api/cxx_api_test.cc index cf78a3fe56d..1380393c07b 100644 --- a/paddle/fluid/lite/api/cxx_api_test.cc +++ b/paddle/fluid/lite/api/cxx_api_test.cc @@ -41,7 +41,8 @@ TEST(CXXApi, test) { auto* input_tensor = predictor.GetInput(0); input_tensor->Resize({100, 100}); - auto* data = input_tensor->mutable_data(); + auto* data = TensorMutableData(input_tensor, TARGET(kHost), + product(input_tensor->dims())); for (int i = 0; i < 100 * 100; i++) { data[i] = i; } diff --git a/paddle/fluid/lite/core/CMakeLists.txt b/paddle/fluid/lite/core/CMakeLists.txt index 6eb988385b4..15d98bd757c 100644 --- a/paddle/fluid/lite/core/CMakeLists.txt +++ b/paddle/fluid/lite/core/CMakeLists.txt @@ -1,7 +1,12 @@ cc_library(lite_gtest_main SRCS lite_gtest_main.cc) cc_library(memory_lite SRCS memory.cc) cc_library(target_wrapper_lite SRCS target_wrapper.cc) -cc_library(tensor_lite SRCS tensor.cc DEPS memory_lite target_wrapper_lite) + +if(LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) + cc_library(tensor_lite SRCS lite_tensor.cc DEPS memory_lite target_wrapper_lite) +else() + cc_library(tensor_lite DEPS lod_tensor) +endif() cc_library(kernel_lite SRCS kernel.cc DEPS type_system target_wrapper_lite) cc_library(variable_lite SRCS variable.cc) cc_library(op_registry_lite SRCS op_registry.cc) diff --git a/paddle/fluid/lite/core/compatible_tensor.cc b/paddle/fluid/lite/core/compatible_tensor.cc new file mode 100644 index 00000000000..c5b839397c3 --- /dev/null +++ b/paddle/fluid/lite/core/compatible_tensor.cc @@ -0,0 +1,15 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/core/compatible_tensor.h" diff --git a/paddle/fluid/lite/core/compatible_tensor.h b/paddle/fluid/lite/core/compatible_tensor.h new file mode 100644 index 00000000000..490b67e923e --- /dev/null +++ b/paddle/fluid/lite/core/compatible_tensor.h @@ -0,0 +1,96 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/fluid/lite/core/target_wrapper.h" +#ifdef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK +#include "paddle/fluid/lite/core/lite_tensor.h" +#else +#include "paddle/fluid/framework/lod_tensor.h" +#endif + +namespace paddle { +namespace lite { + +#ifdef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK +using Tensor = details::Tensor; +using DDim = details::DDim; +#else +using Tensor = framework::LoDTensor; +using DDim = framework::DDim; + +static TargetType TensorGetTarget(const Tensor &x) { + if (platform::is_gpu_place(x.place())) { + return TARGET(kCUDA); + } else if (platform::is_cpu_place(x.place())) { + return TARGET(kX86); + } + return TARGET(kUnk); +} + +template +T *TensorMutableData(Tensor *x, TargetType target, size_t size) { + if (target == TARGET(kX86) || target == TARGET(kHost)) { + return x->mutable_data(platform::CPUPlace(), memory::Allocator::kDefault, + size); + } else if (target == TARGET(kCUDA)) { + return x->mutable_data(platform::CUDAPlace(), + memory::Allocator::kDefault, size); + } + LOG(FATAL) << "not valid target " << TargetToStr(target); + return nullptr; +} +#endif + +static int product(const DDim &dims, int start, int end) { + int res = 1; + for (int i = start; i < end; i++) { + res *= dims[i]; + } + return res; +} + +static DDim SliceDims(const DDim &dims, int begin, int end) { +#ifdef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK + return DDim(dims[0] + begin, dims.begin() + end - 1); +#else + auto vec = framework::vectorize(dims); + return DDim(&vec[0] + begin, end - begin); +#endif +} + +static std::vector DDimVectorize(const DDim &x) { +#ifdef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK + return x; +#else + return framework::vectorize(x); +#endif +} + +#ifdef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK +static int product(const DDim &dims) { + return std::accumulate(dims.begin(), dims.end(), 1, + [](int a, int b) { return a * b; }); +} +#endif + +static DDim flatten_to_2d(const DDim &dims, int col) { + return DDim({product(SliceDims(dims, 0, col)), + product(SliceDims(dims, col, dims.size()))}); +} + +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/core/tensor.cc b/paddle/fluid/lite/core/lite_tensor.cc similarity index 97% rename from paddle/fluid/lite/core/tensor.cc rename to paddle/fluid/lite/core/lite_tensor.cc index 65a47ed05f7..c2dc501c32c 100644 --- a/paddle/fluid/lite/core/tensor.cc +++ b/paddle/fluid/lite/core/lite_tensor.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/lite/core/tensor.h" +#include "paddle/fluid/lite/core/lite_tensor.h" namespace paddle { namespace lite { diff --git a/paddle/fluid/lite/core/tensor.h b/paddle/fluid/lite/core/lite_tensor.h similarity index 55% rename from paddle/fluid/lite/core/tensor.h rename to paddle/fluid/lite/core/lite_tensor.h index 246bc5b214c..918a675b350 100644 --- a/paddle/fluid/lite/core/tensor.h +++ b/paddle/fluid/lite/core/lite_tensor.h @@ -17,30 +17,15 @@ #include #include #include + #include "paddle/fluid/lite/core/memory.h" #include "paddle/fluid/lite/core/target_wrapper.h" namespace paddle { namespace lite { +namespace details { using DDim = std::vector; -static DDim SliceDims(const DDim& dims, int begin, int end) { - return DDim(dims.begin() + begin, dims.begin() + end - 1); -} - -static int product(const DDim& dims) { - return std::accumulate(dims.begin(), dims.end(), 1, - [](int a, int b) { return a * b; }); -} - -static int product(DDim::const_iterator begin, DDim::const_iterator end) { - return std::accumulate(begin, end, 1, [](int a, int b) { return a * b; }); -} - -static DDim flatten_to_2d(const DDim& dims, int col) { - return DDim({product(SliceDims(dims, 0, col)), - product(SliceDims(dims, col, dims.size()))}); -} using LoD = std::vector>; @@ -50,32 +35,32 @@ class Tensor { Tensor() : buffer_(std::make_shared()) {} template - const T* data() const { - return static_cast(buffer_->data()); + const T *data() const { + return static_cast(buffer_->data()); } - void Resize(const DDim& ddim) { dims_ = ddim; } + void Resize(const DDim &ddim) { dims_ = ddim; } - const DDim& dims() const { return dims_; } + const DDim &dims() const { return dims_; } - const LoD& lod() const { return lod_; } - LoD* mutable_lod() { return &lod_; } + const LoD &lod() const { return lod_; } + LoD *mutable_lod() { return &lod_; } template - T* mutable_data(); + T *mutable_data(); template - T* mutable_data(TargetType target); - void* mutable_data(size_t memory_size); - void* mutable_data(TargetType target, size_t memory_size); + T *mutable_data(TargetType target); + void *mutable_data(size_t memory_size); + void *mutable_data(TargetType target, size_t memory_size); size_t memory_size() const { return memory_size_; } bool IsInitialized() const { return buffer_->data(); } // Other share data to this. - void ShareDataWith(const Tensor& other); + void ShareDataWith(const Tensor &other); - void CopyDataFrom(const Tensor& other); + void CopyDataFrom(const Tensor &other); TargetType target() const { return target_; } @@ -88,22 +73,23 @@ class Tensor { }; template -T* Tensor::mutable_data() { +T *Tensor::mutable_data() { memory_size_ = product(dims_) * sizeof(T); buffer_->ResetLazy(target_, memory_size_); - return static_cast(buffer_->data()); + return static_cast(buffer_->data()); } template -T* Tensor::mutable_data(TargetType target) { +T *Tensor::mutable_data(TargetType target) { target_ = target; memory_size_ = product(dims_) * sizeof(T); buffer_->ResetLazy(target, memory_size()); - return static_cast(buffer_->data()); + return static_cast(buffer_->data()); } -std::ostream& operator<<(std::ostream& os, const DDim& dims); -std::ostream& operator<<(std::ostream& os, const Tensor& tensor); +std::ostream &operator<<(std::ostream &os, const DDim &dims); +std::ostream &operator<<(std::ostream &os, const Tensor &tensor); +} // namespace details } // namespace lite } // namespace paddle diff --git a/paddle/fluid/lite/core/tensor_test.cc b/paddle/fluid/lite/core/tensor_test.cc index 8b264562028..247f2d73bf0 100644 --- a/paddle/fluid/lite/core/tensor_test.cc +++ b/paddle/fluid/lite/core/tensor_test.cc @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/lite/core/tensor.h" #include +#include "paddle/fluid/lite/core/lite_tensor.h" namespace paddle { namespace lite { diff --git a/paddle/fluid/lite/core/type_system.h b/paddle/fluid/lite/core/type_system.h index 888a7bf8cc0..5ddb5c8ab6c 100644 --- a/paddle/fluid/lite/core/type_system.h +++ b/paddle/fluid/lite/core/type_system.h @@ -27,7 +27,7 @@ #include #include #include -#include "paddle/fluid/lite/core/tensor.h" +#include "paddle/fluid/lite/core/compatible_tensor.h" #include "paddle/fluid/lite/utils/all.h" namespace paddle { diff --git a/paddle/fluid/lite/core/variable.h b/paddle/fluid/lite/core/variable.h index 63c2505ab90..a0d0636066b 100644 --- a/paddle/fluid/lite/core/variable.h +++ b/paddle/fluid/lite/core/variable.h @@ -13,7 +13,7 @@ // limitations under the License. #pragma once -#include "paddle/fluid/lite/core/tensor.h" +#include "paddle/fluid/lite/core/compatible_tensor.h" #include "paddle/fluid/lite/utils/all.h" namespace paddle { diff --git a/paddle/fluid/lite/cuda/CMakeLists.txt b/paddle/fluid/lite/cuda/CMakeLists.txt index 1e3a9a5c8df..505759c7d4a 100644 --- a/paddle/fluid/lite/cuda/CMakeLists.txt +++ b/paddle/fluid/lite/cuda/CMakeLists.txt @@ -1,2 +1,6 @@ +if(NOT LITE_WITH_CUDA) + return() +endif() + nv_library(target_wrapper_cuda SRCS target_wrapper.cc) nv_library(cuda_blas_lite SRCS blas.cc) diff --git a/paddle/fluid/lite/kernels/CMakeLists.txt b/paddle/fluid/lite/kernels/CMakeLists.txt index 5c2883d5379..ebbfb2139e5 100644 --- a/paddle/fluid/lite/kernels/CMakeLists.txt +++ b/paddle/fluid/lite/kernels/CMakeLists.txt @@ -1,6 +1,4 @@ set(lite_kernel_deps type_system kernel_lite op_registry_lite) add_subdirectory(host) add_subdirectory(arm) -if(LITE_WITH_CUDA) - add_subdirectory(cuda) -endif() +add_subdirectory(cuda) diff --git a/paddle/fluid/lite/kernels/cuda/CMakeLists.txt b/paddle/fluid/lite/kernels/cuda/CMakeLists.txt index f2b2006600b..3d58e9911bd 100644 --- a/paddle/fluid/lite/kernels/cuda/CMakeLists.txt +++ b/paddle/fluid/lite/kernels/cuda/CMakeLists.txt @@ -1,3 +1,7 @@ +if(NOT LITE_WITH_CUDA) + return() +endif() + nv_library(mul_compute_cuda SRCS mul_compute.cc DEPS tensor_lite) cc_library(io_copy_compute_cuda SRCS io_copy_compute.cc DEPS tensor_lite) diff --git a/paddle/fluid/lite/kernels/cuda/io_copy_compute.cc b/paddle/fluid/lite/kernels/cuda/io_copy_compute.cc index 493c8a9181e..897cd67fc47 100644 --- a/paddle/fluid/lite/kernels/cuda/io_copy_compute.cc +++ b/paddle/fluid/lite/kernels/cuda/io_copy_compute.cc @@ -46,11 +46,12 @@ class IoCopyHostToCudaCompute public: void Run() override { auto& param = Param(); - CHECK(param.x->target() == TARGET(kHost) || - param.x->target() == TARGET(kX86)); + CHECK(TensorGetTarget(*param.x) == TARGET(kHost) || + TensorGetTarget(*param.x) == TARGET(kX86)); LOG(INFO) << "copy size " << param.x->memory_size(); - auto* data = param.y->mutable_data(TARGET(kCUDA), param.x->memory_size()); - CopyFromHostSync(data, param.x->data(), param.x->memory_size()); + auto* data = TensorMutableData(param.y, TARGET(kCUDA), + param.x->memory_size()); + CopyFromHostSync(data, param.x->data(), param.x->memory_size()); } std::unique_ptr GetTypeInferHandler() override { @@ -81,8 +82,9 @@ class IoCopyCudaToHostCompute public: void Run() override { auto& param = Param(); - CHECK(param.x->target() == TARGET(kCUDA)); - auto* data = param.y->mutable_data(TARGET(kHost), param.x->memory_size()); + CHECK(TensorGetTarget(*param.x) == TARGET(kCUDA)); + auto* data = TensorMutableData(param.y, TARGET(kHost), + param.x->memory_size()); LOG(INFO) << "copy size " << param.x->memory_size(); CopyToHostSync(data, param.x->data(), param.x->memory_size()); } diff --git a/paddle/fluid/lite/kernels/cuda/mul_compute.h b/paddle/fluid/lite/kernels/cuda/mul_compute.h index 4b5998fd2fb..90cbe0e3fe2 100644 --- a/paddle/fluid/lite/kernels/cuda/mul_compute.h +++ b/paddle/fluid/lite/kernels/cuda/mul_compute.h @@ -51,7 +51,8 @@ class MulCompute : public KernelLite { */ const auto& param = Param(); - param.output->mutable_data(TARGET(kCUDA)); + TensorMutableData(param.output, TARGET(kCUDA), + product(param.output->dims())); LOG(INFO) << "mul output memory size " << param.output->memory_size(); // mul_compute(blas, x, x_h, x_w, y, y_h, y_w, out); diff --git a/paddle/fluid/lite/kernels/host/fc_compute.cc b/paddle/fluid/lite/kernels/host/fc_compute.cc index f63ee3958c9..aad74377c37 100644 --- a/paddle/fluid/lite/kernels/host/fc_compute.cc +++ b/paddle/fluid/lite/kernels/host/fc_compute.cc @@ -29,17 +29,16 @@ void FcCompute::Run() { CHECK_GE(param.input->dims().size(), 2UL); CHECK_EQ(param.output->dims().size(), 2UL); - fc_compute_eigen( - param.input->data(), // x - product(param.input->dims().begin() + param.in_num_col_dims, - param.input->dims().end()), // x_w - product(param.input->dims().begin(), - param.input->dims().begin() + param.in_num_col_dims), // x_h - param.w->data(), // w - param.w->dims()[1], // w_w - param.w->dims()[0], // w_h - param.bias->data(), // b - param.output->mutable_data()); + fc_compute_eigen(param.input->data(), // x + product(param.input->dims(), 0, param.in_num_col_dims), + product(param.input->dims(), param.in_num_col_dims, + param.input->dims().size()), + param.w->data(), // w + param.w->dims()[1], // w_w + param.w->dims()[0], // w_h + param.bias->data(), // b + TensorMutableData(param.output, TARGET(kHost), + product(param.output->dims()))); } // TargetType FcCompute::target() const { return TARGET(kHost); } diff --git a/paddle/fluid/lite/kernels/host/feed_compute.cc b/paddle/fluid/lite/kernels/host/feed_compute.cc index 0df9e8f429d..38fca30998c 100644 --- a/paddle/fluid/lite/kernels/host/feed_compute.cc +++ b/paddle/fluid/lite/kernels/host/feed_compute.cc @@ -28,7 +28,7 @@ class FeedCompute void Run() override { auto ¶m = Param(); const Tensor &feed_item = param.feed_list->at(param.col); - param.out->CopyDataFrom(feed_item); + param.out->ShareDataWith(feed_item); LOG(INFO) << "FEED input " << feed_item << " col " << param.col; LOG(INFO) << "FEED output " << *param.out; } diff --git a/paddle/fluid/lite/kernels/host/mul_compute.cc b/paddle/fluid/lite/kernels/host/mul_compute.cc index 9a61ffaa1c5..7715e588e6f 100644 --- a/paddle/fluid/lite/kernels/host/mul_compute.cc +++ b/paddle/fluid/lite/kernels/host/mul_compute.cc @@ -41,19 +41,18 @@ class MulCompute : public KernelLite { void Run() override { auto& param = Param(); - core::dim2 x_shape({product(param.x->dims().begin(), - param.x->dims().begin() + param.x_num_col_dims), - product(param.x->dims().begin() + param.x_num_col_dims, - param.x->dims().end())}); + core::dim2 x_shape({product(param.x->dims(), 0, param.x_num_col_dims), + product(param.x->dims(), param.x_num_col_dims, + param.x->dims().size())}); - core::dim2 y_shape({product(param.y->dims().begin(), - param.y->dims().begin() + param.x_num_col_dims), - product(param.y->dims().begin() + param.x_num_col_dims, - param.y->dims().end())}); + core::dim2 y_shape({product(param.y->dims(), 0, param.y_num_col_dims), + product(param.y->dims(), param.y_num_col_dims, + param.y->dims().size())}); mul_compute_eigen(param.x->data(), x_shape.x, x_shape.y, // param.y->data(), y_shape.x, y_shape.y, // - param.output->mutable_data()); + TensorMutableData(param.output, TARGET(kHost), + product(param.output->dims()))); LOG(INFO) << "MUL x " << *param.x; LOG(INFO) << "MUL W " << *param.y; LOG(INFO) << "MUL out " << *param.output; diff --git a/paddle/fluid/lite/kernels/host/relu_compute.h b/paddle/fluid/lite/kernels/host/relu_compute.h index eec420e74bf..276535120d7 100644 --- a/paddle/fluid/lite/kernels/host/relu_compute.h +++ b/paddle/fluid/lite/kernels/host/relu_compute.h @@ -24,10 +24,11 @@ namespace host { class ReluCompute : public KernelLite { public: void Run() override { - auto& theparam = Param(); - auto n = product(theparam.input->dims()); - const float* input = theparam.input->data(); - float* output = theparam.output->mutable_data(); + auto& param = Param(); + auto n = product(param.input->dims()); + const float* input = param.input->data(); + float* output = TensorMutableData(param.output, TARGET(kHost), + product(param.output->dims())); for (int i = 0; i < n; i++) { output[i] = std::max(0.f, input[i]); } diff --git a/paddle/fluid/lite/kernels/host/scale_compute.cc b/paddle/fluid/lite/kernels/host/scale_compute.cc index b17498a3612..de1b59e7e09 100644 --- a/paddle/fluid/lite/kernels/host/scale_compute.cc +++ b/paddle/fluid/lite/kernels/host/scale_compute.cc @@ -37,7 +37,9 @@ class ScaleCompute : public KernelLite { void Run() override { auto& param = Param(); - scale_compute(param.x->data(), param.output->mutable_data(), + scale_compute(param.x->data(), + TensorMutableData(param.output, TARGET(kHost), + product(param.output->dims())), product(param.x->dims()), param.scale, param.bias, param.bias_after_scale); } diff --git a/paddle/fluid/lite/model_parser/CMakeLists.txt b/paddle/fluid/lite/model_parser/CMakeLists.txt index 33487809be0..7b8f1534cfd 100644 --- a/paddle/fluid/lite/model_parser/CMakeLists.txt +++ b/paddle/fluid/lite/model_parser/CMakeLists.txt @@ -3,7 +3,7 @@ cc_test(test_model_parser_lite SRCS model_parser_test.cc DEPS model_parser_lite) if(LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) cc_library(compatible_pb_lite SRCS compatible_pb.cc DEPS op_desc_lite var_desc_lite) else() - cc_library(compatible_pb_lite SRCS compatible_pb.cc DEPS framework_proto) + cc_library(compatible_pb_lite SRCS compatible_pb.cc DEPS framework_proto proto_desc) endif(LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) set(model_parser_deps variable_lite scope_lite tensor_lite scope_lite diff --git a/paddle/fluid/lite/model_parser/compatible_pb.h b/paddle/fluid/lite/model_parser/compatible_pb.h index c77d180031d..49db1e4d897 100644 --- a/paddle/fluid/lite/model_parser/compatible_pb.h +++ b/paddle/fluid/lite/model_parser/compatible_pb.h @@ -33,6 +33,7 @@ namespace paddle { namespace lite { #ifdef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK +using Attribute = lite::pb::Attribute; using OpDesc = lite::pb::OpDesc; using VarDesc = lite::pb::VarDesc; #else // LITE_WITH_LIGHT_WEIGHT_FRAMEWORK @@ -41,5 +42,17 @@ using OpDesc = framework::OpDesc; using VarDesc = framework::VarDesc; #endif // LITE_WITH_LIGHT_WEIGHT_FRAMEWORK +#ifdef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK +template +T GetAttr(const Attribute& x) { + return x.get(); +} +#else +template +T GetAttr(const Attribute& x) { + return boost::get(x); +} +#endif // LITE_WITH_LIGHT_WEIGHT_FRAMEWORK + } // namespace lite } // namespace paddle diff --git a/paddle/fluid/lite/model_parser/model_parser.cc b/paddle/fluid/lite/model_parser/model_parser.cc index eecb570b5b6..59aec582749 100644 --- a/paddle/fluid/lite/model_parser/model_parser.cc +++ b/paddle/fluid/lite/model_parser/model_parser.cc @@ -14,8 +14,8 @@ #include "paddle/fluid/lite/model_parser/model_parser.h" #include +#include "paddle/fluid/lite/core/compatible_tensor.h" #include "paddle/fluid/lite/core/scope.h" -#include "paddle/fluid/lite/core/tensor.h" #include "paddle/fluid/lite/core/variable.h" namespace paddle { @@ -59,16 +59,16 @@ void TensorFromStream(std::istream &is, lite::Tensor *tensor) { // read tensor std::vector dims; - dims.reserve(static_cast(desc.dims().size())); std::copy(desc.dims().begin(), desc.dims().end(), std::back_inserter(dims)); - tensor->Resize(dims); + tensor->Resize(lite::DDim(&dims[0], dims.size())); void *buf; size_t size = product(tensor->dims()) * SizeOfType(desc.data_type()); // alllocate memory switch (static_cast(desc.data_type())) { -#define DO(desc, type) \ - case Type::VarType_Type_##desc: \ - buf = tensor->mutable_data(); \ +#define DO(desc, type) \ + case Type::VarType_Type_##desc: \ + buf = TensorMutableData(tensor, TensorGetTarget(*tensor), \ + product(tensor->dims())); break; DO(BOOL, bool); DO(FP32, float); @@ -198,7 +198,8 @@ void TensorToStream(std::ostream &os, const lite::Tensor &tensor) { auto dims = tensor.dims(); auto *pb_dims = desc.mutable_dims(); pb_dims->Resize(static_cast(dims.size()), 0); - std::copy(dims.begin(), dims.end(), pb_dims->begin()); + auto dims_vec = DDimVectorize(dims); + std::copy(dims_vec.begin(), dims_vec.end(), pb_dims->begin()); int32_t size = desc.ByteSize(); os.write(reinterpret_cast(&size), sizeof(size)); auto out = desc.SerializeAsString(); @@ -210,9 +211,9 @@ void TensorToStream(std::ostream &os, const lite::Tensor &tensor) { << "Index overflow when writing tensor"; #ifdef LITE_WITH_CUDA - if (tensor.target() == TARGET(kCUDA)) { + if (TensorGetTarget(tensor) == TARGET(kCUDA)) { std::unique_ptr tmp_buffer(new char[size]); - TargetWrapperCuda::MemcpySync(tmp_buffer.get(), tensor.data(), + TargetWrapperCuda::MemcpySync(tmp_buffer.get(), tensor.data(), tensor.memory_size(), IoDirection::DtoH); os.write(static_cast(tmp_buffer.get()), static_cast(size)); diff --git a/paddle/fluid/lite/model_parser/model_parser.h b/paddle/fluid/lite/model_parser/model_parser.h index 1919550348b..003d335c1b3 100644 --- a/paddle/fluid/lite/model_parser/model_parser.h +++ b/paddle/fluid/lite/model_parser/model_parser.h @@ -20,7 +20,6 @@ #include #include "paddle/fluid/framework/framework.pb.h" #include "paddle/fluid/lite/core/scope.h" -#include "paddle/fluid/lite/core/tensor.h" #include "paddle/fluid/lite/core/variable.h" namespace paddle { diff --git a/paddle/fluid/lite/operators/fc_op.cc b/paddle/fluid/lite/operators/fc_op.cc index e4f6d336307..03c91d1c36c 100644 --- a/paddle/fluid/lite/operators/fc_op.cc +++ b/paddle/fluid/lite/operators/fc_op.cc @@ -58,7 +58,7 @@ bool FcOpLite::InferShape() const { output_dims[i] = input_dims[i]; } output_dims.back() = w_dims[1]; - param_.output->Resize(output_dims); + param_.output->Resize(DDim(&output_dims[0], output_dims.size())); // share LoD // param_.output->set_lod(param_.input->lod()); diff --git a/paddle/fluid/lite/operators/fc_op.h b/paddle/fluid/lite/operators/fc_op.h index 4c322e41b8f..d5379b8344a 100644 --- a/paddle/fluid/lite/operators/fc_op.h +++ b/paddle/fluid/lite/operators/fc_op.h @@ -16,10 +16,10 @@ #include #include +#include "paddle/fluid/lite/core/compatible_tensor.h" #include "paddle/fluid/lite/core/kernel.h" #include "paddle/fluid/lite/core/op_lite.h" #include "paddle/fluid/lite/core/scope.h" -#include "paddle/fluid/lite/core/tensor.h" #include "paddle/fluid/lite/operators/op_params.h" #include "paddle/fluid/lite/utils/all.h" @@ -57,7 +57,7 @@ class FcOpLite : public OpLite { param_.bias = scope->FindVar(bias)->GetMutable(); CHECK(scope->FindVar(out)); param_.output = scope->FindVar(out)->GetMutable(); - param_.in_num_col_dims = op_desc.GetAttr("in_num_col_dims").get(); + param_.in_num_col_dims = GetAttr(op_desc.GetAttr("in_num_col_dims")); CHECK(kernel_); kernel_->SetParam(param_); diff --git a/paddle/fluid/lite/operators/feed_op.cc b/paddle/fluid/lite/operators/feed_op.cc index 0b5ffcfd63a..03ea820f49a 100644 --- a/paddle/fluid/lite/operators/feed_op.cc +++ b/paddle/fluid/lite/operators/feed_op.cc @@ -49,7 +49,7 @@ class FeedOp : public OpLite { // NOTE need boost here // TODO(Superjomn) drop the need of framework::op_desc - param_.col = opdesc.GetAttr("col").get(); + param_.col = GetAttr(opdesc.GetAttr("col")); return true; } diff --git a/paddle/fluid/lite/operators/fetch_op.cc b/paddle/fluid/lite/operators/fetch_op.cc index b34d57645ef..ea86d6a2f75 100644 --- a/paddle/fluid/lite/operators/fetch_op.cc +++ b/paddle/fluid/lite/operators/fetch_op.cc @@ -43,7 +43,7 @@ class FetchOp : public OpLite { auto* out = scope->FindVar(_out); param_.fetch_list = out->GetMutable>(); - param_.col = opdesc.GetAttr("col").get(); + param_.col = GetAttr(opdesc.GetAttr("col")); return true; } diff --git a/paddle/fluid/lite/operators/mul_op.cc b/paddle/fluid/lite/operators/mul_op.cc index e0fe5837153..c79f16dbff0 100644 --- a/paddle/fluid/lite/operators/mul_op.cc +++ b/paddle/fluid/lite/operators/mul_op.cc @@ -45,7 +45,7 @@ bool MulOpLite::InferShape() const { } out_dims.back() = y_dims[1]; - param_.output->Resize(out_dims); + param_.output->Resize(DDim(&out_dims[0], out_dims.size())); // share LoD // param_.output->set_lod(param_.input->lod()); diff --git a/paddle/fluid/lite/operators/mul_op.h b/paddle/fluid/lite/operators/mul_op.h index 334321c457c..613e8c9f0c1 100644 --- a/paddle/fluid/lite/operators/mul_op.h +++ b/paddle/fluid/lite/operators/mul_op.h @@ -18,7 +18,6 @@ #include "paddle/fluid/lite/core/kernel.h" #include "paddle/fluid/lite/core/op_lite.h" #include "paddle/fluid/lite/core/scope.h" -#include "paddle/fluid/lite/core/tensor.h" #include "paddle/fluid/lite/operators/op_params.h" #include "paddle/fluid/lite/utils/all.h" @@ -47,8 +46,8 @@ class MulOpLite : public OpLite { param_.y = scope->FindVar(W)->GetMutable(); CHECK(scope->FindVar(out)); param_.output = scope->FindVar(out)->GetMutable(); - param_.x_num_col_dims = op_desc.GetAttr("x_num_col_dims").get(); - param_.y_num_col_dims = op_desc.GetAttr("y_num_col_dims").get(); + param_.x_num_col_dims = GetAttr(op_desc.GetAttr("x_num_col_dims")); + param_.y_num_col_dims = GetAttr(op_desc.GetAttr("y_num_col_dims")); return true; } diff --git a/paddle/fluid/lite/operators/op_params.h b/paddle/fluid/lite/operators/op_params.h index 8f1d007d85e..c3f716906a1 100644 --- a/paddle/fluid/lite/operators/op_params.h +++ b/paddle/fluid/lite/operators/op_params.h @@ -13,7 +13,7 @@ // limitations under the License. #pragma once -#include "paddle/fluid/lite/core/tensor.h" +#include "paddle/fluid/lite/core/compatible_tensor.h" #include "paddle/fluid/lite/utils/all.h" /* diff --git a/paddle/fluid/lite/operators/relu_op.h b/paddle/fluid/lite/operators/relu_op.h index 088f1314dac..66929cef199 100644 --- a/paddle/fluid/lite/operators/relu_op.h +++ b/paddle/fluid/lite/operators/relu_op.h @@ -16,7 +16,6 @@ #include #include "paddle/fluid/lite/core/op_lite.h" #include "paddle/fluid/lite/core/scope.h" -#include "paddle/fluid/lite/core/tensor.h" #include "paddle/fluid/lite/utils/all.h" namespace paddle { diff --git a/paddle/fluid/lite/operators/scale_op.cc b/paddle/fluid/lite/operators/scale_op.cc index 33d925bae73..87cbe2a2e03 100644 --- a/paddle/fluid/lite/operators/scale_op.cc +++ b/paddle/fluid/lite/operators/scale_op.cc @@ -18,7 +18,6 @@ #include "paddle/fluid/lite/core/op_lite.h" #include "paddle/fluid/lite/core/op_registry.h" #include "paddle/fluid/lite/core/scope.h" -#include "paddle/fluid/lite/core/tensor.h" #include "paddle/fluid/lite/operators/op_params.h" #include "paddle/fluid/lite/utils/all.h" @@ -53,9 +52,10 @@ class ScaleOp : public OpLite { param_.x = scope->FindVar(x)->GetMutable(); CHECK(scope->FindVar(out)); param_.output = scope->FindVar(out)->GetMutable(); - param_.scale = op_desc.GetAttr("scale").get(); - param_.bias = op_desc.GetAttr("bias").get(); - param_.bias_after_scale = op_desc.GetAttr("bias_after_scale").get(); + param_.scale = GetAttr(op_desc.GetAttr("scale")); + param_.bias = GetAttr(op_desc.GetAttr("bias")); + param_.bias_after_scale = + GetAttr(op_desc.GetAttr("bias_after_scale")); return true; } diff --git a/paddle/fluid/lite/x86/CMakeLists.txt b/paddle/fluid/lite/x86/CMakeLists.txt index 7cbf432c165..be772b921b4 100644 --- a/paddle/fluid/lite/x86/CMakeLists.txt +++ b/paddle/fluid/lite/x86/CMakeLists.txt @@ -1 +1,5 @@ +if (NOT LITE_WITH_X86) + return() +endif() + cc_library(target_wrapper_x86 SRCS target_wrapper.cc) -- GitLab