From 5017b154689bd8cb595c1d37a54cb2fd072488bc Mon Sep 17 00:00:00 2001 From: qijun Date: Mon, 17 Jul 2017 15:37:42 +0800 Subject: [PATCH] refactor tensor mutable_data --- paddle/framework/operator.h | 14 +++++++------- paddle/framework/tensor.h | 22 ++++++++++------------ paddle/platform/device_context.h | 4 ++-- 3 files changed, 19 insertions(+), 21 deletions(-) diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index c48d990eb2..e6cae9c32b 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -14,17 +14,17 @@ limitations under the License. */ #pragma once -#include -#include -#include -#include -#include -#include -#include #include #include #include #include +#include "paddle/framework/attr_checker.h" +#include "paddle/framework/op_desc.pb.h" +#include "paddle/framework/scope.h" +#include "paddle/framework/tensor.h" +#include "paddle/platform/device_context.h" +#include "paddle/platform/place.h" +#include "paddle/utils/Error.h" namespace paddle { namespace framework { diff --git a/paddle/framework/tensor.h b/paddle/framework/tensor.h index 30e00d0e0f..7ba4b29e7c 100644 --- a/paddle/framework/tensor.h +++ b/paddle/framework/tensor.h @@ -62,21 +62,19 @@ class Tensor { !(holder_->place() == place) /* some versions of boost::variant don't have operator!= */ || holder_->size() < numel_ * sizeof(T) + offset_) { + if (platform::is_cpu_place(place)) { + holder_.reset(new PlaceholderImpl( + boost::get(place), numel_ * sizeof(T))); + } #ifdef __CUDACC__ - switch (place.which()) { - case 0: - holder_.reset(new PlaceholderImpl( - boost::get(place), numel_ * sizeof(T))); - break; - - case 1: - holder_.reset(new PlaceholderImpl( - boost::get(place), numel_ * sizeof(T))); - break; + else if (platform::is_gpu_place(place)) { + holder_.reset(new PlaceholderImpl( + boost::get(place), numel_ * sizeof(T))); } #else - holder_.reset(new PlaceholderImpl( - boost::get(place), numel_ * sizeof(T))); + else if (platform::is_gpu_place(place)) { + PADDLE_ENFORCE(true, "GPU not support!"); + } #endif offset_ = 0; } diff --git a/paddle/platform/device_context.h b/paddle/platform/device_context.h index 5f8ad15951..f226a75c20 100644 --- a/paddle/platform/device_context.h +++ b/paddle/platform/device_context.h @@ -20,9 +20,9 @@ limitations under the License. */ #include "paddle/platform/gpu_info.h" #define EIGEN_USE_GPU #endif -#include #include -#include +#include "paddle/platform/place.h" +#include "unsupported/Eigen/CXX11/Tensor" namespace paddle { namespace platform { -- GitLab