& no_grad_vars);
-```
-
-The implementation behind it can be divided into two parts, **Backward Operator Creating** and **Backward Operator Building**.
-
-### Backward Operator Registry
-
-A backward network is built up with several backward operators. Backward operators take forward operators' inputs, outputs, and output gradients and then calculate its input gradients.
-
-| | forward operator | backward operator
-| ---------------------- | ---------------- |------------------------- |
-| **Operator::inputs_** | Inputs | Inputs, Outputs, OutputGradients |
-| **Operator::outputs_** | Outputs | InputGradients |
-
- In most cases, there is a one-to-one relation between the forward and backward operators. These relations are recorded by a global hash map(`OpInfoMap`). To follow the philosophy of minimum core and to make operators pluggable, the registry mechanism is introduced.
-
-For example, we have `mul_op`, and we can register its information and corresponding backward operator by the following macro:
-
-```cpp
-REGISTER_OP(mul, MulOp, MulOpMaker, mul_grad, MulOpGrad);
-```
-
-`mul` is the operator's type. `MulOp` and `MulOpMaker` are the operator class and the operator maker class respectively.
-
-`mul_grad` is the type of backward operator, and `MulOpGrad` is its class name.
-
-### Backward Opeartor Creating
-
-Given a certain forward operator, we can get its corresponding backward operator by calling:
-
-```cpp
-OperatorBase* bwd_op = BuildGradOp(const OperatorBase* fwd_op);
-```
-
-The function `BuildGradOp` will sequentially execute following processes:
-
-1. Get the `type_` of given forward operator, and then get the corresponding backward operator's type by looking up the `OpInfoMap`.
-
-2. Build two maps named `inputs` and `outputs` to temporarily store backward operator's inputs and outputs. Copy forward operator's `inputs_` and `outputs_` to map `inputs`, except these, are not necessary for gradient computing.
-
-3. Add forward inputs' gradient variables into map `output`, adding forward outputs' gradient variables into map `input`.
-
-4. Building backward operator with `inputs`, `outputs` and forward operator's attributes.
-
-### Backward Network Building
-
-A backward network is a series of backward operators. The main idea of building a backward network is creating backward operators in the inverted sequence and appending them together one by one. There are some corner cases that need special processing.
-
-1. Op
-
- When the input forward network is an Op, return its gradient Operator immediately. If all of its outputs are in no gradient set, then return a special `NOP`.
-
-2. NetOp
-
- In our design, the network itself is also a kind of operator(**NetOp**). So the operators contained by a big network may be some small network. When the input forward network is a NetOp, it needs to call the sub NetOp/Operators backward function recursively. During the process, we need to collect the `OutputGradients` name according to the forward NetOp.
-
-3. RnnOp
-
- RnnOp is a nested stepnet operator. Backward module needs to recusively call `Backward` for every stepnet.
-
-4. Sharing Variables
-
- As illustrated in the figure 1 and figure 2, two operators share the same variable name **W@GRAD**, which will overwrite their shared input variable.
-
-
-
-
- Figure 1. Sharing variables in operators.
-
-
-
- Sharing variable between operators or same input variable used in multiple operators can lead to duplicate gradient variables. As illustrated in figure 2, we need to rename the gradient names recursively and add a generic add operator to prevent overwriting.
-
-
-
-
- Figure 2. Replace sharing variable's gradient with `Add` operator.
-
-
-
- Because the framework finds variables according to their names, we need to rename the output links. We add an integer suffix to represent its position in the clockwise direction.
-
-5. Part of the Gradient is Zero.
-
- In the whole graph, there is some case of that one operator's gradient is not needed, but its input's gradient is a dependency link of other operator, we need to fill a same shape gradient matrix in the position. In our implementation, we insert a special `fillZeroLike` operator.
-
-
-Follow these rules above, then collect the sub graph `OutputGradients`/`InputGradients` as the NetOp's and return it.
diff --git a/paddle/framework/backward_test.cc b/paddle/framework/backward_test.cc
index 0957646b5642cd9afce5d88b2c638679cb01f198..692406b1c37d0c02714eafb5cf9a28329ed873bc 100644
--- a/paddle/framework/backward_test.cc
+++ b/paddle/framework/backward_test.cc
@@ -1,16 +1,16 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
- http://www.apache.org/licenses/LICENSE-2.0
+ http://www.apache.org/licenses/LICENSE-2.0
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License. */
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
#include "paddle/framework/backward.h"
diff --git a/paddle/framework/data_layout.h b/paddle/framework/data_layout.h
index 7d7a444cf02ba0da88178a34c98f5ef7a95c3852..4a8669c3a41fceaad26878a79eabfd0affce86fd 100644
--- a/paddle/framework/data_layout.h
+++ b/paddle/framework/data_layout.h
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
+#include "paddle/platform/enforce.h"
#include
#include "paddle/platform/enforce.h"
@@ -20,7 +21,7 @@ limitations under the License. */
namespace paddle {
namespace framework {
-enum DataLayout {
+enum class DataLayout {
kNHWC = 0,
kNCHW = 1,
kAnyLayout = 2,
@@ -38,11 +39,11 @@ inline DataLayout StringToDataLayout(const std::string& str) {
inline std::string DataLayoutToString(const DataLayout& data_layout) {
switch (data_layout) {
- case kNHWC:
+ case DataLayout::kNHWC:
return "NHWC";
- case kNCHW:
+ case DataLayout::kNCHW:
return "NCHW";
- case kAnyLayout:
+ case DataLayout::kAnyLayout:
return "ANY_LAYOUT";
default:
PADDLE_THROW("unknown DataLayou %d", data_layout);
diff --git a/paddle/framework/data_transform.cc b/paddle/framework/data_transform.cc
new file mode 100644
index 0000000000000000000000000000000000000000..376268888e70b0a70060c81384f79f8bf5d6dcc5
--- /dev/null
+++ b/paddle/framework/data_transform.cc
@@ -0,0 +1,27 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#include "paddle/framework/data_transform.h"
+#include "paddle/framework/lod_tensor.h"
+
+namespace paddle {
+namespace framework {
+
+DataTransformFnMap& DataTransformFnMap::Instance() {
+ static DataTransformFnMap data_transform_map;
+ return data_transform_map;
+}
+
+} // namespace framework
+} // namespace paddle
diff --git a/paddle/framework/data_transform.h b/paddle/framework/data_transform.h
new file mode 100644
index 0000000000000000000000000000000000000000..bd6d301c12e0611c5b01c3ff58869dbeb96b268e
--- /dev/null
+++ b/paddle/framework/data_transform.h
@@ -0,0 +1,108 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#pragma once
+
+#include
+#include
+#include
+
+#include "paddle/framework/op_kernel_type.h"
+#include "paddle/framework/tensor.h"
+#include "paddle/framework/variable.h"
+#include "paddle/platform/device_context.h"
+#include "paddle/platform/macros.h"
+
+namespace paddle {
+namespace framework {
+
+using DataTransformFn = std::function;
+using KernelTypePair = std::pair;
+
+struct KernelTypePairHash {
+ static void HashCombine(const OpKernelType& t, std::size_t* seed) {
+ OpKernelType::Hash kernel_type_hasher;
+ (*seed) ^= kernel_type_hasher(t) + 0x9e3779b9 + (*seed << 6) + (*seed >> 2);
+ }
+
+ size_t operator()(const KernelTypePair& kernel_pair) const {
+ std::size_t seed = 0;
+ HashCombine(kernel_pair.first, &seed);
+ HashCombine(kernel_pair.second, &seed);
+ return seed;
+ }
+};
+
+using DataTransformMap =
+ std::unordered_map;
+
+class DataTransformFnMap {
+ public:
+ static DataTransformFnMap& Instance();
+
+ bool Has(const KernelTypePair& key_pair) const {
+ return map_.find(key_pair) != map_.end();
+ }
+
+ void Insert(const OpKernelType& left, const OpKernelType& right,
+ const DataTransformFn& data_tranform_fn) {
+ Insert(std::make_pair(left, right), data_tranform_fn);
+ }
+
+ void Insert(const KernelTypePair& kernel_type_pair,
+ const DataTransformFn& data_tranform_fn) {
+ PADDLE_ENFORCE(!Has(kernel_type_pair),
+ "KernelTypePair %s has been registered", "");
+ map_.insert({kernel_type_pair, data_tranform_fn});
+ }
+
+ const DataTransformFn& Get(const KernelTypePair& key_pair) const {
+ auto data_transformer = GetNullable(key_pair);
+ PADDLE_ENFORCE_NOT_NULL(data_transformer,
+ "DataTransformFn should not be NULL");
+ return *data_transformer;
+ }
+
+ const DataTransformFn* GetNullable(const KernelTypePair& key_pair) const {
+ auto it = map_.find(key_pair);
+ if (it == map_.end()) {
+ return nullptr;
+ } else {
+ return &(it->second);
+ }
+ }
+
+ const DataTransformMap& Map() const { return map_; }
+
+ private:
+ DataTransformFnMap() = default;
+ DataTransformMap map_;
+ DISABLE_COPY_AND_ASSIGN(DataTransformFnMap);
+};
+
+// generate unique name with __LINE__
+// refs https://stackoverflow.com/questions/1597007
+#define TOKENPASTE(x, y) x##y
+#define TOKENPASTE2(x, y) TOKENPASTE(x, y)
+#define REGISTER_DATA_TRANSFORM_FN(from, to, fn) \
+ static int TOKENPASTE2(fn_, __LINE__)() { \
+ ::paddle::framework::DataTransformFnMap::Instance().Insert(from, to, fn); \
+ return 0; \
+ } \
+ static int TOKENPASTE2(var_, __LINE__) __attribute__((unused)) = \
+ TOKENPASTE2(fn_, __LINE__)()
+
+} // namespace framework
+} // namespace paddle
diff --git a/paddle/framework/data_transform_test.cc b/paddle/framework/data_transform_test.cc
new file mode 100644
index 0000000000000000000000000000000000000000..5f05e881fa16eead1dc690f85375706bf3cd3e6d
--- /dev/null
+++ b/paddle/framework/data_transform_test.cc
@@ -0,0 +1,99 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+#include
+#include
+
+#include
+
+#include "paddle/framework/data_transform.h"
+
+namespace paddle {
+namespace framework {
+using namespace platform;
+
+/**
+ * @brief cross validation of different kernel type transform
+ * We use four bit map represent different combination.
+ * If the field has multiple possible value, only choose two of them.
+ * For DataType, only test the FP32(float), FP64(double).
+ * e.g. 0000 -> FP32, CPUPlace, kNHWC, kPlain
+ * 1111 -> FP64, GPUPlace, kNCHW, kMKLDNN
+ */
+
+std::array kDataType = {
+ {proto::DataType::FP32, proto::DataType::FP64}};
+
+std::array kPlace = {{CPUPlace(), CUDAPlace(0)}};
+
+std::array kDataLayout = {
+ {DataLayout::kNHWC, DataLayout::kNCHW}};
+
+std::array kLibraryType = {
+ {LibraryType::kPlain, LibraryType::kMKLDNN}};
+
+OpKernelType GenFromBit(const std::vector bits) {
+ return OpKernelType(kDataType[bits[0]], kPlace[bits[1]], kDataLayout[bits[2]],
+ kLibraryType[bits[3]]);
+}
+
+int test_value = 0;
+
+auto kernel0 = GenFromBit({0, 0, 0, 0});
+auto kernel1 = GenFromBit({0, 0, 0, 1});
+auto kernel2 = GenFromBit({0, 0, 1, 0});
+auto kernel3 = GenFromBit({0, 0, 1, 1});
+
+void TransDataType_t(const platform::DeviceContext* ctx, const Variable& in,
+ Variable* out) {
+ test_value++;
+}
+
+void TransDataLayout_t(const platform::DeviceContext* ctx, const Variable& in,
+ Variable* out) {
+ test_value--;
+}
+
+void TransLibraryType_t(const platform::DeviceContext* ctx, const Variable& in,
+ Variable* out) {
+ test_value += 2;
+}
+
+} // namespace framework
+} // namespace paddle
+
+namespace frw = paddle::framework;
+
+REGISTER_DATA_TRANSFORM_FN(frw::kernel0, frw::kernel1, frw::TransDataType_t);
+REGISTER_DATA_TRANSFORM_FN(frw::kernel1, frw::kernel2, frw::TransDataLayout_t);
+REGISTER_DATA_TRANSFORM_FN(frw::kernel0, frw::kernel2, frw::TransLibraryType_t);
+
+TEST(DataTransform, Register) {
+ using namespace paddle::framework;
+ using namespace paddle::platform;
+
+ auto& instance = DataTransformFnMap::Instance();
+ ASSERT_EQ(instance.Map().size(), 3UL);
+ DeviceContext* ctx = nullptr;
+ paddle::framework::Variable in;
+ paddle::framework::Variable out;
+
+ instance.Get(std::make_pair(frw::kernel0, frw::kernel1))(ctx, in, &out);
+ ASSERT_EQ(test_value, 1);
+
+ instance.Get(std::make_pair(frw::kernel1, frw::kernel2))(ctx, in, &out);
+ ASSERT_EQ(test_value, 0);
+
+ instance.Get(std::make_pair(frw::kernel0, frw::kernel2))(ctx, in, &out);
+ ASSERT_EQ(test_value, 2);
+}
diff --git a/paddle/framework/data_type.h b/paddle/framework/data_type.h
index e94ee2ed52bc40f52caf783f971dd0b560534e08..6a372ac32e48131eed28e2d42125feb5b92a11c7 100644
--- a/paddle/framework/data_type.h
+++ b/paddle/framework/data_type.h
@@ -1,16 +1,16 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
- http://www.apache.org/licenses/LICENSE-2.0
+ http://www.apache.org/licenses/LICENSE-2.0
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License. */
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
#pragma once
#include
diff --git a/paddle/framework/ddim_test.cc b/paddle/framework/ddim_test.cc
index bd5ea09d7da700479aa387283d7bde77c64c1293..bc259d1f603fb34ac8546c388669d8c5c1250bd1 100644
--- a/paddle/framework/ddim_test.cc
+++ b/paddle/framework/ddim_test.cc
@@ -1,16 +1,16 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
- http://www.apache.org/licenses/LICENSE-2.0
+ http://www.apache.org/licenses/LICENSE-2.0
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License. */
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
#include
#include
diff --git a/paddle/framework/details/op_registry.h b/paddle/framework/details/op_registry.h
index 7f5151c41d6046f21f7a9707e45de85ec50219ad..6d50e820b2b625f932768d2ca671d999071f1ca6 100644
--- a/paddle/framework/details/op_registry.h
+++ b/paddle/framework/details/op_registry.h
@@ -1,16 +1,16 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
- http://www.apache.org/licenses/LICENSE-2.0
+ http://www.apache.org/licenses/LICENSE-2.0
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License. */
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
#pragma once
diff --git a/paddle/framework/executor.cc b/paddle/framework/executor.cc
index 997773c1689efad4ce5a86c09ce58bd3a40185e0..bf1f0471ccbfccf13cb6f74c8088da7acd68ec0b 100644
--- a/paddle/framework/executor.cc
+++ b/paddle/framework/executor.cc
@@ -14,18 +14,17 @@ limitations under the License. */
#include "paddle/framework/executor.h"
-#include
-#include
-#include
#include
-#include
+#include "gflags/gflags.h"
#include "paddle/framework/feed_fetch_type.h"
#include "paddle/framework/lod_rank_table.h"
-#include "paddle/framework/lod_tensor.h"
#include "paddle/framework/lod_tensor_array.h"
#include "paddle/framework/op_registry.h"
-#include "paddle/framework/scope.h"
+
+DEFINE_bool(check_nan_inf, false,
+ "Checking whether operator produce NAN/INF or not. It will be "
+ "extremely slow so please use this flag wisely.");
namespace paddle {
namespace framework {
@@ -58,6 +57,19 @@ static void CreateTensor(Variable* var, proto::VarDesc::VarType var_type) {
}
}
+static void CheckTensorNANOrInf(const std::string& name,
+ const framework::Tensor& tensor) {
+ if (tensor.memory_size() == 0) {
+ return;
+ }
+ if (tensor.type().hash_code() != typeid(float).hash_code() &&
+ tensor.type().hash_code() != typeid(double).hash_code()) {
+ return;
+ }
+ PADDLE_ENFORCE(!framework::HasInf(tensor), "Tensor %s has Inf", name);
+ PADDLE_ENFORCE(!framework::HasNAN(tensor), "Tensor %s has NAN", name);
+}
+
void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id,
bool create_local_scope, bool create_vars) {
// TODO(tonyyang-svail):
@@ -101,8 +113,17 @@ void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id,
auto op = paddle::framework::OpRegistry::CreateOp(*op_desc);
VLOG(3) << op->DebugString();
op->Run(*local_scope, place_);
+ if (FLAGS_check_nan_inf) {
+ for (auto& vname : op->OutputVars(true)) {
+ auto* var = local_scope->FindVar(vname);
+ if (var == nullptr) continue;
+ if (var->IsType()) {
+ CheckTensorNANOrInf(vname, var->Get());
+ }
+ }
+ }
}
- if (create_local_scope) {
+ if (create_vars && create_local_scope) {
scope->DeleteScope(local_scope);
}
}
diff --git a/paddle/framework/feed_fetch_type.h b/paddle/framework/feed_fetch_type.h
index bc4ae440fc708f696c18bb9d5ab3ba7dd59e21ab..9bc4a90c44828ecb7458d524f59609f01848cc5c 100644
--- a/paddle/framework/feed_fetch_type.h
+++ b/paddle/framework/feed_fetch_type.h
@@ -1,16 +1,16 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
- http://www.apache.org/licenses/LICENSE-2.0
+ http://www.apache.org/licenses/LICENSE-2.0
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License. */
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
#pragma once
#include
diff --git a/paddle/framework/grad_op_desc_maker.h b/paddle/framework/grad_op_desc_maker.h
index cf411fa710103350713342b43946697a8dd2aa46..2de5242831835b47893a5825e5532500ad5ec3f9 100644
--- a/paddle/framework/grad_op_desc_maker.h
+++ b/paddle/framework/grad_op_desc_maker.h
@@ -1,16 +1,16 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
- http://www.apache.org/licenses/LICENSE-2.0
+ http://www.apache.org/licenses/LICENSE-2.0
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License. */
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
#pragma once
#include
diff --git a/paddle/framework/init.cc b/paddle/framework/init.cc
index 3ff2da344627ed3ada3955ec5ee2c886402554f4..682cff168d4d31e0565fc987604f97a671566fbd 100644
--- a/paddle/framework/init.cc
+++ b/paddle/framework/init.cc
@@ -1,16 +1,16 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
- http://www.apache.org/licenses/LICENSE-2.0
+ http://www.apache.org/licenses/LICENSE-2.0
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License. */
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
#include
#include
@@ -71,7 +71,7 @@ bool InitDevices(const std::vector &devices) {
places.emplace_back(platform::CPUPlace());
LOG(WARNING) << "Not specified CPU device, create CPU by Default.";
}
- platform::DeviceContextPool::Create(places);
+ platform::DeviceContextPool::Init(places);
return true;
}
diff --git a/paddle/framework/init.h b/paddle/framework/init.h
index 1715cd81e6647158e269e39d4d91fbe065cd0008..33907f9eb00fb3469b53dcf8151557cc7a2d3791 100644
--- a/paddle/framework/init.h
+++ b/paddle/framework/init.h
@@ -1,16 +1,16 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
- http://www.apache.org/licenses/LICENSE-2.0
+ http://www.apache.org/licenses/LICENSE-2.0
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License. */
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
#pragma once
#include
diff --git a/paddle/framework/init_test.cc b/paddle/framework/init_test.cc
index cb1ba7ce8fdbf740846689356c94c4f2fabb95cb..f0788051d4855a175d2d7ea1f1a0805c776c462b 100644
--- a/paddle/framework/init_test.cc
+++ b/paddle/framework/init_test.cc
@@ -1,16 +1,16 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
- http://www.apache.org/licenses/LICENSE-2.0
+ http://www.apache.org/licenses/LICENSE-2.0
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License. */
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
#include "gtest/gtest.h"
#include "paddle/framework/init.h"
diff --git a/paddle/framework/library_type.h b/paddle/framework/library_type.h
index aa66cf00f3be14c4ac9496326190428688fc3496..7707799cae8c4edc304cd81725270a85f01fd28d 100644
--- a/paddle/framework/library_type.h
+++ b/paddle/framework/library_type.h
@@ -20,18 +20,41 @@ namespace framework {
// For more details about the design of LibraryType, Please refer to
// https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/operator_kernel_type.md#library
-enum LibraryType { kPlain = 0, kMKLDNN = 1, kCUDNN = 2 };
+enum class LibraryType {
+ kPlain = 0,
+ kMKLDNN = 1,
+ kCUDNN = 2,
+};
inline std::string LibraryTypeToString(const LibraryType& library_type) {
switch (library_type) {
- case kPlain:
+ case LibraryType::kPlain:
return "PLAIN";
- case kMKLDNN:
+ case LibraryType::kMKLDNN:
return "MKLDNN";
- case kCUDNN:
+ case LibraryType::kCUDNN:
return "CUDNN";
default:
- PADDLE_THROW("unknown LibraryType %d", library_type);
+ PADDLE_THROW("unknown LibraryType %d", static_cast(library_type));
+ }
+}
+
+inline LibraryType StringToLibraryType(const char* ctype) {
+ std::string s(ctype);
+ if (s == std::string("PLAIN")) {
+ return LibraryType::kPlain;
+ } else if (s == std::string("MKLDNN")) {
+ return LibraryType::kMKLDNN;
+ } else if (s == std::string("CUDNN")) {
+ return LibraryType::kCUDNN;
+ // To be compatible with register macro.
+ // CPU, CUDA, PLAIN are same library type.
+ } else if (s == std::string("CPU")) {
+ return LibraryType::kPlain;
+ } else if (s == std::string("CUDA")) {
+ return LibraryType::kPlain;
+ } else {
+ PADDLE_THROW("Unknown LibraryType %s", s.c_str());
}
}
diff --git a/paddle/framework/lod_rank_table.cc b/paddle/framework/lod_rank_table.cc
index 17d524c09276fc0eb166925bd79bc0bdfcead195..704bce2a0eb60b974efd41a4edda0af2933da825 100644
--- a/paddle/framework/lod_rank_table.cc
+++ b/paddle/framework/lod_rank_table.cc
@@ -1,16 +1,16 @@
-/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
- http://www.apache.org/licenses/LICENSE-2.0
+ http://www.apache.org/licenses/LICENSE-2.0
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License. */
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
#include "paddle/framework/lod_rank_table.h"
diff --git a/paddle/framework/lod_rank_table.h b/paddle/framework/lod_rank_table.h
index d3007d3d7379a59b32465cbd55780c6268e0e4a8..df188709e91871ded0258fa5703ee16a5664f057 100644
--- a/paddle/framework/lod_rank_table.h
+++ b/paddle/framework/lod_rank_table.h
@@ -1,16 +1,16 @@
-/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
- http://www.apache.org/licenses/LICENSE-2.0
+ http://www.apache.org/licenses/LICENSE-2.0
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License. */
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
#pragma once
#include
diff --git a/paddle/framework/lod_tensor.cc b/paddle/framework/lod_tensor.cc
index d766d3c4163b6b7c6fdc772acb4b7e7b315f8783..7b6dc09bdb5535488c8c4dbc71c9cd6a7998bd0b 100644
--- a/paddle/framework/lod_tensor.cc
+++ b/paddle/framework/lod_tensor.cc
@@ -1,16 +1,16 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
- http://www.apache.org/licenses/LICENSE-2.0
+ http://www.apache.org/licenses/LICENSE-2.0
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License. */
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
#include "paddle/framework/lod_tensor.h"
#include "paddle/framework/data_type.h"
@@ -189,62 +189,16 @@ void AppendLoD(LoD *lod, const LoD &lod_length) {
void SerializeToStream(std::ostream &os, const LoDTensor &tensor,
const platform::DeviceContext &dev_ctx) {
- // TODO(typhoonzero): serialize to ostream
- { // the 1st field, uint32_t version
+ { // the 1st field, uint32_t version for LoDTensor
constexpr uint32_t version = 0;
os.write(reinterpret_cast(&version), sizeof(version));
}
- { // the 2nd field, tensor description
- // int32_t size
- // void* protobuf message
- proto::TensorDesc desc;
- desc.set_data_type(framework::ToDataType(tensor.type()));
- auto dims = framework::vectorize(tensor.dims());
- auto *pb_dims = desc.mutable_dims();
- pb_dims->Resize(static_cast(dims.size()), 0);
- std::copy(dims.begin(), dims.end(), pb_dims->begin());
- int32_t size = desc.ByteSize();
- os.write(reinterpret_cast(&size), sizeof(size));
- auto out = desc.SerializeAsString();
- os.write(out.data(), size);
- }
- { // the 3rd field, tensor data
- uint64_t size = tensor.memory_size();
- auto *data_ptr = tensor.data();
- PADDLE_ENFORCE(size < std::numeric_limits::max(),
- "Index overflow when writing tensor");
- if (platform::is_gpu_place(tensor.place())) {
-#ifdef PADDLE_WITH_CUDA
- constexpr size_t kBufSize = 1024 * 1024 * 64; // 64MB
- std::unique_ptr buf(new char[kBufSize]);
- auto &gpu_dev_ctx =
- static_cast(dev_ctx);
- platform::CPUPlace cpu;
- uintptr_t data = reinterpret_cast(data_ptr);
- while (size != 0) {
- size_t size_to_write = std::min(kBufSize, static_cast(size));
- memory::Copy(cpu, buf.get(),
- boost::get(tensor.place()),
- reinterpret_cast(data), size_to_write,
- gpu_dev_ctx.stream());
- gpu_dev_ctx.Wait();
- os.write(buf.get(), size_to_write);
- data += size_to_write;
- size -= size_to_write;
- }
-#else
- PADDLE_THROW("Unexpected branch");
-#endif
- } else {
- os.write(static_cast(data_ptr),
- static_cast(size));
- }
- }
- { // the 4th field, lod information
- // uint64_t lod_level
- // uint64_t lod_level_1 size in byte.
- // int* lod_level_1 data
- // ...
+ {
+ // the 2st field, LoD information
+ // uint64_t lod_level
+ // uint64_t lod_level_1 size in byte.
+ // int* lod_level_1 data
+ // ...
auto lod = tensor.lod();
uint64_t size = lod.size();
os.write(reinterpret_cast(&size), sizeof(size));
@@ -256,49 +210,19 @@ void SerializeToStream(std::ostream &os, const LoDTensor &tensor,
static_cast(size));
}
}
+ // the 3st field, Tensor
+ SerializeToStream(os, static_cast(tensor), dev_ctx);
}
void DeserializeFromStream(std::istream &is, LoDTensor *tensor) {
- uint32_t version;
- is.read(reinterpret_cast(&version), sizeof(version));
- PADDLE_ENFORCE_EQ(version, 0U, "Only version 0 is supported");
- proto::TensorDesc desc;
- { // int32_t size
- // proto buffer
- int32_t size;
- is.read(reinterpret_cast(&size), sizeof(size));
- std::unique_ptr buf(new char[size]);
- is.read(reinterpret_cast(buf.get()), size);
- PADDLE_ENFORCE(desc.ParseFromArray(buf.get(), size),
- "Cannot parse tensor desc");
- }
- { // read tensor
- std::vector dims;
- dims.reserve(static_cast(desc.dims().size()));
- std::copy(desc.dims().begin(), desc.dims().end(), std::back_inserter(dims));
- tensor->Resize(framework::make_ddim(dims));
-
- void *buf;
- platform::Place cpu = platform::CPUPlace();
- switch (desc.data_type()) {
- case proto::FP32:
- buf = tensor->mutable_data(cpu);
- break;
- case proto::FP64:
- buf = tensor->mutable_data(cpu);
- break;
- case proto::INT32:
- buf = tensor->mutable_data(cpu);
- break;
- case proto::INT64:
- buf = tensor->mutable_data(cpu);
- break;
- default:
- PADDLE_THROW("DataType %d not supported", desc.data_type());
- }
- is.read(static_cast(buf), tensor->memory_size());
- }
- { // read lod
+ {
+ // the 1st field, unit32_t version for SelectedRows
+ uint32_t version;
+ is.read(reinterpret_cast(&version), sizeof(version));
+ PADDLE_ENFORCE_EQ(version, 0U, "Only version 0 is supported");
+ }
+ {
+ // the 2st field, LoD information
uint64_t lod_level;
is.read(reinterpret_cast(&lod_level), sizeof(lod_level));
auto &lod = *tensor->mutable_lod();
@@ -312,6 +236,8 @@ void DeserializeFromStream(std::istream &is, LoDTensor *tensor) {
lod[i] = tmp;
}
}
+ // the 3st filed, Tensor
+ DeserializeFromStream(is, static_cast(tensor));
}
} // namespace framework
diff --git a/paddle/framework/lod_tensor.h b/paddle/framework/lod_tensor.h
index 0923c52a0ad2fe10cea760df20c99021984ad39d..147db3ab0877662d9e47ae7ee6df05638b5fcbd1 100644
--- a/paddle/framework/lod_tensor.h
+++ b/paddle/framework/lod_tensor.h
@@ -1,16 +1,16 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
- http://www.apache.org/licenses/LICENSE-2.0
+ http://www.apache.org/licenses/LICENSE-2.0
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License. */
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
#pragma once
diff --git a/paddle/framework/lod_tensor_array.h b/paddle/framework/lod_tensor_array.h
index 13f0608d24be97d8bba149b74f1a4deb57deeb48..4a8e7f4fa540b1c2f19a6e3ec236a0dd5c0daf0b 100644
--- a/paddle/framework/lod_tensor_array.h
+++ b/paddle/framework/lod_tensor_array.h
@@ -1,16 +1,16 @@
-/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
- http://www.apache.org/licenses/LICENSE-2.0
+ http://www.apache.org/licenses/LICENSE-2.0
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License. */
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
#pragma once
#include
diff --git a/paddle/framework/lod_tensor_test.cc b/paddle/framework/lod_tensor_test.cc
index 02d84b68233f2fdfc66e1df2fc7ce20307cadd94..0747c8db531d6ae443d76591b945cce0c9bbea2b 100644
--- a/paddle/framework/lod_tensor_test.cc
+++ b/paddle/framework/lod_tensor_test.cc
@@ -126,6 +126,20 @@ TEST_F(LoDTensorTester, ShrinkInLevel) {
EXPECT_NE(t1.data(), lod_tensor_.data());
}
+TEST_F(LoDTensorTester, SerializeAndDeserialize) {
+ LoDTensor dst_tensor;
+ platform::CPUDeviceContext cpu_ctx((platform::CPUPlace()));
+ std::ostringstream oss;
+ SerializeToStream(oss, lod_tensor_, cpu_ctx);
+ std::istringstream iss(oss.str());
+ DeserializeFromStream(iss, &dst_tensor);
+ float* dst_ptr = dst_tensor.mutable_data(platform::CPUPlace());
+ for (int i = 0; i < kLodTensorSize; ++i) {
+ EXPECT_EQ(dst_ptr[i], i);
+ }
+ EXPECT_EQ(dst_tensor.lod(), lod_tensor_.lod());
+}
+
TEST(LodExpand, test) {
LoD lod{{0, 2}};
LoDTensor tensor;
diff --git a/paddle/framework/op_desc.cc b/paddle/framework/op_desc.cc
index b361e64438251c1df827667fb825e7f5909fb09e..781bbb4c19f1c610df485c3061ca8b510e727019 100644
--- a/paddle/framework/op_desc.cc
+++ b/paddle/framework/op_desc.cc
@@ -88,6 +88,14 @@ OpDesc::OpDesc(const std::string &type, const VariableNameMap &inputs,
need_update_ = true;
}
+void OpDesc::CopyFrom(const OpDesc &op_desc) {
+ desc_.set_type(op_desc.Type());
+ inputs_ = op_desc.inputs_;
+ outputs_ = op_desc.outputs_;
+ attrs_ = op_desc.attrs_;
+ need_update_ = true;
+}
+
OpDesc::OpDesc(const proto::OpDesc &desc, ProgramDesc *prog)
: desc_(desc), need_update_(false) {
// restore inputs_
diff --git a/paddle/framework/op_desc.h b/paddle/framework/op_desc.h
index 93d4a88f3c390551ab41e42ec2f6f30f52e306db..4cf784a0d0d319d09caa27b4e2b589bd7ac4f324 100644
--- a/paddle/framework/op_desc.h
+++ b/paddle/framework/op_desc.h
@@ -35,6 +35,8 @@ class OpDesc {
OpDesc(const proto::OpDesc &desc, ProgramDesc *prog);
+ void CopyFrom(const OpDesc &op_desc);
+
proto::OpDesc *Proto();
std::string Type() const { return desc_.type(); }
diff --git a/paddle/framework/op_info.cc b/paddle/framework/op_info.cc
index 81ba29797c5f478e5d6a91236f3e8de1e6b43e49..b520108109bb2f72b80f83559fa065a5ca58e9e1 100644
--- a/paddle/framework/op_info.cc
+++ b/paddle/framework/op_info.cc
@@ -1,16 +1,16 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
- http://www.apache.org/licenses/LICENSE-2.0
+ http://www.apache.org/licenses/LICENSE-2.0
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License. */
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
#include "paddle/framework/op_info.h"
diff --git a/paddle/framework/op_info.h b/paddle/framework/op_info.h
index 7772d6e745c2207024863d3dd5cbef052358272e..d9b89f9cac9611fcecb18bef87940632df1e2234 100644
--- a/paddle/framework/op_info.h
+++ b/paddle/framework/op_info.h
@@ -1,16 +1,16 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
- http://www.apache.org/licenses/LICENSE-2.0
+ http://www.apache.org/licenses/LICENSE-2.0
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License. */
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
#pragma once
#include
diff --git a/paddle/framework/op_kernel_type.h b/paddle/framework/op_kernel_type.h
index e9c45b958cd0a65bca62099324e951f298d9ecb1..b06002096fb109da806809f7b908d9768cf095ba 100644
--- a/paddle/framework/op_kernel_type.h
+++ b/paddle/framework/op_kernel_type.h
@@ -40,6 +40,7 @@ struct OpKernelType {
// place, data_type, library_type kinds less than 2^8
constexpr static int LEFT_SHIFT = 8;
+
proto::DataType data_type_;
DataLayout data_layout_;
platform::Place place_;
@@ -67,6 +68,8 @@ struct OpKernelType {
data_type_ == o.data_type_ && data_layout_ == o.data_layout_ &&
library_type_ == o.library_type_;
}
+
+ bool operator!=(const OpKernelType& o) const { return !(*this == o); }
};
inline std::ostream& operator<<(std::ostream& os,
@@ -77,5 +80,11 @@ inline std::ostream& operator<<(std::ostream& os,
return os;
}
+inline std::string KernelTypeToString(const OpKernelType& kernel_key) {
+ std::ostringstream stream;
+ stream << kernel_key;
+ return stream.str();
+}
+
} // namespace framework
} // namespace paddle
diff --git a/paddle/framework/op_kernel_type_test.cc b/paddle/framework/op_kernel_type_test.cc
index 8753d7cc378662ce116e447dc6a340a07e5dd2ca..649afeee8a846b0579545f2edff77e9dbe3b4dd8 100644
--- a/paddle/framework/op_kernel_type_test.cc
+++ b/paddle/framework/op_kernel_type_test.cc
@@ -26,10 +26,8 @@ TEST(OpKernelType, ToString) {
OpKernelType op_kernel_type(DataType::FP32, CPUPlace(), DataLayout::kNCHW,
LibraryType::kCUDNN);
- std::ostringstream stream;
- stream << op_kernel_type;
ASSERT_EQ(
- stream.str(),
+ paddle::framework::KernelTypeToString(op_kernel_type),
"data_type[5]:data_layout[NCHW]:place[CPUPlace]:library_type[CUDNN]");
}
@@ -48,4 +46,4 @@ TEST(OpKernelType, Hash) {
OpKernelType::Hash hasher;
ASSERT_NE(hasher(op_kernel_type_1), hasher(op_kernel_type_2));
-}
\ No newline at end of file
+}
diff --git a/paddle/framework/op_registry.h b/paddle/framework/op_registry.h
index 9bb2a3b5c2931d03152cc3262c0ad8da17b8aacb..bdaa25918155caca4b64b0ed60aa3f6be03eb12f 100644
--- a/paddle/framework/op_registry.h
+++ b/paddle/framework/op_registry.h
@@ -79,30 +79,31 @@ struct OpKernelRegistrarFunctor {
using KERNEL_TYPE =
typename std::tuple_element>::type;
- void operator()(const char* op_type) const {
+ void operator()(const char* op_type, const char* library_type) const {
using T = typename KERNEL_TYPE::ELEMENT_TYPE;
- OpKernelType key(ToDataType(std::type_index(typeid(T))), PlaceType());
+ OpKernelType key(ToDataType(std::type_index(typeid(T))), PlaceType(),
+ DataLayout::kAnyLayout, StringToLibraryType(library_type));
OperatorWithKernel::AllOpKernels()[op_type][key].reset(new KERNEL_TYPE);
constexpr auto size = std::tuple_size>::value;
OpKernelRegistrarFunctor
func;
- func(op_type);
+ func(op_type, library_type);
}
};
template
struct OpKernelRegistrarFunctor {
- void operator()(const char* op_type) const {}
+ void operator()(const char* op_type, const char* library_type) const {}
};
// User can register many kernel in one place. The data type could be different.
template
class OpKernelRegistrar : public Registrar {
public:
- explicit OpKernelRegistrar(const char* op_type) {
+ explicit OpKernelRegistrar(const char* op_type, const char* library_type) {
OpKernelRegistrarFunctor func;
- func(op_type);
+ func(op_type, library_type);
}
};
@@ -181,7 +182,8 @@ class OpKernelRegistrar : public Registrar {
__reg_op_kernel_##op_type##_##DEVICE_TYPE##__, \
"REGISTER_OP_KERNEL must be called in global namespace"); \
static ::paddle::framework::OpKernelRegistrar \
- __op_kernel_registrar_##op_type##_##DEVICE_TYPE##__(#op_type); \
+ __op_kernel_registrar_##op_type##_##DEVICE_TYPE##__(#op_type, \
+ #DEVICE_TYPE); \
int TouchOpKernelRegistrar_##op_type##_##DEVICE_TYPE() { \
__op_kernel_registrar_##op_type##_##DEVICE_TYPE##__.Touch(); \
return 0; \
diff --git a/paddle/framework/op_registry_test.cc b/paddle/framework/op_registry_test.cc
index 4cdf6e0865e0922b72bd184172f85a9c705dcd00..cef530c6e639f6e2188869fa57d114ec6b885aa8 100644
--- a/paddle/framework/op_registry_test.cc
+++ b/paddle/framework/op_registry_test.cc
@@ -1,3 +1,17 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License. */
+
#include "paddle/framework/op_registry.h"
#include
@@ -182,3 +196,71 @@ TEST(OperatorRegistrar, Test) {
using namespace paddle::framework;
OperatorRegistrar reg("cos");
}
+
+namespace paddle {
+namespace framework {
+
+class OpKernelTestMaker : public OpProtoAndCheckerMaker {
+ public:
+ OpKernelTestMaker(OpProto* proto, OpAttrChecker* op_checker)
+ : OpProtoAndCheckerMaker(proto, op_checker) {
+ AddComment("NoGradOp, same input output. no Grad");
+ }
+};
+
+class OpWithKernelTest : public OperatorWithKernel {
+ public:
+ using OperatorWithKernel::OperatorWithKernel;
+
+ protected:
+ void InferShape(InferShapeContext* ctx) const override {}
+
+ framework::OpKernelType GetActualKernelType(
+ const framework::ExecutionContext& ctx) const override {
+ return framework::OpKernelType(proto::DataType::FP32, ctx.device_context());
+ }
+};
+
+template
+class OpKernelTest : public paddle::framework::OpKernel {
+ public:
+ void Compute(const paddle::framework::ExecutionContext& ctx) const {}
+};
+
+} // namespace framework
+} // namespace paddle
+
+REGISTER_OP_WITHOUT_GRADIENT(op_with_kernel,
+ paddle::framework::OpWithKernelTest,
+ paddle::framework::OpKernelTestMaker);
+REGISTER_OP_CPU_KERNEL(
+ op_with_kernel,
+ paddle::framework::OpKernelTest);
+
+REGISTER_OP_CUDA_KERNEL(op_with_kernel,
+ paddle::framework::OpKernelTest<
+ paddle::platform::CUDADeviceContext, float>);
+
+TEST(OperatorRegistrar, CPU) {
+ paddle::framework::proto::OpDesc op_desc;
+ paddle::platform::CPUPlace cpu_place;
+ paddle::framework::Scope scope;
+
+ op_desc.set_type("op_with_kernel");
+ auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
+
+ op->Run(scope, cpu_place);
+}
+
+#ifdef PADDLE_WITH_CUDA
+TEST(OperatorRegistrar, CUDA) {
+ paddle::framework::proto::OpDesc op_desc;
+ paddle::platform::CUDAPlace cuda_place(0);
+ paddle::framework::Scope scope;
+
+ op_desc.set_type("op_with_kernel");
+ auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
+
+ op->Run(scope, cuda_place);
+}
+#endif
diff --git a/paddle/framework/operator.cc b/paddle/framework/operator.cc
index 66840a2e037e7ca0fd1eacc64421865b170b47f8..a3ce96c409675ad52a811586c736ca22b5c7e99e 100644
--- a/paddle/framework/operator.cc
+++ b/paddle/framework/operator.cc
@@ -15,6 +15,7 @@ limitations under the License. */
#include
#include
+#include "paddle/framework/data_transform.h"
#include "paddle/framework/executor.h"
#include "paddle/framework/lod_tensor_array.h"
#include "paddle/framework/operator.h"
@@ -383,12 +384,30 @@ class RuntimeInferShapeContext : public InferShapeContext {
const Scope& scope_;
};
+const platform::DeviceContext* GetDeviceContext(
+ framework::KernelTypePair& kernel_pair) {
+ auto& actual_kernel_key = kernel_pair.first;
+ auto& expected_kernel_key = kernel_pair.second;
+ platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
+
+ if (platform::is_gpu_place(actual_kernel_key.place_) &&
+ platform::is_cpu_place(expected_kernel_key.place_)) {
+ return pool.Get(actual_kernel_key.place_);
+ } else if (platform::is_cpu_place(actual_kernel_key.place_) &&
+ platform::is_gpu_place(expected_kernel_key.place_)) {
+ return pool.Get(expected_kernel_key.place_);
+ } else {
+ PADDLE_THROW(
+ "Currently, model parallelism is only supported between CPU and CUDA");
+ }
+}
+
void OperatorWithKernel::Run(const Scope& scope,
const platform::Place& place) const {
RuntimeInferShapeContext infer_shape_ctx(*this, scope);
this->InferShape(&infer_shape_ctx);
- platform::DeviceContextPool& pool = platform::DeviceContextPool::Get();
- auto dev_ctx = pool.Borrow(place);
+ platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
+ auto dev_ctx = pool.Get(place);
// check if op[type] has kernel registered.
auto& all_op_kernels = AllOpKernels();
@@ -411,6 +430,47 @@ void OperatorWithKernel::Run(const Scope& scope,
expected_kernel_key);
}
+ if (actual_kernel_key == expected_kernel_key) {
+ PADDLE_ENFORCE_EQ(actual_kernel_key.place_, expected_kernel_key.place_,
+ "Currently, model parallelism is only supported between "
+ "CPU and other devices. For example, multi-GPU model "
+ "parallelism will failed.");
+ } else {
+ auto kernel_pair = std::make_pair(actual_kernel_key, expected_kernel_key);
+ const DataTransformFn* trans_fun =
+ DataTransformFnMap::Instance().GetNullable(kernel_pair);
+ if (trans_fun) {
+ auto input_vars = this->InputVars();
+ // TODO(qijun) filter the input vars that do not need to be transformed
+
+ // filter vars that has been transformed
+ std::vector need_trans;
+ for (auto var_name : input_vars) {
+ auto var_name_trans =
+ var_name + framework::KernelTypeToString(expected_kernel_key);
+ if (!scope.FindVar(var_name_trans)) {
+ const_cast(scope).Var(var_name_trans);
+ need_trans.push_back(var_name);
+ }
+ }
+
+ if (!need_trans.empty()) {
+ auto trans_dev_ctx = GetDeviceContext(kernel_pair);
+
+ // Wait for transform starting
+ dev_ctx->Wait();
+
+ for (auto var_name : need_trans) {
+ (*trans_fun)(trans_dev_ctx, *(scope.FindVar(var_name)),
+ scope.FindVar(var_name + framework::KernelTypeToString(
+ expected_kernel_key)));
+ }
+ // Wait for data transform finishing
+ trans_dev_ctx->Wait();
+ }
+ }
+ }
+
kernel_iter->second->Compute(ctx);
}
diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h
index 55eed57e6665515aec36dab4be8028dc75dbf7f3..d0a9b643d565d6651fd7ec0b515f088362852ba3 100644
--- a/paddle/framework/operator.h
+++ b/paddle/framework/operator.h
@@ -89,6 +89,9 @@ class OperatorBase {
/// Net will call this function to Run an op.
virtual void Run(const Scope& scope, const platform::Place& place) const = 0;
+ // FIXME(typhoonzero): this is only used for recv_op to stop event_loop.
+ virtual void Stop() {}
+
virtual bool IsNetOp() const { return false; }
virtual bool SupportGPU() const { return false; }
diff --git a/paddle/framework/program_desc_test.cc b/paddle/framework/program_desc_test.cc
index a49886f7ea56bc57459202dba65e3f76a902cd70..59947c9f2189348226b7ff6c2b9315196bbf55fa 100644
--- a/paddle/framework/program_desc_test.cc
+++ b/paddle/framework/program_desc_test.cc
@@ -1,16 +1,16 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
- http://www.apache.org/licenses/LICENSE-2.0
+ http://www.apache.org/licenses/LICENSE-2.0
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License. */
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
#include "paddle/framework/program_desc.h"
#include "gtest/gtest.h"
diff --git a/paddle/framework/prune_test.cc b/paddle/framework/prune_test.cc
index bdd57659432ea4f9bdd05425a802110b0c202fb8..d76c5abca94cb87220ce73537a8657c3ec695f4d 100644
--- a/paddle/framework/prune_test.cc
+++ b/paddle/framework/prune_test.cc
@@ -1,16 +1,16 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
- http://www.apache.org/licenses/LICENSE-2.0
+ http://www.apache.org/licenses/LICENSE-2.0
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License. */
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
#include "paddle/framework/prune.h"
diff --git a/paddle/framework/scope.cc b/paddle/framework/scope.cc
index 656736e23846c8de50553a608c54a0bdd3272cb1..0c01d605bcd95f5796fba1e5a3351a2640b2898a 100644
--- a/paddle/framework/scope.cc
+++ b/paddle/framework/scope.cc
@@ -74,17 +74,9 @@ void Scope::DropKids() {
kids_.clear();
}
-std::vector Scope::GetAllNames(bool recursive) const {
- std::vector known_vars(vars_.size());
-
- if (recursive) {
- for (auto& kid : kids_) {
- auto kid_vars = kid->GetAllNames();
- for (auto& p : kid_vars) {
- known_vars.emplace_back(p);
- }
- }
- }
+std::vector Scope::LocalVarNames() const {
+ std::vector known_vars;
+ known_vars.reserve(this->vars_.size());
for (auto& p : vars_) {
known_vars.emplace_back(p.first);
}
diff --git a/paddle/framework/scope.h b/paddle/framework/scope.h
index 56e815db54b6385c4e4d87f456ed5d59113ca77b..10143326dfa201894c777b3e5e226d5ca5015eda 100644
--- a/paddle/framework/scope.h
+++ b/paddle/framework/scope.h
@@ -66,7 +66,7 @@ class Scope {
void DropKids();
// enumerate all the variables current contains.
- std::vector GetAllNames(bool recursive = false) const;
+ std::vector LocalVarNames() const;
// Rename variable to a new name
void Rename(const std::string& origin_name,
diff --git a/paddle/framework/scope_test.cc b/paddle/framework/scope_test.cc
index f738d5ba9ecda57ea25bb5f84057d1d0106eef66..0f5b86061dbdebde08badca7f984f4a2c8d7bc79 100644
--- a/paddle/framework/scope_test.cc
+++ b/paddle/framework/scope_test.cc
@@ -61,7 +61,7 @@ TEST(Scope, GetAllNames) {
Variable* v = s.Var("a");
EXPECT_EQ(&s, s.FindScope(v));
- std::vector ans = s.GetAllNames();
+ std::vector ans = s.LocalVarNames();
std::string str;
for (auto& var : ans) {
str += var;
diff --git a/paddle/framework/selected_rows.cc b/paddle/framework/selected_rows.cc
index c74459c9dd7006a24615b1d6df041583088fb25c..82adfa7123a3cf40d929021602c45fe7d2e34ffa 100644
--- a/paddle/framework/selected_rows.cc
+++ b/paddle/framework/selected_rows.cc
@@ -12,5 +12,58 @@ limitations under the License. */
#include "paddle/framework/selected_rows.h"
namespace paddle {
-namespace framework {} // namespace framework
+namespace framework {
+void SerializeToStream(std::ostream& os, const SelectedRows& selected_rows,
+ const platform::DeviceContext& dev_ctx) {
+ { // the 1st field, uint32_t version
+ constexpr uint32_t version = 0;
+ os.write(reinterpret_cast(&version), sizeof(version));
+ }
+ {
+ // the 2st field, rows information
+ auto& rows = selected_rows.rows();
+ uint64_t size = rows.size();
+ os.write(reinterpret_cast(&size), sizeof(size));
+ for (uint64_t i = 0; i < size; ++i) {
+ os.write(reinterpret_cast(&rows[i]), sizeof(rows[i]));
+ }
+ }
+ {
+ // the 3st field, the height of SelectedRows
+ int64_t height = selected_rows.height();
+ os.write(reinterpret_cast(&height), sizeof(height));
+ }
+ // the 4st field, Tensor data
+ SerializeToStream(os, selected_rows.value(), dev_ctx);
+}
+
+void DeserializeFromStream(std::istream& is, SelectedRows* selected_rows) {
+ auto tensor = *selected_rows->mutable_value();
+ {
+ // the 1st field, unit32_t version for SelectedRows
+ uint32_t version;
+ is.read(reinterpret_cast(&version), sizeof(version));
+ PADDLE_ENFORCE_EQ(version, 0U, "Only version 0 is supported");
+ }
+ {
+ // the 2st field, rows information
+ uint64_t size;
+ is.read(reinterpret_cast(&size), sizeof(size));
+ auto& rows = *selected_rows->mutable_rows();
+ rows.resize(size);
+ for (uint64_t i = 0; i < size; ++i) {
+ is.read(reinterpret_cast(&rows[i]), sizeof(int64_t));
+ }
+ }
+ {
+ // the 3st field, the height of the SelectedRows
+ int64_t height;
+ is.read(reinterpret_cast(&height), sizeof(int64_t));
+ selected_rows->set_height(height);
+ }
+ // the 4st field, tensor which contains the data
+ DeserializeFromStream(is, &tensor);
+}
+
+} // namespace framework
} // namespace paddle
diff --git a/paddle/framework/selected_rows.h b/paddle/framework/selected_rows.h
index 0332b91323e3a4b4b80e02302ad3dcafe0986cde..699e392688e9889f050592172f8bfc45f855d0b1 100644
--- a/paddle/framework/selected_rows.h
+++ b/paddle/framework/selected_rows.h
@@ -59,5 +59,14 @@ class SelectedRows {
int64_t height_;
};
+/*
+ * Serialize/Desiralize SelectedRows to std::ostream
+ * You can pass ofstream or ostringstream to serilize to file
+ * or to a in memory string. GPU tensor will be copied to CPU.
+ */
+void SerializeToStream(std::ostream& os, const SelectedRows& selected_rows,
+ const platform::DeviceContext& dev_ctx);
+void DeserializeFromStream(std::istream& is, SelectedRows* selected_rows);
+
} // namespace framework
} // namespace paddle
diff --git a/paddle/framework/selected_rows_test.cc b/paddle/framework/selected_rows_test.cc
index 4ee13a65d72e44693573397bb686b355effb2227..75487c4010391aa9e519d73058184fa936dabb84 100644
--- a/paddle/framework/selected_rows_test.cc
+++ b/paddle/framework/selected_rows_test.cc
@@ -43,5 +43,19 @@ TEST_F(SelectedRowsTester, complete_dims) {
ASSERT_EQ(selected_rows_->GetCompleteDims(), make_ddim({10, 100}));
}
+TEST_F(SelectedRowsTester, SerializeAndDeseralize) {
+ SelectedRows dst_tensor;
+ platform::CPUDeviceContext cpu_ctx(place_);
+ std::ostringstream oss;
+
+ SerializeToStream(oss, *selected_rows_, cpu_ctx);
+
+ std::istringstream iss(oss.str());
+ DeserializeFromStream(iss, &dst_tensor);
+
+ ASSERT_EQ(selected_rows_->rows(), dst_tensor.rows());
+ ASSERT_EQ(selected_rows_->height(), dst_tensor.height());
+}
+
} // namespace framework
} // namespace paddle
diff --git a/paddle/framework/shape_inference.cc b/paddle/framework/shape_inference.cc
index 86dc01665bda5e7f988e60780c0600b049d737ef..e53cc0cdabc623ae358f1a3e21823a2f38ec3c62 100644
--- a/paddle/framework/shape_inference.cc
+++ b/paddle/framework/shape_inference.cc
@@ -1,16 +1,16 @@
-/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
- http://www.apache.org/licenses/LICENSE-2.0
+ http://www.apache.org/licenses/LICENSE-2.0
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License. */
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
#include "paddle/framework/shape_inference.h"
#include "grad_op_desc_maker.h"
#include "paddle/framework/operator.h"
diff --git a/paddle/framework/tensor.cc b/paddle/framework/tensor.cc
index ea7b2a1f7b17d9abc2c2e14de5ecd1cf4a1a5027..f922e606249849e621e679f71d6dbe0f007c8464 100644
--- a/paddle/framework/tensor.cc
+++ b/paddle/framework/tensor.cc
@@ -1,16 +1,16 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
- http://www.apache.org/licenses/LICENSE-2.0
+ http://www.apache.org/licenses/LICENSE-2.0
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License. */
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
#include "paddle/framework/tensor.h"
diff --git a/paddle/framework/tensor.h b/paddle/framework/tensor.h
index 6a0c5133c9a6bb326ca51755242e75b6eb9e5474..341a6949beeb2dfa64b23d2079bd8f48750a94f8 100644
--- a/paddle/framework/tensor.h
+++ b/paddle/framework/tensor.h
@@ -20,12 +20,12 @@ limitations under the License. */
#include
#include
+#include "paddle/framework/data_layout.h"
#include "paddle/framework/ddim.h"
#include "paddle/memory/memory.h"
#include "paddle/platform/device_context.h"
#include "paddle/platform/enforce.h"
#include "paddle/platform/place.h"
-#include "unsupported/Eigen/CXX11/Tensor"
namespace paddle {
@@ -115,6 +115,10 @@ class Tensor {
inline void check_memory_size() const;
+ inline DataLayout layout() const { return layout_; }
+
+ inline void set_layout(const DataLayout layout) { layout_ = layout; }
+
private:
friend class LoDTensor;
@@ -173,6 +177,19 @@ class Tensor {
DDim dims_;
+ /**
+ * @brief the layout of memory block, default is NHWC.
+ *
+ * @note the memory allocation order, describe how weight/data is stored
+ * For example, in 4-D Tensor(rank=4), there are three commonly
+ * used layout. They are
+ * NCHW, NHWC, CHWN.
+ * N,C,H,W for respectively the batch size, the number of
+ * feature maps, the height.
+ */
+
+ DataLayout layout_ = DataLayout::kNHWC;
+
/**
* @brief A PlaceHolder may be shared by more than one tensor.
*
diff --git a/paddle/framework/tensor_impl.h b/paddle/framework/tensor_impl.h
index 3d93b7808bc96143b5261e1ec41ad98b15c74975..6c6f298edc187a87677089e54c4c9046821282df 100644
--- a/paddle/framework/tensor_impl.h
+++ b/paddle/framework/tensor_impl.h
@@ -165,6 +165,7 @@ inline Tensor Tensor::Slice(int begin_idx, int end_idx) const {
size_t base = numel() / dims_[0];
Tensor dst;
dst.holder_ = holder_;
+ dst.set_layout(layout_);
DDim dst_dims = dims_;
dst_dims[0] = end_idx - begin_idx;
dst.Resize(dst_dims);
diff --git a/paddle/framework/tensor_test.cc b/paddle/framework/tensor_test.cc
index f347981f2e11fdc4770df8d5e7277cd55744ab3f..a1b4a03289eca4c8b9d8c23ede4221853cb31f79 100644
--- a/paddle/framework/tensor_test.cc
+++ b/paddle/framework/tensor_test.cc
@@ -15,12 +15,13 @@
#include
#include
+namespace framework = paddle::framework;
+namespace platform = paddle::platform;
+
TEST(Tensor, Dims) {
- using namespace paddle::framework;
- using namespace paddle::platform;
- Tensor tt;
+ framework::Tensor tt;
tt.Resize({2, 3, 4});
- DDim dims = tt.dims();
+ framework::DDim dims = tt.dims();
ASSERT_EQ(arity(dims), 3);
for (int i = 0; i < 3; ++i) {
EXPECT_EQ(i + 2, dims[i]);
@@ -28,12 +29,12 @@ TEST(Tensor, Dims) {
}
TEST(Tensor, DataAssert) {
- paddle::framework::Tensor src_tensor;
+ framework::Tensor src_tensor;
bool caught = false;
try {
src_tensor.data();
- } catch (paddle::platform::EnforceNotMet err) {
+ } catch (platform::EnforceNotMet err) {
caught = true;
std::string msg =
"holder_ should not be null\nTensor holds no memory. Call "
@@ -50,61 +51,65 @@ TEST(Tensor, DataAssert) {
because Memory::Alloc() and Memory::Free() have not been ready.
*/
TEST(Tensor, MutableData) {
- using namespace paddle::framework;
- using namespace paddle::platform;
{
- Tensor src_tensor;
+ framework::Tensor src_tensor;
float* p1 = nullptr;
float* p2 = nullptr;
// initialization
- p1 = src_tensor.mutable_data(make_ddim({1, 2, 3}), CPUPlace());
+ p1 = src_tensor.mutable_data(framework::make_ddim({1, 2, 3}),
+ platform::CPUPlace());
EXPECT_NE(p1, nullptr);
// set src_tensor a new dim with large size
// momery is supposed to be re-allocated
- p2 = src_tensor.mutable_data(make_ddim({3, 4}), CPUPlace());
+ p2 = src_tensor.mutable_data(framework::make_ddim({3, 4}),
+ platform::CPUPlace());
EXPECT_NE(p2, nullptr);
EXPECT_NE(p1, p2);
// set src_tensor a new dim with same size
// momery block is supposed to be unchanged
- p1 = src_tensor.mutable_data(make_ddim({2, 2, 3}), CPUPlace());
+ p1 = src_tensor.mutable_data(framework::make_ddim({2, 2, 3}),
+ platform::CPUPlace());
EXPECT_EQ(p1, p2);
// set src_tensor a new dim with smaller size
// momery block is supposed to be unchanged
- p2 = src_tensor.mutable_data(make_ddim({2, 2}), CPUPlace());
+ p2 = src_tensor.mutable_data(framework::make_ddim({2, 2}),
+ platform::CPUPlace());
EXPECT_EQ(p1, p2);
}
#ifdef PADDLE_WITH_CUDA
{
- Tensor src_tensor;
+ framework::Tensor src_tensor;
float* p1 = nullptr;
float* p2 = nullptr;
// initialization
- p1 = src_tensor.mutable_data(make_ddim({1, 2, 3}), CUDAPlace());
+ p1 = src_tensor.mutable_data(framework::make_ddim({1, 2, 3}),
+ platform::CUDAPlace());
EXPECT_NE(p1, nullptr);
// set src_tensor a new dim with large size
// momery is supposed to be re-allocated
- p2 = src_tensor.mutable_data(make_ddim({3, 4}), CUDAPlace());
+ p2 = src_tensor.mutable_data(framework::make_ddim({3, 4}),
+ platform::CUDAPlace());
EXPECT_NE(p2, nullptr);
EXPECT_NE(p1, p2);
// set src_tensor a new dim with same size
// momery block is supposed to be unchanged
- p1 = src_tensor.mutable_data(make_ddim({2, 2, 3}), CUDAPlace());
+ p1 = src_tensor.mutable_data(framework::make_ddim({2, 2, 3}),
+ platform::CUDAPlace());
EXPECT_EQ(p1, p2);
// set src_tensor a new dim with smaller size
// momery block is supposed to be unchanged
- p2 = src_tensor.mutable_data(make_ddim({2, 2}), CUDAPlace());
+ p2 = src_tensor.mutable_data(framework::make_ddim({2, 2}),
+ platform::CUDAPlace());
EXPECT_EQ(p1, p2);
}
#endif
}
TEST(Tensor, ShareDataWith) {
- using namespace paddle::framework;
- using namespace paddle::platform;
{
- Tensor src_tensor;
- Tensor dst_tensor;
+ framework::Tensor src_tensor;
+ framework::Tensor dst_tensor;
// Try to share data form uninitialized tensor
bool caught = false;
try {
@@ -121,16 +126,18 @@ TEST(Tensor, ShareDataWith) {
}
ASSERT_TRUE(caught);
- src_tensor.mutable_data(make_ddim({2, 3, 4}), CPUPlace());
+ src_tensor.mutable_data(framework::make_ddim({2, 3, 4}),
+ platform::CPUPlace());
dst_tensor.ShareDataWith(src_tensor);
ASSERT_EQ(src_tensor.data(), dst_tensor.data());
}
#ifdef PADDLE_WITH_CUDA
{
- Tensor src_tensor;
- Tensor dst_tensor;
- src_tensor.mutable_data(make_ddim({2, 3, 4}), CUDAPlace());
+ framework::Tensor src_tensor;
+ framework::Tensor dst_tensor;
+ src_tensor.mutable_data(framework::make_ddim({2, 3, 4}),
+ platform::CUDAPlace());
dst_tensor.ShareDataWith(src_tensor);
ASSERT_EQ(src_tensor.data(), dst_tensor.data());
}
@@ -138,13 +145,12 @@ TEST(Tensor, ShareDataWith) {
}
TEST(Tensor, Slice) {
- using namespace paddle::framework;
- using namespace paddle::platform;
{
- Tensor src_tensor;
- src_tensor.mutable_data(make_ddim({5, 3, 4}), CPUPlace());
- Tensor slice_tensor = src_tensor.Slice(1, 3);
- DDim slice_dims = slice_tensor.dims();
+ framework::Tensor src_tensor;
+ src_tensor.mutable_data(framework::make_ddim({5, 3, 4}),
+ platform::CPUPlace());
+ framework::Tensor slice_tensor = src_tensor.Slice(1, 3);
+ framework::DDim slice_dims = slice_tensor.dims();
ASSERT_EQ(arity(slice_dims), 3);
EXPECT_EQ(slice_dims[0], 2);
EXPECT_EQ(slice_dims[1], 3);
@@ -153,11 +159,12 @@ TEST(Tensor, Slice) {
uintptr_t src_data_address =
reinterpret_cast(src_tensor.data());
uintptr_t src_mutable_data_address = reinterpret_cast(
- src_tensor.mutable_data(src_tensor.dims(), CPUPlace()));
+ src_tensor.mutable_data(src_tensor.dims(), platform::CPUPlace()));
uintptr_t slice_data_address =
reinterpret_cast(slice_tensor.data());
- uintptr_t slice_mutable_data_address = reinterpret_cast(
- slice_tensor.mutable_data(slice_tensor.dims(), CPUPlace()));
+ uintptr_t slice_mutable_data_address =
+ reinterpret_cast(slice_tensor.mutable_data(
+ slice_tensor.dims(), platform::CPUPlace()));
EXPECT_EQ(src_data_address, src_mutable_data_address);
EXPECT_EQ(slice_data_address, slice_mutable_data_address);
EXPECT_EQ(src_data_address + 3 * 4 * 1 * sizeof(int), slice_data_address);
@@ -165,22 +172,25 @@ TEST(Tensor, Slice) {
#ifdef PADDLE_WITH_CUDA
{
- Tensor src_tensor;
- src_tensor.mutable_data(make_ddim({6, 9}), CUDAPlace());
- Tensor slice_tensor = src_tensor.Slice(2, 6);
- DDim slice_dims = slice_tensor.dims();
+ framework::Tensor src_tensor;
+ src_tensor.mutable_data(framework::make_ddim({6, 9}),
+ platform::CUDAPlace());
+ framework::Tensor slice_tensor = src_tensor.Slice(2, 6);
+ framework::DDim slice_dims = slice_tensor.dims();
ASSERT_EQ(arity(slice_dims), 2);
EXPECT_EQ(slice_dims[0], 4);
EXPECT_EQ(slice_dims[1], 9);
uintptr_t src_data_address =
reinterpret_cast(src_tensor.data());
- uintptr_t src_mutable_data_address = reinterpret_cast(
- src_tensor.mutable_data(src_tensor.dims(), CUDAPlace()));
+ uintptr_t src_mutable_data_address =
+ reinterpret_cast(src_tensor.mutable_data(
+ src_tensor.dims(), platform::CUDAPlace()));
uintptr_t slice_data_address =
reinterpret_cast(slice_tensor.data());
- uintptr_t slice_mutable_data_address = reinterpret_cast(
- slice_tensor.mutable_data(slice_tensor.dims(), CUDAPlace()));
+ uintptr_t slice_mutable_data_address =
+ reinterpret_cast(slice_tensor.mutable_data(
+ slice_tensor.dims(), platform::CUDAPlace()));
EXPECT_EQ(src_data_address, src_mutable_data_address);
EXPECT_EQ(slice_data_address, slice_mutable_data_address);
EXPECT_EQ(src_data_address + 9 * 2 * sizeof(double), slice_data_address);
@@ -189,14 +199,19 @@ TEST(Tensor, Slice) {
}
TEST(Tensor, ReshapeToMatrix) {
- using namespace paddle::framework;
- using namespace paddle::platform;
- Tensor src;
- int* src_ptr = src.mutable_data({2, 3, 4, 9}, CPUPlace());
+ framework::Tensor src;
+ int* src_ptr = src.mutable_data({2, 3, 4, 9}, platform::CPUPlace());
for (int i = 0; i < 2 * 3 * 4 * 9; ++i) {
src_ptr[i] = i;
}
- Tensor res = ReshapeToMatrix(src, 2);
+ framework::Tensor res = framework::ReshapeToMatrix(src, 2);
ASSERT_EQ(res.dims()[0], 2 * 3);
ASSERT_EQ(res.dims()[1], 4 * 9);
}
+
+TEST(Tensor, Layout) {
+ framework::Tensor src;
+ ASSERT_EQ(src.layout(), framework::DataLayout::kNHWC);
+ src.set_layout(framework::DataLayout::kAnyLayout);
+ ASSERT_EQ(src.layout(), framework::DataLayout::kAnyLayout);
+}
diff --git a/paddle/framework/tensor_util.cc b/paddle/framework/tensor_util.cc
new file mode 100644
index 0000000000000000000000000000000000000000..7efc649d0bcda67c663d148e83bcbb6789b0f371
--- /dev/null
+++ b/paddle/framework/tensor_util.cc
@@ -0,0 +1,119 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License. */
+
+#include "paddle/framework/tensor_util.h"
+
+namespace paddle {
+namespace framework {
+template
+struct AnyDTypeVisitor {
+ Predicate predicate_;
+ const Tensor& tensor_;
+ const DevCtx& ctx_;
+ Tensor* out_;
+
+ AnyDTypeVisitor(Predicate predicate, const Tensor& tensor, const DevCtx& ctx,
+ Tensor* out)
+ : predicate_(predicate), tensor_(tensor), ctx_(ctx), out_(out) {}
+
+ template
+ void operator()() const {
+ auto t = EigenVector::Flatten(tensor_);
+ auto o = EigenScalar::From(*out_);
+ // return any of predicate_(t) is true.
+ o.device(*ctx_.eigen_device()) = predicate_(t).any();
+ }
+};
+
+template
+inline void AnyImpl(Predicate predicate, const framework::Tensor& tensor,
+ const DevCtx& ctx, framework::Tensor* out) {
+ VisitDataType(ToDataType(tensor.type()), AnyDTypeVisitor(
+ predicate, tensor, ctx, out));
+}
+
+template
+struct AnyVisitor : public boost::static_visitor {
+ const framework::Tensor& tensor_;
+ Predicate predicate_;
+
+ AnyVisitor(const framework::Tensor& tensor, Predicate predicate)
+ : tensor_(tensor), predicate_(std::move(predicate)) {}
+
+ template
+ bool operator()(const Place& place) const {
+ framework::Tensor out;
+ out.Resize({1});
+ out.mutable_data(place);
+ auto* ctx = platform::DeviceContextPool::Instance().GetByPlace(place);
+ AnyImpl(predicate_, tensor_, *ctx, &out);
+ return this->GetResult(out, place);
+ }
+
+ bool GetResult(const framework::Tensor& out,
+ const platform::CUDAPlace& gpu) const {
+ platform::CPUPlace cpu;
+ framework::Tensor tmp;
+ tmp.Resize({1});
+ tmp.mutable_data(cpu);
+ auto gpuctx = platform::DeviceContextPool::Instance().Get(gpu);
+ gpuctx->Wait();
+ CopyFrom(out, cpu, *gpuctx, &tmp);
+ gpuctx->Wait();
+ return GetResult(tmp, cpu);
+ }
+
+ bool GetResult(const framework::Tensor& out,
+ const platform::CPUPlace& cpu) const {
+ return *out.data();
+ }
+};
+
+template
+inline bool Any(const framework::Tensor& tensor, Predicate predicate) {
+ AnyVisitor visitor(tensor, predicate);
+ auto place = tensor.place();
+ return platform::VisitPlace(place, visitor);
+}
+
+struct HasNANPredicate {
+ template
+ auto operator()(const T& eigen_vec) const
+ -> decltype(std::declval().isnan()) {
+ // Cast eigen_vector to vector of bool. true if is inf.
+ return eigen_vec.isnan();
+ }
+};
+
+bool HasNAN(const framework::Tensor& tensor) {
+ HasNANPredicate predicate;
+ return Any(tensor, predicate);
+}
+
+struct HasInfPredicate {
+ template
+ auto operator()(const T& eigen_vec) const
+ -> decltype(std::declval().isinf()) {
+ // Cast eigen_vector to vector of bool. true if is inf.
+ return eigen_vec.isinf();
+ }
+};
+
+bool HasInf(const framework::Tensor& tensor) {
+ HasInfPredicate predicate;
+ return Any(tensor, predicate);
+}
+
+} // namespace framework
+} // namespace paddle
diff --git a/paddle/framework/tensor_util.cu b/paddle/framework/tensor_util.cu
new file mode 120000
index 0000000000000000000000000000000000000000..b00e6e59d93328bf3142597ea4de0dc225501e56
--- /dev/null
+++ b/paddle/framework/tensor_util.cu
@@ -0,0 +1 @@
+./tensor_util.cc
\ No newline at end of file
diff --git a/paddle/framework/tensor_util.h b/paddle/framework/tensor_util.h
index ebfb0e553877b776afe13d8a4e7a7ffa8710405e..6a21f8db1e3966fd23eee0da2346b2d61f9321fb 100644
--- a/paddle/framework/tensor_util.h
+++ b/paddle/framework/tensor_util.h
@@ -1,19 +1,23 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
- http://www.apache.org/licenses/LICENSE-2.0
+ http://www.apache.org/licenses/LICENSE-2.0
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License. */
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
#pragma once
+#include "paddle/framework/data_type.h"
+#include "paddle/framework/eigen.h"
+#include "paddle/framework/framework.pb.h"
#include "paddle/framework/tensor.h"
+#include "paddle/platform/device_context.h"
namespace paddle {
namespace framework {
@@ -33,6 +37,7 @@ inline void CopyFrom(const Tensor& src, const platform::Place& dst_place,
src.check_memory_size();
dst->Resize(src.dims());
+ dst->set_layout(src.layout());
auto src_place = src.place();
auto src_ptr = src.data();
@@ -89,6 +94,7 @@ inline void CopyFrom(const Tensor& src, const platform::Place& dst_place,
Tensor* dst) {
src.check_memory_size();
dst->Resize(src.dims());
+ dst->set_layout(src.layout());
auto src_place = src.place();
auto src_ptr = src.data();
@@ -203,5 +209,109 @@ inline void CopyToVector(const Tensor& src, std::vector* dst) {
src_ptr, size);
}
+// Returns true if a tensor contains NAN, i.e., Not A Number.
+bool HasNAN(const framework::Tensor& tensor);
+
+// Returns true if a tensor contains Inf, i.e., Infinity.
+bool HasInf(const framework::Tensor& tensor);
+
+inline void SerializeToStream(std::ostream& os, const Tensor& tensor,
+ const platform::DeviceContext& dev_ctx) {
+ // TODO(typhoonzero): serialize to ostream
+ { // the 1st field, uint32_t version
+ constexpr uint32_t version = 0;
+ os.write(reinterpret_cast(&version), sizeof(version));
+ }
+ { // the 2nd field, tensor description
+ // int32_t size
+ // void* protobuf message
+ proto::TensorDesc desc;
+ desc.set_data_type(framework::ToDataType(tensor.type()));
+ auto dims = framework::vectorize(tensor.dims());
+ auto* pb_dims = desc.mutable_dims();
+ pb_dims->Resize(static_cast(dims.size()), 0);
+ std::copy(dims.begin(), dims.end(), pb_dims->begin());
+ int32_t size = desc.ByteSize();
+ os.write(reinterpret_cast(&size), sizeof(size));
+ auto out = desc.SerializeAsString();
+ os.write(out.data(), size);
+ }
+ { // the 3rd field, tensor data
+ uint64_t size = tensor.memory_size();
+ auto* data_ptr = tensor.data();
+ PADDLE_ENFORCE(size < std::numeric_limits::max(),
+ "Index overflow when writing tensor");
+ if (platform::is_gpu_place(tensor.place())) {
+#ifdef PADDLE_WITH_CUDA
+ constexpr size_t kBufSize = 1024 * 1024 * 64; // 64MB
+ std::unique_ptr buf(new char[kBufSize]);
+ auto& gpu_dev_ctx =
+ static_cast(dev_ctx);
+ platform::CPUPlace cpu;
+ uintptr_t data = reinterpret_cast(data_ptr);
+ while (size != 0) {
+ size_t size_to_write = std::min(kBufSize, static_cast(size));
+ memory::Copy(cpu, buf.get(),
+ boost::get(tensor.place()),
+ reinterpret_cast(data), size_to_write,
+ gpu_dev_ctx.stream());
+ gpu_dev_ctx.Wait();
+ os.write(buf.get(), size_to_write);
+ data += size_to_write;
+ size -= size_to_write;
+ }
+#else
+ PADDLE_THROW("Unexpected branch");
+#endif
+ } else {
+ os.write(static_cast(data_ptr),
+ static_cast(size));
+ }
+ }
+}
+
+inline void DeserializeFromStream(std::istream& is, Tensor* tensor) {
+ uint32_t version;
+ is.read(reinterpret_cast(&version), sizeof(version));
+ PADDLE_ENFORCE_EQ(version, 0U, "Only version 0 is supported");
+ proto::TensorDesc desc;
+ { // int32_t size
+ // proto buffer
+ int32_t size;
+ is.read(reinterpret_cast(&size), sizeof(size));
+ std::unique_ptr buf(new char[size]);
+ is.read(reinterpret_cast(buf.get()), size);
+ PADDLE_ENFORCE(desc.ParseFromArray(buf.get(), size),
+ "Cannot parse tensor desc");
+ }
+ { // read tensor
+ std::vector dims;
+ dims.reserve(static_cast(desc.dims().size()));
+ std::copy(desc.dims().begin(), desc.dims().end(), std::back_inserter(dims));
+ tensor->Resize(framework::make_ddim(dims));
+
+ void* buf;
+ platform::Place cpu = platform::CPUPlace();
+ // TODO(Yancey1989): use VisiterDataType instead of DataType switch
+ switch (desc.data_type()) {
+ case proto::FP32:
+ buf = tensor->mutable_data(cpu);
+ break;
+ case proto::FP64:
+ buf = tensor->mutable_data(cpu);
+ break;
+ case proto::INT32:
+ buf = tensor->mutable_data(cpu);
+ break;
+ case proto::INT64:
+ buf = tensor->mutable_data(cpu);
+ break;
+ default:
+ PADDLE_THROW("DataType %d not supported", desc.data_type());
+ }
+ is.read(static_cast(buf), tensor->memory_size());
+ }
+}
+
} // namespace framework
} // namespace paddle
diff --git a/paddle/framework/tensor_util_test.cc b/paddle/framework/tensor_util_test.cc
index 6fc243aaf6ea68d716e7bb6a73289197fde56247..0dc5166fcabf77b48b8681ab1f050e2bc88f44ab 100644
--- a/paddle/framework/tensor_util_test.cc
+++ b/paddle/framework/tensor_util_test.cc
@@ -13,6 +13,7 @@
#include "paddle/framework/tensor_util.h"
#include
+#include
#include
namespace paddle {
@@ -28,6 +29,7 @@ TEST(CopyFrom, Tensor) {
int arr[9] = {1, 2, 3, 4, 5, 6, 7, 8, 9};
memcpy(src_ptr, arr, 9 * sizeof(int));
+ src_tensor.set_layout(DataLayout::kAnyLayout);
auto cpu_place = new platform::CPUPlace();
CopyFrom(src_tensor, *cpu_place, &dst_tensor);
@@ -38,14 +40,18 @@ TEST(CopyFrom, Tensor) {
EXPECT_EQ(src_ptr[i], dst_ptr[i]);
}
+ EXPECT_TRUE(dst_tensor.layout() == src_tensor.layout());
+
Tensor slice_tensor = src_tensor.Slice(1, 2);
- CopyFrom(slice_tensor, *cpu_place, cpu_ctx, &dst_tensor);
+ CopyFrom(slice_tensor, *cpu_place, &dst_tensor);
const int* slice_ptr = slice_tensor.data();
dst_ptr = dst_tensor.data();
ASSERT_NE(dst_ptr, slice_ptr);
for (size_t i = 0; i < 3; ++i) {
EXPECT_EQ(dst_ptr[i], slice_ptr[i]);
}
+ EXPECT_TRUE(dst_tensor.layout() == src_tensor.layout());
+
#ifdef PADDLE_WITH_CUDA
{
Tensor src_tensor;
@@ -91,6 +97,8 @@ TEST(CopyFrom, Tensor) {
for (size_t i = 0; i < 3; ++i) {
EXPECT_EQ(dst_ptr[i], slice_ptr[i]);
}
+
+ EXPECT_TRUE(dst_tensor.layout() == src_tensor.layout());
}
#endif
}
@@ -223,5 +231,78 @@ TEST(CopyToVector, Tensor) {
#endif
}
+TEST(HasNAN, CPU) {
+ using namespace paddle::framework;
+ using namespace paddle::platform;
+ Tensor src;
+ float* buf = src.mutable_data({3}, CPUPlace());
+ buf[0] = 0.0;
+ buf[1] = NAN;
+ buf[2] = 0.0;
+
+ ASSERT_TRUE(HasNAN(src));
+}
+
+TEST(HasInf, CPU) {
+ using namespace paddle::framework;
+ using namespace paddle::platform;
+ Tensor src;
+ double* buf = src.mutable_data({3}, CPUPlace());
+ buf[0] = 1.0;
+ buf[1] = INFINITY;
+ buf[2] = 0.0;
+ ASSERT_TRUE(HasInf(src));
+}
+
+TEST(Tensor, SerializeAndDeserialize) {
+ framework::Tensor src_tensor;
+ int array[6] = {1, 2, 3, 4, 5, 6};
+ src_tensor.Resize({2, 3});
+ int* src_ptr = src_tensor.mutable_data(platform::CPUPlace());
+ for (int i = 0; i < 6; ++i) {
+ src_ptr[i] = array[i];
+ }
+ {
+ framework::Tensor dst_tensor;
+ auto place = new platform::CPUPlace();
+ platform::CPUDeviceContext cpu_ctx(*place);
+ std::ostringstream oss;
+ SerializeToStream(oss, src_tensor, cpu_ctx);
+
+ std::istringstream iss(oss.str());
+ DeserializeFromStream(iss, &dst_tensor);
+ int* dst_ptr = dst_tensor.mutable_data(platform::CPUPlace());
+ for (int i = 0; i < 5; ++i) {
+ ASSERT_EQ(dst_ptr[i], array[i]);
+ }
+ delete place;
+ }
+#ifdef PADDLE_WITH_CUDA
+ {
+ Tensor gpu_tensor;
+ gpu_tensor.Resize({2, 3});
+ Tensor dst_tensor;
+
+ auto gpu_place = new platform::CUDAPlace();
+ platform::CUDADeviceContext gpu_ctx(*gpu_place);
+
+ CopyFrom(src_tensor, *gpu_place, gpu_ctx, &gpu_tensor);
+
+ std::ostringstream oss;
+ SerializeToStream(oss, gpu_tensor, gpu_ctx);
+
+ std::istringstream iss(oss.str());
+ DeserializeFromStream(iss, &dst_tensor);
+
+ int* dst_ptr = dst_tensor.mutable_data(platform::CPUPlace());
+ for (int i = 0; i < 6; ++i) {
+ ASSERT_EQ(dst_ptr[i], array[i]);
+ }
+
+ delete gpu_place;
+ }
+#endif
+}
+
} // namespace framework
} // namespace paddle
diff --git a/paddle/framework/tensor_util_test.cu b/paddle/framework/tensor_util_test.cu
new file mode 100644
index 0000000000000000000000000000000000000000..ebd35fdf6c2a1388fec23057070f723c8ef9da9c
--- /dev/null
+++ b/paddle/framework/tensor_util_test.cu
@@ -0,0 +1,57 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License. */
+
+#include "gtest/gtest.h"
+#include "paddle/framework/tensor_util.h"
+#include "paddle/platform/device_context.h"
+#include "paddle/platform/place.h"
+
+namespace paddle {
+namespace framework {
+
+static __global__ void FillNAN(float* buf) {
+ buf[0] = 0.0;
+ buf[1] = 0.1;
+ buf[2] = NAN;
+}
+static __global__ void FillInf(float* buf) {
+ buf[0] = 0.0;
+ buf[1] = INFINITY;
+ buf[2] = 0.5;
+}
+
+TEST(HasNAN, GPU) {
+ Tensor tensor;
+ platform::CUDAPlace gpu(0);
+ auto& pool = platform::DeviceContextPool::Instance();
+ auto* cuda_ctx = pool.GetByPlace(gpu);
+ float* buf = tensor.mutable_data({3}, gpu);
+ FillNAN<<<1, 1, 0, cuda_ctx->stream()>>>(buf);
+ cuda_ctx->Wait();
+ ASSERT_TRUE(HasNAN(tensor));
+}
+
+TEST(HasInf, GPU) {
+ Tensor tensor;
+ platform::CUDAPlace gpu(0);
+ auto& pool = platform::DeviceContextPool::Instance();
+ auto* cuda_ctx = pool.GetByPlace(gpu);
+ float* buf = tensor.mutable_data({3}, gpu);
+ FillInf<<<1, 1, 0, cuda_ctx->stream()>>>(buf);
+ cuda_ctx->Wait();
+ ASSERT_TRUE(HasInf(tensor));
+}
+
+} // namespace framework
+} // namespace paddle
diff --git a/paddle/framework/threadpool.cc b/paddle/framework/threadpool.cc
new file mode 100644
index 0000000000000000000000000000000000000000..109a7e7dc440d91e8223f2c0924f489f54a06f64
--- /dev/null
+++ b/paddle/framework/threadpool.cc
@@ -0,0 +1,24 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#include "paddle/framework/threadpool.h"
+
+namespace paddle {
+namespace framework {
+
+std::unique_ptr ThreadPool::threadpool(nullptr);
+std::once_flag ThreadPool::init_flag;
+
+} // namespace framework
+} // namespace paddle
diff --git a/paddle/framework/threadpool.h b/paddle/framework/threadpool.h
index 9a1ece3ae8452399205e71fbaa26977710fcdcac..bcd8190755083ec30687675602a1c95a9c15c69e 100644
--- a/paddle/framework/threadpool.h
+++ b/paddle/framework/threadpool.h
@@ -13,24 +13,24 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
+
#include
-#include
#include
-#include
+#include
#include