From 706e83af3af3916a5753b6ce644a1bd822d908a1 Mon Sep 17 00:00:00 2001 From: superjomn Date: Mon, 6 May 2019 19:05:57 +0800 Subject: [PATCH] make an adapter for TensorLite and framework::LoDTensor and DDim --- paddle/fluid/lite/api/cxx_api.h | 6 +- paddle/fluid/lite/api/cxx_api_test.cc | 19 +- paddle/fluid/lite/api/light_api.h | 9 +- paddle/fluid/lite/core/CMakeLists.txt | 12 +- paddle/fluid/lite/core/compatible_tensor.h | 73 +------- paddle/fluid/lite/core/hvy_tensor.cc | 15 ++ paddle/fluid/lite/core/hvy_tensor.h | 109 ++++++++++++ paddle/fluid/lite/core/kernel.h | 2 +- paddle/fluid/lite/core/lite_gtest_main.cc | 2 + paddle/fluid/lite/core/lite_tensor.cc | 32 +--- paddle/fluid/lite/core/lite_tensor.h | 51 ++++-- paddle/fluid/lite/core/op_executor_test.cc | 6 +- paddle/fluid/lite/core/program.h | 4 +- paddle/fluid/lite/core/program_fake_utils.h | 10 +- paddle/fluid/lite/core/tensor.h | 168 ++++++++++++++++++ paddle/fluid/lite/core/tensor_test.cc | 2 +- paddle/fluid/lite/core/type_system_test.cc | 2 +- paddle/fluid/lite/core/variable.h | 4 +- paddle/fluid/lite/kernels/CMakeLists.txt | 2 +- paddle/fluid/lite/kernels/cuda/CMakeLists.txt | 4 +- .../lite/kernels/cuda/io_copy_compute.cc | 20 +-- paddle/fluid/lite/kernels/cuda/mul_compute.h | 5 +- paddle/fluid/lite/kernels/host/fc_compute.cc | 21 +-- .../lite/kernels/host/fc_compute_test.cc | 10 +- .../fluid/lite/kernels/host/feed_compute.cc | 4 +- paddle/fluid/lite/kernels/host/mul_compute.cc | 24 ++- paddle/fluid/lite/kernels/host/relu_compute.h | 5 +- .../fluid/lite/kernels/host/scale_compute.cc | 6 +- paddle/fluid/lite/model_parser/CMakeLists.txt | 2 +- .../fluid/lite/model_parser/model_parser.cc | 27 +-- .../lite/model_parser/model_parser_test.cc | 2 +- paddle/fluid/lite/operators/CMakeLists.txt | 17 +- paddle/fluid/lite/operators/fc_op.cc | 4 +- paddle/fluid/lite/operators/fc_op.h | 8 +- paddle/fluid/lite/operators/fc_op_test.cc | 8 +- paddle/fluid/lite/operators/feed_op.cc | 4 +- paddle/fluid/lite/operators/fetch_op.cc | 2 +- paddle/fluid/lite/operators/mul_op.cc | 2 +- paddle/fluid/lite/operators/op_params.h | 36 ++-- paddle/fluid/lite/operators/relu_op.cc | 6 +- paddle/fluid/lite/utils/varient.h | 15 +- 41 files changed, 505 insertions(+), 255 deletions(-) create mode 100644 paddle/fluid/lite/core/hvy_tensor.cc create mode 100644 paddle/fluid/lite/core/hvy_tensor.h create mode 100644 paddle/fluid/lite/core/tensor.h diff --git a/paddle/fluid/lite/api/cxx_api.h b/paddle/fluid/lite/api/cxx_api.h index 81f5694fbae..a3a66e99000 100644 --- a/paddle/fluid/lite/api/cxx_api.h +++ b/paddle/fluid/lite/api/cxx_api.h @@ -45,17 +45,17 @@ class LightPredictor { void SaveModel(const std::string& dir); // Get offset-th col of feed. - Tensor* GetInput(size_t offset) { + lite::Tensor* GetInput(size_t offset) { auto* _feed_list = program_->exec_scope()->FindVar("feed"); CHECK(_feed_list) << "no feed variable in exec_scope"; - auto* feed_list = _feed_list->GetMutable>(); + auto* feed_list = _feed_list->GetMutable>(); if (offset >= feed_list->size()) { feed_list->resize(offset + 1); } return &feed_list->at(offset); } - const Tensor* GetOutput(size_t offset) { + const lite::Tensor* GetOutput(size_t offset) { auto* _fetch_list = program_->exec_scope()->FindVar("fetch"); CHECK(_fetch_list) << "no fatch variable in exec_scope"; auto& fetch_list = *_fetch_list->GetMutable>(); diff --git a/paddle/fluid/lite/api/cxx_api_test.cc b/paddle/fluid/lite/api/cxx_api_test.cc index 1380393c07b..25eaa3d9e5d 100644 --- a/paddle/fluid/lite/api/cxx_api_test.cc +++ b/paddle/fluid/lite/api/cxx_api_test.cc @@ -13,11 +13,14 @@ // limitations under the License. #include "paddle/fluid/lite/api/cxx_api.h" +#include #include #include "paddle/fluid/lite/core/mir/passes.h" #include "paddle/fluid/lite/core/op_executor.h" #include "paddle/fluid/lite/core/op_registry.h" +DEFINE_string(model_dir, "", ""); + namespace paddle { namespace lite { @@ -36,24 +39,22 @@ TEST(CXXApi, test) { }); #endif - predictor.Build("/home/chunwei/project/models/model2", - Place{TARGET(kCUDA), PRECISION(kFloat)}, valid_places); + predictor.Build(FLAGS_model_dir, Place{TARGET(kCUDA), PRECISION(kFloat)}, + valid_places); auto* input_tensor = predictor.GetInput(0); - input_tensor->Resize({100, 100}); - auto* data = TensorMutableData(input_tensor, TARGET(kHost), - product(input_tensor->dims())); + input_tensor->Resize(DDim(std::vector({100, 100}))); + auto* data = input_tensor->mutable_data(); for (int i = 0; i < 100 * 100; i++) { data[i] = i; } - LOG(INFO) << "input " << input_tensor; LOG(INFO) << "input " << *input_tensor; predictor.Run(); auto* out = predictor.GetOutput(0); - LOG(INFO) << out << " memory size " << out->memory_size(); + LOG(INFO) << out << " memory size " << out->data_size(); LOG(INFO) << "out " << out->data()[0]; LOG(INFO) << "out " << out->data()[1]; LOG(INFO) << "dims " << out->dims(); @@ -63,8 +64,8 @@ TEST(CXXApi, test) { TEST(CXXApi, save_model) { lite::LightPredictor predictor; std::vector valid_places({Place{TARGET(kHost), PRECISION(kFloat)}}); - predictor.Build("/home/chunwei/project/models/model2", - Place{TARGET(kCUDA), PRECISION(kFloat)}, valid_places); + predictor.Build(FLAGS_model_dir, Place{TARGET(kCUDA), PRECISION(kFloat)}, + valid_places); predictor.SaveModel("./optimized_model"); } diff --git a/paddle/fluid/lite/api/light_api.h b/paddle/fluid/lite/api/light_api.h index 484af4d339b..ec07b8c979d 100644 --- a/paddle/fluid/lite/api/light_api.h +++ b/paddle/fluid/lite/api/light_api.h @@ -41,20 +41,21 @@ class LightPredictor { void Run() { program_->Run(); } // Get offset-th col of feed. - Tensor* GetInput(size_t offset) { + TensorBase* GetInput(size_t offset) { auto* _feed_list = program_->exec_scope()->FindVar("feed"); CHECK(_feed_list) << "no feed variable in exec_scope"; - auto* feed_list = _feed_list->GetMutable>(); + auto* feed_list = _feed_list->GetMutable>(); if (offset >= feed_list->size()) { feed_list->resize(offset + 1); } return &feed_list->at(offset); } - const Tensor* GetOutput(size_t offset) { + const TensorBase* GetOutput(size_t offset) { auto* _fetch_list = program_->exec_scope()->FindVar("fetch"); CHECK(_fetch_list) << "no fatch variable in exec_scope"; - auto& fetch_list = *_fetch_list->GetMutable>(); + auto& fetch_list = + *_fetch_list->GetMutable>(); CHECK_LT(offset, fetch_list.size()) << "offset " << offset << " overflow"; return &fetch_list.at(offset); } diff --git a/paddle/fluid/lite/core/CMakeLists.txt b/paddle/fluid/lite/core/CMakeLists.txt index 15d98bd757c..5e760cf1631 100644 --- a/paddle/fluid/lite/core/CMakeLists.txt +++ b/paddle/fluid/lite/core/CMakeLists.txt @@ -2,22 +2,26 @@ cc_library(lite_gtest_main SRCS lite_gtest_main.cc) cc_library(memory_lite SRCS memory.cc) cc_library(target_wrapper_lite SRCS target_wrapper.cc) +cc_library(lite_tensor SRCS lite_tensor.cc DEPS memory_lite target_wrapper_lite) +cc_library(hvy_tensor SRCS hvy_tensor.cc DEPS lod_tensor) + if(LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) - cc_library(tensor_lite SRCS lite_tensor.cc DEPS memory_lite target_wrapper_lite) + set(tensor_lite lite_tensor) else() - cc_library(tensor_lite DEPS lod_tensor) + set(tensor_lite hvy_tensor) endif() + cc_library(kernel_lite SRCS kernel.cc DEPS type_system target_wrapper_lite) cc_library(variable_lite SRCS variable.cc) cc_library(op_registry_lite SRCS op_registry.cc) cc_library(scope_lite SRCS scope.cc) cc_library(op_lite SRCS op_lite.cc DEPS scope_lite op_registry_lite compatible_pb_lite) -cc_library(op_executor_lite SRCS op_executor.cc DEPS scope_lite tensor_lite op_lite op_registry_lite +cc_library(op_executor_lite SRCS op_executor.cc DEPS scope_lite ${tensor_lite} op_lite op_registry_lite #TODO(Superjomn) remove these dependencies from original framework ) cc_library(kernel_executor_lite SRCS kernel_executor.cc DEPS mir_ssa_graph kernel_lite) cc_library(types_lite SRCS types.cc) -cc_library(type_system SRCS type_system.cc DEPS tensor_lite) +cc_library(type_system SRCS type_system.cc DEPS ${tensor_lite}) cc_library(program_fake_utils SRCS program_fake_utils.cc DEPS mir_ssa_graph scope_lite op_registry_lite proto_desc op_lite ops_lite diff --git a/paddle/fluid/lite/core/compatible_tensor.h b/paddle/fluid/lite/core/compatible_tensor.h index 490b67e923e..812ead795b3 100644 --- a/paddle/fluid/lite/core/compatible_tensor.h +++ b/paddle/fluid/lite/core/compatible_tensor.h @@ -14,83 +14,24 @@ #pragma once -#include -#include "paddle/fluid/lite/core/target_wrapper.h" +#include "paddle/fluid/lite/core/tensor.h" + #ifdef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK #include "paddle/fluid/lite/core/lite_tensor.h" #else -#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/lite/core/hvy_tensor.h" #endif namespace paddle { namespace lite { #ifdef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK -using Tensor = details::Tensor; -using DDim = details::DDim; -#else -using Tensor = framework::LoDTensor; -using DDim = framework::DDim; - -static TargetType TensorGetTarget(const Tensor &x) { - if (platform::is_gpu_place(x.place())) { - return TARGET(kCUDA); - } else if (platform::is_cpu_place(x.place())) { - return TARGET(kX86); - } - return TARGET(kUnk); -} - -template -T *TensorMutableData(Tensor *x, TargetType target, size_t size) { - if (target == TARGET(kX86) || target == TARGET(kHost)) { - return x->mutable_data(platform::CPUPlace(), memory::Allocator::kDefault, - size); - } else if (target == TARGET(kCUDA)) { - return x->mutable_data(platform::CUDAPlace(), - memory::Allocator::kDefault, size); - } - LOG(FATAL) << "not valid target " << TargetToStr(target); - return nullptr; -} -#endif - -static int product(const DDim &dims, int start, int end) { - int res = 1; - for (int i = start; i < end; i++) { - res *= dims[i]; - } - return res; -} - -static DDim SliceDims(const DDim &dims, int begin, int end) { -#ifdef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK - return DDim(dims[0] + begin, dims.begin() + end - 1); +using DDim = lite::DDimLite; +using Tensor = lite::TensorLite; #else - auto vec = framework::vectorize(dims); - return DDim(&vec[0] + begin, end - begin); +using DDim = lite::DDimHvy; +using Tensor = lite::TensorHvy; #endif -} - -static std::vector DDimVectorize(const DDim &x) { -#ifdef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK - return x; -#else - return framework::vectorize(x); -#endif -} - -#ifdef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK -static int product(const DDim &dims) { - return std::accumulate(dims.begin(), dims.end(), 1, - [](int a, int b) { return a * b; }); -} -#endif - -static DDim flatten_to_2d(const DDim &dims, int col) { - return DDim({product(SliceDims(dims, 0, col)), - product(SliceDims(dims, col, dims.size()))}); -} } // namespace lite } // namespace paddle diff --git a/paddle/fluid/lite/core/hvy_tensor.cc b/paddle/fluid/lite/core/hvy_tensor.cc new file mode 100644 index 00000000000..4ae429bdd55 --- /dev/null +++ b/paddle/fluid/lite/core/hvy_tensor.cc @@ -0,0 +1,15 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/core/hvy_tensor.h" diff --git a/paddle/fluid/lite/core/hvy_tensor.h b/paddle/fluid/lite/core/hvy_tensor.h new file mode 100644 index 00000000000..9fb2aeea3d6 --- /dev/null +++ b/paddle/fluid/lite/core/hvy_tensor.h @@ -0,0 +1,109 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +/* + * This file defines the heavy tensor (alias of the LoDTensor in the server + * framework). We derive it from the TensorLite interface, so the lite framework + * can share much code between the server side and mobile side. + */ + +#pragma once +#include +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/lite/core/tensor.h" + +namespace paddle { +namespace lite { + +class DDimHvy : public DDimBase { + public: + DDimHvy() = default; + explicit DDimHvy(const std::vector& x) : DDimBase() { + ConstructFrom(x); + } + explicit DDimHvy(const framework::DDim& x) : data_(x) {} + + void ConstructFrom(const std::vector& xs) { + data_ = framework::DDim(xs.data(), xs.size()); + } + + value_type operator[](int offset) const { return data_[offset]; } + + std::vector Vectorize() const { return framework::vectorize(data_); } + + const framework::DDim& data() const { return data_; } + + size_t size() const { return data_.size(); } + bool empty() const { return data_.size() == 0; } + + private: + framework::DDim data_; +}; + +class TensorHvy : public TensorBase { + public: + using DDimT = DDimHvy; + using LoDT = framework::LoD; + + TargetType target() const { + if (platform::is_gpu_place(data_.place())) { + return TARGET(kCUDA); + } else if (platform::is_cpu_place(data_.place())) { + return TARGET(kX86); + } + LOG(FATAL) << "Unknown place"; + return TARGET(kUnk); + } + + template + T* mutable_data() { + return data_.mutable_data(data_.dims(), platform::CPUPlace()); + } + template + T* mutable_data(TargetType target) { + if (target == TARGET(kCUDA)) { + return data_.mutable_data(data_.dims(), platform::CUDAPlace()); + } + return data_.mutable_data(data_.dims(), platform::CPUPlace()); + } + + template + const T* data() const { + return data_.data(); + } + + template + void Resize(const DimT& dims) { + LOG(INFO) << "dims.size " << dims.size(); + data_.Resize(framework::make_ddim(dims.Vectorize())); + } + + void ShareDataWith(const TensorHvy& other) { + data_.ShareDataWith(other.data_); + } + void CopyDataFrom(const TensorHvy& other) { + data_.ShareDataWith(other.data_); + } + + DDimT dims() const { return DDimT(framework::vectorize(data_.dims())); } + + const framework::LoD& lod() const { return data_.lod(); } + framework::LoD* mutable_lod() { return data_.mutable_lod(); } + + private: + framework::LoDTensor data_; +}; + +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/core/kernel.h b/paddle/fluid/lite/core/kernel.h index bfabd87baff..4695a87a42c 100644 --- a/paddle/fluid/lite/core/kernel.h +++ b/paddle/fluid/lite/core/kernel.h @@ -52,7 +52,7 @@ class KernelBase { } template P& Param() const { - return param_.get

(); + return *param_.get_mutable

(); } // This is used in the kernels that takes 'kAny' places and inference the diff --git a/paddle/fluid/lite/core/lite_gtest_main.cc b/paddle/fluid/lite/core/lite_gtest_main.cc index 9f9bd7ba467..9784fc79945 100644 --- a/paddle/fluid/lite/core/lite_gtest_main.cc +++ b/paddle/fluid/lite/core/lite_gtest_main.cc @@ -12,10 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include #include int main(int argc, char** argv) { testing::InitGoogleTest(&argc, argv); + google::ParseCommandLineFlags(&argc, &argv, false); return RUN_ALL_TESTS(); } diff --git a/paddle/fluid/lite/core/lite_tensor.cc b/paddle/fluid/lite/core/lite_tensor.cc index c2dc501c32c..0dce115cec6 100644 --- a/paddle/fluid/lite/core/lite_tensor.cc +++ b/paddle/fluid/lite/core/lite_tensor.cc @@ -17,31 +17,7 @@ namespace paddle { namespace lite { -std::ostream &operator<<(std::ostream &os, const DDim &dims) { - if (dims.empty()) { - os << "[]"; - return os; - } - - os << "["; - for (size_t i = 0; i < dims.size() - 1; i++) { - os << dims[i] << " "; - } - os << dims.back() << "]"; - return os; -} - -std::ostream &operator<<(std::ostream &os, const Tensor &tensor) { - os << "Tensor:" << '\n'; - os << "dim: " << tensor.dims() << '\n'; - for (int i = 0; i < product(tensor.dims()); i++) { - os << tensor.data()[i] << " "; - } - os << "\n"; - return os; -} - -void Tensor::ShareDataWith(const Tensor &other) { +void TensorLite::ShareDataWith(const TensorLite &other) { buffer_ = other.buffer_; dims_ = other.dims_; target_ = other.target_; @@ -49,17 +25,17 @@ void Tensor::ShareDataWith(const Tensor &other) { memory_size_ = other.memory_size_; } -void *Tensor::mutable_data(size_t memory_size) { +void *TensorLite::mutable_data(size_t memory_size) { buffer_->ResetLazy(target_, memory_size); return buffer_->data(); } -void *Tensor::mutable_data(TargetType target, size_t memory_size) { +void *TensorLite::mutable_data(TargetType target, size_t memory_size) { target_ = target; return mutable_data(memory_size); } -void Tensor::CopyDataFrom(const Tensor &other) { +void TensorLite::CopyDataFrom(const TensorLite &other) { dims_ = other.dims_; target_ = other.target_; lod_ = other.lod_; diff --git a/paddle/fluid/lite/core/lite_tensor.h b/paddle/fluid/lite/core/lite_tensor.h index 918a675b350..ea31b3f9aa5 100644 --- a/paddle/fluid/lite/core/lite_tensor.h +++ b/paddle/fluid/lite/core/lite_tensor.h @@ -20,28 +20,49 @@ #include "paddle/fluid/lite/core/memory.h" #include "paddle/fluid/lite/core/target_wrapper.h" +#include "paddle/fluid/lite/core/tensor.h" namespace paddle { namespace lite { -namespace details { -using DDim = std::vector; +class DDimLite : public DDimBase { + public: + DDimLite() = default; + + DDimLite(const std::vector &x) : DDimBase() { + ConstructFrom(x); + } + + void ConstructFrom(const std::vector &x) { data_ = x; } + + value_type operator[](int offset) const { return data_[offset]; } + std::vector Vectorize() { return data_; } + + size_t size() const { return data_.size(); } + bool empty() const { return data_.empty(); } + const std::vector &data() const { return data_; } + + private: + std::vector data_; +}; using LoD = std::vector>; // A light-weight tensor implementation. -class Tensor { +class TensorLite : public TensorBase { public: - Tensor() : buffer_(std::make_shared()) {} + using DDimT = DDimLite; + + TensorLite() : buffer_(std::make_shared()) {} template const T *data() const { return static_cast(buffer_->data()); } - void Resize(const DDim &ddim) { dims_ = ddim; } + void Resize(const DDimLite &ddim) { dims_ = ddim; } - const DDim &dims() const { return dims_; } + const DDimLite &dims() const { return dims_; } const LoD &lod() const { return lod_; } LoD *mutable_lod() { return &lod_; } @@ -58,38 +79,34 @@ class Tensor { bool IsInitialized() const { return buffer_->data(); } // Other share data to this. - void ShareDataWith(const Tensor &other); + void ShareDataWith(const TensorLite &other); - void CopyDataFrom(const Tensor &other); + void CopyDataFrom(const TensorLite &other); TargetType target() const { return target_; } private: TargetType target_{TargetType::kHost}; - DDim dims_; + DDimLite dims_; std::shared_ptr buffer_; LoD lod_; size_t memory_size_{}; }; template -T *Tensor::mutable_data() { - memory_size_ = product(dims_) * sizeof(T); +T *TensorLite::mutable_data() { + memory_size_ = dims_.production() * sizeof(T); buffer_->ResetLazy(target_, memory_size_); return static_cast(buffer_->data()); } template -T *Tensor::mutable_data(TargetType target) { +T *TensorLite::mutable_data(TargetType target) { target_ = target; - memory_size_ = product(dims_) * sizeof(T); + memory_size_ = dims_.production() * sizeof(T); buffer_->ResetLazy(target, memory_size()); return static_cast(buffer_->data()); } -std::ostream &operator<<(std::ostream &os, const DDim &dims); -std::ostream &operator<<(std::ostream &os, const Tensor &tensor); - -} // namespace details } // namespace lite } // namespace paddle diff --git a/paddle/fluid/lite/core/op_executor_test.cc b/paddle/fluid/lite/core/op_executor_test.cc index 51912b363a8..1fb81ee1d1c 100644 --- a/paddle/fluid/lite/core/op_executor_test.cc +++ b/paddle/fluid/lite/core/op_executor_test.cc @@ -39,11 +39,11 @@ TEST(executor, test) { op_desc.SetAttr("in_num_col_dims", static_cast(1)); program.Flush(); - auto* w = scope->Var("w")->GetMutable(); + auto* w = scope->Var("w")->GetMutable(); w->Resize({20, 20}); - auto* x = scope->Var("x")->GetMutable(); + auto* x = scope->Var("x")->GetMutable(); x->Resize({1, 10, 20}); - auto* bias = scope->Var("bias")->GetMutable(); + auto* bias = scope->Var("bias")->GetMutable(); bias->Resize({1, 20}); bias->mutable_data(); diff --git a/paddle/fluid/lite/core/program.h b/paddle/fluid/lite/core/program.h index 1800e3f64f3..f57b8d923db 100644 --- a/paddle/fluid/lite/core/program.h +++ b/paddle/fluid/lite/core/program.h @@ -81,8 +81,8 @@ struct Program { CHECK(!exec_scope) << "Duplicate PrepareWorkspace found"; exec_scope = &scope->NewScope(); // Create Feed and Fetch var. - scope->Var("feed")->GetMutable>(); - scope->Var("fetch")->GetMutable>(); + scope->Var("feed")->GetMutable>(); + scope->Var("fetch")->GetMutable>(); tmp_vars.push_back("feed"); tmp_vars.push_back("fetch"); diff --git a/paddle/fluid/lite/core/program_fake_utils.h b/paddle/fluid/lite/core/program_fake_utils.h index e1dafc8ac52..55e3d4663b8 100644 --- a/paddle/fluid/lite/core/program_fake_utils.h +++ b/paddle/fluid/lite/core/program_fake_utils.h @@ -28,9 +28,9 @@ Program FakeProgram() { std::string w1 = "w" + std::to_string(id); std::string b1 = "b" + std::to_string(id); std::string out1 = "out" + std::to_string(id); - auto w1v = program.scope->Var(w1)->GetMutable(); - auto b1v = program.scope->Var(b1)->GetMutable(); - auto out1v = program.scope->Var(out1)->GetMutable(); + auto w1v = program.scope->Var(w1)->GetMutable(); + auto b1v = program.scope->Var(b1)->GetMutable(); + auto out1v = program.scope->Var(out1)->GetMutable(); lite::OpDesc desc; desc.SetInput("Input", {x}); @@ -60,7 +60,7 @@ Program FakeProgram() { std::string x = "x"; program.tmp_vars.push_back(x); - auto* xv = program.scope->Var(x)->GetMutable(); + auto* xv = program.scope->Var(x)->GetMutable(); xv->Resize({100, 100}); for (int i = 0; i < 3; i++) { @@ -81,7 +81,7 @@ class ProgramFaker { void CreateVars(lite::Scope* scope) { for (auto& var : tmp_vars_) { auto* x = scope->Var(var); - x->GetMutable(); + x->GetMutable(); } for (auto& x : tmp_vars_) { diff --git a/paddle/fluid/lite/core/tensor.h b/paddle/fluid/lite/core/tensor.h new file mode 100644 index 00000000000..59c35498e59 --- /dev/null +++ b/paddle/fluid/lite/core/tensor.h @@ -0,0 +1,168 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +/* + * This file defines the general interface for DDim and Tensor, which is used in + * server and mobile framework, to make the framework on the two devices share + * the same code, we clear up the methods and make the different implementations + * looks the same. + */ + +#include +#include "paddle/fluid/lite/core/target_wrapper.h" + +namespace paddle { +namespace lite { + +/* + * This class defines the basic interfaces of the DDims for server and mobile. + * For the DDims's implementation is too tedious, we add a simple implementation + * for mobile, and use this interface to share the framework both for mobile and + * server. + * + * The derived should implement following interfaces: + * ConstructFrom + * operator[] + * Vectorize + * size + */ +template +class DDimBase { + public: + using value_type = int64_t; + + DDimBase() = default; + + explicit DDimBase(const std::vector &x) { self()->ConstructFrom(x); } + value_type operator[](int offset) const { return (*self())[offset]; } + std::vector Vectorize() { return self()->Vectorize(); } + size_t size() const { return const_self()->size(); } + bool empty() const { return const_self()->empty(); } + + value_type production() const { + value_type res = 1; + for (int i = 0; i < const_self()->size(); i++) { + res *= (*const_self())[i]; + } + return res; + } + + DDimT Slice(int start, int end) const { + std::vector vec; + for (int i = start; i < end; i++) { + vec.push_back((*const_self())[i]); + } + return DDimT(vec); + } + + DDimT Flattern2D(int col) const { + return DDimT(std::vector( + {Slice(0, col).production(), Slice(col, size()).production()})); + } + + friend std::ostream &operator<<(std::ostream &os, const DDimT &dims) { + if (dims.empty()) { + os << "[]"; + return os; + } + + os << "["; + for (size_t i = 0; i < dims.size() - 1; i++) { + os << dims[i] << " "; + } + if (!dims.empty()) os << dims[dims.size() - 1]; + os << "]"; + return os; + } + + private: + DDimT *self() { return static_cast(this); } + const DDimT *const_self() const { return static_cast(this); } +}; + +/* + * This class defines the basic interfaces of the tensors implemented for + * server and mobile. It use the CRTR technology to accelerate the runtime + * performance. + */ +template +class TensorBase { + public: + TensorBase() = default; + TargetType target() const { return self()->target(); } + + template + T *mutable_data() { + return self()->template mutable_data(); + } + + template + T *mutable_data(TargetType target) { + return self()->template mutable_data(target); + } + + template + const T *data() { + return self()->template data(); + } + + template + void Resize(const DimT &dims) { + self()->Resize(dims); + } + + template + DDimT dims() { + return self()->dims(); + } + + template + const LoDT &lod() const { + return const_self()->lod(); + } + template + LoDT *mutable_lod() { + return self()->mutable_lod(); + } + template + const T &data() const { + return const_self()->data(); + } + + size_t data_size() const { return const_self()->dims().production(); } + + void ShareDataWith(const TensorBase &other) { self()->ShareDataWith(other); } + void CopyDataFrom(const TensorBase &other) { self()->CopyDataFrom(other); } + + friend std::ostream &operator<<(std::ostream &os, const TensorT &tensor) { + os << "Tensor:" << '\n'; + os << "dim: " << tensor.dims() << '\n'; + for (int i = 0; i < tensor.dims().production(); i++) { + os << tensor.template data()[i] << " "; + } + os << "\n"; + return os; + } + + private: + TensorT *self() { return static_cast(this); } + const TensorT *const_self() const { + return static_cast(this); + } +}; + +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/core/tensor_test.cc b/paddle/fluid/lite/core/tensor_test.cc index 247f2d73bf0..b9046822149 100644 --- a/paddle/fluid/lite/core/tensor_test.cc +++ b/paddle/fluid/lite/core/tensor_test.cc @@ -19,7 +19,7 @@ namespace paddle { namespace lite { TEST(tensor, test) { - Tensor tensor; + TensorBase tensor; tensor.Resize({1, 8}); for (int i = 0; i < 8; i++) { diff --git a/paddle/fluid/lite/core/type_system_test.cc b/paddle/fluid/lite/core/type_system_test.cc index 407fe96e49a..b26234b7c8b 100644 --- a/paddle/fluid/lite/core/type_system_test.cc +++ b/paddle/fluid/lite/core/type_system_test.cc @@ -19,7 +19,7 @@ namespace paddle { namespace lite { TEST(TypeSystem, test) { - ASSERT_TRUE(TypeSystem::Global().Contains()); + ASSERT_TRUE(TypeSystem::Global().Contains()); } TEST(TypeSystem, register_new) { diff --git a/paddle/fluid/lite/core/variable.h b/paddle/fluid/lite/core/variable.h index a0d0636066b..dc0a211d60b 100644 --- a/paddle/fluid/lite/core/variable.h +++ b/paddle/fluid/lite/core/variable.h @@ -29,7 +29,7 @@ class Variable { template T* GetMutable() { if (!blob_.is()) blob_.set(); - return &blob_.get(); + return blob_.get_mutable(); } template @@ -38,7 +38,7 @@ class Variable { } private: - variant blob_; + variant blob_; }; } // namespace lite diff --git a/paddle/fluid/lite/kernels/CMakeLists.txt b/paddle/fluid/lite/kernels/CMakeLists.txt index ebbfb2139e5..047b3820122 100644 --- a/paddle/fluid/lite/kernels/CMakeLists.txt +++ b/paddle/fluid/lite/kernels/CMakeLists.txt @@ -1,4 +1,4 @@ -set(lite_kernel_deps type_system kernel_lite op_registry_lite) +set(lite_kernel_deps type_system kernel_lite op_lite op_registry_lite ${tensor_lite}) add_subdirectory(host) add_subdirectory(arm) add_subdirectory(cuda) diff --git a/paddle/fluid/lite/kernels/cuda/CMakeLists.txt b/paddle/fluid/lite/kernels/cuda/CMakeLists.txt index 3d58e9911bd..bc51b35528f 100644 --- a/paddle/fluid/lite/kernels/cuda/CMakeLists.txt +++ b/paddle/fluid/lite/kernels/cuda/CMakeLists.txt @@ -2,7 +2,7 @@ if(NOT LITE_WITH_CUDA) return() endif() -nv_library(mul_compute_cuda SRCS mul_compute.cc DEPS tensor_lite) -cc_library(io_copy_compute_cuda SRCS io_copy_compute.cc DEPS tensor_lite) +nv_library(mul_compute_cuda SRCS mul_compute.cc DEPS ${tensor_lite}) +cc_library(io_copy_compute_cuda SRCS io_copy_compute.cc DEPS ${tensor_lite}) nv_library(kernels_cuda DEPS mul_compute_cuda io_copy_compute_cuda cuda_blas_lite) diff --git a/paddle/fluid/lite/kernels/cuda/io_copy_compute.cc b/paddle/fluid/lite/kernels/cuda/io_copy_compute.cc index 897cd67fc47..0a81afaf7b5 100644 --- a/paddle/fluid/lite/kernels/cuda/io_copy_compute.cc +++ b/paddle/fluid/lite/kernels/cuda/io_copy_compute.cc @@ -46,12 +46,11 @@ class IoCopyHostToCudaCompute public: void Run() override { auto& param = Param(); - CHECK(TensorGetTarget(*param.x) == TARGET(kHost) || - TensorGetTarget(*param.x) == TARGET(kX86)); - LOG(INFO) << "copy size " << param.x->memory_size(); - auto* data = TensorMutableData(param.y, TARGET(kCUDA), - param.x->memory_size()); - CopyFromHostSync(data, param.x->data(), param.x->memory_size()); + CHECK(param.x->target() == TARGET(kHost) || + param.x->target() == TARGET(kX86)); + LOG(INFO) << "copy size " << param.x->data_size(); + auto* data = param.y->mutable_data(TARGET(kCUDA)); + CopyFromHostSync(data, param.x->data(), param.x->data_size()); } std::unique_ptr GetTypeInferHandler() override { @@ -82,11 +81,10 @@ class IoCopyCudaToHostCompute public: void Run() override { auto& param = Param(); - CHECK(TensorGetTarget(*param.x) == TARGET(kCUDA)); - auto* data = TensorMutableData(param.y, TARGET(kHost), - param.x->memory_size()); - LOG(INFO) << "copy size " << param.x->memory_size(); - CopyToHostSync(data, param.x->data(), param.x->memory_size()); + CHECK(param.x->target() == TARGET(kCUDA)); + auto* data = param.y->mutable_data(); + LOG(INFO) << "copy size " << param.x->data_size(); + CopyToHostSync(data, param.x->data(), param.x->data_size()); } std::string doc() const override { return "Copy IO from CUDA to HOST"; } diff --git a/paddle/fluid/lite/kernels/cuda/mul_compute.h b/paddle/fluid/lite/kernels/cuda/mul_compute.h index 90cbe0e3fe2..c8f323e0537 100644 --- a/paddle/fluid/lite/kernels/cuda/mul_compute.h +++ b/paddle/fluid/lite/kernels/cuda/mul_compute.h @@ -51,9 +51,8 @@ class MulCompute : public KernelLite { */ const auto& param = Param(); - TensorMutableData(param.output, TARGET(kCUDA), - product(param.output->dims())); - LOG(INFO) << "mul output memory size " << param.output->memory_size(); + param.output->mutable_data(TARGET(kCUDA)); + LOG(INFO) << "mul output memory size " << param.output->data_size(); // mul_compute(blas, x, x_h, x_w, y, y_h, y_w, out); } diff --git a/paddle/fluid/lite/kernels/host/fc_compute.cc b/paddle/fluid/lite/kernels/host/fc_compute.cc index aad74377c37..7b84720c803 100644 --- a/paddle/fluid/lite/kernels/host/fc_compute.cc +++ b/paddle/fluid/lite/kernels/host/fc_compute.cc @@ -29,16 +29,17 @@ void FcCompute::Run() { CHECK_GE(param.input->dims().size(), 2UL); CHECK_EQ(param.output->dims().size(), 2UL); - fc_compute_eigen(param.input->data(), // x - product(param.input->dims(), 0, param.in_num_col_dims), - product(param.input->dims(), param.in_num_col_dims, - param.input->dims().size()), - param.w->data(), // w - param.w->dims()[1], // w_w - param.w->dims()[0], // w_h - param.bias->data(), // b - TensorMutableData(param.output, TARGET(kHost), - product(param.output->dims()))); + fc_compute_eigen( + param.input->data(), // x + param.input->dims().Slice(0, param.in_num_col_dims).production(), + param.input->dims() + .Slice(param.in_num_col_dims, param.input->dims().size()) + .production(), + param.w->data(), // w + param.w->dims()[1], // w_w + param.w->dims()[0], // w_h + param.bias->data(), // b + param.output->mutable_data()); } // TargetType FcCompute::target() const { return TARGET(kHost); } diff --git a/paddle/fluid/lite/kernels/host/fc_compute_test.cc b/paddle/fluid/lite/kernels/host/fc_compute_test.cc index 474965e2777..56871fb30db 100644 --- a/paddle/fluid/lite/kernels/host/fc_compute_test.cc +++ b/paddle/fluid/lite/kernels/host/fc_compute_test.cc @@ -23,7 +23,7 @@ namespace kernels { namespace host { TEST(fc_compute_naive, test) { - Tensor x, w, b, out, out1; + TensorBase x, w, b, out, out1; const int batch_size = 2; x.Resize({batch_size, 3}); w.Resize({4, 3}); @@ -79,10 +79,10 @@ TEST(fc_host, compute) { FcCompute fc; operators::FcParam param; - Tensor x; - Tensor w; - Tensor bias; - Tensor output; + TensorBase x; + TensorBase w; + TensorBase bias; + TensorBase output; x.Resize({1, 10, 20}); w.Resize({20, 20}); diff --git a/paddle/fluid/lite/kernels/host/feed_compute.cc b/paddle/fluid/lite/kernels/host/feed_compute.cc index 38fca30998c..02da6b2672f 100644 --- a/paddle/fluid/lite/kernels/host/feed_compute.cc +++ b/paddle/fluid/lite/kernels/host/feed_compute.cc @@ -27,7 +27,9 @@ class FeedCompute void Run() override { auto ¶m = Param(); - const Tensor &feed_item = param.feed_list->at(param.col); + LOG(INFO) << "feed_list.size: " << param.feed_list->size(); + LOG(INFO) << "col " << param.col; + const lite::Tensor &feed_item = (*param.feed_list)[0]; param.out->ShareDataWith(feed_item); LOG(INFO) << "FEED input " << feed_item << " col " << param.col; LOG(INFO) << "FEED output " << *param.out; diff --git a/paddle/fluid/lite/kernels/host/mul_compute.cc b/paddle/fluid/lite/kernels/host/mul_compute.cc index 7715e588e6f..34ec07a1c67 100644 --- a/paddle/fluid/lite/kernels/host/mul_compute.cc +++ b/paddle/fluid/lite/kernels/host/mul_compute.cc @@ -41,18 +41,24 @@ class MulCompute : public KernelLite { void Run() override { auto& param = Param(); - core::dim2 x_shape({product(param.x->dims(), 0, param.x_num_col_dims), - product(param.x->dims(), param.x_num_col_dims, - param.x->dims().size())}); - - core::dim2 y_shape({product(param.y->dims(), 0, param.y_num_col_dims), - product(param.y->dims(), param.y_num_col_dims, - param.y->dims().size())}); + core::dim2 x_shape( + {static_cast( + param.x->dims().Slice(0, param.x_num_col_dims).production()), + static_cast( + param.x->dims() + .Slice(param.x_num_col_dims, param.x->dims().size()) + .production())}); + core::dim2 y_shape( + {static_cast( + param.y->dims().Slice(0, param.y_num_col_dims).production()), + static_cast( + param.y->dims() + .Slice(param.y_num_col_dims, param.y->dims().size()) + .production())}); mul_compute_eigen(param.x->data(), x_shape.x, x_shape.y, // param.y->data(), y_shape.x, y_shape.y, // - TensorMutableData(param.output, TARGET(kHost), - product(param.output->dims()))); + param.output->mutable_data()); LOG(INFO) << "MUL x " << *param.x; LOG(INFO) << "MUL W " << *param.y; LOG(INFO) << "MUL out " << *param.output; diff --git a/paddle/fluid/lite/kernels/host/relu_compute.h b/paddle/fluid/lite/kernels/host/relu_compute.h index 276535120d7..5a1fd41c172 100644 --- a/paddle/fluid/lite/kernels/host/relu_compute.h +++ b/paddle/fluid/lite/kernels/host/relu_compute.h @@ -25,10 +25,9 @@ class ReluCompute : public KernelLite { public: void Run() override { auto& param = Param(); - auto n = product(param.input->dims()); + auto n = param.input->dims().production(); const float* input = param.input->data(); - float* output = TensorMutableData(param.output, TARGET(kHost), - product(param.output->dims())); + float* output = param.output->mutable_data(); for (int i = 0; i < n; i++) { output[i] = std::max(0.f, input[i]); } diff --git a/paddle/fluid/lite/kernels/host/scale_compute.cc b/paddle/fluid/lite/kernels/host/scale_compute.cc index de1b59e7e09..78281ba8a61 100644 --- a/paddle/fluid/lite/kernels/host/scale_compute.cc +++ b/paddle/fluid/lite/kernels/host/scale_compute.cc @@ -37,10 +37,8 @@ class ScaleCompute : public KernelLite { void Run() override { auto& param = Param(); - scale_compute(param.x->data(), - TensorMutableData(param.output, TARGET(kHost), - product(param.output->dims())), - product(param.x->dims()), param.scale, param.bias, + scale_compute(param.x->data(), param.output->mutable_data(), + param.x->dims().production(), param.scale, param.bias, param.bias_after_scale); } diff --git a/paddle/fluid/lite/model_parser/CMakeLists.txt b/paddle/fluid/lite/model_parser/CMakeLists.txt index 7b8f1534cfd..0d0014a0599 100644 --- a/paddle/fluid/lite/model_parser/CMakeLists.txt +++ b/paddle/fluid/lite/model_parser/CMakeLists.txt @@ -6,7 +6,7 @@ else() cc_library(compatible_pb_lite SRCS compatible_pb.cc DEPS framework_proto proto_desc) endif(LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) -set(model_parser_deps variable_lite scope_lite tensor_lite scope_lite +set(model_parser_deps variable_lite scope_lite ${tensor_lite} scope_lite target_wrapper_host compatible_pb_lite ) diff --git a/paddle/fluid/lite/model_parser/model_parser.cc b/paddle/fluid/lite/model_parser/model_parser.cc index 59aec582749..54e99cfb1a8 100644 --- a/paddle/fluid/lite/model_parser/model_parser.cc +++ b/paddle/fluid/lite/model_parser/model_parser.cc @@ -58,19 +58,20 @@ void TensorFromStream(std::istream &is, lite::Tensor *tensor) { } // read tensor - std::vector dims; - std::copy(desc.dims().begin(), desc.dims().end(), std::back_inserter(dims)); - tensor->Resize(lite::DDim(&dims[0], dims.size())); + std::vector dims_vec; + std::copy(desc.dims().begin(), desc.dims().end(), + std::back_inserter(dims_vec)); + lite::DDim dims(dims_vec); + tensor->Resize(dims); void *buf; - size_t size = product(tensor->dims()) * SizeOfType(desc.data_type()); + size_t size = tensor->dims().production() * SizeOfType(desc.data_type()); // alllocate memory switch (static_cast(desc.data_type())) { -#define DO(desc, type) \ - case Type::VarType_Type_##desc: \ - buf = TensorMutableData(tensor, TensorGetTarget(*tensor), \ - product(tensor->dims())); +#define DO(desc, type) \ + case Type::VarType_Type_##desc: \ + buf = tensor->mutable_data(); \ break; - DO(BOOL, bool); + // DO(BOOL, bool); DO(FP32, float); DO(INT8, int8_t); DO(INT16, int16_t); @@ -198,7 +199,7 @@ void TensorToStream(std::ostream &os, const lite::Tensor &tensor) { auto dims = tensor.dims(); auto *pb_dims = desc.mutable_dims(); pb_dims->Resize(static_cast(dims.size()), 0); - auto dims_vec = DDimVectorize(dims); + auto dims_vec = dims.Vectorize(); std::copy(dims_vec.begin(), dims_vec.end(), pb_dims->begin()); int32_t size = desc.ByteSize(); os.write(reinterpret_cast(&size), sizeof(size)); @@ -206,15 +207,15 @@ void TensorToStream(std::ostream &os, const lite::Tensor &tensor) { os.write(out.data(), size); } { // the 3rd field, tensor data - uint64_t size = tensor.memory_size(); + uint64_t size = tensor.data_size(); CHECK_LT(size, std::numeric_limits::max()) << "Index overflow when writing tensor"; #ifdef LITE_WITH_CUDA - if (TensorGetTarget(tensor) == TARGET(kCUDA)) { + if (tensor.target() == TARGET(kCUDA)) { std::unique_ptr tmp_buffer(new char[size]); TargetWrapperCuda::MemcpySync(tmp_buffer.get(), tensor.data(), - tensor.memory_size(), IoDirection::DtoH); + tensor.data_size(), IoDirection::DtoH); os.write(static_cast(tmp_buffer.get()), static_cast(size)); } else diff --git a/paddle/fluid/lite/model_parser/model_parser_test.cc b/paddle/fluid/lite/model_parser/model_parser_test.cc index b5d721809a6..bab10dee409 100644 --- a/paddle/fluid/lite/model_parser/model_parser_test.cc +++ b/paddle/fluid/lite/model_parser/model_parser_test.cc @@ -28,7 +28,7 @@ TEST(ModelParser, LoadParam) { Scope scope; auto* v = scope.Var("xxx"); LoadParam("/home/chunwei/project2/models/fc/fluid_checkpoint/b1", v); - const auto& t = v->Get(); + const auto& t = v->Get(); LOG(INFO) << "loaded\n"; LOG(INFO) << t; } diff --git a/paddle/fluid/lite/operators/CMakeLists.txt b/paddle/fluid/lite/operators/CMakeLists.txt index 8b0c1b236f2..d356b68fb91 100644 --- a/paddle/fluid/lite/operators/CMakeLists.txt +++ b/paddle/fluid/lite/operators/CMakeLists.txt @@ -1,12 +1,13 @@ -cc_library(fc_op_lite SRCS fc_op.cc DEPS op_lite op_params_lite tensor_lite) -cc_library(relu_op_lite SRCS relu_op.cc DEPS op_lite) -cc_library(mul_op_lite SRCS mul_op.cc DEPS op_lite) -cc_library(scale_op_lite SRCS scale_op.cc DEPS op_lite) -cc_library(feed_op_lite SRCS feed_op.cc DEPS op_lite) -cc_library(fetch_op_lite SRCS fetch_op.cc DEPS op_lite) -cc_library(io_copy_op_lite SRCS io_copy_op.cc DEPS op_lite) +set(op_DEPS ${tensor_lite} op_lite op_params_lite) +cc_library(fc_op_lite SRCS fc_op.cc DEPS ${op_DEPS}) +cc_library(relu_op_lite SRCS relu_op.cc DEPS ${op_DEPS}) +cc_library(mul_op_lite SRCS mul_op.cc DEPS ${op_DEPS}) +cc_library(scale_op_lite SRCS scale_op.cc DEPS ${op_DEPS}) +cc_library(feed_op_lite SRCS feed_op.cc DEPS ${op_DEPS}) +cc_library(fetch_op_lite SRCS fetch_op.cc DEPS ${op_DEPS}) +cc_library(io_copy_op_lite SRCS io_copy_op.cc DEPS ${op_DEPS}) -cc_library(op_params_lite SRCS op_params.cc DEPS tensor_lite) +cc_library(op_params_lite SRCS op_params.cc DEPS ${tensor_lite}) cc_library(ops_lite DEPS fc_op_lite relu_op_lite diff --git a/paddle/fluid/lite/operators/fc_op.cc b/paddle/fluid/lite/operators/fc_op.cc index 03c91d1c36c..87d7f35c771 100644 --- a/paddle/fluid/lite/operators/fc_op.cc +++ b/paddle/fluid/lite/operators/fc_op.cc @@ -42,7 +42,7 @@ bool FcOpLite::CheckShape() const { CHECK_GT_OR_FALSE(input_dims.size(), static_cast(param_.in_num_col_dims)); - param_.in_mat_dims = lite::flatten_to_2d(input_dims, param_.in_num_col_dims); + param_.in_mat_dims = input_dims.Flattern2D(param_.in_num_col_dims); // CHECK_EQ_OR_FALSE(param_.in_mat_dims[1], w_dims[0]); return true; @@ -58,7 +58,7 @@ bool FcOpLite::InferShape() const { output_dims[i] = input_dims[i]; } output_dims.back() = w_dims[1]; - param_.output->Resize(DDim(&output_dims[0], output_dims.size())); + param_.output->Resize(lite::DDim(output_dims)); // share LoD // param_.output->set_lod(param_.input->lod()); diff --git a/paddle/fluid/lite/operators/fc_op.h b/paddle/fluid/lite/operators/fc_op.h index d5379b8344a..c24e5cd25df 100644 --- a/paddle/fluid/lite/operators/fc_op.h +++ b/paddle/fluid/lite/operators/fc_op.h @@ -52,11 +52,11 @@ class FcOpLite : public OpLite { auto bias = op_desc.Input("Bias").front(); auto out = op_desc.Output("Out").front(); - param_.input = scope->FindVar(input)->GetMutable(); - param_.w = scope->FindVar(W)->GetMutable(); - param_.bias = scope->FindVar(bias)->GetMutable(); + param_.input = scope->FindVar(input)->GetMutable(); + param_.w = scope->FindVar(W)->GetMutable(); + param_.bias = scope->FindVar(bias)->GetMutable(); CHECK(scope->FindVar(out)); - param_.output = scope->FindVar(out)->GetMutable(); + param_.output = scope->FindVar(out)->GetMutable(); param_.in_num_col_dims = GetAttr(op_desc.GetAttr("in_num_col_dims")); CHECK(kernel_); diff --git a/paddle/fluid/lite/operators/fc_op_test.cc b/paddle/fluid/lite/operators/fc_op_test.cc index 54914b5ab19..278f8d8f8f9 100644 --- a/paddle/fluid/lite/operators/fc_op_test.cc +++ b/paddle/fluid/lite/operators/fc_op_test.cc @@ -24,10 +24,10 @@ TEST(fc_op_lite, test) { LOG(INFO) << "\n" << KernelRegistry::Global().DebugString(); // prepare variables Scope scope; - auto* x = scope.Var("x")->GetMutable(); - auto* w = scope.Var("w")->GetMutable(); - auto* bias = scope.Var("bias")->GetMutable(); - auto* output = scope.Var("output")->GetMutable(); + auto* x = scope.Var("x")->GetMutable(); + auto* w = scope.Var("w")->GetMutable(); + auto* bias = scope.Var("bias")->GetMutable(); + auto* output = scope.Var("output")->GetMutable(); x->Resize({1, 10, 20}); w->Resize({20, 20}); bias->Resize({1, 10}); diff --git a/paddle/fluid/lite/operators/feed_op.cc b/paddle/fluid/lite/operators/feed_op.cc index 03ea820f49a..26ca59dd0d5 100644 --- a/paddle/fluid/lite/operators/feed_op.cc +++ b/paddle/fluid/lite/operators/feed_op.cc @@ -39,13 +39,13 @@ class FeedOp : public OpLite { auto feed_var_name = opdesc.Input("X").front(); auto* feed_var = scope->FindVar(feed_var_name); CHECK(feed_var); - auto& feed_tensor_list = feed_var->Get>(); + auto& feed_tensor_list = feed_var->Get>(); param_.feed_list = &feed_tensor_list; auto out_name = opdesc.Output("Out").front(); auto* out_var = scope->FindVar(out_name); CHECK(out_var); - param_.out = out_var->GetMutable(); + param_.out = out_var->GetMutable(); // NOTE need boost here // TODO(Superjomn) drop the need of framework::op_desc diff --git a/paddle/fluid/lite/operators/fetch_op.cc b/paddle/fluid/lite/operators/fetch_op.cc index ea86d6a2f75..337a6ecc9d5 100644 --- a/paddle/fluid/lite/operators/fetch_op.cc +++ b/paddle/fluid/lite/operators/fetch_op.cc @@ -37,7 +37,7 @@ class FetchOp : public OpLite { auto _x = opdesc.Input("X").front(); auto* x = scope->FindVar(_x); CHECK(x); - param_.input = &x->Get(); + param_.input = &x->Get(); auto _out = opdesc.Output("Out").front(); auto* out = scope->FindVar(_out); diff --git a/paddle/fluid/lite/operators/mul_op.cc b/paddle/fluid/lite/operators/mul_op.cc index c79f16dbff0..b78ae4578a6 100644 --- a/paddle/fluid/lite/operators/mul_op.cc +++ b/paddle/fluid/lite/operators/mul_op.cc @@ -45,7 +45,7 @@ bool MulOpLite::InferShape() const { } out_dims.back() = y_dims[1]; - param_.output->Resize(DDim(&out_dims[0], out_dims.size())); + param_.output->Resize(lite::DDim(out_dims)); // share LoD // param_.output->set_lod(param_.input->lod()); diff --git a/paddle/fluid/lite/operators/op_params.h b/paddle/fluid/lite/operators/op_params.h index c3f716906a1..ed37d6cef16 100644 --- a/paddle/fluid/lite/operators/op_params.h +++ b/paddle/fluid/lite/operators/op_params.h @@ -25,36 +25,36 @@ namespace lite { namespace operators { struct FeedParam { - const std::vector* feed_list{}; - Tensor* out{}; + const std::vector* feed_list{}; + lite::Tensor* out{}; int col; }; struct FetchParam { - const Tensor* input{}; - std::vector* fetch_list{}; + const lite::Tensor* input{}; + std::vector* fetch_list{}; int col; }; struct FcParam { - Tensor* input{}; - Tensor* w{}; - Tensor* bias{}; - Tensor* output{}; - DDim in_mat_dims; + lite::Tensor* input{}; + lite::Tensor* w{}; + lite::Tensor* bias{}; + lite::Tensor* output{}; + lite::DDim in_mat_dims; int in_num_col_dims{1}; }; struct ReluParam { - Tensor* input{}; - Tensor* output{}; + lite::Tensor* input{}; + lite::Tensor* output{}; }; // For Mul Op struct MulParam { - Tensor* x{}; - Tensor* y{}; - Tensor* output{}; + lite::Tensor* x{}; + lite::Tensor* y{}; + lite::Tensor* output{}; int x_num_col_dims{1}; int y_num_col_dims{1}; @@ -62,8 +62,8 @@ struct MulParam { // For Scale Op struct ScaleParam { - Tensor* x{}; - Tensor* output{}; + lite::Tensor* x{}; + lite::Tensor* output{}; float scale{1.}; float bias{}; @@ -71,8 +71,8 @@ struct ScaleParam { }; struct IoCopyParam { - const Tensor* x{}; - Tensor* y{}; + const lite::Tensor* x{}; + lite::Tensor* y{}; }; using param_t = variant( - &scope->FindVar(opdesc.Input("Input").front())->Get()); + param_.input = const_cast( + &scope->FindVar(opdesc.Input("Input").front())->Get()); param_.output = - scope->FindVar(opdesc.Output("Out").front())->GetMutable(); + scope->FindVar(opdesc.Output("Out").front())->GetMutable(); CHECK(param_.input); CHECK(param_.output); kernel_->SetParam(param_); diff --git a/paddle/fluid/lite/utils/varient.h b/paddle/fluid/lite/utils/varient.h index 40290f1fcef..fdd59502fcc 100644 --- a/paddle/fluid/lite/utils/varient.h +++ b/paddle/fluid/lite/utils/varient.h @@ -109,10 +109,21 @@ struct variant { type_id = typeid(T).hash_code(); } template - T& get() { + const T& get() const { // It is a dynamic_cast-like behaviour if (type_id == typeid(T).hash_code()) - return *reinterpret_cast(&data); + return *reinterpret_cast(&data); + else + LOG(FATAL) << "unmatched type get, should be " << type_id << " but get " + << typeid(T).name(); + return *reinterpret_cast(&data); + } + + template + T* get_mutable() { + // It is a dynamic_cast-like behaviour + if (type_id == typeid(T).hash_code()) + return reinterpret_cast(&data); else LOG(FATAL) << "unmatched type get, should be " << type_id << " but get " << typeid(T).name(); -- GitLab