提交 5017b154 编写于 作者: Q qijun

refactor tensor mutable_data

上级 65dbeb6a
......@@ -14,17 +14,17 @@ limitations under the License. */
#pragma once
#include <paddle/framework/attr_checker.h>
#include <paddle/framework/op_desc.pb.h>
#include <paddle/framework/scope.h>
#include <paddle/framework/tensor.h>
#include <paddle/platform/device_context.h>
#include <paddle/platform/place.h>
#include <paddle/utils/Error.h>
#include <boost/variant.hpp>
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/framework/attr_checker.h"
#include "paddle/framework/op_desc.pb.h"
#include "paddle/framework/scope.h"
#include "paddle/framework/tensor.h"
#include "paddle/platform/device_context.h"
#include "paddle/platform/place.h"
#include "paddle/utils/Error.h"
namespace paddle {
namespace framework {
......
......@@ -62,21 +62,19 @@ class Tensor {
!(holder_->place() ==
place) /* some versions of boost::variant don't have operator!= */
|| holder_->size() < numel_ * sizeof(T) + offset_) {
if (platform::is_cpu_place(place)) {
holder_.reset(new PlaceholderImpl<T, platform::CPUPlace>(
boost::get<platform::CPUPlace>(place), numel_ * sizeof(T)));
}
#ifdef __CUDACC__
switch (place.which()) {
case 0:
holder_.reset(new PlaceholderImpl<T, platform::GPUPlace>(
boost::get<platform::GPUPlace>(place), numel_ * sizeof(T)));
break;
case 1:
holder_.reset(new PlaceholderImpl<T, platform::CPUPlace>(
boost::get<platform::CPUPlace>(place), numel_ * sizeof(T)));
break;
else if (platform::is_gpu_place(place)) {
holder_.reset(new PlaceholderImpl<T, platform::GPUPlace>(
boost::get<platform::GPUPlace>(place), numel_ * sizeof(T)));
}
#else
holder_.reset(new PlaceholderImpl<T, platform::CPUPlace>(
boost::get<platform::CPUPlace>(place), numel_ * sizeof(T)));
else if (platform::is_gpu_place(place)) {
PADDLE_ENFORCE(true, "GPU not support!");
}
#endif
offset_ = 0;
}
......
......@@ -20,9 +20,9 @@ limitations under the License. */
#include "paddle/platform/gpu_info.h"
#define EIGEN_USE_GPU
#endif
#include <paddle/platform/place.h>
#include <memory>
#include <unsupported/Eigen/CXX11/Tensor>
#include "paddle/platform/place.h"
#include "unsupported/Eigen/CXX11/Tensor"
namespace paddle {
namespace platform {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册