diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt index af4079875a50ffe6eb627492f834fb601bbee716..ed5f6310f4e1212844948dc8c2555e527b4d10e8 100644 --- a/paddle/framework/CMakeLists.txt +++ b/paddle/framework/CMakeLists.txt @@ -32,10 +32,12 @@ cc_test(threadpool_test SRCS threadpool_test.cc DEPS threadpool) cc_library(scope SRCS scope.cc DEPS glog threadpool) cc_test(scope_test SRCS scope_test.cc DEPS scope) -cc_library(device_data_transform SRCS device_data_transform.cc DEPS tensor) +cc_library(data_device_transform SRCS data_device_transform.cc DEPS tensor) +cc_library(data_type_transform SRCS data_type_transform.cc DEPS tensor) +cc_library(data_layout_transform SRCS data_layout_transform.cc DEPS tensor math_function) -cc_library(data_transform SRCS data_transform.cc DEPS math_function tensor framework_proto selected_rows device_data_transform) -cc_test(data_transform_test SRCS data_transform_test.cc DEPS data_transform device_context) +cc_library(data_transform SRCS data_transform.cc DEPS math_function tensor + framework_proto selected_rows data_device_transform data_type_transform data_layout_transform) cc_library(attribute SRCS attribute.cc DEPS framework_proto) cc_test(program_desc_test SRCS program_desc_test.cc DEPS proto_desc @@ -80,5 +82,5 @@ cc_test(init_test SRCS init_test.cc DEPS init) cc_test(op_kernel_type_test SRCS op_kernel_type_test.cc DEPS place device_context framework_proto) cc_test(cow_ptr_tests SRCS details/cow_ptr_test.cc) -nv_test(device_data_transform_test SRCS device_data_transform_test.cu +nv_test(data_device_transform_test SRCS data_device_transform_test.cu DEPS operator op_registry init math_function) diff --git a/paddle/framework/device_data_transform.cc b/paddle/framework/data_device_transform.cc similarity index 96% rename from paddle/framework/device_data_transform.cc rename to paddle/framework/data_device_transform.cc index cd5104cc6f287315ed9d22aa2ec6414f7204d214..b3fd48ae12c368ac7d83c4f3b6e2fb1939932ac0 100644 --- a/paddle/framework/device_data_transform.cc +++ b/paddle/framework/data_device_transform.cc @@ -11,7 +11,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/framework/device_data_transform.h" +#include "paddle/framework/data_device_transform.h" namespace paddle { namespace framework { diff --git a/paddle/framework/device_data_transform.h b/paddle/framework/data_device_transform.h similarity index 100% rename from paddle/framework/device_data_transform.h rename to paddle/framework/data_device_transform.h diff --git a/paddle/framework/device_data_transform_test.cu b/paddle/framework/data_device_transform_test.cu similarity index 100% rename from paddle/framework/device_data_transform_test.cu rename to paddle/framework/data_device_transform_test.cu diff --git a/paddle/framework/data_layout.h b/paddle/framework/data_layout.h index 4a8669c3a41fceaad26878a79eabfd0affce86fd..3ab976ecac4dfb0571ebf5dc93f726939da01116 100644 --- a/paddle/framework/data_layout.h +++ b/paddle/framework/data_layout.h @@ -13,7 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once -#include "paddle/platform/enforce.h" #include #include "paddle/platform/enforce.h" diff --git a/paddle/framework/data_layout_transform.cc b/paddle/framework/data_layout_transform.cc new file mode 100644 index 0000000000000000000000000000000000000000..96794cae97d460e86fe83ac1395e1dfc7e371e3b --- /dev/null +++ b/paddle/framework/data_layout_transform.cc @@ -0,0 +1,82 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/framework/data_layout_transform.h" + +#include "paddle/framework/tensor.h" +#include "paddle/operators/math/math_function.h" + +namespace paddle { +namespace framework { + +struct CastDataLayout { + CastDataLayout(const platform::DeviceContext* ctx, + const std::vector& axis, const framework::Tensor& in, + framework::Tensor* out) + : in_(in), out_(out), ctx_(ctx), axis_(axis) {} + const framework::Tensor in_; + framework::Tensor* out_; + const platform::DeviceContext* ctx_; + const std::vector axis_; + + template + void operator()() { + auto place = ctx_->GetPlace(); + + if (platform::is_cpu_place(place)) { + operators::math::Transpose trans4; + auto* context = static_cast(ctx_); + trans4(*context, in_, out_, axis_); + } else { + PADDLE_THROW("Unsupport CPU <-> GPU!"); + } + } +}; + +void TransDataLayout(const std::vector& axis, + const platform::DeviceContext* ctx, + const KernelTypePair& kernel_pair, const Variable& in, + Variable* out) { + PADDLE_ENFORCE(in.IsType(), "Only support Tensor transform!."); + PADDLE_ENFORCE( + platform::places_are_same_class(kernel_pair.first.place_, + kernel_pair.second.place_), + "TransDataLayout only support DataLayout transform on same place!"); + PADDLE_ENFORCE(kernel_pair.first.data_type_ == kernel_pair.second.data_type_, + "TransDataLayout only support Datatype are same!"); + + auto src = in.Get(); + auto* dst = out->GetMutable(); + PADDLE_ENFORCE(arity(src.dims()) == 4, "Input Arity Only Suppport 4!"); + + auto src_dim = src.dims(); + std::vector dst_dim; + + dst_dim.resize(axis.size()); + for (size_t i = 0; i < axis.size(); i++) { + dst_dim[i] = src_dim[axis[i]]; + } + + dst->Resize(make_ddim(dst_dim)); + auto place = kernel_pair.second.place_; + dst->mutable_data(place, src.type()); + + auto src_type = kernel_pair.first.data_type_; + framework::VisitDataType(src_type, CastDataLayout(ctx, axis, src, dst)); + + dst->set_layout(kernel_pair.second.data_layout_); +} + +} // namespace framework +} // namespace paddle diff --git a/paddle/framework/data_layout_transform.h b/paddle/framework/data_layout_transform.h new file mode 100644 index 0000000000000000000000000000000000000000..befae1f63616a4c21d998c6b784b8ef288d00617 --- /dev/null +++ b/paddle/framework/data_layout_transform.h @@ -0,0 +1,31 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/framework/op_kernel_type.h" +#include "paddle/framework/variable.h" + +namespace paddle { +namespace framework { + +using KernelTypePair = std::pair; + +void TransDataLayout(const std::vector& axis, + const platform::DeviceContext* ctx, + const KernelTypePair& kernel_pair, const Variable& in, + Variable* out); + +} // namespace framework +} // namespace paddle diff --git a/paddle/framework/data_transform.cc b/paddle/framework/data_transform.cc index fed958db1584c4fda5394d59a2ef8936045a9ce9..e56edb95396ef8de44da95ce795161d7cf1debc6 100644 --- a/paddle/framework/data_transform.cc +++ b/paddle/framework/data_transform.cc @@ -11,22 +11,14 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include #include "paddle/framework/data_transform.h" -#include "paddle/framework/device_data_transform.h" -#include "paddle/framework/lod_tensor.h" -#include "paddle/framework/selected_rows.h" -#include "paddle/platform/device_context.h" + +#include "paddle/framework/data_device_transform.h" namespace paddle { namespace framework { -DataTransformFnMap& DataTransformFnMap::Instance() { - static DataTransformFnMap data_transform_map; - return data_transform_map; -} - Tensor* DataTransform(const OpKernelType& expected_kernel_type, const OpKernelType& kernel_type_for_var, const Tensor& input_tensor) { @@ -58,134 +50,5 @@ void CopyVariableWithTensor(const Variable& in_var, const Tensor& tensor, } } -auto KernelFP32 = OpKernelType(proto::DataType::FP32, platform::CPUPlace(), - DataLayout::kNHWC, LibraryType::kPlain); - -auto KernelFP64 = OpKernelType(proto::DataType::FP64, platform::CPUPlace(), - DataLayout::kNHWC, LibraryType::kPlain); - -auto KernelNHWC = OpKernelType(proto::DataType::FP64, platform::CPUPlace(), - DataLayout::kNHWC, LibraryType::kPlain); - -auto KernelNCHW = OpKernelType(proto::DataType::FP64, platform::CPUPlace(), - DataLayout::kNCHW, LibraryType::kPlain); - -// TODO(dzhwinter): Only for testing multiple op kernel. -// Dummy transform function for library_type -// should be removed. -auto KernelPlain = OpKernelType(proto::DataType::FP32, platform::CUDAPlace(0), - DataLayout::kAnyLayout, LibraryType::kPlain); - -auto KernelCUDNN = OpKernelType(proto::DataType::FP32, platform::CUDAPlace(0), - DataLayout::kAnyLayout, LibraryType::kCUDNN); - -void DummyTrans(const platform::DeviceContext* ctx, - const KernelTypePair& kernel_pair, const Variable& in, - Variable* out) { - PADDLE_ENFORCE(in.IsType(), "Only Support Tensor transform!."); - PADDLE_ENFORCE( - platform::places_are_same_class(kernel_pair.first.place_, - kernel_pair.second.place_), - "TransDataType Only Support DataType transform on same place!"); - auto src = in.Get(); - auto* dst = out->GetMutable(); - *dst = src; -} - -void TransDataType(const platform::DeviceContext* ctx, - const KernelTypePair& kernel_pair, const Variable& in, - Variable* out) { - PADDLE_ENFORCE(in.IsType(), "Only Support Tensor transform!."); - PADDLE_ENFORCE( - platform::places_are_same_class(kernel_pair.first.place_, - kernel_pair.second.place_), - "TransDataType Only Support DataType transform on same place!"); - - auto src = in.Get(); - auto* dst = out->GetMutable(); - - auto dims = src.dims(); - dst->Resize(dims); - auto dst_type = kernel_pair.second.data_type_; - auto src_type = kernel_pair.first.data_type_; - - switch (src_type) { - case proto::DataType::FP32: - framework::VisitDataType(dst_type, CastDataType(src, dst, ctx)); - break; - case proto::DataType::FP64: - framework::VisitDataType(dst_type, CastDataType(src, dst, ctx)); - break; - case proto::DataType::INT32: - framework::VisitDataType(dst_type, CastDataType(src, dst, ctx)); - break; - case proto::DataType::INT64: - framework::VisitDataType(dst_type, CastDataType(src, dst, ctx)); - break; - case proto::DataType::BOOL: - framework::VisitDataType(dst_type, CastDataType(src, dst, ctx)); - break; - default: - PADDLE_THROW("Not support type %d", src_type); - } -} - -void TransDataLayout(const std::vector& axis, - const platform::DeviceContext* ctx, - const KernelTypePair& kernel_pair, const Variable& in, - Variable* out) { - PADDLE_ENFORCE(in.IsType(), "Only support Tensor transform!."); - PADDLE_ENFORCE( - platform::places_are_same_class(kernel_pair.first.place_, - kernel_pair.second.place_), - "TransDataLayout only support DataLayout transform on same place!"); - PADDLE_ENFORCE(kernel_pair.first.data_type_ == kernel_pair.second.data_type_, - "TransDataLayout only support Datatype are same!"); - - auto src = in.Get(); - auto* dst = out->GetMutable(); - PADDLE_ENFORCE(arity(src.dims()) == 4, "Input Arity Only Suppport 4!"); - - auto src_dim = src.dims(); - std::vector dst_dim; - - dst_dim.resize(axis.size()); - for (size_t i = 0; i < axis.size(); i++) { - dst_dim[i] = src_dim[axis[i]]; - } - - dst->Resize(make_ddim(dst_dim)); - auto place = kernel_pair.second.place_; - dst->mutable_data(place, src.type()); - - auto src_type = kernel_pair.first.data_type_; - framework::VisitDataType(src_type, CastDataLayout(ctx, axis, src, dst)); - - dst->set_layout(kernel_pair.second.data_layout_); -} - } // namespace framework } // namespace paddle - -namespace f = paddle::framework; - -namespace { -std::vector NHWC2NCHW = {0, 3, 1, 2}; -std::vector NCHW2NHWC = {0, 2, 3, 1}; -} - -REGISTER_DATA_TRANSFORM_FN(f::KernelFP32, f::KernelFP64, f::TransDataType); -REGISTER_DATA_TRANSFORM_FN(f::KernelPlain, f::KernelCUDNN, f::DummyTrans); -REGISTER_DATA_TRANSFORM_FN(f::KernelCUDNN, f::KernelPlain, f::DummyTrans); -REGISTER_DATA_TRANSFORM_FN(f::KernelNHWC, f::KernelNCHW, - std::bind(f::TransDataLayout, NHWC2NCHW, - std::placeholders::_1, - std::placeholders::_2, - std::placeholders::_3, - std::placeholders::_4)); -REGISTER_DATA_TRANSFORM_FN(f::KernelNCHW, f::KernelNHWC, - std::bind(f::TransDataLayout, NCHW2NHWC, - std::placeholders::_1, - std::placeholders::_2, - std::placeholders::_3, - std::placeholders::_4)); diff --git a/paddle/framework/data_transform.h b/paddle/framework/data_transform.h index e4e5c30a96a3c985ae2ecd494b723c8afeceb12f..ee95c7e8564d0392c8f25fce161d0f722c04761a 100644 --- a/paddle/framework/data_transform.h +++ b/paddle/framework/data_transform.h @@ -30,26 +30,6 @@ limitations under the License. */ namespace paddle { namespace framework { -using KernelTypePair = std::pair; - -using DataTransformFn = - std::function; - -struct KernelTypePairHash { - static void HashCombine(const OpKernelType& t, std::size_t* seed) { - OpKernelType::Hash kernel_type_hasher; - (*seed) ^= kernel_type_hasher(t) + 0x9e3779b9 + (*seed << 6) + (*seed >> 2); - } - - size_t operator()(const KernelTypePair& kernel_pair) const { - std::size_t seed = 0; - HashCombine(kernel_pair.first, &seed); - HashCombine(kernel_pair.second, &seed); - return seed; - } -}; - Tensor* DataTransform(const OpKernelType& expected_kernel_type, const OpKernelType& kernel_type_for_var, const Tensor& input_tensor); @@ -57,125 +37,5 @@ Tensor* DataTransform(const OpKernelType& expected_kernel_type, void CopyVariableWithTensor(const Variable& in_var, const Tensor& tensor, Variable& out_var); -template -struct CastDataTypeFunctor { - HOSTDEVICE inline OutType operator()(InType in) const { - return static_cast(in); - } -}; - -template -struct CastDataType { - CastDataType(const framework::Tensor& in, framework::Tensor* out, - const platform::DeviceContext* ctx) - : in_(in), out_(out), ctx_(ctx) {} - const framework::Tensor in_; - framework::Tensor* out_; - const platform::DeviceContext* ctx_; - - template - void operator()() { - auto place = ctx_->GetPlace(); - - auto* in_begin = in_.data(); - auto numel = in_.numel(); - auto* in_end = in_begin + numel; - auto* out_begin = out_->mutable_data(place); - - if (platform::is_cpu_place(place)) { - platform::Transform trans; - auto* context = static_cast(ctx_); - trans(*context, in_begin, in_end, out_begin, - CastDataTypeFunctor()); - } else { - // TODO(dzhwinter): enhance Copy CPU<->GPU with different data type? - PADDLE_THROW("Unsupport CPU <-> GPU!"); - } - } -}; - -struct CastDataLayout { - CastDataLayout(const platform::DeviceContext* ctx, - const std::vector& axis, const framework::Tensor& in, - framework::Tensor* out) - : in_(in), out_(out), ctx_(ctx), axis_(axis) {} - const framework::Tensor in_; - framework::Tensor* out_; - const platform::DeviceContext* ctx_; - const std::vector axis_; - - template - void operator()() { - auto place = ctx_->GetPlace(); - - if (platform::is_cpu_place(place)) { - operators::math::Transpose trans4; - auto* context = static_cast(ctx_); - trans4(*context, in_, out_, axis_); - } else { - PADDLE_THROW("Unsupport CPU <-> GPU!"); - } - } -}; - -using DataTransformMap = - std::unordered_map; - -class DataTransformFnMap { - public: - static DataTransformFnMap& Instance(); - - bool Has(const KernelTypePair& key_pair) const { - return map_.find(key_pair) != map_.end(); - } - - void Insert(const OpKernelType& left, const OpKernelType& right, - const DataTransformFn& data_tranform_fn) { - Insert(std::make_pair(left, right), data_tranform_fn); - } - - void Insert(const KernelTypePair& kernel_type_pair, - const DataTransformFn& data_tranform_fn) { - PADDLE_ENFORCE(!Has(kernel_type_pair), - "KernelTypePair %s has been registered", ""); - map_.insert({kernel_type_pair, data_tranform_fn}); - } - - const DataTransformFn& Get(const KernelTypePair& key_pair) const { - auto data_transformer = GetNullable(key_pair); - PADDLE_ENFORCE_NOT_NULL(data_transformer, - "DataTransformFn should not be NULL"); - return *data_transformer; - } - - const DataTransformFn* GetNullable(const KernelTypePair& key_pair) const { - auto it = map_.find(key_pair); - if (it == map_.end()) { - return nullptr; - } else { - return &(it->second); - } - } - - const DataTransformMap& Map() const { return map_; } - - private: - DataTransformFnMap() = default; - DataTransformMap map_; - DISABLE_COPY_AND_ASSIGN(DataTransformFnMap); -}; - -// generate unique name with __LINE__ -// refs https://stackoverflow.com/questions/1597007 -#define TOKENPASTE(x, y) x##y -#define TOKENPASTE2(x, y) TOKENPASTE(x, y) -#define REGISTER_DATA_TRANSFORM_FN(from, to, fn) \ - static int TOKENPASTE2(fn_, __LINE__)() { \ - ::paddle::framework::DataTransformFnMap::Instance().Insert(from, to, fn); \ - return 0; \ - } \ - static int TOKENPASTE2(var_, __LINE__) __attribute__((unused)) = \ - TOKENPASTE2(fn_, __LINE__)() - } // namespace framework } // namespace paddle diff --git a/paddle/framework/data_transform_test.cc b/paddle/framework/data_transform_test.cc deleted file mode 100644 index edd305fd17ae202926b83fbec10089719baa2e16..0000000000000000000000000000000000000000 --- a/paddle/framework/data_transform_test.cc +++ /dev/null @@ -1,168 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ -#include -#include - -#include - -#include "paddle/framework/data_transform.h" -#include "paddle/platform/device_context.h" - -namespace paddle { -namespace framework { -using namespace platform; - -/** - * @brief cross validation of different kernel type transform - * We use four bit map represent different combination. - * If the field has multiple possible value, only choose two of them. - * For DataType, only test the FP32(float), FP64(double). - * e.g. 0000 -> FP32, CPUPlace, kNHWC, kPlain - * 1111 -> FP64, GPUPlace, kNCHW, kMKLDNN - */ - -std::array kDataType = { - {proto::DataType::FP32, proto::DataType::FP64}}; - -std::array kPlace = {{CPUPlace(), CUDAPlace(0)}}; - -std::array kDataLayout = {{ - DataLayout::kNHWC, DataLayout::kNCHW, -}}; - -std::array kLibraryType = {{ - LibraryType::kPlain, LibraryType::kMKLDNN, -}}; - -OpKernelType GenFromBit(const std::vector bits) { - return OpKernelType(kDataType[bits[0]], kPlace[bits[1]], kDataLayout[bits[2]], - kLibraryType[bits[3]]); -} - -int test_value = 0; - -auto kernel0 = GenFromBit({0, 0, 0, 0}); -auto kernel1 = GenFromBit({0, 0, 0, 1}); -auto kernel2 = GenFromBit({0, 0, 1, 0}); -auto kernel3 = GenFromBit({0, 0, 1, 1}); - -void TransDataType_t(const platform::DeviceContext* ctx, - const KernelTypePair& p, const Variable& in, - Variable* out) { - test_value++; -} - -void TransDataLayout_t(const platform::DeviceContext* ctx, - const KernelTypePair& p, const Variable& in, - Variable* out) { - test_value--; -} - -void TransLibraryType_t(const platform::DeviceContext* ctx, - const KernelTypePair& p, const Variable& in, - Variable* out) { - test_value += 2; -} - -} // namespace framework -} // namespace paddle - -namespace frw = paddle::framework; - -REGISTER_DATA_TRANSFORM_FN(frw::kernel0, frw::kernel1, frw::TransDataType_t); -REGISTER_DATA_TRANSFORM_FN(frw::kernel1, frw::kernel2, frw::TransDataLayout_t); -REGISTER_DATA_TRANSFORM_FN(frw::kernel0, frw::kernel2, frw::TransLibraryType_t); - -TEST(DataTransform, Register) { - using namespace paddle::framework; - using namespace paddle::platform; - - auto& instance = DataTransformFnMap::Instance(); - paddle::framework::Variable in; - paddle::framework::Variable out; - - DeviceContext* ctx = new CPUDeviceContext(); - auto pair0 = std::make_pair(frw::kernel0, frw::kernel1); - instance.Get(pair0)(ctx, pair0, in, &out); - ASSERT_EQ(test_value, 1); - - auto pair1 = std::make_pair(frw::kernel1, frw::kernel2); - instance.Get(pair1)(ctx, pair1, in, &out); - ASSERT_EQ(test_value, 0); - - auto pair3 = std::make_pair(frw::kernel0, frw::kernel2); - instance.Get(pair3)(ctx, pair3, in, &out); - ASSERT_EQ(test_value, 2); -} - -TEST(DataTransform, DataLayout) { - using namespace paddle::framework; - using namespace paddle::platform; - - auto& instance = DataTransformFnMap::Instance(); - Variable in; - Variable out; - Tensor* src = in.GetMutable(); - src->mutable_data(make_ddim({2, 3, 1, 2}), CPUPlace()); - src->set_layout(DataLayout::kNHWC); - - DeviceContext* ctx = new CPUDeviceContext(); - - { - auto kernel1 = GenFromBit({1, 0, 0, 0}); - auto kernel2 = GenFromBit({1, 0, 1, 0}); - auto pair0 = std::make_pair(kernel1, kernel2); - instance.Get(pair0)(ctx, pair0, in, &out); - } - - Tensor dst = out.Get(); - - EXPECT_TRUE(dst.layout() == DataLayout::kNCHW); - EXPECT_TRUE(dst.dims() == make_ddim({2, 2, 3, 1})); - - { - auto kernel1 = GenFromBit({1, 0, 1, 0}); - auto kernel2 = GenFromBit({1, 0, 0, 0}); - auto pair0 = std::make_pair(kernel1, kernel2); - instance.Get(pair0)(ctx, pair0, out, &in); - } - - EXPECT_TRUE(src->layout() == DataLayout::kNHWC); - EXPECT_TRUE(src->dims() == make_ddim({2, 3, 1, 2})); -} - -TEST(DataTransform, DataType) { - using namespace paddle::framework; - using namespace paddle::platform; - - auto& instance = DataTransformFnMap::Instance(); - DeviceContext* ctx = new CPUDeviceContext(); - - Variable in; - Variable out; - Tensor* src = in.GetMutable(); - float* ptr = src->mutable_data(make_ddim({2, 3}), CPUPlace()); - for (int i = 0; i < 6; ++i) { - ptr[i] = i / 3; - } - - { - auto kernel1 = GenFromBit({0, 0, 0, 0}); - auto kernel2 = GenFromBit({1, 0, 0, 0}); - auto pair0 = std::make_pair(kernel1, kernel2); - instance.Get(pair0)(ctx, pair0, in, &out); - } - Tensor dst = out.Get(); - EXPECT_TRUE(dst.data() != nullptr); -} diff --git a/paddle/framework/data_type_transform.cc b/paddle/framework/data_type_transform.cc new file mode 100644 index 0000000000000000000000000000000000000000..63373232e910d44eb0996f9280f9c166ad092030 --- /dev/null +++ b/paddle/framework/data_type_transform.cc @@ -0,0 +1,99 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/framework/data_type_transform.h" + +#include "paddle/framework/selected_rows.h" +#include "paddle/platform/transform.h" + +namespace paddle { +namespace framework { + +template +struct CastDataTypeFunctor { + HOSTDEVICE inline OutType operator()(InType in) const { + return static_cast(in); + } +}; + +template +struct CastDataType { + CastDataType(const framework::Tensor& in, framework::Tensor* out, + const platform::DeviceContext* ctx) + : in_(in), out_(out), ctx_(ctx) {} + const framework::Tensor in_; + framework::Tensor* out_; + const platform::DeviceContext* ctx_; + + template + void operator()() { + auto place = ctx_->GetPlace(); + + auto* in_begin = in_.data(); + auto numel = in_.numel(); + auto* in_end = in_begin + numel; + auto* out_begin = out_->mutable_data(place); + + if (platform::is_cpu_place(place)) { + platform::Transform trans; + auto* context = static_cast(ctx_); + trans(*context, in_begin, in_end, out_begin, + CastDataTypeFunctor()); + } else { + // TODO(dzhwinter): enhance Copy CPU<->GPU with different data type? + PADDLE_THROW("Unsupport CPU <-> GPU!"); + } + } +}; + +void TransDataType(const platform::DeviceContext* ctx, + const KernelTypePair& kernel_pair, const Variable& in, + Variable* out) { + PADDLE_ENFORCE(in.IsType(), "Only Support Tensor transform!."); + PADDLE_ENFORCE( + platform::places_are_same_class(kernel_pair.first.place_, + kernel_pair.second.place_), + "TransDataType Only Support DataType transform on same place!"); + + auto src = in.Get(); + auto* dst = out->GetMutable(); + + auto dims = src.dims(); + dst->Resize(dims); + auto dst_type = kernel_pair.second.data_type_; + auto src_type = kernel_pair.first.data_type_; + + switch (src_type) { + case proto::DataType::FP32: + framework::VisitDataType(dst_type, CastDataType(src, dst, ctx)); + break; + case proto::DataType::FP64: + framework::VisitDataType(dst_type, CastDataType(src, dst, ctx)); + break; + case proto::DataType::INT32: + framework::VisitDataType(dst_type, CastDataType(src, dst, ctx)); + break; + case proto::DataType::INT64: + framework::VisitDataType(dst_type, CastDataType(src, dst, ctx)); + break; + case proto::DataType::BOOL: + framework::VisitDataType(dst_type, CastDataType(src, dst, ctx)); + break; + default: + PADDLE_THROW("Not support type %d", src_type); + } +} + +} // namespace framework +} // namespace paddle diff --git a/paddle/framework/data_type_transform.h b/paddle/framework/data_type_transform.h new file mode 100644 index 0000000000000000000000000000000000000000..8ec90742256c2308a242d993838e46e51a6fc167 --- /dev/null +++ b/paddle/framework/data_type_transform.h @@ -0,0 +1,31 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/framework/op_kernel_type.h" +#include "paddle/framework/variable.h" +#include "paddle/platform/device_context.h" + +namespace paddle { +namespace framework { + +using KernelTypePair = std::pair; + +void TransDataType(const platform::DeviceContext* ctx, + const KernelTypePair& kernel_pair, const Variable& in, + Variable* out); + +} // namespace framework +} // namespace paddle diff --git a/paddle/framework/operator.cc b/paddle/framework/operator.cc index 35ebe48ba682f135b7f85edb3b2999db7c29e51a..ef2c55cc3799ba2fac54f3c9370505b63ef22ad3 100644 --- a/paddle/framework/operator.cc +++ b/paddle/framework/operator.cc @@ -16,7 +16,6 @@ limitations under the License. */ #include #include "paddle/framework/data_transform.h" -#include "paddle/framework/device_data_transform.h" #include "paddle/framework/executor.h" #include "paddle/framework/operator.h" #include "paddle/framework/shape_inference.h" diff --git a/paddle/gserver/layers/MKLDNNLayer.cpp b/paddle/gserver/layers/MKLDNNLayer.cpp index 6fbf3c7fdec2f537769adb660c67c5a597beb609..2d0fff608c3932ede57ddcbb6a85cda255b77246 100644 --- a/paddle/gserver/layers/MKLDNNLayer.cpp +++ b/paddle/gserver/layers/MKLDNNLayer.cpp @@ -132,6 +132,8 @@ void MKLDNNLayer::reshapeInput(int& batchsize, if (w != 0) { width = w; } + height = height != 0 ? height : 1; + width = width != 0 ? width : 1; } void MKLDNNLayer::reshapeOutput(size_t height, size_t width) { diff --git a/paddle/gserver/layers/MKLDNNLayer.h b/paddle/gserver/layers/MKLDNNLayer.h index e48b9b5a91f7f17cb3f31e9140f1428ba8954a20..3ba39f18b6ae4b38c17b3b72a359fd514e6dad5b 100644 --- a/paddle/gserver/layers/MKLDNNLayer.h +++ b/paddle/gserver/layers/MKLDNNLayer.h @@ -98,6 +98,8 @@ protected: public: explicit MKLDNNLayer(const LayerConfig& config) : Layer(config), + ih_(0), + iw_(0), condition_(0), needResetBwd_(true), outputOnlyMKLDNN_(false), diff --git a/paddle/operators/detail/CMakeLists.txt b/paddle/operators/detail/CMakeLists.txt index f6bdc63cc2cfae526fe911ee4d989675452d5c5d..571a75c9dcd903672d460f192bf28ddbeaea7c78 100644 --- a/paddle/operators/detail/CMakeLists.txt +++ b/paddle/operators/detail/CMakeLists.txt @@ -1 +1 @@ -grpc_library(sendrecvop_grpc SRCS recv_impl.cc send_impl.cc PROTO send_recv.proto DEPS lod_tensor selected_rows) +grpc_library(sendrecvop_grpc SRCS sendrecvop_utils.cc grpc_client.cc grpc_server.cc PROTO send_recv.proto DEPS lod_tensor selected_rows) diff --git a/paddle/operators/detail/grpc_client.cc b/paddle/operators/detail/grpc_client.cc new file mode 100644 index 0000000000000000000000000000000000000000..5a4db2d7e686ce84abef620f890be8f3aa82cb73 --- /dev/null +++ b/paddle/operators/detail/grpc_client.cc @@ -0,0 +1,147 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "grpc_client.h" +namespace paddle { +namespace operators { +namespace detail { + +bool RPCClient::AsyncSendVariable(const std::string& ep, + const platform::DeviceContext& ctx, + const framework::Scope& scope, + const std::string& var_name, + int64_t time_out) { + sendrecv::VariableMessage req; + auto* var = scope.FindVar(var_name); + SerializeToMessage(var_name, var, ctx, &req); + + // varhandle + VarHandle var_h; + var_h.ep = ep; + var_h.scope = &scope; + var_h.name = var_name; + var_h.ctx = &ctx; + + // stub context + auto ch = GetChannel(ep); + SendProcessor* s = new SendProcessor(ch); + s->Prepare(var_h, time_out); + s->response_call_back_ = NULL; + + auto rpc = s->stub_->AsyncSendVariable(s->context_.get(), req, &cq_); + rpc->Finish(&s->reply_, &s->status_, (void*)s); + + req_count_++; + + return true; +} + +void ProcGetResponse(const VarHandle& var_h, + const sendrecv::VariableMessage& ret_msg) { + auto* outvar = var_h.scope->FindVar(var_h.name); + + std::istringstream iss(ret_msg.serialized()); + DeserializeFromMessage(ret_msg, *var_h.ctx, outvar); +} + +bool RPCClient::AsyncGetVariable(const std::string& ep, + const platform::DeviceContext& ctx, + const framework::Scope& scope, + const std::string& var_name, + int64_t time_out) { + sendrecv::VariableMessage req; + req.set_varname(var_name); + + auto* var = scope.FindVar(var_name); + SerializeToMessage(var_name, var, ctx, &req); + + // varhandle + VarHandle var_h; + var_h.ep = ep; + var_h.scope = &scope; + var_h.name = var_name; + var_h.ctx = &ctx; + + // stub context + auto ch = GetChannel(ep); + GetProcessor* s = new GetProcessor(ch); + s->Prepare(var_h, time_out); + s->response_call_back_ = ProcGetResponse; + + auto rpc = s->stub_->AsyncGetVariable(s->context_.get(), req, &cq_); + rpc->Finish(&s->reply_, &s->status_, (void*)s); + + req_count_++; + + return true; +} + +bool RPCClient::wait() { + bool ok = true; + + while (true) { + if (req_count_ <= 0) { + break; + } + + if (!Proceed()) { + LOG(ERROR) << "Get meets CompletionQueue error"; + return false; + } + } + + return ok; +} + +bool RPCClient::Proceed() { + void* tag = NULL; + bool ok = false; + + // request counts. + if (!cq_.Next(&tag, &ok)) { + return false; + } + req_count_--; + + GPR_ASSERT(ok); + PADDLE_ENFORCE(tag); + + // TODO(gongwb): add more retries. + ClientBase* c = static_cast(tag); + if (!c->status_.ok()) { + delete c; + return true; + } + + c->Process(); + delete c; + return true; +} + +std::shared_ptr RPCClient::GetChannel(const std::string& ep) { + auto it = channels_.find(ep); + if (it != channels_.end()) { + return it->second; + } + + auto ch = std::shared_ptr( + grpc::CreateChannel(ep, grpc::InsecureChannelCredentials())); + + channels_[ep] = ch; + return ch; +} + +} // namespace detail +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/detail/grpc_client.h b/paddle/operators/detail/grpc_client.h new file mode 100644 index 0000000000000000000000000000000000000000..d27b5ced9ece67f9b9da3b7f87ec231477603580 --- /dev/null +++ b/paddle/operators/detail/grpc_client.h @@ -0,0 +1,147 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "paddle/framework/data_type.h" +#include "paddle/framework/lod_tensor.h" +#include "paddle/framework/scope.h" +#include "paddle/framework/selected_rows.h" +#include "paddle/operators/detail/sendrecvop_utils.h" +#include "paddle/operators/detail/simple_block_queue.h" + +namespace paddle { +namespace operators { +namespace detail { + +struct VarHandle { + std::string ep; + const platform::DeviceContext* ctx; + const framework::Scope* scope; + std::string name; + + std::string String() const { + std::ostringstream s; + s << "name:[" << name << "] ep:[" << ep << "]"; + return s.str(); + } +}; + +void ProcGetResponse(const VarHandle& var_h, + const sendrecv::VariableMessage& msg); + +class ClientBase { + public: + explicit ClientBase(std::shared_ptr ch) { + stub_ = sendrecv::SendRecvService::NewStub(ch); + context_ = NULL; + } + + virtual ~ClientBase() {} + + virtual void Prepare(const VarHandle& var_info, int64_t time_out) { + context_.reset(new grpc::ClientContext()); + var_h_ = var_info; + + std::chrono::system_clock::time_point deadline = + std::chrono::system_clock::now() + std::chrono::milliseconds(time_out); + + context_->set_deadline(deadline); + } + + virtual void Process() = 0; + + std::unique_ptr stub_; + std::unique_ptr context_; + grpc::Status status_; + VarHandle var_h_; +}; + +typedef std::function + RequestSendCallBack; + +class SendProcessor : public ClientBase { + public: + explicit SendProcessor(std::shared_ptr ch) : ClientBase(ch) {} + + virtual ~SendProcessor() {} + + virtual void Process() { + if (response_call_back_) { + response_call_back_(var_h_, reply_); + } + } + + sendrecv::VoidMessage reply_; + RequestSendCallBack response_call_back_ = NULL; +}; + +typedef std::function + RequestGetCallBack; + +class GetProcessor : public ClientBase { + public: + explicit GetProcessor(std::shared_ptr ch) : ClientBase(ch) {} + + virtual ~GetProcessor() {} + + virtual void Process() { + if (response_call_back_) { + response_call_back_(var_h_, reply_); + } + } + + sendrecv::VariableMessage reply_; + RequestGetCallBack response_call_back_ = ProcGetResponse; +}; + +class RPCClient { + public: + bool AsyncSendVariable(const std::string& ep, + const platform::DeviceContext& ctx, + const framework::Scope& scope, + const std::string& var_name, + int64_t time_out = 600 * 1000); + + bool AsyncGetVariable(const std::string& ep, + const platform::DeviceContext& ctx, + const framework::Scope& scope, + const std::string& var_name, + int64_t time_out = 600 * 1000); + bool wait(); + + private: + bool Proceed(); + std::shared_ptr GetChannel(const std::string& ep); + + private: + grpc::CompletionQueue cq_; + std::map> channels_; + int64_t req_count_ = 0; +}; + +} // namespace detail +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/detail/grpc_server.cc b/paddle/operators/detail/grpc_server.cc new file mode 100644 index 0000000000000000000000000000000000000000..e8d561a57ff59e9221400241f881cb26fb6c6f06 --- /dev/null +++ b/paddle/operators/detail/grpc_server.cc @@ -0,0 +1,237 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/detail/grpc_server.h" + +using grpc::ServerAsyncResponseWriter; + +namespace paddle { +namespace operators { +namespace detail { + +enum CallStatus { PROCESS = 0, FINISH }; + +// reference: +// https://stackoverflow.com/questions/41732884/grpc-multiple-services-in-cpp-async-server +class RequestBase { + public: + explicit RequestBase(sendrecv::SendRecvService::AsyncService* service, + grpc::ServerCompletionQueue* cq) + : service_(service), cq_(cq), status_(PROCESS) {} + virtual ~RequestBase() {} + virtual void Process() { assert(false); } + + CallStatus Status() { return status_; } + void SetStatus(CallStatus status) { status_ = status; } + + protected: + grpc::ServerContext ctx_; + sendrecv::SendRecvService::AsyncService* service_; + grpc::ServerCompletionQueue* cq_; + CallStatus status_; +}; + +typedef std::pair MessageWithName; + +class RequestSend final : public RequestBase { + public: + explicit RequestSend(sendrecv::SendRecvService::AsyncService* service, + grpc::ServerCompletionQueue* cq, + SimpleBlockQueue* queue) + : RequestBase(service, cq), queue_(queue), responder_(&ctx_) { + service_->RequestSendVariable(&ctx_, &request_, &responder_, cq_, cq_, + this); + } + + virtual ~RequestSend() {} + + virtual void Process() { + MessageWithName msg_with_name = + std::make_pair(request_.varname(), std::move(request_)); + queue_->Push(std::move(msg_with_name)); + // TODO(gongwb): check var's info. + responder_.Finish(reply_, grpc::Status::OK, this); + } + + protected: + sendrecv::VariableMessage request_; + sendrecv::VoidMessage reply_; + SimpleBlockQueue* queue_; + ServerAsyncResponseWriter responder_; +}; + +class RequestGet final : public RequestBase { + public: + explicit RequestGet(sendrecv::SendRecvService::AsyncService* service, + grpc::ServerCompletionQueue* cq, framework::Scope* scope) + : RequestBase(service, cq), responder_(&ctx_), scope_(scope) { + service_->RequestGetVariable(&ctx_, &request_, &responder_, cq_, cq_, this); + } + + virtual ~RequestGet() {} + + virtual void Process() { + // proc request. + std::string var_name = request_.varname(); + auto* var = scope_->FindVar(var_name); + SerializeToMessage(var_name, var, platform::CPUDeviceContext(), &reply_); + // TODO(gongwb): check var's info. + responder_.Finish(reply_, grpc::Status::OK, this); + } + + protected: + sendrecv::VariableMessage request_; + sendrecv::VariableMessage reply_; + ServerAsyncResponseWriter responder_; + framework::Scope* scope_; +}; + +void AsyncGRPCServer::RunSyncUpdate() { + grpc::ServerBuilder builder; + builder.AddListeningPort(address_, grpc::InsecureServerCredentials()); + builder.RegisterService(&service_); + + cq_send_ = builder.AddCompletionQueue(); + cq_get_ = builder.AddCompletionQueue(); + server_ = builder.BuildAndStart(); + LOG(INFO) << "Server listening on " << address_ << std::endl; + + std::function send_register = + std::bind(&AsyncGRPCServer::TryToRegisterNewSendOne, this); + std::function get_register = + std::bind(&AsyncGRPCServer::TryToRegisterNewGetOne, this); + + t_send_.reset( + new std::thread(std::bind(&AsyncGRPCServer::HandleRequest, this, false, + cq_send_.get(), "cq_send", send_register))); + + t_get_.reset( + new std::thread(std::bind(&AsyncGRPCServer::HandleRequest, this, true, + cq_get_.get(), "cq_get", get_register))); + + // wait server + server_->Wait(); + t_send_->join(); + t_get_->join(); +} + +void AsyncGRPCServer::ShutdownQueue() { + std::unique_lock lock(cq_mutex_); + cq_send_->Shutdown(); + cq_get_->Shutdown(); + is_shut_down_ = true; +} + +// This URL explains why shutdown is complicate: +// https://stackoverflow.com/questions/35708348/grpc-what-is-the-recommended-way-to-shut-down-an-asynchronous-server-in-c +void AsyncGRPCServer::ShutDown() { + server_->Shutdown(); + ShutdownQueue(); +} + +void AsyncGRPCServer::TryToRegisterNewSendOne() { + std::unique_lock lock(cq_mutex_); + if (is_shut_down_) { + return; + } + RequestSend* send = + new RequestSend(&service_, cq_send_.get(), &var_recv_queue_); + VLOG(4) << "create RequestSend status:" << send->Status(); +} + +void AsyncGRPCServer::TryToRegisterNewGetOne() { + std::unique_lock lock(cq_mutex_); + if (is_shut_down_) { + return; + } + RequestGet* get = new RequestGet(&service_, cq_get_.get(), scope_); + VLOG(4) << "create Requestget status:" << get->Status(); +} + +void AsyncGRPCServer::SetFinishOrDelete(RequestBase*& last) { + std::unique_lock lock(cq_mutex_); + if (is_shut_down_) { + delete last; + last = NULL; + return; + } + + last->SetStatus(FINISH); + return; +} + +void AsyncGRPCServer::HandleRequest(bool wait, grpc::ServerCompletionQueue* cq, + std::string cq_name, + std::function TryToRegisterNewOne) { + TryToRegisterNewOne(); + + void* tag = NULL; + bool ok = false; + while (true) { + if (!cq->Next(&tag, &ok)) { + LOG(INFO) << cq_name << " get CompletionQueue shutdown!"; + break; + } + + if (wait && !done_) { + Wait(); + } + + RequestBase* base = (RequestBase*)tag; + if (!ok) { + VLOG(4) << cq_name << " recv no regular event"; + TryToRegisterNewOne(); + delete base; + continue; + } + + switch (base->Status()) { + case PROCESS: { + VLOG(4) << cq_name << " status:" << base->Status(); + TryToRegisterNewOne(); + base->Process(); + SetFinishOrDelete(base); + break; + } + case FINISH: { + VLOG(4) << cq_name << " status:" << base->Status(); + delete base; + break; + } + default: { assert(false); } + } + } +} + +void AsyncGRPCServer::Wait() { + std::unique_lock lock(this->mutex_); + condition_.wait(lock, [=] { return this->done_ == true; }); +} + +void AsyncGRPCServer::Reset() { + std::lock_guard lock(this->mutex_); + done_ = false; +} + +void AsyncGRPCServer::Done() { + { + std::lock_guard lock(this->mutex_); + done_ = true; + } + condition_.notify_all(); +} + +} // namespace detail +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/detail/grpc_server.h b/paddle/operators/detail/grpc_server.h new file mode 100644 index 0000000000000000000000000000000000000000..041fe05b2e9c37e8a91669b8f523c47b56e14cba --- /dev/null +++ b/paddle/operators/detail/grpc_server.h @@ -0,0 +1,91 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/framework/lod_tensor.h" +#include "paddle/framework/scope.h" +#include "paddle/framework/selected_rows.h" +#include "paddle/framework/var_type.h" +#include "paddle/operators/detail/simple_block_queue.h" + +#include "paddle/operators/detail/send_recv.grpc.pb.h" +#include "paddle/operators/detail/send_recv.pb.h" + +#include +#include +#include +#include "paddle/operators/detail/sendrecvop_utils.h" + +namespace paddle { +namespace operators { +namespace detail { + +typedef std::pair MessageWithName; +class RequestBase; + +class AsyncGRPCServer final : public sendrecv::SendRecvService::Service { + public: + explicit AsyncGRPCServer(std::string address) { address_ = address; } + + void RunSyncUpdate(); + + void Reset(); + + void Done(); + + void SetScope(framework::Scope *scope) { scope_ = scope; } + + const MessageWithName Get() { return this->var_recv_queue_.Pop(); } + + void Push(const MessageWithName &msg) { this->var_recv_queue_.Push(msg); } + + void ShutDown(); + + protected: + void Wait(); + void HandleRequest(bool wait, grpc::ServerCompletionQueue *cq, + std::string cq_name, + std::function TryToRegisterNewOne); + void TryToRegisterNewSendOne(); + void TryToRegisterNewGetOne(); + void SetFinishOrDelete(RequestBase *&last); + void ShutdownQueue(); + + private: + std::mutex cq_mutex_; + volatile bool is_shut_down_ = false; + std::unique_ptr cq_send_; + std::unique_ptr cq_get_; + + sendrecv::SendRecvService::AsyncService service_; + std::unique_ptr server_; + + std::string address_; + framework::Scope *scope_; + // received variable from RPC, operators fetch variable from this queue. + SimpleBlockQueue var_recv_queue_; + + // condition of the sub program + std::mutex mutex_; + volatile mutable bool done_; + std::condition_variable condition_; + + std::unique_ptr t_send_; + std::unique_ptr t_get_; +}; + +}; // namespace detail +}; // namespace operators +}; // namespace paddle diff --git a/paddle/operators/detail/recv_impl.cc b/paddle/operators/detail/recv_impl.cc deleted file mode 100644 index 319404e56a5f3c407f313991240bbbb85fd39a2a..0000000000000000000000000000000000000000 --- a/paddle/operators/detail/recv_impl.cc +++ /dev/null @@ -1,65 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "send_recv_impl.h" - -namespace paddle { -namespace operators { -namespace detail { - -Status SendRecvServerImpl::SendVariable(ServerContext *context, - const VariableMessage *in_var, - VoidMessage *out_var) { - MessageWithName msg_with_name = - std::make_pair(in_var->varname(), std::move(*in_var)); - var_recv_queue_.Push(std::move(msg_with_name)); - return Status::OK; -} - -Status SendRecvServerImpl::GetVariable(ServerContext *context, - const VariableMessage *in_var, - VariableMessage *out_var) { - std::string get_var_name = in_var->varname(); - auto *var = scope_->FindVar(get_var_name); - - SerializeToMessage(get_var_name, var, platform::CPUDeviceContext(), out_var); - return Status::OK; -} - -Status SendRecvServerImpl::Wait(ServerContext *context, - const VoidMessage *in_var, - VoidMessage *out_var) { - { - std::unique_lock lock(this->mutex_); - condition_.wait(lock, [=] { return this->done_ == true; }); - } - return Status::OK; -} - -void SendRecvServerImpl::Reset() { - std::lock_guard lock(this->mutex_); - done_ = false; -} - -void SendRecvServerImpl::Done() { - { - std::lock_guard lock(this->mutex_); - done_ = true; - } - condition_.notify_all(); -} - -} // namespace detail -} // namespace operators -} // namespace paddle diff --git a/paddle/operators/detail/send_impl.cc b/paddle/operators/detail/send_impl.cc deleted file mode 100644 index ae85cf2cec2cd8e046c0c7fd3408f2212f225819..0000000000000000000000000000000000000000 --- a/paddle/operators/detail/send_impl.cc +++ /dev/null @@ -1,67 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "send_recv_impl.h" - -namespace paddle { -namespace operators { -namespace detail { - -bool RPCClient::SendVariable(const framework::Scope& scope, - const std::string& inname) { - ClientContext context; - VariableMessage msg; - VoidMessage out_msg; - // FIXME(typhoonzero): pass device context to here. - auto ctx = platform::CPUDeviceContext(); - auto* var = scope.FindVar(inname); - PADDLE_ENFORCE(var); - SerializeToMessage(inname, var, ctx, &msg); - - Status status = stub_->SendVariable(&context, msg, &out_msg); - if (!status.ok()) { - LOG(ERROR) << "gRPC error: " << status.error_message(); - return false; - } - return true; -} - -bool RPCClient::GetVariable(const framework::Scope& scope, - const std::string& outname) { - ClientContext context; - VariableMessage call_msg, ret_msg; - call_msg.set_varname(outname); - auto ctx = platform::CPUDeviceContext(); - Status status = stub_->GetVariable(&context, call_msg, &ret_msg); - auto* outvar = scope.FindVar(outname); - if (!status.ok()) { - LOG(ERROR) << "gRPC error: " << status.error_message(); - return false; - } - - std::istringstream iss(ret_msg.serialized()); - DeserializeFromMessage(ret_msg, ctx, outvar); - - return true; -} - -void RPCClient::Wait() { - ClientContext context; - VoidMessage call_msg, ret_msg; - stub_->Wait(&context, call_msg, &ret_msg); -} - -} // namespace detail -} // namespace operators -} // namespace paddle diff --git a/paddle/operators/detail/send_recv.proto b/paddle/operators/detail/send_recv.proto index f141c755ce14ef540aeab32c11c289179aff3f8c..8f962b4c69cc83dc2ab98b7dc27e18bc4b42bf18 100644 --- a/paddle/operators/detail/send_recv.proto +++ b/paddle/operators/detail/send_recv.proto @@ -21,8 +21,6 @@ service SendRecvService { rpc SendVariable(VariableMessage) returns (VoidMessage) {} // Argument VariableMessage for GetVariable should only contain varname. rpc GetVariable(VariableMessage) returns (VariableMessage) {} - // wait for one execution of the program - rpc Wait(VoidMessage) returns (VoidMessage) {} } // VariableMessage is serialized paddle variable message. diff --git a/paddle/operators/detail/send_recv_impl.h b/paddle/operators/detail/send_recv_impl.h deleted file mode 100644 index 1fe54f1f0536aed7d41bbdeeca076534abafe98d..0000000000000000000000000000000000000000 --- a/paddle/operators/detail/send_recv_impl.h +++ /dev/null @@ -1,141 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#include "paddle/framework/lod_tensor.h" -#include "paddle/framework/scope.h" -#include "paddle/framework/selected_rows.h" -#include "paddle/framework/var_type.h" -#include "paddle/operators/detail/simple_block_queue.h" - -#include "paddle/operators/detail/send_recv.grpc.pb.h" -#include "paddle/operators/detail/send_recv.pb.h" - -#include - -using grpc::Channel; -using grpc::Server; -using grpc::ServerContext; -using grpc::ServerReader; -using grpc::ServerBuilder; - -using grpc::ClientContext; -using grpc::ClientReader; -using grpc::ClientReaderWriter; -using grpc::ClientWriter; -using grpc::Status; -using sendrecv::SendRecvService; -using sendrecv::VariableMessage; -using sendrecv::VoidMessage; - -namespace paddle { -namespace operators { -namespace detail { - -typedef std::pair MessageWithName; - -class SendRecvServerImpl final : public SendRecvService::Service { - public: - explicit SendRecvServerImpl() {} - - Status SendVariable(ServerContext *context, const VariableMessage *in_var, - VoidMessage *out_var) override; - Status GetVariable(ServerContext *context, const VariableMessage *in_var, - VariableMessage *out_var) override; - Status Wait(ServerContext *context, const VoidMessage *in_var, - VoidMessage *out_var) override; - void Reset(); - void Done(); - void SetScope(framework::Scope *scope) { scope_ = scope; }; - - const MessageWithName Get() { return this->var_recv_queue_.Pop(); } - - void Push(const MessageWithName &msg) { this->var_recv_queue_.Push(msg); } - - private: - // received variable from RPC, operators fetch variable from this queue. - SimpleBlockQueue var_recv_queue_; - framework::Scope *scope_; - // condition of the sub program - std::mutex mutex_; - bool done_; - std::condition_variable condition_; -}; - -// RPCClient is a class to send tensors to pserver sub-network -// using different hashing methods. -class RPCClient { - public: - RPCClient(std::shared_ptr channel) - : stub_(SendRecvService::NewStub(channel)) {} - - bool SendVariable(const framework::Scope &scope, const std::string &inname); - bool GetVariable(const framework::Scope &scope, const std::string &outname); - void Wait(); - - private: - std::unique_ptr stub_; -}; - -inline void SerializeToMessage(const std::string &name, - const framework::Variable *var, - const platform::DeviceContext &ctx, - VariableMessage *msg) { - msg->set_varname(name); - std::ostringstream oss; - switch (framework::ToVarType(var->Type())) { - case framework::proto::VarDesc_VarType_LOD_TENSOR: - msg->set_type(sendrecv::VarType::LOD_TENSOR); - framework::SerializeToStream(oss, var->Get(), ctx); - break; - case framework::proto::VarDesc_VarType_SELECTED_ROWS: - msg->set_type(sendrecv::VarType::SELECTED_ROWS); - framework::SerializeToStream(oss, var->Get(), - ctx); - break; - default: { - PADDLE_THROW("Serialize does not support type: %s", - typeid(var->Type()).name()); - break; - } - } - msg->set_serialized(oss.str()); -} - -inline void DeserializeFromMessage(const VariableMessage &msg, - const platform::DeviceContext &ctx, - framework::Variable *var) { - using namespace paddle::framework::proto; - std::istringstream iss(msg.serialized()); - switch (msg.type()) { - case sendrecv::VarType::LOD_TENSOR: - DeserializeFromStream(iss, var->GetMutable(), ctx); - break; - case sendrecv::VarType::SELECTED_ROWS: { - DeserializeFromStream(iss, var->GetMutable(), - ctx); - break; - } - default: { - PADDLE_THROW("Deserialize does not support type: %s", - typeid(var->Type()).name()); - break; - } - } -} - -} // namespace detail -} // namespace operators -} // namespace paddle diff --git a/paddle/operators/detail/sendrecvop_utils.cc b/paddle/operators/detail/sendrecvop_utils.cc new file mode 100644 index 0000000000000000000000000000000000000000..7635b9e8dbdff624bb42a9de346b8d05a980f9b6 --- /dev/null +++ b/paddle/operators/detail/sendrecvop_utils.cc @@ -0,0 +1,68 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/detail/sendrecvop_utils.h" + +namespace paddle { +namespace operators { +namespace detail { + +void SerializeToMessage(const std::string& name, const framework::Variable* var, + const platform::DeviceContext& ctx, + sendrecv::VariableMessage* msg) { + msg->set_varname(name); + std::ostringstream oss; + switch (framework::ToVarType(var->Type())) { + case framework::proto::VarDesc_VarType_LOD_TENSOR: + msg->set_type(sendrecv::VarType::LOD_TENSOR); + framework::SerializeToStream(oss, var->Get(), ctx); + break; + case framework::proto::VarDesc_VarType_SELECTED_ROWS: + msg->set_type(sendrecv::VarType::SELECTED_ROWS); + framework::SerializeToStream(oss, var->Get(), + ctx); + break; + default: { + PADDLE_THROW("Serialize does not support type: %s", + typeid(var->Type()).name()); + break; + } + } + msg->set_serialized(oss.str()); +} + +void DeserializeFromMessage(const sendrecv::VariableMessage& msg, + const platform::DeviceContext& ctx, + framework::Variable* var) { + std::istringstream iss(msg.serialized()); + switch (msg.type()) { + case sendrecv::VarType::LOD_TENSOR: + DeserializeFromStream(iss, var->GetMutable(), ctx); + break; + case sendrecv::VarType::SELECTED_ROWS: { + DeserializeFromStream(iss, var->GetMutable(), + ctx); + break; + } + default: { + PADDLE_THROW("Deserialize does not support type: %s", + typeid(var->Type()).name()); + break; + } + } +} + +} // namespace detail +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/detail/sendrecvop_utils.h b/paddle/operators/detail/sendrecvop_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..bc6581afab93c626c7c2439d699c6c2d858df9fa --- /dev/null +++ b/paddle/operators/detail/sendrecvop_utils.h @@ -0,0 +1,42 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include +#include + +#include "paddle/framework/data_type.h" +#include "paddle/framework/lod_tensor.h" +#include "paddle/framework/scope.h" +#include "paddle/framework/selected_rows.h" +#include "paddle/framework/var_type.h" + +#include "paddle/operators/detail/send_recv.grpc.pb.h" +#include "paddle/operators/detail/send_recv.pb.h" + +namespace paddle { +namespace operators { +namespace detail { + +void SerializeToMessage(const std::string& name, const framework::Variable* var, + const platform::DeviceContext& ctx, + sendrecv::VariableMessage* msg); + +void DeserializeFromMessage(const sendrecv::VariableMessage& msg, + const platform::DeviceContext& ctx, + framework::Variable* var); +} // namespace detail +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/edit_distance_op.cc b/paddle/operators/edit_distance_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..e383f07fa9b53a3def10f6405a0d36f48f52ff08 --- /dev/null +++ b/paddle/operators/edit_distance_op.cc @@ -0,0 +1,98 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/edit_distance_op.h" + +namespace paddle { +namespace operators { + +class EditDistanceOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("Hyps"), "Input(Hyps) shouldn't be null."); + PADDLE_ENFORCE(ctx->HasInput("Refs"), "Input(Refs) shouldn't be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) shouldn't be null."); + auto hyp_dims = ctx->GetInputDim("Hyps"); + auto ref_dims = ctx->GetInputDim("Refs"); + PADDLE_ENFORCE(hyp_dims.size() == 2 && hyp_dims[1] == 1, + "Input(Hyps) must be a 2-D LoDTensor with the 2nd dimension " + "equal to 1."); + PADDLE_ENFORCE(ref_dims.size() == 2 && ref_dims[1] == 1, + "Input(Refs) must be a 2-D LoDTensor with the 2nd dimension " + "equal to 1."); + ctx->SetOutputDim("Out", ctx->GetInputDim("Refs")); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + return framework::OpKernelType(framework::proto::DataType::FP32, + ctx.device_context()); + } +}; + +class EditDistanceOpMaker : public framework::OpProtoAndCheckerMaker { + public: + EditDistanceOpMaker(OpProto *proto, OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("Hyps", + "(2-D LoDTensor, 2nd dim. equal to 1) " + "The indices for hypothesis strings."); + AddInput("Refs", + "(2-D LoDTensor, 2nd dim. equal to 1) " + "The indices for reference strings."); + AddAttr("normalized", + "(bool, default false) Indicated whether to normalize " + "the edit distance by the length of reference string.") + .SetDefault(false); + AddOutput("Out", + "(2-D Tensor with shape [`batch_size` x 1]) " + "The output edit distances of EditDistance operator."); + AddComment(R"DOC( + +EditDistance operator computes the edit distances between a batch of hypothesis +strings and their references. + +Edit distance, also called Levenshtein distance, measures how dissimilar two strings +are by counting the minimum number of operations to transform one string into anthor. +Here the operations include insertion, deletion, and substitution. For example, +given hypothesis string A = "kitten" and reference B = "sitting", the edit distance +is 3 for A will be transformed into B at least after two substitutions and one +insertion: + + "kitten" -> "sitten" -> "sittin" -> "sitting" + +Input(Hyps) is a LoDTensor consisting of all the hypothesis strings with the total +number denoted by `batch_size`, and the separation is specified by the LoD information. +And the `batch_size` reference strings are arranged in order in the same way in the +LoDTensor Input(Refs). + +Output(Out) contains the `batch_size` results and each stands for the edit stance +for a pair of strings respectively. If Attr(normalized) is true, the edit distance +will be divided by the length of reference string. +)DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OPERATOR(edit_distance, ops::EditDistanceOp, ops::EditDistanceOpMaker, + paddle::framework::EmptyGradOpMaker); +REGISTER_OP_CPU_KERNEL( + edit_distance, ops::EditDistanceKernel); diff --git a/paddle/operators/edit_distance_op.cu b/paddle/operators/edit_distance_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..cf5ebc5c38fd006d10de790e45e9bff3409bd20c --- /dev/null +++ b/paddle/operators/edit_distance_op.cu @@ -0,0 +1,149 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include "paddle/framework/op_registry.h" +#include "paddle/platform/cuda_helper.h" +#include "paddle/platform/gpu_info.h" + +namespace paddle { +namespace operators { + +using platform::PADDLE_CUDA_NUM_THREADS; + +template +__global__ void FillFirstRow(T* dist, const int N) { + int idx = blockDim.x * blockIdx.x + threadIdx.x; + if (idx < N + 1) { + dist[idx] = idx; + } +} + +template +__global__ void FillFirstColumn(T* dist, const int M, const int N) { + int idx = blockDim.x * blockIdx.x + threadIdx.x; + if (idx < M + 1) { + dist[idx * (N + 1)] = idx; + } +} + +template +__global__ void Levenshtein(T* dist, const int* x1, const int* x2, const int M, + const int N, const int start) { + int idx = blockDim.x * blockIdx.x + threadIdx.x; + int offset = N; + int index = start + idx * offset; + int row = index / (N + 1); + int col = index % (N + 1); + if (row > 0 && col > 0 && row < M + 1 && col < N + 1) { + int cost = x1[row - 1] == x2[col - 1] ? 0 : 1; + int dels = dist[(row - 1) * (N + 1) + col] + 1; + int ins = dist[row * (N + 1) + col - 1] + 1; + int subs = dist[(row - 1) * (N + 1) + (col - 1)] + cost; + dist[index] = min(dels, min(ins, subs)); + } +} + +template +__global__ void SetOutput(T* out, const T* dist, const int M, const int N, + bool normalized) { + int idx = blockDim.x * blockIdx.x + threadIdx.x; + if (idx == 0) { + out[0] = normalized ? dist[M * (N + 1) + N] / N : dist[M * (N + 1) + N]; + } +} + +template +class EditDistanceGPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const { + auto* out_t = ctx.Output("Out"); + + auto* x1_t = ctx.Input("Hyps"); + auto* x2_t = ctx.Input("Refs"); + + auto normalized = ctx.Attr("normalized"); + auto stream = reinterpret_cast( + ctx.device_context()) + .stream(); + + auto hyp_lod = x1_t->lod()[0]; + auto ref_lod = x2_t->lod()[0]; + PADDLE_ENFORCE( + hyp_lod.size() == ref_lod.size(), + "Input(Hyps) and Input(Refs) must have the same batch size."); + for (size_t i = 1; i < ref_lod.size(); ++i) { + PADDLE_ENFORCE(ref_lod[i] > ref_lod[i - 1], + "Reference string %d is empty.", i); + } + + auto num_strs = hyp_lod.size() - 1; + out_t->Resize({static_cast(num_strs), 1}); + out_t->mutable_data(ctx.GetPlace()); + auto out = out_t->data(); + + T distance = 0.0; + for (size_t num = 0; num < num_strs; num++) { + auto m = static_cast(hyp_lod[num + 1] - hyp_lod[num]); + auto n = static_cast(ref_lod[num + 1] - ref_lod[num]); + if (m == 0 || n == 0) { + distance = std::max(m, n); + if (normalized) { + PADDLE_ENFORCE(n > 0, + "The reference string (#%d) cannot be empty " + "when Attr(normalized) is enabled.", + n); + distance = distance / n; + } + memory::Copy(boost::get(ctx.GetPlace()), out + num, + platform::CPUPlace(), &distance, sizeof(T), stream); + } else { + framework::Tensor dist_t; + dist_t.Resize({m + 1, n + 1}); + dist_t.mutable_data(ctx.GetPlace()); + auto dist = dist_t.data(); + auto x1 = x1_t->data() + hyp_lod[num]; + auto x2 = x2_t->data() + ref_lod[num]; + + FillFirstColumn<<<1 + m / PADDLE_CUDA_NUM_THREADS, + PADDLE_CUDA_NUM_THREADS, 0, stream>>>(dist, m, n); + + FillFirstRow<<<1 + n / PADDLE_CUDA_NUM_THREADS, + PADDLE_CUDA_NUM_THREADS, 0, stream>>>(dist, n); + // Compute the elements of distance matrix in the anti-diagonal diretion + for (int64_t slice = 2; slice < m + n + 1; ++slice) { + int z_m = slice < m + 1 ? 0 : slice - m; + int z_n = slice < n + 1 ? 0 : slice - n; + int size = slice - (z_m + z_n) + 1; // number of elments in the same + // anti-diagonal line to update + // the start index at which computes from + int start = slice < n + 1 ? slice : (z_n + 1) * (n + 1) - 1; + Levenshtein<<<1 + (size - 1) / PADDLE_CUDA_NUM_THREADS, + PADDLE_CUDA_NUM_THREADS, 0, stream>>>(dist, x1, x2, + m, n, start); + } + SetOutput<<<1, 1, 0, stream>>>(out + num, dist, m, n, normalized); + } + } + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OP_CUDA_KERNEL( + edit_distance, + ops::EditDistanceGPUKernel); diff --git a/paddle/operators/edit_distance_op.h b/paddle/operators/edit_distance_op.h new file mode 100644 index 0000000000000000000000000000000000000000..537e70281a5a750db480468a8f8e3c0465de6c5a --- /dev/null +++ b/paddle/operators/edit_distance_op.h @@ -0,0 +1,96 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include "paddle/framework/eigen.h" +#include "paddle/framework/op_registry.h" + +namespace paddle { +namespace operators { + +template +class EditDistanceKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const { + auto* out_t = ctx.Output("Out"); + + auto* x1_t = ctx.Input("Hyps"); + auto* x2_t = ctx.Input("Refs"); + + auto normalized = ctx.Attr("normalized"); + + auto hyp_lod = x1_t->lod()[0]; + auto ref_lod = x2_t->lod()[0]; + PADDLE_ENFORCE( + hyp_lod.size() == ref_lod.size(), + "Input(Hyps) and Input(Refs) must have the same batch size."); + for (size_t i = 1; i < ref_lod.size(); ++i) { + PADDLE_ENFORCE(ref_lod[i] > ref_lod[i - 1], + "Reference string %d is empty.", i); + } + auto num_strs = hyp_lod.size() - 1; + + out_t->Resize({static_cast(num_strs), 1}); + out_t->mutable_data(ctx.GetPlace()); + auto out = out_t->data(); + + T distance = 0.0; + for (size_t num = 0; num < num_strs; ++num) { + auto m = static_cast(hyp_lod[num + 1] - hyp_lod[num]); + auto n = static_cast(ref_lod[num + 1] - ref_lod[num]); + + if (m == 0) { + distance = n; + } else if (n == 0) { + distance = m; + } else { + framework::Tensor dist_t; + dist_t.Resize({m + 1, n + 1}); + dist_t.mutable_data(ctx.GetPlace()); + auto dist = dist_t.data(); + auto x1 = x1_t->data() + hyp_lod[num]; + auto x2 = x2_t->data() + ref_lod[num]; + for (int64_t i = 0; i < m + 1; ++i) { + dist[i * (n + 1)] = i; + } + for (int64_t j = 0; j < n + 1; ++j) { + dist[j] = j; + } + for (int64_t i = 1; i < m + 1; ++i) { + for (int64_t j = 1; j < n + 1; ++j) { + int cost = x1[i - 1] == x2[j - 1] ? 0 : 1; + int dels = dist[(i - 1) * (n + 1) + j] + 1; + int ins = dist[i * (n + 1) + (j - 1)] + 1; + int subs = dist[(i - 1) * (n + 1) + (j - 1)] + cost; + dist[i * (n + 1) + j] = std::min(dels, std::min(ins, subs)); + } + } + distance = dist[m * (n + 1) + n]; + } + + if (normalized) { + PADDLE_ENFORCE(n > 0, + "The reference string (#%d) cannot be empty " + "when Attr(normalized) is enabled.", + n); + distance = distance / n; + } + out[num] = distance; + } + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/recv_op.cc b/paddle/operators/recv_op.cc index 9331c7b563491902b2824898766cacb9bfdee2d9..55b33343af43802e1b6b95a32603bfee806c9764 100644 --- a/paddle/operators/recv_op.cc +++ b/paddle/operators/recv_op.cc @@ -24,7 +24,8 @@ limitations under the License. */ #include "paddle/framework/lod_tensor.h" #include "paddle/framework/op_registry.h" #include "paddle/framework/proto_desc.h" -#include "paddle/operators/detail/send_recv_impl.h" +#include "paddle/operators/detail/grpc_server.h" +#include "paddle/operators/detail/sendrecvop_utils.h" #include "paddle/operators/detail/simple_block_queue.h" #define LISTEN_TERMINATE_MESSAGE "TERMINATE@RECV" @@ -32,6 +33,11 @@ limitations under the License. */ namespace paddle { namespace operators { +void RunServer(std::shared_ptr service) { + service->RunSyncUpdate(); + VLOG(4) << "RunServer thread end"; +} + static void CreateTensorFromMessageType(framework::Variable *var, sendrecv::VarType var_type) { if (var_type == sendrecv::VarType::LOD_TENSOR) { @@ -46,18 +52,6 @@ static void CreateTensorFromMessageType(framework::Variable *var, } } -void RunServer(Server **rpc_server, - std::shared_ptr service, - const std::string &server_address) { - ServerBuilder builder; - builder.AddListeningPort(server_address, grpc::InsecureServerCredentials()); - builder.RegisterService(service.get()); - std::unique_ptr server(builder.BuildAndStart()); - *rpc_server = server.get(); - LOG(INFO) << "Server listening on " << server_address; - server->Wait(); -} - class RecvOp : public framework::OperatorBase { public: RecvOp(const std::string &type, const framework::VariableNameMap &inputs, @@ -65,10 +59,9 @@ class RecvOp : public framework::OperatorBase { const framework::AttributeMap &attrs) : OperatorBase(type, inputs, outputs, attrs) { if (!rpc_service_) { - rpc_service_.reset(new detail::SendRecvServerImpl()); std::string endpoint = Attr("endpoint"); - server_thread_.reset( - new std::thread(RunServer, &rpc_server_, rpc_service_, endpoint)); + rpc_service_.reset(new detail::AsyncGRPCServer(endpoint)); + server_thread_.reset(new std::thread(RunServer, rpc_service_)); } } @@ -76,7 +69,7 @@ class RecvOp : public framework::OperatorBase { detail::MessageWithName term_msg; term_msg.first = LISTEN_TERMINATE_MESSAGE; rpc_service_->Push(term_msg); - rpc_server_->Shutdown(); + rpc_service_->ShutDown(); server_thread_->join(); } @@ -99,10 +92,12 @@ class RecvOp : public framework::OperatorBase { auto grad_list = Attr>("GradList"); auto trainer_count = Attr("Trainers"); size_t param_count = param_list.size(); + rpc_service_->Reset(); // TODO(typhoonzero): change this to a while_op for every cluster-batch. bool exit_flag = false; while (!exit_flag) { + // TODO(gognwb): simply this loop. // Get from multiple trainers, we don't care about order in which // the gradient arrives, just add suffix 0~n then average the gradient. for (size_t i = 0; i < param_count * trainer_count; ++i) { @@ -110,6 +105,7 @@ class RecvOp : public framework::OperatorBase { const detail::MessageWithName &v = rpc_service_->Get(); auto grad_var_name = v.first; if (grad_var_name == LISTEN_TERMINATE_MESSAGE) { + VLOG(4) << "received LISTEN_TERMINATE_MESSAGE and RunOp.Run() exit"; exit_flag = true; break; } @@ -118,10 +114,12 @@ class RecvOp : public framework::OperatorBase { if (it != grad_list.end()) { param_var_name = param_list[it - grad_list.begin()]; } else { - LOG(ERROR) << "grad have no paired param found!"; + LOG(ERROR) << "grad have no paired param found!\"" << grad_var_name + << "\""; } VLOG(3) << "recved grad: " << grad_var_name << " updating param: " << param_var_name; + auto *merged_grad = recv_scope.FindVar(grad_var_name); if (merged_grad == nullptr) { auto *ptr = recv_scope.Var(grad_var_name); @@ -141,9 +139,11 @@ class RecvOp : public framework::OperatorBase { auto &dev_ctx = *pool.Get(dev_place); detail::DeserializeFromMessage(v.second, dev_ctx, var); } + if (exit_flag) { break; } + rpc_service_->Reset(); std::string program_str = Attr("OptimizeProgram"); @@ -158,17 +158,14 @@ class RecvOp : public framework::OperatorBase { } catch (std::exception &e) { LOG(ERROR) << "run sub program error " << e.what(); } + rpc_service_->Done(); grads_counter_.clear(); } // while(true) } protected: - // grpc server instance to track status and gracefully shutdown. - // borrow an pointer from server thread. - Server *rpc_server_{nullptr}; - // grpc send/recv service implement to register. - std::shared_ptr rpc_service_; + std::shared_ptr rpc_service_; std::shared_ptr server_thread_; mutable std::unordered_map grads_counter_; }; diff --git a/paddle/operators/send_op.cc b/paddle/operators/send_op.cc index 95c207221a7b34732eca4cfd07fed0a8f1671981..4d145250bdc73607c8817e20fdb753f4c96e2391 100644 --- a/paddle/operators/send_op.cc +++ b/paddle/operators/send_op.cc @@ -19,59 +19,45 @@ limitations under the License. */ #include "paddle/framework/lod_tensor.h" #include "paddle/framework/op_registry.h" -#include "paddle/operators/detail/send_recv_impl.h" -#include "paddle/operators/detail/simple_block_queue.h" +#include +#include "paddle/operators/detail/grpc_client.h" namespace paddle { namespace operators { -// TODO(typhoonzero): this is a simple implementation which only send -// one tensor class SendOp : public framework::OperatorBase { public: - SendOp(const std::string &type, const framework::VariableNameMap &inputs, - const framework::VariableNameMap &outputs, - const framework::AttributeMap &attrs) - : OperatorBase(type, inputs, outputs, attrs) { - // init client when the operator is created at runtime. - std::vector endpoints = - Attr>("endpoints"); - for (auto ep : endpoints) { - client_map_[ep].reset(new detail::RPCClient( - grpc::CreateChannel(ep, grpc::InsecureChannelCredentials()))); - } - } + SendOp(const std::string& type, const framework::VariableNameMap& inputs, + const framework::VariableNameMap& outputs, + const framework::AttributeMap& attrs) + : OperatorBase(type, inputs, outputs, attrs) {} - void Run(const framework::Scope &scope, - const platform::Place &dev_place) const override { + void Run(const framework::Scope& scope, + const platform::Place& dev_place) const override { auto ins = Inputs("X"); auto outs = Outputs("Out"); std::vector epmap = Attr>("epmap"); - // TODO(typhoonzero): use async calls to send multiple variable asyncly. - for (size_t i = 0; i < ins.size(); ++i) { - bool ret = client_map_[epmap[i]]->SendVariable(scope, ins[i]); - if (!ret) { - LOG(ERROR) << "send variable error: " << ins[i]; - } + + // FIXME(gongwb): DeviceContext? + auto ctx = platform::CPUDeviceContext(); + for (size_t i = 0; i < ins.size(); i++) { + client_.AsyncSendVariable(epmap[i], ctx, scope, ins[i]); } - // TODO(typhoonzero): support async optimization - client_map_[epmap[0]]->Wait(); - for (size_t i = 0; i < outs.size(); ++i) { - bool ret = client_map_[epmap[i]]->GetVariable(scope, outs[i]); - if (!ret) { - LOG(ERROR) << "GetVariable error: " << outs[i]; - } + + for (size_t i = 0; i < outs.size(); i++) { + client_.AsyncGetVariable(epmap[i], ctx, scope, outs[i]); } + + client_.wait(); } - protected: - mutable std::unordered_map> - client_map_; + private: + mutable detail::RPCClient client_; }; class SendOpMaker : public framework::OpProtoAndCheckerMaker { public: - SendOpMaker(OpProto *proto, OpAttrChecker *op_checker) + SendOpMaker(OpProto* proto, OpAttrChecker* op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", "(Tensor) Input tensor to be send").AsDuplicable(); AddOutput("Out", "(Tensor) Output tensor to get from server") diff --git a/paddle/operators/send_recv_op_test.cc b/paddle/operators/send_recv_op_test.cc index fa94424bf9e8e719ec0822268685b0806a109d21..ea091694798475dfd9631910a750405be950c20c 100644 --- a/paddle/operators/send_recv_op_test.cc +++ b/paddle/operators/send_recv_op_test.cc @@ -140,7 +140,7 @@ void StartServerNet(bool is_sparse) { TEST(SendRecvOp, CPUDense) { std::thread server_thread(StartServerNet, false); - sleep(3); // wait server to start + sleep(10); // wait server to start // local net f::Scope scope; p::CPUPlace place; diff --git a/paddle/operators/sequence_erase_op.cc b/paddle/operators/sequence_erase_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..d17b2686238b2d2f872331edfdbb095fb8693b87 --- /dev/null +++ b/paddle/operators/sequence_erase_op.cc @@ -0,0 +1,89 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/sequence_erase_op.h" + +namespace paddle { +namespace operators { + +class SequenceEraseOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of SequenceEraseOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(Out) of SequenceEraseOp should not be null."); + auto x_dims = ctx->GetInputDim("X"); + PADDLE_ENFORCE(x_dims.size() == 2 && x_dims[1] == 1, + "Input(X) of SequenceEraseOp should be a 2-D LoDTensor " + "with the 2nd dimension equal to 1."); + ctx->SetOutputDim("Out", x_dims); + } +}; + +class SequenceEraseOpMaker : public framework::OpProtoAndCheckerMaker { + public: + SequenceEraseOpMaker(OpProto* proto, OpAttrChecker* op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", + "(2-D LoDTensor with the 2nd dim. equal to 1) " + "Input LoDTensor of SequenceEraseOp."); + AddOutput("Out", + "(2-D LoDTensor with the 2nd dim. equal to 1) " + "Output LoDTensor of SequenceEraseOp."); + AddAttr>("tokens", + "(vector) Tokens need to be erased from " + "input sequences."); + AddComment(R"DOC( +Sequence Erase Operator. + +Sequence erase operator erases tokens specified by Attr(tokens) from the input +sequences Input(X), and outputs the remaining data and modifies the LoD +information at the same time. For example, given a 2-D LoDTensor + + X = [[2, 2, 6, 1, 3, 9, 6, 1, 0, 1]]^T + +with lod = [[0, 3, 6, 10]], there are three sequences in the input: + + X1 = [[2, 2, 6]]^T, X2 = [[1, 3, 9]]^T and X3 = [[6, 1, 0, 1]]^T. + +If the tokens to be erased are Attr(tokens) = [2, 3, 5], after the erasing +operation, the three sequences become + + X1' = [[6]]^T, X2' = [[1, 9]]^T and X3' = [[6, 1, 0, 1]]^T. + +Hence the LoDTensor Output(Out) should be + + Out = [[6, 1, 9, 6, 1, 0, 1]]^T, + +with lod = [[0, 1, 3, 7]]. + +An example usage for this operator is to remove the special tokens when +computing the edit distance between two strings, such as blank, start token, +and end token. +)DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_WITHOUT_GRADIENT(sequence_erase, ops::SequenceEraseOp, + ops::SequenceEraseOpMaker); +REGISTER_OP_CPU_KERNEL( + sequence_erase, + ops::SequenceEraseKernel); diff --git a/paddle/operators/sequence_erase_op.cu b/paddle/operators/sequence_erase_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..5da8eba3e1ac1fb85dfc65c2fd801574599e02d9 --- /dev/null +++ b/paddle/operators/sequence_erase_op.cu @@ -0,0 +1,133 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include "paddle/operators/sequence_erase_op.h" +#include "paddle/platform/cuda_helper.h" + +namespace paddle { +namespace operators { +using platform::PADDLE_CUDA_NUM_THREADS; +using LoDTensor = framework::LoDTensor; + +template +__global__ void LabelErasedIdx(const T* in_dat, const int in_len, + const T* tokens, const int tokens_len, + int* num_erased) { + int index = blockIdx.x * blockDim.x + threadIdx.x; + if (index < in_len) { + int erased = 0; + for (int i = 0; i < tokens_len; ++i) { + if (in_dat[index] == tokens[i]) { + erased = 1; + } + } + num_erased[index + 1] = erased; + if (index == 0) { + num_erased[0] = 0; + } + } +} + +template +__global__ void GetOutLod(const T* num_erased, const int* in_lod, + const int lod_len, int* out_lod0) { + int index = blockIdx.x * blockDim.x + threadIdx.x; + if (index < lod_len) { + out_lod0[index] = in_lod[index] - num_erased[in_lod[index]]; + } +} + +template +__global__ void SetOutput(const T* in_dat, const int in_len, + const int* num_erased, T* out_dat) { + int index = blockIdx.x * blockDim.x + threadIdx.x; + if (index < in_len) { + if (in_dat[index] != in_dat[index + 1]) { + out_dat[index - num_erased[index]] = in_dat[index]; + } + } +} + +template +class SequenceEraseOpCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* in = ctx.Input("X"); + auto* out = ctx.Output("Out"); + + auto lod = in->lod(); + PADDLE_ENFORCE_EQ(lod.size(), 1UL, "Only support one level sequence now."); + PADDLE_ENFORCE_EQ(lod[0].back(), (size_t)in->numel(), + "The actual size mismatches with the LoD information."); + auto tokens = ctx.Attr>("tokens"); + auto tokens_len = tokens.size(); + auto in_len = in->numel(); + auto in_dat = in->data(); + auto lod0 = lod[0]; + + thrust::host_vector host_tokens(tokens_len); + for (size_t i = 0; i < tokens.size(); ++i) { + host_tokens[i] = tokens[i]; + } + thrust::device_vector dev_tokens = host_tokens; + thrust::device_vector num_erased(in_len + 1); + + T* dev_tokens_ptr = thrust::raw_pointer_cast(dev_tokens.data()); + int* num_erased_ptr = thrust::raw_pointer_cast(num_erased.data()); + + auto stream = ctx.cuda_device_context().stream(); + LabelErasedIdx<<<(in_len - 1) / PADDLE_CUDA_NUM_THREADS + 1, + PADDLE_CUDA_NUM_THREADS, 0, stream>>>( + in_dat, in_len, dev_tokens_ptr, tokens_len, num_erased_ptr); + thrust::inclusive_scan(num_erased.begin() + 1, num_erased.end(), + num_erased.begin() + 1); + + // Calc LoD + auto lod_len = lod0.size(); + thrust::host_vector host_lod(lod_len); + for (size_t i = 0; i < lod_len; ++i) { + host_lod[i] = lod0[i]; + } + thrust::device_vector dev_in_lod = host_lod; + thrust::device_vector dev_out_lod(lod_len); + int* dev_in_lod_ptr = thrust::raw_pointer_cast(dev_in_lod.data()); + int* dev_out_lod_ptr = thrust::raw_pointer_cast(dev_out_lod.data()); + GetOutLod<<<(lod_len - 1) / PADDLE_CUDA_NUM_THREADS + 1, + PADDLE_CUDA_NUM_THREADS, 0, stream>>>( + num_erased_ptr, dev_in_lod_ptr, lod_len, dev_out_lod_ptr); + thrust::host_vector host_out_lod = dev_out_lod; + std::vector out_lod0(lod_len, 0); + for (size_t i = 0; i < lod_len; i++) { + out_lod0[i] = host_out_lod[i]; + } + framework::LoD out_lod; + out_lod.push_back(out_lod0); + out->set_lod(out_lod); + + // Set output + out->Resize({out_lod0.back(), 1}); + auto out_dat = out->mutable_data(ctx.GetPlace()); + SetOutput<<<(in_len - 1) / PADDLE_CUDA_NUM_THREADS + 1, + PADDLE_CUDA_NUM_THREADS, 0, stream>>>(in_dat, in_len, + num_erased_ptr, out_dat); + } +}; + +} // namespace operators +} // namespace paddle + +REGISTER_OP_CUDA_KERNEL(sequence_erase, + paddle::operators::SequenceEraseOpCUDAKernel); diff --git a/paddle/operators/sequence_erase_op.h b/paddle/operators/sequence_erase_op.h new file mode 100644 index 0000000000000000000000000000000000000000..cb2d7be009dcbe0138818457249e95fbdd27fc0a --- /dev/null +++ b/paddle/operators/sequence_erase_op.h @@ -0,0 +1,70 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/framework/op_registry.h" + +namespace paddle { +namespace operators { + +template +class SequenceEraseKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* in = ctx.Input("X"); + auto* out = ctx.Output("Out"); + + auto lod = in->lod(); + PADDLE_ENFORCE_EQ(lod.size(), 1UL, "Only support one level sequence now."); + PADDLE_ENFORCE_EQ(lod[0].back(), (size_t)in->numel(), + "The actual size mismatches with the LoD information."); + auto tokens = ctx.Attr>("tokens"); + auto in_len = in->numel(); + auto in_dat = in->data(); + auto lod0 = lod[0]; + + std::vector num_erased(in_len + 1, 0); + std::vector out_lod0(1, 0); + for (size_t i = 0; i < lod0.size() - 1; ++i) { + size_t num_out = 0; + for (auto j = lod0[i] + 1; j <= lod0[i + 1]; ++j) { + num_erased[j] = num_erased[j - 1]; + if (std::find(tokens.begin(), tokens.end(), in_dat[j - 1]) != + tokens.end()) { + num_erased[j] += 1; + } else { + num_out += 1; + } + } + out_lod0.push_back(out_lod0.back() + num_out); + } + + auto out_len = in_len - num_erased[in_len]; + out->Resize({static_cast(out_len), 1}); + auto out_dat = out->mutable_data(ctx.GetPlace()); + + for (int64_t i = 0; i < in_len; ++i) { + if (num_erased[i] == num_erased[i + 1]) { + out_dat[i - num_erased[i]] = in_dat[i]; + } + } + framework::LoD out_lod; + out_lod.push_back(out_lod0); + out->set_lod(out_lod); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/shrink_rnn_memory_op.cc b/paddle/operators/shrink_rnn_memory_op.cc index 821754a0a632e15643eaeff4133174eb75c9700f..3f5b2a9b84350c7dee5cb461ba6207e20e95c11b 100644 --- a/paddle/operators/shrink_rnn_memory_op.cc +++ b/paddle/operators/shrink_rnn_memory_op.cc @@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/framework/lod_rank_table.h" +#include "paddle/framework/lod_tensor.h" #include "paddle/operators/array_operator.h" #include "paddle/operators/math/math_function.h" @@ -46,8 +47,21 @@ class ShrinkRNNMemoryOp : public ArrayOp { auto *out_var = scope.FindVar(Output("Out")); PADDLE_ENFORCE(out_var != nullptr, "Output Out must be set"); auto &out_tensor = *out_var->GetMutable(); + + size_t height = dst_num_rows; + + // do shrink for the top level LoD + if (x_tensor.lod().size() > 0 && + x_tensor.lod()[0].size() > static_cast(dst_num_rows)) { + auto lod_offset = framework::GetSubLoDAndAbsoluteOffset(x_tensor.lod(), 0, + dst_num_rows, 0); + height = lod_offset.second.second; + auto out_lod = out_tensor.mutable_lod(); + framework::AppendLoD(out_lod, lod_offset.first); + } + if (dst_num_rows != 0) { - out_tensor.ShareDataWith(x_tensor.Slice(0, dst_num_rows)); + out_tensor.ShareDataWith(x_tensor.Slice(0, height)); } } }; @@ -64,11 +78,11 @@ class ShrinkRNNMemoryOpProtoMaker : public framework::OpProtoAndCheckerMaker { AddOutput("Out", "(LoDTensor) The shrinked RNN step memory."); AddComment( R"DOC( - In dynamic RNN, we are able to handle sequences of different lengths. - Because of the multiple lengths, the size of each step input can be + In dynamic RNN, we are able to handle sequences of different lengths. + Because of the multiple lengths, the size of each step input can be different, which may lead to a mismatching between the input of - the current step and the memory generated by the previous one. This - operator shrinks memory according to the size of the next step input, + the current step and the memory generated by the previous one. This + operator shrinks memory according to the size of the next step input, to make sure that they can match each other. )DOC"); } @@ -132,6 +146,7 @@ class ShrinkRNNMemoryGradInferShape : public framework::InferShapeBase { PADDLE_ENFORCE(context->HasOutput(framework::GradVarName("X"))); context->SetOutputDim(framework::GradVarName("X"), context->GetInputDim("X")); + context->ShareLoD("X", framework::GradVarName("X")); } }; diff --git a/paddle/scripts/submit_local.sh.in b/paddle/scripts/submit_local.sh.in index 8a352b0078d701f797f7202c85bd0e08201ac9b8..bb47ad614ed85923ce5d9704760ec6c5b5ae59ee 100755 --- a/paddle/scripts/submit_local.sh.in +++ b/paddle/scripts/submit_local.sh.in @@ -92,6 +92,9 @@ function threads_config() { if [ -z "$OPENBLAS_NUM_THREADS" ]; then export OPENBLAS_NUM_THREADS=$threads fi + if [ $threads -gt 1 ] && [ -z "$OPENBLAS_MAIN_FREE" ]; then + export OPENBLAS_MAIN_FREE=1 + fi fi } diff --git a/python/paddle/v2/__init__.py b/python/paddle/v2/__init__.py index 0de417df2cb942ce46ff8ac3acc61ae4999ed634..df710c33d0c0ca16d358dac1eb42327e9cd4c7ae 100644 --- a/python/paddle/v2/__init__.py +++ b/python/paddle/v2/__init__.py @@ -62,12 +62,15 @@ __all__ = [ cp.begin_parse() -def set_omp_mkl_env_vars(trainer_count): +def set_env_vars(trainer_count): '''Auto set CPU environment if have not set before. - export KMP_AFFINITY, OMP_DYNAMIC according to the Hyper Threading status. - export OMP_NUM_THREADS, MKL_NUM_THREADS according to trainer_count. + For MKL: + export KMP_AFFINITY, OMP_DYNAMIC according to the Hyper Threading status. + export OMP_NUM_THREADS, MKL_NUM_THREADS according to trainer_count. + For OpenBLAS: + export OPENBLAS_NUM_THREADS, OPENBLAS_MAIN_FREE according to trainer_count. ''' - import platform + import platform, paddle if not platform.system() in ['Linux', 'Darwin']: return @@ -103,16 +106,22 @@ def set_omp_mkl_env_vars(trainer_count): num_cores = num_physical_cores() num_processors = num_logical_processors() - if num_processors > num_cores: # Hyper Threading is enabled - set_env("OMP_DYNAMIC", "true") - set_env("KMP_AFFINITY", "granularity=fine,compact,1,0") - else: - set_env("OMP_DYNAMIC", "false") - set_env("KMP_AFFINITY", "granularity=fine,compact,0,0") + if paddle.version.mkl() == 'ON': + if num_processors > num_cores: # Hyper Threading is enabled + set_env("OMP_DYNAMIC", "true") + set_env("KMP_AFFINITY", "granularity=fine,compact,1,0") + else: + set_env("OMP_DYNAMIC", "false") + set_env("KMP_AFFINITY", "granularity=fine,compact,0,0") threads = num_processors / trainer_count threads = '1' if threads < 1 else str(threads) - set_env("OMP_NUM_THREADS", threads) - set_env("MKL_NUM_THREADS", threads) + if paddle.version.mkl() == 'ON': + set_env("OMP_NUM_THREADS", threads) + set_env("MKL_NUM_THREADS", threads) + else: + set_env("OPENBLAS_NUM_THREADS", threads) + if threads > 1: + set_env("OPENBLAS_MAIN_FREE", '1') def init(**kwargs): @@ -129,7 +138,7 @@ def init(**kwargs): for key in args_dict.keys(): args.append('--%s=%s' % (key, str(args_dict[key]))) - set_omp_mkl_env_vars(kwargs.get('trainer_count', 1)) + set_env_vars(kwargs.get('trainer_count', 1)) if 'use_gpu' in kwargs: cp.g_command_config_args['use_gpu'] = kwargs['use_gpu'] diff --git a/python/paddle/v2/fluid/layers/nn.py b/python/paddle/v2/fluid/layers/nn.py index b1534c5a886db3c9694637e4a4195427c3538bb7..48a6bee5588949f708e6c588152be9e174f3ad69 100644 --- a/python/paddle/v2/fluid/layers/nn.py +++ b/python/paddle/v2/fluid/layers/nn.py @@ -14,7 +14,7 @@ __all__ = [ 'chunk_eval', 'sequence_conv', 'conv2d', 'sequence_pool', 'pool2d', 'batch_norm', 'beam_search_decode', 'conv2d_transpose', 'sequence_expand', 'lstm_unit', 'reduce_sum', 'reduce_mean', 'reduce_max', 'reduce_min', - 'sequence_first_step', 'sequence_last_step' + 'sequence_first_step', 'sequence_last_step', 'dropout' ] @@ -386,6 +386,21 @@ def cos_sim(X, Y, **kwargs): return out +def dropout(x, dropout_prob, is_test=False, seed=0, **kwargs): + helper = LayerHelper('dropout', **kwargs) + out = helper.create_tmp_variable(dtype=x.dtype) + mask = helper.create_tmp_variable(dtype=x.dtype, stop_gradient=True) + helper.append_op( + type='dropout', + inputs={'X': [x]}, + outputs={'Out': [out], + 'Mask': [mask]}, + attrs={'dropout_prob': dropout_prob, + 'is_test': is_test, + 'seed': seed}) + return out + + def cross_entropy(input, label, **kwargs): """ **Cross Entropy Layer** diff --git a/python/paddle/v2/fluid/layers/ops.py b/python/paddle/v2/fluid/layers/ops.py index 544623c4bce0cb75ea727906c4879e986c8d1ce8..d3a5b70785947148d6e208b4d8dafec8bb52ff85 100644 --- a/python/paddle/v2/fluid/layers/ops.py +++ b/python/paddle/v2/fluid/layers/ops.py @@ -1,23 +1,12 @@ from ..registry import register_layer __activations__ = [ - 'abs', - 'ceil', - 'exp', - 'floor', - 'log', - 'relu', - 'round', - 'sigmoid', - 'sqrt', - 'square', - 'tanh', + 'abs', 'tanh', 'sigmoid', 'relu', 'sqrt', 'ceil', 'floor', 'log', 'round' ] __all__ = [ 'mean', 'mul', - 'dropout', 'reshape', 'scale', 'transpose', diff --git a/python/paddle/v2/fluid/tests/test_edit_distance_op.py b/python/paddle/v2/fluid/tests/test_edit_distance_op.py new file mode 100644 index 0000000000000000000000000000000000000000..38e87728b387bb70a8921a2fe73a4e69701aabe9 --- /dev/null +++ b/python/paddle/v2/fluid/tests/test_edit_distance_op.py @@ -0,0 +1,94 @@ +import unittest +import numpy as np +from op_test import OpTest + + +def Levenshtein(hyp, ref): + """ Compute the Levenshtein distance between two strings. + + :param hyp: hypothesis string in index + :type hyp: list + :param ref: reference string in index + :type ref: list + """ + m = len(hyp) + n = len(ref) + if m == 0: + return n + if n == 0: + return m + + dist = np.zeros((m + 1, n + 1)).astype("float32") + for i in range(0, m + 1): + dist[i][0] = i + for j in range(0, n + 1): + dist[0][j] = j + + for i in range(1, m + 1): + for j in range(1, n + 1): + cost = 0 if hyp[i - 1] == ref[j - 1] else 1 + deletion = dist[i - 1][j] + 1 + insertion = dist[i][j - 1] + 1 + substitution = dist[i - 1][j - 1] + cost + dist[i][j] = min(deletion, insertion, substitution) + return dist[m][n] + + +class TestEditDistanceOp(OpTest): + def setUp(self): + self.op_type = "edit_distance" + normalized = False + x1 = np.array([[0, 12, 3, 5, 8, 2]]).astype("int32") + x2 = np.array([[0, 12, 4, 7, 8]]).astype("int32") + x1 = np.transpose(x1) + x2 = np.transpose(x2) + x1_lod = [0, 1, 5] + x2_lod = [0, 3, 4] + + num_strs = len(x1_lod) - 1 + distance = np.zeros((num_strs, 1)).astype("float32") + for i in range(0, num_strs): + distance[i] = Levenshtein( + hyp=x1[x1_lod[i]:x1_lod[i + 1]], + ref=x2[x2_lod[i]:x2_lod[i + 1]]) + if normalized is True: + len_ref = x2_lod[i + 1] - x2_lod[i] + distance[i] = distance[i] / len_ref + self.attrs = {'normalized': normalized} + self.inputs = {'Hyps': (x1, [x1_lod]), 'Refs': (x2, [x2_lod])} + self.outputs = {'Out': distance} + + def test_check_output(self): + self.check_output() + + +class TestEditDistanceOpNormalized(OpTest): + def setUp(self): + self.op_type = "edit_distance" + normalized = True + x1 = np.array([[0, 10, 3, 6, 5, 8, 2]]).astype("int32") + x2 = np.array([[0, 10, 4, 6, 7, 8]]).astype("int32") + x1 = np.transpose(x1) + x2 = np.transpose(x2) + x1_lod = [0, 1, 3, 6] + x2_lod = [0, 2, 3, 5] + + num_strs = len(x1_lod) - 1 + distance = np.zeros((num_strs, 1)).astype("float32") + for i in range(0, num_strs): + distance[i] = Levenshtein( + hyp=x1[x1_lod[i]:x1_lod[i + 1]], + ref=x2[x2_lod[i]:x2_lod[i + 1]]) + if normalized is True: + len_ref = x2_lod[i + 1] - x2_lod[i] + distance[i] = distance[i] / len_ref + self.attrs = {'normalized': normalized} + self.inputs = {'Hyps': (x1, [x1_lod]), 'Refs': (x2, [x2_lod])} + self.outputs = {'Out': distance} + + def test_check_output(self): + self.check_output() + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/v2/fluid/tests/test_sequence_erase_op.py b/python/paddle/v2/fluid/tests/test_sequence_erase_op.py new file mode 100644 index 0000000000000000000000000000000000000000..bf257fefea0d98c6f4d9860dbac4ccedf59bcdd9 --- /dev/null +++ b/python/paddle/v2/fluid/tests/test_sequence_erase_op.py @@ -0,0 +1,35 @@ +import unittest +import numpy as np +from op_test import OpTest + + +def sequence_erase(in_seq, lod0, tokens): + new_lod0 = [0] + out_seq = [] + for i in range(0, len(lod0) - 1): + num_out = 0 + for dat in in_seq[lod0[i]:lod0[i + 1]]: + if dat not in tokens: + out_seq.append(dat) + num_out += 1 + new_lod0.append(new_lod0[-1] + num_out) + return np.array(out_seq).astype("int32"), new_lod0 + + +class TestSequenceEraseOp(OpTest): + def setUp(self): + self.op_type = "sequence_erase" + in_seq = np.random.randint(0, 10, (30, 1)).astype("int32") + lod = [[0, 9, 13, 24, 30]] + tokens = [2, 3, 5] + out_seq, new_lod0 = sequence_erase(in_seq, lod[0], tokens) + self.attrs = {'tokens': tokens} + self.inputs = {'X': (in_seq, lod)} + self.outputs = {'Out': (out_seq, [new_lod0])} + + def test_check_output(self): + self.check_output() + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/v2/fluid/tests/test_shrink_rnn_memory.py b/python/paddle/v2/fluid/tests/test_shrink_rnn_memory.py index be1588fc2d09fa58882425eb3d080ef1560ebc79..a14721b9aacfa7437623024af41555fd26990499 100644 --- a/python/paddle/v2/fluid/tests/test_shrink_rnn_memory.py +++ b/python/paddle/v2/fluid/tests/test_shrink_rnn_memory.py @@ -3,43 +3,86 @@ import paddle.v2.fluid.core as core from paddle.v2.fluid.executor import Executor import paddle.v2.fluid.layers as layers from paddle.v2.fluid.backward import append_backward -from paddle.v2.fluid.framework import default_main_program -import numpy +from paddle.v2.fluid.framework import default_main_program, switch_main_program +from paddle.v2.fluid.framework import Program +import numpy as np -main_program = default_main_program() - -class TestShrinkRNNMemory(unittest.TestCase): - def test_shrink_rnn_memory(self): +class TestShrinkRNNMemoryBase(unittest.TestCase): + def setUp(self): + self.main_program = Program() + switch_main_program(self.main_program) x = layers.data('x', shape=[100], dtype='float32') x.stop_gradient = False - table = layers.lod_rank_table(x=x) + rank_table_tensor = layers.data( + 'rank_table_tensor', shape=[1], dtype='float32', lod_level=1) + table = layers.lod_rank_table(x=rank_table_tensor) i = layers.zeros(dtype='int64', shape=[1]) - mem1 = layers.shrink_memory(x=x, i=i, table=table) + self.mem1 = layers.shrink_memory(x=x, i=i, table=table) i = layers.increment(x=i) i.stop_gradient = True - mem2 = layers.shrink_memory(x=mem1, i=i, table=table) + self.mem2 = layers.shrink_memory(x=self.mem1, i=i, table=table) i = layers.increment(x=i) i.stop_gradient = True - mem3 = layers.shrink_memory(x=mem2, i=i, table=table) + self.mem3 = layers.shrink_memory(x=self.mem2, i=i, table=table) + mem3_mean = layers.mean(x=self.mem3) + append_backward(loss=mem3_mean) + self.x_grad = self.main_program.global_block().var('x@GRAD') + + def sum_lodtensor(self, tensor): + sum_res = 0.0 + for i in xrange(np.product(tensor.get_dims())): + sum_res += tensor.get_float_element(i) + return sum_res + +class TestShrinkRNNMemoryReferLoD(TestShrinkRNNMemoryBase): + def test_refer_lod(self): cpu = core.CPUPlace() - tensor = core.LoDTensor() - tensor.set_lod([[0, 2, 5, 6]]) - tensor_np = numpy.random.random(size=(3, 100)).astype('float32') - tensor.set(tensor_np, cpu) + x_tensor = core.LoDTensor() + x_tensor.set_lod([[0, 2, 5, 6]]) + tensor_np = np.random.random(size=(6, 100)).astype('float32') + x_tensor.set(tensor_np, cpu) + + rank_table_tensor = core.LoDTensor() + rank_table_tensor.set_lod([[0, 1, 3, 6]]) + rank_table_tensor.set(np.random.random(size=(6, 1)).astype('float32'), + cpu) + exe = Executor(cpu) - outs = exe.run(feed={'x': tensor}, fetch_list=[mem1, mem2, mem3]) - self.assertTrue(numpy.allclose(tensor_np[0:3], outs[0])) - self.assertTrue(numpy.allclose(tensor_np[0:2], outs[1])) - self.assertTrue(numpy.allclose(tensor_np[0:1], outs[2])) + outs = exe.run( + feed={'x': x_tensor, + 'rank_table_tensor': rank_table_tensor}, + fetch_list=[self.mem1, self.mem2, self.mem3, self.x_grad], + return_numpy=False) + self.assertTrue(np.allclose(tensor_np[0:6], outs[0])) + self.assertTrue(np.allclose(tensor_np[0:5], outs[1])) + self.assertTrue(np.allclose(tensor_np[0:2], outs[2])) + self.assertAlmostEqual(1.0, self.sum_lodtensor(outs[3]), delta=0.01) - mem3_mean = layers.mean(x=mem3) - append_backward(loss=mem3_mean) - x_grad = exe.run( - feed={'x': tensor}, - fetch_list=[main_program.global_block().var('x@GRAD')])[0] - self.assertAlmostEqual(1.0, x_grad.sum(), delta=0.1) + +class TestShrinkRNNMemoryNoLoD(TestShrinkRNNMemoryBase): + def test_no_lod(self): + cpu = core.CPUPlace() + x_tensor = core.LoDTensor() + tensor_np = np.random.random(size=(3, 100)).astype('float32') + x_tensor.set(tensor_np, cpu) + + rank_table_tensor = core.LoDTensor() + rank_table_tensor.set_lod([[0, 1, 3, 6]]) + rank_table_tensor.set(np.random.random(size=(6, 1)).astype('float32'), + cpu) + + exe = Executor(cpu) + outs = exe.run( + feed={'x': x_tensor, + 'rank_table_tensor': rank_table_tensor}, + fetch_list=[self.mem1, self.mem2, self.mem3, self.x_grad], + return_numpy=False) + self.assertTrue(np.allclose(tensor_np[0:3], outs[0])) + self.assertTrue(np.allclose(tensor_np[0:2], outs[1])) + self.assertTrue(np.allclose(tensor_np[0:1], outs[2])) + self.assertAlmostEqual(1.0, self.sum_lodtensor(outs[3]), delta=0.01) if __name__ == '__main__': diff --git a/python/setup.py.in b/python/setup.py.in index 66ccfe808763d0e157f866ce08868e3fdebdea79..65ec58ecf98e693ecf02922129f6eef13cbe5303 100644 --- a/python/setup.py.in +++ b/python/setup.py.in @@ -31,6 +31,7 @@ patch = '%(patch)d' rc = '%(rc)d' istaged = %(istaged)s commit = '%(commit)s' +with_mkl = '%(with_mkl)s' def show(): if istaged: @@ -41,6 +42,9 @@ def show(): print 'rc:', rc else: print 'commit:', commit + +def mkl(): + return with_mkl ''' commit = git_commit() with open(filename, 'w') as f: @@ -51,7 +55,8 @@ def show(): 'rc': RC, 'version': '${PADDLE_VERSION}', 'commit': commit, - 'istaged': ISTAGED}) + 'istaged': ISTAGED, + 'with_mkl': '@WITH_MKL@'}) write_version_py(filename='@PADDLE_SOURCE_DIR@/python/paddle/version.py')