diff --git a/doc/design/model_format.md b/doc/design/model_format.md index db8c36e5f5dca94b516aad2134c1bdc8ccc6c744..e29129fddf775939c9f7a8b49d850d523e6e5a45 100644 --- a/doc/design/model_format.md +++ b/doc/design/model_format.md @@ -2,35 +2,35 @@ ## Motivation -The model is the output of training process. One complete model consists of two parts, namely, the **topology** and the **parameters**. To support industrial deployment, we need to make the model format must be self-completed and do not expose any training source code. +A model is an output of the training process. One complete model consists of two parts, the **topology** and the **parameters**. In order to support industrial deployment, the model format must be self-complete and must not expose any training source code. -As a result, In PaddlePaddle, the **topology** represents as a [ProgramDesc](https://github.com/PaddlePaddle/Paddle/blob/1c0a4c901c9fc881d120249c703b15d1c50dae7d/doc/design/program.md), which describes the model structure. The **parameters** contain all the trainable weights in the model, we must support large size parameter, and efficient serialization/deserialization. +As a result, In PaddlePaddle, the **topology** is represented as a [ProgramDesc](https://github.com/PaddlePaddle/Paddle/blob/1c0a4c901c9fc881d120249c703b15d1c50dae7d/doc/design/program.md), which describes the model structure. The **parameters** contain all the trainable weights in the model. We must support large size parameters and efficient serialization/deserialization of parameters. ## Implementation -The topology is saved as a plain text, in detail, a self-contain protobuf file. +The topology is saved as a plain text in a detailed self-contain protobuf file. -The parameters are saved as a binary file. As we all know, the protobuf message has the limits of [64M size](https://developers.google.com/protocol-buffers/docs/reference/cpp/google.protobuf.io.coded_stream#CodedInputStream.SetTotalBytesLimit.details). We do a (benchmark experiment)[https://github.com/PaddlePaddle/Paddle/pull/4610], its result shows protobuf is not fit in this scene. +The parameters are saved as a binary file. As we all know, the protobuf message has a limit of [64M size](https://developers.google.com/protocol-buffers/docs/reference/cpp/google.protobuf.io.coded_stream#CodedInputStream.SetTotalBytesLimit.details). We have done a [benchmark experiment](https://github.com/PaddlePaddle/Paddle/pull/4610), which shows that protobuf is not fit for the task. -As a result, we design a particular format for tensor serialization. By default, arbitrary tensor in Paddle is a [LoDTensor](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/framework/lod_tensor.md), and has a description information proto of (LoDTensorDesc)[https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/framework/framework.proto#L99]. We save the DescProto as the byte string header, it contains the necessary information, such as the `dims`, the `name` of the tensor, and the `LoD` information in [LoDTensor](https://github.com/PaddlePaddle/Paddle/blob/1c0a4c901c9fc881d120249c703b15d1c50dae7d/paddle/framework/lod_tensor.md). Tensor stores value in a continuous memory buffer, for speed we dump the raw memory to disk and save it as the byte string content. So, the binary format of one tensor is, +As a result, we design a particular format for tensor serialization. By default, an arbitrary tensor in Paddle is a [LoDTensor](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/framework/lod_tensor.md), and has a description information proto of [LoDTensorDesc](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/framework/framework.proto#L99). We save the DescProto as the byte string header. It contains all the necessary information, such as the `dims`, and the `LoD` information in [LoDTensor](https://github.com/PaddlePaddle/Paddle/blob/1c0a4c901c9fc881d120249c703b15d1c50dae7d/paddle/framework/lod_tensor.md). A tensor stores values in a continuous memory buffer. For speed we dump the raw memory to disk and save it as the byte string content. So, the binary format of one tensor is, -|HeaderLength|ContentLength|**LoDTensorDesc**|**TensorValue**| +The table below shows a tensor's byte view in detail. Note that all the signed values are written in the little-endian format. + +|field name | type | description | +| --- | --- | --- | +| version | uint32_t | Version of saved file. Always 0 now. | +| tensor desc length | uint32_t | TensorDesc(Protobuf message) length in bytes. | +| tensor desc | void* | TensorDesc protobuf binary message | +| tensor data | void* | Tensor's data in binary format. The length of `tensor_data` is decided by `TensorDesc.dims()` and `TensorDesc.data_type()` | +| lod_level | uint64_t | Level of LoD | +| length of lod[0] | uint64_t | [Optional] length of lod[0] in bytes. | +| data of lod[0] | uint64_t* | [Optional] lod[0].data() | +| ... | ... | ... | -In detail, tensor's byte view as the table shows. Note that all the signed value written in little-endian. -```text -[offset] [type] [description] -0004 4 bytes integer HeaderLength, the length of LoDTensorDesc -0008 4 bytes integer ContentLength, the length of LodTensor Buffer -0009 1 bytes char TensorDesc -00010 1 bytes char TensorDesc -... -00100 1 bytes char TensorValue -00101 1 bytes char TensorValue -00102 1 bytes char TensorValue .. -... -``` ## Summary -We introduce the model format, the `ProgramDesc` describe the **topology**, and a bunch of particular format binary tensors describes the **parameters**. +- We introduce a model format. +- The model represented by its forward-pass computation procedure is saved in a **ProgramDesc** protobuf message. +- A bunch of specified format binary tensors describe the **parameters**. diff --git a/doc/design/regularization.md b/doc/design/regularization.md index 703a9fbdd4392aa7f44733cce2da19caa1b51e4a..21280ac898feb4dd5e5a5d9e88d121e856850f0b 100644 --- a/doc/design/regularization.md +++ b/doc/design/regularization.md @@ -1,7 +1,7 @@ # Regularization in PaddlePaddle ## Introduction to Regularization -A central problem in machine learning is how to design an algorithm that will perform well not just on the training data, but also on new data. Many strategies are used by machine learning practitioners to reduce the test error, possibly at the expense of increased training error. These strategies are collectively known as **regularization**. +A central problem in machine learning is how to design an algorithm that will perform well not just on the training data, but also on new data. A frequently faced problem is the problem of **overfitting**, where the model does not make reliable predictions on new unseen data. **Regularization** is the process of introducing additional information in order to prevent overfitting. This is usually done by adding extra penalties to the loss function that restricts the parameter spaces that an optimization algorithm can explore. ### Parameter Norm Penalties Most common regularization approaches in deep learning are based on limiting the capacity of the models by adding a parameter norm penalty to the objective function `J`. This is given as follows: @@ -18,52 +18,21 @@ The most commonly used norm penalties are the L2 norm penalty and the L1 norm pe ##### L1 Regularization
-A much more detailed mathematical background of reguilarization can be found [here](http://www.deeplearningbook.org/contents/regularization.html). +A much more detailed mathematical background of regularization can be found [here](http://www.deeplearningbook.org/contents/regularization.html). +## Regularization Survey -## How to do Regularization in PaddlePaddle - -On surveying existing frameworks like Tensorflow, PyTorch, Caffe, etc, it can be seen that there are 2 common approaches of doing regularization: - -1. Making regularization a part of the optimizer using an attribute like `weight_decay` that is used to control the scale of the L2 Penalty. This approach is used in PyTorch as follows: - ```python - opt = torch.optim.SGD(params, lr=0.2, weight_decay=0.2) - ``` - At every optimization step, this code will add the gradient of the L2 Norm of the params to the gradient of the params with respect to the loss function. This can seen in the following code snippet: - ```python - if weight_decay != 0: - d_p.add_(weight_decay, p.data) - ``` - This is a very restyrictive way of doing regularization and does not give the users enough flexibility. - - **Advantages**: - - It is easy to implement for us. - - Faster execution of backward. However, it can be done manually by advanced users too. - - **Disadvantages**: - - Not flexible for other regularizations such as L1/L0 regularization. - - Does not allow for different regularization coefficient for different parameters. For example, in most models, ony the weight matrices are regularized and the bias vectors are unregularized. - - Tightly coupled optimizer and regularization implementation. - - -2. Adding regularization ops to the graph through Python API. This approach is used by Tensorflow and Caffe. Using this approach, we manually add regularization ops to the graph and then add the regularization loss to the final loss function before sending them to the optimizer. - - **Advantages**: - - Allows for greater flexibility to the users of Paddle. Using this approach, the users can put different regularization to different parameters and also choose parameters that are not a part of regularization. - - Makes it easy for the users to customize and extend the framework. - - **Disadvantages**: - - Implementation requires comprehensive design and time. +A detailed survey of regularization in various deep learning frameworks can be found [here](https://github.com/PaddlePaddle/Paddle/wiki/Regularization-Survey). ## Proposal for Regularization in PaddlePaddle ### Low-Level implementation -In the new design, we propose to create new operations for regularization. For now, we can add 2 ops thgat correspond to the most frequently used regularizations: +In the new design, we propose to create new operations for regularization. For now, we can add 2 ops that correspond to the most frequently used regularizations: - L2_regularization_op - L1_regularization_op -These ops can be like any other ops with their own CPU/GPU implementations either using Eigen or separate Cpu and GPU kernels. As the initial implementation, we can implement their kernels using Eigen following the abstraction pattern implemented for [Activation Ops](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/operators/accuracy_op.h). This abstraction pattern can make it very easy to implement new regularization schemes. other than L1 and L2 norm penalties. +These ops can be like any other ops with their own CPU/GPU implementations either using Eigen or separate CPU and GPU kernels. As the initial implementation, we can implement their kernels using Eigen following the abstraction pattern implemented for [Activation Ops](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/operators/accuracy_op.h). This abstraction pattern can make it very easy to implement new regularization schemes other than L1 and L2 norm penalties. The idea of building ops for regularization is in sync with the refactored Paddle philosophy of using operators to represent any computation unit. The way these ops will be added to the computation graph, will be decided by the [layer functions](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/python_api.md#layer-function) in Python API. @@ -94,7 +63,7 @@ Since we want to create the regularization ops in a lazy manner, the regularizat #### High-level API -In PaddlePaddle Python API, users will primarily rely on [layer functions](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/python_api.md#layer-function) to create neural network layers. Hence, we lso need to provide regularization functionality in layer functions. The design of these APIs can be postponed for later right now. A good reference for these APIs can be found in [Keras](https://keras.io/regularizers/) and also by looking at Tensorflow in [`tf.contrib.layers`](https://www.tensorflow.org/api_guides/python/contrib.layers). +In PaddlePaddle Python API, users will primarily rely on [layer functions](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/python_api.md#layer-function) to create neural network layers. Hence, we also need to provide regularization functionality in layer functions. The design of these APIs can be postponed for later right now. A good reference for these APIs can be found in [Keras](https://keras.io/regularizers/) and also by looking at Tensorflow in [`tf.contrib.layers`](https://www.tensorflow.org/api_guides/python/contrib.layers). diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt index 85374a476d51dc4c0e22793e8b53d6d7ba21c8da..0a77859d6148f636dacef2c6759fc00d387f5d5d 100644 --- a/paddle/framework/CMakeLists.txt +++ b/paddle/framework/CMakeLists.txt @@ -1,6 +1,5 @@ # ddim lib proto_library(framework_proto SRCS framework.proto) -proto_library(saver_proto SRCS framework.proto saver.proto) cc_library(ddim SRCS ddim.cc DEPS eigen3) cc_test(ddim_test SRCS ddim_test.cc DEPS ddim) @@ -10,7 +9,7 @@ cc_library(tensor SRCS tensor.cc DEPS ddim place paddle_memory device_context) cc_test(tensor_test SRCS tensor_test.cc DEPS tensor) cc_test(eigen_test SRCS eigen_test.cc DEPS tensor) -cc_library(lod_tensor SRCS lod_tensor.cc DEPS ddim place tensor saver_proto framework_proto) +cc_library(lod_tensor SRCS lod_tensor.cc DEPS ddim place tensor framework_proto) cc_test(lod_tensor_test SRCS lod_tensor_test.cc DEPS lod_tensor paddle_memory) nv_test(lod_tensor_gpu_test SRCS lod_tensor_test.cu DEPS lod_tensor) diff --git a/paddle/framework/data_type.h b/paddle/framework/data_type.h index c25a62c2b11ead614d93a4be8d63d40d0cc0165a..bafb4fbd480bf2a28e3aa3dc615a310f80cec493 100644 --- a/paddle/framework/data_type.h +++ b/paddle/framework/data_type.h @@ -15,6 +15,7 @@ #pragma once #include #include "paddle/framework/framework.pb.h" +#include "paddle/platform/enforce.h" namespace paddle { namespace framework { diff --git a/paddle/framework/lod_tensor.cc b/paddle/framework/lod_tensor.cc index f53dd1c1858b45d39692eb683bc1dd9ee75b88fb..584308a5388da0d02d29f71a28097b02b6ea825f 100644 --- a/paddle/framework/lod_tensor.cc +++ b/paddle/framework/lod_tensor.cc @@ -13,7 +13,6 @@ limitations under the License. */ #include "paddle/framework/lod_tensor.h" -#include "paddle/framework/saver.pb.h" #include "paddle/memory/memcpy.h" #include "paddle/memory/memory.h" @@ -106,6 +105,15 @@ size_t LoDTensor::NumElements(size_t level, size_t idx) const { return lod_[level][idx + 1] - lod_[level][idx]; } +size_t LoDTensor::NumInstancesInElement(size_t level, size_t idx) const { + PADDLE_ENFORCE_LT(level, NumLevels()); + PADDLE_ENFORCE_LT(idx, NumElements(level)); + auto abs_lod = ToAbsOffset(lod()); + size_t begin = abs_lod[level][idx]; + size_t end = abs_lod[level][idx + 1]; + return end - begin; +} + void LoDTensor::ShrinkLevels(size_t level_begin, size_t level_end) { auto new_lod = framework::SliceLevels(lod_, level_begin, level_end); lod_ = new_lod; @@ -117,144 +125,15 @@ void LoDTensor::ShrinkInLevel(size_t level, size_t elem_begin, PADDLE_ENFORCE_LT(elem_begin, NumElements(level)); PADDLE_ENFORCE_LT(elem_end, NumElements(level) + 1); + auto abs_lod = framework::ToAbsOffset(lod()); auto new_lod = framework::SliceInLevel(lod_, level, elem_begin, elem_end); lod_ = new_lod; -} - -std::string LoDTensor::SerializeToString() const { - LoDTensorProto desc; - - // set data_type - if (this->type() == typeid(int8_t)) desc.set_data_type(DataType::BOOL); - if (this->type() == typeid(int16_t)) desc.set_data_type(DataType::INT16); - if (this->type() == typeid(int32_t)) desc.set_data_type(DataType::INT32); - if (this->type() == typeid(int64_t)) desc.set_data_type(DataType::INT64); - // FIXME(dzh): there is no fp16 in standard c++ - - if (this->type() == typeid(float)) // NOLINT - desc.set_data_type(DataType::FP32); - if (this->type() == typeid(double)) // NOLINT - desc.set_data_type(DataType::FP64); - - for (int i = 0; i < dims().size(); ++i) { - desc.add_dims(dims()[i]); - } - - // set lod information - desc.set_lod_level(this->NumLevels()); - for (size_t i = 0; i < this->NumLevels(); ++i) { - LoDInfo* lod = desc.add_levels(); - for (size_t j = 0; j < lod_[i].size(); ++j) { - lod->add_level(lod_[i][j]); - } - } - - desc.set_version(0); - - std::string desc_bytes = desc.SerializeAsString(); - - // FIXME(dzh) : implement fix chunk size buffer. - size_t DESC_SIZE = desc_bytes.size(); - size_t DATA_SIZE = holder_->size() - offset_; - - const size_t BUFFER_SIZE = DESC_SIZE + DATA_SIZE + 2 * sizeof(size_t); - char* buffer = - static_cast(memory::Alloc(platform::CPUPlace(), BUFFER_SIZE)); - - // format: desc_size data_size, desc_bytes, data_bytes. - platform::CPUPlace src_place; - platform::CPUPlace dst_place; - - memory::Copy(dst_place, buffer, src_place, &BUFFER_SIZE, sizeof(size_t)); - memory::Copy(dst_place, buffer + sizeof(size_t), src_place, &DESC_SIZE, - sizeof(size_t)); - memory::Copy(dst_place, buffer + sizeof(size_t) * 2, src_place, - desc_bytes.c_str(), desc_bytes.size()); - - PADDLE_ENFORCE(this->numel() != 0, "Serialize a empty Tensor!"); - platform::Place place = holder_->place(); - int element_width = holder_->size() / this->numel(); - - if (platform::is_cpu_place(place)) { - memory::Copy(dst_place, buffer + sizeof(size_t) * 2 + desc_bytes.size(), - boost::get(place), - static_cast(holder_->ptr()) + offset_ / element_width, - DATA_SIZE); - } -#ifdef PADDLE_WITH_GPU - if (platform::is_gpu_place(place)) { - memory::Copy(dst_place, buffer + sizeof(size_t) * 2 + desc_bytes.size(), - boost::get(place), - static_cast(holder_->ptr()) + offset_ / element_width, - DATA_SIZE); - } -#endif - - std::string ret(buffer, BUFFER_SIZE); - memory::Free(platform::CPUPlace(), buffer); - return ret; + // slice the underlying tensor + size_t begin = abs_lod[level][elem_begin]; + size_t end = abs_lod[level][elem_end]; + PADDLE_ENFORCE_LT(begin, end, "Cannot shrink, the result tensor is empty."); + ShareDataWith(Slice(begin, end)); } - -void LoDTensor::DeserializeFromString(const std::string& s, - const platform::Place& dst_place) { - size_t DESC_SIZE, BUFFER_SIZE; - platform::CPUPlace src_place; - - memory::Copy(src_place, &BUFFER_SIZE, src_place, s.c_str(), sizeof(size_t)); - memory::Copy(src_place, &DESC_SIZE, src_place, s.c_str() + sizeof(size_t), - sizeof(size_t)); - - const size_t DATA_SIZE = BUFFER_SIZE - DESC_SIZE - sizeof(size_t) * 2; - - // parse LoDTensorDesc - LoDTensorProto desc; - desc.ParseFromArray(s.c_str() + sizeof(size_t) * 2, DESC_SIZE); - - std::vector dims; - std::copy(desc.dims().begin(), desc.dims().end(), std::back_inserter(dims)); - this->Resize(make_ddim(dims)); - - // parse data type - void* ptr = nullptr; - if (desc.data_type() == DataType::BOOL) - ptr = this->mutable_data(dst_place); - if (desc.data_type() == DataType::INT16) - ptr = this->mutable_data(dst_place); - if (desc.data_type() == DataType::INT32) - ptr = this->mutable_data(dst_place); - if (desc.data_type() == DataType::INT64) - ptr = this->mutable_data(dst_place); - // FIXME(dzh): there is no fp16 in standard c++ - - if (desc.data_type() == DataType::FP32) - ptr = this->mutable_data(dst_place); - if (desc.data_type() == DataType::FP64) - ptr = this->mutable_data(dst_place); - - LoD lod; - std::vector levels; - for (int i = 0; i < desc.levels().size(); ++i) { - auto current_level = desc.levels()[i].level(); - std::copy(current_level.begin(), current_level.end(), - std::back_inserter(levels)); - lod.emplace_back(levels); - levels.clear(); - } - - this->set_lod(lod); - - if (platform::is_cpu_place(dst_place)) { - memory::Copy(boost::get(dst_place), ptr, src_place, - s.c_str() + sizeof(size_t) * 2 + DESC_SIZE, DATA_SIZE); - } -#ifdef PADDLE_WITH_GPU - if (platform::is_gpu_place(dst_place)) { - memory::Copy(boost::get(dst_place), ptr, src_place, - s.c_str() + sizeof(size_t) * 2 + DESC_SIZE, DATA_SIZE); - } -#endif -} - } // namespace framework } // namespace paddle diff --git a/paddle/framework/lod_tensor.h b/paddle/framework/lod_tensor.h index f78a751c53621aa103026b5d8a251966685822bb..f4fe4cdac6019a1899fd3db8e1b6ca588be0d436 100644 --- a/paddle/framework/lod_tensor.h +++ b/paddle/framework/lod_tensor.h @@ -85,7 +85,9 @@ class LoDTensor : public Tensor { void set_lod(const LoD& lod) { lod_ = lod; } - LoD lod() const { return lod_; } + const LoD& lod() const { return lod_; } + + LoD* mutable_lod() { return &lod_; } /* * Get the start offset and end offset of an element from LoD. @@ -122,6 +124,12 @@ class LoDTensor : public Tensor { */ size_t NumElements(size_t level, size_t idx) const; + /* + * Get the number of instances in the underlying tensor in the `idx`-th + * element. + */ + size_t NumInstancesInElement(size_t level, size_t idx) const; + /* * Shrink levels[level_begin:level_end] */ @@ -133,29 +141,45 @@ class LoDTensor : public Tensor { */ void ShrinkInLevel(size_t level, size_t elem_begin, size_t elem_end); - /** - * @brief Serialize tensor to char bytes. - * Please check model_format.md for the format detail. - * NOTE: GPUTensor will copy data to cpu implicitly. - * @return return string - */ - - // FIXME(dzh) : Currently, this interface should only be used in - // save/restore model and checkpoint. ParameterServer do not use shape - // information to do the optimization, as a result, when we serialize - // parameter/gradient to string, we should serialize the tensor - // to string in the ps trainer instead of LoDTensor. - std::string SerializeToString() const; - - /** - * @brief Deserialize char bytes to tensor. - * @return return string - */ - void DeserializeFromString(const std::string& s, - const platform::Place& dst_place); - private: LoD lod_; }; + +/* + * Expand the `source` to fit the LoD of `lod`. For example, a `source` + * LoDTensor is + * - LoD: [0, 2] + * - tensor: [a0, a1] + * a `lod` is + * - LoD: [0 3 5] + * returns a new LoDTensor + * - [a0 a0 a0 a1 a1] + */ +template +LoDTensor LodExpand(const LoDTensor& source, const LoD& lod, size_t level, + const platform::Place& place) { + LoD abs_lod = ToAbsOffset(lod); + const auto& lod_level = lod[level]; + size_t num_instances = source.dims()[0]; + + // new tensor + LoDTensor tensor; + tensor.set_lod(lod); + auto dims = source.dims(); + dims[0] = lod_level.back(); + tensor.Resize(dims); + tensor.mutable_data(place); + + PADDLE_ENFORCE_EQ(num_instances, lod_level.size() - 1); + for (size_t ins = 0; ins < num_instances; ins++) { + for (size_t elem = lod_level[ins]; elem < lod_level[ins + 1]; elem++) { + tensor.Slice(elem, elem + 1) + .CopyFrom(source.Slice(ins, ins + 1), platform::CPUPlace(), + platform::CPUDeviceContext()); + } + } + return tensor; +} + } // namespace framework } // namespace paddle diff --git a/paddle/framework/lod_tensor_test.cc b/paddle/framework/lod_tensor_test.cc index b984d620717453456fb15620b4d10c4268be8a94..aa2f6c993d41ae98e0769d470dccad3b410da53e 100644 --- a/paddle/framework/lod_tensor_test.cc +++ b/paddle/framework/lod_tensor_test.cc @@ -92,11 +92,14 @@ TEST_F(LoDTensorTester, ShrinkInLevel) { size_t level = 0; LoDTensor new_lod_tensor = lod_tensor_; new_lod_tensor.ShrinkInLevel(level, 0, 1); - EXPECT_EQ(new_lod_tensor.NumLevels(), 3UL); - EXPECT_EQ(new_lod_tensor.NumElements(0), 1UL); - EXPECT_EQ(new_lod_tensor.NumElements(1), 2UL); - EXPECT_EQ(new_lod_tensor.NumElements(2), 5UL); - ASSERT_EQ(new_lod_tensor.data(), lod_tensor_.data()); + ASSERT_EQ(new_lod_tensor.NumLevels(), 3UL); + ASSERT_EQ(new_lod_tensor.NumElements(0), 1UL); + ASSERT_EQ(new_lod_tensor.NumElements(1), 2UL); + ASSERT_EQ(new_lod_tensor.NumElements(2), 5UL); + ASSERT_EQ(new_lod_tensor.dims()[0], 12); + for (int i = 0; i < 12 * 128; i++) { + ASSERT_EQ(new_lod_tensor.data()[i], i); + } level = 1; new_lod_tensor = lod_tensor_; @@ -104,23 +107,41 @@ TEST_F(LoDTensorTester, ShrinkInLevel) { ASSERT_EQ(new_lod_tensor.NumLevels(), 2UL); ASSERT_EQ(new_lod_tensor.NumElements(0), 1UL); ASSERT_EQ(new_lod_tensor.NumElements(1), 3UL); - ASSERT_EQ(new_lod_tensor.data(), lod_tensor_.data()); + ASSERT_EQ(new_lod_tensor.dims()[0], 7); + for (int i = 5 * 128; i < 12 * 128; i++) { + ASSERT_EQ(new_lod_tensor.data()[i - 5 * 128], i); + } + + LoDTensor t1; + t1.set_lod(lod_tensor_.lod()); + t1.ShareDataWith(lod_tensor_); + + LoDTensor t2; + t2.set_lod(lod_tensor_.lod()); + t2.ShareDataWith(lod_tensor_); + + t1.ShrinkInLevel(0, 1, 2); + t2.ShrinkInLevel(0, 0, 1); + EXPECT_NE(t1.data(), t2.data()); + EXPECT_NE(t1.data(), lod_tensor_.data()); } -TEST_F(LoDTensorTester, SerializeDeserialize) { - LoDTensor new_lod_tensor = lod_tensor_; - float* src_ptr = lod_tensor_.data(); - std::string s = lod_tensor_.SerializeToString(); - LoDTensor dst; - dst.DeserializeFromString(s, platform::CPUPlace()); - float* dst_ptr = dst.data(); - for (int i = 0; i < kLodTensorSize; ++i) { - EXPECT_EQ(dst_ptr[i], src_ptr[i]); +TEST(LodExpand, test) { + LoD lod{{0, 2}}; + LoDTensor tensor; + tensor.set_lod(lod); + tensor.Resize({2, 1}); + tensor.mutable_data(platform::CPUPlace()); + tensor.data()[0] = 0; + tensor.data()[1] = 1; + + LoD target; + target.emplace_back(std::vector{0, 3, 5}); + auto new_tensor = LodExpand(tensor, target, 0UL, platform::CPUPlace()); + std::vector result{{0, 0, 0, 1, 1}}; + for (size_t i = 0; i < 5; i++) { + ASSERT_EQ(new_tensor.data()[i], result[i]); } - - ASSERT_EQ(dst.NumElements(0), 2UL); - ASSERT_EQ(dst.NumElements(1), 3UL); - ASSERT_EQ(dst.NumElements(2), 8UL); } } // namespace framework diff --git a/paddle/framework/lod_tensor_test.cu b/paddle/framework/lod_tensor_test.cu index 11659be02ac340728150cf0a6438db8626c8e611..c79c4d0c721f9e568c937cb9e524e925fcdc83d0 100644 --- a/paddle/framework/lod_tensor_test.cu +++ b/paddle/framework/lod_tensor_test.cu @@ -47,31 +47,4 @@ TEST(LoDTensor, LoDInGPU) { for (size_t i = 0; i < src_lod[0].size(); ++i) { CHECK_EQ(lod[0].data()[i], src_lod[0].data()[i] * 2); } -} - -TEST(LoDTensor, SerializeDeserialize) { - paddle::framework::LoDTensor lod_tensor; - paddle::platform::GPUPlace place(0); - - paddle::framework::LoD src_lod; - src_lod.push_back(std::vector{0, 2, 4, 6, 8, 10, 12, 14}); - - lod_tensor.Resize({14, 16}); - lod_tensor.mutable_data(place); - - lod_tensor.set_lod(src_lod); - CHECK_EQ(lod_tensor.lod_element(0, 2).first, 4UL); - CHECK_EQ(lod_tensor.lod_element(0, 4).first, 8UL); - - test<<<1, 8>>>(src_lod[0].data(), src_lod[0].size()); - cudaDeviceSynchronize(); - - std::string s = lod_tensor.SerializeToString(); - paddle::framework::LoDTensor dst; - dst.DeserializeFromString(s, place); - paddle::framework::LoD dst_lod = dst.lod(); - - for (size_t i = 0; i < dst_lod[0].size(); ++i) { - CHECK_EQ(src_lod[0].data()[i], dst_lod[0].data()[i] * 2); - } -} +} \ No newline at end of file diff --git a/paddle/framework/saver.proto b/paddle/framework/saver.proto deleted file mode 100644 index 90a191a6a79250761489b68916b1fa09116830f2..0000000000000000000000000000000000000000 --- a/paddle/framework/saver.proto +++ /dev/null @@ -1,39 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -syntax = "proto2"; -option optimize_for = LITE_RUNTIME; -package paddle.framework; - -import "framework.proto"; - -/** - * This file contains necessary information for model, checkpoint. - * etc. - */ - -message LoDInfo { repeated int64 level = 1; } - -/** - * Save the LoDTensorDesc information through LoDTensorProto, its data memory - * is copyed to c buffer immediately. See model_format.md for details. - */ - -message LoDTensorProto { - optional DataType data_type = 1; - repeated int64 dims = 2; // [UNK, 640, 480] is saved as [-1, 640, 480] - repeated LoDInfo levels = 3; - optional int32 lod_level = 4 [ default = 0 ]; - optional int32 version = 5; -} diff --git a/paddle/framework/tensor.h b/paddle/framework/tensor.h index e31472327dbca45dc12ea2c9e494beddd36860dc..9d2dc6a32bb2d4f6368fd9c7264c55fb9588819c 100644 --- a/paddle/framework/tensor.h +++ b/paddle/framework/tensor.h @@ -132,6 +132,8 @@ class Tensor { std::type_index type() const { return holder_->type(); } + size_t memory_size() const; + private: inline void check_memory_size() const; diff --git a/paddle/framework/tensor_array.cc b/paddle/framework/tensor_array.cc index 4c82c3638351c41df26503e2a26b5a4bb5822a67..0947e33548130a923e998f8bad68db00097af909 100644 --- a/paddle/framework/tensor_array.cc +++ b/paddle/framework/tensor_array.cc @@ -20,6 +20,8 @@ #include #include +#include "paddle/framework/eigen.h" + namespace paddle { namespace framework { @@ -104,10 +106,10 @@ void TensorArray::Write(size_t index, const LoDTensor& value) { values_.resize(index + 1); } + values_[index].set_lod(value.lod()); values_[index].Resize(value.dims()); - values_[index].mutable_data(platform::CPUPlace()); - values_[index].CopyFrom(value, platform::CPUPlace(), - platform::CPUDeviceContext()); + values_[index].mutable_data(value.place()); + values_[index].CopyFrom(value, value.place(), platform::CPUDeviceContext()); } void TensorArray::WriteShared(size_t index, const LoDTensor& value) { @@ -116,6 +118,7 @@ void TensorArray::WriteShared(size_t index, const LoDTensor& value) { values_.resize(index + 1); } + values_[index].set_lod(value.lod()); values_[index].ShareDataWith(value); } @@ -144,6 +147,155 @@ DySeqMetaBatch TensorArray::Unpack(const LoDTensor& source, int level, return unpacker.meta; } +LoDTensor TensorArray::LodPack(size_t level) const { + PADDLE_ENFORCE_GT(size(), 0UL, "no time step exists"); + // the levels should be no less than 2 + LoDTensor merged; + const LoDTensor *pre, *cur; + pre = &Read(0); + + for (size_t step = 1; step < size(); step++) { + cur = &Read(step); + PADDLE_ENFORCE_GT(cur->NumLevels(), 0); + PADDLE_ENFORCE_GT(pre->NumLevels(), 0); + PADDLE_ENFORCE_EQ(pre->NumLevels(), cur->NumLevels()); + PADDLE_ENFORCE_EQ(pre->NumElements(level), cur->NumElements(level)); + + merged = LodPackTwo(*pre, *cur, level); + pre = &merged; + } + return merged; +} + +/* + * NOTE currently, only the lowest level supports packing. + * The lowest LoD will be changed, while the relative offsets in levels above + * stay unchanged. + * + * previous step : [0] [1] [3] + * current step: [0 1 2] [2 3] [] + * packed to + * [0 0] [0 1] [0 2] [1 2] [1 3] [3] + */ +LoDTensor TensorArray::LodPackTwo(const LoDTensor& pre, const LoDTensor& cur, + size_t level) const { + PADDLE_ENFORCE_EQ(pre.NumLevels(), cur.NumLevels()); + PADDLE_ENFORCE_EQ(pre.NumLevels(), level + 1, + "Only the lowest LoD level supports pack temporarily."); + // calculate the result tensor's shape first + size_t num_instances = 0; + for (size_t elem = 0; elem < pre.NumElements(level); elem++) { + size_t prefix_size = pre.NumElements(level, elem); + size_t num_candidates = cur.NumElements(level, elem); + if (num_candidates > 0) { + num_instances += num_candidates * (prefix_size + 1); + } else { + num_instances += prefix_size; + } + } + + auto res_dims = pre.dims(); + res_dims[0] = num_instances; + LoDTensor result; + result.Resize(res_dims); + result.mutable_data(cur.place()); + + Vector last_lod_level; + // copy data + size_t index = 0; + last_lod_level.push_back(index); + for (size_t elem = 0; elem < pre.NumElements(level); elem++) { + size_t prefix_size = pre.NumElements(level, elem); + size_t num_candidates = cur.NumElements(level, elem); + + // slice the prefix Tensor + LoDTensor prefix = pre; + prefix.ShrinkInLevel(level, elem, elem + 1); + LoDTensor candidate = cur; + if (num_candidates > 0) { + candidate.ShrinkInLevel(level, elem, elem + 1); + } else { // just push prefix + result.Slice(index, index + prefix_size) + .CopyFrom(prefix, result.place(), platform::CPUDeviceContext()); + index += prefix_size; + last_lod_level.push_back(index); + } + for (size_t candi = 0; candi < num_candidates; candi++) { + // TODO(superjom) support GPU + result.Slice(index, index + prefix_size) + .CopyFrom(prefix, result.place(), platform::CPUDeviceContext()); + index += prefix_size; + // copy candidate record + result.Slice(index, index + 1) + .CopyFrom(candidate.Slice(candi, candi + 1), result.place(), + platform::CPUDeviceContext()); + index++; + last_lod_level.push_back(index); + } + } + + // update lod + auto lod = cur.lod(); + lod.back() = last_lod_level; + result.set_lod(lod); + return result; +} + +/* + * source [0 1 2] [3 4] [5 6 7] will be transformd to a list of LoDTensors such + * as + * [0 3 5] [1 4 6] [2 7] with 1-level LoDs: + * - [0 1 2 3] + * - [0 1 2 3] + * - [0 1 1 2], the [1,1) here means the second sequence is empty + * + * NOTE Unpack a LoDTensor in this approach may result in a big LoD. + */ +void TensorArray::LodUnpack(const LoDTensor& source, size_t level) { + PADDLE_ENFORCE_EQ(level, source.NumLevels() - 1, + "only the lowest LoD level supports unpack."); + const size_t non_empty_instances = source.dims()[0]; + size_t index = 0; + Vector lowest_lod_level; + lowest_lod_level.push_back(index); + + for (size_t step = 0; step < non_empty_instances; step++) { + size_t num_instances = 0; + for (size_t id = 0; id < source.NumElements(level); id++) { + auto instance = source; + instance.ShrinkInLevel(level, id, id + 1); + if (static_cast(instance.dims()[0]) > step) { + num_instances++; + index++; + } + lowest_lod_level.push_back(index); + } + + // create tensor for this time step + LoDTensor tensor; + auto dims = source.dims(); + dims[0] = num_instances; + // set lod + auto lod = source.lod(); + lod.back() = lowest_lod_level; + tensor.set_lod(lod); + + index = 0; + for (size_t id = 0; id < source.NumElements(level); id++) { + auto instance = source; + instance.ShrinkInLevel(level, id, id + 1); + if (static_cast(instance.dims()[0]) > step) { + // copy this instance + tensor.Slice(index, index + 1) + .CopyFrom(instance.Slice(step, step + 1), tensor.place(), + platform::CPUDeviceContext()); + index++; + } + } + Write(step, tensor); + } +} + LoDTensor TensorArray::Stack() const { LoDTensor result; if (size() == 0) return result; diff --git a/paddle/framework/tensor_array.h b/paddle/framework/tensor_array.h index 046ecb5221b7ed9d88e5017348ee8fcde23c7677..78fad8cab7e27a7f07ca542c2a083460ee9e2b79 100644 --- a/paddle/framework/tensor_array.h +++ b/paddle/framework/tensor_array.h @@ -86,6 +86,16 @@ class TensorArray { */ DySeqMetaBatch Unpack(const LoDTensor &source, int level, bool length_desend); + /* + * Pack an array of LoDTensors to a LoDTensor. + */ + LoDTensor LodPack(size_t level) const; + + /* + * Unpack a LoDTensor to an array of LoDTensors. + */ + void LodUnpack(const LoDTensor &source, size_t level); + /* * Pack the values into a tensor with rank one higher than each tensor in * values. @@ -111,6 +121,9 @@ class TensorArray { protected: void Unstack(const LoDTensor &source, bool data_shared) const; + LoDTensor LodPackTwo(const LoDTensor &pre, const LoDTensor &cur, + size_t level) const; + private: mutable std::vector values_; }; // class TensorArray diff --git a/paddle/framework/tensor_array_test.cc b/paddle/framework/tensor_array_test.cc index 9470ac5e6ed714d5ba63f3743e683af7f8edd4b0..83b52b442daf9b2f1fc40f23e458fcb67c5040e8 100644 --- a/paddle/framework/tensor_array_test.cc +++ b/paddle/framework/tensor_array_test.cc @@ -126,5 +126,57 @@ TEST_F(TensorArrayTester, size) { ASSERT_EQ(ta.size(), static_cast(batch_size)); } +TEST(TensorArray, LodPack) { + // three time steps, each step stores a LoDTensors + // - [0] [1] + // - [2 3], [4 5] + // - [6 7] [] [8], [9, 10] + // try to get a LoDTensor with content: + // - [0 2 6] + // - [0 2 7] + // - [0 3] + // - [1 4 8] + // - [1 5 9] + // - [1 5 10] + std::array tensors; + tensors[0].Resize(make_ddim({2, 1})); + tensors[1].Resize(make_ddim({4, 1})); + tensors[2].Resize(make_ddim({5, 1})); + int index = 0; + for (auto& t : tensors) { + t.mutable_data(platform::CPUPlace()); + for (int i = 0; i < t.dims()[0]; i++) { + t.data()[i] = index; + index++; + } + } + + std::array lods; + std::vector> levels{ + {0, 1, 2}, {0, 2, 4}, {0, 2, 2, 3, 5}}; + for (int i = 0; i < 3; i++) { + lods[i].emplace_back(levels[i].begin(), levels[i].end()); + } + + TensorArray ta; + for (int i = 0; i < 3; i++) { + tensors[i].set_lod(lods[i]); + ta.Write(i, tensors[i]); + } + + auto merged = ta.LodPack(0); + + std::vector target_tensor_data{{0, 2, 6, // 0 + 0, 2, 7, // 1 + 0, 3, // 2 + 1, 4, 8, // 3 + 1, 5, 9, // 5 + 1, 5, 10}}; + EXPECT_EQ(merged.dims()[0], (int)target_tensor_data.size()); + for (size_t i = 0; i < target_tensor_data.size(); i++) { + EXPECT_EQ(target_tensor_data[i], merged.data()[i]); + } +} + } // namespace framework } // namespace paddle diff --git a/paddle/framework/tensor_impl.h b/paddle/framework/tensor_impl.h index f6e801bbb4a056b5590da95a4b140cb90638f322..29ac683f48fcde4dd3b5ad7f04b5d1d7434706ba 100644 --- a/paddle/framework/tensor_impl.h +++ b/paddle/framework/tensor_impl.h @@ -62,12 +62,16 @@ inline void Tensor::check_memory_size() const { PADDLE_ENFORCE_NOT_NULL( holder_, "Tensor holds no memory. Call Tensor::mutable_data first."); PADDLE_ENFORCE_GE( - holder_->size(), numel() * SizeOfType(type()) + offset_, + holder_->size(), memory_size() + offset_, "Tensor's dims_ is out of bound. Call Tensor::mutable_data " "first to re-allocate memory.\n" "or maybe the required data-type mismatches the data already stored."); } +inline size_t Tensor::memory_size() const { + return holder_ == nullptr ? 0UL : numel() * SizeOfType(type()); +} + template inline const T* Tensor::data() const { check_memory_size(); diff --git a/paddle/framework/variable.h b/paddle/framework/variable.h index a80f0e66b5a59bf95efc200d159ad5dd9cf4111a..cde5ec2413ad01a0396e19fa617688af0eafbc75 100644 --- a/paddle/framework/variable.h +++ b/paddle/framework/variable.h @@ -46,6 +46,8 @@ class Variable { std::type_index(typeid(T)) == std::type_index(holder_->Type()); } + void Clear() { holder_.reset(); } + private: struct Placeholder { virtual ~Placeholder() {} diff --git a/paddle/gserver/layers/MKLDNNBatchNormLayer.cpp b/paddle/gserver/layers/MKLDNNBatchNormLayer.cpp new file mode 100644 index 0000000000000000000000000000000000000000..f577616230be65e9581cf8f3ed5f63a77c7c3e21 --- /dev/null +++ b/paddle/gserver/layers/MKLDNNBatchNormLayer.cpp @@ -0,0 +1,318 @@ +/* Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "MKLDNNBatchNormLayer.h" + +using namespace mkldnn; // NOLINT +typedef memory::format format; + +namespace paddle { + +REGISTER_LAYER(mkldnn_batch_norm, MKLDNNBatchNormLayer); + +const real MKLDNNBatchNormLayer::EPS = 1E-5; + +bool MKLDNNBatchNormLayer::init(const LayerMap& layerMap, + const ParameterMap& parameterMap) { + if (!MKLDNNLayer::init(layerMap, parameterMap)) { + return false; + } + + // first one is input layer + // the other two are created in config_parser.py saving moving mean and var + CHECK_EQ(inputLayers_.size(), 3U); + CHECK_EQ(inputLayers_.size(), parameters_.size()); + CHECK_EQ(inputLayers_.size(), size_t(config_.inputs_size())); + + const ImageConfig& conf = config_.inputs(0).image_conf(); + ic_ = conf.channels(); + ih_ = inputLayers_[0]->getOutput().getFrameHeight(); + iw_ = inputLayers_[0]->getOutput().getFrameWidth(); + if (iw_ == 0 && ih_ == 0) { + iw_ = conf.img_size(); + ih_ = conf.has_img_size_y() ? conf.img_size_y() : conf.img_size(); + } + oc_ = ic_; + oh_ = ih_; + ow_ = iw_; + if (config_.has_use_global_stats()) { + useGlobalStats_ = config_.use_global_stats(); + } + movingAvgFraction_ = config_.moving_average_fraction(); + VLOG(MKLDNN_BASE) << "--- " << (useGlobalStats_ ? "use" : "do not use") + << " --- global stats"; + VLOG(MKLDNN_BASE) << "Moving average fraction: " << movingAvgFraction_; + + initWeight(); + movingMean_.reset(new Weight(oc_, 1, parameters_[1], 0)); + movingVar_.reset(new Weight(oc_, 1, parameters_[2], 0)); + return true; +} + +void MKLDNNBatchNormLayer::initWeight() { + weight_.reset(new Weight(1, oc_, parameters_[0])); + if (biasParameter_.get() != NULL) { + biases_ = std::unique_ptr(new Weight(1, oc_, biasParameter_)); + } + CHECK_EQ(weight_ != nullptr, biases_ != nullptr) + << "only support have both weight and bias, or neither"; + if (weight_ && weight_->getW()) { + CHECK(biases_ && biases_->getW()); + valueScaleShift_ = Matrix::create(2, oc_, false, false); + valueScaleShift_->zeroMem(); + VectorPtr scale(new CpuVector(oc_, valueScaleShift_->getMemoryHandle(), 0)); + VectorPtr shift( + new CpuVector(oc_, valueScaleShift_->getMemoryHandle(), oc_)); + const VectorPtr& wgt = parameters_[0]->getBuf(PARAMETER_VALUE); + const VectorPtr& bias = biasParameter_->getBuf(PARAMETER_VALUE); + scale->copyFrom(*wgt); + shift->copyFrom(*bias); + wgt->setData(valueScaleShift_->getData()); + bias->setData(valueScaleShift_->getData() + oc_); + } + if (weight_ && weight_->getWGrad()) { + CHECK(biases_ && biases_->getWGrad()); + gradScaleShift_ = Matrix::create(2, oc_, false, false); + gradScaleShift_->zeroMem(); + const VectorPtr& wgt = parameters_[0]->getBuf(PARAMETER_GRADIENT); + const VectorPtr& bias = biasParameter_->getBuf(PARAMETER_GRADIENT); + wgt->setData(gradScaleShift_->getData()); + bias->setData(gradScaleShift_->getData() + oc_); + } +} + +void MKLDNNBatchNormLayer::convertWeightsFromPaddle() { + if (hasInitedWgt_) { + return; + } + // prepare mean and var if necessary + if (useGlobalStats_) { + CHECK(mean_); + CHECK(var_); + mean_->copyFrom(*(movingMean_->getW())); + var_->copyFrom(*(movingVar_->getW())); + } + hasInitedWgt_ = true; +} + +void MKLDNNBatchNormLayer::calMovingMeanAndVar() { + // calculating and saving moving mean and variance + CHECK_EQ(useGlobalStats_, false); + movingMean_->getW()->add( + *mean_, movingAvgFraction_, 1.0 - movingAvgFraction_); + // here var is v^2 + movingVar_->getW()->add(*var_, movingAvgFraction_, 1.0 - movingAvgFraction_); +} + +void MKLDNNBatchNormLayer::reshape( + int& bs, int& ic, int& ih, int& iw, int oc, int& oh, int& ow) { + reshapeInput(bs, ih, iw); + oh = ih; + ow = ow; + // ic_ and oc can not be changed + CHECK_EQ(inputElemenCnt_ / bs / ih / iw, (size_t)ic) + << "Input channel can not be changed"; + reshapeOutput(oh, ow); + resizeOutput(bs, oc * oh * ow); + printSizeInfo(); +} + +void MKLDNNBatchNormLayer::resetFwd(std::vector& pipeline, + MKLDNNMatrixPtr& in, + MKLDNNMatrixPtr& wgt, + MKLDNNMatrixPtr& bias, + MKLDNNMatrixPtr& out) { + // In training phase, it will always calculate mean and var, + // so useGlobalStats must be false. + // In scoring phase, it depends on useGlobalStats choice. + if (passType_ != PASS_TEST && useGlobalStats_ == true) { + LOG(WARNING) << "use_global_stats is invalid setting in training phase"; + useGlobalStats_ = false; + } + + resetFwdBuffers(in, wgt, out); + + resetFwdPD(fwdPD_, in, wgt, out); + + resetFwdPipeline(pipeline, fwdPD_, in, wgt, out); +} + +void MKLDNNBatchNormLayer::resetBwd(std::vector& pipeline, + MKLDNNMatrixPtr& in, + MKLDNNMatrixPtr& wgt, + MKLDNNMatrixPtr& bias, + MKLDNNMatrixPtr& out) { + std::shared_ptr pd; + + resetBwdBuffers(in, wgt, out); + + resetBwdPD(pd, in, wgt, out); + + resetBwdPipeline(pipeline, pd, in, wgt, out); +} + +void MKLDNNBatchNormLayer::forward(PassType passType) { + MKLDNNLayer::forward(passType); + + // calculate and save moving mean and variance + if (passType_ != PASS_TEST) { + calMovingMeanAndVar(); + } +} + +void MKLDNNBatchNormLayer::updateWeights(const UpdateCallback& callback) { + weight_->getParameterPtr()->incUpdate(callback); + if (biases_ && biases_->getWGrad()) { + biases_->getParameterPtr()->incUpdate(callback); + } +} + +void MKLDNNBatchNormLayer::resetFwdBuffers(MKLDNNMatrixPtr& in, + MKLDNNMatrixPtr& wgt, + MKLDNNMatrixPtr& out) { + resetInValue(in); + + memory::dims outDims = memory::dims{bs_, oc_, oh_, ow_}; + CHECK(in); + auto outPD = + MKLDNNMatrix::createPrimitiveDesc(outDims, in->getFormat(), engine_); + resetOutValue(out, outPD); + + if (valueScaleShift_) { + auto pd = MKLDNNMatrix::createPrimitiveDesc({2, oc_}, format::nc, engine_); + resetWithMatrix(wgt, valueScaleShift_, pd); + } + if (passType_ != PASS_TEST || useGlobalStats_) { + auto pd = MKLDNNMatrix::createPrimitiveDesc({oc_}, format::x, engine_); + mean_ = MKLDNNMatrix::create(pd); + var_ = MKLDNNMatrix::create(pd); + } +} + +void MKLDNNBatchNormLayer::resetFwdPD( + std::shared_ptr& pd, + MKLDNNMatrixPtr in, + MKLDNNMatrixPtr wgt, + MKLDNNMatrixPtr out) { + flags_ = 0u; + prop_kind pk = passType_ == PASS_TEST ? prop_kind::forward_scoring + : prop_kind::forward_training; + if (useGlobalStats_) { + flags_ = (flags_ | batch_normalization_flag::use_global_stats); + } + if (wgt) { + flags_ = (flags_ | batch_normalization_flag::use_scale_shift); + } + auto fwdDesc = bn_fwd::desc(pk, in->getMemoryDesc(), EPS, flags_); + pd.reset(new bn_fwd::primitive_desc(fwdDesc, engine_)); + // TODO(TJ): use check macro + CHECK(out); + CHECK(out->getPrimitiveDesc() == pd->dst_primitive_desc()); + if (wgt) { + CHECK(wgt->getPrimitiveDesc() == pd->weights_primitive_desc()); + } + if (passType_ != PASS_TEST || useGlobalStats_) { + CHECK(mean_); + CHECK(mean_->getPrimitiveDesc() == pd->mean_primitive_desc()); + CHECK(var_); + CHECK(var_->getPrimitiveDesc() == pd->variance_primitive_desc()); + } +} + +void MKLDNNBatchNormLayer::resetFwdPipeline( + std::vector& pipeline, + std::shared_ptr& pd, + MKLDNNMatrixPtr& in, + MKLDNNMatrixPtr& wgt, + MKLDNNMatrixPtr& out) { + if (passType_ == PASS_TEST) { + if (useGlobalStats_) { + fwd_.reset(wgt != nullptr ? new bn_fwd(*pd, + *in, + (const primitive::at)(*mean_), + (const primitive::at)(*var_), + *wgt, + *out) + : new bn_fwd(*pd, + *in, + (const primitive::at)(*mean_), + (const primitive::at)(*var_), + *out)); + } else { + fwd_.reset(wgt != nullptr ? new bn_fwd(*pd, *in, *wgt, *out) + : new bn_fwd(*pd, *in, *out)); + } + } else { + CHECK_EQ(useGlobalStats_, false) + << "useGlobalStats should be false in training"; + fwd_.reset(wgt != nullptr ? new bn_fwd(*pd, *in, *wgt, *out, *mean_, *var_) + : new bn_fwd(*pd, *in, *out, *mean_, *var_)); + } + pipeline.push_back(*fwd_); +} + +void MKLDNNBatchNormLayer::resetBwdBuffers(MKLDNNMatrixPtr& in, + MKLDNNMatrixPtr& wgt, + MKLDNNMatrixPtr& out) { + CHECK(inVal_ && outVal_); + resetOutGrad(out, outVal_->getPrimitiveDesc()); + resetInGrad(in, inVal_->getPrimitiveDesc()); + if (gradScaleShift_) { + CHECK(wgtVal_); + resetWithMatrix(wgt, gradScaleShift_, wgtVal_->getPrimitiveDesc()); + } +} + +void MKLDNNBatchNormLayer::resetBwdPD( + std::shared_ptr& pd, + MKLDNNMatrixPtr& in, + MKLDNNMatrixPtr& wgt, + MKLDNNMatrixPtr& out) { + pd = nullptr; + if (in == nullptr) { + return; + } + CHECK(out); + CHECK(out->getPrimitiveDesc() == in->getPrimitiveDesc()); + auto md = in->getMemoryDesc(); + auto bwdDesc = bn_bwd::desc(prop_kind::backward, md, md, EPS, flags_); + pd.reset(new bn_bwd::primitive_desc(bwdDesc, engine_, *fwdPD_)); + // TODO(TJ): use check macro + CHECK(wgt); + CHECK(wgt->getPrimitiveDesc() == pd->diff_weights_primitive_desc()); + CHECK(pd->weights_primitive_desc() == fwdPD_->weights_primitive_desc()); + CHECK(mean_); + CHECK(mean_->getPrimitiveDesc() == pd->mean_primitive_desc()); + CHECK(var_); + CHECK(var_->getPrimitiveDesc() == pd->variance_primitive_desc()); +} + +void MKLDNNBatchNormLayer::resetBwdPipeline( + std::vector& pipeline, + std::shared_ptr& pd, + MKLDNNMatrixPtr& in, + MKLDNNMatrixPtr& wgt, + MKLDNNMatrixPtr& out) { + if (pd == nullptr) { + return; + } + CHECK(inVal_); + bwdData_.reset( + wgt && wgtVal_ + ? new bn_bwd(*pd, *inVal_, *mean_, *var_, *out, *wgtVal_, *in, *wgt) + : new bn_bwd(*pd, *inVal_, *mean_, *var_, *out, *in)); + pipeline.push_back(*bwdData_); +} + +} // namespace paddle diff --git a/paddle/gserver/layers/MKLDNNBatchNormLayer.h b/paddle/gserver/layers/MKLDNNBatchNormLayer.h new file mode 100644 index 0000000000000000000000000000000000000000..456c0424ecb8dde17f98a900c5d77268cc672e34 --- /dev/null +++ b/paddle/gserver/layers/MKLDNNBatchNormLayer.h @@ -0,0 +1,138 @@ +/* Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "MKLDNNLayer.h" +#include "mkldnn.hpp" + +namespace paddle { +typedef mkldnn::batch_normalization_forward bn_fwd; +typedef mkldnn::batch_normalization_backward bn_bwd; + +/** + * @brief A subclass of MKLDNNLayer BatchNorm layer. + * + * The config file api is mkldnn_batch_norm + */ +class MKLDNNBatchNormLayer : public MKLDNNLayer { +protected: + // save forward primitive_desc, which can be used backward + std::shared_ptr fwdPD_; + + // Epsilon value used in the batch normalization formula. + static const real EPS; + // weight and bias in paddle + std::unique_ptr weight_; + std::unique_ptr biases_; + // mkldnn use a large buffer store both scale and shift + // which are weight and bias in paddle corresponding. + MatrixPtr valueScaleShift_; + MatrixPtr gradScaleShift_; + // Moving average of mean. + std::unique_ptr movingMean_; + // Moving average of variance. + std::unique_ptr movingVar_; + + // if useGlobalStats_ is true, will use the loaded mean and variance. + // otherwise, calculate mean and variance in every mini-batch. + bool useGlobalStats_; + // used in MKLDNN primitive desc + unsigned flags_; + // use to compute moving mean and variance. + real movingAvgFraction_; + // whether the weight has been init + bool hasInitedWgt_; + + // local mean and variance + // when useGlobalStats_ they are loaded from moving mean and variance + // when do not useGlobalStats_ they are calculated from this mini-batch + MKLDNNMatrixPtr mean_; + MKLDNNMatrixPtr var_; + +public: + explicit MKLDNNBatchNormLayer(const LayerConfig& config) + : MKLDNNLayer(config), useGlobalStats_(true), hasInitedWgt_(false) {} + + ~MKLDNNBatchNormLayer() {} + + bool init(const LayerMap& layerMap, + const ParameterMap& parameterMap) override; + + void forward(PassType passType) override; + + void reshape( + int& bs, int& ic, int& ih, int& iw, int oc, int& oh, int& ow) override; + + void resetFwd(std::vector& pipeline, + MKLDNNMatrixPtr& in, + MKLDNNMatrixPtr& wgt, + MKLDNNMatrixPtr& bias, + MKLDNNMatrixPtr& out) override; + + void resetBwd(std::vector& pipeline, + MKLDNNMatrixPtr& in, + MKLDNNMatrixPtr& wgt, + MKLDNNMatrixPtr& bias, + MKLDNNMatrixPtr& out) override; + + void updateWeights(const UpdateCallback& callback) override; + + void convertWeightsFromPaddle() override; + +protected: + void initWeight(); + /** + * cal moving mean and variance. + * moving = moving * AvgFraction + local * (1 - AvgFraction) + */ + void calMovingMeanAndVar(); + /** + * Forward functions: reset buffers(input, weight, output), + * reset primitive descriptor, + * reset pipeline. + */ + void resetFwdBuffers(MKLDNNMatrixPtr& in, + MKLDNNMatrixPtr& wgt, + MKLDNNMatrixPtr& out); + void resetFwdPD(std::shared_ptr& pd, + MKLDNNMatrixPtr in, + MKLDNNMatrixPtr wgt, + MKLDNNMatrixPtr out); + void resetFwdPipeline(std::vector& pipeline, + std::shared_ptr& pd, + MKLDNNMatrixPtr& in, + MKLDNNMatrixPtr& wgt, + MKLDNNMatrixPtr& out); + + /** + * Backward functions: reset buffers(input, weight, output), + * reset primitive descriptor, + * reset pipeline. + */ + void resetBwdBuffers(MKLDNNMatrixPtr& in, + MKLDNNMatrixPtr& wgt, + MKLDNNMatrixPtr& out); + void resetBwdPD(std::shared_ptr& pd, + MKLDNNMatrixPtr& in, + MKLDNNMatrixPtr& wgt, + MKLDNNMatrixPtr& out); + void resetBwdPipeline(std::vector& pipeline, + std::shared_ptr& pd, + MKLDNNMatrixPtr& in, + MKLDNNMatrixPtr& wgt, + MKLDNNMatrixPtr& out); +}; + +} // namespace paddle diff --git a/paddle/gserver/tests/MKLDNNTester.cpp b/paddle/gserver/tests/MKLDNNTester.cpp index 0a19fe23336ea943cb8a572dc40f8c0fbbd7236a..73b7e8857f35d194e71b2b5b341f89b77fd1f8b0 100644 --- a/paddle/gserver/tests/MKLDNNTester.cpp +++ b/paddle/gserver/tests/MKLDNNTester.cpp @@ -91,10 +91,16 @@ void MKLDNNTester::setInputImgSize() { // init randome parameters of ref, and copy to mkldnn void MKLDNNTester::randomWgtDatas() { EXPECT_EQ(parameters_[DNN].size(), parameters_[REF].size()); + const bool isBN = refLayer_->getType() == "batch_norm"; for (size_t i = 0; i < parameters_[REF].size(); ++i) { const VectorPtr& dnnValue = parameters_[DNN][i]->getBuf(PARAMETER_VALUE); const VectorPtr& refValue = parameters_[REF][i]->getBuf(PARAMETER_VALUE); parameters_[REF][i]->randomize(); + if (isBN && i == 2) { + // this param is moving average in batch norm, which must larger than 0 + real offset = fabs(refValue->getMin()) + 1.0; + refValue->add(offset); + } dnnValue->copyFrom(*refValue); VLOG(MKLDNN_TESTS) << "Random weight " << parameters_[DNN][i]->getName(); @@ -132,8 +138,7 @@ void MKLDNNTester::checkForward() { void MKLDNNTester::checkBackwardData() { VLOG(MKLDNN_TESTS) << "Check Backward Data"; - // TODO(TJ): uncomment me when batch norm ready - // const bool isBN = dnnLayer_->getType() == "mkldnn_batch_norm"; + const bool isBN = refLayer_->getType() == "batch_norm"; for (size_t i = 0; i < dataLayers_[DNN].size(); ++i) { const MatrixPtr& dnnDiff = dataLayers_[DNN][i]->getOutputGrad(); const MatrixPtr& refDiff = dataLayers_[REF][i]->getOutputGrad(); @@ -144,11 +149,11 @@ void MKLDNNTester::checkBackwardData() { double delta = compareMatrix(dnnDiff, refDiff); EXPECT_LE(fabs(delta), eps_); - // TODO(TJ): uncomment me when batch norm ready - // if (isBN) { - // // the other two inputs in batch norm are for moving mean and var - // break; - // } + if (isBN) { + // the other two inputs in batch norm are for moving mean and var + // do not have grad to compare + break; + } } } @@ -308,10 +313,14 @@ double MKLDNNTester::compareVector(const VectorPtr& v1, const VectorPtr& v2) { void MKLDNNTester::runOnce() { // test forward randomBotDatas(); - dnnLayer_->forward(PASS_TRAIN); - refLayer_->forward(PASS_TRAIN); + dnnLayer_->forward(passType_); + refLayer_->forward(passType_); checkForward(); + if (passType_ == PASS_TEST) { + return; + } + // test backward // simple updater UpdateCallback updateCallback = [](Parameter* para) { @@ -343,6 +352,7 @@ void MKLDNNTester::run(const TestConfig& dnn, size_t batchSize, size_t inputImgH, size_t inputImgW, + PassType passType, bool printDetails, size_t iter, float epsilon) { @@ -361,6 +371,7 @@ void MKLDNNTester::run(const TestConfig& dnn, ih_ = inputImgH; iw_ = inputImgW; + passType_ = passType; log_ = printDetails; iter_ = iter; eps_ = epsilon; diff --git a/paddle/gserver/tests/MKLDNNTester.h b/paddle/gserver/tests/MKLDNNTester.h index c385d1c72717d120211f167b5c5eb9a557da3714..19d8848f74f2ee4a809e42164a0eb180abd2a4e1 100644 --- a/paddle/gserver/tests/MKLDNNTester.h +++ b/paddle/gserver/tests/MKLDNNTester.h @@ -62,12 +62,15 @@ protected: float eps_; /// input image size, default 1 size_t ih_, iw_; + /// passType, PASS_TRAIN, PASS_TEST or PASS_GC (Gradient Check pass) + PassType passType_; public: explicit MKLDNNTester(size_t iter = 3, float epsilon = 1e-4) { iter_ = iter; eps_ = epsilon; log_ = false; + passType_ = PASS_TRAIN; } ~MKLDNNTester() {} @@ -78,6 +81,7 @@ public: size_t batchSize, size_t inputImgH = 1, size_t inputImgW = 1, + PassType passType = PASS_TRAIN, bool printDetails = false, size_t iter = 3, float epsilon = 1e-4); diff --git a/paddle/gserver/tests/test_MKLDNN.cpp b/paddle/gserver/tests/test_MKLDNN.cpp index 6cb4ca5e08eab5b979e404c9e09dcfec11086c22..85d4f437c2664135a7975c6ed3270d8f1ddbeaf4 100644 --- a/paddle/gserver/tests/test_MKLDNN.cpp +++ b/paddle/gserver/tests/test_MKLDNN.cpp @@ -212,6 +212,66 @@ TEST(MKLDNNLayer, PoolLayer) { testPoolLayer({2, 8, 56, 56, 29, 29, 3, 3, 1, 1, 2, 2}); } +struct testBatchNormDesc { + int bs; + int ic; + int ih, iw; +}; + +static void getMKLDNNBatchNormConfig(TestConfig& cfg, + const testBatchNormDesc& pm) { + cfg.layerConfig.set_size(pm.ic * pm.ih * pm.iw); + cfg.layerConfig.set_type("mkldnn_batch_norm"); + cfg.biasSize = pm.ic; + cfg.inputDefs.push_back( + {INPUT_DATA, + "layer_0", + /* size of input layer= */ size_t(pm.ic * pm.ih * pm.iw), + /* size of weight= */ size_t(pm.ic)}); + cfg.inputDefs.push_back( + {INPUT_DATA, "layer_1_moving_mean", 1, size_t(pm.ic)}); + cfg.inputDefs.back().isStatic = true; + cfg.inputDefs.push_back({INPUT_DATA, "layer_2_moving_var", 1, size_t(pm.ic)}); + cfg.inputDefs.back().isStatic = true; + LayerInputConfig* input = cfg.layerConfig.add_inputs(); + // TODO(TJ): uncomment me when refine and support comparing all zeroes vector + // cfg.layerConfig.set_active_type("relu"); + cfg.layerConfig.add_inputs(); + cfg.layerConfig.add_inputs(); + ImageConfig* img_conf = input->mutable_image_conf(); + img_conf->set_channels(pm.ic); + img_conf->set_img_size_y(pm.ih); + img_conf->set_img_size(pm.iw); +} + +void testBatchNormLayer(const testBatchNormDesc& pm) { + TestConfig dnnConfig; + getMKLDNNBatchNormConfig(dnnConfig, pm); + TestConfig refConfig = dnnConfig; + refConfig.layerConfig.set_type("batch_norm"); + // for PASS_TRAIN, use_global_stats always should be false, and batchsize != 1 + VLOG(MKLDNN_TESTS) << "check train phase"; + dnnConfig.layerConfig.set_use_global_stats(false); + refConfig.layerConfig.set_use_global_stats(false); + MKLDNNTester tester; + tester.run(dnnConfig, refConfig, pm.bs, pm.ih, pm.iw, PASS_TRAIN); + // for PASS_TEST, check use_global_stats true and false, and batchsize 1 + VLOG(MKLDNN_TESTS) << "check test phase"; + for (auto useGS : {false, true}) { + dnnConfig.layerConfig.set_use_global_stats(useGS); + refConfig.layerConfig.set_use_global_stats(useGS); + MKLDNNTester tester; + for (auto bs : {pm.bs, 1}) { + tester.run(dnnConfig, refConfig, bs, pm.ih, pm.iw, PASS_TEST); + } + } +} + +TEST(MKLDNNLayer, BatchNormLayer) { + testBatchNormLayer({4, 10, 6, 6}); + testBatchNormLayer({16, 32, 16, 16}); +} + struct testActDesc { int bs, ic, ih, iw; }; diff --git a/paddle/math/MKLDNNMatrix.h b/paddle/math/MKLDNNMatrix.h index fe755d096da9713e39581a909e5d21aa93d69f0f..2b62d4e11ac7276924947ab47360ffca84240aea 100644 --- a/paddle/math/MKLDNNMatrix.h +++ b/paddle/math/MKLDNNMatrix.h @@ -91,6 +91,11 @@ public: const MKLDNNMatrixPtr& dst, bool checkData = true); + void copyFrom(const Matrix& src) { + // TODO(TJ): reorder data if this format is not nchw or x + m_->copyFrom(src); + } + public: /** * Reorder this MKLDNNMatrix from other format. diff --git a/paddle/math/RowBuffer.h b/paddle/math/RowBuffer.h index 9ef5b89680b00981188d78cb312dc75e2c0a79ee..e457d71f1b357aecae48107688499edd7271a5db 100644 --- a/paddle/math/RowBuffer.h +++ b/paddle/math/RowBuffer.h @@ -60,7 +60,7 @@ public: */ inline real* get(int row) const { if (preallocatedBuf_) { - CHECK_LE((row + 1) * width_ * sizeof(real), preallocatedBuf_->getSize()); + CHECK_LE((row)*width_ * sizeof(real), preallocatedBuf_->getSize()); return reinterpret_cast(preallocatedBuf_->getBuf()) + row * width_; } else { CHECK_LE((row + 1) * width_, rowStore_.size()); diff --git a/paddle/memory/memcpy.h b/paddle/memory/memcpy.h index 9b36182c2b619317da31310141823442d8fd3f94..29c20e18601b71bac5201df8ff0c7ce0bed702dc 100644 --- a/paddle/memory/memcpy.h +++ b/paddle/memory/memcpy.h @@ -54,6 +54,5 @@ void Copy(DstPlace, void* dst, SrcPlace, const void* src, size_t num, cudaStream_t stream); #endif - } // namespace memory } // namespace paddle diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt index d2d70d8be71208cfa9673f6a6936b1bca16d7426..1ca4ba29d7f1b5e4aeecf7d352f68c1717f288a4 100644 --- a/paddle/operators/CMakeLists.txt +++ b/paddle/operators/CMakeLists.txt @@ -82,7 +82,7 @@ function(op_library TARGET) # It's enough to just adding one operator to pybind file(APPEND ${pybind_file} "USE_OP(sigmoid);\n") endif() - + # reduce_op contains several operators if ("${TARGET}" STREQUAL "reduce_op") set(pybind_flag 1) @@ -148,3 +148,4 @@ cc_test(net_op_test SRCS net_op_test.cc DEPS net_op) cc_test(scatter_test SRCS scatter_test.cc DEPS tensor) cc_test(strided_memcpy_test SRCS strided_memcpy_test.cc DEPS tensor paddle_memory) cc_test(dynamic_recurrent_op_test SRCS dynamic_recurrent_op_test.cc DEPS dynamic_recurrent_op recurrent_op tensor_array) +cc_test(save_load_op_test SRCS save_load_op_test.cc DEPS save_op load_op) diff --git a/paddle/operators/batch_norm_op.cu b/paddle/operators/batch_norm_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..6ba6ee12ec7b0a5dc2ffcdfd7519377c8f32fef8 --- /dev/null +++ b/paddle/operators/batch_norm_op.cu @@ -0,0 +1,262 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/batch_norm_op.h" + +#include +#include "paddle/operators/math/math_function.h" +#include "paddle/platform/cudnn_helper.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +template +using CudnnDataType = platform::CudnnDataType; + +void ExtractNCWHD(const framework::DDim &dims, + const TensorFormat &tensor_format, int *N, int *C, int *H, + int *W, int *D) { + *N = dims[0]; + *C = tensor_format == TensorFormat::NCHW ? dims[1] : dims[dims.size() - 1]; + *H = tensor_format == TensorFormat::NCHW ? dims[2] : dims[1]; + *W = dims.size() > 3 + ? (tensor_format == TensorFormat::NCHW ? dims[3] : dims[2]) + : 1; + *D = dims.size() > 4 + ? (tensor_format == TensorFormat::NCHW ? dims[4] : dims[3]) + : 1; +} + +template +class BatchNormKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), + "It must use GPUPlace."); + double epsilon = static_cast(ctx.Attr("epsilon")); + const float momentum = ctx.Attr("momentum"); + const bool is_test = ctx.Attr("is_test"); + const std::string tensor_format_str = + ctx.Attr("tensor_format"); + const TensorFormat tensor_format = StringToTensorFormat(tensor_format_str); + + // Get the size for each dimension. + // NCHW [batch_size, in_channels, in_height, in_width] + const auto *x = ctx.Input("X"); + const auto &x_dims = x->dims(); + PADDLE_ENFORCE(x_dims.size() >= 3 && x_dims.size() <= 5, + "The Input dim size should be between 3 and 5"); + int N, C, H, W, D; + ExtractNCWHD(x_dims, tensor_format, &N, &C, &H, &W, &D); + + // ------------------- cudnn descriptors --------------------- + cudnnTensorDescriptor_t data_desc_; + cudnnTensorDescriptor_t bn_param_desc_; + cudnnBatchNormMode_t mode_; + + CUDNN_ENFORCE(platform::dynload::cudnnCreateTensorDescriptor(&data_desc_)); + CUDNN_ENFORCE( + platform::dynload::cudnnCreateTensorDescriptor(&bn_param_desc_)); + + if (epsilon <= CUDNN_BN_MIN_EPSILON - FLT_EPSILON) { + LOG(ERROR) << "Provided epsilon is smaller than " + << "CUDNN_BN_MIN_EPSILON. Setting it to " + << "CUDNN_BN_MIN_EPSILON instead."; + } + epsilon = std::max(epsilon, CUDNN_BN_MIN_EPSILON); +#if CUDNN_VERSION_MIN(7, 0, 0) + mode_ = CUDNN_BATCHNORM_SPATIAL_PERSISTENT; +#else + mode_ = CUDNN_BATCHNORM_SPATIAL; +#endif + + VLOG(1) << "Setting descriptors."; + std::vector dims; + std::vector strides; + if (tensor_format == TensorFormat::NCHW) { + dims = {N, C, H, W, D}; + strides = {C * H * W * D, H * W * D, W * D, D, 1}; + } else { + dims = {N, C, H, W, D}; + strides = {H * W * D * C, 1, W * D * C, D * C, C}; + } + CUDNN_ENFORCE(platform::dynload::cudnnSetTensorNdDescriptor( + data_desc_, CudnnDataType::type, + x_dims.size() > 3 ? x_dims.size() : 4, dims.data(), strides.data())); + CUDNN_ENFORCE(platform::dynload::cudnnDeriveBNTensorDescriptor( + bn_param_desc_, data_desc_, mode_)); + + const auto *scale = ctx.Input("Scale"); + const auto *bias = ctx.Input("Bias"); + + auto *y = ctx.Output("Y"); + auto *mean_out = ctx.Output("MeanOut"); + auto *variance_out = ctx.Output("VarianceOut"); + auto *saved_mean = ctx.Output("SavedMean"); + auto *saved_variance = ctx.Output("SavedVariance"); + + // alloc memory + y->mutable_data(ctx.GetPlace()); + mean_out->mutable_data(ctx.GetPlace()); + variance_out->mutable_data(ctx.GetPlace()); + saved_mean->mutable_data(ctx.GetPlace()); + saved_variance->mutable_data(ctx.GetPlace()); + + math::SetConstant functor; + functor(ctx.device_context(), saved_mean, 0); + functor(ctx.device_context(), saved_variance, 0); + // FIXME(qiao) should not set zero self + functor(ctx.device_context(), mean_out, 0); + functor(ctx.device_context(), variance_out, 0); + + auto handle = ctx.cuda_device_context().cudnn_handle(); + + // Now, depending on whether we are running test or not, we have two paths. + if (is_test) { + // only when test we use input to do computation. + const auto *est_mean = ctx.Input("Mean"); + const auto *est_var = ctx.Input("Variance"); + // Run inference mode. + PADDLE_ENFORCE_EQ(est_mean->dims().size(), 1UL); + PADDLE_ENFORCE_EQ(est_var->dims().size(), 1UL); + PADDLE_ENFORCE_EQ(est_mean->dims()[0], C); + PADDLE_ENFORCE_EQ(est_var->dims()[0], C); + + CUDNN_ENFORCE(platform::dynload::cudnnBatchNormalizationForwardInference( + handle, + // Note: PERSISTENT not implemented for inference + CUDNN_BATCHNORM_SPATIAL, CudnnDataType::kOne(), + CudnnDataType::kZero(), data_desc_, x->template data(), + data_desc_, y->template mutable_data(ctx.GetPlace()), + bn_param_desc_, scale->template data(), bias->template data(), + est_mean->template data(), est_var->template data(), epsilon)); + } else { + // Run training mode. + // obtain running mean and running inv var, and see if we need to + // initialize them. + double this_factor = 1. - momentum; + + CUDNN_ENFORCE(platform::dynload::cudnnBatchNormalizationForwardTraining( + handle, mode_, CudnnDataType::kOne(), CudnnDataType::kZero(), + data_desc_, x->template data(), data_desc_, + y->template mutable_data(ctx.GetPlace()), bn_param_desc_, + scale->template data(), bias->template data(), this_factor, + mean_out->template mutable_data(ctx.GetPlace()), + variance_out->template mutable_data(ctx.GetPlace()), epsilon, + saved_mean->template mutable_data(ctx.GetPlace()), + saved_variance->template mutable_data(ctx.GetPlace()))); + } + + // clean when exit. + CUDNN_ENFORCE(platform::dynload::cudnnDestroyTensorDescriptor(data_desc_)); + CUDNN_ENFORCE( + platform::dynload::cudnnDestroyTensorDescriptor(bn_param_desc_)); + } +}; + +template +class BatchNormGradKernel + : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), + "It must use GPUPlace."); + double epsilon = static_cast(ctx.Attr("epsilon")); + const std::string tensor_format_str = + ctx.Attr("tensor_format"); + const TensorFormat tensor_format = StringToTensorFormat(tensor_format_str); + const auto *x = ctx.Input("X"); + const auto *d_y = ctx.Input(framework::GradVarName("Y")); + const auto *scale = ctx.Input("Scale"); + + const auto &x_dims = x->dims(); + + PADDLE_ENFORCE(x_dims.size() >= 3 && x_dims.size() <= 5, + "The Input dim size should be between 3 and 5"); + int N, C, H, W, D; + ExtractNCWHD(x_dims, tensor_format, &N, &C, &H, &W, &D); + + PADDLE_ENFORCE_EQ(scale->dims().size(), 1UL); + PADDLE_ENFORCE_EQ(scale->dims()[0], C); + + // ------------------- cudnn descriptors --------------------- + cudnnTensorDescriptor_t data_desc_; + cudnnTensorDescriptor_t bn_param_desc_; + cudnnBatchNormMode_t mode_; + + CUDNN_ENFORCE(platform::dynload::cudnnCreateTensorDescriptor(&data_desc_)); + CUDNN_ENFORCE( + platform::dynload::cudnnCreateTensorDescriptor(&bn_param_desc_)); + if (epsilon <= CUDNN_BN_MIN_EPSILON - FLT_EPSILON) { + LOG(ERROR) << "Provided epsilon is smaller than " + << "CUDNN_BN_MIN_EPSILON. Setting it to " + << "CUDNN_BN_MIN_EPSILON instead."; + } + epsilon = std::max(epsilon, CUDNN_BN_MIN_EPSILON); +#if CUDNN_VERSION_MIN(7, 0, 0) + mode_ = CUDNN_BATCHNORM_SPATIAL_PERSISTENT; +#else + mode_ = CUDNN_BATCHNORM_SPATIAL; +#endif + + std::vector dims = {N, C, H, W, D}; + std::vector strides = {H * W * C * D, 1, W * D * C, D * C, C}; + CUDNN_ENFORCE(platform::dynload::cudnnSetTensorNdDescriptor( + data_desc_, CudnnDataType::type, + x_dims.size() > 3 ? x_dims.size() : 4, dims.data(), strides.data())); + CUDNN_ENFORCE(platform::dynload::cudnnDeriveBNTensorDescriptor( + bn_param_desc_, data_desc_, mode_)); + + // init output + auto *d_x = ctx.Output(framework::GradVarName("X")); + auto *d_scale = ctx.Output(framework::GradVarName("Scale")); + auto *d_bias = ctx.Output(framework::GradVarName("Bias")); + + d_x->mutable_data(ctx.GetPlace()); + d_scale->mutable_data(ctx.GetPlace()); + d_bias->mutable_data(ctx.GetPlace()); + + const auto *saved_mean = ctx.Input("SavedMean"); + const auto *saved_var = ctx.Input("SavedVariance"); + const void *saved_mean_data = saved_mean->template data(); + const void *saved_var_data = saved_var->template data(); + + CUDNN_ENFORCE(platform::dynload::cudnnBatchNormalizationBackward( + ctx.cuda_device_context().cudnn_handle(), mode_, + CudnnDataType::kOne(), CudnnDataType::kZero(), + CudnnDataType::kOne(), CudnnDataType::kZero(), data_desc_, + x->template data(), data_desc_, d_y->template data(), data_desc_, + d_x->template mutable_data(ctx.GetPlace()), bn_param_desc_, + scale->template data(), + d_scale->template mutable_data(ctx.GetPlace()), + d_bias->template mutable_data(ctx.GetPlace()), epsilon, + saved_mean_data, saved_var_data)); + + // clean when exit. + CUDNN_ENFORCE(platform::dynload::cudnnDestroyTensorDescriptor(data_desc_)); + CUDNN_ENFORCE( + platform::dynload::cudnnDestroyTensorDescriptor(bn_param_desc_)); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_GPU_KERNEL(batch_norm, + ops::BatchNormKernel); +REGISTER_OP_GPU_KERNEL( + batch_norm_grad, + ops::BatchNormGradKernel); diff --git a/paddle/operators/fill_constant_batch_size_like_op.cc b/paddle/operators/fill_constant_batch_size_like_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..58c9f1cd2c79c150aaed7753641f6ad6120dd0f5 --- /dev/null +++ b/paddle/operators/fill_constant_batch_size_like_op.cc @@ -0,0 +1,82 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/fill_constant_batch_size_like_op.h" + +namespace paddle { +namespace operators { + +class FillConstantBatchSizeLikeOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE( + ctx->HasInput("Input"), + "Input(Input) of FillConstantBatchSizeLikeOp should not be null."); + PADDLE_ENFORCE( + ctx->HasOutput("Out"), + "Output(Out) of FillConstantBatchSizeLikeOp should not be null."); + + auto &shape = ctx->Attrs().Get>("shape"); + PADDLE_ENFORCE_GT(shape.size(), 0); + std::vector shape_int64(shape.size(), 0); + std::transform(shape.begin(), shape.end(), shape_int64.begin(), + [](int a) { return static_cast(a); }); + auto dims = framework::make_ddim(shape_int64); + + dims[0] = ctx->GetInputDim("Input")[0]; + ctx->SetOutputDim("Out", dims); + } + + protected: + framework::DataType IndicateDataType( + const framework::ExecutionContext &ctx) const override { + return static_cast(ctx.Attr("data_type")); + } +}; + +class FillConstantBatchSizeLikeOpMaker + : public framework::OpProtoAndCheckerMaker { + public: + FillConstantBatchSizeLikeOpMaker(framework::OpProto *proto, + framework::OpAttrChecker *op_checker) + : framework::OpProtoAndCheckerMaker(proto, op_checker) { + AddAttr("data_type", + "(int, default 5 (FP32)) " + "Output data type") + .SetDefault(framework::DataType::FP32); + AddAttr>("shape", "(vector) The shape of the output"); + AddAttr("value", "(float, default 0) The value to be filled") + .SetDefault(0.0f); + AddInput("Input", + "(Tensor) Tensor " + "whose first dimension is used to specify the batch_size"); + AddOutput("Out", + "(Tensor) Tensor of specified shape will be filled " + "with the specified value"); + AddComment(R"DOC(Fill up a variable with specified constant value.)DOC"); + } +}; +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_WITHOUT_GRADIENT(fill_constant_batch_size_like, + ops::FillConstantBatchSizeLikeOp, + ops::FillConstantBatchSizeLikeOpMaker); +REGISTER_OP_CPU_KERNEL( + fill_constant_batch_size_like, + ops::FillConstantBatchSizeLikeOpKernel, + ops::FillConstantBatchSizeLikeOpKernel); diff --git a/paddle/operators/fill_constant_batch_size_like_op.cu b/paddle/operators/fill_constant_batch_size_like_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..cfa5df001e9d6c606751e3ca3cddda02812ef180 --- /dev/null +++ b/paddle/operators/fill_constant_batch_size_like_op.cu @@ -0,0 +1,23 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#define EIGEN_USE_GPU +#include "paddle/framework/op_registry.h" +#include "paddle/operators/fill_constant_batch_size_like_op.h" + +namespace ops = paddle::operators; +REGISTER_OP_GPU_KERNEL( + fill_constant_batch_size_like, + ops::FillConstantBatchSizeLikeOpKernel, + ops::FillConstantBatchSizeLikeOpKernel); diff --git a/paddle/operators/fill_constant_batch_size_like_op.h b/paddle/operators/fill_constant_batch_size_like_op.h new file mode 100644 index 0000000000000000000000000000000000000000..a360e6683ec7204ea5bdbe27ca88a0ac51c983ac --- /dev/null +++ b/paddle/operators/fill_constant_batch_size_like_op.h @@ -0,0 +1,37 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/framework/eigen.h" +#include "paddle/framework/op_registry.h" + +namespace paddle { +namespace operators { + +template +class FillConstantBatchSizeLikeOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* out = ctx.Output("Out"); + out->mutable_data(ctx.GetPlace()); + auto value = ctx.Attr("value"); + + auto out_eigen = framework::EigenVector::Flatten(*out); + auto place = ctx.GetEigenDevice(); + out_eigen.device(place) = out_eigen.constant(static_cast(value)); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/load_op.cc b/paddle/operators/load_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..2d4eff0c35af520dd27b9eb197937026a8fbdff9 --- /dev/null +++ b/paddle/operators/load_op.cc @@ -0,0 +1,132 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/framework/op_registry.h" + +#include + +namespace paddle { +namespace operators { + +class LoadOp : public framework::OperatorBase { + public: + LoadOp(const std::string &type, const framework::VariableNameMap &inputs, + const framework::VariableNameMap &outputs, + const framework::AttributeMap &attrs) + : OperatorBase(type, inputs, outputs, attrs) {} + void Run(const framework::Scope &scope, + const platform::DeviceContext &dev_ctx) const override { + auto filename = Attr("file_path"); + std::ifstream fin(filename); + PADDLE_ENFORCE(static_cast(fin), "Cannot open file %s for load op", + filename); + + auto out_var_name = Output("Out"); + auto *out_var = scope.FindVar(out_var_name); + PADDLE_ENFORCE(out_var != nullptr, "Output variable %s cannot be found", + out_var_name); + + auto *tensor = out_var->GetMutable(); + + uint32_t version; + fin.read(reinterpret_cast(&version), sizeof(version)); + PADDLE_ENFORCE_EQ(version, 0U, "Only version 0 is supported"); + framework::TensorDesc desc; + { // int32_t size + // proto buffer + int32_t size; + fin.read(reinterpret_cast(&size), sizeof(size)); + std::unique_ptr buf(new char[size]); + fin.read(reinterpret_cast(buf.get()), size); + PADDLE_ENFORCE(desc.ParseFromArray(buf.get(), size), + "Cannot parse tensor desc"); + } + { // read tensor + std::vector dims; + dims.reserve(static_cast(desc.dims().size())); + std::copy(desc.dims().begin(), desc.dims().end(), + std::back_inserter(dims)); + tensor->Resize(framework::make_ddim(dims)); + + void *buf; + platform::Place cpu = platform::CPUPlace(); + switch (desc.data_type()) { + case framework::FP32: + buf = tensor->mutable_data(cpu); + break; + case framework::FP64: + buf = tensor->mutable_data(cpu); + break; + case framework::INT32: + buf = tensor->mutable_data(cpu); + break; + case framework::INT64: + buf = tensor->mutable_data(cpu); + break; + default: + PADDLE_THROW("DataType %d not supported", desc.data_type()); + } + fin.read(static_cast(buf), tensor->memory_size()); + } + { // read lod + uint64_t lod_level; + fin.read(reinterpret_cast(&lod_level), sizeof(lod_level)); + auto &lod = *tensor->mutable_lod(); + lod.resize(lod_level); + for (uint64_t i = 0; i < lod_level; ++i) { + uint64_t size; + fin.read(reinterpret_cast(&size), sizeof(size)); + std::vector tmp(size / sizeof(size_t)); + fin.read(reinterpret_cast(tmp.data()), + static_cast(size)); + lod[i] = tmp; + } + } + + auto place = dev_ctx.GetPlace(); + if (platform::is_gpu_place(place)) { + // copy CPU to GPU + framework::LoDTensor cpu_tensor; + cpu_tensor.ShareDataWith(*tensor); + cpu_tensor.set_lod(tensor->lod()); + + // reset tensor + out_var->Clear(); + tensor = out_var->GetMutable(); + tensor->set_lod(cpu_tensor.lod()); + tensor->CopyFrom(cpu_tensor, place, dev_ctx); + } + } +}; + +class LoadOpProtoMaker : public framework::OpProtoAndCheckerMaker { + public: + LoadOpProtoMaker(framework::OpProto *proto, + framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddOutput("Out", "The tensor need to be loaded"); + AddComment(R"DOC(Load Operator +Load operator will load a tensor variable from disk file. +)DOC"); + AddAttr("file_path", + "Variable will be loaded from \"file_path\".") + .AddCustomChecker( + [](const std::string &path) { return !path.empty(); }); + } +}; +} // namespace operators +} // namespace paddle +namespace ops = paddle::operators; + +REGISTER_OPERATOR(load, ops::LoadOp, ops::LoadOpProtoMaker); diff --git a/paddle/operators/save_load_op_test.cc b/paddle/operators/save_load_op_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..fe2b15ec09c6d29ad5f78e5c36f534c6a88497e6 --- /dev/null +++ b/paddle/operators/save_load_op_test.cc @@ -0,0 +1,63 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "gtest/gtest.h" +#include "paddle/framework/op_registry.h" + +USE_NO_KERNEL_OP(save); +USE_NO_KERNEL_OP(load); + +TEST(SaveLoadOp, CPU) { + paddle::framework::Scope scope; + paddle::platform::CPUPlace place; + paddle::platform::CPUDeviceContext ctx(place); + auto var = scope.Var("test_var"); + auto tensor = var->GetMutable(); + tensor->Resize({10, 10}); + paddle::framework::LoD expect_lod; + expect_lod.resize(1); + expect_lod[0].push_back(0); + expect_lod[0].push_back(1); + expect_lod[0].push_back(2); + expect_lod[0].push_back(3); + + tensor->set_lod(expect_lod); + int* expect = tensor->mutable_data(place); + for (size_t i = 0; i < paddle::framework::product(tensor->dims()); ++i) { + expect[i] = static_cast(i); + } + paddle::framework::AttributeMap attrs; + attrs.insert({"file_path", std::string("tensor.save")}); + + auto save_op = paddle::framework::OpRegistry::CreateOp( + "save", {{"X", {"test_var"}}}, {}, attrs); + save_op->Run(scope, ctx); + + auto load_var = scope.Var("out_var"); + auto target = load_var->GetMutable(); + auto load_op = paddle::framework::OpRegistry::CreateOp( + "load", {}, {{"Out", {"out_var"}}}, attrs); + load_op->Run(scope, ctx); + int* actual = target->data(); + for (size_t i = 0; i < paddle::framework::product(tensor->dims()); ++i) { + EXPECT_EQ(expect[i], actual[i]); + } + auto& actual_lod = target->lod(); + EXPECT_EQ(expect_lod.size(), actual_lod.size()); + for (size_t i = 0; i < expect_lod.size(); ++i) { + for (size_t j = 0; j < expect_lod[i].size(); ++j) { + EXPECT_EQ(expect_lod[i][j], actual_lod[i][j]); + } + } +} \ No newline at end of file diff --git a/paddle/operators/save_op.cc b/paddle/operators/save_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..490256dfa1cf9b891713dac264e9260906ce1025 --- /dev/null +++ b/paddle/operators/save_op.cc @@ -0,0 +1,184 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include +#include +#include +#include + +#include "paddle/framework/data_type.h" +#include "paddle/framework/framework.pb.h" +#include "paddle/framework/lod_tensor.h" +#include "paddle/framework/op_registry.h" + +namespace paddle { +namespace operators { + +// TODO(yuyang18): If the functions below are needed by other files, move them +// to paddle::filesystem namespace. +constexpr char kSEP = '/'; +static bool FileExists(const std::string &filepath) { + struct stat buffer; + return (stat(filepath.c_str(), &buffer) == 0); +} + +static std::string DirName(const std::string &filepath) { + auto pos = filepath.rfind(kSEP); + if (pos == std::string::npos) { + return ""; + } + return filepath.substr(0, pos); +} + +static void MkDir(const char *path) { + if (mkdir(path, 0755)) { + PADDLE_ENFORCE_EQ(errno, EEXIST, "%s mkdir failed!", path); + } +} + +static void MkDirRecursively(const char *fullpath) { + if (*fullpath == '\0') return; // empty string + if (FileExists(fullpath)) return; + + MkDirRecursively(DirName(fullpath).c_str()); + MkDir(fullpath); +} + +class SaveOp : public framework::OperatorBase { + public: + SaveOp(const std::string &type, const framework::VariableNameMap &inputs, + const framework::VariableNameMap &outputs, + const framework::AttributeMap &attrs) + : OperatorBase(type, inputs, outputs, attrs) {} + void Run(const framework::Scope &scope, + const platform::DeviceContext &dev_ctx) const override { + auto filename = Attr("file_path"); + auto overwrite = Attr("overwrite"); + + if (FileExists(filename) && !overwrite) { + PADDLE_THROW("%s is existed, cannot save to it when overwrite=false", + filename, overwrite); + } + + MkDirRecursively(DirName(filename).c_str()); + + // FIXME(yuyang18): We save variable to local file now, but we should change + // it to save an output stream. + std::ofstream fout(filename); + PADDLE_ENFORCE(static_cast(fout), "Cannot open %s to write", + filename); + + auto iname = Input("X"); + auto *var = scope.FindVar(iname); + PADDLE_ENFORCE(var != nullptr, "Cannot find variable %s for save_op", + iname); + + PADDLE_ENFORCE(var->IsType(), + "SaveOp only support LoDTensor, %s has wrong type", iname); + + auto &tensor = var->Get(); + + { // the 1st field, uint32_t version + constexpr uint32_t version = 0; + fout.write(reinterpret_cast(&version), sizeof(version)); + } + { // the 2nd field, tensor description + // int32_t size + // void* protobuf message + framework::TensorDesc desc; + desc.set_data_type(framework::ToDataType(tensor.type())); + auto dims = framework::vectorize(tensor.dims()); + auto *pb_dims = desc.mutable_dims(); + pb_dims->Resize(static_cast(dims.size()), 0); + std::copy(dims.begin(), dims.end(), pb_dims->begin()); + int32_t size = desc.ByteSize(); + fout.write(reinterpret_cast(&size), sizeof(size)); + auto out = desc.SerializeAsString(); + fout.write(out.data(), size); + } + { // the 3rd field, tensor data + uint64_t size = tensor.memory_size(); + auto *data_ptr = tensor.data(); + PADDLE_ENFORCE(size < std::numeric_limits::max(), + "Index overflow when writing tensor"); + if (platform::is_gpu_place(tensor.place())) { +#ifdef PADDLE_WITH_CUDA + constexpr size_t kBufSize = 1024 * 1024 * 64; // 64MB + std::unique_ptr buf(new char[kBufSize]); + auto &gpu_dev_ctx = + static_cast(dev_ctx); + platform::CPUPlace cpu; + uintptr_t data = reinterpret_cast(data_ptr); + while (size != 0) { + size_t size_to_write = std::min(kBufSize, static_cast(size)); + memory::Copy(cpu, buf.get(), + boost::get(tensor.place()), + reinterpret_cast(data), size_to_write, + gpu_dev_ctx.stream()); + gpu_dev_ctx.Wait(); + fout.write(buf.get(), size_to_write); + data += size_to_write; + size -= size_to_write; + } +#else + PADDLE_THROW("Unexpected branch"); +#endif + } else { + fout.write(static_cast(data_ptr), + static_cast(size)); + } + } + { // the 4th field, lod information + // uint64_t lod_level + // uint64_t lod_level_1 size in byte. + // int* lod_level_1 data + // ... + auto lod = tensor.lod(); + uint64_t size = lod.size(); + fout.write(reinterpret_cast(&size), sizeof(size)); + + for (auto &each : lod) { + size = each.size() * sizeof(framework::LoD::value_type::value_type); + fout.write(reinterpret_cast(&size), sizeof(size)); + fout.write(reinterpret_cast(each.data()), + static_cast(size)); + } + } + } +}; + +class SaveOpProtoMaker : public framework::OpProtoAndCheckerMaker { + public: + SaveOpProtoMaker(framework::OpProto *proto, + framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", "The tensor need to be saved"); + AddComment(R"DOC(Save operator +Save operator will serialize and write a tensor variable to disk file. +)DOC"); + AddAttr("overwrite", "Overwrite the output file if exist") + .SetDefault(true); + AddAttr("file_path", + "Variable will be saved to \"file_path\".") + .AddCustomChecker( + [](const std::string &path) { return !path.empty(); }); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OPERATOR(save, ops::SaveOp, ops::SaveOpProtoMaker); diff --git a/paddle/operators/save_restore_op.cc b/paddle/operators/save_restore_op.cc deleted file mode 100644 index 314e4e927924bf0442b7afe0184bf344e24c1521..0000000000000000000000000000000000000000 --- a/paddle/operators/save_restore_op.cc +++ /dev/null @@ -1,147 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. */ - -#include "paddle/framework/eigen.h" -#include "paddle/framework/op_registry.h" - -#include - -namespace paddle { -namespace operators { - -using framework::Tensor; -using framework::LoDTensor; - -inline static std::string VarToFileName(const std::string& folder_path, - const std::string& var_name) { - return folder_path + "/__" + var_name + "__"; -} - -class SaveOp : public framework::OperatorBase { - public: - SaveOp(const std::string& type, const framework::VariableNameMap& inputs, - const framework::VariableNameMap& outputs, - const framework::AttributeMap& attrs) - : OperatorBase(type, inputs, outputs, attrs) {} - - void Run(const framework::Scope& scope, - const platform::DeviceContext& dev_ctx) const override { - const auto& var_names = this->Inputs("X"); - for (const auto& name : var_names) { - PADDLE_ENFORCE_NOT_NULL(scope.FindVar(name), - "Can not find variable '%s' in the scope.", name); - } - std::string folder_path = this->Attr("folderPath"); - PADDLE_ENFORCE(!folder_path.empty(), - "'folderPath' of SaveOp shouldn't be empty."); - - VLOG(1) << "Save variables to folder: " << folder_path; - for (const auto& name : var_names) { - std::string file_name = VarToFileName(folder_path, name); - std::ofstream fout(file_name, std::ofstream::out); - PADDLE_ENFORCE(fout.is_open(), "Fail to create file %s.", file_name); - const LoDTensor& tensor = scope.FindVar(name)->Get(); - std::string bytes = tensor.SerializeToString(); - fout << bytes; - fout.close(); - } - VLOG(1) << "Compelete saving variables. Items count: " << var_names.size(); - } -}; - -class SaveOpMaker : public framework::OpProtoAndCheckerMaker { - public: - SaveOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker) - : OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("X", - "(tensor), the tensor count can be 1~INT_MAX, tensors names which " - "values will be saved.") - .AsDuplicable(); - AddAttr("folderPath", "the folderPath for save model."); - AddComment(R"DOC( -Save the input tensors to a binary file based on input tensor names and absolute path. - -All the inputs can carry the LoD (Level of Details) information, -or not. -)DOC"); - } -}; - -class RestoreOp : public framework::OperatorBase { - public: - RestoreOp(const std::string& type, const framework::VariableNameMap& inputs, - const framework::VariableNameMap& outputs, - const framework::AttributeMap& attrs) - : OperatorBase(type, inputs, outputs, attrs) {} - - void Run(const framework::Scope& scope, - const platform::DeviceContext& dev_ctx) const override { - const auto& var_names = this->Outputs("Out"); - for (const auto& name : var_names) { - PADDLE_ENFORCE_NOT_NULL(scope.FindVar(name), - "Can not find variable '%s' in the scope.", name); - } - std::string folder_path = this->Attr("folderPath"); - PADDLE_ENFORCE(!folder_path.empty(), - "'folderPath' of RestoreOp shouldn't be empty."); - - VLOG(1) << "Try loading variables from folder: " << folder_path; - - for (const auto& name : var_names) { - std::string file_name = VarToFileName(folder_path, name); - std::ifstream fin(file_name, std::ifstream::in); - PADDLE_ENFORCE(fin.is_open(), "Fail to open file %s.", file_name); - const size_t kBufferSize = 4096; // equal to linux page size - char buffer[kBufferSize]; - std::string cache; - while (!fin.eof()) { - fin.read(buffer, kBufferSize); - cache.append(buffer, fin.gcount()); - } - LoDTensor* tensor = scope.FindVar(name)->GetMutable(); - tensor->DeserializeFromString(cache, dev_ctx.GetPlace()); - fin.close(); - } - VLOG(1) << "Complete loading variables."; - } -}; - -class RestoreOpMaker : public framework::OpProtoAndCheckerMaker { - public: - RestoreOpMaker(framework::OpProto* proto, - framework::OpAttrChecker* op_checker) - : OpProtoAndCheckerMaker(proto, op_checker) { - AddOutput("Out", - "(tensor), the tensor count can be 1~INT_MAX, tensors which " - "values will be restores.") - .AsDuplicable(); - AddAttr("folderPath", "the folderPath for model file."); - AddAttr("data_type", "output tensor data type") - .SetDefault(framework::DataType::FP32); - AddComment(R"DOC( -Restore the tensors from model file based on absolute path. - -All the tensors outputs may carry the LoD (Level of Details) information, -or not. -)DOC"); - } -}; - -} // namespace operators -} // namespace paddle - -REGISTER_OPERATOR(save, paddle::operators::SaveOp, - paddle::framework::EmptyGradOpMaker, - paddle::operators::SaveOpMaker); - -REGISTER_OPERATOR(restore, paddle::operators::RestoreOp, - paddle::framework::EmptyGradOpMaker, - paddle::operators::RestoreOpMaker); diff --git a/paddle/operators/squared_l2_norm_op.cc b/paddle/operators/squared_l2_norm_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..42ad87e65a85355e1b9b927dcef9ebbb88cde717 --- /dev/null +++ b/paddle/operators/squared_l2_norm_op.cc @@ -0,0 +1,78 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/operators/squared_l2_norm_op.h" + +namespace paddle { +namespace operators { + +using framework::Tensor; + +class SquaredL2NormOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should be not null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) should be not null."); + + ctx->SetOutputDim("Out", {1}); + } +}; + +class SquaredL2NormGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should be not null."); + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), + "Input(Out@GRAD) should be not null."); + PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")), + "Output(X@GRAD) should be not null."); + + ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); + } +}; + +class SquaredL2NormOpMaker : public framework::OpProtoAndCheckerMaker { + public: + SquaredL2NormOpMaker(framework::OpProto* proto, + framework::OpAttrChecker* op_checker) + : framework::OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", "(Tensor) The input of squared_l2_norm op."); + AddOutput("Out", "(Float) The output of squared_l2_norm op."); + AddComment(R"DOC( +SquaredL2Norm Operator. + +Computes the squared L2 norm of a tensor. + +Out = sum (X ** 2) + +)DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP(squared_l2_norm, ops::SquaredL2NormOp, ops::SquaredL2NormOpMaker, + squared_l2_norm_grad, ops::SquaredL2NormGradOp); +REGISTER_OP_CPU_KERNEL( + squared_l2_norm, + ops::SquaredL2NormKernel); +REGISTER_OP_CPU_KERNEL( + squared_l2_norm_grad, + ops::SquaredL2NormGradKernel); diff --git a/paddle/operators/squared_l2_norm_op.cu b/paddle/operators/squared_l2_norm_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..d384e9c28c9150fa901404478739ff809f29126f --- /dev/null +++ b/paddle/operators/squared_l2_norm_op.cu @@ -0,0 +1,24 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#define EIGEN_USE_GPU +#include "paddle/operators/squared_l2_norm_op.h" + +namespace ops = paddle::operators; +REGISTER_OP_GPU_KERNEL( + squared_l2_norm, + ops::SquaredL2NormKernel); +REGISTER_OP_GPU_KERNEL( + squared_l2_norm_grad, + ops::SquaredL2NormGradKernel); diff --git a/paddle/operators/squared_l2_norm_op.h b/paddle/operators/squared_l2_norm_op.h new file mode 100644 index 0000000000000000000000000000000000000000..c8d37ac40c1533a77acf78e6a42e1659555127e1 --- /dev/null +++ b/paddle/operators/squared_l2_norm_op.h @@ -0,0 +1,64 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#pragma once +#include "paddle/framework/eigen.h" +#include "paddle/framework/op_registry.h" + +namespace paddle { +namespace operators { + +// Out = sum(square(X)) +template +class SquaredL2NormKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &context) const override { + const framework::Tensor *X = context.Input("X"); + framework::Tensor *Out = context.Output("Out"); + Out->mutable_data(context.GetPlace()); + + auto x = framework::EigenVector::Flatten(*X); + auto out = framework::EigenVector::Flatten(*Out); + auto place = context.GetEigenDevice(); + + out.device(place) = x.square().sum(); + } +}; + +// dX = X +template +class SquaredL2NormGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &context) const override { + const framework::Tensor *X = context.Input("X"); + const framework::Tensor *dOut = + context.Input(framework::GradVarName("Out")); + PADDLE_ENFORCE(dOut->numel() == 1, + "Squared L2 Norm Gradient should be scalar"); + framework::Tensor *dX = + context.Output(framework::GradVarName("X")); + dX->mutable_data(context.GetPlace()); + + auto x = framework::EigenVector::Flatten(*X); + auto dout = framework::EigenVector::Flatten(*dOut); + auto dx = framework::EigenVector::Flatten(*dX); + auto place = context.GetEigenDevice(); + + Eigen::DSizes x_dsize(X->numel()); + dx.device(place) = (dout.broadcast(x_dsize) * x) * static_cast(2.0); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/optimizer/sgd_optimizer.cc b/paddle/optimizer/sgd_optimizer.cc index bf2540ecb092437e57a5970264559dc3c6ab4167..1090419083c8b8cf60eca02791ef673287f4a9a4 100644 --- a/paddle/optimizer/sgd_optimizer.cc +++ b/paddle/optimizer/sgd_optimizer.cc @@ -44,7 +44,7 @@ void SGDOptimizer::DeserializeState(const std::string &str) { this->lr_policy_->DeserializeState(lr_state.SerializeAsString()); num_sample_passed_ = state.num_sample_passed(); ProtoToTensor(state.parameter(), parameter_); - if (momentum_ != 0.0) ProtoToTensor(state.parameter(), momentums_); + if (momentum_ != 0.0) ProtoToTensor(state.momentums(), momentums_); } } // namespace optimizer diff --git a/paddle/optimizer/tensor.h b/paddle/optimizer/tensor.h index 80a8c93081ea7758d3b5ba016a14d424954db913..86fa625e01b981f0377bd699d191fc865ee89784 100644 --- a/paddle/optimizer/tensor.h +++ b/paddle/optimizer/tensor.h @@ -15,7 +15,8 @@ template class TensorT { public: TensorT(size_t size) : height_(1), width_(size) { - data_ptr_ = std::shared_ptr(new T[size], std::default_delete()); + // new T[size]() initializes all element to zero value. + data_ptr_ = std::shared_ptr(new T[size](), std::default_delete()); data_ = data_ptr_.get(); } diff --git a/paddle/platform/cudnn_helper.h b/paddle/platform/cudnn_helper.h index 0c5719ef5162546578253e383209b1893c0cd71f..ce3421a3cb840e4c1e872eea12dedc1150c85962 100644 --- a/paddle/platform/cudnn_helper.h +++ b/paddle/platform/cudnn_helper.h @@ -22,6 +22,47 @@ limitations under the License. */ namespace paddle { namespace platform { +inline const char* cudnnGetErrorString(cudnnStatus_t status) { + switch (status) { + case CUDNN_STATUS_SUCCESS: + return "CUDNN_STATUS_SUCCESS"; + case CUDNN_STATUS_NOT_INITIALIZED: + return "CUDNN_STATUS_NOT_INITIALIZED"; + case CUDNN_STATUS_ALLOC_FAILED: + return "CUDNN_STATUS_ALLOC_FAILED"; + case CUDNN_STATUS_BAD_PARAM: + return "CUDNN_STATUS_BAD_PARAM"; + case CUDNN_STATUS_INTERNAL_ERROR: + return "CUDNN_STATUS_INTERNAL_ERROR"; + case CUDNN_STATUS_INVALID_VALUE: + return "CUDNN_STATUS_INVALID_VALUE"; + case CUDNN_STATUS_ARCH_MISMATCH: + return "CUDNN_STATUS_ARCH_MISMATCH"; + case CUDNN_STATUS_MAPPING_ERROR: + return "CUDNN_STATUS_MAPPING_ERROR"; + case CUDNN_STATUS_EXECUTION_FAILED: + return "CUDNN_STATUS_EXECUTION_FAILED"; + case CUDNN_STATUS_NOT_SUPPORTED: + return "CUDNN_STATUS_NOT_SUPPORTED"; + case CUDNN_STATUS_LICENSE_ERROR: + return "CUDNN_STATUS_LICENSE_ERROR"; + default: + return "Unknown cudnn error number"; + } +} + +#define CUDNN_VERSION_MIN(major, minor, patch) \ + (CUDNN_VERSION >= ((major)*1000 + (minor)*100 + (patch))) + +#define CUDNN_ENFORCE(condition) \ + do { \ + cudnnStatus_t status = condition; \ + if (status != CUDNN_STATUS_SUCCESS) { \ + VLOG(1) << ::paddle::platform::cudnnGetErrorString(status); \ + PADDLE_THROW("cuDNN call failed"); \ + } \ + } while (false) + enum class DataLayout { kNHWC, kNCHW, @@ -40,12 +81,30 @@ template <> class CudnnDataType { public: static const cudnnDataType_t type = CUDNN_DATA_FLOAT; + typedef const float ScalingParamType; + static ScalingParamType* kOne() { + static ScalingParamType v = 1.0; + return &v; + } + static ScalingParamType* kZero() { + static ScalingParamType v = 0.0; + return &v; + } }; template <> class CudnnDataType { public: static const cudnnDataType_t type = CUDNN_DATA_DOUBLE; + typedef const double ScalingParamType; + static ScalingParamType* kOne() { + static ScalingParamType v = 1.0; + return &v; + } + static ScalingParamType* kZero() { + static ScalingParamType v = 0.0; + return &v; + } }; inline cudnnTensorFormat_t GetCudnnTensorFormat(const DataLayout& order) { diff --git a/paddle/platform/dynload/cudnn.h b/paddle/platform/dynload/cudnn.h index 0120625b7c14448f1b8deb88c24a3ee06eaf4f01..b2d69da93bcd4a5c8e694a18ca648ddc4bd947af 100644 --- a/paddle/platform/dynload/cudnn.h +++ b/paddle/platform/dynload/cudnn.h @@ -83,6 +83,7 @@ extern void* cudnn_dso_handle; __macro(cudnnDestroyConvolutionDescriptor); \ __macro(cudnnSetConvolutionNdDescriptor); \ __macro(cudnnGetConvolutionNdDescriptor); \ + __macro(cudnnDeriveBNTensorDescriptor); \ __macro(cudnnCreate); \ __macro(cudnnDestroy); \ __macro(cudnnSetStream); \ diff --git a/paddle/pserver/ParameterClient2.cpp b/paddle/pserver/ParameterClient2.cpp index 54063a809a4f9e558f8d364f5c437f2b6d98925b..9562c649867a8f82f0262a049398b2f17026a983 100644 --- a/paddle/pserver/ParameterClient2.cpp +++ b/paddle/pserver/ParameterClient2.cpp @@ -186,6 +186,7 @@ void ParameterClient2::sendParallel(int tid, parameter->getMat(recvParameterType).get()); CHECK(recvMat); size_t width = parameter->getConfig().dims(1); + // TODO(wuyi): need add lock here? may also cause resize. buf = recvMat->getLocalRow(block.begin_pos() / width); } /// sparse_id is not useful while receiving data since sparse data @@ -265,9 +266,9 @@ void ParameterClient2::prepareSendData( uint64_t beginDim = 0; uint64_t endDim = 0; - // FIXME(typhoonzero): let it resize first - prefetchMat->getLocalRow(nLocalBlocks + 1); - sendMat->getLocalRow(nLocalBlocks + 1); + // HACK(typhoonzero): let it resize first + prefetchMat->getLocalRow(nLocalBlocks); + sendMat->getLocalRow(nLocalBlocks); for (size_t row = 0; row < nLocalBlocks; ++row) { int64_t blockId = localIndices[row]; // local row -> sparse row diff --git a/paddle/trainer/NewRemoteParameterUpdater.cpp b/paddle/trainer/NewRemoteParameterUpdater.cpp index 35dcb235e7e8b65f7d1623a1ec66d963b1283385..7d5216a9669195eeed442828b9be5d379d069c3e 100644 --- a/paddle/trainer/NewRemoteParameterUpdater.cpp +++ b/paddle/trainer/NewRemoteParameterUpdater.cpp @@ -43,11 +43,6 @@ void NewRemoteParameterUpdater::init( const std::vector ¶meters) { ParameterUpdater::init(parameters); - for (auto ¶ : parameters_) { - para->getBuf(PARAMETER_VALUE)->zeroMem(); - para->getBuf(PARAMETER_GRADIENT)->zeroMem(); - } - // create parameter server client. if (useEtcd_) { parameterClient_ = @@ -109,6 +104,8 @@ void NewRemoteParameterUpdater::init( LOG(ERROR) << "got unsupported v1 learning_rate_schedule config: " << trainerConfig_.learning_rate_schedule() << ", set to const"; optimizerConfigV2.set_lr_policy(paddle::OptimizerConfig::Const); + optimizerConfigV2.mutable_const_lr()->set_learning_rate( + trainerConfig_.learning_rate()); } // overwrite optimizerConfigV2 for per-parameter(layer) configs diff --git a/paddle/trainer/tests/sample_trainer_config_branch_net.conf b/paddle/trainer/tests/sample_trainer_config_branch_net.conf index a073708a184d6392a4eead69272e684013f1c751..3d8fb77a11958218091d2ee72e1d5a40ad1d9f5b 100644 --- a/paddle/trainer/tests/sample_trainer_config_branch_net.conf +++ b/paddle/trainer/tests/sample_trainer_config_branch_net.conf @@ -89,6 +89,36 @@ tmp = img_pool_layer(input=tmp, padding=1, pool_type=MaxPooling()) +tmp = img_conv_layer(input=tmp, + filter_size=3, + num_filters=32, + padding=1, + shared_biases=True, + act=LinearActivation(), + bias_attr=False) + +tmp = batch_norm_layer(input=tmp, + use_global_stats=False, + act=ReluActivation()) + +c1 = img_conv_layer(input=tmp, + filter_size=1, + num_filters=32, + padding=0, + shared_biases=True, + act=ReluActivation()) + +c2 = img_conv_layer(input=tmp, + filter_size=3, + num_filters=32, + padding=1, + shared_biases=True, + act=ReluActivation()) + +tmp = addto_layer(input=[c1, c2], + act=ReluActivation(), + bias_attr=False) + tmp = fc_layer(input=tmp, size=64, bias_attr=False, act=TanhActivation()) diff --git a/paddle/trainer/tests/sample_trainer_config_simple_net.conf b/paddle/trainer/tests/sample_trainer_config_simple_net.conf index 2ba71884d0953dc721808732fde12e695c6a757d..c615b5622b7e50b7aa99a9fcf9f63d7b4351417c 100644 --- a/paddle/trainer/tests/sample_trainer_config_simple_net.conf +++ b/paddle/trainer/tests/sample_trainer_config_simple_net.conf @@ -38,9 +38,14 @@ tmp = img_pool_layer(input=tmp, tmp = img_conv_layer(input=tmp, filter_size=3, - num_filters=64, + num_filters=32, padding=1, shared_biases=True, + act=LinearActivation(), + bias_attr=False) + +tmp = batch_norm_layer(input=tmp, + use_global_stats=False, act=ReluActivation()) tmp = img_pool_layer(input=tmp, diff --git a/python/paddle/trainer/config_parser.py b/python/paddle/trainer/config_parser.py index 09c92d3513e86a7657880c01736f5f41f53cfcf6..e88e962cff5bbfcb8be1014dbaab85568d2625ff 100644 --- a/python/paddle/trainer/config_parser.py +++ b/python/paddle/trainer/config_parser.py @@ -2420,6 +2420,7 @@ class BatchNormLayer(LayerBase): # If not use is_static, even set learning_rate = 0, decay_rate = 0, # these paras will change if set average_window in configure. use_gpu = bool(int(g_command_config_args.get("use_gpu", 0))) + use_mkldnn = bool(int(g_command_config_args.get("use_mkldnn", 0))) is_shared = True if not use_gpu else False for i in xrange(2): inputs.append( @@ -2433,11 +2434,17 @@ class BatchNormLayer(LayerBase): parallel_nn = bool(int(g_command_config_args.get("parallel_nn", 0))) cudnn_version = int(g_command_config_args.get("cudnn_version", 0)) - # Automatically select cudnn_batch_norm for GPU and batch_norm for CPU. - # Also based on cudnn version. + # Automatically select cudnn_batch_norm for GPU, batch_norm for CPU + # and mkldnn_batch_norm for MKLDNN. Also based on cudnn version. + if batch_norm_type == "mkldnn_batch_norm": + config_assert(use_mkldnn, "mkldnn_batch_norm only support MKLDNN") use_cudnn = use_gpu and batch_norm_type != "batch_norm" and \ + not use_mkldnn and batch_norm_type != "mkldnn_batch_norm" and \ ((not parallel_nn) or self.config.device > -1) - self.layer_type = "cudnn_batch_norm" if use_cudnn else "batch_norm" + if use_cudnn: + self.layer_type = "cudnn_batch_norm" + else: + self.layer_type = "mkldnn_batch_norm" if use_mkldnn else "batch_norm" super(BatchNormLayer, self).__init__( name, self.layer_type, 0, inputs=inputs, **xargs) diff --git a/python/paddle/trainer_config_helpers/layers.py b/python/paddle/trainer_config_helpers/layers.py index 09315b9d9224076d91c16a6c0b949d4ab289bf70..cc1b34df9e7cf8d17bafeb57624548de017066e9 100644 --- a/python/paddle/trainer_config_helpers/layers.py +++ b/python/paddle/trainer_config_helpers/layers.py @@ -3014,16 +3014,19 @@ def batch_norm_layer(input, :param input: batch normalization input. Better be linear activation. Because there is an activation inside batch_normalization. :type input: LayerOutput - :param batch_norm_type: We have batch_norm and cudnn_batch_norm. batch_norm - supports both CPU and GPU. cudnn_batch_norm requires - cuDNN version greater or equal to v4 (>=v4). But - cudnn_batch_norm is faster and needs less memory - than batch_norm. By default (None), we will - automaticly select cudnn_batch_norm for GPU and - batch_norm for CPU. Otherwise, select batch norm - type based on the specified type. If you use cudnn_batch_norm, + :param batch_norm_type: We have batch_norm, mkldnn_batch_norm and cudnn_batch_norm. + batch_norm supports CPU, MKLDNN and GPU. cudnn_batch_norm + requires cuDNN version greater or equal to v4 (>=v4). + But cudnn_batch_norm is faster and needs less + memory than batch_norm. mkldnn_batch_norm requires + enable use_mkldnn. By default (None), we will + automaticly select cudnn_batch_norm for GPU, + mkldnn_batch_norm for MKLDNN and batch_norm for CPU. + Otherwise, select batch norm type based on the + specified type. If you use cudnn_batch_norm, we suggested you use latest version, such as v5.1. :type batch_norm_type: None | string, None or "batch_norm" or "cudnn_batch_norm" + or "mkldnn_batch_norm" :param act: Activation Type. Better be relu. Because batch normalization will normalize input near zero. :type act: BaseActivation @@ -3063,6 +3066,7 @@ def batch_norm_layer(input, else: num_channels = input.size assert (batch_norm_type is None) or (batch_norm_type == "batch_norm") or \ + (batch_norm_type == "mkldnn_batch_norm") or \ (batch_norm_type == "cudnn_batch_norm") l = Layer( name=name, diff --git a/python/paddle/v2/dataset/common.py b/python/paddle/v2/dataset/common.py index 053ae151c571e5557c9f2f9f4ec866f546a77797..e31e501ce93c5dc20693a8724ee7dd864f9aef55 100644 --- a/python/paddle/v2/dataset/common.py +++ b/python/paddle/v2/dataset/common.py @@ -65,7 +65,14 @@ def download(url, module_name, md5sum): os.makedirs(dirname) filename = os.path.join(dirname, url.split('/')[-1]) - if not (os.path.exists(filename) and md5file(filename) == md5sum): + retry = 0 + retry_limit = 3 + while not (os.path.exists(filename) and md5file(filename) == md5sum): + if retry < retry_limit: + retry += 1 + else: + raise RuntimeError("Cannot download {0} within retry limit {2}". + format(url, retry_limit)) print "Cache file %s not found, downloading %s" % (filename, url) r = requests.get(url, stream=True) total_length = r.headers.get('content-length') diff --git a/python/paddle/v2/framework/framework.py b/python/paddle/v2/framework/framework.py index b3f8be8be9ac5c0c6c15646d39d4796df0fd87e2..8f28d3e76688234747c75dda53e7316a202dfd14 100644 --- a/python/paddle/v2/framework/framework.py +++ b/python/paddle/v2/framework/framework.py @@ -261,7 +261,7 @@ class Operator(object): self.desc.set_attr(attr_name, attrs[attr_name]) self.desc.check_attrs() - no_kernel_op_set = {'feed', 'fetch', 'save', 'restore'} + no_kernel_op_set = {'feed', 'fetch', 'save', 'load'} if type not in no_kernel_op_set: self.desc.infer_var_type(self.block.desc) self.desc.infer_shape(self.block.desc) diff --git a/python/paddle/v2/framework/optimizer.py b/python/paddle/v2/framework/optimizer.py index a86908c64897eb4e01f3c99a66b4da27a5f3394b..e9df5483e243843992f48c7af2d1f017dfa8857c 100644 --- a/python/paddle/v2/framework/optimizer.py +++ b/python/paddle/v2/framework/optimizer.py @@ -4,7 +4,8 @@ import paddle.v2.framework.framework as framework from paddle.v2.framework.backward import append_backward_ops __all__ = [ - 'SGDOptimizer', 'MomentumOptimizer', 'AdagradOptimizer', 'AdamOptimizer' + 'SGDOptimizer', 'MomentumOptimizer', 'AdagradOptimizer', 'AdamOptimizer', + 'AdamaxOptimizer' ] @@ -211,13 +212,14 @@ class MomentumOptimizer(Optimizer): """ _velocity_acc_str = "velocity" - def __init__(self, learning_rate, momentum): + def __init__(self, learning_rate, momentum, use_nesterov=False): assert learning_rate is not None assert momentum is not None super(MomentumOptimizer, self).__init__() self.type = "momentum" self._learning_rate = learning_rate self._momentum = momentum + self._use_nesterov = bool(use_nesterov) def _initialize_tensors(self, block): assert isinstance(block, framework.Block) @@ -259,7 +261,8 @@ class MomentumOptimizer(Optimizer): "ParamOut": param_and_grad[0], "VelocityOut": velocity_acc }, - attrs={"mu": self._momentum}) + attrs={"mu": self._momentum, + "useNesterov": self._use_nesterov}) return momentum_op @@ -397,7 +400,7 @@ class AdamOptimizer(Optimizer): param_and_grad[0]) moment2 = self._get_accumulator(self._moment2_acc_str, param_and_grad[0]) - # create the momentum optimize op + # create the adam optimize op adam_op = block.append_op( type=self.type, inputs={ @@ -440,3 +443,108 @@ class AdamOptimizer(Optimizer): attrs={"scale": self._beta2}) return [scale_beta1, scale_beta2] + + +class AdamaxOptimizer(Optimizer): + """Implements the Adamax Optimizer + """ + _moment_acc_str = "moment" + _inf_norm_acc_str = "inf_norm" + + def __init__(self, + learning_rate=0.001, + beta1=0.9, + beta2=0.999, + epsilon=1e-8): + assert learning_rate is not None + assert beta1 is not None + assert beta2 is not None + assert epsilon is not None + super(AdamaxOptimizer, self).__init__() + self.type = "adamax" + self._learning_rate = learning_rate + self._beta1 = beta1 + self._beta2 = beta2 + self._epsilon = epsilon + + def _initialize_tensors(self, block): + assert isinstance(block, framework.Block) + lr_shape = [1] + # create a variable for learning_rate + self._lr = block.create_var( + dtype="float32", shape=lr_shape, lod_level=0) + + # create an op to init the learning_rate + # FIXME: Fix when Initialization design has been implemented + # https://github.com/PaddlePaddle/Paddle/pull/4852 + block.append_op( + type="fill_constant", + outputs={"Out": self._lr}, + attrs={"shape": lr_shape, + "value": self._learning_rate}) + + def _create_accumulators(self, block, parameters): + assert isinstance(block, framework.Block) + + global_block = block.program.global_block() + # Create beta1 power accumulator tensor + beta_shape = [1] + self._beta1_pow_acc = global_block.create_var( + dtype="float32", shape=beta_shape, lod_level=0) + + # Initialize beta1 power accumulator + # FIXME: Fix when Initialization design has been implemented + # https://github.com/PaddlePaddle/Paddle/pull/4852 + global_block.append_op( + type="fill_constant", + outputs={"Out": self._beta1_pow_acc}, + attrs={"shape": beta_shape, + "value": self._beta1}) + + # Create accumulator tensors for first moment and infinity norm + for p in parameters: + self._add_accumulator(block, self._moment_acc_str, p, 'float32') + self._add_accumulator(block, self._inf_norm_acc_str, p, 'float32') + + def _append_optimize_op(self, block, param_and_grad): + assert isinstance(block, framework.Block) + + moment = self._get_accumulator(self._moment_acc_str, param_and_grad[0]) + inf_norm = self._get_accumulator(self._inf_norm_acc_str, + param_and_grad[0]) + # create the adamax optimize op + adamax_op = block.append_op( + type=self.type, + inputs={ + "Param": param_and_grad[0], + "Grad": param_and_grad[1], + "LearningRate": self._lr, + "Moment": moment, + "InfNorm": inf_norm, + "Beta1Pow": self._beta1_pow_acc + }, + outputs={ + "ParamOut": param_and_grad[0], + "MomentOut": moment, + "InfNormOut": inf_norm + }, + attrs={ + "beta1": self._beta1, + "beta2": self._beta2, + "epsilon": self._epsilon + }) + + return adamax_op + + def _finish_update(self, block): + """Update Beta1 Power accumulator + """ + assert isinstance(block, framework.Block) + global_block = block.program.global_block() + scale_beta1 = global_block.append_op( + type="scale", + inputs={"X": self._beta1_pow_acc}, + outputs={"Out": self._beta1_pow_acc}, + attrs={"scale": self._beta1}) + + return [scale_beta1] diff --git a/python/paddle/v2/framework/tests/test_fill_constant_batch_size_like_op.py b/python/paddle/v2/framework/tests/test_fill_constant_batch_size_like_op.py new file mode 100644 index 0000000000000000000000000000000000000000..065a9133dca25fac988f9493c1527e0d8f9821dc --- /dev/null +++ b/python/paddle/v2/framework/tests/test_fill_constant_batch_size_like_op.py @@ -0,0 +1,21 @@ +import unittest +import numpy as np +from op_test import OpTest + + +class TestFillConstantBatchSizeLikeOp(OpTest): + def setUp(self): + self.op_type = "fill_constant_batch_size_like" + self.inputs = {'Input': np.random.random((219, 232)).astype("float32")} + self.attrs = {'value': 3.5, 'shape': [-1, 132, 777]} + + out = np.random.random((219, 132, 777)).astype("float32") + out.fill(3.5) + self.outputs = {'Out': out} + + def test_check_output(self): + self.check_output() + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/v2/framework/tests/test_optimizer.py b/python/paddle/v2/framework/tests/test_optimizer.py index eb5d49bcbafe46ddb5ce96c8565417cf9bedc668..6dfd94e8c8c96d87037faa028a3d2a537a90c9c7 100644 --- a/python/paddle/v2/framework/tests/test_optimizer.py +++ b/python/paddle/v2/framework/tests/test_optimizer.py @@ -36,7 +36,7 @@ class TestMomentumOptimizer(unittest.TestCase): def get_velocity_str(self): return self._velocity_acc_str - def test_momentum_optimizer(self): + def test_vanilla_momentum_optimizer(self): program = framework.Program() block = program.global_block() mul_x = block.create_parameter( @@ -60,6 +60,42 @@ class TestMomentumOptimizer(unittest.TestCase): self.assertEqual(len(opts), 1) sgd_op = opts[0] self.assertEqual(sgd_op.type, "momentum") + self.assertFalse(sgd_op.attr('useNesterov')) + + # Check accumulators + accumulators = momentum_optimizer.get_accumulators() + self.assertEqual(len(accumulators), 1) + self.assertTrue(momentum_optimizer.get_velocity_str() in accumulators) + velocity_acc = accumulators[momentum_optimizer.get_velocity_str()] + self.assertEqual(len(velocity_acc), 1) + self.assertTrue(mul_x.name in velocity_acc) + + def test_nesterov_momentum_optimizer(self): + program = framework.Program() + block = program.global_block() + mul_x = block.create_parameter( + dtype="float32", shape=[5, 10], lod_level=0, name="mul.x") + mul_y = block.create_var( + dtype="float32", shape=[10, 8], lod_level=0, name="mul.y") + mul_out = block.create_var( + dtype="float32", shape=[5, 8], lod_level=0, name="mul.out") + block.append_op( + type="mul", + inputs={"X": mul_x, + "Y": mul_y}, + outputs={"Out": mul_out}, + attrs={"x_num_col_dims": 1}) + momentum_optimizer = self.MockMomentum( + learning_rate=0.01, momentum=0.2, use_nesterov=True) + params_grads = append_backward_ops(mul_out) + self.assertEqual(len(params_grads), 1) + self.assertEqual(len(momentum_optimizer.get_accumulators()), 0) + opts = momentum_optimizer.create_optimization_pass(params_grads, + mul_out) + self.assertEqual(len(opts), 1) + sgd_op = opts[0] + self.assertEqual(sgd_op.type, "momentum") + self.assertTrue(sgd_op.attr('useNesterov')) # Check accumulators accumulators = momentum_optimizer.get_accumulators() @@ -160,5 +196,54 @@ class TestAdamOptimizer(unittest.TestCase): self.assertTrue(mul_x.name in moment2_acc) +class TestAdamaxOptimizer(unittest.TestCase): + class MockAdamax(optimizer.AdamaxOptimizer): + def get_accumulators(self): + return self._accumulators + + def get_moment_str(self): + return self._moment_acc_str + + def get_inf_norm_str(self): + return self._inf_norm_acc_str + + def test_adamax_optimizer(self): + program = framework.Program() + block = program.global_block() + mul_x = block.create_parameter( + dtype="float32", shape=[5, 10], lod_level=0, name="mul.x") + mul_y = block.create_var( + dtype="float32", shape=[10, 8], lod_level=0, name="mul.y") + mul_out = block.create_var( + dtype="float32", shape=[5, 8], lod_level=0, name="mul.out") + block.append_op( + type="mul", + inputs={"X": mul_x, + "Y": mul_y}, + outputs={"Out": mul_out}, + attrs={"x_num_col_dims": 1}) + adamax_optimizer = self.MockAdamax( + learning_rate=0.01, beta1=0.9, beta2=0.999) + params_grads = append_backward_ops(mul_out) + self.assertEqual(len(params_grads), 1) + self.assertEqual(len(adamax_optimizer.get_accumulators()), 0) + opts = adamax_optimizer.create_optimization_pass(params_grads, mul_out) + self.assertEqual(len(opts), 2) + adam_op = opts[0] + self.assertEqual(adam_op.type, "adamax") + + # Check accumulators + accumulators = adamax_optimizer.get_accumulators() + self.assertEqual(len(accumulators), 2) + self.assertTrue(adamax_optimizer.get_moment_str() in accumulators) + self.assertTrue(adamax_optimizer.get_inf_norm_str() in accumulators) + moment_acc = accumulators[adamax_optimizer.get_moment_str()] + inf_norm_acc = accumulators[adamax_optimizer.get_inf_norm_str()] + self.assertEqual(len(moment_acc), 1) + self.assertEqual(len(inf_norm_acc), 1) + self.assertTrue(mul_x.name in moment_acc) + self.assertTrue(mul_x.name in inf_norm_acc) + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/v2/framework/tests/test_save_restore_op.py b/python/paddle/v2/framework/tests/test_save_restore_op.py deleted file mode 100644 index 3a36d03f62a7ad50f656e5c3fdb8c87548a120e8..0000000000000000000000000000000000000000 --- a/python/paddle/v2/framework/tests/test_save_restore_op.py +++ /dev/null @@ -1,71 +0,0 @@ -import paddle.v2.framework.core as core -import paddle.v2.framework.framework as framework -import paddle.v2.framework.executor as executor - -import numpy as np -import unittest -import os -import sys -import shutil - -FOLDER_PATH = "./tmp_test_dir" - - -class TestSaveRestoreOp(unittest.TestCase): - def test_save_restore_op(self): - tensor_1_val = np.random.rand(3, 9).astype("float32") - tensor_2_val = np.random.randint(0, 20, size=(4, 2)).astype("int32") - place = core.CPUPlace() - - program = framework.Program() - block = program.global_block() - v_a = block.create_var( - dtype="float32", shape=[3, 9], lod_level=0, name="tensor_1") - v_b = block.create_var( - dtype="int32", shape=[4, 2], lod_level=0, name="tensor_2") - - t_1 = core.LoDTensor() - t_1.set(tensor_1_val, place) - t_2 = core.LoDTensor() - t_2.set(tensor_2_val, place) - block.append_op( - type="save", - inputs={"X": [v_a, v_b]}, - attrs={"folderPath": FOLDER_PATH}) - block.append_op( - type="fill_constant", - outputs={"Out": [v_a]}, - attrs={"shape": [2, 2], - "value": 0.0}) - block.append_op( - type="fill_constant", - outputs={"Out": [v_b]}, - attrs={"shape": [2, 2], - "value": 0.0}) - block.append_op( - type="restore", - outputs={"Out": [v_a, v_b]}, - attrs={"folderPath": FOLDER_PATH}) - - if os.path.exists(FOLDER_PATH): - shutil.rmtree(FOLDER_PATH) - os.makedirs(FOLDER_PATH) - - exe = executor.Executor(place) - out = exe.run(program, - feed={"tensor_1": t_1, - "tensor_2": t_2}, - fetch_list=[v_a, v_b]) - - self.assertTrue(os.path.isdir(FOLDER_PATH)) - self.assertTrue(os.path.isfile(FOLDER_PATH + "/__tensor_1__")) - self.assertTrue(os.path.isfile(FOLDER_PATH + "/__tensor_2__")) - - self.assertTrue(np.array_equal(np.array(out[0]), tensor_1_val)) - self.assertTrue(np.array_equal(np.array(out[1]), tensor_2_val)) - - shutil.rmtree(FOLDER_PATH) - - -if __name__ == "__main__": - unittest.main() diff --git a/python/paddle/v2/framework/tests/test_squared_l2_norm_op.py b/python/paddle/v2/framework/tests/test_squared_l2_norm_op.py new file mode 100644 index 0000000000000000000000000000000000000000..5a52c6a66c781672a483324083b97a3c5894f508 --- /dev/null +++ b/python/paddle/v2/framework/tests/test_squared_l2_norm_op.py @@ -0,0 +1,29 @@ +import numpy as np +import unittest +from numpy import linalg as LA +from op_test import OpTest + + +class TestL2LossOp(OpTest): + """Test squared_l2_norm + """ + + def setUp(self): + self.op_type = "squared_l2_norm" + self.max_relative_error = 0.05 + + X = np.random.uniform(-1, 1, (13, 19)).astype("float32") + X[np.abs(X) < self.max_relative_error] = 0.1 + self.inputs = {'X': X} + self.outputs = {'Out': np.square(LA.norm(X))} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad( + ['X'], 'Out', max_relative_error=self.max_relative_error) + + +if __name__ == "__main__": + unittest.main()