diff --git a/CMakeLists.txt b/CMakeLists.txt index dcff6b54cafce35846627e78cfcdac65fae7e686..2a6b0a20e441676c85c9ed8f8ad1a6e7abdf1ea8 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -13,7 +13,6 @@ # limitations under the License cmake_minimum_required(VERSION 3.0) - set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${CMAKE_CURRENT_SOURCE_DIR}/cmake") set(PROJ_ROOT ${CMAKE_CURRENT_SOURCE_DIR}) set(PROJ_BINARY_ROOT ${CMAKE_CURRENT_BINARY_DIR}) diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt index eb3416462324edf6f6e76e32d7400d1fd774b9bd..a00b9c81906fb1194b51efc50e6255f092875281 100644 --- a/paddle/framework/CMakeLists.txt +++ b/paddle/framework/CMakeLists.txt @@ -4,8 +4,11 @@ cc_test(enforce_test SRCS enforce_test.cc DEPS enforce) cc_library(ddim SRCS ddim.cc DEPS eigen3) cc_test(ddim_test SRCS ddim_test.cc DEPS ddim) nv_test(dim_test SRCS dim_test.cu DEPS ddim) + cc_library(tensor SRCS tensor.cc DEPS ddim place enforce paddle_memory) cc_test(tensor_test SRCS tensor_test.cc DEPS tensor) +cc_test(eigen_test SRCS eigen_test.cc DEPS tensor) + cc_test(variable_test SRCS variable_test.cc) cc_test(scope_test SRCS scope_test.cc) proto_library(attr_type SRCS attr_type.proto) diff --git a/paddle/framework/attr_checker.h b/paddle/framework/attr_checker.h index c0c33d81149ac2fc2a9a57d90931ef32375fe1d0..f2d88f3cb00e20f548a5cd412b515e843491a76d 100644 --- a/paddle/framework/attr_checker.h +++ b/paddle/framework/attr_checker.h @@ -4,6 +4,7 @@ #include #include #include +#include #include #include "paddle/framework/enforce.h" @@ -41,6 +42,35 @@ class DefaultValueSetter { T default_value_; }; +template +class EnumInContainer { + public: + explicit EnumInContainer(const std::unordered_set& c) : container_(c) {} + void operator()(T& val) const { + PADDLE_ENFORCE(container_.find(val) != container_.end(), + "Value %s is not in enum container %s", val, + ContainerDebugString()); + } + + private: + std::string ContainerDebugString() const { + std::ostringstream sout; + sout << "["; + size_t cnt = 0; + for (auto& v : container_) { + sout << v; + ++cnt; + if (cnt != container_.size()) { + sout << " ,"; + } + } + sout << "]"; + return sout.str(); + } + + std::unordered_set container_; +}; + // check whether a certain attribute fit its limits // an attribute can have more than one limits template @@ -50,6 +80,11 @@ class TypedAttrChecker { public: TypedAttrChecker(const std::string& attr_name) : attr_name_(attr_name) {} + TypedAttrChecker& InEnum(const std::unordered_set& range) { + value_checkers_.push_back(EnumInContainer(range)); + return *this; + } + TypedAttrChecker& LargerThan(const T& lower_bound) { value_checkers_.push_back(LargerThanChecker(lower_bound)); return *this; diff --git a/paddle/framework/ddim.h b/paddle/framework/ddim.h index 070850375d1bd3a61b98184495c979573bf9542c..06c4c583b3afa9b472561e5d8166cef3398c57f4 100644 --- a/paddle/framework/ddim.h +++ b/paddle/framework/ddim.h @@ -119,17 +119,6 @@ int arity(const DDim& ddim); std::ostream& operator<<(std::ostream&, const DDim&); -template -Eigen::DSizes ToEigenDSizes(const DDim& dims) { - int rank = arity(dims); - PADDLE_ENFORCE(rank == NDIMS, "DDim and NDIMS must be same"); - Eigen::DSizes dsizes; - for (int d = 0; d < rank; d++) { - dsizes[d] = dims[d]; - } - return dsizes; -} - } // namespace framework } // namespace paddle diff --git a/paddle/framework/eigen.h b/paddle/framework/eigen.h new file mode 100644 index 0000000000000000000000000000000000000000..4ba4fd4d110330805faf2468bd406cb23c6f1b1c --- /dev/null +++ b/paddle/framework/eigen.h @@ -0,0 +1,84 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/framework/tensor.h" +#include "unsupported/Eigen/CXX11/Tensor" + +namespace paddle { +namespace framework { + +// EigenDim converts paddle::platform::DDim into Eigen::DSizes. +template +struct EigenDim { + using Type = Eigen::DSizes; + + static Type From(const DDim& dims) { + PADDLE_ENFORCE(arity(dims) == D, "D must match arity(DDim)"); + Type ret; + for (int d = 0; d < arity(dims); d++) { + ret[d] = dims[d]; + } + return ret; + } +}; + +// Interpret paddle::platform::Tensor as EigenTensor and EigenConstTensor. +template +struct EigenTensor { + // TODO(qijun) Now, default type in unaligned, and we will make a benchmark on + // the speed of aligned and unaligned version in future. + using Type = Eigen::TensorMap>; + + using ConstType = + Eigen::TensorMap>; + + static Type From(Tensor& tensor, DDim dims) { + return Type(tensor.data(), EigenDim::From(dims)); + } + + static Type From(Tensor& tensor) { return From(tensor, tensor.dims_); } + + static ConstType From(const Tensor& tensor, DDim dims) { + return ConstType(tensor.data(), EigenDim::From(dims)); + } + + static ConstType From(const Tensor& tensor) { + return From(tensor, tensor.dims_); + } +}; + +template +struct EigenVector : public EigenTensor { + // Flatten is to reshape a Tensor into a one dimension EigenVector + static typename EigenTensor::Type Flatten(Tensor& tensor) { + return EigenTensor::From( + tensor, make_ddim({static_cast(product(tensor.dims_))})); + } + + static typename EigenTensor::ConstType Flatten(const Tensor& tensor) { + return EigenTensor::From( + tensor, make_ddim({static_cast(product(tensor.dims_))})); + } +}; + +template +using EigenMatrix = EigenTensor; + +} // namespace framework +} // namespace paddle diff --git a/paddle/framework/eigen_test.cc b/paddle/framework/eigen_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..a9fa728e49a0dcc781e520a22c1ee5f921c4c733 --- /dev/null +++ b/paddle/framework/eigen_test.cc @@ -0,0 +1,101 @@ +/* + Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +#include "paddle/framework/eigen.h" +#include + +namespace paddle { +namespace framework { + +TEST(EigenDim, From) { + EigenDim<3>::Type ed = EigenDim<3>::From(make_ddim({1, 2, 3})); + ASSERT_EQ(1, ed[0]); + ASSERT_EQ(2, ed[1]); + ASSERT_EQ(3, ed[2]); +} + +TEST(Eigen, Tensor) { + Tensor t; + float* p = t.mutable_data(make_ddim({1, 2, 3}), platform::CPUPlace()); + for (int i = 0; i < 1 * 2 * 3; i++) { + p[i] = static_cast(i); + } + + EigenTensor::Type et = EigenTensor::From(t); + + ASSERT_EQ(1, et.dimension(0)); + ASSERT_EQ(2, et.dimension(1)); + ASSERT_EQ(3, et.dimension(2)); + + for (int i = 0; i < 1; i++) { + for (int j = 0; j < 2; j++) { + for (int k = 0; k < 3; k++) { + ASSERT_NEAR((i * 2 + j) * 3 + k, et(i, j, k), 1e-6f); + } + } + } +} + +TEST(Eigen, VectorFrom) { + Tensor t; + float* p = t.mutable_data(make_ddim({6}), platform::CPUPlace()); + for (int i = 0; i < 6; i++) { + p[i] = static_cast(i); + } + + EigenVector::Type ev = EigenVector::From(t); + + ASSERT_EQ(6, ev.dimension(0)); + + for (int i = 0; i < 6; i++) { + ASSERT_NEAR(i, ev(i), 1e-6f); + } +} + +TEST(Eigen, VectorFlatten) { + Tensor t; + float* p = t.mutable_data(make_ddim({1, 2, 3}), platform::CPUPlace()); + for (int i = 0; i < 1 * 2 * 3; i++) { + p[i] = static_cast(i); + } + + EigenVector::Type ev = EigenVector::Flatten(t); + + ASSERT_EQ(1 * 2 * 3, ev.dimension(0)); + + for (int i = 0; i < 1 * 2 * 3; i++) { + ASSERT_NEAR(i, ev(i), 1e-6f); + } +} + +TEST(Eigen, Matrix) { + Tensor t; + float* p = t.mutable_data(make_ddim({2, 3}), platform::CPUPlace()); + for (int i = 0; i < 2 * 3; i++) { + p[i] = static_cast(i); + } + + EigenMatrix::Type em = EigenMatrix::From(t); + + ASSERT_EQ(2, em.dimension(0)); + ASSERT_EQ(3, em.dimension(1)); + + for (int i = 0; i < 2; i++) { + for (int j = 0; j < 3; j++) { + ASSERT_NEAR(i * 3 + j, em(i, j), 1e-6f); + } + } +} + +} // namespace framework +} // namespace paddle diff --git a/paddle/framework/net.cc b/paddle/framework/net.cc index 7311cda9a9ad282b21711d8eb0b9ba1cf9542296..501536657d76cc50b1cc4104007edd4b47758aea 100644 --- a/paddle/framework/net.cc +++ b/paddle/framework/net.cc @@ -19,7 +19,10 @@ namespace paddle { namespace framework { -void PlainNet::CompleteAddOp() { +void PlainNet::CompleteAddOp(bool calc) { + add_op_done_ = true; + if (!calc) return; + std::unordered_set input_set; std::unordered_set output_set; std::unordered_set temp_output; @@ -52,7 +55,15 @@ void PlainNet::CompleteAddOp() { } attrs_["temporary_index"] = tmp_index; - add_op_done_ = true; +} + +std::string PlainNet::DebugString() const { + std::ostringstream os; + os << this->type_ << ":" << std::endl; + for (auto& op : ops_) { + os << "\t" << op->DebugString() << std::endl; + } + return os.str(); } } // namespace framework diff --git a/paddle/framework/net.h b/paddle/framework/net.h index 19a1620e29b86fbccfc112a5f85a1784a197dd0b..19c5fa223b4e75b1f06ca14ded053cebfd8bffe2 100644 --- a/paddle/framework/net.h +++ b/paddle/framework/net.h @@ -16,7 +16,6 @@ limitations under the License. */ #include #include -#include "paddle/framework/net_proto.pb.h" #include "paddle/framework/op_proto.pb.h" #include "paddle/framework/op_registry.h" #include "paddle/framework/scope.h" @@ -41,7 +40,7 @@ namespace framework { class Net : public OperatorBase { public: virtual void AddOp(const OperatorPtr& op) = 0; - virtual void CompleteAddOp() = 0; + virtual void CompleteAddOp(bool calc) = 0; }; using NetPtr = std::shared_ptr; @@ -86,7 +85,9 @@ class PlainNet : public Net { ops_.push_back(op); } - void CompleteAddOp() override; + void CompleteAddOp(bool calculate = true) override; + + std::string DebugString() const override; std::vector ops_; diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index 1071fcc0e2884beb4ce9ba46429ae87e9d72c4c1..5f046d6293d5dbb9fd594b0c13aa8d62012cf915 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -72,7 +72,7 @@ class OperatorBase { return boost::get(attrs_.at(name)); } - std::string DebugString() const; + virtual std::string DebugString() const; /// Init will be called after CreateOperator, you can put some initialization /// logic here. diff --git a/paddle/framework/tensor.h b/paddle/framework/tensor.h index 4f07350e59dea72431417876f41f172e51ea53f9..39e0f9f7103dae4f710a80e8c33702094e1ab590 100644 --- a/paddle/framework/tensor.h +++ b/paddle/framework/tensor.h @@ -20,7 +20,6 @@ limitations under the License. */ #include #include "paddle/framework/ddim.h" #include "paddle/framework/enforce.h" -#include "paddle/framework/tensor_types.h" #include "paddle/memory/memory.h" #include "paddle/platform/place.h" #include "unsupported/Eigen/CXX11/Tensor" @@ -35,6 +34,15 @@ struct CastToPyBufferImpl; namespace framework { class Tensor { + template + friend struct paddle::pybind::details::CastToPyBufferImpl; + + template + friend struct EigenTensor; + + template + friend struct EigenVector; + public: Tensor() : offset_(0) {} @@ -46,7 +54,7 @@ class Tensor { } template - T* raw_data() const { + T* data() { CheckDims(); return reinterpret_cast(reinterpret_cast(holder_->ptr()) + offset_); @@ -71,14 +79,14 @@ class Tensor { holder_.reset(new PlaceholderImpl( boost::get(place), product(dims_) * sizeof(T))); } else if (platform::is_gpu_place(place)) { -#ifdef __CUDACC__ +#ifdef PADDLE_ONLY_CPU + PADDLE_THROW("'GPUPlace' is not supported in CPU only device."); +#else holder_.reset(new PlaceholderImpl( boost::get(place), product(dims_) * sizeof(T))); -#else - PADDLE_ENFORCE(true, "'GPUPlace' is not supported in CPU only device."); #endif } else { - PADDLE_ENFORCE(true, "Unknown 'place'."); + PADDLE_THROW("Unknown 'place'."); } offset_ = 0; } @@ -86,66 +94,6 @@ class Tensor { offset_); } - template - typename TTypes::Tensor shaped(DDim new_dims) { - Eigen::array dims = - paddle::framework::ToEigenDSizes(new_dims); - return typename TTypes::Tensor(raw_data(), dims); - } - - template - typename TTypes::Tensor tensor() { - return typename TTypes::Tensor( - raw_data(), paddle::framework::ToEigenDSizes(dims_)); - } - - // flat to rank = 1 - template - typename TTypes::Flat flat() { - return shaped(make_ddim({static_cast(product(dims_))})); - } - - // to TensorType Vec - template - typename TTypes::Vec vec() { - return tensor(); - } - - // to TensorType Matrix - template - typename TTypes::Matrix matrix() { - return tensor(); - } - - // const versions of all the methods above. - template - typename TTypes::Tensor shaped(DDim new_dims) const { - Eigen::array dims = - paddle::framework::ToEigenDSizes(new_dims); - return typename TTypes::Tensor(data(), dims); - } - - template - typename TTypes::ConstantTensor tensor() const { - return typename TTypes::Tensor( - data(), paddle::framework::ToEigenDSizes(dims_)); - } - - template - typename TTypes::ConstFlat flat() const { - return shaped(make_ddim({static_cast(product(dims_))})); - } - - template - typename TTypes::ConstVec vec() const { - return tensor(); - } - - template - typename TTypes::ConstMatrix matrix() const { - return tensor(); - } - template void ShareDataFrom(const Tensor& src) { src.CheckDims(); @@ -251,8 +199,6 @@ class Tensor { std::shared_ptr holder_; // holds the memory block if allocated. DDim dims_; size_t offset_; // marks the begin of tensor data area. - template - friend struct paddle::pybind::details::CastToPyBufferImpl; }; } // namespace framework diff --git a/paddle/framework/tensor_types.h b/paddle/framework/tensor_types.h deleted file mode 100644 index 4bf27a377e828a56f9679e6698d314457d7caf0b..0000000000000000000000000000000000000000 --- a/paddle/framework/tensor_types.h +++ /dev/null @@ -1,67 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#include "unsupported/Eigen/CXX11/Tensor" - -namespace paddle { -namespace framework { - -// Helper to define Tensor types given that the scalar is of type T. -template -struct TTypes { - // Rank- tensor of scalar type T. - typedef Eigen::TensorMap, - Eigen::Aligned> - Tensor; - typedef Eigen::TensorMap< - Eigen::Tensor, Eigen::Aligned> - ConstTensor; - - // Scalar tensor (implemented as a rank-0 tensor) of scalar type T. - typedef Eigen::TensorMap< - Eigen::TensorFixedSize, Eigen::RowMajor, IndexType>, - Eigen::Aligned> - Scalar; - typedef Eigen::TensorMap, - Eigen::RowMajor, IndexType>, - Eigen::Aligned> - ConstScalar; - - // Rank-1 tensor (vector) of scalar type T. - typedef Eigen::TensorMap, - Eigen::Aligned> - Flat; - typedef Eigen::TensorMap< - Eigen::Tensor, Eigen::Aligned> - ConstFlat; - typedef Eigen::TensorMap, - Eigen::Aligned> - Vec; - typedef Eigen::TensorMap< - Eigen::Tensor, Eigen::Aligned> - ConstVec; - - // Rank-2 tensor (matrix) of scalar type T. - typedef Eigen::TensorMap, - Eigen::Aligned> - Matrix; - typedef Eigen::TensorMap< - Eigen::Tensor, Eigen::Aligned> - ConstMatrix; -}; - -} // namespace framework -} // namespace paddle diff --git a/paddle/function/CMakeLists.txt b/paddle/function/CMakeLists.txt index a5b14c0c71c18da1bb0b506c663f8680b1c3830a..2bec00cdb2d32d01a5a24e662bcca07f4154939c 100644 --- a/paddle/function/CMakeLists.txt +++ b/paddle/function/CMakeLists.txt @@ -36,6 +36,7 @@ if(WITH_GPU) add_simple_unittest(MulOpTest) add_simple_unittest(CosSimOpTest) add_simple_unittest(RowConvOpTest) + add_simple_unittest(CropOpTest) endif() add_simple_unittest(ConvOpTest) diff --git a/paddle/function/CropOp.cpp b/paddle/function/CropOp.cpp new file mode 100644 index 0000000000000000000000000000000000000000..f12ee43e3d72f9ac776eaff93914228850694dd2 --- /dev/null +++ b/paddle/function/CropOp.cpp @@ -0,0 +1,177 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "CropOp.h" +#include "paddle/function/TensorShape.h" +#include "paddle/math/Vector.h" + +namespace paddle { + +template <> +void Crop(real* outputs, + const real* inputs, + const TensorShape inShape, + const TensorShape outShape, + const FuncConfig& conf) { + std::vector crop_corner = + conf.get>("crop_corner"); + int cCrop = crop_corner[1]; + int hCrop = crop_corner[2]; + int wCrop = crop_corner[3]; + + int num = inShape[0]; + int inC = inShape[1]; + int inH = inShape[2]; + int inW = inShape[3]; + + int outC = outShape[1]; + int outH = outShape[2]; + int outW = outShape[3]; + + for (int n = 0; n < num; n++) { + for (int c = 0; c < outC; c++) { + for (int h = 0; h < outH; h++) { + int outoff = ((n * outC + c) * outH + h) * outW; + int inoff = ((n * inC + c + cCrop) * inH + h + hCrop) * inW + wCrop; + memcpy(outputs + outoff, inputs + inoff, outW * sizeof(real)); + } + } + } +} + +template <> +void CropGrad(const real* inGrad, + real* outGrad, + const TensorShape inShape, + const TensorShape outShape, + const FuncConfig& conf) { + std::vector crop_corner = + conf.get>("crop_corner"); + int cCrop = crop_corner[1]; + int hCrop = crop_corner[2]; + int wCrop = crop_corner[3]; + + int num = outShape[0]; + int outC = outShape[1]; + int outH = outShape[2]; + int outW = outShape[3]; + + int inC = inShape[1]; + int inH = inShape[2]; + int inW = inShape[3]; + + for (int n = 0; n < num; n++) { + for (int c = 0; c < inC; c++) { + for (int h = 0; h < inH; h++) { + int outoff = ((n * outC + c + cCrop) * outH + h + hCrop) * outW + wCrop; + int inoff = ((n * inC + c) * inH + h) * inW; + CpuVector inG = CpuVector(inW, const_cast(inGrad + inoff)); + CpuVector outG = CpuVector(inW, outGrad + outoff); + outG += inG; + } + } + } +} + +/** + * \brief Crop input according to the specify corner and shape. + * The input and output is a 4D tensor. In CropFunc, we only + * crop the 2nd to 4th dimension. + * + * Argument in this Function: + * \param pad_ A struct object contains the cropping corner and shape. + * \param inputs A 4D tensor, only one input. + * \param outputs A 4D tensor, the output value after cropping. + * + * For example, + * Input(2,2,2,3) = [ + * [ [[1,2,3], [3,4,5]], + * [[2,3,5], [1,6,7]] ], + * [ [[4,3,1], [1,8,7]], + * [[3,8,9], [2,3,5]] ] + * ] # the input shape is (2,2,2,3) + * + * pad_: if corner = (0,1,1) and crop_shape = (2,1,2) + * Output(2,2,1,2) = [ + * [ [[4,5]], + * [[6,7]] ], + * [ [[8,7]], + * [[3,5]] ] + * ] # the input shape is (2,2,2,3) + */ +template +class CropFunc : public FunctionBase { +public: + void init(const FuncConfig& config) override { conf_ = config; } + + void calc(const BufferArgs& inputs, const BufferArgs& outputs) override { + CHECK_EQ(1UL, inputs.size()); + CHECK_EQ(1UL, outputs.size()); + CHECK_EQ(outputs[0].getArgType(), ASSIGN_TO); + + TensorShape inShape = inputs[0].shape(); + TensorShape outShape = outputs[0].shape(); + + Crop(outputs[0].data(), + inputs[0].data(), + inShape, + outShape, + conf_); + } + +private: + FuncConfig conf_; +}; + +/** + * \brief The backward propagation of cropping Function. + * + * Argument in this Function: + * \param crop_ The same meaning as it in CropFunc. + * \param inputs The gradient with respect to the output value of CropFunc. + * \param outputs The gradient with respect to the input value of CropFunc. + */ + +template +class CropGradFunc : public FunctionBase { +public: + void init(const FuncConfig& config) override { conf_ = config; } + + void calc(const BufferArgs& inputs, const BufferArgs& outputs) override { + CHECK_EQ(1UL, inputs.size()); + CHECK_EQ(1UL, outputs.size()); + CHECK_EQ(outputs[0].getArgType(), ADD_TO); + + TensorShape outShape = outputs[0].shape(); + TensorShape inShape = inputs[0].shape(); + + CropGrad(inputs[0].data(), + outputs[0].data(), + inShape, + outShape, + conf_); + } + +private: + FuncConfig conf_; +}; + +REGISTER_TYPED_FUNC(Crop, CPU, CropFunc); +REGISTER_TYPED_FUNC(CropGrad, CPU, CropGradFunc); +#ifndef PADDLE_ONLY_CPU +REGISTER_TYPED_FUNC(Crop, GPU, CropFunc); +REGISTER_TYPED_FUNC(CropGrad, GPU, CropGradFunc); +#endif + +} // namespace paddle diff --git a/paddle/function/CropOp.h b/paddle/function/CropOp.h new file mode 100644 index 0000000000000000000000000000000000000000..87986fbdc7e33aeb24d947e82a5d67ba23f532de --- /dev/null +++ b/paddle/function/CropOp.h @@ -0,0 +1,51 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "Function.h" + +namespace paddle { + +/** + * \brief This funtion crops inputs according to the specify start point and + *shape. + * + * \param[out] outputs save results. + * \param[in] inputs input data. + * \param[in] inShape the shape of input tensor. + * \param[in] conf the cropping config + */ +template +void Crop(real* outputs, + const real* inputs, + const TensorShape inShape, + const TensorShape outShape, + const FuncConfig& conf); + +/** + * \brief Cropping operation backward. + * + * \param[out] inGrad gradients of previous layer + * \param[in] outGrad output gradient + * \param[in] inShape the shape of input tensor. + * \param[in] conf the cropping config + */ +template +void CropGrad(const real* inGrad, + real* outGrad, + const TensorShape inShape, + const TensorShape outShape, + const FuncConfig& conf); +} // namespace paddle diff --git a/paddle/function/CropOpGpu.cu b/paddle/function/CropOpGpu.cu new file mode 100644 index 0000000000000000000000000000000000000000..37ce6de0647e5e06a231710b5a53089533de2407 --- /dev/null +++ b/paddle/function/CropOpGpu.cu @@ -0,0 +1,113 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "hl_base.h" +#include "CropOp.h" + +namespace paddle { + +__global__ void KeCrop(real* outputs, const real* inputs, + int inC, int inH, int inW, + int cropC, int cropH, int cropW, + int outC, int outH, int outW, int nthreads) { + const int idx = threadIdx.x + blockIdx.x * blockDim.x; + if (idx < nthreads) { + const int w = idx % outW; + const int h = (idx / outW) % outH; + const int c = (idx / outW / outH) % outC; + const int n = idx / outW / outH / outC; + + const int off = ((n * inC + c + cropC) * inH + h + cropH) * inW + cropW + w; + outputs[idx] = inputs[off]; + } +} + +template <> +void Crop(real* outputs, + const real* inputs, + const TensorShape inShape, + const TensorShape outShape, + const FuncConfig& conf) { + std::vector crop_corner = conf.get>("crop_corner"); + int cropC = crop_corner[1]; + int cropH = crop_corner[2]; + int cropW = crop_corner[3]; + + int num = inShape[0]; + int inC = inShape[1]; + int inH = inShape[2]; + int inW = inShape[3]; + + int outC = outShape[1]; + int outH = outShape[2]; + int outW = outShape[3]; + + size_t nth = num * outC * outH * outW; + int blockSize = 1024; + int gridSize = (nth + blockSize - 1) / blockSize; + + KeCrop<<>> + (outputs, inputs, inC, inH, inW, cropC, cropH, cropW, + outC, outH, outW, nth); + CHECK_SYNC("Crop"); +} + +__global__ void KeCropDiff(const real* inGrad, real* outGrad, + int inC, int inH, int inW, + int cropC, int cropH, int cropW, + int outC, int outH, int outW, int nthreads) { + const int idx = threadIdx.x + blockIdx.x * blockDim.x; + if (idx < nthreads) { + const int w = idx % inW; + const int h = (idx / inW) % inH; + const int c = (idx / inW / inH) % inC; + const int n = idx / inW / inH / inC; + + const int off = ((n * outC + c + cropC) * outH + h + cropH) * outW + cropW + w; + + outGrad[off] += inGrad[idx]; + } +} + +template <> +void CropGrad(const real* inGrad, + real* outGrad, + const TensorShape inShape, + const TensorShape outShape, + const FuncConfig& conf) { + std::vector crop_corner = conf.get>("crop_corner"); + int cropC = crop_corner[1]; + int cropH = crop_corner[2]; + int cropW = crop_corner[3]; + + int num = outShape[0]; + int outC = outShape[1]; + int outH = outShape[2]; + int outW = outShape[3]; + + int inC = inShape[1]; + int inH = inShape[2]; + int inW = inShape[3]; + + size_t nth = num * inC * inH * inW; + int blockSize = 1024; + int gridSize = (nth + blockSize - 1) / blockSize; + + KeCropDiff <<>> + (inGrad, outGrad, inC, inH, inW, cropC, cropH, cropW, + outC, outH, outW, nth); + CHECK_SYNC("CropGrad"); +} + +} // namespace paddle diff --git a/paddle/function/CropOpTest.cpp b/paddle/function/CropOpTest.cpp new file mode 100644 index 0000000000000000000000000000000000000000..6f11abfdf6f752857e0a75c62fb2b5c089c206d9 --- /dev/null +++ b/paddle/function/CropOpTest.cpp @@ -0,0 +1,49 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include "FunctionTest.h" + +namespace paddle { + +TEST(Crop, real) { + for (size_t numSamples : {5, 32}) { + for (size_t channels : {5, 5, 32}) { + for (size_t imgSizeH : {5, 33, 100}) { + for (size_t imgSizeW : {5, 32, 96}) { + VLOG(3) << " numSamples=" << numSamples << " channels=" << channels + << " imgSizeH=" << imgSizeH << " imgSizeW=" << imgSizeW; + for (bool test_grad : {false, true}) { + CpuGpuFuncCompare compare( + test_grad ? "CropGrad" : "Crop", + FuncConfig() + .set>("crop_corner", {0, 1, 1, 1}) + .set>("crop_shape", {0, 2, 3, 3})); + TensorShape inDims{numSamples, channels, imgSizeH, imgSizeW}; + TensorShape outDims{numSamples, 2, 3, 3}; + compare.addInputs( + BufferArg(VALUE_TYPE_FLOAT, test_grad ? outDims : inDims)); + compare.addOutputs(BufferArg(VALUE_TYPE_FLOAT, + test_grad ? inDims : outDims, + test_grad ? ADD_TO : ASSIGN_TO), + test_grad ? ADD_TO : ASSIGN_TO); + compare.run(); + } + } + } + } + } +} + +} // namespace paddle diff --git a/paddle/gserver/layers/CropLayer.cpp b/paddle/gserver/layers/CropLayer.cpp new file mode 100644 index 0000000000000000000000000000000000000000..69ad913420bdb6e1b2ed0618b7f9b78d7477be99 --- /dev/null +++ b/paddle/gserver/layers/CropLayer.cpp @@ -0,0 +1,146 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "CropLayer.h" +#include "paddle/utils/Stat.h" +namespace paddle { + +REGISTER_LAYER(crop, CropLayer); + +bool CropLayer::init(const LayerMap& layerMap, + const ParameterMap& parameterMap) { + /* Initialize the basic parent class */ + Layer::init(layerMap, parameterMap); + CHECK_LE(static_cast(inputLayers_.size()), 2); + CHECK_GE(static_cast(inputLayers_.size()), 1); + crop_axis_ = config_.axis(); + for (int i = 0; i < config_.offset_size(); i++) { + crop_offsets_.push_back(config_.offset(i)); + } + + // 1. get input_0 shape + auto& input0_img_conf = config_.inputs(0).image_conf(); + inDims_ = TensorShape({0, + input0_img_conf.channels(), + input0_img_conf.has_img_size_y() + ? input0_img_conf.img_size_y() + : input0_img_conf.img_size(), + input0_img_conf.img_size()}); + // 2. get target dims from config + if (config_.inputs_size() == 1) { + targetDims_ = TensorShape({config_.shape(0), + config_.shape(1), + config_.shape(2), + config_.shape(3)}); + } else { + // 2. get input_1 shape + auto& input1_img_conf = config_.inputs(1).image_conf(); + targetDims_ = TensorShape({0, + input1_img_conf.channels(), + input1_img_conf.has_img_size_y() + ? input1_img_conf.img_size_y() + : input1_img_conf.img_size(), + input1_img_conf.img_size()}); + } + + // 3. get final crop corner + int dimSize = 4; + crop_corner_ = {0, 0, 0, 0}; + for (int i = 0; i < dimSize; i++) { + if (i >= crop_axis_) { + if (crop_offsets_.size() > 1) { + crop_corner_[i] = crop_offsets_[i - crop_axis_]; + } else { + crop_corner_[i] = crop_offsets_[0]; + } + } + } + + outDims_ = TensorShape(4); + + createFunction( + forward_, "Crop", FuncConfig().set("crop_corner", crop_corner_)); + createFunction( + backward_, "CropGrad", FuncConfig().set("crop_corner", crop_corner_)); + + return true; +} + +void CropLayer::setOutDims() { + MatrixPtr input = inputLayers_[1]->getOutputValue(); + size_t batchSize = input->getHeight(); + // get target dims from input_1 + if (config_.inputs_size() == 2) { + targetDims_.setDim(0, batchSize); + int ch = config_.inputs(0).image_conf().channels(); + if (ch != 0) targetDims_.setDim(1, ch); + int h = inputLayers_[1]->getOutput().getFrameHeight(); + if (h != 0) targetDims_.setDim(2, h); + int w = inputLayers_[1]->getOutput().getFrameWidth(); + if (w != 0) targetDims_.setDim(3, w); + } + // get final crop shape from target dims and crop axis + std::vector crop_shape; + int dimSize = 4; + for (int i = 0; i < dimSize; i++) { + if (i >= crop_axis_) { + crop_shape.push_back(targetDims_[i]); + } else { + crop_shape.push_back(inDims_[i]); + } + } + + outDims_.reshape( + {crop_shape[0], crop_shape[1], crop_shape[2], crop_shape[3]}); + output_.setFrameHeight(crop_shape[2]); + output_.setFrameWidth(crop_shape[3]); +} + +void CropLayer::setInDims() { + MatrixPtr input = inputLayers_[0]->getOutputValue(); + size_t batchSize = input->getHeight(); + inDims_.setDim(0, batchSize); + int h = inputLayers_[0]->getOutput().getFrameHeight(); + if (h != 0) inDims_.setDim(2, h); + int w = inputLayers_[0]->getOutput().getFrameWidth(); + if (w != 0) inDims_.setDim(3, w); +} + +void CropLayer::forward(PassType passType) { + Layer::forward(passType); + setInDims(); + setOutDims(); + int size = outDims_[1] * outDims_[2] * outDims_[3]; + resetOutput(outDims_[0], size); + MatrixPtr outV = getOutputValue(); + REGISTER_TIMER_INFO("CropForward", getName().c_str()); + + BufferArgs inputs; + BufferArgs outputs; + inputs.addArg(*getInputValue(0), inDims_); + outputs.addArg(*getOutputValue(), outDims_, ASSIGN_TO); + forward_[0]->calc(inputs, outputs); +} + +void CropLayer::backward(const UpdateCallback& callback) { + (void)callback; + REGISTER_TIMER_INFO("CropBackward", getName().c_str()); + + BufferArgs inputs; + BufferArgs outputs; + inputs.addArg(*getOutputGrad(), outDims_); + outputs.addArg(*getInputGrad(0), inDims_, ADD_TO); + backward_[0]->calc(inputs, outputs); +} +} // namespace paddle diff --git a/paddle/gserver/layers/CropLayer.h b/paddle/gserver/layers/CropLayer.h new file mode 100644 index 0000000000000000000000000000000000000000..6b6202621023575c1c83049ecbd019656c726e3f --- /dev/null +++ b/paddle/gserver/layers/CropLayer.h @@ -0,0 +1,52 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "Layer.h" + +namespace paddle { + +/** + * \brief This layer crop input according to the specify conf. + * input_0: input to be cropped + * input_1: optional reference input + * axis: start dimension to be croped + * offset: offset of cropping in each dimension + * shape: if reference input layer was not setted, + * crop input as this shape conf + */ +class CropLayer : public Layer { +public: + explicit CropLayer(const LayerConfig& config) : Layer(config) {} + + ~CropLayer() {} + + bool init(const LayerMap& layerMap, + const ParameterMap& parameterMap) override; + void forward(PassType passType) override; + void backward(const UpdateCallback& callback = nullptr) override; + +protected: + void setOutDims(); + void setInDims(); + + int32_t crop_axis_; + std::vector crop_offsets_; + std::vector crop_corner_; + TensorShape inDims_; + TensorShape targetDims_; + TensorShape outDims_; +}; +} // namespace paddle diff --git a/paddle/gserver/layers/Layer.cpp b/paddle/gserver/layers/Layer.cpp index 4b92b5d163ad107c0783beae45f8c936112fcccf..d5621412caee843e24a0d0c9b7096402765738c7 100644 --- a/paddle/gserver/layers/Layer.cpp +++ b/paddle/gserver/layers/Layer.cpp @@ -359,12 +359,11 @@ void Layer::backwardActivation() { /* Do error clipping */ if (config_.error_clipping_threshold() > 0.0f) { if (FLAGS_log_error_clipping) { - CpuVector outGradVec(0, nullptr); - outGradVec.subVecFrom( - output_.grad->getData(), 0, output_.grad->getElementCnt()); - real maxAbsGrad = outGradVec.getAbsMax(); + VectorPtr outGradVec = Vector::create( + output_.grad->getData(), output_.grad->getElementCnt(), useGpu_); + real maxAbsGrad = outGradVec->getAbsMax(); if (maxAbsGrad > config_.error_clipping_threshold()) { - real avgAbsGrad = outGradVec.getAbsSum() / outGradVec.getSize(); + real avgAbsGrad = outGradVec->getAbsSum() / outGradVec->getSize(); LOG(INFO) << " layer=" << config_.name() << " need clipping," << " max error=" << maxAbsGrad << " avg error=" << avgAbsGrad; } diff --git a/paddle/gserver/tests/CMakeLists.txt b/paddle/gserver/tests/CMakeLists.txt index 92f6cbcfe5a0e23c5939b1689a3e339367450387..a43adc7ce7db937bd62ea9bf1533b8a5899c259a 100644 --- a/paddle/gserver/tests/CMakeLists.txt +++ b/paddle/gserver/tests/CMakeLists.txt @@ -56,7 +56,7 @@ add_test(NAME test_DetectionOutput add_unittest_without_exec(test_ConvUnify test_ConvUnify.cpp LayerGradUtil.cpp) - + add_test(NAME test_ConvUnify COMMAND test_ConvUnify) ################# test_BatchNorm ####################### diff --git a/paddle/gserver/tests/test_LayerGrad.cpp b/paddle/gserver/tests/test_LayerGrad.cpp index 67251f08e34faff57d9e6fd6a1163ba655619a8b..9af083468c0f01218117211f9e4931ca0669e96a 100644 --- a/paddle/gserver/tests/test_LayerGrad.cpp +++ b/paddle/gserver/tests/test_LayerGrad.cpp @@ -1802,6 +1802,34 @@ TEST(Layer, RowConvLayer) { } } +TEST(Layer, CropLayer) { + TestConfig config; + // config input_0 + config.inputDefs.push_back({INPUT_DATA, "layer_0", 1024, 0}); + LayerInputConfig* input = config.layerConfig.add_inputs(); + ImageConfig* img = input->mutable_image_conf(); + img->set_channels(4); + img->set_img_size(16); + config.layerConfig.set_axis(2); + config.layerConfig.add_offset(0); + config.layerConfig.add_offset(0); + + // config input_1 + config.inputDefs.push_back({INPUT_DATA, "layer_1", 128, 0}); + input = config.layerConfig.add_inputs(); + img = input->mutable_image_conf(); + img->set_channels(2); + img->set_img_size(8); + + // config crop layer + config.layerConfig.set_type("crop"); + config.layerConfig.set_name("cropLayer"); + + for (auto useGpu : {false, true}) { + testLayerGrad(config, "crop", 100, false, useGpu, false); + } +} + int main(int argc, char** argv) { testing::InitGoogleTest(&argc, argv); initMain(argc, argv); diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt index f47c3a42083f289d6c99fe6df62e3478e0363e31..bc64bfd7ec2ed27835e5a3f9135343aeb3d4a580 100644 --- a/paddle/operators/CMakeLists.txt +++ b/paddle/operators/CMakeLists.txt @@ -27,7 +27,8 @@ function(op_library TARGET) endif() list(LENGTH cu_srcs cu_srcs_len) - if (${cu_srcs_len} EQUAL 0) + list(LENGTH op_library_DEPS dep_len) + if (${cu_srcs_len} EQUAL 0 AND ${dep_len} EQUAL 0) message(WARNING "The op library ${TARGET} not support GPU!") endif() @@ -47,3 +48,6 @@ op_library(mul_op SRCS mul_op.cc mul_op.cu) op_library(rowwise_add_op SRCS rowwise_add_op.cu rowwise_add_op.cc) op_library(sigmoid_op SRCS sigmoid_op.cu sigmoid_op.cc) op_library(softmax_op SRCS softmax_op.cc softmax_op.cu) + +op_library(fc_op SRCS fc_op.cc DEPS mul_op rowwise_add_op sigmoid_op + softmax_op net) diff --git a/paddle/operators/add_op.h b/paddle/operators/add_op.h index e08b3fb18775e2536a13bc838f40472c5c3e7ff7..39d54a63bd16cdafeec1cfcd86ef5d142382e880 100644 --- a/paddle/operators/add_op.h +++ b/paddle/operators/add_op.h @@ -14,6 +14,7 @@ limitations under the License. */ #pragma once #include "glog/logging.h" +#include "paddle/framework/eigen.h" #include "paddle/framework/operator.h" namespace paddle { @@ -29,8 +30,10 @@ public: output->mutable_data(context.GetPlace()); - output->flat().device(*(context.GetEigenDevice())) = - input0.flat() + input1.flat(); + framework::EigenVector::Flatten(*output).device( + *(context.GetEigenDevice())) = + framework::EigenVector::Flatten(input0) + + framework::EigenVector::Flatten(input1); } }; diff --git a/paddle/operators/fc_op.cc b/paddle/operators/fc_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..01e96f4c4817466e3266ca57a0d0ae2368b3e097 --- /dev/null +++ b/paddle/operators/fc_op.cc @@ -0,0 +1,76 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/framework/net.h" +#include "paddle/framework/op_registry.h" +#include "paddle/framework/operator.h" + +namespace paddle { +namespace operators { + +class FullyConnectedOp : public framework::PlainNet { +public: + void Init() override { + AddOp(framework::OpRegistry::CreateOp("mul", + { + Input("X"), Input("W"), + }, + {Output("before_act")}, + {})); + auto b = Input("b"); + if (b != framework::OperatorBase::EMPTY_VAR_NAME()) { + AddOp(framework::OpRegistry::CreateOp("rowwise_add", + {Output("before_act"), Input("b")}, + {Output("before_act")}, + {})); + } + + auto activation = GetAttr("activation"); + AddOp(framework::OpRegistry::CreateOp( + activation, {Output("before_act")}, {Output("Y")}, {})); + CompleteAddOp(false); + } +}; + +class FullyConnectedOpMaker : public framework::OpProtoAndCheckerMaker { +public: + FullyConnectedOpMaker(framework::OpProto *proto, + framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", "the input of fc operator"); + AddInput("W", "the weight of fc operator"); + AddInput("b", "the bias of fc operator"); + + AddOutput("Y", "the output of fc operator"); + AddOutput( + "before_act", "the before activation output of fc operator", true); + AddAttr("activation", "The activation key for fc layer") + .SetDefault("sigmoid") + .InEnum({"sigmoid", "softmax"}); + + //! TODO(yuyang18): Complete comment; + AddComment("FullyConnected Operator"); + } +}; +} // namespace operators +} // namespace paddle + +USE_OP(mul); +USE_OP(rowwise_add); +USE_OP(sigmoid); +USE_OP(softmax); + +REGISTER_OP(fc, + paddle::operators::FullyConnectedOp, + paddle::operators::FullyConnectedOpMaker); diff --git a/paddle/pybind/CMakeLists.txt b/paddle/pybind/CMakeLists.txt index 00b14a94321990baef6de35df547eed04b3da04f..29fb29c7c14f699e6114cc25c265ea8d85bce4d7 100644 --- a/paddle/pybind/CMakeLists.txt +++ b/paddle/pybind/CMakeLists.txt @@ -1,2 +1,2 @@ cc_library(paddle_pybind SHARED SRCS pybind.cc DEPS pybind python - add_op mul_op rowwise_add_op sigmoid_op softmax_op) + add_op fc_op) diff --git a/paddle/pybind/pybind.cc b/paddle/pybind/pybind.cc index fc9c6544c3cbf5a804b2d052f738bd483d6bf41b..7e84550f770e8dba998ce7ff91b9d774acbffc3e 100644 --- a/paddle/pybind/pybind.cc +++ b/paddle/pybind/pybind.cc @@ -14,6 +14,7 @@ limitations under the License. */ #include #include +#include #include #include #include @@ -26,10 +27,7 @@ namespace py = pybind11; namespace pd = paddle::framework; USE_OP(add_two); -USE_OP(softmax); -USE_OP(mul); -USE_OP(rowwise_add); -USE_OP(sigmoid); +USE_OP_WITHOUT_KERNEL(fc); PYBIND11_PLUGIN(core) { py::module m("core", "C++ core of Paddle Paddle"); @@ -53,7 +51,9 @@ PYBIND11_PLUGIN(core) { self.mutable_data(paddle::platform::CPUPlace()); }) .def("set", paddle::pybind::PyTensorSetFromArray) - .def("set", paddle::pybind::PyTensorSetFromArray); + .def("set", paddle::pybind::PyTensorSetFromArray) + .def("shape", + [](pd::Tensor& self) { return pd::vectorize(self.dims()); }); py::class_(m, "Variable", R"DOC(Variable Class. @@ -83,15 +83,16 @@ All parameter, weight, gradient are variables in Paddle. //! @note: Be careful! PyBind will return std::string as an unicode, not //! Python str. If you want a str object, you should cast them in Python. - m.def("get_all_op_protos", []() -> std::vector { + m.def("get_all_op_protos", []() -> std::vector { auto& protos = pd::OpRegistry::protos(); - std::vector ret_values; + std::vector ret_values; for (auto it = protos.begin(); it != protos.end(); ++it) { PADDLE_ENFORCE(it->second.IsInitialized(), "OpProto must all be initialized"); - ret_values.emplace_back(); - PADDLE_ENFORCE(it->second.SerializeToString(&ret_values.back()), + std::string str; + PADDLE_ENFORCE(it->second.SerializeToString(&str), "Serialize OpProto Error. This could be a bug of Paddle."); + ret_values.push_back(py::bytes(str)); } return ret_values; }); @@ -101,17 +102,26 @@ All parameter, weight, gradient are variables in Paddle. .def("empty", pd::OperatorBase::EMPTY_VAR_NAME) .def("temp", pd::OperatorBase::TMP_VAR_NAME); + py::class_(m, "DeviceContext") + .def_static("cpu_context", []() -> paddle::platform::DeviceContext* { + return new paddle::platform::CPUDeviceContext(); + }); + py::class_(m, "Operator") .def("__str__", &pd::OperatorBase::DebugString) - .def_static("create", [](const std::string& protobin) { - pd::OpDesc desc; - PADDLE_ENFORCE(desc.ParsePartialFromString(protobin), - "Cannot parse user input to OpDesc"); - PADDLE_ENFORCE(desc.IsInitialized(), - "User OpDesc is not initialized, reason %s", - desc.InitializationErrorString()); - return pd::OpRegistry::CreateOp(desc); - }); + .def_static("create", + [](py::bytes protobin) { + pd::OpDesc desc; + PADDLE_ENFORCE(desc.ParsePartialFromString(protobin), + "Cannot parse user input to OpDesc"); + PADDLE_ENFORCE(desc.IsInitialized(), + "User OpDesc is not initialized, reason %s", + desc.InitializationErrorString()); + return pd::OpRegistry::CreateOp(desc); + }) + .def("infer_shape", &pd::OperatorBase::InferShape) + .def("run", &pd::OperatorBase::Run) + .def("outputs", [](const pd::OperatorPtr& op) { return op->outputs_; }); return m.ptr(); } diff --git a/paddle/scripts/travis/check_style.sh b/paddle/scripts/travis/check_style.sh index 8049aeb7b00870220e59c981addf6d70a66877c7..ec499a839ac6593bac788f4cca5e33afbed73010 100755 --- a/paddle/scripts/travis/check_style.sh +++ b/paddle/scripts/travis/check_style.sh @@ -1,7 +1,7 @@ #!/bin/bash function abort(){ echo "Your change doesn't follow PaddlePaddle's code style." 1>&2 - echo "Please use pre-commit to reformat your code and git push again." 1>&2 + echo "Please use pre-commit to check what is wrong." 1>&2 exit 1 } @@ -19,7 +19,8 @@ ln -sf $TRAVIS_BUILD_DIR $GOPATH/src/github.com/PaddlePaddle/Paddle cd $GOPATH/src/github.com/PaddlePaddle/Paddle/go; glide install; cd - if ! pre-commit run -a ; then - git diff --exit-code + git diff + exit 1 fi trap : 0 diff --git a/proto/ModelConfig.proto b/proto/ModelConfig.proto index 37cd16c79890738f6d8966579e15686c653d4df3..83f72c137bdf5e55f28be908321bd2ccd6c906fe 100644 --- a/proto/ModelConfig.proto +++ b/proto/ModelConfig.proto @@ -472,10 +472,16 @@ message LayerConfig { // blank label used in ctc loss optional uint32 blank = 52 [default = 0]; - // stride parameter for seqlastins layer, AverageLayer, MaxLayer, which + // stride parameter for seqlastins layer, AverageLayer, MaxLayer, which // controls the scope of pooling operation. can be set > 0. // leave empty or set to -1 to disable this stride pooling. optional int32 seq_pool_stride = 53 [default = -1]; + + // for crop layer + optional int32 axis = 54 [default = 2]; + repeated uint32 offset = 55; + repeated uint32 shape = 56; + } message EvaluatorConfig { diff --git a/python/paddle/trainer/config_parser.py b/python/paddle/trainer/config_parser.py index 826ba2834a820d11e69feec5569ef3537194e3c3..ab81e67579e39a34e3ace18d14434eb86b66fa5b 100644 --- a/python/paddle/trainer/config_parser.py +++ b/python/paddle/trainer/config_parser.py @@ -1575,7 +1575,13 @@ class MultiClassCrossEntropySelfNormCostLayer(LayerBase): @config_layer('fc') class FCLayer(LayerBase): - def __init__(self, name, size, inputs, bias=True, **xargs): + def __init__(self, + name, + size, + inputs, + bias=True, + error_clipping_threshold=None, + **xargs): super(FCLayer, self).__init__(name, 'fc', size, inputs=inputs, **xargs) for input_index in xrange(len(self.inputs)): input_layer = self.get_input_layer(input_index) @@ -1592,6 +1598,8 @@ class FCLayer(LayerBase): self.create_input_parameter(input_index, psize, dims, sparse, format) self.create_bias_parameter(bias, self.config.size) + if error_clipping_threshold is not None: + self.config.error_clipping_threshold = error_clipping_threshold @config_layer('selective_fc') @@ -1990,6 +1998,23 @@ class PadLayer(LayerBase): self.config.size = out_ch * out_h * out_w +@config_layer('crop') +class CropLayer(LayerBase): + def __init__(self, name, inputs, axis, offset, shape, **xargs): + super(CropLayer, self).__init__(name, 'crop', 0, inputs=inputs, **xargs) + self.config.axis = axis + self.config.offset.extend(offset) + self.config.shape.extend(shape) + + # get channel, width and height from input_0 layer + input_layer = self.get_input_layer(0) + image_conf = self.config.inputs[0].image_conf + image_conf.img_size = input_layer.width + image_conf.img_size_y = input_layer.height + image_conf.channels = input_layer.size / (input_layer.width * + input_layer.height) + + @config_layer('batch_norm') class BatchNormLayer(LayerBase): layer_type = 'batch_norm' diff --git a/python/paddle/trainer_config_helpers/layers.py b/python/paddle/trainer_config_helpers/layers.py index 78aa0778f8d1dca9fae82f0411be5a00e636cbc9..fdb6f83f2ba510232714fb8a9c7c1af837a753ff 100755 --- a/python/paddle/trainer_config_helpers/layers.py +++ b/python/paddle/trainer_config_helpers/layers.py @@ -127,6 +127,7 @@ __all__ = [ 'dropout_layer', 'prelu_layer', 'gated_unit_layer', + 'crop_layer', ] @@ -218,6 +219,7 @@ class LayerType(object): SMOOTH_L1 = 'smooth_l1' PRELU = 'prelu' + CROP_LAYER = 'crop' @staticmethod def is_layer_type(type_name): @@ -5970,3 +5972,52 @@ def gated_unit_layer(input, name="%s_gated_act" % name, input=dotmul_operator(input_proj, gate), layer_attr=layer_attr) + + +@wrap_name_default() +@layer_support() +def crop_layer(input, offset, axis=2, shape=None, name=None, layer_attr=None): + """ + The crop layer crops images by offset and shape. User can set crop shape by + args 'shape' explicitly or by reference input layer. + + The example usage is: + + .. code-block:: python + crop = crop_layer(input=[image_input, reference_input], axis=2, offset=[2, 3]) + + :param input: The input layer.If two inputs were setted, + the second input will be regarded as reference input + :type input: LayerOutput or Sequence + :param offset: The crop offset + :type offset: Sequence + :param axis: start axis to be cropped. To image input layer: + - 0: batch size + - 1: channels + - 2: height + - 3: width + :type partial_sum: int + :param shape: The shape to be cropped. Default is None. + :type shape: Sequence | None + :param name: Name of this layer. + :type name: basestring + :return: LayerOutput object. + :rtype: LayerOutput + """ + if isinstance(input, LayerOutput): + input = [input] + else: + assert isinstance(input, collections.Sequence) + l = Layer( + inputs=[x.name for x in input], + axis=axis, + offset=offset, + shape=shape, + name=name, + type=LayerType.CROP_LAYER, + **ExtraLayerAttribute.to_kwargs(layer_attr)) + return LayerOutput( + name=name, + layer_type=LayerType.CROP_LAYER, + parents=input, + size=l.config.size) diff --git a/python/paddle/trainer_config_helpers/tests/configs/test_crop.py b/python/paddle/trainer_config_helpers/tests/configs/test_crop.py new file mode 100644 index 0000000000000000000000000000000000000000..8314a7e9a5586647c70ff010156817110919c72b --- /dev/null +++ b/python/paddle/trainer_config_helpers/tests/configs/test_crop.py @@ -0,0 +1,21 @@ +from paddle.trainer_config_helpers import * + +settings(batch_size=1000, learning_rate=1e-5) + +data = data_layer(name='data', size=2016, height=48, width=42) +refernce_data = data_layer(name='data', size=768, height=16, width=16) + +conv = img_conv_layer( + input=data, + filter_size=3, + num_channels=1, + num_filters=16, + padding=1, + act=LinearActivation(), + bias_attr=True) + +pool = img_pool_layer(input=conv, pool_size=2, stride=2, pool_type=MaxPooling()) + +crop = crop_layer(input=[pool, refernce_data], axis=2) + +outputs(pad) diff --git a/python/paddle/v2/dataset/__init__.py b/python/paddle/v2/dataset/__init__.py index 2e4beb6882789249db09705f3f4d6c5c19e492cd..90830515c1e8e6f5260cfca631e02a3a52cedbe5 100644 --- a/python/paddle/v2/dataset/__init__.py +++ b/python/paddle/v2/dataset/__init__.py @@ -26,8 +26,9 @@ import sentiment import wmt14 import mq2007 import flowers +import voc2012 __all__ = [ 'mnist', 'imikolov', 'imdb', 'cifar', 'movielens', 'conll05', 'sentiment' - 'uci_housing', 'wmt14', 'mq2007', 'flowers' + 'uci_housing', 'wmt14', 'mq2007', 'flowers', 'voc2012' ] diff --git a/python/paddle/v2/dataset/tests/voc2012_test.py b/python/paddle/v2/dataset/tests/voc2012_test.py new file mode 100644 index 0000000000000000000000000000000000000000..31e72ebf5eac0508d12783f9ceaa6eef0fa6d353 --- /dev/null +++ b/python/paddle/v2/dataset/tests/voc2012_test.py @@ -0,0 +1,42 @@ +# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle.v2.dataset.voc2012 +import unittest + + +class TestVOC(unittest.TestCase): + def check_reader(self, reader): + sum = 0 + label = 0 + for l in reader(): + self.assertEqual(l[0].size, 3 * l[1].size) + sum += 1 + return sum + + def test_train(self): + count = self.check_reader(paddle.v2.dataset.voc_seg.train()) + self.assertEqual(count, 2913) + + def test_test(self): + count = self.check_reader(paddle.v2.dataset.voc_seg.test()) + self.assertEqual(count, 1464) + + def test_val(self): + count = self.check_reader(paddle.v2.dataset.voc_seg.val()) + self.assertEqual(count, 1449) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/v2/dataset/voc2012.py b/python/paddle/v2/dataset/voc2012.py new file mode 100644 index 0000000000000000000000000000000000000000..617e212d67fbe37f9d9663e9c83c62045411fa77 --- /dev/null +++ b/python/paddle/v2/dataset/voc2012.py @@ -0,0 +1,85 @@ +# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Image dataset for segmentation. +The 2012 dataset contains images from 2008-2011 for which additional +segmentations have been prepared. As in previous years the assignment +to training/test sets has been maintained. The total number of images +with segmentation has been increased from 7,062 to 9,993. +""" + +import tarfile +import io +import numpy as np +from paddle.v2.dataset.common import download +from paddle.v2.image import * +from PIL import Image + +__all__ = ['train', 'test', 'val'] + +VOC_URL = 'http://host.robots.ox.ac.uk/pascal/VOC/voc2012/\ +VOCtrainval_11-May-2012.tar' + +VOC_MD5 = '6cd6e144f989b92b3379bac3b3de84fd' +SET_FILE = 'VOCdevkit/VOC2012/ImageSets/Segmentation/{}.txt' +DATA_FILE = 'VOCdevkit/VOC2012/JPEGImages/{}.jpg' +LABEL_FILE = 'VOCdevkit/VOC2012/SegmentationClass/{}.png' + +CACHE_DIR = 'voc2012' + + +def reader_creator(filename, sub_name): + + tarobject = tarfile.open(filename) + name2mem = {} + for ele in tarobject.getmembers(): + name2mem[ele.name] = ele + + def reader(): + set_file = SET_FILE.format(sub_name) + sets = tarobject.extractfile(name2mem[set_file]) + for line in sets: + line = line.strip() + data_file = DATA_FILE.format(line) + label_file = LABEL_FILE.format(line) + data = tarobject.extractfile(name2mem[data_file]).read() + label = tarobject.extractfile(name2mem[label_file]).read() + data = Image.open(io.BytesIO(data)) + label = Image.open(io.BytesIO(label)) + data = np.array(data) + label = np.array(label) + yield data, label + + return reader + + +def train(): + """ + Create a train dataset reader containing 2913 images in HWC order. + """ + return reader_creator(download(VOC_URL, CACHE_DIR, VOC_MD5), 'trainval') + + +def test(): + """ + Create a test dataset reader containing 1464 images in HWC order. + """ + return reader_creator(download(VOC_URL, CACHE_DIR, VOC_MD5), 'train') + + +def val(): + """ + Create a val dataset reader containing 1449 images in HWC order. + """ + return reader_creator(download(VOC_URL, CACHE_DIR, VOC_MD5), 'val') diff --git a/python/paddle/v2/framework/create_op_creation_methods.py b/python/paddle/v2/framework/create_op_creation_methods.py index c2a7ae7692b08762ffbc91726be7bfa90e8ddedb..7248c3f52a9902e8c08ac2f1405801a5710459e5 100644 --- a/python/paddle/v2/framework/create_op_creation_methods.py +++ b/python/paddle/v2/framework/create_op_creation_methods.py @@ -217,6 +217,10 @@ def create_op_creation_method(op_proto): return core.Operator.create(opdesc.SerializeToString()) __impl__.__doc__ = get_docstring_from_op_proto(op_proto) + __impl__.all_input_args = [var.name for var in op_proto.inputs] + __impl__.all_output_args = [var.name for var in op_proto.outputs] + __impl__.all_attr_args = [attr.name for attr in op_proto.attrs] + return __impl__ diff --git a/python/paddle/v2/framework/tests/CMakeLists.txt b/python/paddle/v2/framework/tests/CMakeLists.txt index 4ce2bef6fcc4b8ddf5a6de3809a1891bce590aab..f71009aa8569beae330b18171043d456b59bca8d 100644 --- a/python/paddle/v2/framework/tests/CMakeLists.txt +++ b/python/paddle/v2/framework/tests/CMakeLists.txt @@ -1,3 +1,3 @@ add_python_test(test_framework test_protobuf.py test_scope.py test_default_scope_funcs.py test_op_creation_methods.py - test_tensor.py) + test_tensor.py test_fc_op.py test_add_two_op.py) diff --git a/python/paddle/v2/framework/tests/op_test_util.py b/python/paddle/v2/framework/tests/op_test_util.py new file mode 100644 index 0000000000000000000000000000000000000000..237f9b7eb0d525a2c8431523a2d90b7e32493d53 --- /dev/null +++ b/python/paddle/v2/framework/tests/op_test_util.py @@ -0,0 +1,50 @@ +import paddle.v2.framework.core as core +import unittest +import numpy +import paddle.v2.framework.create_op_creation_methods as creation + + +class OpTestMeta(type): + def __new__(cls, name, bases, attrs): + obj = super(OpTestMeta, cls).__new__(cls, name, bases, attrs) + + def test_all(self): + func = getattr(creation.op_creations, self.type, None) + self.assertIsNotNone(func) + + scope = core.Scope(None) + kwargs = dict() + + for in_name in func.all_input_args: + if hasattr(self, in_name): + kwargs[in_name] = in_name + var = scope.create_var(in_name).get_tensor() + arr = getattr(self, in_name) + var.set_dims(arr.shape) + var.set(arr) + else: + kwargs[in_name] = "@EMPTY@" + + for out_name in func.all_output_args: + if hasattr(self, out_name): + kwargs[out_name] = out_name + scope.create_var(out_name).get_tensor() + + for attr_name in func.all_attr_args: + if hasattr(self, attr_name): + kwargs[attr_name] = getattr(self, attr_name) + + op = func(**kwargs) + + op.infer_shape(scope) + + ctx = core.DeviceContext.cpu_context() + op.run(scope, ctx) + + for out_name in func.all_output_args: + actual = numpy.array(scope.get_var(out_name).get_tensor()) + expect = getattr(self, out_name) + numpy.testing.assert_almost_equal(actual, expect) + + obj.test_all = test_all + return obj diff --git a/python/paddle/v2/framework/tests/test_add_two_op.py b/python/paddle/v2/framework/tests/test_add_two_op.py new file mode 100644 index 0000000000000000000000000000000000000000..a06d7a78ecf838a49e5f2808d3686c6b92faa8ce --- /dev/null +++ b/python/paddle/v2/framework/tests/test_add_two_op.py @@ -0,0 +1,17 @@ +import unittest +from op_test_util import OpTestMeta +import numpy + + +class TestAddOp(unittest.TestCase): + __metaclass__ = OpTestMeta + + def setUp(self): + self.type = "add_two" + self.X = numpy.random.random((342, 345)).astype("float32") + self.Y = numpy.random.random((342, 345)).astype("float32") + self.Out = self.X + self.Y + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/v2/framework/tests/test_fc_op.py b/python/paddle/v2/framework/tests/test_fc_op.py new file mode 100644 index 0000000000000000000000000000000000000000..59e7e61249e2a7d49a17e5d87209f03b8f35f730 --- /dev/null +++ b/python/paddle/v2/framework/tests/test_fc_op.py @@ -0,0 +1,43 @@ +import paddle.v2.framework.core as core +import unittest +import numpy +import paddle.v2.framework.create_op_creation_methods as creation + + +class TestFc(unittest.TestCase): + def test_fc(self): + scope = core.Scope(None) + x = scope.create_var("X") + x_tensor = x.get_tensor() + x_tensor.set_dims([1000, 784]) + x_tensor.alloc_float() + + w = scope.create_var("W") + w_tensor = w.get_tensor() + w_tensor.set_dims([784, 100]) + w_tensor.alloc_float() + + w_tensor.set(numpy.random.random((784, 100)).astype("float32")) + + # Set a real numpy array here. + # x_tensor.set(numpy.array([])) + + op = creation.op_creations.fc(X="X", Y="Y", W="W") + + for out in op.outputs(): + if scope.get_var(out) is None: + scope.create_var(out).get_tensor() + + tensor = scope.get_var("Y").get_tensor() + op.infer_shape(scope) + self.assertEqual([1000, 100], tensor.shape()) + + ctx = core.DeviceContext.cpu_context() + + op.run(scope, ctx) + + # After complete all ops, check Y is expect or not. + + +if __name__ == '__main__': + unittest.main() diff --git a/python/setup.py.in b/python/setup.py.in index b1041f6102a56f5a200aa909e77729095c052f31..65a26940d4d703ea4fbb5022523a90716982ec10 100644 --- a/python/setup.py.in +++ b/python/setup.py.in @@ -20,6 +20,7 @@ setup_requires=["requests", "matplotlib", "rarfile", "scipy>=0.19.0", + "Pillow", "nltk"] if '${CMAKE_SYSTEM_PROCESSOR}' not in ['arm', 'armv7-a', 'aarch64']: