提交 706e83af 编写于 作者: S superjomn

make an adapter for TensorLite and framework::LoDTensor and DDim

上级 e88d6418
...@@ -45,17 +45,17 @@ class LightPredictor { ...@@ -45,17 +45,17 @@ class LightPredictor {
void SaveModel(const std::string& dir); void SaveModel(const std::string& dir);
// Get offset-th col of feed. // Get offset-th col of feed.
Tensor* GetInput(size_t offset) { lite::Tensor* GetInput(size_t offset) {
auto* _feed_list = program_->exec_scope()->FindVar("feed"); auto* _feed_list = program_->exec_scope()->FindVar("feed");
CHECK(_feed_list) << "no feed variable in exec_scope"; CHECK(_feed_list) << "no feed variable in exec_scope";
auto* feed_list = _feed_list->GetMutable<std::vector<Tensor>>(); auto* feed_list = _feed_list->GetMutable<std::vector<lite::Tensor>>();
if (offset >= feed_list->size()) { if (offset >= feed_list->size()) {
feed_list->resize(offset + 1); feed_list->resize(offset + 1);
} }
return &feed_list->at(offset); return &feed_list->at(offset);
} }
const Tensor* GetOutput(size_t offset) { const lite::Tensor* GetOutput(size_t offset) {
auto* _fetch_list = program_->exec_scope()->FindVar("fetch"); auto* _fetch_list = program_->exec_scope()->FindVar("fetch");
CHECK(_fetch_list) << "no fatch variable in exec_scope"; CHECK(_fetch_list) << "no fatch variable in exec_scope";
auto& fetch_list = *_fetch_list->GetMutable<std::vector<lite::Tensor>>(); auto& fetch_list = *_fetch_list->GetMutable<std::vector<lite::Tensor>>();
......
...@@ -13,11 +13,14 @@ ...@@ -13,11 +13,14 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/lite/api/cxx_api.h" #include "paddle/fluid/lite/api/cxx_api.h"
#include <gflags/gflags.h>
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include "paddle/fluid/lite/core/mir/passes.h" #include "paddle/fluid/lite/core/mir/passes.h"
#include "paddle/fluid/lite/core/op_executor.h" #include "paddle/fluid/lite/core/op_executor.h"
#include "paddle/fluid/lite/core/op_registry.h" #include "paddle/fluid/lite/core/op_registry.h"
DEFINE_string(model_dir, "", "");
namespace paddle { namespace paddle {
namespace lite { namespace lite {
...@@ -36,24 +39,22 @@ TEST(CXXApi, test) { ...@@ -36,24 +39,22 @@ TEST(CXXApi, test) {
}); });
#endif #endif
predictor.Build("/home/chunwei/project/models/model2", predictor.Build(FLAGS_model_dir, Place{TARGET(kCUDA), PRECISION(kFloat)},
Place{TARGET(kCUDA), PRECISION(kFloat)}, valid_places); valid_places);
auto* input_tensor = predictor.GetInput(0); auto* input_tensor = predictor.GetInput(0);
input_tensor->Resize({100, 100}); input_tensor->Resize(DDim(std::vector<DDim::value_type>({100, 100})));
auto* data = TensorMutableData<float>(input_tensor, TARGET(kHost), auto* data = input_tensor->mutable_data<float>();
product(input_tensor->dims()));
for (int i = 0; i < 100 * 100; i++) { for (int i = 0; i < 100 * 100; i++) {
data[i] = i; data[i] = i;
} }
LOG(INFO) << "input " << input_tensor;
LOG(INFO) << "input " << *input_tensor; LOG(INFO) << "input " << *input_tensor;
predictor.Run(); predictor.Run();
auto* out = predictor.GetOutput(0); auto* out = predictor.GetOutput(0);
LOG(INFO) << out << " memory size " << out->memory_size(); LOG(INFO) << out << " memory size " << out->data_size();
LOG(INFO) << "out " << out->data<float>()[0]; LOG(INFO) << "out " << out->data<float>()[0];
LOG(INFO) << "out " << out->data<float>()[1]; LOG(INFO) << "out " << out->data<float>()[1];
LOG(INFO) << "dims " << out->dims(); LOG(INFO) << "dims " << out->dims();
...@@ -63,8 +64,8 @@ TEST(CXXApi, test) { ...@@ -63,8 +64,8 @@ TEST(CXXApi, test) {
TEST(CXXApi, save_model) { TEST(CXXApi, save_model) {
lite::LightPredictor predictor; lite::LightPredictor predictor;
std::vector<Place> valid_places({Place{TARGET(kHost), PRECISION(kFloat)}}); std::vector<Place> valid_places({Place{TARGET(kHost), PRECISION(kFloat)}});
predictor.Build("/home/chunwei/project/models/model2", predictor.Build(FLAGS_model_dir, Place{TARGET(kCUDA), PRECISION(kFloat)},
Place{TARGET(kCUDA), PRECISION(kFloat)}, valid_places); valid_places);
predictor.SaveModel("./optimized_model"); predictor.SaveModel("./optimized_model");
} }
......
...@@ -41,20 +41,21 @@ class LightPredictor { ...@@ -41,20 +41,21 @@ class LightPredictor {
void Run() { program_->Run(); } void Run() { program_->Run(); }
// Get offset-th col of feed. // Get offset-th col of feed.
Tensor* GetInput(size_t offset) { TensorBase* GetInput(size_t offset) {
auto* _feed_list = program_->exec_scope()->FindVar("feed"); auto* _feed_list = program_->exec_scope()->FindVar("feed");
CHECK(_feed_list) << "no feed variable in exec_scope"; CHECK(_feed_list) << "no feed variable in exec_scope";
auto* feed_list = _feed_list->GetMutable<std::vector<Tensor>>(); auto* feed_list = _feed_list->GetMutable<std::vector<TensorBase>>();
if (offset >= feed_list->size()) { if (offset >= feed_list->size()) {
feed_list->resize(offset + 1); feed_list->resize(offset + 1);
} }
return &feed_list->at(offset); return &feed_list->at(offset);
} }
const Tensor* GetOutput(size_t offset) { const TensorBase* GetOutput(size_t offset) {
auto* _fetch_list = program_->exec_scope()->FindVar("fetch"); auto* _fetch_list = program_->exec_scope()->FindVar("fetch");
CHECK(_fetch_list) << "no fatch variable in exec_scope"; CHECK(_fetch_list) << "no fatch variable in exec_scope";
auto& fetch_list = *_fetch_list->GetMutable<std::vector<lite::Tensor>>(); auto& fetch_list =
*_fetch_list->GetMutable<std::vector<lite::TensorBase>>();
CHECK_LT(offset, fetch_list.size()) << "offset " << offset << " overflow"; CHECK_LT(offset, fetch_list.size()) << "offset " << offset << " overflow";
return &fetch_list.at(offset); return &fetch_list.at(offset);
} }
......
...@@ -2,22 +2,26 @@ cc_library(lite_gtest_main SRCS lite_gtest_main.cc) ...@@ -2,22 +2,26 @@ cc_library(lite_gtest_main SRCS lite_gtest_main.cc)
cc_library(memory_lite SRCS memory.cc) cc_library(memory_lite SRCS memory.cc)
cc_library(target_wrapper_lite SRCS target_wrapper.cc) cc_library(target_wrapper_lite SRCS target_wrapper.cc)
cc_library(lite_tensor SRCS lite_tensor.cc DEPS memory_lite target_wrapper_lite)
cc_library(hvy_tensor SRCS hvy_tensor.cc DEPS lod_tensor)
if(LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) if(LITE_WITH_LIGHT_WEIGHT_FRAMEWORK)
cc_library(tensor_lite SRCS lite_tensor.cc DEPS memory_lite target_wrapper_lite) set(tensor_lite lite_tensor)
else() else()
cc_library(tensor_lite DEPS lod_tensor) set(tensor_lite hvy_tensor)
endif() endif()
cc_library(kernel_lite SRCS kernel.cc DEPS type_system target_wrapper_lite) cc_library(kernel_lite SRCS kernel.cc DEPS type_system target_wrapper_lite)
cc_library(variable_lite SRCS variable.cc) cc_library(variable_lite SRCS variable.cc)
cc_library(op_registry_lite SRCS op_registry.cc) cc_library(op_registry_lite SRCS op_registry.cc)
cc_library(scope_lite SRCS scope.cc) cc_library(scope_lite SRCS scope.cc)
cc_library(op_lite SRCS op_lite.cc DEPS scope_lite op_registry_lite compatible_pb_lite) cc_library(op_lite SRCS op_lite.cc DEPS scope_lite op_registry_lite compatible_pb_lite)
cc_library(op_executor_lite SRCS op_executor.cc DEPS scope_lite tensor_lite op_lite op_registry_lite cc_library(op_executor_lite SRCS op_executor.cc DEPS scope_lite ${tensor_lite} op_lite op_registry_lite
#TODO(Superjomn) remove these dependencies from original framework #TODO(Superjomn) remove these dependencies from original framework
) )
cc_library(kernel_executor_lite SRCS kernel_executor.cc DEPS mir_ssa_graph kernel_lite) cc_library(kernel_executor_lite SRCS kernel_executor.cc DEPS mir_ssa_graph kernel_lite)
cc_library(types_lite SRCS types.cc) cc_library(types_lite SRCS types.cc)
cc_library(type_system SRCS type_system.cc DEPS tensor_lite) cc_library(type_system SRCS type_system.cc DEPS ${tensor_lite})
cc_library(program_fake_utils SRCS program_fake_utils.cc DEPS mir_ssa_graph cc_library(program_fake_utils SRCS program_fake_utils.cc DEPS mir_ssa_graph
scope_lite op_registry_lite proto_desc op_lite scope_lite op_registry_lite proto_desc op_lite
ops_lite ops_lite
......
...@@ -14,83 +14,24 @@ ...@@ -14,83 +14,24 @@
#pragma once #pragma once
#include <vector> #include "paddle/fluid/lite/core/tensor.h"
#include "paddle/fluid/lite/core/target_wrapper.h"
#ifdef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK #ifdef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
#include "paddle/fluid/lite/core/lite_tensor.h" #include "paddle/fluid/lite/core/lite_tensor.h"
#else #else
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/lite/core/hvy_tensor.h"
#endif #endif
namespace paddle { namespace paddle {
namespace lite { namespace lite {
#ifdef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK #ifdef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
using Tensor = details::Tensor; using DDim = lite::DDimLite;
using DDim = details::DDim; using Tensor = lite::TensorLite;
#else
using Tensor = framework::LoDTensor;
using DDim = framework::DDim;
static TargetType TensorGetTarget(const Tensor &x) {
if (platform::is_gpu_place(x.place())) {
return TARGET(kCUDA);
} else if (platform::is_cpu_place(x.place())) {
return TARGET(kX86);
}
return TARGET(kUnk);
}
template <typename T>
T *TensorMutableData(Tensor *x, TargetType target, size_t size) {
if (target == TARGET(kX86) || target == TARGET(kHost)) {
return x->mutable_data<T>(platform::CPUPlace(), memory::Allocator::kDefault,
size);
} else if (target == TARGET(kCUDA)) {
return x->mutable_data<T>(platform::CUDAPlace(),
memory::Allocator::kDefault, size);
}
LOG(FATAL) << "not valid target " << TargetToStr(target);
return nullptr;
}
#endif
static int product(const DDim &dims, int start, int end) {
int res = 1;
for (int i = start; i < end; i++) {
res *= dims[i];
}
return res;
}
static DDim SliceDims(const DDim &dims, int begin, int end) {
#ifdef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
return DDim(dims[0] + begin, dims.begin() + end - 1);
#else #else
auto vec = framework::vectorize(dims); using DDim = lite::DDimHvy;
return DDim(&vec[0] + begin, end - begin); using Tensor = lite::TensorHvy;
#endif #endif
}
static std::vector<int64_t> DDimVectorize(const DDim &x) {
#ifdef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
return x;
#else
return framework::vectorize(x);
#endif
}
#ifdef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
static int product(const DDim &dims) {
return std::accumulate(dims.begin(), dims.end(), 1,
[](int a, int b) { return a * b; });
}
#endif
static DDim flatten_to_2d(const DDim &dims, int col) {
return DDim({product(SliceDims(dims, 0, col)),
product(SliceDims(dims, col, dims.size()))});
}
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/hvy_tensor.h"
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/*
* This file defines the heavy tensor (alias of the LoDTensor in the server
* framework). We derive it from the TensorLite interface, so the lite framework
* can share much code between the server side and mobile side.
*/
#pragma once
#include <vector>
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/lite/core/tensor.h"
namespace paddle {
namespace lite {
class DDimHvy : public DDimBase<DDimHvy> {
public:
DDimHvy() = default;
explicit DDimHvy(const std::vector<value_type>& x) : DDimBase<DDimHvy>() {
ConstructFrom(x);
}
explicit DDimHvy(const framework::DDim& x) : data_(x) {}
void ConstructFrom(const std::vector<value_type>& xs) {
data_ = framework::DDim(xs.data(), xs.size());
}
value_type operator[](int offset) const { return data_[offset]; }
std::vector<int64_t> Vectorize() const { return framework::vectorize(data_); }
const framework::DDim& data() const { return data_; }
size_t size() const { return data_.size(); }
bool empty() const { return data_.size() == 0; }
private:
framework::DDim data_;
};
class TensorHvy : public TensorBase<TensorHvy> {
public:
using DDimT = DDimHvy;
using LoDT = framework::LoD;
TargetType target() const {
if (platform::is_gpu_place(data_.place())) {
return TARGET(kCUDA);
} else if (platform::is_cpu_place(data_.place())) {
return TARGET(kX86);
}
LOG(FATAL) << "Unknown place";
return TARGET(kUnk);
}
template <typename T>
T* mutable_data() {
return data_.mutable_data<T>(data_.dims(), platform::CPUPlace());
}
template <typename T>
T* mutable_data(TargetType target) {
if (target == TARGET(kCUDA)) {
return data_.mutable_data<T>(data_.dims(), platform::CUDAPlace());
}
return data_.mutable_data<T>(data_.dims(), platform::CPUPlace());
}
template <typename T>
const T* data() const {
return data_.data<T>();
}
template <typename DimT>
void Resize(const DimT& dims) {
LOG(INFO) << "dims.size " << dims.size();
data_.Resize(framework::make_ddim(dims.Vectorize()));
}
void ShareDataWith(const TensorHvy& other) {
data_.ShareDataWith(other.data_);
}
void CopyDataFrom(const TensorHvy& other) {
data_.ShareDataWith(other.data_);
}
DDimT dims() const { return DDimT(framework::vectorize(data_.dims())); }
const framework::LoD& lod() const { return data_.lod(); }
framework::LoD* mutable_lod() { return data_.mutable_lod(); }
private:
framework::LoDTensor data_;
};
} // namespace lite
} // namespace paddle
...@@ -52,7 +52,7 @@ class KernelBase { ...@@ -52,7 +52,7 @@ class KernelBase {
} }
template <typename P> template <typename P>
P& Param() const { P& Param() const {
return param_.get<P>(); return *param_.get_mutable<P>();
} }
// This is used in the kernels that takes 'kAny' places and inference the // This is used in the kernels that takes 'kAny' places and inference the
......
...@@ -12,10 +12,12 @@ ...@@ -12,10 +12,12 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include <gflags/gflags.h>
#include <gtest/gtest.h> #include <gtest/gtest.h>
int main(int argc, char** argv) { int main(int argc, char** argv) {
testing::InitGoogleTest(&argc, argv); testing::InitGoogleTest(&argc, argv);
google::ParseCommandLineFlags(&argc, &argv, false);
return RUN_ALL_TESTS(); return RUN_ALL_TESTS();
} }
...@@ -17,31 +17,7 @@ ...@@ -17,31 +17,7 @@
namespace paddle { namespace paddle {
namespace lite { namespace lite {
std::ostream &operator<<(std::ostream &os, const DDim &dims) { void TensorLite::ShareDataWith(const TensorLite &other) {
if (dims.empty()) {
os << "[]";
return os;
}
os << "[";
for (size_t i = 0; i < dims.size() - 1; i++) {
os << dims[i] << " ";
}
os << dims.back() << "]";
return os;
}
std::ostream &operator<<(std::ostream &os, const Tensor &tensor) {
os << "Tensor:" << '\n';
os << "dim: " << tensor.dims() << '\n';
for (int i = 0; i < product(tensor.dims()); i++) {
os << tensor.data<float>()[i] << " ";
}
os << "\n";
return os;
}
void Tensor::ShareDataWith(const Tensor &other) {
buffer_ = other.buffer_; buffer_ = other.buffer_;
dims_ = other.dims_; dims_ = other.dims_;
target_ = other.target_; target_ = other.target_;
...@@ -49,17 +25,17 @@ void Tensor::ShareDataWith(const Tensor &other) { ...@@ -49,17 +25,17 @@ void Tensor::ShareDataWith(const Tensor &other) {
memory_size_ = other.memory_size_; memory_size_ = other.memory_size_;
} }
void *Tensor::mutable_data(size_t memory_size) { void *TensorLite::mutable_data(size_t memory_size) {
buffer_->ResetLazy(target_, memory_size); buffer_->ResetLazy(target_, memory_size);
return buffer_->data(); return buffer_->data();
} }
void *Tensor::mutable_data(TargetType target, size_t memory_size) { void *TensorLite::mutable_data(TargetType target, size_t memory_size) {
target_ = target; target_ = target;
return mutable_data(memory_size); return mutable_data(memory_size);
} }
void Tensor::CopyDataFrom(const Tensor &other) { void TensorLite::CopyDataFrom(const TensorLite &other) {
dims_ = other.dims_; dims_ = other.dims_;
target_ = other.target_; target_ = other.target_;
lod_ = other.lod_; lod_ = other.lod_;
......
...@@ -20,28 +20,49 @@ ...@@ -20,28 +20,49 @@
#include "paddle/fluid/lite/core/memory.h" #include "paddle/fluid/lite/core/memory.h"
#include "paddle/fluid/lite/core/target_wrapper.h" #include "paddle/fluid/lite/core/target_wrapper.h"
#include "paddle/fluid/lite/core/tensor.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
namespace details {
using DDim = std::vector<int64_t>; class DDimLite : public DDimBase<DDimLite> {
public:
DDimLite() = default;
DDimLite(const std::vector<value_type> &x) : DDimBase<DDimLite>() {
ConstructFrom(x);
}
void ConstructFrom(const std::vector<value_type> &x) { data_ = x; }
value_type operator[](int offset) const { return data_[offset]; }
std::vector<int64_t> Vectorize() { return data_; }
size_t size() const { return data_.size(); }
bool empty() const { return data_.empty(); }
const std::vector<value_type> &data() const { return data_; }
private:
std::vector<value_type> data_;
};
using LoD = std::vector<std::vector<size_t>>; using LoD = std::vector<std::vector<size_t>>;
// A light-weight tensor implementation. // A light-weight tensor implementation.
class Tensor { class TensorLite : public TensorBase<TensorLite> {
public: public:
Tensor() : buffer_(std::make_shared<Buffer>()) {} using DDimT = DDimLite;
TensorLite() : buffer_(std::make_shared<Buffer>()) {}
template <typename T> template <typename T>
const T *data() const { const T *data() const {
return static_cast<const T *>(buffer_->data()); return static_cast<const T *>(buffer_->data());
} }
void Resize(const DDim &ddim) { dims_ = ddim; } void Resize(const DDimLite &ddim) { dims_ = ddim; }
const DDim &dims() const { return dims_; } const DDimLite &dims() const { return dims_; }
const LoD &lod() const { return lod_; } const LoD &lod() const { return lod_; }
LoD *mutable_lod() { return &lod_; } LoD *mutable_lod() { return &lod_; }
...@@ -58,38 +79,34 @@ class Tensor { ...@@ -58,38 +79,34 @@ class Tensor {
bool IsInitialized() const { return buffer_->data(); } bool IsInitialized() const { return buffer_->data(); }
// Other share data to this. // Other share data to this.
void ShareDataWith(const Tensor &other); void ShareDataWith(const TensorLite &other);
void CopyDataFrom(const Tensor &other); void CopyDataFrom(const TensorLite &other);
TargetType target() const { return target_; } TargetType target() const { return target_; }
private: private:
TargetType target_{TargetType::kHost}; TargetType target_{TargetType::kHost};
DDim dims_; DDimLite dims_;
std::shared_ptr<Buffer> buffer_; std::shared_ptr<Buffer> buffer_;
LoD lod_; LoD lod_;
size_t memory_size_{}; size_t memory_size_{};
}; };
template <typename T> template <typename T>
T *Tensor::mutable_data() { T *TensorLite::mutable_data() {
memory_size_ = product(dims_) * sizeof(T); memory_size_ = dims_.production() * sizeof(T);
buffer_->ResetLazy(target_, memory_size_); buffer_->ResetLazy(target_, memory_size_);
return static_cast<T *>(buffer_->data()); return static_cast<T *>(buffer_->data());
} }
template <typename T> template <typename T>
T *Tensor::mutable_data(TargetType target) { T *TensorLite::mutable_data(TargetType target) {
target_ = target; target_ = target;
memory_size_ = product(dims_) * sizeof(T); memory_size_ = dims_.production() * sizeof(T);
buffer_->ResetLazy(target, memory_size()); buffer_->ResetLazy(target, memory_size());
return static_cast<T *>(buffer_->data()); return static_cast<T *>(buffer_->data());
} }
std::ostream &operator<<(std::ostream &os, const DDim &dims);
std::ostream &operator<<(std::ostream &os, const Tensor &tensor);
} // namespace details
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
...@@ -39,11 +39,11 @@ TEST(executor, test) { ...@@ -39,11 +39,11 @@ TEST(executor, test) {
op_desc.SetAttr("in_num_col_dims", static_cast<int>(1)); op_desc.SetAttr("in_num_col_dims", static_cast<int>(1));
program.Flush(); program.Flush();
auto* w = scope->Var("w")->GetMutable<Tensor>(); auto* w = scope->Var("w")->GetMutable<TensorBase>();
w->Resize({20, 20}); w->Resize({20, 20});
auto* x = scope->Var("x")->GetMutable<Tensor>(); auto* x = scope->Var("x")->GetMutable<TensorBase>();
x->Resize({1, 10, 20}); x->Resize({1, 10, 20});
auto* bias = scope->Var("bias")->GetMutable<Tensor>(); auto* bias = scope->Var("bias")->GetMutable<TensorBase>();
bias->Resize({1, 20}); bias->Resize({1, 20});
bias->mutable_data<float>(); bias->mutable_data<float>();
......
...@@ -81,8 +81,8 @@ struct Program { ...@@ -81,8 +81,8 @@ struct Program {
CHECK(!exec_scope) << "Duplicate PrepareWorkspace found"; CHECK(!exec_scope) << "Duplicate PrepareWorkspace found";
exec_scope = &scope->NewScope(); exec_scope = &scope->NewScope();
// Create Feed and Fetch var. // Create Feed and Fetch var.
scope->Var("feed")->GetMutable<std::vector<Tensor>>(); scope->Var("feed")->GetMutable<std::vector<lite::Tensor>>();
scope->Var("fetch")->GetMutable<std::vector<Tensor>>(); scope->Var("fetch")->GetMutable<std::vector<lite::Tensor>>();
tmp_vars.push_back("feed"); tmp_vars.push_back("feed");
tmp_vars.push_back("fetch"); tmp_vars.push_back("fetch");
......
...@@ -28,9 +28,9 @@ Program FakeProgram() { ...@@ -28,9 +28,9 @@ Program FakeProgram() {
std::string w1 = "w" + std::to_string(id); std::string w1 = "w" + std::to_string(id);
std::string b1 = "b" + std::to_string(id); std::string b1 = "b" + std::to_string(id);
std::string out1 = "out" + std::to_string(id); std::string out1 = "out" + std::to_string(id);
auto w1v = program.scope->Var(w1)->GetMutable<Tensor>(); auto w1v = program.scope->Var(w1)->GetMutable<TensorBase>();
auto b1v = program.scope->Var(b1)->GetMutable<Tensor>(); auto b1v = program.scope->Var(b1)->GetMutable<TensorBase>();
auto out1v = program.scope->Var(out1)->GetMutable<Tensor>(); auto out1v = program.scope->Var(out1)->GetMutable<TensorBase>();
lite::OpDesc desc; lite::OpDesc desc;
desc.SetInput("Input", {x}); desc.SetInput("Input", {x});
...@@ -60,7 +60,7 @@ Program FakeProgram() { ...@@ -60,7 +60,7 @@ Program FakeProgram() {
std::string x = "x"; std::string x = "x";
program.tmp_vars.push_back(x); program.tmp_vars.push_back(x);
auto* xv = program.scope->Var(x)->GetMutable<Tensor>(); auto* xv = program.scope->Var(x)->GetMutable<TensorBase>();
xv->Resize({100, 100}); xv->Resize({100, 100});
for (int i = 0; i < 3; i++) { for (int i = 0; i < 3; i++) {
...@@ -81,7 +81,7 @@ class ProgramFaker { ...@@ -81,7 +81,7 @@ class ProgramFaker {
void CreateVars(lite::Scope* scope) { void CreateVars(lite::Scope* scope) {
for (auto& var : tmp_vars_) { for (auto& var : tmp_vars_) {
auto* x = scope->Var(var); auto* x = scope->Var(var);
x->GetMutable<lite::Tensor>(); x->GetMutable<lite::TensorBase>();
} }
for (auto& x : tmp_vars_) { for (auto& x : tmp_vars_) {
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
/*
* This file defines the general interface for DDim and Tensor, which is used in
* server and mobile framework, to make the framework on the two devices share
* the same code, we clear up the methods and make the different implementations
* looks the same.
*/
#include <vector>
#include "paddle/fluid/lite/core/target_wrapper.h"
namespace paddle {
namespace lite {
/*
* This class defines the basic interfaces of the DDims for server and mobile.
* For the DDims's implementation is too tedious, we add a simple implementation
* for mobile, and use this interface to share the framework both for mobile and
* server.
*
* The derived should implement following interfaces:
* ConstructFrom
* operator[]
* Vectorize
* size
*/
template <typename DDimT>
class DDimBase {
public:
using value_type = int64_t;
DDimBase() = default;
explicit DDimBase(const std::vector<int64_t> &x) { self()->ConstructFrom(x); }
value_type operator[](int offset) const { return (*self())[offset]; }
std::vector<int64_t> Vectorize() { return self()->Vectorize(); }
size_t size() const { return const_self()->size(); }
bool empty() const { return const_self()->empty(); }
value_type production() const {
value_type res = 1;
for (int i = 0; i < const_self()->size(); i++) {
res *= (*const_self())[i];
}
return res;
}
DDimT Slice(int start, int end) const {
std::vector<value_type> vec;
for (int i = start; i < end; i++) {
vec.push_back((*const_self())[i]);
}
return DDimT(vec);
}
DDimT Flattern2D(int col) const {
return DDimT(std::vector<value_type>(
{Slice(0, col).production(), Slice(col, size()).production()}));
}
friend std::ostream &operator<<(std::ostream &os, const DDimT &dims) {
if (dims.empty()) {
os << "[]";
return os;
}
os << "[";
for (size_t i = 0; i < dims.size() - 1; i++) {
os << dims[i] << " ";
}
if (!dims.empty()) os << dims[dims.size() - 1];
os << "]";
return os;
}
private:
DDimT *self() { return static_cast<DDimT *>(this); }
const DDimT *const_self() const { return static_cast<const DDimT *>(this); }
};
/*
* This class defines the basic interfaces of the tensors implemented for
* server and mobile. It use the CRTR technology to accelerate the runtime
* performance.
*/
template <typename TensorT>
class TensorBase {
public:
TensorBase() = default;
TargetType target() const { return self()->target(); }
template <typename T>
T *mutable_data() {
return self()->template mutable_data<T>();
}
template <typename T>
T *mutable_data(TargetType target) {
return self()->template mutable_data<T>(target);
}
template <typename T>
const T *data() {
return self()->template data<T>();
}
template <typename DimT>
void Resize(const DimT &dims) {
self()->Resize(dims);
}
template <typename DDimT>
DDimT dims() {
return self()->dims();
}
template <typename LoDT>
const LoDT &lod() const {
return const_self()->lod();
}
template <typename LoDT>
LoDT *mutable_lod() {
return self()->mutable_lod();
}
template <typename T>
const T &data() const {
return const_self()->data();
}
size_t data_size() const { return const_self()->dims().production(); }
void ShareDataWith(const TensorBase &other) { self()->ShareDataWith(other); }
void CopyDataFrom(const TensorBase &other) { self()->CopyDataFrom(other); }
friend std::ostream &operator<<(std::ostream &os, const TensorT &tensor) {
os << "Tensor:" << '\n';
os << "dim: " << tensor.dims() << '\n';
for (int i = 0; i < tensor.dims().production(); i++) {
os << tensor.template data<float>()[i] << " ";
}
os << "\n";
return os;
}
private:
TensorT *self() { return static_cast<TensorT *>(this); }
const TensorT *const_self() const {
return static_cast<const TensorT *>(this);
}
};
} // namespace lite
} // namespace paddle
...@@ -19,7 +19,7 @@ namespace paddle { ...@@ -19,7 +19,7 @@ namespace paddle {
namespace lite { namespace lite {
TEST(tensor, test) { TEST(tensor, test) {
Tensor tensor; TensorBase tensor;
tensor.Resize({1, 8}); tensor.Resize({1, 8});
for (int i = 0; i < 8; i++) { for (int i = 0; i < 8; i++) {
......
...@@ -19,7 +19,7 @@ namespace paddle { ...@@ -19,7 +19,7 @@ namespace paddle {
namespace lite { namespace lite {
TEST(TypeSystem, test) { TEST(TypeSystem, test) {
ASSERT_TRUE(TypeSystem::Global().Contains<lite::Tensor>()); ASSERT_TRUE(TypeSystem::Global().Contains<lite::TensorBase>());
} }
TEST(TypeSystem, register_new) { TEST(TypeSystem, register_new) {
......
...@@ -29,7 +29,7 @@ class Variable { ...@@ -29,7 +29,7 @@ class Variable {
template <typename T> template <typename T>
T* GetMutable() { T* GetMutable() {
if (!blob_.is<T>()) blob_.set<T>(); if (!blob_.is<T>()) blob_.set<T>();
return &blob_.get<T>(); return blob_.get_mutable<T>();
} }
template <typename T> template <typename T>
...@@ -38,7 +38,7 @@ class Variable { ...@@ -38,7 +38,7 @@ class Variable {
} }
private: private:
variant<int, float, std::string, Tensor> blob_; variant<int, float, std::string, lite::Tensor> blob_;
}; };
} // namespace lite } // namespace lite
......
set(lite_kernel_deps type_system kernel_lite op_registry_lite) set(lite_kernel_deps type_system kernel_lite op_lite op_registry_lite ${tensor_lite})
add_subdirectory(host) add_subdirectory(host)
add_subdirectory(arm) add_subdirectory(arm)
add_subdirectory(cuda) add_subdirectory(cuda)
...@@ -2,7 +2,7 @@ if(NOT LITE_WITH_CUDA) ...@@ -2,7 +2,7 @@ if(NOT LITE_WITH_CUDA)
return() return()
endif() endif()
nv_library(mul_compute_cuda SRCS mul_compute.cc DEPS tensor_lite) nv_library(mul_compute_cuda SRCS mul_compute.cc DEPS ${tensor_lite})
cc_library(io_copy_compute_cuda SRCS io_copy_compute.cc DEPS tensor_lite) cc_library(io_copy_compute_cuda SRCS io_copy_compute.cc DEPS ${tensor_lite})
nv_library(kernels_cuda DEPS mul_compute_cuda io_copy_compute_cuda cuda_blas_lite) nv_library(kernels_cuda DEPS mul_compute_cuda io_copy_compute_cuda cuda_blas_lite)
...@@ -46,12 +46,11 @@ class IoCopyHostToCudaCompute ...@@ -46,12 +46,11 @@ class IoCopyHostToCudaCompute
public: public:
void Run() override { void Run() override {
auto& param = Param<operators::IoCopyParam>(); auto& param = Param<operators::IoCopyParam>();
CHECK(TensorGetTarget(*param.x) == TARGET(kHost) || CHECK(param.x->target() == TARGET(kHost) ||
TensorGetTarget(*param.x) == TARGET(kX86)); param.x->target() == TARGET(kX86));
LOG(INFO) << "copy size " << param.x->memory_size(); LOG(INFO) << "copy size " << param.x->data_size();
auto* data = TensorMutableData<int8_t>(param.y, TARGET(kCUDA), auto* data = param.y->mutable_data<int8_t>(TARGET(kCUDA));
param.x->memory_size()); CopyFromHostSync(data, param.x->data<int8_t>(), param.x->data_size());
CopyFromHostSync(data, param.x->data<int8_t>(), param.x->memory_size());
} }
std::unique_ptr<type_infer_handler_t> GetTypeInferHandler() override { std::unique_ptr<type_infer_handler_t> GetTypeInferHandler() override {
...@@ -82,11 +81,10 @@ class IoCopyCudaToHostCompute ...@@ -82,11 +81,10 @@ class IoCopyCudaToHostCompute
public: public:
void Run() override { void Run() override {
auto& param = Param<operators::IoCopyParam>(); auto& param = Param<operators::IoCopyParam>();
CHECK(TensorGetTarget(*param.x) == TARGET(kCUDA)); CHECK(param.x->target() == TARGET(kCUDA));
auto* data = TensorMutableData<int8_t>(param.y, TARGET(kHost), auto* data = param.y->mutable_data<float>();
param.x->memory_size()); LOG(INFO) << "copy size " << param.x->data_size();
LOG(INFO) << "copy size " << param.x->memory_size(); CopyToHostSync(data, param.x->data<void>(), param.x->data_size());
CopyToHostSync(data, param.x->data<void>(), param.x->memory_size());
} }
std::string doc() const override { return "Copy IO from CUDA to HOST"; } std::string doc() const override { return "Copy IO from CUDA to HOST"; }
......
...@@ -51,9 +51,8 @@ class MulCompute : public KernelLite<TARGET(kCUDA), PRECISION(kFloat)> { ...@@ -51,9 +51,8 @@ class MulCompute : public KernelLite<TARGET(kCUDA), PRECISION(kFloat)> {
*/ */
const auto& param = Param<operators::MulParam>(); const auto& param = Param<operators::MulParam>();
TensorMutableData<float>(param.output, TARGET(kCUDA), param.output->mutable_data<float>(TARGET(kCUDA));
product(param.output->dims())); LOG(INFO) << "mul output memory size " << param.output->data_size();
LOG(INFO) << "mul output memory size " << param.output->memory_size();
// mul_compute<float>(blas, x, x_h, x_w, y, y_h, y_w, out); // mul_compute<float>(blas, x, x_h, x_w, y, y_h, y_w, out);
} }
......
...@@ -29,16 +29,17 @@ void FcCompute::Run() { ...@@ -29,16 +29,17 @@ void FcCompute::Run() {
CHECK_GE(param.input->dims().size(), 2UL); CHECK_GE(param.input->dims().size(), 2UL);
CHECK_EQ(param.output->dims().size(), 2UL); CHECK_EQ(param.output->dims().size(), 2UL);
fc_compute_eigen(param.input->data<float>(), // x fc_compute_eigen(
product(param.input->dims(), 0, param.in_num_col_dims), param.input->data<float>(), // x
product(param.input->dims(), param.in_num_col_dims, param.input->dims().Slice(0, param.in_num_col_dims).production(),
param.input->dims().size()), param.input->dims()
param.w->data<float>(), // w .Slice(param.in_num_col_dims, param.input->dims().size())
param.w->dims()[1], // w_w .production(),
param.w->dims()[0], // w_h param.w->data<float>(), // w
param.bias->data<float>(), // b param.w->dims()[1], // w_w
TensorMutableData<float>(param.output, TARGET(kHost), param.w->dims()[0], // w_h
product(param.output->dims()))); param.bias->data<float>(), // b
param.output->mutable_data<float>());
} }
// TargetType FcCompute::target() const { return TARGET(kHost); } // TargetType FcCompute::target() const { return TARGET(kHost); }
......
...@@ -23,7 +23,7 @@ namespace kernels { ...@@ -23,7 +23,7 @@ namespace kernels {
namespace host { namespace host {
TEST(fc_compute_naive, test) { TEST(fc_compute_naive, test) {
Tensor x, w, b, out, out1; TensorBase x, w, b, out, out1;
const int batch_size = 2; const int batch_size = 2;
x.Resize({batch_size, 3}); x.Resize({batch_size, 3});
w.Resize({4, 3}); w.Resize({4, 3});
...@@ -79,10 +79,10 @@ TEST(fc_host, compute) { ...@@ -79,10 +79,10 @@ TEST(fc_host, compute) {
FcCompute fc; FcCompute fc;
operators::FcParam param; operators::FcParam param;
Tensor x; TensorBase x;
Tensor w; TensorBase w;
Tensor bias; TensorBase bias;
Tensor output; TensorBase output;
x.Resize({1, 10, 20}); x.Resize({1, 10, 20});
w.Resize({20, 20}); w.Resize({20, 20});
......
...@@ -27,7 +27,9 @@ class FeedCompute ...@@ -27,7 +27,9 @@ class FeedCompute
void Run() override { void Run() override {
auto &param = Param<operators::FeedParam>(); auto &param = Param<operators::FeedParam>();
const Tensor &feed_item = param.feed_list->at(param.col); LOG(INFO) << "feed_list.size: " << param.feed_list->size();
LOG(INFO) << "col " << param.col;
const lite::Tensor &feed_item = (*param.feed_list)[0];
param.out->ShareDataWith(feed_item); param.out->ShareDataWith(feed_item);
LOG(INFO) << "FEED input " << feed_item << " col " << param.col; LOG(INFO) << "FEED input " << feed_item << " col " << param.col;
LOG(INFO) << "FEED output " << *param.out; LOG(INFO) << "FEED output " << *param.out;
......
...@@ -41,18 +41,24 @@ class MulCompute : public KernelLite<TARGET(kHost), PRECISION(kFloat)> { ...@@ -41,18 +41,24 @@ class MulCompute : public KernelLite<TARGET(kHost), PRECISION(kFloat)> {
void Run() override { void Run() override {
auto& param = Param<operators::MulParam>(); auto& param = Param<operators::MulParam>();
core::dim2 x_shape({product(param.x->dims(), 0, param.x_num_col_dims), core::dim2 x_shape(
product(param.x->dims(), param.x_num_col_dims, {static_cast<int>(
param.x->dims().size())}); param.x->dims().Slice(0, param.x_num_col_dims).production()),
static_cast<int>(
core::dim2 y_shape({product(param.y->dims(), 0, param.y_num_col_dims), param.x->dims()
product(param.y->dims(), param.y_num_col_dims, .Slice(param.x_num_col_dims, param.x->dims().size())
param.y->dims().size())}); .production())});
core::dim2 y_shape(
{static_cast<int>(
param.y->dims().Slice(0, param.y_num_col_dims).production()),
static_cast<int>(
param.y->dims()
.Slice(param.y_num_col_dims, param.y->dims().size())
.production())});
mul_compute_eigen(param.x->data<float>(), x_shape.x, x_shape.y, // mul_compute_eigen(param.x->data<float>(), x_shape.x, x_shape.y, //
param.y->data<float>(), y_shape.x, y_shape.y, // param.y->data<float>(), y_shape.x, y_shape.y, //
TensorMutableData<float>(param.output, TARGET(kHost), param.output->mutable_data<float>());
product(param.output->dims())));
LOG(INFO) << "MUL x " << *param.x; LOG(INFO) << "MUL x " << *param.x;
LOG(INFO) << "MUL W " << *param.y; LOG(INFO) << "MUL W " << *param.y;
LOG(INFO) << "MUL out " << *param.output; LOG(INFO) << "MUL out " << *param.output;
......
...@@ -25,10 +25,9 @@ class ReluCompute : public KernelLite<TARGET(kHost), PRECISION(kFloat)> { ...@@ -25,10 +25,9 @@ class ReluCompute : public KernelLite<TARGET(kHost), PRECISION(kFloat)> {
public: public:
void Run() override { void Run() override {
auto& param = Param<operators::ReluParam>(); auto& param = Param<operators::ReluParam>();
auto n = product(param.input->dims()); auto n = param.input->dims().production();
const float* input = param.input->data<float>(); const float* input = param.input->data<float>();
float* output = TensorMutableData<float>(param.output, TARGET(kHost), float* output = param.output->mutable_data<float>();
product(param.output->dims()));
for (int i = 0; i < n; i++) { for (int i = 0; i < n; i++) {
output[i] = std::max(0.f, input[i]); output[i] = std::max(0.f, input[i]);
} }
......
...@@ -37,10 +37,8 @@ class ScaleCompute : public KernelLite<TARGET(kHost), PRECISION(kFloat)> { ...@@ -37,10 +37,8 @@ class ScaleCompute : public KernelLite<TARGET(kHost), PRECISION(kFloat)> {
void Run() override { void Run() override {
auto& param = Param<operators::ScaleParam>(); auto& param = Param<operators::ScaleParam>();
scale_compute(param.x->data<float>(), scale_compute(param.x->data<float>(), param.output->mutable_data<float>(),
TensorMutableData<float>(param.output, TARGET(kHost), param.x->dims().production(), param.scale, param.bias,
product(param.output->dims())),
product(param.x->dims()), param.scale, param.bias,
param.bias_after_scale); param.bias_after_scale);
} }
......
...@@ -6,7 +6,7 @@ else() ...@@ -6,7 +6,7 @@ else()
cc_library(compatible_pb_lite SRCS compatible_pb.cc DEPS framework_proto proto_desc) cc_library(compatible_pb_lite SRCS compatible_pb.cc DEPS framework_proto proto_desc)
endif(LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) endif(LITE_WITH_LIGHT_WEIGHT_FRAMEWORK)
set(model_parser_deps variable_lite scope_lite tensor_lite scope_lite set(model_parser_deps variable_lite scope_lite ${tensor_lite} scope_lite
target_wrapper_host target_wrapper_host
compatible_pb_lite compatible_pb_lite
) )
......
...@@ -58,19 +58,20 @@ void TensorFromStream(std::istream &is, lite::Tensor *tensor) { ...@@ -58,19 +58,20 @@ void TensorFromStream(std::istream &is, lite::Tensor *tensor) {
} }
// read tensor // read tensor
std::vector<int64_t> dims; std::vector<int64_t> dims_vec;
std::copy(desc.dims().begin(), desc.dims().end(), std::back_inserter(dims)); std::copy(desc.dims().begin(), desc.dims().end(),
tensor->Resize(lite::DDim(&dims[0], dims.size())); std::back_inserter(dims_vec));
lite::DDim dims(dims_vec);
tensor->Resize(dims);
void *buf; void *buf;
size_t size = product(tensor->dims()) * SizeOfType(desc.data_type()); size_t size = tensor->dims().production() * SizeOfType(desc.data_type());
// alllocate memory // alllocate memory
switch (static_cast<int>(desc.data_type())) { switch (static_cast<int>(desc.data_type())) {
#define DO(desc, type) \ #define DO(desc, type) \
case Type::VarType_Type_##desc: \ case Type::VarType_Type_##desc: \
buf = TensorMutableData<type>(tensor, TensorGetTarget(*tensor), \ buf = tensor->mutable_data<type>(); \
product(tensor->dims()));
break; break;
DO(BOOL, bool); // DO(BOOL, bool);
DO(FP32, float); DO(FP32, float);
DO(INT8, int8_t); DO(INT8, int8_t);
DO(INT16, int16_t); DO(INT16, int16_t);
...@@ -198,7 +199,7 @@ void TensorToStream(std::ostream &os, const lite::Tensor &tensor) { ...@@ -198,7 +199,7 @@ void TensorToStream(std::ostream &os, const lite::Tensor &tensor) {
auto dims = tensor.dims(); auto dims = tensor.dims();
auto *pb_dims = desc.mutable_dims(); auto *pb_dims = desc.mutable_dims();
pb_dims->Resize(static_cast<int>(dims.size()), 0); pb_dims->Resize(static_cast<int>(dims.size()), 0);
auto dims_vec = DDimVectorize(dims); auto dims_vec = dims.Vectorize();
std::copy(dims_vec.begin(), dims_vec.end(), pb_dims->begin()); std::copy(dims_vec.begin(), dims_vec.end(), pb_dims->begin());
int32_t size = desc.ByteSize(); int32_t size = desc.ByteSize();
os.write(reinterpret_cast<const char *>(&size), sizeof(size)); os.write(reinterpret_cast<const char *>(&size), sizeof(size));
...@@ -206,15 +207,15 @@ void TensorToStream(std::ostream &os, const lite::Tensor &tensor) { ...@@ -206,15 +207,15 @@ void TensorToStream(std::ostream &os, const lite::Tensor &tensor) {
os.write(out.data(), size); os.write(out.data(), size);
} }
{ // the 3rd field, tensor data { // the 3rd field, tensor data
uint64_t size = tensor.memory_size(); uint64_t size = tensor.data_size();
CHECK_LT(size, std::numeric_limits<std::streamsize>::max()) CHECK_LT(size, std::numeric_limits<std::streamsize>::max())
<< "Index overflow when writing tensor"; << "Index overflow when writing tensor";
#ifdef LITE_WITH_CUDA #ifdef LITE_WITH_CUDA
if (TensorGetTarget(tensor) == TARGET(kCUDA)) { if (tensor.target() == TARGET(kCUDA)) {
std::unique_ptr<char> tmp_buffer(new char[size]); std::unique_ptr<char> tmp_buffer(new char[size]);
TargetWrapperCuda::MemcpySync(tmp_buffer.get(), tensor.data<float>(), TargetWrapperCuda::MemcpySync(tmp_buffer.get(), tensor.data<float>(),
tensor.memory_size(), IoDirection::DtoH); tensor.data_size(), IoDirection::DtoH);
os.write(static_cast<const char *>(tmp_buffer.get()), os.write(static_cast<const char *>(tmp_buffer.get()),
static_cast<std::streamsize>(size)); static_cast<std::streamsize>(size));
} else } else
......
...@@ -28,7 +28,7 @@ TEST(ModelParser, LoadParam) { ...@@ -28,7 +28,7 @@ TEST(ModelParser, LoadParam) {
Scope scope; Scope scope;
auto* v = scope.Var("xxx"); auto* v = scope.Var("xxx");
LoadParam("/home/chunwei/project2/models/fc/fluid_checkpoint/b1", v); LoadParam("/home/chunwei/project2/models/fc/fluid_checkpoint/b1", v);
const auto& t = v->Get<Tensor>(); const auto& t = v->Get<TensorBase>();
LOG(INFO) << "loaded\n"; LOG(INFO) << "loaded\n";
LOG(INFO) << t; LOG(INFO) << t;
} }
......
cc_library(fc_op_lite SRCS fc_op.cc DEPS op_lite op_params_lite tensor_lite) set(op_DEPS ${tensor_lite} op_lite op_params_lite)
cc_library(relu_op_lite SRCS relu_op.cc DEPS op_lite) cc_library(fc_op_lite SRCS fc_op.cc DEPS ${op_DEPS})
cc_library(mul_op_lite SRCS mul_op.cc DEPS op_lite) cc_library(relu_op_lite SRCS relu_op.cc DEPS ${op_DEPS})
cc_library(scale_op_lite SRCS scale_op.cc DEPS op_lite) cc_library(mul_op_lite SRCS mul_op.cc DEPS ${op_DEPS})
cc_library(feed_op_lite SRCS feed_op.cc DEPS op_lite) cc_library(scale_op_lite SRCS scale_op.cc DEPS ${op_DEPS})
cc_library(fetch_op_lite SRCS fetch_op.cc DEPS op_lite) cc_library(feed_op_lite SRCS feed_op.cc DEPS ${op_DEPS})
cc_library(io_copy_op_lite SRCS io_copy_op.cc DEPS op_lite) cc_library(fetch_op_lite SRCS fetch_op.cc DEPS ${op_DEPS})
cc_library(io_copy_op_lite SRCS io_copy_op.cc DEPS ${op_DEPS})
cc_library(op_params_lite SRCS op_params.cc DEPS tensor_lite) cc_library(op_params_lite SRCS op_params.cc DEPS ${tensor_lite})
cc_library(ops_lite DEPS cc_library(ops_lite DEPS
fc_op_lite fc_op_lite
relu_op_lite relu_op_lite
......
...@@ -42,7 +42,7 @@ bool FcOpLite::CheckShape() const { ...@@ -42,7 +42,7 @@ bool FcOpLite::CheckShape() const {
CHECK_GT_OR_FALSE(input_dims.size(), CHECK_GT_OR_FALSE(input_dims.size(),
static_cast<size_t>(param_.in_num_col_dims)); static_cast<size_t>(param_.in_num_col_dims));
param_.in_mat_dims = lite::flatten_to_2d(input_dims, param_.in_num_col_dims); param_.in_mat_dims = input_dims.Flattern2D(param_.in_num_col_dims);
// CHECK_EQ_OR_FALSE(param_.in_mat_dims[1], w_dims[0]); // CHECK_EQ_OR_FALSE(param_.in_mat_dims[1], w_dims[0]);
return true; return true;
...@@ -58,7 +58,7 @@ bool FcOpLite::InferShape() const { ...@@ -58,7 +58,7 @@ bool FcOpLite::InferShape() const {
output_dims[i] = input_dims[i]; output_dims[i] = input_dims[i];
} }
output_dims.back() = w_dims[1]; output_dims.back() = w_dims[1];
param_.output->Resize(DDim(&output_dims[0], output_dims.size())); param_.output->Resize(lite::DDim(output_dims));
// share LoD // share LoD
// param_.output->set_lod(param_.input->lod()); // param_.output->set_lod(param_.input->lod());
......
...@@ -52,11 +52,11 @@ class FcOpLite : public OpLite { ...@@ -52,11 +52,11 @@ class FcOpLite : public OpLite {
auto bias = op_desc.Input("Bias").front(); auto bias = op_desc.Input("Bias").front();
auto out = op_desc.Output("Out").front(); auto out = op_desc.Output("Out").front();
param_.input = scope->FindVar(input)->GetMutable<Tensor>(); param_.input = scope->FindVar(input)->GetMutable<lite::Tensor>();
param_.w = scope->FindVar(W)->GetMutable<Tensor>(); param_.w = scope->FindVar(W)->GetMutable<lite::Tensor>();
param_.bias = scope->FindVar(bias)->GetMutable<Tensor>(); param_.bias = scope->FindVar(bias)->GetMutable<lite::Tensor>();
CHECK(scope->FindVar(out)); CHECK(scope->FindVar(out));
param_.output = scope->FindVar(out)->GetMutable<Tensor>(); param_.output = scope->FindVar(out)->GetMutable<lite::Tensor>();
param_.in_num_col_dims = GetAttr<int>(op_desc.GetAttr("in_num_col_dims")); param_.in_num_col_dims = GetAttr<int>(op_desc.GetAttr("in_num_col_dims"));
CHECK(kernel_); CHECK(kernel_);
......
...@@ -24,10 +24,10 @@ TEST(fc_op_lite, test) { ...@@ -24,10 +24,10 @@ TEST(fc_op_lite, test) {
LOG(INFO) << "\n" << KernelRegistry::Global().DebugString(); LOG(INFO) << "\n" << KernelRegistry::Global().DebugString();
// prepare variables // prepare variables
Scope scope; Scope scope;
auto* x = scope.Var("x")->GetMutable<Tensor>(); auto* x = scope.Var("x")->GetMutable<TensorBase>();
auto* w = scope.Var("w")->GetMutable<Tensor>(); auto* w = scope.Var("w")->GetMutable<TensorBase>();
auto* bias = scope.Var("bias")->GetMutable<Tensor>(); auto* bias = scope.Var("bias")->GetMutable<TensorBase>();
auto* output = scope.Var("output")->GetMutable<Tensor>(); auto* output = scope.Var("output")->GetMutable<TensorBase>();
x->Resize({1, 10, 20}); x->Resize({1, 10, 20});
w->Resize({20, 20}); w->Resize({20, 20});
bias->Resize({1, 10}); bias->Resize({1, 10});
......
...@@ -39,13 +39,13 @@ class FeedOp : public OpLite { ...@@ -39,13 +39,13 @@ class FeedOp : public OpLite {
auto feed_var_name = opdesc.Input("X").front(); auto feed_var_name = opdesc.Input("X").front();
auto* feed_var = scope->FindVar(feed_var_name); auto* feed_var = scope->FindVar(feed_var_name);
CHECK(feed_var); CHECK(feed_var);
auto& feed_tensor_list = feed_var->Get<std::vector<Tensor>>(); auto& feed_tensor_list = feed_var->Get<std::vector<lite::Tensor>>();
param_.feed_list = &feed_tensor_list; param_.feed_list = &feed_tensor_list;
auto out_name = opdesc.Output("Out").front(); auto out_name = opdesc.Output("Out").front();
auto* out_var = scope->FindVar(out_name); auto* out_var = scope->FindVar(out_name);
CHECK(out_var); CHECK(out_var);
param_.out = out_var->GetMutable<Tensor>(); param_.out = out_var->GetMutable<lite::Tensor>();
// NOTE need boost here // NOTE need boost here
// TODO(Superjomn) drop the need of framework::op_desc // TODO(Superjomn) drop the need of framework::op_desc
......
...@@ -37,7 +37,7 @@ class FetchOp : public OpLite { ...@@ -37,7 +37,7 @@ class FetchOp : public OpLite {
auto _x = opdesc.Input("X").front(); auto _x = opdesc.Input("X").front();
auto* x = scope->FindVar(_x); auto* x = scope->FindVar(_x);
CHECK(x); CHECK(x);
param_.input = &x->Get<Tensor>(); param_.input = &x->Get<lite::Tensor>();
auto _out = opdesc.Output("Out").front(); auto _out = opdesc.Output("Out").front();
auto* out = scope->FindVar(_out); auto* out = scope->FindVar(_out);
......
...@@ -45,7 +45,7 @@ bool MulOpLite::InferShape() const { ...@@ -45,7 +45,7 @@ bool MulOpLite::InferShape() const {
} }
out_dims.back() = y_dims[1]; out_dims.back() = y_dims[1];
param_.output->Resize(DDim(&out_dims[0], out_dims.size())); param_.output->Resize(lite::DDim(out_dims));
// share LoD // share LoD
// param_.output->set_lod(param_.input->lod()); // param_.output->set_lod(param_.input->lod());
......
...@@ -25,36 +25,36 @@ namespace lite { ...@@ -25,36 +25,36 @@ namespace lite {
namespace operators { namespace operators {
struct FeedParam { struct FeedParam {
const std::vector<Tensor>* feed_list{}; const std::vector<lite::Tensor>* feed_list{};
Tensor* out{}; lite::Tensor* out{};
int col; int col;
}; };
struct FetchParam { struct FetchParam {
const Tensor* input{}; const lite::Tensor* input{};
std::vector<Tensor>* fetch_list{}; std::vector<lite::Tensor>* fetch_list{};
int col; int col;
}; };
struct FcParam { struct FcParam {
Tensor* input{}; lite::Tensor* input{};
Tensor* w{}; lite::Tensor* w{};
Tensor* bias{}; lite::Tensor* bias{};
Tensor* output{}; lite::Tensor* output{};
DDim in_mat_dims; lite::DDim in_mat_dims;
int in_num_col_dims{1}; int in_num_col_dims{1};
}; };
struct ReluParam { struct ReluParam {
Tensor* input{}; lite::Tensor* input{};
Tensor* output{}; lite::Tensor* output{};
}; };
// For Mul Op // For Mul Op
struct MulParam { struct MulParam {
Tensor* x{}; lite::Tensor* x{};
Tensor* y{}; lite::Tensor* y{};
Tensor* output{}; lite::Tensor* output{};
int x_num_col_dims{1}; int x_num_col_dims{1};
int y_num_col_dims{1}; int y_num_col_dims{1};
...@@ -62,8 +62,8 @@ struct MulParam { ...@@ -62,8 +62,8 @@ struct MulParam {
// For Scale Op // For Scale Op
struct ScaleParam { struct ScaleParam {
Tensor* x{}; lite::Tensor* x{};
Tensor* output{}; lite::Tensor* output{};
float scale{1.}; float scale{1.};
float bias{}; float bias{};
...@@ -71,8 +71,8 @@ struct ScaleParam { ...@@ -71,8 +71,8 @@ struct ScaleParam {
}; };
struct IoCopyParam { struct IoCopyParam {
const Tensor* x{}; const lite::Tensor* x{};
Tensor* y{}; lite::Tensor* y{};
}; };
using param_t = variant<FeedParam, FetchParam, FcParam, ReluParam, MulParam, using param_t = variant<FeedParam, FetchParam, FcParam, ReluParam, MulParam,
......
...@@ -32,10 +32,10 @@ bool ReluOp::InferShape() const { ...@@ -32,10 +32,10 @@ bool ReluOp::InferShape() const {
} }
bool ReluOp::AttachImpl(const OpDesc &opdesc, lite::Scope *scope) { bool ReluOp::AttachImpl(const OpDesc &opdesc, lite::Scope *scope) {
param_.input = const_cast<Tensor *>( param_.input = const_cast<lite::Tensor *>(
&scope->FindVar(opdesc.Input("Input").front())->Get<Tensor>()); &scope->FindVar(opdesc.Input("Input").front())->Get<lite::Tensor>());
param_.output = param_.output =
scope->FindVar(opdesc.Output("Out").front())->GetMutable<Tensor>(); scope->FindVar(opdesc.Output("Out").front())->GetMutable<lite::Tensor>();
CHECK(param_.input); CHECK(param_.input);
CHECK(param_.output); CHECK(param_.output);
kernel_->SetParam(param_); kernel_->SetParam(param_);
......
...@@ -109,10 +109,21 @@ struct variant { ...@@ -109,10 +109,21 @@ struct variant {
type_id = typeid(T).hash_code(); type_id = typeid(T).hash_code();
} }
template <typename T> template <typename T>
T& get() { const T& get() const {
// It is a dynamic_cast-like behaviour // It is a dynamic_cast-like behaviour
if (type_id == typeid(T).hash_code()) if (type_id == typeid(T).hash_code())
return *reinterpret_cast<T*>(&data); return *reinterpret_cast<const T*>(&data);
else
LOG(FATAL) << "unmatched type get, should be " << type_id << " but get "
<< typeid(T).name();
return *reinterpret_cast<const T*>(&data);
}
template <typename T>
T* get_mutable() {
// It is a dynamic_cast-like behaviour
if (type_id == typeid(T).hash_code())
return reinterpret_cast<T*>(&data);
else else
LOG(FATAL) << "unmatched type get, should be " << type_id << " but get " LOG(FATAL) << "unmatched type get, should be " << type_id << " but get "
<< typeid(T).name(); << typeid(T).name();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册