提交 3208914b 编写于 作者: Q QI JUN 提交者: GitHub

Merge pull request #2805 from QiJune/tensor_to_EigenTensor

Add method converting Tensor to Eigen TensorMap
# ddim lib
cc_library(enforce SRCS enforce.cc DEPS glog) cc_library(enforce SRCS enforce.cc DEPS glog)
cc_test(enforce_test SRCS enforce_test.cc DEPS enforce) cc_test(enforce_test SRCS enforce_test.cc DEPS enforce)
cc_library(ddim SRCS ddim.cc) cc_library(ddim SRCS ddim.cc DEPS eigen3)
cc_test(ddim_test SRCS ddim_test.cc DEPS ddim) cc_test(ddim_test SRCS ddim_test.cc DEPS ddim)
nv_test(dim_test SRCS dim_test.cu DEPS ddim) nv_test(dim_test SRCS dim_test.cu DEPS ddim)
cc_library(tensor SRCS tensor.cc DEPS ddim place enforce paddle_memory) cc_library(tensor SRCS tensor.cc DEPS ddim place enforce paddle_memory)
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once #pragma once
#include <boost/variant.hpp> #include <boost/variant.hpp>
#include <initializer_list> #include <initializer_list>
#include <stdexcept> #include <stdexcept>
#include <vector> #include <vector>
#include "paddle/framework/dim.h" #include "paddle/framework/dim.h"
#include "paddle/framework/enforce.h"
#include "unsupported/Eigen/CXX11/Tensor"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -104,6 +119,17 @@ int arity(const DDim& ddim); ...@@ -104,6 +119,17 @@ int arity(const DDim& ddim);
std::ostream& operator<<(std::ostream&, const DDim&); std::ostream& operator<<(std::ostream&, const DDim&);
template <int NDIMS>
Eigen::DSizes<Eigen::DenseIndex, NDIMS> ToEigenDSizes(const DDim& dims) {
int rank = arity(dims);
PADDLE_ENFORCE(rank == NDIMS, "DDim and NDIMS must be same");
Eigen::DSizes<Eigen::DenseIndex, NDIMS> dsizes;
for (int d = 0; d < rank; d++) {
dsizes[d] = dims[d];
}
return dsizes;
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
......
...@@ -19,6 +19,20 @@ limitations under the License. */ ...@@ -19,6 +19,20 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
template <>
Eigen::DefaultDevice* KernelContext::GetEigenDevice<
platform::CPUPlace, Eigen::DefaultDevice>() const {
return device_context_.get_eigen_device<Eigen::DefaultDevice>();
}
#ifndef PADDLE_ONLY_CPU
template <>
Eigen::GpuDevice*
KernelContext::GetEigenDevice<platform::GPUPlace, Eigen::GpuDevice>() const {
return device_context_.get_eigen_device<Eigen::GpuDevice>();
}
#endif
const std::string& OperatorBase::Input(const std::string& name) const { const std::string& OperatorBase::Input(const std::string& name) const {
auto it = in_out_idxs_->find(name); auto it = in_out_idxs_->find(name);
PADDLE_ENFORCE(it != in_out_idxs_->end(), "no key [%s] in in_out_idxs_", PADDLE_ENFORCE(it != in_out_idxs_->end(), "no key [%s] in in_out_idxs_",
......
...@@ -31,6 +31,21 @@ limitations under the License. */ ...@@ -31,6 +31,21 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
template <typename T>
struct EigenDeviceConverter;
template <>
struct EigenDeviceConverter<platform::CPUPlace> {
using EigenDeviceType = Eigen::DefaultDevice;
};
#ifndef PADDLE_ONLY_CPU
template <>
struct EigenDeviceConverter<platform::GPUPlace> {
using EigenDeviceType = Eigen::GpuDevice;
};
#endif
class OperatorBase; class OperatorBase;
using OperatorPtr = std::shared_ptr<OperatorBase>; using OperatorPtr = std::shared_ptr<OperatorBase>;
/** /**
...@@ -131,6 +146,13 @@ class KernelContext { ...@@ -131,6 +146,13 @@ class KernelContext {
return res; return res;
} }
template <typename PlaceType,
typename DeviceType =
typename EigenDeviceConverter<PlaceType>::EigenDeviceType>
DeviceType* GetEigenDevice() const;
platform::Place GetPlace() const { return device_context_.GetPlace(); }
const OperatorBase& op_; const OperatorBase& op_;
const std::shared_ptr<Scope>& scope_; const std::shared_ptr<Scope>& scope_;
const platform::DeviceContext& device_context_; const platform::DeviceContext& device_context_;
...@@ -144,6 +166,7 @@ class OpKernel { ...@@ -144,6 +166,7 @@ class OpKernel {
* device resource such as CUDA stream, cublas handle, etc. from * device resource such as CUDA stream, cublas handle, etc. from
* KernelContext. User should construct it before run the Operator. * KernelContext. User should construct it before run the Operator.
*/ */
virtual void Compute(const KernelContext& context) const = 0; virtual void Compute(const KernelContext& context) const = 0;
virtual ~OpKernel() {} virtual ~OpKernel() {}
......
...@@ -20,8 +20,10 @@ limitations under the License. */ ...@@ -20,8 +20,10 @@ limitations under the License. */
#include <typeindex> #include <typeindex>
#include "paddle/framework/ddim.h" #include "paddle/framework/ddim.h"
#include "paddle/framework/enforce.h" #include "paddle/framework/enforce.h"
#include "paddle/framework/tensor_types.h"
#include "paddle/memory/memory.h" #include "paddle/memory/memory.h"
#include "paddle/platform/place.h" #include "paddle/platform/place.h"
#include "unsupported/Eigen/CXX11/Tensor"
namespace paddle { namespace paddle {
namespace pybind { namespace pybind {
...@@ -43,6 +45,13 @@ class Tensor { ...@@ -43,6 +45,13 @@ class Tensor {
reinterpret_cast<uintptr_t>(holder_->ptr()) + offset_); reinterpret_cast<uintptr_t>(holder_->ptr()) + offset_);
} }
template <typename T>
T* raw_data() const {
CheckDims<T>();
return reinterpret_cast<T*>(reinterpret_cast<uintptr_t>(holder_->ptr()) +
offset_);
}
template <typename T> template <typename T>
T* mutable_data(DDim dims, platform::Place place) { T* mutable_data(DDim dims, platform::Place place) {
set_dims(dims); set_dims(dims);
...@@ -77,6 +86,66 @@ class Tensor { ...@@ -77,6 +86,66 @@ class Tensor {
offset_); offset_);
} }
template <typename T, size_t NDIMS>
typename TTypes<T, NDIMS>::Tensor shaped(DDim new_dims) {
Eigen::array<Eigen::DenseIndex, NDIMS> dims =
paddle::framework::ToEigenDSizes<NDIMS>(new_dims);
return typename TTypes<T, NDIMS>::Tensor(raw_data<T>(), dims);
}
template <typename T, size_t NDIMS>
typename TTypes<T, NDIMS>::Tensor tensor() {
return typename TTypes<T, NDIMS>::Tensor(
raw_data<T>(), paddle::framework::ToEigenDSizes<NDIMS>(dims_));
}
// flat to rank = 1
template <typename T>
typename TTypes<T>::Flat flat() {
return shaped<T, 1>(make_ddim({static_cast<int>(product(dims_))}));
}
// to TensorType Vec
template <typename T>
typename TTypes<T>::Vec vec() {
return tensor<T, 1>();
}
// to TensorType Matrix
template <typename T>
typename TTypes<T>::Matrix matrix() {
return tensor<T, 2>();
}
// const versions of all the methods above.
template <typename T, size_t NDIMS>
typename TTypes<T, NDIMS>::Tensor shaped(DDim new_dims) const {
Eigen::array<Eigen::DenseIndex, NDIMS> dims =
paddle::framework::ToEigenDSizes<NDIMS>(new_dims);
return typename TTypes<T, NDIMS>::Tensor(data<T>(), dims);
}
template <typename T, size_t NDIMS>
typename TTypes<T, NDIMS>::ConstantTensor tensor() const {
return typename TTypes<T, NDIMS>::Tensor(
data<T>(), paddle::framework::ToEigenDSizes<NDIMS>(dims_));
}
template <typename T>
typename TTypes<T>::ConstFlat flat() const {
return shaped<T, 1>(make_ddim({static_cast<int>(product(dims_))}));
}
template <typename T>
typename TTypes<T>::ConstVec vec() const {
return tensor<T, 1>();
}
template <typename T>
typename TTypes<T>::ConstMatrix matrix() const {
return tensor<T, 2>();
}
template <typename T> template <typename T>
void ShareDataFrom(const Tensor& src) { void ShareDataFrom(const Tensor& src) {
src.CheckDims<T>(); src.CheckDims<T>();
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "unsupported/Eigen/CXX11/Tensor"
namespace paddle {
namespace framework {
// Helper to define Tensor types given that the scalar is of type T.
template <typename T, int NDIMS = 1, typename IndexType = Eigen::DenseIndex>
struct TTypes {
// Rank-<NDIMS> tensor of scalar type T.
typedef Eigen::TensorMap<Eigen::Tensor<T, NDIMS, Eigen::RowMajor, IndexType>,
Eigen::Aligned>
Tensor;
typedef Eigen::TensorMap<
Eigen::Tensor<const T, NDIMS, Eigen::RowMajor, IndexType>, Eigen::Aligned>
ConstTensor;
// Scalar tensor (implemented as a rank-0 tensor) of scalar type T.
typedef Eigen::TensorMap<
Eigen::TensorFixedSize<T, Eigen::Sizes<>, Eigen::RowMajor, IndexType>,
Eigen::Aligned>
Scalar;
typedef Eigen::TensorMap<Eigen::TensorFixedSize<const T, Eigen::Sizes<>,
Eigen::RowMajor, IndexType>,
Eigen::Aligned>
ConstScalar;
// Rank-1 tensor (vector) of scalar type T.
typedef Eigen::TensorMap<Eigen::Tensor<T, 1, Eigen::RowMajor, IndexType>,
Eigen::Aligned>
Flat;
typedef Eigen::TensorMap<
Eigen::Tensor<const T, 1, Eigen::RowMajor, IndexType>, Eigen::Aligned>
ConstFlat;
typedef Eigen::TensorMap<Eigen::Tensor<T, 1, Eigen::RowMajor, IndexType>,
Eigen::Aligned>
Vec;
typedef Eigen::TensorMap<
Eigen::Tensor<const T, 1, Eigen::RowMajor, IndexType>, Eigen::Aligned>
ConstVec;
// Rank-2 tensor (matrix) of scalar type T.
typedef Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor, IndexType>,
Eigen::Aligned>
Matrix;
typedef Eigen::TensorMap<
Eigen::Tensor<const T, 2, Eigen::RowMajor, IndexType>, Eigen::Aligned>
ConstMatrix;
};
} // namespace framework
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <paddle/framework/op_registry.h> #include "paddle/operators/add_op.h"
#include <paddle/framework/tensor.h> #include "paddle/framework/op_registry.h"
#include <paddle/operators/add_op.h> #include "paddle/framework/tensor.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -53,5 +53,6 @@ The equation is: Out = X + Y ...@@ -53,5 +53,6 @@ The equation is: Out = X + Y
} // namespace paddle } // namespace paddle
REGISTER_OP(add_two, paddle::operators::AddOp, paddle::operators::AddOpMaker); REGISTER_OP(add_two, paddle::operators::AddOp, paddle::operators::AddOpMaker);
REGISTER_OP_CPU_KERNEL( typedef paddle::operators::AddKernel<::paddle::platform::CPUPlace, float>
add_two, ::paddle::operators::AddKernel<::paddle::platform::CPUPlace>); AddKernel_CPU_float;
REGISTER_OP_CPU_KERNEL(add_two, AddKernel_CPU_float);
#include <paddle/operators/add_op.h> #include "paddle/operators/add_op.h"
#include <paddle/framework/op_registry.h> #include "paddle/framework/op_registry.h"
typedef paddle::operators::AddKernel<::paddle::platform::GPUPlace, float> AddKernel_GPU_float;
REGISTER_OP_GPU_KERNEL(add_two, REGISTER_OP_GPU_KERNEL(add_two,
paddle::operators::AddKernel<paddle::platform::GPUPlace>); AddKernel_GPU_float);
\ No newline at end of file \ No newline at end of file
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once #pragma once
#include <glog/logging.h> #include "glog/logging.h"
#include <paddle/framework/operator.h> #include "paddle/framework/operator.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename Place> template <typename Place, typename T>
class AddKernel : public framework::OpKernel { class AddKernel : public framework::OpKernel {
public: public:
void Compute(const framework::KernelContext &context) const override { void Compute(const framework::KernelContext& context) const override {
LOG(INFO) << "Add kernel in " << typeid(Place).name(); auto input0 = context.Input(0)->Get<framework::Tensor>();
auto input1 = context.Input(1)->Get<framework::Tensor>();
auto* output = context.Output(0)->GetMutable<framework::Tensor>();
output->mutable_data<T>(context.GetPlace());
output->flat<T>().device(*(context.GetEigenDevice<Place>())) =
input0.flat<T>() + input1.flat<T>();
} }
}; };
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <gtest/gtest.h> #include <gtest/gtest.h>
#define private public #define private public
#include <paddle/framework/op_registry.h> #include <paddle/framework/op_registry.h>
......
...@@ -15,14 +15,15 @@ namespace paddle { ...@@ -15,14 +15,15 @@ namespace paddle {
namespace platform { namespace platform {
template <> template <>
Eigen::DefaultDevice* DeviceContext::get_eigen_device<Eigen::DefaultDevice>() { Eigen::DefaultDevice* DeviceContext::get_eigen_device<Eigen::DefaultDevice>()
return reinterpret_cast<CPUDeviceContext*>(this)->eigen_device(); const {
return reinterpret_cast<const CPUDeviceContext*>(this)->eigen_device();
} }
#ifndef PADDLE_ONLY_CPU #ifndef PADDLE_ONLY_CPU
template <> template <>
Eigen::GpuDevice* DeviceContext::get_eigen_device<Eigen::GpuDevice>() { Eigen::GpuDevice* DeviceContext::get_eigen_device<Eigen::GpuDevice>() const {
return reinterpret_cast<CUDADeviceContext*>(this)->eigen_device(); return reinterpret_cast<const CUDADeviceContext*>(this)->eigen_device();
} }
#endif #endif
......
...@@ -20,9 +20,9 @@ limitations under the License. */ ...@@ -20,9 +20,9 @@ limitations under the License. */
#include "paddle/platform/gpu_info.h" #include "paddle/platform/gpu_info.h"
#define EIGEN_USE_GPU #define EIGEN_USE_GPU
#endif #endif
#include <paddle/platform/place.h>
#include <memory> #include <memory>
#include <unsupported/Eigen/CXX11/Tensor> #include "paddle/platform/place.h"
#include "unsupported/Eigen/CXX11/Tensor"
namespace paddle { namespace paddle {
namespace platform { namespace platform {
...@@ -33,17 +33,14 @@ class DeviceContext { ...@@ -33,17 +33,14 @@ class DeviceContext {
virtual Place GetPlace() const = 0; virtual Place GetPlace() const = 0;
template <typename DeviceType> template <typename DeviceType>
DeviceType* get_eigen_device(); DeviceType* get_eigen_device() const;
}; };
class CPUDeviceContext : public DeviceContext { class CPUDeviceContext : public DeviceContext {
public: public:
Eigen::DefaultDevice* eigen_device() { CPUDeviceContext() { eigen_device_.reset(new Eigen::DefaultDevice()); }
if (!eigen_device_) {
eigen_device_.reset(new Eigen::DefaultDevice()); Eigen::DefaultDevice* eigen_device() const { return eigen_device_.get(); }
}
return eigen_device_.get();
}
Place GetPlace() const override { Place GetPlace() const override {
Place retv = CPUPlace(); Place retv = CPUPlace();
...@@ -92,7 +89,7 @@ class CUDADeviceContext : public DeviceContext { ...@@ -92,7 +89,7 @@ class CUDADeviceContext : public DeviceContext {
cudaStream_t stream() { return stream_; } cudaStream_t stream() { return stream_; }
Eigen::GpuDevice* eigen_device() { return eigen_device_.get(); } Eigen::GpuDevice* eigen_device() const { return eigen_device_.get(); }
cublasHandle_t cublas_handle() { cublasHandle_t cublas_handle() {
if (!blas_handle_) { if (!blas_handle_) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册