& no_grad_vars);
-```
-
-The implementation behind it can be divided into two parts, **Backward Operator Creating** and **Backward Operator Building**.
-
-### Backward Operator Registry
-
-A backward network is built up with several backward operators. Backward operators take forward operators' inputs, outputs, and output gradients and then calculate its input gradients.
-
-| | forward operator | backward operator
-| ---------------------- | ---------------- |------------------------- |
-| **Operator::inputs_** | Inputs | Inputs, Outputs, OutputGradients |
-| **Operator::outputs_** | Outputs | InputGradients |
-
- In most cases, there is a one-to-one relation between the forward and backward operators. These relations are recorded by a global hash map(`OpInfoMap`). To follow the philosophy of minimum core and to make operators pluggable, the registry mechanism is introduced.
-
-For example, we have `mul_op`, and we can register its information and corresponding backward operator by the following macro:
-
-```cpp
-REGISTER_OP(mul, MulOp, MulOpMaker, mul_grad, MulOpGrad);
-```
-
-`mul` is the operator's type. `MulOp` and `MulOpMaker` are the operator class and the operator maker class respectively.
-
-`mul_grad` is the type of backward operator, and `MulOpGrad` is its class name.
-
-### Backward Opeartor Creating
-
-Given a certain forward operator, we can get its corresponding backward operator by calling:
-
-```cpp
-OperatorBase* bwd_op = BuildGradOp(const OperatorBase* fwd_op);
-```
-
-The function `BuildGradOp` will sequentially execute following processes:
-
-1. Get the `type_` of given forward operator, and then get the corresponding backward operator's type by looking up the `OpInfoMap`.
-
-2. Build two maps named `inputs` and `outputs` to temporarily store backward operator's inputs and outputs. Copy forward operator's `inputs_` and `outputs_` to map `inputs`, except these, are not necessary for gradient computing.
-
-3. Add forward inputs' gradient variables into map `output`, adding forward outputs' gradient variables into map `input`.
-
-4. Building backward operator with `inputs`, `outputs` and forward operator's attributes.
-
-### Backward Network Building
-
-A backward network is a series of backward operators. The main idea of building a backward network is creating backward operators in the inverted sequence and appending them together one by one. There are some corner cases that need special processing.
-
-1. Op
-
- When the input forward network is an Op, return its gradient Operator immediately. If all of its outputs are in no gradient set, then return a special `NOP`.
-
-2. NetOp
-
- In our design, the network itself is also a kind of operator(**NetOp**). So the operators contained by a big network may be some small network. When the input forward network is a NetOp, it needs to call the sub NetOp/Operators backward function recursively. During the process, we need to collect the `OutputGradients` name according to the forward NetOp.
-
-3. RnnOp
-
- RnnOp is a nested stepnet operator. Backward module needs to recusively call `Backward` for every stepnet.
-
-4. Sharing Variables
-
- As illustrated in the figure 1 and figure 2, two operators share the same variable name **W@GRAD**, which will overwrite their shared input variable.
-
-
-
-
- Figure 1. Sharing variables in operators.
-
-
-
- Sharing variable between operators or same input variable used in multiple operators can lead to duplicate gradient variables. As illustrated in figure 2, we need to rename the gradient names recursively and add a generic add operator to prevent overwriting.
-
-
-
-
- Figure 2. Replace sharing variable's gradient with `Add` operator.
-
-
-
- Because the framework finds variables according to their names, we need to rename the output links. We add an integer suffix to represent its position in the clockwise direction.
-
-5. Part of the Gradient is Zero.
-
- In the whole graph, there is some case of that one operator's gradient is not needed, but its input's gradient is a dependency link of other operator, we need to fill a same shape gradient matrix in the position. In our implementation, we insert a special `fillZeroLike` operator.
-
-
-Follow these rules above, then collect the sub graph `OutputGradients`/`InputGradients` as the NetOp's and return it.
diff --git a/paddle/framework/backward_test.cc b/paddle/framework/backward_test.cc
index 0957646b5642cd9afce5d88b2c638679cb01f198..692406b1c37d0c02714eafb5cf9a28329ed873bc 100644
--- a/paddle/framework/backward_test.cc
+++ b/paddle/framework/backward_test.cc
@@ -1,16 +1,16 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
- http://www.apache.org/licenses/LICENSE-2.0
+ http://www.apache.org/licenses/LICENSE-2.0
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License. */
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
#include "paddle/framework/backward.h"
diff --git a/paddle/framework/data_layout.h b/paddle/framework/data_layout.h
index 7429de7ee39297c26360984809e2451100f7b3ff..4a8669c3a41fceaad26878a79eabfd0affce86fd 100644
--- a/paddle/framework/data_layout.h
+++ b/paddle/framework/data_layout.h
@@ -13,11 +13,15 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
+#include "paddle/platform/enforce.h"
+
+#include
+#include "paddle/platform/enforce.h"
namespace paddle {
namespace framework {
-enum DataLayout {
+enum class DataLayout {
kNHWC = 0,
kNCHW = 1,
kAnyLayout = 2,
@@ -33,5 +37,23 @@ inline DataLayout StringToDataLayout(const std::string& str) {
}
}
+inline std::string DataLayoutToString(const DataLayout& data_layout) {
+ switch (data_layout) {
+ case DataLayout::kNHWC:
+ return "NHWC";
+ case DataLayout::kNCHW:
+ return "NCHW";
+ case DataLayout::kAnyLayout:
+ return "ANY_LAYOUT";
+ default:
+ PADDLE_THROW("unknown DataLayou %d", data_layout);
+ }
+}
+
+inline std::ostream& operator<<(std::ostream& out, DataLayout l) {
+ out << DataLayoutToString(l);
+ return out;
+}
+
} // namespace framework
} // namespace paddle
diff --git a/paddle/framework/data_transform.cc b/paddle/framework/data_transform.cc
new file mode 100644
index 0000000000000000000000000000000000000000..9d6a8424426a68ae66cf93b803c35e33e30226f2
--- /dev/null
+++ b/paddle/framework/data_transform.cc
@@ -0,0 +1,115 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#include "paddle/framework/data_transform.h"
+#include "paddle/framework/lod_tensor.h"
+#include "paddle/platform/device_context.h"
+
+namespace paddle {
+namespace framework {
+
+DataTransformFnMap& DataTransformFnMap::Instance() {
+ static DataTransformFnMap data_transform_map;
+ return data_transform_map;
+}
+
+auto KernelFP32 = OpKernelType(proto::DataType::FP32, platform::CPUPlace(),
+ DataLayout::kNHWC, LibraryType::kPlain);
+
+auto KernelFP64 = OpKernelType(proto::DataType::FP64, platform::CPUPlace(),
+ DataLayout::kNHWC, LibraryType::kPlain);
+
+auto KernelNHWC = OpKernelType(proto::DataType::FP64, platform::CPUPlace(),
+ DataLayout::kNHWC, LibraryType::kPlain);
+
+auto KernelNCHW = OpKernelType(proto::DataType::FP64, platform::CPUPlace(),
+ DataLayout::kNCHW, LibraryType::kPlain);
+
+void TransDataType(const platform::DeviceContext* ctx,
+ const KernelTypePair& kernel_pair, const Variable& in,
+ Variable* out) {
+ PADDLE_ENFORCE(in.IsType(), "Only Support Tensor transform!.");
+ PADDLE_ENFORCE(
+ platform::places_are_same_class(kernel_pair.first.place_,
+ kernel_pair.second.place_),
+ "TransDataType Only Support DataType transform on same place!");
+
+ auto src = in.Get();
+ auto* dst = out->GetMutable();
+
+ auto dims = src.dims();
+ dst->Resize(dims);
+ auto dst_type = kernel_pair.second.data_type_;
+ auto src_type = kernel_pair.first.data_type_;
+
+ switch (src_type) {
+ case proto::DataType::FP32:
+ framework::VisitDataType(dst_type, CastDataType(src, dst, ctx));
+ break;
+ case proto::DataType::FP64:
+ framework::VisitDataType(dst_type, CastDataType(src, dst, ctx));
+ break;
+ case proto::DataType::INT32:
+ framework::VisitDataType(dst_type, CastDataType(src, dst, ctx));
+ break;
+ case proto::DataType::INT64:
+ framework::VisitDataType(dst_type, CastDataType(src, dst, ctx));
+ break;
+ case proto::DataType::BOOL:
+ framework::VisitDataType(dst_type, CastDataType(src, dst, ctx));
+ break;
+ default:
+ PADDLE_THROW("Not support type %d", src_type);
+ }
+}
+
+void TransDataLayout(const platform::DeviceContext* ctx,
+ const KernelTypePair& kernel_pair, const Variable& in,
+ Variable* out) {
+ PADDLE_ENFORCE(in.IsType(), "Only Support Tensor transform!.");
+ PADDLE_ENFORCE(
+ platform::places_are_same_class(kernel_pair.first.place_,
+ kernel_pair.second.place_),
+ "TransDataType Only Support DataType transform on same place!");
+
+ auto src = in.Get();
+ auto* dst = out->GetMutable();
+ PADDLE_ENFORCE(arity(src.dims()) == 4, "Input Arity Only Suppport 4!");
+
+ auto src_dim = src.dims();
+ dst->Resize(src_dim);
+ auto place = kernel_pair.second.place_;
+ CopyFrom(src, place, *ctx, dst);
+ const std::vector axis = {0, 2, 3, 1};
+
+ std::vector dst_dim;
+ dst_dim.resize(axis.size());
+ for (size_t i = 0; i < axis.size(); i++) {
+ dst_dim[i] = src_dim[axis[i]];
+ }
+
+ dst->Resize(make_ddim(dst_dim));
+
+ auto src_type = kernel_pair.first.data_type_;
+ framework::VisitDataType(src_type, CastDataLayout(src, dst, ctx, axis));
+
+ dst->set_layout(kernel_pair.second.data_layout_);
+}
+
+} // namespace framework
+} // namespace paddle
+
+namespace f = paddle::framework;
+REGISTER_DATA_TRANSFORM_FN(f::KernelFP32, f::KernelFP64, f::TransDataType);
+REGISTER_DATA_TRANSFORM_FN(f::KernelNHWC, f::KernelNCHW, f::TransDataLayout);
diff --git a/paddle/framework/data_transform.h b/paddle/framework/data_transform.h
new file mode 100644
index 0000000000000000000000000000000000000000..9abb3c99bf30fcf9deab59dc7ee9c02e7c7c775b
--- /dev/null
+++ b/paddle/framework/data_transform.h
@@ -0,0 +1,171 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#pragma once
+
+#include
+#include
+#include
+
+#include "paddle/framework/op_kernel_type.h"
+#include "paddle/framework/tensor.h"
+#include "paddle/framework/variable.h"
+#include "paddle/operators/math/math_function.h"
+#include "paddle/platform/device_context.h"
+#include "paddle/platform/macros.h"
+#include "paddle/platform/transform.h"
+
+namespace paddle {
+namespace framework {
+
+using KernelTypePair = std::pair;
+
+using DataTransformFn =
+ std::function;
+
+struct KernelTypePairHash {
+ static void HashCombine(const OpKernelType& t, std::size_t* seed) {
+ OpKernelType::Hash kernel_type_hasher;
+ (*seed) ^= kernel_type_hasher(t) + 0x9e3779b9 + (*seed << 6) + (*seed >> 2);
+ }
+
+ size_t operator()(const KernelTypePair& kernel_pair) const {
+ std::size_t seed = 0;
+ HashCombine(kernel_pair.first, &seed);
+ HashCombine(kernel_pair.second, &seed);
+ return seed;
+ }
+};
+
+template
+struct CastDataTypeFunctor {
+ HOSTDEVICE inline OutType operator()(InType in) const {
+ return static_cast(in);
+ }
+};
+
+template
+struct CastDataType {
+ CastDataType(const framework::Tensor& in, framework::Tensor* out,
+ const platform::DeviceContext* ctx)
+ : in_(in), out_(out), ctx_(ctx) {}
+ const framework::Tensor in_;
+ framework::Tensor* out_;
+ const platform::DeviceContext* ctx_;
+
+ template
+ void operator()() {
+ auto place = ctx_->GetPlace();
+
+ auto* in_begin = in_.data();
+ auto numel = in_.numel();
+ auto* in_end = in_begin + numel;
+ auto* out_begin = out_->mutable_data(place);
+ if (platform::is_cpu_place(place)) {
+ platform::Transform trans;
+ auto* context = static_cast(ctx_);
+ trans(*context, in_begin, in_end, out_begin,
+ CastDataTypeFunctor());
+ } else {
+ // TODO(dzhwinter): enhance CopyFrom CPU<->GPU with different data type?
+ PADDLE_THROW("Unsupport CPU <-> GPU!");
+ }
+ }
+};
+
+struct CastDataLayout {
+ CastDataLayout(const framework::Tensor& in, framework::Tensor* out,
+ const platform::DeviceContext* ctx,
+ const std::vector& axis)
+ : in_(in), out_(out), ctx_(ctx), axis_(axis) {}
+ const framework::Tensor in_;
+ framework::Tensor* out_;
+ const platform::DeviceContext* ctx_;
+ const std::vector axis_;
+
+ template
+ void operator()() {
+ auto place = ctx_->GetPlace();
+ if (platform::is_cpu_place(place)) {
+ operators::math::Transpose trans4;
+ auto* context = static_cast(ctx_);
+ trans4(*context, in_, out_, axis_);
+ } else {
+ PADDLE_THROW("Unsupport CPU <-> GPU!");
+ }
+ }
+};
+
+using DataTransformMap =
+ std::unordered_map;
+
+class DataTransformFnMap {
+ public:
+ static DataTransformFnMap& Instance();
+
+ bool Has(const KernelTypePair& key_pair) const {
+ return map_.find(key_pair) != map_.end();
+ }
+
+ void Insert(const OpKernelType& left, const OpKernelType& right,
+ const DataTransformFn& data_tranform_fn) {
+ Insert(std::make_pair(left, right), data_tranform_fn);
+ }
+
+ void Insert(const KernelTypePair& kernel_type_pair,
+ const DataTransformFn& data_tranform_fn) {
+ PADDLE_ENFORCE(!Has(kernel_type_pair),
+ "KernelTypePair %s has been registered", "");
+ map_.insert({kernel_type_pair, data_tranform_fn});
+ }
+
+ const DataTransformFn& Get(const KernelTypePair& key_pair) const {
+ auto data_transformer = GetNullable(key_pair);
+ PADDLE_ENFORCE_NOT_NULL(data_transformer,
+ "DataTransformFn should not be NULL");
+ return *data_transformer;
+ }
+
+ const DataTransformFn* GetNullable(const KernelTypePair& key_pair) const {
+ auto it = map_.find(key_pair);
+ if (it == map_.end()) {
+ return nullptr;
+ } else {
+ return &(it->second);
+ }
+ }
+
+ const DataTransformMap& Map() const { return map_; }
+
+ private:
+ DataTransformFnMap() = default;
+ DataTransformMap map_;
+ DISABLE_COPY_AND_ASSIGN(DataTransformFnMap);
+};
+
+// generate unique name with __LINE__
+// refs https://stackoverflow.com/questions/1597007
+#define TOKENPASTE(x, y) x##y
+#define TOKENPASTE2(x, y) TOKENPASTE(x, y)
+#define REGISTER_DATA_TRANSFORM_FN(from, to, fn) \
+ static int TOKENPASTE2(fn_, __LINE__)() { \
+ ::paddle::framework::DataTransformFnMap::Instance().Insert(from, to, fn); \
+ return 0; \
+ } \
+ static int TOKENPASTE2(var_, __LINE__) __attribute__((unused)) = \
+ TOKENPASTE2(fn_, __LINE__)()
+
+} // namespace framework
+} // namespace paddle
diff --git a/paddle/framework/data_transform_test.cc b/paddle/framework/data_transform_test.cc
new file mode 100644
index 0000000000000000000000000000000000000000..8665b6248faa2d218230449c45a10f022f3fbf4f
--- /dev/null
+++ b/paddle/framework/data_transform_test.cc
@@ -0,0 +1,156 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+#include
+#include
+
+#include
+
+#include "paddle/framework/data_transform.h"
+#include "paddle/platform/device_context.h"
+
+namespace paddle {
+namespace framework {
+using namespace platform;
+
+/**
+ * @brief cross validation of different kernel type transform
+ * We use four bit map represent different combination.
+ * If the field has multiple possible value, only choose two of them.
+ * For DataType, only test the FP32(float), FP64(double).
+ * e.g. 0000 -> FP32, CPUPlace, kNHWC, kPlain
+ * 1111 -> FP64, GPUPlace, kNCHW, kMKLDNN
+ */
+
+std::array kDataType = {
+ {proto::DataType::FP32, proto::DataType::FP64}};
+
+std::array kPlace = {{CPUPlace(), CUDAPlace(0)}};
+
+std::array kDataLayout = {{
+ DataLayout::kNHWC, DataLayout::kNCHW,
+}};
+
+std::array kLibraryType = {{
+ LibraryType::kPlain, LibraryType::kMKLDNN,
+}};
+
+OpKernelType GenFromBit(const std::vector bits) {
+ return OpKernelType(kDataType[bits[0]], kPlace[bits[1]], kDataLayout[bits[2]],
+ kLibraryType[bits[3]]);
+}
+
+int test_value = 0;
+
+auto kernel0 = GenFromBit({0, 0, 0, 0});
+auto kernel1 = GenFromBit({0, 0, 0, 1});
+auto kernel2 = GenFromBit({0, 0, 1, 0});
+auto kernel3 = GenFromBit({0, 0, 1, 1});
+
+void TransDataType_t(const platform::DeviceContext* ctx,
+ const KernelTypePair& p, const Variable& in,
+ Variable* out) {
+ test_value++;
+}
+
+void TransDataLayout_t(const platform::DeviceContext* ctx,
+ const KernelTypePair& p, const Variable& in,
+ Variable* out) {
+ test_value--;
+}
+
+void TransLibraryType_t(const platform::DeviceContext* ctx,
+ const KernelTypePair& p, const Variable& in,
+ Variable* out) {
+ test_value += 2;
+}
+
+} // namespace framework
+} // namespace paddle
+
+namespace frw = paddle::framework;
+
+REGISTER_DATA_TRANSFORM_FN(frw::kernel0, frw::kernel1, frw::TransDataType_t);
+REGISTER_DATA_TRANSFORM_FN(frw::kernel1, frw::kernel2, frw::TransDataLayout_t);
+REGISTER_DATA_TRANSFORM_FN(frw::kernel0, frw::kernel2, frw::TransLibraryType_t);
+
+TEST(DataTransform, Register) {
+ using namespace paddle::framework;
+ using namespace paddle::platform;
+
+ auto& instance = DataTransformFnMap::Instance();
+ paddle::framework::Variable in;
+ paddle::framework::Variable out;
+
+ DeviceContext* ctx = new CPUDeviceContext();
+ auto pair0 = std::make_pair(frw::kernel0, frw::kernel1);
+ instance.Get(pair0)(ctx, pair0, in, &out);
+ ASSERT_EQ(test_value, 1);
+
+ auto pair1 = std::make_pair(frw::kernel1, frw::kernel2);
+ instance.Get(pair1)(ctx, pair1, in, &out);
+ ASSERT_EQ(test_value, 0);
+
+ auto pair3 = std::make_pair(frw::kernel0, frw::kernel2);
+ instance.Get(pair3)(ctx, pair3, in, &out);
+ ASSERT_EQ(test_value, 2);
+}
+
+TEST(DataTransform, Layout) {
+ using namespace paddle::framework;
+ using namespace paddle::platform;
+
+ auto& instance = DataTransformFnMap::Instance();
+ Variable in;
+ Variable out;
+ Tensor* src = in.GetMutable();
+ src->mutable_data(make_ddim({2, 3, 1, 2}), CPUPlace());
+ src->set_layout(DataLayout::kNHWC);
+
+ DeviceContext* ctx = new CPUDeviceContext();
+
+ {
+ auto kernel1 = GenFromBit({1, 0, 0, 0});
+ auto kernel2 = GenFromBit({1, 0, 1, 0});
+ auto pair0 = std::make_pair(kernel1, kernel2);
+ instance.Get(pair0)(ctx, pair0, in, &out);
+ }
+
+ Tensor dst = out.Get();
+ EXPECT_TRUE(dst.layout() != src->layout());
+}
+
+TEST(DataTransform, DataType) {
+ using namespace paddle::framework;
+ using namespace paddle::platform;
+
+ auto& instance = DataTransformFnMap::Instance();
+ DeviceContext* ctx = new CPUDeviceContext();
+
+ Variable in;
+ Variable out;
+ Tensor* src = in.GetMutable();
+ float* ptr = src->mutable_data(make_ddim({2, 3}), CPUPlace());
+ for (int i = 0; i < 6; ++i) {
+ ptr[i] = i / 3;
+ }
+
+ {
+ auto kernel1 = GenFromBit({0, 0, 0, 0});
+ auto kernel2 = GenFromBit({1, 0, 0, 0});
+ auto pair0 = std::make_pair(kernel1, kernel2);
+ instance.Get(pair0)(ctx, pair0, in, &out);
+ }
+ Tensor dst = out.Get();
+ EXPECT_TRUE(dst.data() != nullptr);
+}
diff --git a/paddle/framework/data_type.h b/paddle/framework/data_type.h
index e94ee2ed52bc40f52caf783f971dd0b560534e08..6a372ac32e48131eed28e2d42125feb5b92a11c7 100644
--- a/paddle/framework/data_type.h
+++ b/paddle/framework/data_type.h
@@ -1,16 +1,16 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
- http://www.apache.org/licenses/LICENSE-2.0
+ http://www.apache.org/licenses/LICENSE-2.0
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License. */
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
#pragma once
#include
diff --git a/paddle/framework/ddim_test.cc b/paddle/framework/ddim_test.cc
index bd5ea09d7da700479aa387283d7bde77c64c1293..bc259d1f603fb34ac8546c388669d8c5c1250bd1 100644
--- a/paddle/framework/ddim_test.cc
+++ b/paddle/framework/ddim_test.cc
@@ -1,16 +1,16 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
- http://www.apache.org/licenses/LICENSE-2.0
+ http://www.apache.org/licenses/LICENSE-2.0
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License. */
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
#include
#include
diff --git a/paddle/framework/details/op_registry.h b/paddle/framework/details/op_registry.h
index 7f5151c41d6046f21f7a9707e45de85ec50219ad..6d50e820b2b625f932768d2ca671d999071f1ca6 100644
--- a/paddle/framework/details/op_registry.h
+++ b/paddle/framework/details/op_registry.h
@@ -1,16 +1,16 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
- http://www.apache.org/licenses/LICENSE-2.0
+ http://www.apache.org/licenses/LICENSE-2.0
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License. */
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
#pragma once
diff --git a/paddle/framework/executor.cc b/paddle/framework/executor.cc
index 997773c1689efad4ce5a86c09ce58bd3a40185e0..bf1f0471ccbfccf13cb6f74c8088da7acd68ec0b 100644
--- a/paddle/framework/executor.cc
+++ b/paddle/framework/executor.cc
@@ -14,18 +14,17 @@ limitations under the License. */
#include "paddle/framework/executor.h"
-#include
-#include
-#include
#include
-#include
+#include "gflags/gflags.h"
#include "paddle/framework/feed_fetch_type.h"
#include "paddle/framework/lod_rank_table.h"
-#include "paddle/framework/lod_tensor.h"
#include "paddle/framework/lod_tensor_array.h"
#include "paddle/framework/op_registry.h"
-#include "paddle/framework/scope.h"
+
+DEFINE_bool(check_nan_inf, false,
+ "Checking whether operator produce NAN/INF or not. It will be "
+ "extremely slow so please use this flag wisely.");
namespace paddle {
namespace framework {
@@ -58,6 +57,19 @@ static void CreateTensor(Variable* var, proto::VarDesc::VarType var_type) {
}
}
+static void CheckTensorNANOrInf(const std::string& name,
+ const framework::Tensor& tensor) {
+ if (tensor.memory_size() == 0) {
+ return;
+ }
+ if (tensor.type().hash_code() != typeid(float).hash_code() &&
+ tensor.type().hash_code() != typeid(double).hash_code()) {
+ return;
+ }
+ PADDLE_ENFORCE(!framework::HasInf(tensor), "Tensor %s has Inf", name);
+ PADDLE_ENFORCE(!framework::HasNAN(tensor), "Tensor %s has NAN", name);
+}
+
void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id,
bool create_local_scope, bool create_vars) {
// TODO(tonyyang-svail):
@@ -101,8 +113,17 @@ void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id,
auto op = paddle::framework::OpRegistry::CreateOp(*op_desc);
VLOG(3) << op->DebugString();
op->Run(*local_scope, place_);
+ if (FLAGS_check_nan_inf) {
+ for (auto& vname : op->OutputVars(true)) {
+ auto* var = local_scope->FindVar(vname);
+ if (var == nullptr) continue;
+ if (var->IsType()) {
+ CheckTensorNANOrInf(vname, var->Get());
+ }
+ }
+ }
}
- if (create_local_scope) {
+ if (create_vars && create_local_scope) {
scope->DeleteScope(local_scope);
}
}
diff --git a/paddle/framework/feed_fetch_type.h b/paddle/framework/feed_fetch_type.h
index bc4ae440fc708f696c18bb9d5ab3ba7dd59e21ab..9bc4a90c44828ecb7458d524f59609f01848cc5c 100644
--- a/paddle/framework/feed_fetch_type.h
+++ b/paddle/framework/feed_fetch_type.h
@@ -1,16 +1,16 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
- http://www.apache.org/licenses/LICENSE-2.0
+ http://www.apache.org/licenses/LICENSE-2.0
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License. */
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
#pragma once
#include
diff --git a/paddle/framework/grad_op_desc_maker.h b/paddle/framework/grad_op_desc_maker.h
index cf411fa710103350713342b43946697a8dd2aa46..2de5242831835b47893a5825e5532500ad5ec3f9 100644
--- a/paddle/framework/grad_op_desc_maker.h
+++ b/paddle/framework/grad_op_desc_maker.h
@@ -1,16 +1,16 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
- http://www.apache.org/licenses/LICENSE-2.0
+ http://www.apache.org/licenses/LICENSE-2.0
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License. */
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
#pragma once
#include
diff --git a/paddle/framework/init.cc b/paddle/framework/init.cc
index 4deb4fa903dec04e9b76c5a620f1eb76c9f1db07..682cff168d4d31e0565fc987604f97a671566fbd 100644
--- a/paddle/framework/init.cc
+++ b/paddle/framework/init.cc
@@ -1,16 +1,16 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
- http://www.apache.org/licenses/LICENSE-2.0
+ http://www.apache.org/licenses/LICENSE-2.0
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License. */
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
#include
#include
@@ -54,7 +54,7 @@ bool InitDevices(const std::vector &devices) {
#ifdef PADDLE_WITH_CUDA
auto pos = string::RFind(p, ':', string::Piece::npos);
auto number = device.substr(pos + 1);
- places.emplace_back(platform::GPUPlace(std::stoi(number)));
+ places.emplace_back(platform::CUDAPlace(std::stoi(number)));
#else
LOG(WARNING)
<< "'GPU' is not supported, Please re-compile with WITH_GPU option";
@@ -71,7 +71,7 @@ bool InitDevices(const std::vector &devices) {
places.emplace_back(platform::CPUPlace());
LOG(WARNING) << "Not specified CPU device, create CPU by Default.";
}
- platform::DeviceContextPool::Create(places);
+ platform::DeviceContextPool::Init(places);
return true;
}
diff --git a/paddle/framework/init.h b/paddle/framework/init.h
index 1715cd81e6647158e269e39d4d91fbe065cd0008..33907f9eb00fb3469b53dcf8151557cc7a2d3791 100644
--- a/paddle/framework/init.h
+++ b/paddle/framework/init.h
@@ -1,16 +1,16 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
- http://www.apache.org/licenses/LICENSE-2.0
+ http://www.apache.org/licenses/LICENSE-2.0
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License. */
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
#pragma once
#include
diff --git a/paddle/framework/init_test.cc b/paddle/framework/init_test.cc
index cb1ba7ce8fdbf740846689356c94c4f2fabb95cb..f0788051d4855a175d2d7ea1f1a0805c776c462b 100644
--- a/paddle/framework/init_test.cc
+++ b/paddle/framework/init_test.cc
@@ -1,16 +1,16 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
- http://www.apache.org/licenses/LICENSE-2.0
+ http://www.apache.org/licenses/LICENSE-2.0
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License. */
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
#include "gtest/gtest.h"
#include "paddle/framework/init.h"
diff --git a/paddle/framework/library_type.h b/paddle/framework/library_type.h
index 49b273656bf57f183209e3d0996358da28ec0e7a..7707799cae8c4edc304cd81725270a85f01fd28d 100644
--- a/paddle/framework/library_type.h
+++ b/paddle/framework/library_type.h
@@ -20,7 +20,48 @@ namespace framework {
// For more details about the design of LibraryType, Please refer to
// https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/operator_kernel_type.md#library
-enum LibraryType { kPlain = 0, kMKLDNN = 1, kCUDNN = 2 };
+enum class LibraryType {
+ kPlain = 0,
+ kMKLDNN = 1,
+ kCUDNN = 2,
+};
+
+inline std::string LibraryTypeToString(const LibraryType& library_type) {
+ switch (library_type) {
+ case LibraryType::kPlain:
+ return "PLAIN";
+ case LibraryType::kMKLDNN:
+ return "MKLDNN";
+ case LibraryType::kCUDNN:
+ return "CUDNN";
+ default:
+ PADDLE_THROW("unknown LibraryType %d", static_cast(library_type));
+ }
+}
+
+inline LibraryType StringToLibraryType(const char* ctype) {
+ std::string s(ctype);
+ if (s == std::string("PLAIN")) {
+ return LibraryType::kPlain;
+ } else if (s == std::string("MKLDNN")) {
+ return LibraryType::kMKLDNN;
+ } else if (s == std::string("CUDNN")) {
+ return LibraryType::kCUDNN;
+ // To be compatible with register macro.
+ // CPU, CUDA, PLAIN are same library type.
+ } else if (s == std::string("CPU")) {
+ return LibraryType::kPlain;
+ } else if (s == std::string("CUDA")) {
+ return LibraryType::kPlain;
+ } else {
+ PADDLE_THROW("Unknown LibraryType %s", s.c_str());
+ }
+}
+
+inline std::ostream& operator<<(std::ostream& out, LibraryType l) {
+ out << LibraryTypeToString(l);
+ return out;
+}
} // namespace
} // framework
diff --git a/paddle/framework/lod_rank_table.cc b/paddle/framework/lod_rank_table.cc
index 17d524c09276fc0eb166925bd79bc0bdfcead195..704bce2a0eb60b974efd41a4edda0af2933da825 100644
--- a/paddle/framework/lod_rank_table.cc
+++ b/paddle/framework/lod_rank_table.cc
@@ -1,16 +1,16 @@
-/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
- http://www.apache.org/licenses/LICENSE-2.0
+ http://www.apache.org/licenses/LICENSE-2.0
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License. */
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
#include "paddle/framework/lod_rank_table.h"
diff --git a/paddle/framework/lod_rank_table.h b/paddle/framework/lod_rank_table.h
index d3007d3d7379a59b32465cbd55780c6268e0e4a8..df188709e91871ded0258fa5703ee16a5664f057 100644
--- a/paddle/framework/lod_rank_table.h
+++ b/paddle/framework/lod_rank_table.h
@@ -1,16 +1,16 @@
-/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
- http://www.apache.org/licenses/LICENSE-2.0
+ http://www.apache.org/licenses/LICENSE-2.0
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License. */
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
#pragma once
#include
diff --git a/paddle/framework/lod_tensor.cc b/paddle/framework/lod_tensor.cc
index 465f8c62b5fe2efd549f68bb3a9823d299ba5393..7b6dc09bdb5535488c8c4dbc71c9cd6a7998bd0b 100644
--- a/paddle/framework/lod_tensor.cc
+++ b/paddle/framework/lod_tensor.cc
@@ -1,16 +1,16 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
- http://www.apache.org/licenses/LICENSE-2.0
+ http://www.apache.org/licenses/LICENSE-2.0
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License. */
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
#include "paddle/framework/lod_tensor.h"
#include "paddle/framework/data_type.h"
@@ -189,62 +189,16 @@ void AppendLoD(LoD *lod, const LoD &lod_length) {
void SerializeToStream(std::ostream &os, const LoDTensor &tensor,
const platform::DeviceContext &dev_ctx) {
- // TODO(typhoonzero): serialize to ostream
- { // the 1st field, uint32_t version
+ { // the 1st field, uint32_t version for LoDTensor
constexpr uint32_t version = 0;
os.write(reinterpret_cast(&version), sizeof(version));
}
- { // the 2nd field, tensor description
- // int32_t size
- // void* protobuf message
- proto::TensorDesc desc;
- desc.set_data_type(framework::ToDataType(tensor.type()));
- auto dims = framework::vectorize(tensor.dims());
- auto *pb_dims = desc.mutable_dims();
- pb_dims->Resize(static_cast(dims.size()), 0);
- std::copy(dims.begin(), dims.end(), pb_dims->begin());
- int32_t size = desc.ByteSize();
- os.write(reinterpret_cast(&size), sizeof(size));
- auto out = desc.SerializeAsString();
- os.write(out.data(), size);
- }
- { // the 3rd field, tensor data
- uint64_t size = tensor.memory_size();
- auto *data_ptr = tensor.data();
- PADDLE_ENFORCE(size < std::numeric_limits::max(),
- "Index overflow when writing tensor");
- if (platform::is_gpu_place(tensor.place())) {
-#ifdef PADDLE_WITH_CUDA
- constexpr size_t kBufSize = 1024 * 1024 * 64; // 64MB
- std::unique_ptr buf(new char[kBufSize]);
- auto &gpu_dev_ctx =
- static_cast(dev_ctx);
- platform::CPUPlace cpu;
- uintptr_t data = reinterpret_cast(data_ptr);
- while (size != 0) {
- size_t size_to_write = std::min(kBufSize, static_cast(size));
- memory::Copy(cpu, buf.get(),
- boost::get(tensor.place()),
- reinterpret_cast(data), size_to_write,
- gpu_dev_ctx.stream());
- gpu_dev_ctx.Wait();
- os.write(buf.get(), size_to_write);
- data += size_to_write;
- size -= size_to_write;
- }
-#else
- PADDLE_THROW("Unexpected branch");
-#endif
- } else {
- os.write(static_cast(data_ptr),
- static_cast(size));
- }
- }
- { // the 4th field, lod information
- // uint64_t lod_level
- // uint64_t lod_level_1 size in byte.
- // int* lod_level_1 data
- // ...
+ {
+ // the 2st field, LoD information
+ // uint64_t lod_level
+ // uint64_t lod_level_1 size in byte.
+ // int* lod_level_1 data
+ // ...
auto lod = tensor.lod();
uint64_t size = lod.size();
os.write(reinterpret_cast(&size), sizeof(size));
@@ -256,49 +210,19 @@ void SerializeToStream(std::ostream &os, const LoDTensor &tensor,
static_cast(size));
}
}
+ // the 3st field, Tensor
+ SerializeToStream(os, static_cast(tensor), dev_ctx);
}
void DeserializeFromStream(std::istream &is, LoDTensor *tensor) {
- uint32_t version;
- is.read(reinterpret_cast(&version), sizeof(version));
- PADDLE_ENFORCE_EQ(version, 0U, "Only version 0 is supported");
- proto::TensorDesc desc;
- { // int32_t size
- // proto buffer
- int32_t size;
- is.read(reinterpret_cast(&size), sizeof(size));
- std::unique_ptr buf(new char[size]);
- is.read(reinterpret_cast(buf.get()), size);
- PADDLE_ENFORCE(desc.ParseFromArray(buf.get(), size),
- "Cannot parse tensor desc");
- }
- { // read tensor
- std::vector dims;
- dims.reserve(static_cast(desc.dims().size()));
- std::copy(desc.dims().begin(), desc.dims().end(), std::back_inserter(dims));
- tensor->Resize(framework::make_ddim(dims));
-
- void *buf;
- platform::Place cpu = platform::CPUPlace();
- switch (desc.data_type()) {
- case proto::FP32:
- buf = tensor->mutable_data(cpu);
- break;
- case proto::FP64:
- buf = tensor->mutable_data(cpu);
- break;
- case proto::INT32:
- buf = tensor->mutable_data(cpu);
- break;
- case proto::INT64:
- buf = tensor->mutable_data(cpu);
- break;
- default:
- PADDLE_THROW("DataType %d not supported", desc.data_type());
- }
- is.read(static_cast(buf), tensor->memory_size());
- }
- { // read lod
+ {
+ // the 1st field, unit32_t version for SelectedRows
+ uint32_t version;
+ is.read(reinterpret_cast(&version), sizeof(version));
+ PADDLE_ENFORCE_EQ(version, 0U, "Only version 0 is supported");
+ }
+ {
+ // the 2st field, LoD information
uint64_t lod_level;
is.read(reinterpret_cast(&lod_level), sizeof(lod_level));
auto &lod = *tensor->mutable_lod();
@@ -312,6 +236,8 @@ void DeserializeFromStream(std::istream &is, LoDTensor *tensor) {
lod[i] = tmp;
}
}
+ // the 3st filed, Tensor
+ DeserializeFromStream(is, static_cast(tensor));
}
} // namespace framework
diff --git a/paddle/framework/lod_tensor.h b/paddle/framework/lod_tensor.h
index 0923c52a0ad2fe10cea760df20c99021984ad39d..147db3ab0877662d9e47ae7ee6df05638b5fcbd1 100644
--- a/paddle/framework/lod_tensor.h
+++ b/paddle/framework/lod_tensor.h
@@ -1,16 +1,16 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
- http://www.apache.org/licenses/LICENSE-2.0
+ http://www.apache.org/licenses/LICENSE-2.0
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License. */
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
#pragma once
diff --git a/paddle/framework/lod_tensor_array.h b/paddle/framework/lod_tensor_array.h
index 13f0608d24be97d8bba149b74f1a4deb57deeb48..4a8e7f4fa540b1c2f19a6e3ec236a0dd5c0daf0b 100644
--- a/paddle/framework/lod_tensor_array.h
+++ b/paddle/framework/lod_tensor_array.h
@@ -1,16 +1,16 @@
-/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
- http://www.apache.org/licenses/LICENSE-2.0
+ http://www.apache.org/licenses/LICENSE-2.0
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License. */
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
#pragma once
#include
diff --git a/paddle/framework/lod_tensor_test.cc b/paddle/framework/lod_tensor_test.cc
index 02d84b68233f2fdfc66e1df2fc7ce20307cadd94..0747c8db531d6ae443d76591b945cce0c9bbea2b 100644
--- a/paddle/framework/lod_tensor_test.cc
+++ b/paddle/framework/lod_tensor_test.cc
@@ -126,6 +126,20 @@ TEST_F(LoDTensorTester, ShrinkInLevel) {
EXPECT_NE(t1.data(), lod_tensor_.data());
}
+TEST_F(LoDTensorTester, SerializeAndDeserialize) {
+ LoDTensor dst_tensor;
+ platform::CPUDeviceContext cpu_ctx((platform::CPUPlace()));
+ std::ostringstream oss;
+ SerializeToStream(oss, lod_tensor_, cpu_ctx);
+ std::istringstream iss(oss.str());
+ DeserializeFromStream(iss, &dst_tensor);
+ float* dst_ptr = dst_tensor.mutable_data(platform::CPUPlace());
+ for (int i = 0; i < kLodTensorSize; ++i) {
+ EXPECT_EQ(dst_ptr[i], i);
+ }
+ EXPECT_EQ(dst_tensor.lod(), lod_tensor_.lod());
+}
+
TEST(LodExpand, test) {
LoD lod{{0, 2}};
LoDTensor tensor;
diff --git a/paddle/framework/lod_tensor_test.cu b/paddle/framework/lod_tensor_test.cu
index 5b90fbfca7f6bec4f2c862d0ff18dfd7cf39e181..e8508ad2658ae850e4c98aa798b5db6d007e67d0 100644
--- a/paddle/framework/lod_tensor_test.cu
+++ b/paddle/framework/lod_tensor_test.cu
@@ -27,7 +27,7 @@ __global__ void test(size_t* a, int size) {
TEST(LoDTensor, LoDInGPU) {
paddle::framework::LoDTensor lod_tensor;
- paddle::platform::GPUPlace place(0);
+ paddle::platform::CUDAPlace place(0);
paddle::framework::LoD src_lod;
src_lod.push_back(std::vector{0, 2, 4, 6, 8, 10, 12, 14});
diff --git a/paddle/framework/op_desc.cc b/paddle/framework/op_desc.cc
index b361e64438251c1df827667fb825e7f5909fb09e..3e58e6442edfe006c8aed238f67b9524783601ee 100644
--- a/paddle/framework/op_desc.cc
+++ b/paddle/framework/op_desc.cc
@@ -88,6 +88,14 @@ OpDesc::OpDesc(const std::string &type, const VariableNameMap &inputs,
need_update_ = true;
}
+void OpDesc::CopyFrom(const OpDesc &op_desc) {
+ desc_.set_type(op_desc.Type());
+ inputs_ = op_desc.inputs_;
+ outputs_ = op_desc.outputs_;
+ attrs_ = op_desc.attrs_;
+ need_update_ = true;
+}
+
OpDesc::OpDesc(const proto::OpDesc &desc, ProgramDesc *prog)
: desc_(desc), need_update_(false) {
// restore inputs_
@@ -252,7 +260,13 @@ struct SetAttrDescVisitor : public boost::static_visitor {
void operator()(int v) const { attr_->set_i(v); }
void operator()(float v) const { attr_->set_f(v); }
void operator()(const std::string &v) const { attr_->set_s(v); }
- void operator()(bool b) const { attr_->set_b(b); }
+
+ // Please refer to https://github.com/PaddlePaddle/Paddle/issues/7162
+ template ::value>::type>
+ void operator()(T b) const {
+ attr_->set_b(b);
+ }
void operator()(const std::vector &v) const {
VectorToRepeated(v, attr_->mutable_ints());
@@ -266,9 +280,7 @@ struct SetAttrDescVisitor : public boost::static_visitor {
void operator()(const std::vector &v) const {
VectorToRepeated(v, attr_->mutable_bools());
}
- void operator()(proto::BlockDesc *desc) const {
- attr_->set_block_idx(desc->idx());
- }
+ void operator()(BlockDesc *desc) const { attr_->set_block_idx(desc->ID()); }
void operator()(boost::blank) const { PADDLE_THROW("Unexpected branch"); }
};
diff --git a/paddle/framework/op_desc.h b/paddle/framework/op_desc.h
index 93d4a88f3c390551ab41e42ec2f6f30f52e306db..4cf784a0d0d319d09caa27b4e2b589bd7ac4f324 100644
--- a/paddle/framework/op_desc.h
+++ b/paddle/framework/op_desc.h
@@ -35,6 +35,8 @@ class OpDesc {
OpDesc(const proto::OpDesc &desc, ProgramDesc *prog);
+ void CopyFrom(const OpDesc &op_desc);
+
proto::OpDesc *Proto();
std::string Type() const { return desc_.type(); }
diff --git a/paddle/framework/op_info.cc b/paddle/framework/op_info.cc
index 81ba29797c5f478e5d6a91236f3e8de1e6b43e49..b520108109bb2f72b80f83559fa065a5ca58e9e1 100644
--- a/paddle/framework/op_info.cc
+++ b/paddle/framework/op_info.cc
@@ -1,16 +1,16 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
- http://www.apache.org/licenses/LICENSE-2.0
+ http://www.apache.org/licenses/LICENSE-2.0
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License. */
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
#include "paddle/framework/op_info.h"
diff --git a/paddle/framework/op_info.h b/paddle/framework/op_info.h
index 7772d6e745c2207024863d3dd5cbef052358272e..d9b89f9cac9611fcecb18bef87940632df1e2234 100644
--- a/paddle/framework/op_info.h
+++ b/paddle/framework/op_info.h
@@ -1,16 +1,16 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
- http://www.apache.org/licenses/LICENSE-2.0
+ http://www.apache.org/licenses/LICENSE-2.0
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License. */
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
#pragma once
#include
diff --git a/paddle/framework/op_kernel_type.h b/paddle/framework/op_kernel_type.h
index a1dea0d9d864881ef1f60b117dfaa02da3aa4275..b06002096fb109da806809f7b908d9768cf095ba 100644
--- a/paddle/framework/op_kernel_type.h
+++ b/paddle/framework/op_kernel_type.h
@@ -17,6 +17,7 @@ limitations under the License. */
#include "paddle/framework/data_layout.h"
#include "paddle/framework/data_type.h"
#include "paddle/framework/library_type.h"
+#include "paddle/platform/device_context.h"
#include "paddle/platform/place.h"
namespace paddle {
@@ -39,6 +40,7 @@ struct OpKernelType {
// place, data_type, library_type kinds less than 2^8
constexpr static int LEFT_SHIFT = 8;
+
proto::DataType data_type_;
DataLayout data_layout_;
platform::Place place_;
@@ -66,7 +68,23 @@ struct OpKernelType {
data_type_ == o.data_type_ && data_layout_ == o.data_layout_ &&
library_type_ == o.library_type_;
}
+
+ bool operator!=(const OpKernelType& o) const { return !(*this == o); }
};
+inline std::ostream& operator<<(std::ostream& os,
+ const OpKernelType& kernel_key) {
+ os << "data_type[" << kernel_key.data_type_ << "]:data_layout["
+ << kernel_key.data_layout_ << "]:place[" << kernel_key.place_
+ << "]:library_type[" << kernel_key.library_type_ << "]";
+ return os;
+}
+
+inline std::string KernelTypeToString(const OpKernelType& kernel_key) {
+ std::ostringstream stream;
+ stream << kernel_key;
+ return stream.str();
+}
+
} // namespace framework
} // namespace paddle
diff --git a/paddle/framework/op_kernel_type_test.cc b/paddle/framework/op_kernel_type_test.cc
new file mode 100644
index 0000000000000000000000000000000000000000..649afeee8a846b0579545f2edff77e9dbe3b4dd8
--- /dev/null
+++ b/paddle/framework/op_kernel_type_test.cc
@@ -0,0 +1,49 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#include "paddle/framework/op_kernel_type.h"
+#include
+#include
+
+TEST(OpKernelType, ToString) {
+ using OpKernelType = paddle::framework::OpKernelType;
+ using DataType = paddle::framework::proto::DataType;
+ using CPUPlace = paddle::platform::CPUPlace;
+ using DataLayout = paddle::framework::DataLayout;
+ using LibraryType = paddle::framework::LibraryType;
+
+ OpKernelType op_kernel_type(DataType::FP32, CPUPlace(), DataLayout::kNCHW,
+ LibraryType::kCUDNN);
+
+ ASSERT_EQ(
+ paddle::framework::KernelTypeToString(op_kernel_type),
+ "data_type[5]:data_layout[NCHW]:place[CPUPlace]:library_type[CUDNN]");
+}
+
+TEST(OpKernelType, Hash) {
+ using OpKernelType = paddle::framework::OpKernelType;
+ using DataType = paddle::framework::proto::DataType;
+ using CPUPlace = paddle::platform::CPUPlace;
+ using CUDAPlace = paddle::platform::CUDAPlace;
+ using DataLayout = paddle::framework::DataLayout;
+ using LibraryType = paddle::framework::LibraryType;
+
+ OpKernelType op_kernel_type_1(DataType::FP32, CPUPlace(), DataLayout::kNCHW,
+ LibraryType::kCUDNN);
+ OpKernelType op_kernel_type_2(DataType::FP32, CUDAPlace(0), DataLayout::kNCHW,
+ LibraryType::kCUDNN);
+
+ OpKernelType::Hash hasher;
+ ASSERT_NE(hasher(op_kernel_type_1), hasher(op_kernel_type_2));
+}
diff --git a/paddle/framework/op_registry.h b/paddle/framework/op_registry.h
index 7f0155b61f44b676825b84667d5bebb798cae8a3..bdaa25918155caca4b64b0ed60aa3f6be03eb12f 100644
--- a/paddle/framework/op_registry.h
+++ b/paddle/framework/op_registry.h
@@ -61,17 +61,6 @@ struct OperatorRegistrar : public Registrar {
class OpRegistry {
public:
- template
- static void RegisterOp(const std::string& op_type,
- const std::string& grad_op_type) {
- OperatorRegistrar reg(op_type.c_str());
- reg.info.grad_op_type_ = grad_op_type;
- // register gradient op
- if (!grad_op_type.empty()) {
- OperatorRegistrar grad_reg(grad_op_type.c_str());
- }
- }
-
static std::unique_ptr CreateOp(const std::string& type,
const VariableNameMap& inputs,
const VariableNameMap& outputs,
@@ -90,30 +79,31 @@ struct OpKernelRegistrarFunctor {
using KERNEL_TYPE =
typename std::tuple_element>::type;
- void operator()(const char* op_type) const {
+ void operator()(const char* op_type, const char* library_type) const {
using T = typename KERNEL_TYPE::ELEMENT_TYPE;
- OpKernelType key(ToDataType(std::type_index(typeid(T))), PlaceType());
+ OpKernelType key(ToDataType(std::type_index(typeid(T))), PlaceType(),
+ DataLayout::kAnyLayout, StringToLibraryType(library_type));
OperatorWithKernel::AllOpKernels()[op_type][key].reset(new KERNEL_TYPE);
constexpr auto size = std::tuple_size>::value;
OpKernelRegistrarFunctor
func;
- func(op_type);
+ func(op_type, library_type);
}
};
template
struct OpKernelRegistrarFunctor {
- void operator()(const char* op_type) const {}
+ void operator()(const char* op_type, const char* library_type) const {}
};
// User can register many kernel in one place. The data type could be different.
template
class OpKernelRegistrar : public Registrar {
public:
- explicit OpKernelRegistrar(const char* op_type) {
+ explicit OpKernelRegistrar(const char* op_type, const char* library_type) {
OpKernelRegistrarFunctor func;
- func(op_type);
+ func(op_type, library_type);
}
};
@@ -192,14 +182,15 @@ class OpKernelRegistrar : public Registrar {
__reg_op_kernel_##op_type##_##DEVICE_TYPE##__, \
"REGISTER_OP_KERNEL must be called in global namespace"); \
static ::paddle::framework::OpKernelRegistrar \
- __op_kernel_registrar_##op_type##_##DEVICE_TYPE##__(#op_type); \
+ __op_kernel_registrar_##op_type##_##DEVICE_TYPE##__(#op_type, \
+ #DEVICE_TYPE); \
int TouchOpKernelRegistrar_##op_type##_##DEVICE_TYPE() { \
__op_kernel_registrar_##op_type##_##DEVICE_TYPE##__.Touch(); \
return 0; \
}
#define REGISTER_OP_CUDA_KERNEL(op_type, ...) \
- REGISTER_OP_KERNEL(op_type, CUDA, ::paddle::platform::GPUPlace, __VA_ARGS__)
+ REGISTER_OP_KERNEL(op_type, CUDA, ::paddle::platform::CUDAPlace, __VA_ARGS__)
#define REGISTER_OP_CPU_KERNEL(op_type, ...) \
REGISTER_OP_KERNEL(op_type, CPU, ::paddle::platform::CPUPlace, __VA_ARGS__)
diff --git a/paddle/framework/op_registry_test.cc b/paddle/framework/op_registry_test.cc
index 4cdf6e0865e0922b72bd184172f85a9c705dcd00..cef530c6e639f6e2188869fa57d114ec6b885aa8 100644
--- a/paddle/framework/op_registry_test.cc
+++ b/paddle/framework/op_registry_test.cc
@@ -1,3 +1,17 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License. */
+
#include "paddle/framework/op_registry.h"
#include
@@ -182,3 +196,71 @@ TEST(OperatorRegistrar, Test) {
using namespace paddle::framework;
OperatorRegistrar reg("cos");
}
+
+namespace paddle {
+namespace framework {
+
+class OpKernelTestMaker : public OpProtoAndCheckerMaker {
+ public:
+ OpKernelTestMaker(OpProto* proto, OpAttrChecker* op_checker)
+ : OpProtoAndCheckerMaker(proto, op_checker) {
+ AddComment("NoGradOp, same input output. no Grad");
+ }
+};
+
+class OpWithKernelTest : public OperatorWithKernel {
+ public:
+ using OperatorWithKernel::OperatorWithKernel;
+
+ protected:
+ void InferShape(InferShapeContext* ctx) const override {}
+
+ framework::OpKernelType GetActualKernelType(
+ const framework::ExecutionContext& ctx) const override {
+ return framework::OpKernelType(proto::DataType::FP32, ctx.device_context());
+ }
+};
+
+template
+class OpKernelTest : public paddle::framework::OpKernel {
+ public:
+ void Compute(const paddle::framework::ExecutionContext& ctx) const {}
+};
+
+} // namespace framework
+} // namespace paddle
+
+REGISTER_OP_WITHOUT_GRADIENT(op_with_kernel,
+ paddle::framework::OpWithKernelTest,
+ paddle::framework::OpKernelTestMaker);
+REGISTER_OP_CPU_KERNEL(
+ op_with_kernel,
+ paddle::framework::OpKernelTest);
+
+REGISTER_OP_CUDA_KERNEL(op_with_kernel,
+ paddle::framework::OpKernelTest<
+ paddle::platform::CUDADeviceContext, float>);
+
+TEST(OperatorRegistrar, CPU) {
+ paddle::framework::proto::OpDesc op_desc;
+ paddle::platform::CPUPlace cpu_place;
+ paddle::framework::Scope scope;
+
+ op_desc.set_type("op_with_kernel");
+ auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
+
+ op->Run(scope, cpu_place);
+}
+
+#ifdef PADDLE_WITH_CUDA
+TEST(OperatorRegistrar, CUDA) {
+ paddle::framework::proto::OpDesc op_desc;
+ paddle::platform::CUDAPlace cuda_place(0);
+ paddle::framework::Scope scope;
+
+ op_desc.set_type("op_with_kernel");
+ auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
+
+ op->Run(scope, cuda_place);
+}
+#endif
diff --git a/paddle/framework/operator.cc b/paddle/framework/operator.cc
index 06184f6ba968c438f6baa571d7a5c12a69109c84..fc7091f1c89f8b3f998f6d1b68f032b76bad2197 100644
--- a/paddle/framework/operator.cc
+++ b/paddle/framework/operator.cc
@@ -15,6 +15,7 @@ limitations under the License. */
#include
#include
+#include "paddle/framework/data_transform.h"
#include "paddle/framework/executor.h"
#include "paddle/framework/lod_tensor_array.h"
#include "paddle/framework/operator.h"
@@ -242,13 +243,6 @@ std::vector ExecutionContext::MultiOutput(
return res;
}
-std::ostream& operator<<(std::ostream& os, const OpKernelType& kernel_key) {
- os << "data_type[" << kernel_key.data_type_ << "]:data_layout["
- << kernel_key.data_layout_ << "]:place[" << kernel_key.place_
- << "]:library_type[" << kernel_key.library_type_ << "]";
- return os;
-}
-
bool OpSupportGPU(const std::string& op_type) {
auto& all_kernels = OperatorWithKernel::AllOpKernels();
auto it = all_kernels.find(op_type);
@@ -390,12 +384,30 @@ class RuntimeInferShapeContext : public InferShapeContext {
const Scope& scope_;
};
+const platform::DeviceContext* GetDeviceContext(
+ framework::KernelTypePair& kernel_pair) {
+ auto& actual_kernel_key = kernel_pair.first;
+ auto& expected_kernel_key = kernel_pair.second;
+ platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
+
+ if (platform::is_gpu_place(actual_kernel_key.place_) &&
+ platform::is_cpu_place(expected_kernel_key.place_)) {
+ return pool.Get(actual_kernel_key.place_);
+ } else if (platform::is_cpu_place(actual_kernel_key.place_) &&
+ platform::is_gpu_place(expected_kernel_key.place_)) {
+ return pool.Get(expected_kernel_key.place_);
+ } else {
+ PADDLE_THROW(
+ "Currently, model parallelism is only supported between CPU and CUDA");
+ }
+}
+
void OperatorWithKernel::Run(const Scope& scope,
const platform::Place& place) const {
RuntimeInferShapeContext infer_shape_ctx(*this, scope);
this->InferShape(&infer_shape_ctx);
- platform::DeviceContextPool& pool = platform::DeviceContextPool::Get();
- auto dev_ctx = pool.Borrow(place);
+ platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
+ auto dev_ctx = pool.Get(place);
// check if op[type] has kernel registered.
auto& all_op_kernels = AllOpKernels();
@@ -409,19 +421,69 @@ void OperatorWithKernel::Run(const Scope& scope,
OpKernelMap& kernels = kernels_iter->second;
ExecutionContext ctx(*this, scope, *dev_ctx);
- auto kernel_key = GetKernelType(ctx);
- auto kernel_iter = kernels.find(kernel_key);
+ auto actual_kernel_key = GetActualKernelType(ctx);
+ auto expected_kernel_key = GetExpectedKernelType(actual_kernel_key);
+ auto kernel_iter = kernels.find(expected_kernel_key);
if (kernel_iter == kernels.end()) {
- PADDLE_THROW("The operator %s does not support %s", type_, kernel_key);
+ PADDLE_THROW("The operator %s does not support %s", type_,
+ expected_kernel_key);
+ }
+
+ if (actual_kernel_key == expected_kernel_key) {
+ PADDLE_ENFORCE_EQ(actual_kernel_key.place_, expected_kernel_key.place_,
+ "Currently, model parallelism is only supported between "
+ "CPU and other devices. For example, multi-GPU model "
+ "parallelism will failed.");
+ } else {
+ auto kernel_pair = std::make_pair(actual_kernel_key, expected_kernel_key);
+ const DataTransformFn* trans_fun =
+ DataTransformFnMap::Instance().GetNullable(kernel_pair);
+ if (trans_fun) {
+ auto input_vars = this->InputVars();
+ // TODO(qijun) filter the input vars that do not need to be transformed
+
+ // filter vars that has been transformed
+ std::vector need_trans;
+ for (auto var_name : input_vars) {
+ auto var_name_trans =
+ var_name + framework::KernelTypeToString(expected_kernel_key);
+ if (!scope.FindVar(var_name_trans)) {
+ const_cast(scope).Var(var_name_trans);
+ need_trans.push_back(var_name);
+ }
+ }
+
+ if (!need_trans.empty()) {
+ auto trans_dev_ctx = GetDeviceContext(kernel_pair);
+
+ // Wait for transform starting
+ dev_ctx->Wait();
+
+ for (auto var_name : need_trans) {
+ (*trans_fun)(trans_dev_ctx, kernel_pair, *(scope.FindVar(var_name)),
+ scope.FindVar(var_name + framework::KernelTypeToString(
+ expected_kernel_key)));
+ }
+ // Wait for data transform finishing
+ trans_dev_ctx->Wait();
+ }
+ }
}
kernel_iter->second->Compute(ctx);
}
-OpKernelType OperatorWithKernel::GetKernelType(
+
+OpKernelType OperatorWithKernel::GetActualKernelType(
const ExecutionContext& ctx) const {
return OpKernelType(IndicateDataType(ctx), ctx.GetPlace());
}
+
+OpKernelType OperatorWithKernel::GetExpectedKernelType(
+ const OpKernelType& actual_kernel_type) const {
+ return actual_kernel_type;
+}
+
proto::DataType OperatorWithKernel::IndicateDataType(
const ExecutionContext& ctx) const {
auto& scope = ctx.scope();
diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h
index aba34c5bcb81c85db21e9d82894fc0b937c3c060..d0a9b643d565d6651fd7ec0b515f088362852ba3 100644
--- a/paddle/framework/operator.h
+++ b/paddle/framework/operator.h
@@ -52,6 +52,11 @@ constexpr char kGradVarSuffix[] = "@GRAD";
/// Variables with this suffix are supposed to be filled up with zeros.
constexpr char kZeroVarSuffix[] = "@ZERO";
+// define some kernel hint
+const std::string kUseCPU = "use_cpu";
+const std::string kUseCUDNN = "use_cudnn";
+const std::string kUseMKLDNN = "use_mkldnn";
+
inline std::string GradVarName(const std::string& var_name) {
return var_name + kGradVarSuffix;
}
@@ -84,6 +89,9 @@ class OperatorBase {
/// Net will call this function to Run an op.
virtual void Run(const Scope& scope, const platform::Place& place) const = 0;
+ // FIXME(typhoonzero): this is only used for recv_op to stop event_loop.
+ virtual void Stop() {}
+
virtual bool IsNetOp() const { return false; }
virtual bool SupportGPU() const { return false; }
@@ -373,7 +381,9 @@ class OperatorWithKernel : public OperatorBase {
}
protected:
- virtual OpKernelType GetKernelType(const ExecutionContext& ctx) const;
+ virtual OpKernelType GetActualKernelType(const ExecutionContext& ctx) const;
+ virtual OpKernelType GetExpectedKernelType(
+ const OpKernelType& actual_kernel_type) const;
private:
// indicate kernel DataType by input data. Defaultly all input data must be
@@ -381,8 +391,6 @@ class OperatorWithKernel : public OperatorBase {
proto::DataType IndicateDataType(const ExecutionContext& ctx) const;
};
-std::ostream& operator<<(std::ostream& os, const OpKernelType& kernel_key);
-
extern bool OpSupportGPU(const std::string& op_type);
} // namespace framework
diff --git a/paddle/framework/operator_test.cc b/paddle/framework/operator_test.cc
index fbca45b59dc5446e93e79599f471d80a06ea3661..4d38a7ada91af834aa1a19b49e36d606ebe786ba 100644
--- a/paddle/framework/operator_test.cc
+++ b/paddle/framework/operator_test.cc
@@ -114,7 +114,7 @@ class OpWithKernelTest : public OperatorWithKernel {
protected:
void InferShape(framework::InferShapeContext* ctx) const override {}
- OpKernelType GetKernelType(const ExecutionContext& ctx) const override {
+ OpKernelType GetActualKernelType(const ExecutionContext& ctx) const override {
return OpKernelType(proto::DataType::FP32, ctx.GetPlace());
}
};
diff --git a/paddle/framework/program_desc_test.cc b/paddle/framework/program_desc_test.cc
index a49886f7ea56bc57459202dba65e3f76a902cd70..59947c9f2189348226b7ff6c2b9315196bbf55fa 100644
--- a/paddle/framework/program_desc_test.cc
+++ b/paddle/framework/program_desc_test.cc
@@ -1,16 +1,16 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
- http://www.apache.org/licenses/LICENSE-2.0
+ http://www.apache.org/licenses/LICENSE-2.0
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License. */
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
#include "paddle/framework/program_desc.h"
#include "gtest/gtest.h"
diff --git a/paddle/framework/prune_test.cc b/paddle/framework/prune_test.cc
index bdd57659432ea4f9bdd05425a802110b0c202fb8..d76c5abca94cb87220ce73537a8657c3ec695f4d 100644
--- a/paddle/framework/prune_test.cc
+++ b/paddle/framework/prune_test.cc
@@ -1,16 +1,16 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
- http://www.apache.org/licenses/LICENSE-2.0
+ http://www.apache.org/licenses/LICENSE-2.0
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License. */
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
#include "paddle/framework/prune.h"
diff --git a/paddle/framework/scope.cc b/paddle/framework/scope.cc
index 656736e23846c8de50553a608c54a0bdd3272cb1..0c01d605bcd95f5796fba1e5a3351a2640b2898a 100644
--- a/paddle/framework/scope.cc
+++ b/paddle/framework/scope.cc
@@ -74,17 +74,9 @@ void Scope::DropKids() {
kids_.clear();
}
-std::vector Scope::GetAllNames(bool recursive) const {
- std::vector known_vars(vars_.size());
-
- if (recursive) {
- for (auto& kid : kids_) {
- auto kid_vars = kid->GetAllNames();
- for (auto& p : kid_vars) {
- known_vars.emplace_back(p);
- }
- }
- }
+std::vector Scope::LocalVarNames() const {
+ std::vector known_vars;
+ known_vars.reserve(this->vars_.size());
for (auto& p : vars_) {
known_vars.emplace_back(p.first);
}
diff --git a/paddle/framework/scope.h b/paddle/framework/scope.h
index 56e815db54b6385c4e4d87f456ed5d59113ca77b..10143326dfa201894c777b3e5e226d5ca5015eda 100644
--- a/paddle/framework/scope.h
+++ b/paddle/framework/scope.h
@@ -66,7 +66,7 @@ class Scope {
void DropKids();
// enumerate all the variables current contains.
- std::vector GetAllNames(bool recursive = false) const;
+ std::vector LocalVarNames() const;
// Rename variable to a new name
void Rename(const std::string& origin_name,
diff --git a/paddle/framework/scope_test.cc b/paddle/framework/scope_test.cc
index f738d5ba9ecda57ea25bb5f84057d1d0106eef66..0f5b86061dbdebde08badca7f984f4a2c8d7bc79 100644
--- a/paddle/framework/scope_test.cc
+++ b/paddle/framework/scope_test.cc
@@ -61,7 +61,7 @@ TEST(Scope, GetAllNames) {
Variable* v = s.Var("a");
EXPECT_EQ(&s, s.FindScope(v));
- std::vector ans = s.GetAllNames();
+ std::vector ans = s.LocalVarNames();
std::string str;
for (auto& var : ans) {
str += var;
diff --git a/paddle/framework/selected_rows.cc b/paddle/framework/selected_rows.cc
index c74459c9dd7006a24615b1d6df041583088fb25c..82adfa7123a3cf40d929021602c45fe7d2e34ffa 100644
--- a/paddle/framework/selected_rows.cc
+++ b/paddle/framework/selected_rows.cc
@@ -12,5 +12,58 @@ limitations under the License. */
#include "paddle/framework/selected_rows.h"
namespace paddle {
-namespace framework {} // namespace framework
+namespace framework {
+void SerializeToStream(std::ostream& os, const SelectedRows& selected_rows,
+ const platform::DeviceContext& dev_ctx) {
+ { // the 1st field, uint32_t version
+ constexpr uint32_t version = 0;
+ os.write(reinterpret_cast(&version), sizeof(version));
+ }
+ {
+ // the 2st field, rows information
+ auto& rows = selected_rows.rows();
+ uint64_t size = rows.size();
+ os.write(reinterpret_cast(&size), sizeof(size));
+ for (uint64_t i = 0; i < size; ++i) {
+ os.write(reinterpret_cast(&rows[i]), sizeof(rows[i]));
+ }
+ }
+ {
+ // the 3st field, the height of SelectedRows
+ int64_t height = selected_rows.height();
+ os.write(reinterpret_cast(&height), sizeof(height));
+ }
+ // the 4st field, Tensor data
+ SerializeToStream(os, selected_rows.value(), dev_ctx);
+}
+
+void DeserializeFromStream(std::istream& is, SelectedRows* selected_rows) {
+ auto tensor = *selected_rows->mutable_value();
+ {
+ // the 1st field, unit32_t version for SelectedRows
+ uint32_t version;
+ is.read(reinterpret_cast(&version), sizeof(version));
+ PADDLE_ENFORCE_EQ(version, 0U, "Only version 0 is supported");
+ }
+ {
+ // the 2st field, rows information
+ uint64_t size;
+ is.read(reinterpret_cast(&size), sizeof(size));
+ auto& rows = *selected_rows->mutable_rows();
+ rows.resize(size);
+ for (uint64_t i = 0; i < size; ++i) {
+ is.read(reinterpret_cast(&rows[i]), sizeof(int64_t));
+ }
+ }
+ {
+ // the 3st field, the height of the SelectedRows
+ int64_t height;
+ is.read(reinterpret_cast(&height), sizeof(int64_t));
+ selected_rows->set_height(height);
+ }
+ // the 4st field, tensor which contains the data
+ DeserializeFromStream(is, &tensor);
+}
+
+} // namespace framework
} // namespace paddle
diff --git a/paddle/framework/selected_rows.h b/paddle/framework/selected_rows.h
index 0332b91323e3a4b4b80e02302ad3dcafe0986cde..699e392688e9889f050592172f8bfc45f855d0b1 100644
--- a/paddle/framework/selected_rows.h
+++ b/paddle/framework/selected_rows.h
@@ -59,5 +59,14 @@ class SelectedRows {
int64_t height_;
};
+/*
+ * Serialize/Desiralize SelectedRows to std::ostream
+ * You can pass ofstream or ostringstream to serilize to file
+ * or to a in memory string. GPU tensor will be copied to CPU.
+ */
+void SerializeToStream(std::ostream& os, const SelectedRows& selected_rows,
+ const platform::DeviceContext& dev_ctx);
+void DeserializeFromStream(std::istream& is, SelectedRows* selected_rows);
+
} // namespace framework
} // namespace paddle
diff --git a/paddle/framework/selected_rows_test.cc b/paddle/framework/selected_rows_test.cc
index 4ee13a65d72e44693573397bb686b355effb2227..75487c4010391aa9e519d73058184fa936dabb84 100644
--- a/paddle/framework/selected_rows_test.cc
+++ b/paddle/framework/selected_rows_test.cc
@@ -43,5 +43,19 @@ TEST_F(SelectedRowsTester, complete_dims) {
ASSERT_EQ(selected_rows_->GetCompleteDims(), make_ddim({10, 100}));
}
+TEST_F(SelectedRowsTester, SerializeAndDeseralize) {
+ SelectedRows dst_tensor;
+ platform::CPUDeviceContext cpu_ctx(place_);
+ std::ostringstream oss;
+
+ SerializeToStream(oss, *selected_rows_, cpu_ctx);
+
+ std::istringstream iss(oss.str());
+ DeserializeFromStream(iss, &dst_tensor);
+
+ ASSERT_EQ(selected_rows_->rows(), dst_tensor.rows());
+ ASSERT_EQ(selected_rows_->height(), dst_tensor.height());
+}
+
} // namespace framework
} // namespace paddle
diff --git a/paddle/framework/shape_inference.cc b/paddle/framework/shape_inference.cc
index 86dc01665bda5e7f988e60780c0600b049d737ef..e53cc0cdabc623ae358f1a3e21823a2f38ec3c62 100644
--- a/paddle/framework/shape_inference.cc
+++ b/paddle/framework/shape_inference.cc
@@ -1,16 +1,16 @@
-/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
- http://www.apache.org/licenses/LICENSE-2.0
+ http://www.apache.org/licenses/LICENSE-2.0
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License. */
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
#include "paddle/framework/shape_inference.h"
#include "grad_op_desc_maker.h"
#include "paddle/framework/operator.h"
diff --git a/paddle/framework/tensor.cc b/paddle/framework/tensor.cc
index ea7b2a1f7b17d9abc2c2e14de5ecd1cf4a1a5027..f922e606249849e621e679f71d6dbe0f007c8464 100644
--- a/paddle/framework/tensor.cc
+++ b/paddle/framework/tensor.cc
@@ -1,16 +1,16 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
- http://www.apache.org/licenses/LICENSE-2.0
+ http://www.apache.org/licenses/LICENSE-2.0
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License. */
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
#include "paddle/framework/tensor.h"
diff --git a/paddle/framework/tensor.h b/paddle/framework/tensor.h
index 6a0c5133c9a6bb326ca51755242e75b6eb9e5474..341a6949beeb2dfa64b23d2079bd8f48750a94f8 100644
--- a/paddle/framework/tensor.h
+++ b/paddle/framework/tensor.h
@@ -20,12 +20,12 @@ limitations under the License. */
#include
#include
+#include "paddle/framework/data_layout.h"
#include "paddle/framework/ddim.h"
#include "paddle/memory/memory.h"
#include "paddle/platform/device_context.h"
#include "paddle/platform/enforce.h"
#include "paddle/platform/place.h"
-#include "unsupported/Eigen/CXX11/Tensor"
namespace paddle {
@@ -115,6 +115,10 @@ class Tensor {
inline void check_memory_size() const;
+ inline DataLayout layout() const { return layout_; }
+
+ inline void set_layout(const DataLayout layout) { layout_ = layout; }
+
private:
friend class LoDTensor;
@@ -173,6 +177,19 @@ class Tensor {
DDim dims_;
+ /**
+ * @brief the layout of memory block, default is NHWC.
+ *
+ * @note the memory allocation order, describe how weight/data is stored
+ * For example, in 4-D Tensor(rank=4), there are three commonly
+ * used layout. They are
+ * NCHW, NHWC, CHWN.
+ * N,C,H,W for respectively the batch size, the number of
+ * feature maps, the height.
+ */
+
+ DataLayout layout_ = DataLayout::kNHWC;
+
/**
* @brief A PlaceHolder may be shared by more than one tensor.
*
diff --git a/paddle/framework/tensor.md b/paddle/framework/tensor.md
index 7a80816d8e4ffa3a9462f3d9b87eff0f048466aa..0a27ac9bb6b03649d42e12100fda9e80a56e7f56 100644
--- a/paddle/framework/tensor.md
+++ b/paddle/framework/tensor.md
@@ -71,7 +71,7 @@ private:
```
```c++
-typedef boost::variant Place;
+typedef boost::variant Place;
typedef boost::variant, Dim<2>, Dim<3>, Dim<4>, Dim<5>,
Dim<6>, Dim<7>, Dim<8>, Dim<9>> DDimVar;
typedef boost::variant<
diff --git a/paddle/framework/tensor_impl.h b/paddle/framework/tensor_impl.h
index aba1f9f09329f890ef190f8820b958c56f017e89..6c6f298edc187a87677089e54c4c9046821282df 100644
--- a/paddle/framework/tensor_impl.h
+++ b/paddle/framework/tensor_impl.h
@@ -125,11 +125,11 @@ inline void* Tensor::mutable_data(platform::Place place, std::type_index type) {
boost::get(place), size, type));
} else if (platform::is_gpu_place(place)) {
#ifndef PADDLE_WITH_CUDA
- PADDLE_THROW("'GPUPlace' is not supported in CPU only device.");
+ PADDLE_THROW("'CUDAPlace' is not supported in CPU only device.");
}
#else
- holder_.reset(new PlaceholderImpl(
- boost::get(place), size, type));
+ holder_.reset(new PlaceholderImpl(
+ boost::get(place), size, type));
}
#endif
offset_ = 0;
@@ -165,6 +165,7 @@ inline Tensor Tensor::Slice(int begin_idx, int end_idx) const {
size_t base = numel() / dims_[0];
Tensor dst;
dst.holder_ = holder_;
+ dst.set_layout(layout_);
DDim dst_dims = dims_;
dst_dims[0] = end_idx - begin_idx;
dst.Resize(dst_dims);
diff --git a/paddle/framework/tensor_test.cc b/paddle/framework/tensor_test.cc
index ceca64365a1a628642eb374a3e3bbdff490c955a..a1b4a03289eca4c8b9d8c23ede4221853cb31f79 100644
--- a/paddle/framework/tensor_test.cc
+++ b/paddle/framework/tensor_test.cc
@@ -15,12 +15,13 @@
#include
#include
+namespace framework = paddle::framework;
+namespace platform = paddle::platform;
+
TEST(Tensor, Dims) {
- using namespace paddle::framework;
- using namespace paddle::platform;
- Tensor tt;
+ framework::Tensor tt;
tt.Resize({2, 3, 4});
- DDim dims = tt.dims();
+ framework::DDim dims = tt.dims();
ASSERT_EQ(arity(dims), 3);
for (int i = 0; i < 3; ++i) {
EXPECT_EQ(i + 2, dims[i]);
@@ -28,12 +29,12 @@ TEST(Tensor, Dims) {
}
TEST(Tensor, DataAssert) {
- paddle::framework::Tensor src_tensor;
+ framework::Tensor src_tensor;
bool caught = false;
try {
src_tensor.data();
- } catch (paddle::platform::EnforceNotMet err) {
+ } catch (platform::EnforceNotMet err) {
caught = true;
std::string msg =
"holder_ should not be null\nTensor holds no memory. Call "
@@ -50,61 +51,65 @@ TEST(Tensor, DataAssert) {
because Memory::Alloc() and Memory::Free() have not been ready.
*/
TEST(Tensor, MutableData) {
- using namespace paddle::framework;
- using namespace paddle::platform;
{
- Tensor src_tensor;
+ framework::Tensor src_tensor;
float* p1 = nullptr;
float* p2 = nullptr;
// initialization
- p1 = src_tensor.mutable_data(make_ddim({1, 2, 3}), CPUPlace());
+ p1 = src_tensor.mutable_data(framework::make_ddim({1, 2, 3}),
+ platform::CPUPlace());
EXPECT_NE(p1, nullptr);
// set src_tensor a new dim with large size
// momery is supposed to be re-allocated
- p2 = src_tensor.mutable_data(make_ddim({3, 4}), CPUPlace());
+ p2 = src_tensor.mutable_data(framework::make_ddim({3, 4}),
+ platform::CPUPlace());
EXPECT_NE(p2, nullptr);
EXPECT_NE(p1, p2);
// set src_tensor a new dim with same size
// momery block is supposed to be unchanged
- p1 = src_tensor.mutable_data(make_ddim({2, 2, 3}), CPUPlace());
+ p1 = src_tensor.mutable_data(framework::make_ddim({2, 2, 3}),
+ platform::CPUPlace());
EXPECT_EQ(p1, p2);
// set src_tensor a new dim with smaller size
// momery block is supposed to be unchanged
- p2 = src_tensor.mutable_data(make_ddim({2, 2}), CPUPlace());
+ p2 = src_tensor.mutable_data(framework::make_ddim({2, 2}),
+ platform::CPUPlace());
EXPECT_EQ(p1, p2);
}
#ifdef PADDLE_WITH_CUDA
{
- Tensor src_tensor;
+ framework::Tensor src_tensor;
float* p1 = nullptr;
float* p2 = nullptr;
// initialization
- p1 = src_tensor.mutable_data(make_ddim({1, 2, 3}), GPUPlace());
+ p1 = src_tensor.mutable_data(framework::make_ddim({1, 2, 3}),
+ platform::CUDAPlace());
EXPECT_NE(p1, nullptr);
// set src_tensor a new dim with large size
// momery is supposed to be re-allocated
- p2 = src_tensor.mutable_data(make_ddim({3, 4}), GPUPlace());
+ p2 = src_tensor.mutable_data(framework::make_ddim({3, 4}),
+ platform::CUDAPlace());
EXPECT_NE(p2, nullptr);
EXPECT_NE(p1, p2);
// set src_tensor a new dim with same size
// momery block is supposed to be unchanged
- p1 = src_tensor.mutable_data(make_ddim({2, 2, 3}), GPUPlace());
+ p1 = src_tensor.mutable_data(framework::make_ddim({2, 2, 3}),
+ platform::CUDAPlace());
EXPECT_EQ(p1, p2);
// set src_tensor a new dim with smaller size
// momery block is supposed to be unchanged
- p2 = src_tensor.mutable_data(make_ddim({2, 2}), GPUPlace());
+ p2 = src_tensor.mutable_data(framework::make_ddim({2, 2}),
+ platform::CUDAPlace());
EXPECT_EQ(p1, p2);
}
#endif
}
TEST(Tensor, ShareDataWith) {
- using namespace paddle::framework;
- using namespace paddle::platform;
{
- Tensor src_tensor;
- Tensor dst_tensor;
+ framework::Tensor src_tensor;
+ framework::Tensor dst_tensor;
// Try to share data form uninitialized tensor
bool caught = false;
try {
@@ -121,16 +126,18 @@ TEST(Tensor, ShareDataWith) {
}
ASSERT_TRUE(caught);
- src_tensor.mutable_data(make_ddim({2, 3, 4}), CPUPlace());
+ src_tensor.mutable_data(framework::make_ddim({2, 3, 4}),
+ platform::CPUPlace());
dst_tensor.ShareDataWith(src_tensor);
ASSERT_EQ(src_tensor.data(), dst_tensor.data());
}
#ifdef PADDLE_WITH_CUDA
{
- Tensor src_tensor;
- Tensor dst_tensor;
- src_tensor.mutable_data(make_ddim({2, 3, 4}), GPUPlace());
+ framework::Tensor src_tensor;
+ framework::Tensor dst_tensor;
+ src_tensor.mutable_data(framework::make_ddim({2, 3, 4}),
+ platform::CUDAPlace());
dst_tensor.ShareDataWith(src_tensor);
ASSERT_EQ(src_tensor.data(), dst_tensor.data());
}
@@ -138,13 +145,12 @@ TEST(Tensor, ShareDataWith) {
}
TEST(Tensor, Slice) {
- using namespace paddle::framework;
- using namespace paddle::platform;
{
- Tensor src_tensor;
- src_tensor.mutable_data(make_ddim({5, 3, 4}), CPUPlace());
- Tensor slice_tensor = src_tensor.Slice(1, 3);
- DDim slice_dims = slice_tensor.dims();
+ framework::Tensor src_tensor;
+ src_tensor.mutable_data(framework::make_ddim({5, 3, 4}),
+ platform::CPUPlace());
+ framework::Tensor slice_tensor = src_tensor.Slice(1, 3);
+ framework::DDim slice_dims = slice_tensor.dims();
ASSERT_EQ(arity(slice_dims), 3);
EXPECT_EQ(slice_dims[0], 2);
EXPECT_EQ(slice_dims[1], 3);
@@ -153,11 +159,12 @@ TEST(Tensor, Slice) {
uintptr_t src_data_address =
reinterpret_cast(src_tensor.data());
uintptr_t src_mutable_data_address = reinterpret_cast(
- src_tensor.mutable_data(src_tensor.dims(), CPUPlace()));
+ src_tensor.mutable_data(src_tensor.dims(), platform::CPUPlace()));
uintptr_t slice_data_address =
reinterpret_cast(slice_tensor.data());
- uintptr_t slice_mutable_data_address = reinterpret_cast(
- slice_tensor.mutable_data(slice_tensor.dims(), CPUPlace()));
+ uintptr_t slice_mutable_data_address =
+ reinterpret_cast(slice_tensor.mutable_data(
+ slice_tensor.dims(), platform::CPUPlace()));
EXPECT_EQ(src_data_address, src_mutable_data_address);
EXPECT_EQ(slice_data_address, slice_mutable_data_address);
EXPECT_EQ(src_data_address + 3 * 4 * 1 * sizeof(int), slice_data_address);
@@ -165,22 +172,25 @@ TEST(Tensor, Slice) {
#ifdef PADDLE_WITH_CUDA
{
- Tensor src_tensor;
- src_tensor.mutable_data(make_ddim({6, 9}), GPUPlace());
- Tensor slice_tensor = src_tensor.Slice(2, 6);
- DDim slice_dims = slice_tensor.dims();
+ framework::Tensor src_tensor;
+ src_tensor.mutable_data(framework::make_ddim({6, 9}),
+ platform::CUDAPlace());
+ framework::Tensor slice_tensor = src_tensor.Slice(2, 6);
+ framework::DDim slice_dims = slice_tensor.dims();
ASSERT_EQ(arity(slice_dims), 2);
EXPECT_EQ(slice_dims[0], 4);
EXPECT_EQ(slice_dims[1], 9);
uintptr_t src_data_address =
reinterpret_cast(src_tensor.data());
- uintptr_t src_mutable_data_address = reinterpret_cast(
- src_tensor.mutable_data(src_tensor.dims(), GPUPlace()));
+ uintptr_t src_mutable_data_address =
+ reinterpret_cast(src_tensor.mutable_data(
+ src_tensor.dims(), platform::CUDAPlace()));
uintptr_t slice_data_address =
reinterpret_cast(slice_tensor.data());
- uintptr_t slice_mutable_data_address = reinterpret_cast(
- slice_tensor.mutable_data(slice_tensor.dims(), GPUPlace()));
+ uintptr_t slice_mutable_data_address =
+ reinterpret_cast(slice_tensor.mutable_data(
+ slice_tensor.dims(), platform::CUDAPlace()));
EXPECT_EQ(src_data_address, src_mutable_data_address);
EXPECT_EQ(slice_data_address, slice_mutable_data_address);
EXPECT_EQ(src_data_address + 9 * 2 * sizeof(double), slice_data_address);
@@ -189,14 +199,19 @@ TEST(Tensor, Slice) {
}
TEST(Tensor, ReshapeToMatrix) {
- using namespace paddle::framework;
- using namespace paddle::platform;
- Tensor src;
- int* src_ptr = src.mutable_data({2, 3, 4, 9}, CPUPlace());
+ framework::Tensor src;
+ int* src_ptr = src.mutable_data({2, 3, 4, 9}, platform::CPUPlace());
for (int i = 0; i < 2 * 3 * 4 * 9; ++i) {
src_ptr[i] = i;
}
- Tensor res = ReshapeToMatrix(src, 2);
+ framework::Tensor res = framework::ReshapeToMatrix(src, 2);
ASSERT_EQ(res.dims()[0], 2 * 3);
ASSERT_EQ(res.dims()[1], 4 * 9);
}
+
+TEST(Tensor, Layout) {
+ framework::Tensor src;
+ ASSERT_EQ(src.layout(), framework::DataLayout::kNHWC);
+ src.set_layout(framework::DataLayout::kAnyLayout);
+ ASSERT_EQ(src.layout(), framework::DataLayout::kAnyLayout);
+}
diff --git a/paddle/framework/tensor_util.cc b/paddle/framework/tensor_util.cc
new file mode 100644
index 0000000000000000000000000000000000000000..7efc649d0bcda67c663d148e83bcbb6789b0f371
--- /dev/null
+++ b/paddle/framework/tensor_util.cc
@@ -0,0 +1,119 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License. */
+
+#include "paddle/framework/tensor_util.h"
+
+namespace paddle {
+namespace framework {
+template
+struct AnyDTypeVisitor {
+ Predicate predicate_;
+ const Tensor& tensor_;
+ const DevCtx& ctx_;
+ Tensor* out_;
+
+ AnyDTypeVisitor(Predicate predicate, const Tensor& tensor, const DevCtx& ctx,
+ Tensor* out)
+ : predicate_(predicate), tensor_(tensor), ctx_(ctx), out_(out) {}
+
+ template
+ void operator()() const {
+ auto t = EigenVector::Flatten(tensor_);
+ auto o = EigenScalar