diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt index e7d1c7203aa3422b3f5dab4a83da4e175219ba81..eb3416462324edf6f6e76e32d7400d1fd774b9bd 100644 --- a/paddle/framework/CMakeLists.txt +++ b/paddle/framework/CMakeLists.txt @@ -1,6 +1,7 @@ +# ddim lib cc_library(enforce SRCS enforce.cc DEPS glog) cc_test(enforce_test SRCS enforce_test.cc DEPS enforce) -cc_library(ddim SRCS ddim.cc) +cc_library(ddim SRCS ddim.cc DEPS eigen3) cc_test(ddim_test SRCS ddim_test.cc DEPS ddim) nv_test(dim_test SRCS dim_test.cu DEPS ddim) cc_library(tensor SRCS tensor.cc DEPS ddim place enforce paddle_memory) diff --git a/paddle/framework/ddim.h b/paddle/framework/ddim.h index 3976c6c0299c489764c7ccc209bef0a84736be12..070850375d1bd3a61b98184495c979573bf9542c 100644 --- a/paddle/framework/ddim.h +++ b/paddle/framework/ddim.h @@ -1,11 +1,26 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + #pragma once #include #include #include #include - #include "paddle/framework/dim.h" +#include "paddle/framework/enforce.h" +#include "unsupported/Eigen/CXX11/Tensor" namespace paddle { namespace framework { @@ -104,6 +119,17 @@ int arity(const DDim& ddim); std::ostream& operator<<(std::ostream&, const DDim&); +template +Eigen::DSizes ToEigenDSizes(const DDim& dims) { + int rank = arity(dims); + PADDLE_ENFORCE(rank == NDIMS, "DDim and NDIMS must be same"); + Eigen::DSizes dsizes; + for (int d = 0; d < rank; d++) { + dsizes[d] = dims[d]; + } + return dsizes; +} + } // namespace framework } // namespace paddle diff --git a/paddle/framework/operator.cc b/paddle/framework/operator.cc index 36479830535cdd49c93d965e6b68981012097b71..1e57e9a20f3eecfac266d67276347ad4b5b780f9 100644 --- a/paddle/framework/operator.cc +++ b/paddle/framework/operator.cc @@ -19,6 +19,20 @@ limitations under the License. */ namespace paddle { namespace framework { +template <> +Eigen::DefaultDevice* KernelContext::GetEigenDevice< + platform::CPUPlace, Eigen::DefaultDevice>() const { + return device_context_.get_eigen_device(); +} + +#ifndef PADDLE_ONLY_CPU +template <> +Eigen::GpuDevice* +KernelContext::GetEigenDevice() const { + return device_context_.get_eigen_device(); +} +#endif + const std::string& OperatorBase::Input(const std::string& name) const { auto it = in_out_idxs_->find(name); PADDLE_ENFORCE(it != in_out_idxs_->end(), "no key [%s] in in_out_idxs_", diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index 2081b8a05c197f3fe1451f7e58d2e6f1748120a3..1071fcc0e2884beb4ce9ba46429ae87e9d72c4c1 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -31,6 +31,21 @@ limitations under the License. */ namespace paddle { namespace framework { +template +struct EigenDeviceConverter; + +template <> +struct EigenDeviceConverter { + using EigenDeviceType = Eigen::DefaultDevice; +}; + +#ifndef PADDLE_ONLY_CPU +template <> +struct EigenDeviceConverter { + using EigenDeviceType = Eigen::GpuDevice; +}; +#endif + class OperatorBase; using OperatorPtr = std::shared_ptr; /** @@ -131,6 +146,13 @@ class KernelContext { return res; } + template ::EigenDeviceType> + DeviceType* GetEigenDevice() const; + + platform::Place GetPlace() const { return device_context_.GetPlace(); } + const OperatorBase& op_; const std::shared_ptr& scope_; const platform::DeviceContext& device_context_; @@ -144,6 +166,7 @@ class OpKernel { * device resource such as CUDA stream, cublas handle, etc. from * KernelContext. User should construct it before run the Operator. */ + virtual void Compute(const KernelContext& context) const = 0; virtual ~OpKernel() {} diff --git a/paddle/framework/tensor.h b/paddle/framework/tensor.h index 3dcd8d08970e16539cadeef23ef07f153483937d..4f07350e59dea72431417876f41f172e51ea53f9 100644 --- a/paddle/framework/tensor.h +++ b/paddle/framework/tensor.h @@ -20,8 +20,10 @@ limitations under the License. */ #include #include "paddle/framework/ddim.h" #include "paddle/framework/enforce.h" +#include "paddle/framework/tensor_types.h" #include "paddle/memory/memory.h" #include "paddle/platform/place.h" +#include "unsupported/Eigen/CXX11/Tensor" namespace paddle { namespace pybind { @@ -43,6 +45,13 @@ class Tensor { reinterpret_cast(holder_->ptr()) + offset_); } + template + T* raw_data() const { + CheckDims(); + return reinterpret_cast(reinterpret_cast(holder_->ptr()) + + offset_); + } + template T* mutable_data(DDim dims, platform::Place place) { set_dims(dims); @@ -77,6 +86,66 @@ class Tensor { offset_); } + template + typename TTypes::Tensor shaped(DDim new_dims) { + Eigen::array dims = + paddle::framework::ToEigenDSizes(new_dims); + return typename TTypes::Tensor(raw_data(), dims); + } + + template + typename TTypes::Tensor tensor() { + return typename TTypes::Tensor( + raw_data(), paddle::framework::ToEigenDSizes(dims_)); + } + + // flat to rank = 1 + template + typename TTypes::Flat flat() { + return shaped(make_ddim({static_cast(product(dims_))})); + } + + // to TensorType Vec + template + typename TTypes::Vec vec() { + return tensor(); + } + + // to TensorType Matrix + template + typename TTypes::Matrix matrix() { + return tensor(); + } + + // const versions of all the methods above. + template + typename TTypes::Tensor shaped(DDim new_dims) const { + Eigen::array dims = + paddle::framework::ToEigenDSizes(new_dims); + return typename TTypes::Tensor(data(), dims); + } + + template + typename TTypes::ConstantTensor tensor() const { + return typename TTypes::Tensor( + data(), paddle::framework::ToEigenDSizes(dims_)); + } + + template + typename TTypes::ConstFlat flat() const { + return shaped(make_ddim({static_cast(product(dims_))})); + } + + template + typename TTypes::ConstVec vec() const { + return tensor(); + } + + template + typename TTypes::ConstMatrix matrix() const { + return tensor(); + } + template void ShareDataFrom(const Tensor& src) { src.CheckDims(); diff --git a/paddle/framework/tensor_types.h b/paddle/framework/tensor_types.h new file mode 100644 index 0000000000000000000000000000000000000000..4bf27a377e828a56f9679e6698d314457d7caf0b --- /dev/null +++ b/paddle/framework/tensor_types.h @@ -0,0 +1,67 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "unsupported/Eigen/CXX11/Tensor" + +namespace paddle { +namespace framework { + +// Helper to define Tensor types given that the scalar is of type T. +template +struct TTypes { + // Rank- tensor of scalar type T. + typedef Eigen::TensorMap, + Eigen::Aligned> + Tensor; + typedef Eigen::TensorMap< + Eigen::Tensor, Eigen::Aligned> + ConstTensor; + + // Scalar tensor (implemented as a rank-0 tensor) of scalar type T. + typedef Eigen::TensorMap< + Eigen::TensorFixedSize, Eigen::RowMajor, IndexType>, + Eigen::Aligned> + Scalar; + typedef Eigen::TensorMap, + Eigen::RowMajor, IndexType>, + Eigen::Aligned> + ConstScalar; + + // Rank-1 tensor (vector) of scalar type T. + typedef Eigen::TensorMap, + Eigen::Aligned> + Flat; + typedef Eigen::TensorMap< + Eigen::Tensor, Eigen::Aligned> + ConstFlat; + typedef Eigen::TensorMap, + Eigen::Aligned> + Vec; + typedef Eigen::TensorMap< + Eigen::Tensor, Eigen::Aligned> + ConstVec; + + // Rank-2 tensor (matrix) of scalar type T. + typedef Eigen::TensorMap, + Eigen::Aligned> + Matrix; + typedef Eigen::TensorMap< + Eigen::Tensor, Eigen::Aligned> + ConstMatrix; +}; + +} // namespace framework +} // namespace paddle diff --git a/paddle/operators/add_op.cc b/paddle/operators/add_op.cc index 355c92a50481fb00e81da94381fa1944f1825ed7..41d044cdb72b5fb2a7f8654e8ad103778e0857d1 100644 --- a/paddle/operators/add_op.cc +++ b/paddle/operators/add_op.cc @@ -1,20 +1,20 @@ /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 + http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. */ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ -#include -#include -#include +#include "paddle/operators/add_op.h" +#include "paddle/framework/op_registry.h" +#include "paddle/framework/tensor.h" namespace paddle { namespace operators { @@ -53,5 +53,6 @@ The equation is: Out = X + Y } // namespace paddle REGISTER_OP(add_two, paddle::operators::AddOp, paddle::operators::AddOpMaker); -REGISTER_OP_CPU_KERNEL( - add_two, ::paddle::operators::AddKernel<::paddle::platform::CPUPlace>); +typedef paddle::operators::AddKernel<::paddle::platform::CPUPlace, float> + AddKernel_CPU_float; +REGISTER_OP_CPU_KERNEL(add_two, AddKernel_CPU_float); diff --git a/paddle/operators/add_op.cu b/paddle/operators/add_op.cu index 5979345fffd68d71ba09dc96874d8ff9471bdbcc..0edf142ee4e5f359ea14be02dbf3f7f8855f6db1 100644 --- a/paddle/operators/add_op.cu +++ b/paddle/operators/add_op.cu @@ -1,5 +1,6 @@ -#include -#include +#include "paddle/operators/add_op.h" +#include "paddle/framework/op_registry.h" +typedef paddle::operators::AddKernel<::paddle::platform::GPUPlace, float> AddKernel_GPU_float; REGISTER_OP_GPU_KERNEL(add_two, - paddle::operators::AddKernel); \ No newline at end of file + AddKernel_GPU_float); \ No newline at end of file diff --git a/paddle/operators/add_op.h b/paddle/operators/add_op.h index 000564f66dff2b20d17cee10e44aaca114c2c908..e08b3fb18775e2536a13bc838f40472c5c3e7ff7 100644 --- a/paddle/operators/add_op.h +++ b/paddle/operators/add_op.h @@ -1,15 +1,36 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + #pragma once -#include -#include +#include "glog/logging.h" +#include "paddle/framework/operator.h" namespace paddle { namespace operators { -template +template class AddKernel : public framework::OpKernel { public: - void Compute(const framework::KernelContext &context) const override { - LOG(INFO) << "Add kernel in " << typeid(Place).name(); + void Compute(const framework::KernelContext& context) const override { + auto input0 = context.Input(0)->Get(); + auto input1 = context.Input(1)->Get(); + auto* output = context.Output(0)->GetMutable(); + + output->mutable_data(context.GetPlace()); + + output->flat().device(*(context.GetEigenDevice())) = + input0.flat() + input1.flat(); } }; diff --git a/paddle/operators/add_op_test.cc b/paddle/operators/add_op_test.cc index f554ac1bef3255f136ad4407a7a1096bdc2b1db5..53b354fedcacf2176aed8b504daf2046bdf96bb6 100644 --- a/paddle/operators/add_op_test.cc +++ b/paddle/operators/add_op_test.cc @@ -1,3 +1,17 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + #include #define private public #include diff --git a/paddle/platform/device_context.cc b/paddle/platform/device_context.cc index 960ef0a5955bfe5f7d33b7c8e4524176b0dbfda6..9c1d94e9e703caf2db92ca4a8eac975317e6b945 100644 --- a/paddle/platform/device_context.cc +++ b/paddle/platform/device_context.cc @@ -15,14 +15,15 @@ namespace paddle { namespace platform { template <> -Eigen::DefaultDevice* DeviceContext::get_eigen_device() { - return reinterpret_cast(this)->eigen_device(); +Eigen::DefaultDevice* DeviceContext::get_eigen_device() + const { + return reinterpret_cast(this)->eigen_device(); } #ifndef PADDLE_ONLY_CPU template <> -Eigen::GpuDevice* DeviceContext::get_eigen_device() { - return reinterpret_cast(this)->eigen_device(); +Eigen::GpuDevice* DeviceContext::get_eigen_device() const { + return reinterpret_cast(this)->eigen_device(); } #endif diff --git a/paddle/platform/device_context.h b/paddle/platform/device_context.h index 51c8e1391324649d8e845e902a5632f6bca1fa58..f226a75c20b7a75e5f884cd158d139ebb8b34e47 100644 --- a/paddle/platform/device_context.h +++ b/paddle/platform/device_context.h @@ -20,9 +20,9 @@ limitations under the License. */ #include "paddle/platform/gpu_info.h" #define EIGEN_USE_GPU #endif -#include #include -#include +#include "paddle/platform/place.h" +#include "unsupported/Eigen/CXX11/Tensor" namespace paddle { namespace platform { @@ -33,17 +33,14 @@ class DeviceContext { virtual Place GetPlace() const = 0; template - DeviceType* get_eigen_device(); + DeviceType* get_eigen_device() const; }; class CPUDeviceContext : public DeviceContext { public: - Eigen::DefaultDevice* eigen_device() { - if (!eigen_device_) { - eigen_device_.reset(new Eigen::DefaultDevice()); - } - return eigen_device_.get(); - } + CPUDeviceContext() { eigen_device_.reset(new Eigen::DefaultDevice()); } + + Eigen::DefaultDevice* eigen_device() const { return eigen_device_.get(); } Place GetPlace() const override { Place retv = CPUPlace(); @@ -92,7 +89,7 @@ class CUDADeviceContext : public DeviceContext { cudaStream_t stream() { return stream_; } - Eigen::GpuDevice* eigen_device() { return eigen_device_.get(); } + Eigen::GpuDevice* eigen_device() const { return eigen_device_.get(); } cublasHandle_t cublas_handle() { if (!blas_handle_) {