diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt index af4079875a50ffe6eb627492f834fb601bbee716..ed5f6310f4e1212844948dc8c2555e527b4d10e8 100644 --- a/paddle/framework/CMakeLists.txt +++ b/paddle/framework/CMakeLists.txt @@ -32,10 +32,12 @@ cc_test(threadpool_test SRCS threadpool_test.cc DEPS threadpool) cc_library(scope SRCS scope.cc DEPS glog threadpool) cc_test(scope_test SRCS scope_test.cc DEPS scope) -cc_library(device_data_transform SRCS device_data_transform.cc DEPS tensor) +cc_library(data_device_transform SRCS data_device_transform.cc DEPS tensor) +cc_library(data_type_transform SRCS data_type_transform.cc DEPS tensor) +cc_library(data_layout_transform SRCS data_layout_transform.cc DEPS tensor math_function) -cc_library(data_transform SRCS data_transform.cc DEPS math_function tensor framework_proto selected_rows device_data_transform) -cc_test(data_transform_test SRCS data_transform_test.cc DEPS data_transform device_context) +cc_library(data_transform SRCS data_transform.cc DEPS math_function tensor + framework_proto selected_rows data_device_transform data_type_transform data_layout_transform) cc_library(attribute SRCS attribute.cc DEPS framework_proto) cc_test(program_desc_test SRCS program_desc_test.cc DEPS proto_desc @@ -80,5 +82,5 @@ cc_test(init_test SRCS init_test.cc DEPS init) cc_test(op_kernel_type_test SRCS op_kernel_type_test.cc DEPS place device_context framework_proto) cc_test(cow_ptr_tests SRCS details/cow_ptr_test.cc) -nv_test(device_data_transform_test SRCS device_data_transform_test.cu +nv_test(data_device_transform_test SRCS data_device_transform_test.cu DEPS operator op_registry init math_function) diff --git a/paddle/framework/device_data_transform.cc b/paddle/framework/data_device_transform.cc similarity index 96% rename from paddle/framework/device_data_transform.cc rename to paddle/framework/data_device_transform.cc index cd5104cc6f287315ed9d22aa2ec6414f7204d214..b3fd48ae12c368ac7d83c4f3b6e2fb1939932ac0 100644 --- a/paddle/framework/device_data_transform.cc +++ b/paddle/framework/data_device_transform.cc @@ -11,7 +11,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/framework/device_data_transform.h" +#include "paddle/framework/data_device_transform.h" namespace paddle { namespace framework { diff --git a/paddle/framework/device_data_transform.h b/paddle/framework/data_device_transform.h similarity index 100% rename from paddle/framework/device_data_transform.h rename to paddle/framework/data_device_transform.h diff --git a/paddle/framework/device_data_transform_test.cu b/paddle/framework/data_device_transform_test.cu similarity index 100% rename from paddle/framework/device_data_transform_test.cu rename to paddle/framework/data_device_transform_test.cu diff --git a/paddle/framework/data_layout.h b/paddle/framework/data_layout.h index 4a8669c3a41fceaad26878a79eabfd0affce86fd..3ab976ecac4dfb0571ebf5dc93f726939da01116 100644 --- a/paddle/framework/data_layout.h +++ b/paddle/framework/data_layout.h @@ -13,7 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once -#include "paddle/platform/enforce.h" #include #include "paddle/platform/enforce.h" diff --git a/paddle/framework/data_layout_transform.cc b/paddle/framework/data_layout_transform.cc new file mode 100644 index 0000000000000000000000000000000000000000..96794cae97d460e86fe83ac1395e1dfc7e371e3b --- /dev/null +++ b/paddle/framework/data_layout_transform.cc @@ -0,0 +1,82 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/framework/data_layout_transform.h" + +#include "paddle/framework/tensor.h" +#include "paddle/operators/math/math_function.h" + +namespace paddle { +namespace framework { + +struct CastDataLayout { + CastDataLayout(const platform::DeviceContext* ctx, + const std::vector& axis, const framework::Tensor& in, + framework::Tensor* out) + : in_(in), out_(out), ctx_(ctx), axis_(axis) {} + const framework::Tensor in_; + framework::Tensor* out_; + const platform::DeviceContext* ctx_; + const std::vector axis_; + + template + void operator()() { + auto place = ctx_->GetPlace(); + + if (platform::is_cpu_place(place)) { + operators::math::Transpose trans4; + auto* context = static_cast(ctx_); + trans4(*context, in_, out_, axis_); + } else { + PADDLE_THROW("Unsupport CPU <-> GPU!"); + } + } +}; + +void TransDataLayout(const std::vector& axis, + const platform::DeviceContext* ctx, + const KernelTypePair& kernel_pair, const Variable& in, + Variable* out) { + PADDLE_ENFORCE(in.IsType(), "Only support Tensor transform!."); + PADDLE_ENFORCE( + platform::places_are_same_class(kernel_pair.first.place_, + kernel_pair.second.place_), + "TransDataLayout only support DataLayout transform on same place!"); + PADDLE_ENFORCE(kernel_pair.first.data_type_ == kernel_pair.second.data_type_, + "TransDataLayout only support Datatype are same!"); + + auto src = in.Get(); + auto* dst = out->GetMutable(); + PADDLE_ENFORCE(arity(src.dims()) == 4, "Input Arity Only Suppport 4!"); + + auto src_dim = src.dims(); + std::vector dst_dim; + + dst_dim.resize(axis.size()); + for (size_t i = 0; i < axis.size(); i++) { + dst_dim[i] = src_dim[axis[i]]; + } + + dst->Resize(make_ddim(dst_dim)); + auto place = kernel_pair.second.place_; + dst->mutable_data(place, src.type()); + + auto src_type = kernel_pair.first.data_type_; + framework::VisitDataType(src_type, CastDataLayout(ctx, axis, src, dst)); + + dst->set_layout(kernel_pair.second.data_layout_); +} + +} // namespace framework +} // namespace paddle diff --git a/paddle/framework/data_layout_transform.h b/paddle/framework/data_layout_transform.h new file mode 100644 index 0000000000000000000000000000000000000000..befae1f63616a4c21d998c6b784b8ef288d00617 --- /dev/null +++ b/paddle/framework/data_layout_transform.h @@ -0,0 +1,31 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/framework/op_kernel_type.h" +#include "paddle/framework/variable.h" + +namespace paddle { +namespace framework { + +using KernelTypePair = std::pair; + +void TransDataLayout(const std::vector& axis, + const platform::DeviceContext* ctx, + const KernelTypePair& kernel_pair, const Variable& in, + Variable* out); + +} // namespace framework +} // namespace paddle diff --git a/paddle/framework/data_transform.cc b/paddle/framework/data_transform.cc index fed958db1584c4fda5394d59a2ef8936045a9ce9..e56edb95396ef8de44da95ce795161d7cf1debc6 100644 --- a/paddle/framework/data_transform.cc +++ b/paddle/framework/data_transform.cc @@ -11,22 +11,14 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include #include "paddle/framework/data_transform.h" -#include "paddle/framework/device_data_transform.h" -#include "paddle/framework/lod_tensor.h" -#include "paddle/framework/selected_rows.h" -#include "paddle/platform/device_context.h" + +#include "paddle/framework/data_device_transform.h" namespace paddle { namespace framework { -DataTransformFnMap& DataTransformFnMap::Instance() { - static DataTransformFnMap data_transform_map; - return data_transform_map; -} - Tensor* DataTransform(const OpKernelType& expected_kernel_type, const OpKernelType& kernel_type_for_var, const Tensor& input_tensor) { @@ -58,134 +50,5 @@ void CopyVariableWithTensor(const Variable& in_var, const Tensor& tensor, } } -auto KernelFP32 = OpKernelType(proto::DataType::FP32, platform::CPUPlace(), - DataLayout::kNHWC, LibraryType::kPlain); - -auto KernelFP64 = OpKernelType(proto::DataType::FP64, platform::CPUPlace(), - DataLayout::kNHWC, LibraryType::kPlain); - -auto KernelNHWC = OpKernelType(proto::DataType::FP64, platform::CPUPlace(), - DataLayout::kNHWC, LibraryType::kPlain); - -auto KernelNCHW = OpKernelType(proto::DataType::FP64, platform::CPUPlace(), - DataLayout::kNCHW, LibraryType::kPlain); - -// TODO(dzhwinter): Only for testing multiple op kernel. -// Dummy transform function for library_type -// should be removed. -auto KernelPlain = OpKernelType(proto::DataType::FP32, platform::CUDAPlace(0), - DataLayout::kAnyLayout, LibraryType::kPlain); - -auto KernelCUDNN = OpKernelType(proto::DataType::FP32, platform::CUDAPlace(0), - DataLayout::kAnyLayout, LibraryType::kCUDNN); - -void DummyTrans(const platform::DeviceContext* ctx, - const KernelTypePair& kernel_pair, const Variable& in, - Variable* out) { - PADDLE_ENFORCE(in.IsType(), "Only Support Tensor transform!."); - PADDLE_ENFORCE( - platform::places_are_same_class(kernel_pair.first.place_, - kernel_pair.second.place_), - "TransDataType Only Support DataType transform on same place!"); - auto src = in.Get(); - auto* dst = out->GetMutable(); - *dst = src; -} - -void TransDataType(const platform::DeviceContext* ctx, - const KernelTypePair& kernel_pair, const Variable& in, - Variable* out) { - PADDLE_ENFORCE(in.IsType(), "Only Support Tensor transform!."); - PADDLE_ENFORCE( - platform::places_are_same_class(kernel_pair.first.place_, - kernel_pair.second.place_), - "TransDataType Only Support DataType transform on same place!"); - - auto src = in.Get(); - auto* dst = out->GetMutable(); - - auto dims = src.dims(); - dst->Resize(dims); - auto dst_type = kernel_pair.second.data_type_; - auto src_type = kernel_pair.first.data_type_; - - switch (src_type) { - case proto::DataType::FP32: - framework::VisitDataType(dst_type, CastDataType(src, dst, ctx)); - break; - case proto::DataType::FP64: - framework::VisitDataType(dst_type, CastDataType(src, dst, ctx)); - break; - case proto::DataType::INT32: - framework::VisitDataType(dst_type, CastDataType(src, dst, ctx)); - break; - case proto::DataType::INT64: - framework::VisitDataType(dst_type, CastDataType(src, dst, ctx)); - break; - case proto::DataType::BOOL: - framework::VisitDataType(dst_type, CastDataType(src, dst, ctx)); - break; - default: - PADDLE_THROW("Not support type %d", src_type); - } -} - -void TransDataLayout(const std::vector& axis, - const platform::DeviceContext* ctx, - const KernelTypePair& kernel_pair, const Variable& in, - Variable* out) { - PADDLE_ENFORCE(in.IsType(), "Only support Tensor transform!."); - PADDLE_ENFORCE( - platform::places_are_same_class(kernel_pair.first.place_, - kernel_pair.second.place_), - "TransDataLayout only support DataLayout transform on same place!"); - PADDLE_ENFORCE(kernel_pair.first.data_type_ == kernel_pair.second.data_type_, - "TransDataLayout only support Datatype are same!"); - - auto src = in.Get(); - auto* dst = out->GetMutable(); - PADDLE_ENFORCE(arity(src.dims()) == 4, "Input Arity Only Suppport 4!"); - - auto src_dim = src.dims(); - std::vector dst_dim; - - dst_dim.resize(axis.size()); - for (size_t i = 0; i < axis.size(); i++) { - dst_dim[i] = src_dim[axis[i]]; - } - - dst->Resize(make_ddim(dst_dim)); - auto place = kernel_pair.second.place_; - dst->mutable_data(place, src.type()); - - auto src_type = kernel_pair.first.data_type_; - framework::VisitDataType(src_type, CastDataLayout(ctx, axis, src, dst)); - - dst->set_layout(kernel_pair.second.data_layout_); -} - } // namespace framework } // namespace paddle - -namespace f = paddle::framework; - -namespace { -std::vector NHWC2NCHW = {0, 3, 1, 2}; -std::vector NCHW2NHWC = {0, 2, 3, 1}; -} - -REGISTER_DATA_TRANSFORM_FN(f::KernelFP32, f::KernelFP64, f::TransDataType); -REGISTER_DATA_TRANSFORM_FN(f::KernelPlain, f::KernelCUDNN, f::DummyTrans); -REGISTER_DATA_TRANSFORM_FN(f::KernelCUDNN, f::KernelPlain, f::DummyTrans); -REGISTER_DATA_TRANSFORM_FN(f::KernelNHWC, f::KernelNCHW, - std::bind(f::TransDataLayout, NHWC2NCHW, - std::placeholders::_1, - std::placeholders::_2, - std::placeholders::_3, - std::placeholders::_4)); -REGISTER_DATA_TRANSFORM_FN(f::KernelNCHW, f::KernelNHWC, - std::bind(f::TransDataLayout, NCHW2NHWC, - std::placeholders::_1, - std::placeholders::_2, - std::placeholders::_3, - std::placeholders::_4)); diff --git a/paddle/framework/data_transform.h b/paddle/framework/data_transform.h index e4e5c30a96a3c985ae2ecd494b723c8afeceb12f..ee95c7e8564d0392c8f25fce161d0f722c04761a 100644 --- a/paddle/framework/data_transform.h +++ b/paddle/framework/data_transform.h @@ -30,26 +30,6 @@ limitations under the License. */ namespace paddle { namespace framework { -using KernelTypePair = std::pair; - -using DataTransformFn = - std::function; - -struct KernelTypePairHash { - static void HashCombine(const OpKernelType& t, std::size_t* seed) { - OpKernelType::Hash kernel_type_hasher; - (*seed) ^= kernel_type_hasher(t) + 0x9e3779b9 + (*seed << 6) + (*seed >> 2); - } - - size_t operator()(const KernelTypePair& kernel_pair) const { - std::size_t seed = 0; - HashCombine(kernel_pair.first, &seed); - HashCombine(kernel_pair.second, &seed); - return seed; - } -}; - Tensor* DataTransform(const OpKernelType& expected_kernel_type, const OpKernelType& kernel_type_for_var, const Tensor& input_tensor); @@ -57,125 +37,5 @@ Tensor* DataTransform(const OpKernelType& expected_kernel_type, void CopyVariableWithTensor(const Variable& in_var, const Tensor& tensor, Variable& out_var); -template -struct CastDataTypeFunctor { - HOSTDEVICE inline OutType operator()(InType in) const { - return static_cast(in); - } -}; - -template -struct CastDataType { - CastDataType(const framework::Tensor& in, framework::Tensor* out, - const platform::DeviceContext* ctx) - : in_(in), out_(out), ctx_(ctx) {} - const framework::Tensor in_; - framework::Tensor* out_; - const platform::DeviceContext* ctx_; - - template - void operator()() { - auto place = ctx_->GetPlace(); - - auto* in_begin = in_.data(); - auto numel = in_.numel(); - auto* in_end = in_begin + numel; - auto* out_begin = out_->mutable_data(place); - - if (platform::is_cpu_place(place)) { - platform::Transform trans; - auto* context = static_cast(ctx_); - trans(*context, in_begin, in_end, out_begin, - CastDataTypeFunctor()); - } else { - // TODO(dzhwinter): enhance Copy CPU<->GPU with different data type? - PADDLE_THROW("Unsupport CPU <-> GPU!"); - } - } -}; - -struct CastDataLayout { - CastDataLayout(const platform::DeviceContext* ctx, - const std::vector& axis, const framework::Tensor& in, - framework::Tensor* out) - : in_(in), out_(out), ctx_(ctx), axis_(axis) {} - const framework::Tensor in_; - framework::Tensor* out_; - const platform::DeviceContext* ctx_; - const std::vector axis_; - - template - void operator()() { - auto place = ctx_->GetPlace(); - - if (platform::is_cpu_place(place)) { - operators::math::Transpose trans4; - auto* context = static_cast(ctx_); - trans4(*context, in_, out_, axis_); - } else { - PADDLE_THROW("Unsupport CPU <-> GPU!"); - } - } -}; - -using DataTransformMap = - std::unordered_map; - -class DataTransformFnMap { - public: - static DataTransformFnMap& Instance(); - - bool Has(const KernelTypePair& key_pair) const { - return map_.find(key_pair) != map_.end(); - } - - void Insert(const OpKernelType& left, const OpKernelType& right, - const DataTransformFn& data_tranform_fn) { - Insert(std::make_pair(left, right), data_tranform_fn); - } - - void Insert(const KernelTypePair& kernel_type_pair, - const DataTransformFn& data_tranform_fn) { - PADDLE_ENFORCE(!Has(kernel_type_pair), - "KernelTypePair %s has been registered", ""); - map_.insert({kernel_type_pair, data_tranform_fn}); - } - - const DataTransformFn& Get(const KernelTypePair& key_pair) const { - auto data_transformer = GetNullable(key_pair); - PADDLE_ENFORCE_NOT_NULL(data_transformer, - "DataTransformFn should not be NULL"); - return *data_transformer; - } - - const DataTransformFn* GetNullable(const KernelTypePair& key_pair) const { - auto it = map_.find(key_pair); - if (it == map_.end()) { - return nullptr; - } else { - return &(it->second); - } - } - - const DataTransformMap& Map() const { return map_; } - - private: - DataTransformFnMap() = default; - DataTransformMap map_; - DISABLE_COPY_AND_ASSIGN(DataTransformFnMap); -}; - -// generate unique name with __LINE__ -// refs https://stackoverflow.com/questions/1597007 -#define TOKENPASTE(x, y) x##y -#define TOKENPASTE2(x, y) TOKENPASTE(x, y) -#define REGISTER_DATA_TRANSFORM_FN(from, to, fn) \ - static int TOKENPASTE2(fn_, __LINE__)() { \ - ::paddle::framework::DataTransformFnMap::Instance().Insert(from, to, fn); \ - return 0; \ - } \ - static int TOKENPASTE2(var_, __LINE__) __attribute__((unused)) = \ - TOKENPASTE2(fn_, __LINE__)() - } // namespace framework } // namespace paddle diff --git a/paddle/framework/data_transform_test.cc b/paddle/framework/data_transform_test.cc deleted file mode 100644 index edd305fd17ae202926b83fbec10089719baa2e16..0000000000000000000000000000000000000000 --- a/paddle/framework/data_transform_test.cc +++ /dev/null @@ -1,168 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ -#include -#include - -#include - -#include "paddle/framework/data_transform.h" -#include "paddle/platform/device_context.h" - -namespace paddle { -namespace framework { -using namespace platform; - -/** - * @brief cross validation of different kernel type transform - * We use four bit map represent different combination. - * If the field has multiple possible value, only choose two of them. - * For DataType, only test the FP32(float), FP64(double). - * e.g. 0000 -> FP32, CPUPlace, kNHWC, kPlain - * 1111 -> FP64, GPUPlace, kNCHW, kMKLDNN - */ - -std::array kDataType = { - {proto::DataType::FP32, proto::DataType::FP64}}; - -std::array kPlace = {{CPUPlace(), CUDAPlace(0)}}; - -std::array kDataLayout = {{ - DataLayout::kNHWC, DataLayout::kNCHW, -}}; - -std::array kLibraryType = {{ - LibraryType::kPlain, LibraryType::kMKLDNN, -}}; - -OpKernelType GenFromBit(const std::vector bits) { - return OpKernelType(kDataType[bits[0]], kPlace[bits[1]], kDataLayout[bits[2]], - kLibraryType[bits[3]]); -} - -int test_value = 0; - -auto kernel0 = GenFromBit({0, 0, 0, 0}); -auto kernel1 = GenFromBit({0, 0, 0, 1}); -auto kernel2 = GenFromBit({0, 0, 1, 0}); -auto kernel3 = GenFromBit({0, 0, 1, 1}); - -void TransDataType_t(const platform::DeviceContext* ctx, - const KernelTypePair& p, const Variable& in, - Variable* out) { - test_value++; -} - -void TransDataLayout_t(const platform::DeviceContext* ctx, - const KernelTypePair& p, const Variable& in, - Variable* out) { - test_value--; -} - -void TransLibraryType_t(const platform::DeviceContext* ctx, - const KernelTypePair& p, const Variable& in, - Variable* out) { - test_value += 2; -} - -} // namespace framework -} // namespace paddle - -namespace frw = paddle::framework; - -REGISTER_DATA_TRANSFORM_FN(frw::kernel0, frw::kernel1, frw::TransDataType_t); -REGISTER_DATA_TRANSFORM_FN(frw::kernel1, frw::kernel2, frw::TransDataLayout_t); -REGISTER_DATA_TRANSFORM_FN(frw::kernel0, frw::kernel2, frw::TransLibraryType_t); - -TEST(DataTransform, Register) { - using namespace paddle::framework; - using namespace paddle::platform; - - auto& instance = DataTransformFnMap::Instance(); - paddle::framework::Variable in; - paddle::framework::Variable out; - - DeviceContext* ctx = new CPUDeviceContext(); - auto pair0 = std::make_pair(frw::kernel0, frw::kernel1); - instance.Get(pair0)(ctx, pair0, in, &out); - ASSERT_EQ(test_value, 1); - - auto pair1 = std::make_pair(frw::kernel1, frw::kernel2); - instance.Get(pair1)(ctx, pair1, in, &out); - ASSERT_EQ(test_value, 0); - - auto pair3 = std::make_pair(frw::kernel0, frw::kernel2); - instance.Get(pair3)(ctx, pair3, in, &out); - ASSERT_EQ(test_value, 2); -} - -TEST(DataTransform, DataLayout) { - using namespace paddle::framework; - using namespace paddle::platform; - - auto& instance = DataTransformFnMap::Instance(); - Variable in; - Variable out; - Tensor* src = in.GetMutable(); - src->mutable_data(make_ddim({2, 3, 1, 2}), CPUPlace()); - src->set_layout(DataLayout::kNHWC); - - DeviceContext* ctx = new CPUDeviceContext(); - - { - auto kernel1 = GenFromBit({1, 0, 0, 0}); - auto kernel2 = GenFromBit({1, 0, 1, 0}); - auto pair0 = std::make_pair(kernel1, kernel2); - instance.Get(pair0)(ctx, pair0, in, &out); - } - - Tensor dst = out.Get(); - - EXPECT_TRUE(dst.layout() == DataLayout::kNCHW); - EXPECT_TRUE(dst.dims() == make_ddim({2, 2, 3, 1})); - - { - auto kernel1 = GenFromBit({1, 0, 1, 0}); - auto kernel2 = GenFromBit({1, 0, 0, 0}); - auto pair0 = std::make_pair(kernel1, kernel2); - instance.Get(pair0)(ctx, pair0, out, &in); - } - - EXPECT_TRUE(src->layout() == DataLayout::kNHWC); - EXPECT_TRUE(src->dims() == make_ddim({2, 3, 1, 2})); -} - -TEST(DataTransform, DataType) { - using namespace paddle::framework; - using namespace paddle::platform; - - auto& instance = DataTransformFnMap::Instance(); - DeviceContext* ctx = new CPUDeviceContext(); - - Variable in; - Variable out; - Tensor* src = in.GetMutable(); - float* ptr = src->mutable_data(make_ddim({2, 3}), CPUPlace()); - for (int i = 0; i < 6; ++i) { - ptr[i] = i / 3; - } - - { - auto kernel1 = GenFromBit({0, 0, 0, 0}); - auto kernel2 = GenFromBit({1, 0, 0, 0}); - auto pair0 = std::make_pair(kernel1, kernel2); - instance.Get(pair0)(ctx, pair0, in, &out); - } - Tensor dst = out.Get(); - EXPECT_TRUE(dst.data() != nullptr); -} diff --git a/paddle/framework/data_type_transform.cc b/paddle/framework/data_type_transform.cc new file mode 100644 index 0000000000000000000000000000000000000000..63373232e910d44eb0996f9280f9c166ad092030 --- /dev/null +++ b/paddle/framework/data_type_transform.cc @@ -0,0 +1,99 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/framework/data_type_transform.h" + +#include "paddle/framework/selected_rows.h" +#include "paddle/platform/transform.h" + +namespace paddle { +namespace framework { + +template +struct CastDataTypeFunctor { + HOSTDEVICE inline OutType operator()(InType in) const { + return static_cast(in); + } +}; + +template +struct CastDataType { + CastDataType(const framework::Tensor& in, framework::Tensor* out, + const platform::DeviceContext* ctx) + : in_(in), out_(out), ctx_(ctx) {} + const framework::Tensor in_; + framework::Tensor* out_; + const platform::DeviceContext* ctx_; + + template + void operator()() { + auto place = ctx_->GetPlace(); + + auto* in_begin = in_.data(); + auto numel = in_.numel(); + auto* in_end = in_begin + numel; + auto* out_begin = out_->mutable_data(place); + + if (platform::is_cpu_place(place)) { + platform::Transform trans; + auto* context = static_cast(ctx_); + trans(*context, in_begin, in_end, out_begin, + CastDataTypeFunctor()); + } else { + // TODO(dzhwinter): enhance Copy CPU<->GPU with different data type? + PADDLE_THROW("Unsupport CPU <-> GPU!"); + } + } +}; + +void TransDataType(const platform::DeviceContext* ctx, + const KernelTypePair& kernel_pair, const Variable& in, + Variable* out) { + PADDLE_ENFORCE(in.IsType(), "Only Support Tensor transform!."); + PADDLE_ENFORCE( + platform::places_are_same_class(kernel_pair.first.place_, + kernel_pair.second.place_), + "TransDataType Only Support DataType transform on same place!"); + + auto src = in.Get(); + auto* dst = out->GetMutable(); + + auto dims = src.dims(); + dst->Resize(dims); + auto dst_type = kernel_pair.second.data_type_; + auto src_type = kernel_pair.first.data_type_; + + switch (src_type) { + case proto::DataType::FP32: + framework::VisitDataType(dst_type, CastDataType(src, dst, ctx)); + break; + case proto::DataType::FP64: + framework::VisitDataType(dst_type, CastDataType(src, dst, ctx)); + break; + case proto::DataType::INT32: + framework::VisitDataType(dst_type, CastDataType(src, dst, ctx)); + break; + case proto::DataType::INT64: + framework::VisitDataType(dst_type, CastDataType(src, dst, ctx)); + break; + case proto::DataType::BOOL: + framework::VisitDataType(dst_type, CastDataType(src, dst, ctx)); + break; + default: + PADDLE_THROW("Not support type %d", src_type); + } +} + +} // namespace framework +} // namespace paddle diff --git a/paddle/framework/data_type_transform.h b/paddle/framework/data_type_transform.h new file mode 100644 index 0000000000000000000000000000000000000000..8ec90742256c2308a242d993838e46e51a6fc167 --- /dev/null +++ b/paddle/framework/data_type_transform.h @@ -0,0 +1,31 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/framework/op_kernel_type.h" +#include "paddle/framework/variable.h" +#include "paddle/platform/device_context.h" + +namespace paddle { +namespace framework { + +using KernelTypePair = std::pair; + +void TransDataType(const platform::DeviceContext* ctx, + const KernelTypePair& kernel_pair, const Variable& in, + Variable* out); + +} // namespace framework +} // namespace paddle diff --git a/paddle/framework/lod_tensor.cc b/paddle/framework/lod_tensor.cc index 506fde440533e83f093f26484f925416b89c75a0..7ae94c646537e0d7c4687b949a1b06cd3a7f3404 100644 --- a/paddle/framework/lod_tensor.cc +++ b/paddle/framework/lod_tensor.cc @@ -44,9 +44,19 @@ std::ostream &operator<<(std::ostream &os, const LoD &lod) { } std::ostream &operator<<(std::ostream &os, const LoDTensor &t) { - PADDLE_ENFORCE(platform::is_cpu_place(t.place())); PADDLE_ENFORCE(t.type().hash_code() == typeid(float).hash_code()); + if (!platform::is_cpu_place(t.place())) { + LoDTensor tt; + framework::Copy(t, platform::CPUPlace(), &tt); + platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); + auto &dev_ctx = *pool.Get(t.place()); + dev_ctx.Wait(); + + os << tt; + return os; + } + os << "dim: " << t.dims() << "\n"; os << "lod: " << t.lod() << "\n"; @@ -211,38 +221,23 @@ void DeserializeFromStream(std::istream &is, LoDTensor *tensor, DeserializeFromStream(is, static_cast(tensor), dev_ctx); } +// TODO(tonyyang-svail): make this function support LoD std::vector LoDTensor::SplitLoDTensor( const std::vector places) const { check_memory_size(); - // PADDLE_ENFORCE(lod().empty() || (lod().size() == 1 && lod()[0].empty()) - // , "Disable parallel lod for now"); PADDLE_ENFORCE(lod().empty(), "Disable parallel lod for now"); PADDLE_ENFORCE(dims()[0] % places.size() == 0, "Batch size should be divided by places size"); std::vector lods; for (size_t place_idx = 0; place_idx < places.size(); ++place_idx) { - size_t begin = place_idx * dims()[0] / places.size(); - size_t end = (place_idx + 1) * dims()[0] / places.size(); - auto src = Slice(static_cast(begin), static_cast(end)); + int begin = place_idx * dims()[0] / places.size(); + int end = (place_idx + 1) * dims()[0] / places.size(); - LoDTensor dst; - dst.Resize(src.dims()); + auto src = Slice(begin, end); auto &dst_place = places[place_idx]; - auto dst_ptr = dst.mutable_data(dst_place, src.type()); - - // TODO(tonyyang-svail): - // change the following to framework::Copy - auto src_place = src.place(); - auto src_ptr = src.data(); - auto size = src.numel() * SizeOfType(src.type()); - if (platform::is_cpu_place(src_place) && - platform::is_cpu_place(dst_place)) { - memory::Copy(boost::get(dst_place), dst_ptr, - boost::get(src_place), src_ptr, size); - } else { - PADDLE_THROW("Not Implemented"); - } + LoDTensor dst; + framework::Copy(src, dst_place, &dst); lods.emplace_back(dst); } @@ -250,28 +245,30 @@ std::vector LoDTensor::SplitLoDTensor( return lods; } +// TODO(tonyyang-svail): make this function support LoD void LoDTensor::MergeLoDTensor( - const std::vector &lod_tensors, platform::Place place) { - PADDLE_ENFORCE(platform::is_cpu_place(place)); + const std::vector &lod_tensors, + platform::Place dst_place) { PADDLE_ENFORCE(!lod_tensors.empty()); - framework::DDim new_dim = lod_tensors[0]->dims(); std::type_index new_type = lod_tensors[0]->type(); + auto new_layout = lod_tensors[0]->layout(); for (auto *lod : lod_tensors) { PADDLE_ENFORCE(new_dim == lod->dims()); PADDLE_ENFORCE(new_type == lod->type()); - PADDLE_ENFORCE(platform::is_cpu_place(lod->place())); + PADDLE_ENFORCE(new_layout == lod->layout()); } new_dim[0] *= lod_tensors.size(); Resize(new_dim); + set_layout(new_layout); - auto *dst_ptr = reinterpret_cast(mutable_data(place, new_type)); + mutable_data(dst_place, new_type); + int begin = 0; for (auto *src : lod_tensors) { - auto size = src->numel() * SizeOfType(src->type()); - memory::Copy(boost::get(place), dst_ptr, - boost::get(src->place()), - src->data(), size); - dst_ptr += size; + int end = begin + src->dims()[0]; + auto dst = Slice(begin, end); + framework::Copy(*src, dst_place, &dst); + begin = end; } } diff --git a/paddle/framework/operator.cc b/paddle/framework/operator.cc index 35ebe48ba682f135b7f85edb3b2999db7c29e51a..ef2c55cc3799ba2fac54f3c9370505b63ef22ad3 100644 --- a/paddle/framework/operator.cc +++ b/paddle/framework/operator.cc @@ -16,7 +16,6 @@ limitations under the License. */ #include #include "paddle/framework/data_transform.h" -#include "paddle/framework/device_data_transform.h" #include "paddle/framework/executor.h" #include "paddle/framework/operator.h" #include "paddle/framework/shape_inference.h" diff --git a/paddle/framework/tensor_util.h b/paddle/framework/tensor_util.h index 7c56ccf17f94e29d06f529629c47f61b93d2bd22..f541d2ba693a169d074c070dd794a2dd4e52aabf 100644 --- a/paddle/framework/tensor_util.h +++ b/paddle/framework/tensor_util.h @@ -31,9 +31,10 @@ namespace framework { * * @note Copy supports CPU <-> GPU, GPU <-> GPU. */ - inline void Copy(const Tensor& src, const platform::Place& dst_place, const platform::DeviceContext& ctx, Tensor* dst) { + VLOG(3) << "Copy " << src.dims() << " from " << src.place() << " to " + << dst_place; src.check_memory_size(); dst->Resize(src.dims()); @@ -88,26 +89,25 @@ inline void Copy(const Tensor& src, const platform::Place& dst_place, } /** - * @brief Copy supports CPU <-> CPU + * @brief Wrapper on + * Copy(const Tensor& src, const platform::Place& dst_place, + * const platform::DeviceContext& ctx, Tensor* dst); + * + * @param[in] src The external tensor. + * @param[in] dst_place The dst place. + * + * @note Copy supports CPU <-> GPU, GPU <-> GPU. */ inline void Copy(const Tensor& src, const platform::Place& dst_place, Tensor* dst) { - src.check_memory_size(); - dst->Resize(src.dims()); - dst->set_layout(src.layout()); - - auto src_place = src.place(); - auto src_ptr = src.data(); - - auto dst_ptr = dst->mutable_data(dst_place, src.type()); - - auto size = src.numel() * SizeOfType(src.type()); - - PADDLE_ENFORCE(platform::is_cpu_place(src_place) && - platform::is_cpu_place(dst_place)); - - memory::Copy(boost::get(dst_place), dst_ptr, - boost::get(src_place), src_ptr, size); + platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); + const platform::DeviceContext* dev_ctx; + if (platform::is_gpu_place(src.place())) { + dev_ctx = pool.Get(src.place()); + } else { + dev_ctx = pool.Get(dst_place); + } + Copy(src, dst_place, *dev_ctx, dst); } /** diff --git a/paddle/framework/var_desc.cc b/paddle/framework/var_desc.cc index aeab18d7214f8d9dd79bc3d2e0322490445b3b49..62ab6593ef23c195e3caa2336574796ecaf35bc8 100644 --- a/paddle/framework/var_desc.cc +++ b/paddle/framework/var_desc.cc @@ -74,7 +74,7 @@ const proto::TensorDesc &VarDesc::tensor_desc() const { case proto::VarDesc::LOD_TENSOR_ARRAY: return desc_.tensor_array().tensor(); default: - PADDLE_THROW("The type of var '", this->Name(), "' is unsupported."); + PADDLE_THROW("The type of var %s is unsupported.", this->Name()); } } diff --git a/paddle/gserver/layers/MKLDNNLayer.cpp b/paddle/gserver/layers/MKLDNNLayer.cpp index 6fbf3c7fdec2f537769adb660c67c5a597beb609..2d0fff608c3932ede57ddcbb6a85cda255b77246 100644 --- a/paddle/gserver/layers/MKLDNNLayer.cpp +++ b/paddle/gserver/layers/MKLDNNLayer.cpp @@ -132,6 +132,8 @@ void MKLDNNLayer::reshapeInput(int& batchsize, if (w != 0) { width = w; } + height = height != 0 ? height : 1; + width = width != 0 ? width : 1; } void MKLDNNLayer::reshapeOutput(size_t height, size_t width) { diff --git a/paddle/gserver/layers/MKLDNNLayer.h b/paddle/gserver/layers/MKLDNNLayer.h index e48b9b5a91f7f17cb3f31e9140f1428ba8954a20..3ba39f18b6ae4b38c17b3b72a359fd514e6dad5b 100644 --- a/paddle/gserver/layers/MKLDNNLayer.h +++ b/paddle/gserver/layers/MKLDNNLayer.h @@ -98,6 +98,8 @@ protected: public: explicit MKLDNNLayer(const LayerConfig& config) : Layer(config), + ih_(0), + iw_(0), condition_(0), needResetBwd_(true), outputOnlyMKLDNN_(false), diff --git a/paddle/operators/edit_distance_op.cc b/paddle/operators/edit_distance_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..e383f07fa9b53a3def10f6405a0d36f48f52ff08 --- /dev/null +++ b/paddle/operators/edit_distance_op.cc @@ -0,0 +1,98 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/edit_distance_op.h" + +namespace paddle { +namespace operators { + +class EditDistanceOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("Hyps"), "Input(Hyps) shouldn't be null."); + PADDLE_ENFORCE(ctx->HasInput("Refs"), "Input(Refs) shouldn't be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) shouldn't be null."); + auto hyp_dims = ctx->GetInputDim("Hyps"); + auto ref_dims = ctx->GetInputDim("Refs"); + PADDLE_ENFORCE(hyp_dims.size() == 2 && hyp_dims[1] == 1, + "Input(Hyps) must be a 2-D LoDTensor with the 2nd dimension " + "equal to 1."); + PADDLE_ENFORCE(ref_dims.size() == 2 && ref_dims[1] == 1, + "Input(Refs) must be a 2-D LoDTensor with the 2nd dimension " + "equal to 1."); + ctx->SetOutputDim("Out", ctx->GetInputDim("Refs")); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + return framework::OpKernelType(framework::proto::DataType::FP32, + ctx.device_context()); + } +}; + +class EditDistanceOpMaker : public framework::OpProtoAndCheckerMaker { + public: + EditDistanceOpMaker(OpProto *proto, OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("Hyps", + "(2-D LoDTensor, 2nd dim. equal to 1) " + "The indices for hypothesis strings."); + AddInput("Refs", + "(2-D LoDTensor, 2nd dim. equal to 1) " + "The indices for reference strings."); + AddAttr("normalized", + "(bool, default false) Indicated whether to normalize " + "the edit distance by the length of reference string.") + .SetDefault(false); + AddOutput("Out", + "(2-D Tensor with shape [`batch_size` x 1]) " + "The output edit distances of EditDistance operator."); + AddComment(R"DOC( + +EditDistance operator computes the edit distances between a batch of hypothesis +strings and their references. + +Edit distance, also called Levenshtein distance, measures how dissimilar two strings +are by counting the minimum number of operations to transform one string into anthor. +Here the operations include insertion, deletion, and substitution. For example, +given hypothesis string A = "kitten" and reference B = "sitting", the edit distance +is 3 for A will be transformed into B at least after two substitutions and one +insertion: + + "kitten" -> "sitten" -> "sittin" -> "sitting" + +Input(Hyps) is a LoDTensor consisting of all the hypothesis strings with the total +number denoted by `batch_size`, and the separation is specified by the LoD information. +And the `batch_size` reference strings are arranged in order in the same way in the +LoDTensor Input(Refs). + +Output(Out) contains the `batch_size` results and each stands for the edit stance +for a pair of strings respectively. If Attr(normalized) is true, the edit distance +will be divided by the length of reference string. +)DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OPERATOR(edit_distance, ops::EditDistanceOp, ops::EditDistanceOpMaker, + paddle::framework::EmptyGradOpMaker); +REGISTER_OP_CPU_KERNEL( + edit_distance, ops::EditDistanceKernel); diff --git a/paddle/operators/edit_distance_op.cu b/paddle/operators/edit_distance_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..cf5ebc5c38fd006d10de790e45e9bff3409bd20c --- /dev/null +++ b/paddle/operators/edit_distance_op.cu @@ -0,0 +1,149 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include "paddle/framework/op_registry.h" +#include "paddle/platform/cuda_helper.h" +#include "paddle/platform/gpu_info.h" + +namespace paddle { +namespace operators { + +using platform::PADDLE_CUDA_NUM_THREADS; + +template +__global__ void FillFirstRow(T* dist, const int N) { + int idx = blockDim.x * blockIdx.x + threadIdx.x; + if (idx < N + 1) { + dist[idx] = idx; + } +} + +template +__global__ void FillFirstColumn(T* dist, const int M, const int N) { + int idx = blockDim.x * blockIdx.x + threadIdx.x; + if (idx < M + 1) { + dist[idx * (N + 1)] = idx; + } +} + +template +__global__ void Levenshtein(T* dist, const int* x1, const int* x2, const int M, + const int N, const int start) { + int idx = blockDim.x * blockIdx.x + threadIdx.x; + int offset = N; + int index = start + idx * offset; + int row = index / (N + 1); + int col = index % (N + 1); + if (row > 0 && col > 0 && row < M + 1 && col < N + 1) { + int cost = x1[row - 1] == x2[col - 1] ? 0 : 1; + int dels = dist[(row - 1) * (N + 1) + col] + 1; + int ins = dist[row * (N + 1) + col - 1] + 1; + int subs = dist[(row - 1) * (N + 1) + (col - 1)] + cost; + dist[index] = min(dels, min(ins, subs)); + } +} + +template +__global__ void SetOutput(T* out, const T* dist, const int M, const int N, + bool normalized) { + int idx = blockDim.x * blockIdx.x + threadIdx.x; + if (idx == 0) { + out[0] = normalized ? dist[M * (N + 1) + N] / N : dist[M * (N + 1) + N]; + } +} + +template +class EditDistanceGPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const { + auto* out_t = ctx.Output("Out"); + + auto* x1_t = ctx.Input("Hyps"); + auto* x2_t = ctx.Input("Refs"); + + auto normalized = ctx.Attr("normalized"); + auto stream = reinterpret_cast( + ctx.device_context()) + .stream(); + + auto hyp_lod = x1_t->lod()[0]; + auto ref_lod = x2_t->lod()[0]; + PADDLE_ENFORCE( + hyp_lod.size() == ref_lod.size(), + "Input(Hyps) and Input(Refs) must have the same batch size."); + for (size_t i = 1; i < ref_lod.size(); ++i) { + PADDLE_ENFORCE(ref_lod[i] > ref_lod[i - 1], + "Reference string %d is empty.", i); + } + + auto num_strs = hyp_lod.size() - 1; + out_t->Resize({static_cast(num_strs), 1}); + out_t->mutable_data(ctx.GetPlace()); + auto out = out_t->data(); + + T distance = 0.0; + for (size_t num = 0; num < num_strs; num++) { + auto m = static_cast(hyp_lod[num + 1] - hyp_lod[num]); + auto n = static_cast(ref_lod[num + 1] - ref_lod[num]); + if (m == 0 || n == 0) { + distance = std::max(m, n); + if (normalized) { + PADDLE_ENFORCE(n > 0, + "The reference string (#%d) cannot be empty " + "when Attr(normalized) is enabled.", + n); + distance = distance / n; + } + memory::Copy(boost::get(ctx.GetPlace()), out + num, + platform::CPUPlace(), &distance, sizeof(T), stream); + } else { + framework::Tensor dist_t; + dist_t.Resize({m + 1, n + 1}); + dist_t.mutable_data(ctx.GetPlace()); + auto dist = dist_t.data(); + auto x1 = x1_t->data() + hyp_lod[num]; + auto x2 = x2_t->data() + ref_lod[num]; + + FillFirstColumn<<<1 + m / PADDLE_CUDA_NUM_THREADS, + PADDLE_CUDA_NUM_THREADS, 0, stream>>>(dist, m, n); + + FillFirstRow<<<1 + n / PADDLE_CUDA_NUM_THREADS, + PADDLE_CUDA_NUM_THREADS, 0, stream>>>(dist, n); + // Compute the elements of distance matrix in the anti-diagonal diretion + for (int64_t slice = 2; slice < m + n + 1; ++slice) { + int z_m = slice < m + 1 ? 0 : slice - m; + int z_n = slice < n + 1 ? 0 : slice - n; + int size = slice - (z_m + z_n) + 1; // number of elments in the same + // anti-diagonal line to update + // the start index at which computes from + int start = slice < n + 1 ? slice : (z_n + 1) * (n + 1) - 1; + Levenshtein<<<1 + (size - 1) / PADDLE_CUDA_NUM_THREADS, + PADDLE_CUDA_NUM_THREADS, 0, stream>>>(dist, x1, x2, + m, n, start); + } + SetOutput<<<1, 1, 0, stream>>>(out + num, dist, m, n, normalized); + } + } + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OP_CUDA_KERNEL( + edit_distance, + ops::EditDistanceGPUKernel); diff --git a/paddle/operators/edit_distance_op.h b/paddle/operators/edit_distance_op.h new file mode 100644 index 0000000000000000000000000000000000000000..537e70281a5a750db480468a8f8e3c0465de6c5a --- /dev/null +++ b/paddle/operators/edit_distance_op.h @@ -0,0 +1,96 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include "paddle/framework/eigen.h" +#include "paddle/framework/op_registry.h" + +namespace paddle { +namespace operators { + +template +class EditDistanceKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const { + auto* out_t = ctx.Output("Out"); + + auto* x1_t = ctx.Input("Hyps"); + auto* x2_t = ctx.Input("Refs"); + + auto normalized = ctx.Attr("normalized"); + + auto hyp_lod = x1_t->lod()[0]; + auto ref_lod = x2_t->lod()[0]; + PADDLE_ENFORCE( + hyp_lod.size() == ref_lod.size(), + "Input(Hyps) and Input(Refs) must have the same batch size."); + for (size_t i = 1; i < ref_lod.size(); ++i) { + PADDLE_ENFORCE(ref_lod[i] > ref_lod[i - 1], + "Reference string %d is empty.", i); + } + auto num_strs = hyp_lod.size() - 1; + + out_t->Resize({static_cast(num_strs), 1}); + out_t->mutable_data(ctx.GetPlace()); + auto out = out_t->data(); + + T distance = 0.0; + for (size_t num = 0; num < num_strs; ++num) { + auto m = static_cast(hyp_lod[num + 1] - hyp_lod[num]); + auto n = static_cast(ref_lod[num + 1] - ref_lod[num]); + + if (m == 0) { + distance = n; + } else if (n == 0) { + distance = m; + } else { + framework::Tensor dist_t; + dist_t.Resize({m + 1, n + 1}); + dist_t.mutable_data(ctx.GetPlace()); + auto dist = dist_t.data(); + auto x1 = x1_t->data() + hyp_lod[num]; + auto x2 = x2_t->data() + ref_lod[num]; + for (int64_t i = 0; i < m + 1; ++i) { + dist[i * (n + 1)] = i; + } + for (int64_t j = 0; j < n + 1; ++j) { + dist[j] = j; + } + for (int64_t i = 1; i < m + 1; ++i) { + for (int64_t j = 1; j < n + 1; ++j) { + int cost = x1[i - 1] == x2[j - 1] ? 0 : 1; + int dels = dist[(i - 1) * (n + 1) + j] + 1; + int ins = dist[i * (n + 1) + (j - 1)] + 1; + int subs = dist[(i - 1) * (n + 1) + (j - 1)] + cost; + dist[i * (n + 1) + j] = std::min(dels, std::min(ins, subs)); + } + } + distance = dist[m * (n + 1) + n]; + } + + if (normalized) { + PADDLE_ENFORCE(n > 0, + "The reference string (#%d) cannot be empty " + "when Attr(normalized) is enabled.", + n); + distance = distance / n; + } + out[num] = distance; + } + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/get_places_op.cc b/paddle/operators/get_places_op.cc index 291bbbcb3a736a0b0b7ac9cff211896ef1e7a49c..24fafb23074c358669715f1840246c3520f948c7 100644 --- a/paddle/operators/get_places_op.cc +++ b/paddle/operators/get_places_op.cc @@ -39,17 +39,19 @@ class GetPlacesOp : public framework::OperatorBase { : OperatorBase(type, inputs, outputs, attrs) {} void Run(const framework::Scope &scope, const platform::Place &place) const override { - std::string device_type = Attr("device_type"); + bool is_gpu; + if (Attr("device_type") == "AUTO") { + is_gpu = platform::is_gpu_place(place); + } else { + is_gpu = Attr("device_type") == "CUDA"; + } auto device_count = static_cast(Attr("device_count")); if (device_count == 0) { - if (device_type == "CUDA") { - device_count = CUDADevCount(); - } else if (device_type == "CPU") { - device_count = std::thread::hardware_concurrency(); - } + device_count = + is_gpu ? CUDADevCount() : std::thread::hardware_concurrency(); } PADDLE_ENFORCE_NE(device_count, 0, "Cannot indicate %s device count", - device_type); + is_gpu ? "GPU" : "CPU"); auto out_var_name = Output("Out"); auto &places = @@ -57,14 +59,14 @@ class GetPlacesOp : public framework::OperatorBase { "Output variable %s cannot be found", out_var_name) .GetMutable()); places.reserve(device_count); - if (device_type == "CUDA") { + if (is_gpu) { PADDLE_ENFORCE_LE(device_count, CUDADevCount(), "Only %d CUDA devices found, cannot set to %d", CUDADevCount(), device_count); for (size_t i = 0; i < device_count; ++i) { - places.emplace_back(platform::CUDAPlace(i)); + places.emplace_back(platform::CUDAPlace(static_cast(i))); } - } else if (device_type == "CPU") { + } else { for (size_t i = 0; i < device_count; ++i) { places.emplace_back(platform::CPUPlace()); } @@ -77,10 +79,10 @@ class GetPlacesOpProtoMaker : public framework::OpProtoAndCheckerMaker { GetPlacesOpProtoMaker(OpProto *proto, OpAttrChecker *op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { AddOutput("Out", "vector of Place"); - AddAttr("device_count", "device count").SetDefault(1); - AddAttr("device_type", - R"(device type must be in ["CPU", "CUDA"])") - .InEnum({"CPU", "CUDA"}); + AddAttr("device_count", "device count").SetDefault(0); + AddAttr("device_type", "device type") + .InEnum({"CUDA", "CPU", "AUTO"}) + .SetDefault("AUTO"); AddComment(R"DOC( Returns a list of places based on flags. The list will be used for parallel execution. @@ -111,4 +113,5 @@ class GetPlacesInferShape : public framework::InferShapeBase { namespace ops = paddle::operators; REGISTER_OPERATOR(get_places, ops::GetPlacesOp, ops::GetPlacesOpProtoMaker, - ops::GetPlacesInferVarType, ops::GetPlacesInferShape); + ops::GetPlacesInferVarType, ops::GetPlacesInferShape, + paddle::framework::EmptyGradOpMaker); diff --git a/paddle/operators/parallel_do_op.cc b/paddle/operators/parallel_do_op.cc index a6bc70f4c89fb24cef5aefcb69b97fbaa9dc9d9c..e1bec0421e76143bef669a4f6fa373cdf01226b2 100644 --- a/paddle/operators/parallel_do_op.cc +++ b/paddle/operators/parallel_do_op.cc @@ -39,6 +39,7 @@ void SplitTensorAndMoveTensorToScopes( const std::vector &sub_scopes, const std::vector &places, const std::vector &names) { + PADDLE_ENFORCE_EQ(sub_scopes.size(), places.size()); for (auto &argu : names) { auto *var = scope.FindVar(argu); const auto &tensor = var->Get(); @@ -54,6 +55,15 @@ void SplitTensorAndMoveTensorToScopes( } } +void WaitOnPlaces(const std::vector places) { + platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); + + for (auto &place : places) { + auto &dev_ctx = *pool.Get(place); + dev_ctx.Wait(); + } +} + class ParallelDoOp : public framework::OperatorBase { public: ParallelDoOp(const std::string &type, @@ -71,10 +81,7 @@ class ParallelDoOp : public framework::OperatorBase { auto *block = Attr(kParallelBlock); auto *program = block->Program(); - // TODO(tonyyang-svail): get places from input - std::vector places; - places.emplace_back(platform::CPUPlace()); - places.emplace_back(platform::CPUPlace()); + auto &places = scope.FindVar(Input(kPlaces))->Get(); auto &sub_scopes = *scope.FindVar(Output(kParallelScopes)) ->GetMutable>(); @@ -82,8 +89,22 @@ class ParallelDoOp : public framework::OperatorBase { sub_scopes.push_back(&scope.NewScope()); } + // split input SplitTensorAndMoveTensorToScopes(scope, sub_scopes, places, Inputs(kInputs)); + // copy parameter + for (auto ¶m : Inputs(kParameters)) { + PADDLE_ENFORCE(scope.FindVar(param)->IsType(), + "Only support parameter type as LoDTensor"); + auto &src = scope.FindVar(param)->Get(); + for (size_t i = 0; i < places.size(); ++i) { + auto &place = places[i]; + auto *sub_scope = sub_scopes[i]; + auto *dst = sub_scope->Var(param)->GetMutable(); + framework::Copy(src, place, dst); + } + } + WaitOnPlaces(places); std::vector> workers; workers.reserve(places.size()); @@ -93,12 +114,6 @@ class ParallelDoOp : public framework::OperatorBase { auto &place = places[place_idx]; auto *cur_scope = sub_scopes[place_idx]; - // copy parameter - // some version of boost lacks != for boost::variant - if (!(dev_ctx.GetPlace() == place)) { - PADDLE_THROW("Not Implemented"); - } - workers.emplace_back(framework::Async([program, cur_scope, place, block] { framework::Executor executor(place); executor.Run(*program, cur_scope, block->ID(), @@ -108,6 +123,7 @@ class ParallelDoOp : public framework::OperatorBase { for (auto &worker : workers) { worker.wait(); } + WaitOnPlaces(places); // merge output for (auto &o_name : Outputs(kOutputs)) { @@ -121,6 +137,7 @@ class ParallelDoOp : public framework::OperatorBase { scope.FindVar(o_name)->GetMutable(); lod_tensor_to_be_merged->MergeLoDTensor(lod_tensors, dev_ctx.GetPlace()); } + WaitOnPlaces(places); } }; @@ -161,15 +178,14 @@ class ParallelDoGradOp : public OperatorBase { auto &sub_scopes = scope.FindVar(Input(kParallelScopes)) ->Get>(); - // TODO(tonyyang-svail): get places from input - std::vector places; - places.emplace_back(platform::CPUPlace()); - places.emplace_back(platform::CPUPlace()); + auto &places = scope.FindVar(Input(kPlaces))->Get(); // feed output@grad SplitTensorAndMoveTensorToScopes(scope, sub_scopes, places, Inputs(framework::GradVarName(kOutputs))); + WaitOnPlaces(places); + // for debugging for (auto &s : Inputs(framework::GradVarName(kOutputs))) { VLOG(3) << s; VLOG(3) << scope.FindVar(s)->Get(); @@ -196,10 +212,11 @@ class ParallelDoGradOp : public OperatorBase { for (auto &worker : workers) { worker.wait(); } + WaitOnPlaces(places); // merge grad for (auto &s : Outputs(framework::GradVarName(kParameters))) { - VLOG(3) << s; + VLOG(3) << "merge grad " << s; auto &t = sub_scopes[0]->FindVar(s)->Get(); VLOG(3) << t; @@ -216,7 +233,8 @@ class ParallelDoGradOp : public OperatorBase { auto sum_op = framework::OpRegistry::CreateOp( "sum", {{"X", {s, s_buf}}}, {{"Out", {s}}}, framework::AttributeMap{}); - sum_op->Run(*sub_scopes[0], place); + sum_op->Run(*sub_scopes[0], places[0]); + WaitOnPlaces(places); } VLOG(3) << t; @@ -236,8 +254,10 @@ class ParallelDoGradOpDescMaker : public framework::SingleGradOpDescMaker { for (auto &input_param : this->InputNames()) { VLOG(3) << input_param; grad->SetInput(input_param, this->Input(input_param)); - grad->SetOutput(framework::GradVarName(input_param), - this->InputGrad(input_param, false)); + if (input_param != kPlaces) { + grad->SetOutput(framework::GradVarName(input_param), + this->InputGrad(input_param, false)); + } } for (auto &output_param : this->OutputNames()) { diff --git a/paddle/operators/shrink_rnn_memory_op.cc b/paddle/operators/shrink_rnn_memory_op.cc index 821754a0a632e15643eaeff4133174eb75c9700f..3f5b2a9b84350c7dee5cb461ba6207e20e95c11b 100644 --- a/paddle/operators/shrink_rnn_memory_op.cc +++ b/paddle/operators/shrink_rnn_memory_op.cc @@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/framework/lod_rank_table.h" +#include "paddle/framework/lod_tensor.h" #include "paddle/operators/array_operator.h" #include "paddle/operators/math/math_function.h" @@ -46,8 +47,21 @@ class ShrinkRNNMemoryOp : public ArrayOp { auto *out_var = scope.FindVar(Output("Out")); PADDLE_ENFORCE(out_var != nullptr, "Output Out must be set"); auto &out_tensor = *out_var->GetMutable(); + + size_t height = dst_num_rows; + + // do shrink for the top level LoD + if (x_tensor.lod().size() > 0 && + x_tensor.lod()[0].size() > static_cast(dst_num_rows)) { + auto lod_offset = framework::GetSubLoDAndAbsoluteOffset(x_tensor.lod(), 0, + dst_num_rows, 0); + height = lod_offset.second.second; + auto out_lod = out_tensor.mutable_lod(); + framework::AppendLoD(out_lod, lod_offset.first); + } + if (dst_num_rows != 0) { - out_tensor.ShareDataWith(x_tensor.Slice(0, dst_num_rows)); + out_tensor.ShareDataWith(x_tensor.Slice(0, height)); } } }; @@ -64,11 +78,11 @@ class ShrinkRNNMemoryOpProtoMaker : public framework::OpProtoAndCheckerMaker { AddOutput("Out", "(LoDTensor) The shrinked RNN step memory."); AddComment( R"DOC( - In dynamic RNN, we are able to handle sequences of different lengths. - Because of the multiple lengths, the size of each step input can be + In dynamic RNN, we are able to handle sequences of different lengths. + Because of the multiple lengths, the size of each step input can be different, which may lead to a mismatching between the input of - the current step and the memory generated by the previous one. This - operator shrinks memory according to the size of the next step input, + the current step and the memory generated by the previous one. This + operator shrinks memory according to the size of the next step input, to make sure that they can match each other. )DOC"); } @@ -132,6 +146,7 @@ class ShrinkRNNMemoryGradInferShape : public framework::InferShapeBase { PADDLE_ENFORCE(context->HasOutput(framework::GradVarName("X"))); context->SetOutputDim(framework::GradVarName("X"), context->GetInputDim("X")); + context->ShareLoD("X", framework::GradVarName("X")); } }; diff --git a/paddle/scripts/submit_local.sh.in b/paddle/scripts/submit_local.sh.in index 8a352b0078d701f797f7202c85bd0e08201ac9b8..bb47ad614ed85923ce5d9704760ec6c5b5ae59ee 100755 --- a/paddle/scripts/submit_local.sh.in +++ b/paddle/scripts/submit_local.sh.in @@ -92,6 +92,9 @@ function threads_config() { if [ -z "$OPENBLAS_NUM_THREADS" ]; then export OPENBLAS_NUM_THREADS=$threads fi + if [ $threads -gt 1 ] && [ -z "$OPENBLAS_MAIN_FREE" ]; then + export OPENBLAS_MAIN_FREE=1 + fi fi } diff --git a/python/paddle/v2/__init__.py b/python/paddle/v2/__init__.py index 0de417df2cb942ce46ff8ac3acc61ae4999ed634..df710c33d0c0ca16d358dac1eb42327e9cd4c7ae 100644 --- a/python/paddle/v2/__init__.py +++ b/python/paddle/v2/__init__.py @@ -62,12 +62,15 @@ __all__ = [ cp.begin_parse() -def set_omp_mkl_env_vars(trainer_count): +def set_env_vars(trainer_count): '''Auto set CPU environment if have not set before. - export KMP_AFFINITY, OMP_DYNAMIC according to the Hyper Threading status. - export OMP_NUM_THREADS, MKL_NUM_THREADS according to trainer_count. + For MKL: + export KMP_AFFINITY, OMP_DYNAMIC according to the Hyper Threading status. + export OMP_NUM_THREADS, MKL_NUM_THREADS according to trainer_count. + For OpenBLAS: + export OPENBLAS_NUM_THREADS, OPENBLAS_MAIN_FREE according to trainer_count. ''' - import platform + import platform, paddle if not platform.system() in ['Linux', 'Darwin']: return @@ -103,16 +106,22 @@ def set_omp_mkl_env_vars(trainer_count): num_cores = num_physical_cores() num_processors = num_logical_processors() - if num_processors > num_cores: # Hyper Threading is enabled - set_env("OMP_DYNAMIC", "true") - set_env("KMP_AFFINITY", "granularity=fine,compact,1,0") - else: - set_env("OMP_DYNAMIC", "false") - set_env("KMP_AFFINITY", "granularity=fine,compact,0,0") + if paddle.version.mkl() == 'ON': + if num_processors > num_cores: # Hyper Threading is enabled + set_env("OMP_DYNAMIC", "true") + set_env("KMP_AFFINITY", "granularity=fine,compact,1,0") + else: + set_env("OMP_DYNAMIC", "false") + set_env("KMP_AFFINITY", "granularity=fine,compact,0,0") threads = num_processors / trainer_count threads = '1' if threads < 1 else str(threads) - set_env("OMP_NUM_THREADS", threads) - set_env("MKL_NUM_THREADS", threads) + if paddle.version.mkl() == 'ON': + set_env("OMP_NUM_THREADS", threads) + set_env("MKL_NUM_THREADS", threads) + else: + set_env("OPENBLAS_NUM_THREADS", threads) + if threads > 1: + set_env("OPENBLAS_MAIN_FREE", '1') def init(**kwargs): @@ -129,7 +138,7 @@ def init(**kwargs): for key in args_dict.keys(): args.append('--%s=%s' % (key, str(args_dict[key]))) - set_omp_mkl_env_vars(kwargs.get('trainer_count', 1)) + set_env_vars(kwargs.get('trainer_count', 1)) if 'use_gpu' in kwargs: cp.g_command_config_args['use_gpu'] = kwargs['use_gpu'] diff --git a/python/paddle/v2/fluid/layers/device.py b/python/paddle/v2/fluid/layers/device.py index c2355ed802000f4659147b103aee61e023bd847c..775d40e5b5ef0cbb0b62bdc0678f2368b7b1a59a 100644 --- a/python/paddle/v2/fluid/layers/device.py +++ b/python/paddle/v2/fluid/layers/device.py @@ -4,19 +4,22 @@ All util layers. from ..layer_helper import LayerHelper from ..framework import unique_name +from ..registry import autodoc __all__ = ['get_places'] -def get_places(device_count=0, device_type="CPU"): +@autodoc +def get_places(device_count=None, device_type=None): helper = LayerHelper('get_places', **locals()) out_places = helper.create_variable(name=unique_name(helper.name + ".out")) + attrs = dict() + if device_count is not None: + attrs['device_count'] = int(device_count) + if device_type is not None: + attrs['device_type'] = str(device_type) + helper.append_op( - type='get_places', - outputs={"Out": [out_places]}, - attrs={ - "device_type": device_type, - 'device_count': device_count, - }) + type='get_places', outputs={"Out": [out_places]}, attrs=attrs) return out_places diff --git a/python/paddle/v2/fluid/tests/test_edit_distance_op.py b/python/paddle/v2/fluid/tests/test_edit_distance_op.py new file mode 100644 index 0000000000000000000000000000000000000000..38e87728b387bb70a8921a2fe73a4e69701aabe9 --- /dev/null +++ b/python/paddle/v2/fluid/tests/test_edit_distance_op.py @@ -0,0 +1,94 @@ +import unittest +import numpy as np +from op_test import OpTest + + +def Levenshtein(hyp, ref): + """ Compute the Levenshtein distance between two strings. + + :param hyp: hypothesis string in index + :type hyp: list + :param ref: reference string in index + :type ref: list + """ + m = len(hyp) + n = len(ref) + if m == 0: + return n + if n == 0: + return m + + dist = np.zeros((m + 1, n + 1)).astype("float32") + for i in range(0, m + 1): + dist[i][0] = i + for j in range(0, n + 1): + dist[0][j] = j + + for i in range(1, m + 1): + for j in range(1, n + 1): + cost = 0 if hyp[i - 1] == ref[j - 1] else 1 + deletion = dist[i - 1][j] + 1 + insertion = dist[i][j - 1] + 1 + substitution = dist[i - 1][j - 1] + cost + dist[i][j] = min(deletion, insertion, substitution) + return dist[m][n] + + +class TestEditDistanceOp(OpTest): + def setUp(self): + self.op_type = "edit_distance" + normalized = False + x1 = np.array([[0, 12, 3, 5, 8, 2]]).astype("int32") + x2 = np.array([[0, 12, 4, 7, 8]]).astype("int32") + x1 = np.transpose(x1) + x2 = np.transpose(x2) + x1_lod = [0, 1, 5] + x2_lod = [0, 3, 4] + + num_strs = len(x1_lod) - 1 + distance = np.zeros((num_strs, 1)).astype("float32") + for i in range(0, num_strs): + distance[i] = Levenshtein( + hyp=x1[x1_lod[i]:x1_lod[i + 1]], + ref=x2[x2_lod[i]:x2_lod[i + 1]]) + if normalized is True: + len_ref = x2_lod[i + 1] - x2_lod[i] + distance[i] = distance[i] / len_ref + self.attrs = {'normalized': normalized} + self.inputs = {'Hyps': (x1, [x1_lod]), 'Refs': (x2, [x2_lod])} + self.outputs = {'Out': distance} + + def test_check_output(self): + self.check_output() + + +class TestEditDistanceOpNormalized(OpTest): + def setUp(self): + self.op_type = "edit_distance" + normalized = True + x1 = np.array([[0, 10, 3, 6, 5, 8, 2]]).astype("int32") + x2 = np.array([[0, 10, 4, 6, 7, 8]]).astype("int32") + x1 = np.transpose(x1) + x2 = np.transpose(x2) + x1_lod = [0, 1, 3, 6] + x2_lod = [0, 2, 3, 5] + + num_strs = len(x1_lod) - 1 + distance = np.zeros((num_strs, 1)).astype("float32") + for i in range(0, num_strs): + distance[i] = Levenshtein( + hyp=x1[x1_lod[i]:x1_lod[i + 1]], + ref=x2[x2_lod[i]:x2_lod[i + 1]]) + if normalized is True: + len_ref = x2_lod[i + 1] - x2_lod[i] + distance[i] = distance[i] / len_ref + self.attrs = {'normalized': normalized} + self.inputs = {'Hyps': (x1, [x1_lod]), 'Refs': (x2, [x2_lod])} + self.outputs = {'Out': distance} + + def test_check_output(self): + self.check_output() + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/v2/fluid/tests/test_parallel_op.py b/python/paddle/v2/fluid/tests/test_parallel_op.py index 2788f4e519b31b45250fbb923b2309e8bb1f6fa1..59ed041e7fa1dd68c0f8d610f2575886442d1b4d 100644 --- a/python/paddle/v2/fluid/tests/test_parallel_op.py +++ b/python/paddle/v2/fluid/tests/test_parallel_op.py @@ -18,7 +18,7 @@ class ParallelOpTest(unittest.TestCase): append_batch_size=False, stop_gradient=False) - places = fluid.default_main_program().global_block().create_var() + places = layers.get_places(device_count=4) pd = layers.ParallelDo(places=places) with pd.do(): diff --git a/python/paddle/v2/fluid/tests/test_shrink_rnn_memory.py b/python/paddle/v2/fluid/tests/test_shrink_rnn_memory.py index be1588fc2d09fa58882425eb3d080ef1560ebc79..a14721b9aacfa7437623024af41555fd26990499 100644 --- a/python/paddle/v2/fluid/tests/test_shrink_rnn_memory.py +++ b/python/paddle/v2/fluid/tests/test_shrink_rnn_memory.py @@ -3,43 +3,86 @@ import paddle.v2.fluid.core as core from paddle.v2.fluid.executor import Executor import paddle.v2.fluid.layers as layers from paddle.v2.fluid.backward import append_backward -from paddle.v2.fluid.framework import default_main_program -import numpy +from paddle.v2.fluid.framework import default_main_program, switch_main_program +from paddle.v2.fluid.framework import Program +import numpy as np -main_program = default_main_program() - -class TestShrinkRNNMemory(unittest.TestCase): - def test_shrink_rnn_memory(self): +class TestShrinkRNNMemoryBase(unittest.TestCase): + def setUp(self): + self.main_program = Program() + switch_main_program(self.main_program) x = layers.data('x', shape=[100], dtype='float32') x.stop_gradient = False - table = layers.lod_rank_table(x=x) + rank_table_tensor = layers.data( + 'rank_table_tensor', shape=[1], dtype='float32', lod_level=1) + table = layers.lod_rank_table(x=rank_table_tensor) i = layers.zeros(dtype='int64', shape=[1]) - mem1 = layers.shrink_memory(x=x, i=i, table=table) + self.mem1 = layers.shrink_memory(x=x, i=i, table=table) i = layers.increment(x=i) i.stop_gradient = True - mem2 = layers.shrink_memory(x=mem1, i=i, table=table) + self.mem2 = layers.shrink_memory(x=self.mem1, i=i, table=table) i = layers.increment(x=i) i.stop_gradient = True - mem3 = layers.shrink_memory(x=mem2, i=i, table=table) + self.mem3 = layers.shrink_memory(x=self.mem2, i=i, table=table) + mem3_mean = layers.mean(x=self.mem3) + append_backward(loss=mem3_mean) + self.x_grad = self.main_program.global_block().var('x@GRAD') + + def sum_lodtensor(self, tensor): + sum_res = 0.0 + for i in xrange(np.product(tensor.get_dims())): + sum_res += tensor.get_float_element(i) + return sum_res + +class TestShrinkRNNMemoryReferLoD(TestShrinkRNNMemoryBase): + def test_refer_lod(self): cpu = core.CPUPlace() - tensor = core.LoDTensor() - tensor.set_lod([[0, 2, 5, 6]]) - tensor_np = numpy.random.random(size=(3, 100)).astype('float32') - tensor.set(tensor_np, cpu) + x_tensor = core.LoDTensor() + x_tensor.set_lod([[0, 2, 5, 6]]) + tensor_np = np.random.random(size=(6, 100)).astype('float32') + x_tensor.set(tensor_np, cpu) + + rank_table_tensor = core.LoDTensor() + rank_table_tensor.set_lod([[0, 1, 3, 6]]) + rank_table_tensor.set(np.random.random(size=(6, 1)).astype('float32'), + cpu) + exe = Executor(cpu) - outs = exe.run(feed={'x': tensor}, fetch_list=[mem1, mem2, mem3]) - self.assertTrue(numpy.allclose(tensor_np[0:3], outs[0])) - self.assertTrue(numpy.allclose(tensor_np[0:2], outs[1])) - self.assertTrue(numpy.allclose(tensor_np[0:1], outs[2])) + outs = exe.run( + feed={'x': x_tensor, + 'rank_table_tensor': rank_table_tensor}, + fetch_list=[self.mem1, self.mem2, self.mem3, self.x_grad], + return_numpy=False) + self.assertTrue(np.allclose(tensor_np[0:6], outs[0])) + self.assertTrue(np.allclose(tensor_np[0:5], outs[1])) + self.assertTrue(np.allclose(tensor_np[0:2], outs[2])) + self.assertAlmostEqual(1.0, self.sum_lodtensor(outs[3]), delta=0.01) - mem3_mean = layers.mean(x=mem3) - append_backward(loss=mem3_mean) - x_grad = exe.run( - feed={'x': tensor}, - fetch_list=[main_program.global_block().var('x@GRAD')])[0] - self.assertAlmostEqual(1.0, x_grad.sum(), delta=0.1) + +class TestShrinkRNNMemoryNoLoD(TestShrinkRNNMemoryBase): + def test_no_lod(self): + cpu = core.CPUPlace() + x_tensor = core.LoDTensor() + tensor_np = np.random.random(size=(3, 100)).astype('float32') + x_tensor.set(tensor_np, cpu) + + rank_table_tensor = core.LoDTensor() + rank_table_tensor.set_lod([[0, 1, 3, 6]]) + rank_table_tensor.set(np.random.random(size=(6, 1)).astype('float32'), + cpu) + + exe = Executor(cpu) + outs = exe.run( + feed={'x': x_tensor, + 'rank_table_tensor': rank_table_tensor}, + fetch_list=[self.mem1, self.mem2, self.mem3, self.x_grad], + return_numpy=False) + self.assertTrue(np.allclose(tensor_np[0:3], outs[0])) + self.assertTrue(np.allclose(tensor_np[0:2], outs[1])) + self.assertTrue(np.allclose(tensor_np[0:1], outs[2])) + self.assertAlmostEqual(1.0, self.sum_lodtensor(outs[3]), delta=0.01) if __name__ == '__main__': diff --git a/python/setup.py.in b/python/setup.py.in index 66ccfe808763d0e157f866ce08868e3fdebdea79..65ec58ecf98e693ecf02922129f6eef13cbe5303 100644 --- a/python/setup.py.in +++ b/python/setup.py.in @@ -31,6 +31,7 @@ patch = '%(patch)d' rc = '%(rc)d' istaged = %(istaged)s commit = '%(commit)s' +with_mkl = '%(with_mkl)s' def show(): if istaged: @@ -41,6 +42,9 @@ def show(): print 'rc:', rc else: print 'commit:', commit + +def mkl(): + return with_mkl ''' commit = git_commit() with open(filename, 'w') as f: @@ -51,7 +55,8 @@ def show(): 'rc': RC, 'version': '${PADDLE_VERSION}', 'commit': commit, - 'istaged': ISTAGED}) + 'istaged': ISTAGED, + 'with_mkl': '@WITH_MKL@'}) write_version_py(filename='@PADDLE_SOURCE_DIR@/python/paddle/version.py')