From 958511160bc42fee48c9ad775dfb08e5198bf3e9 Mon Sep 17 00:00:00 2001 From: qijun Date: Tue, 11 Jul 2017 13:40:44 +0800 Subject: [PATCH] add simple add_op_functor --- paddle/framework/ddim.cc | 12 ++++++++ paddle/framework/ddim.h | 8 +----- paddle/framework/tensor.h | 47 +++++++++++++++++++++++++++++-- paddle/framework/tensor_types.h | 14 +++++++++ paddle/operators/add_op_functor.h | 35 +++++++++++++++++++++++ 5 files changed, 107 insertions(+), 9 deletions(-) create mode 100644 paddle/operators/add_op_functor.h diff --git a/paddle/framework/ddim.cc b/paddle/framework/ddim.cc index 3f949a6595e..9431645cf5b 100644 --- a/paddle/framework/ddim.cc +++ b/paddle/framework/ddim.cc @@ -1,4 +1,5 @@ #include "paddle/framework/ddim.h" +#include "paddle/framework/enforce.h" namespace paddle { namespace framework { @@ -220,5 +221,16 @@ std::ostream& operator<<(std::ostream& os, const DDim& ddim) { return os; } +template +Eigen::DSizes ToEigenDSizes(DDim dims) const { + int rank = paddle::framework::arity(dims); + PADDLE_ENFORCE(rank == NDIMS, "DDim and NDIMS must be same") + Eigen::DSizes dsizes; + for (int d = 0; d < paddle::framework::arity(dims); d++) { + dsizes[d] = dims[d]; + } + return dsizes; +} + } // namespace framework } // namespace paddle diff --git a/paddle/framework/ddim.h b/paddle/framework/ddim.h index 053a09d63a2..a83a367196d 100644 --- a/paddle/framework/ddim.h +++ b/paddle/framework/ddim.h @@ -93,13 +93,7 @@ int arity(const DDim& ddim); std::ostream& operator<<(std::ostream&, const DDim&); template -Eigen::DSizes ToEigenDSizes(DDim dims) const { - Eigen::DSizes dsizes; - for (int d = 0; d < paddle::framework::arity(dims); d++) { - dsizes[d] = dims[d]; - } - return dsizes; -} +Eigen::DSizes ToEigenDSizes(DDim dims) const; } // namespace framework } // namespace paddle diff --git a/paddle/framework/tensor.h b/paddle/framework/tensor.h index 81af4306114..0fa74e7ab13 100644 --- a/paddle/framework/tensor.h +++ b/paddle/framework/tensor.h @@ -57,18 +57,61 @@ class Tensor { DDim dim() const { return dims_; } + size_t NumElements() const { return product(dims_); } + + template + typename TTypes::Tensor Tensor::shaped(DDim new_dims) { + Eigen::array dims = + paddle::framework::ToEigenDSizes(new_dims); + return typename TTypes::Tensor(data(), dims); + } + template - typename TTypes::ConstantTensor Tensor::tensor() { + typename TTypes::Tensor Tensor::tensor() { return typename TTypes::Tensor( data(), paddle::framework::ToEigenDSizes(dims_)); } + // flat to rank = 1 + template + typename TTypes::Flat flat() { + return shaped({NumElements()}); + } + + // to TensorType Vec + template + typename TTypes::Vec vec() { + return tensor(); + } + + // to TensorType Matrix + template + typename TTypes::Matrix matrix() { + return tensor(); + } + + // const versions of all the methods above. template - typename TTypes::Tensor Tensor::tensor() { + typename TTypes::ConstantTensor Tensor::tensor() const { return typename TTypes::Tensor( data(), paddle::framework::ToEigenDSizes(dims_)); } + template + typename TTypes::ConstFlat flat() const { + return shaped({NumElements()}); + } + + template + typename TTypes::ConstVec vec() const { + return tensor(); + } + + template + typename TTypes::ConstMatrix matrix() const { + return tensor(); + } + private: // Placeholder hides type T, so it doesn't appear as a template // parameter of Variable. diff --git a/paddle/framework/tensor_types.h b/paddle/framework/tensor_types.h index 26de25b7c26..4bf27a377e8 100644 --- a/paddle/framework/tensor_types.h +++ b/paddle/framework/tensor_types.h @@ -1,3 +1,17 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + #pragma once #include "unsupported/Eigen/CXX11/Tensor" diff --git a/paddle/operators/add_op_functor.h b/paddle/operators/add_op_functor.h new file mode 100644 index 00000000000..904f24b0305 --- /dev/null +++ b/paddle/operators/add_op_functor.h @@ -0,0 +1,35 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/framework/tensor_types.h" +#include "unsupported/Eigen/CXX11/Tensor" + +namespace paddle { +namespace operators { +namespace functor { + +template +struct Add { + void Operator()(const Device& d, + typename TTypes::ConstTensor input1, + typename TTypes::ConstTensor input2, + typename TTypes::Tensor output) { + output.device(d) = input1 + input2; + } +}; +} // namespace functor +} // namespace operators +} // namespace paddle -- GitLab