diff --git a/benchmark/paddle/image/plotlog.py b/benchmark/paddle/image/plotlog.py new file mode 100644 index 0000000000000000000000000000000000000000..8679d4f272d1b7aaf8d5a397f07698a6b70e4fcd --- /dev/null +++ b/benchmark/paddle/image/plotlog.py @@ -0,0 +1,114 @@ +# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import argparse +import matplotlib.pyplot as plt + + +def parse_args(): + parser = argparse.ArgumentParser('Parse Log') + parser.add_argument( + '--file_path', '-f', type=str, help='the path of the log file') + parser.add_argument( + '--sample_rate', + '-s', + type=float, + default=1.0, + help='the rate to take samples from log') + parser.add_argument( + '--log_period', '-p', type=int, default=1, help='the period of log') + + args = parser.parse_args() + return args + + +def parse_file(file_name): + loss = [] + error = [] + with open(file_name) as f: + for i, line in enumerate(f): + line = line.strip() + if not line.startswith('pass'): + continue + line_split = line.split(' ') + if len(line_split) != 5: + continue + + loss_str = line_split[2][:-1] + cur_loss = float(loss_str.split('=')[-1]) + loss.append(cur_loss) + + err_str = line_split[3][:-1] + cur_err = float(err_str.split('=')[-1]) + error.append(cur_err) + + accuracy = [1.0 - err for err in error] + + return loss, accuracy + + +def sample(metric, sample_rate): + interval = int(1.0 / sample_rate) + if interval > len(metric): + return metric[:1] + + num = len(metric) / interval + idx = [interval * i for i in range(num)] + metric_sample = [metric[id] for id in idx] + return metric_sample + + +def plot_metric(metric, + batch_id, + graph_title, + line_style='b-', + line_label='y', + line_num=1): + plt.figure() + plt.title(graph_title) + if line_num == 1: + plt.plot(batch_id, metric, line_style, label=line_label) + else: + for i in range(line_num): + plt.plot(batch_id, metric[i], line_style[i], label=line_label[i]) + plt.xlabel('batch') + plt.ylabel(graph_title) + plt.legend() + plt.savefig(graph_title + '.jpg') + plt.close() + + +def main(): + args = parse_args() + assert args.sample_rate > 0. and args.sample_rate <= 1.0, "The sample rate should in the range (0, 1]." + + loss, accuracy = parse_file(args.file_path) + batch = [args.log_period * i for i in range(len(loss))] + + batch_sample = sample(batch, args.sample_rate) + loss_sample = sample(loss, args.sample_rate) + accuracy_sample = sample(accuracy, args.sample_rate) + + plot_metric(loss_sample, batch_sample, 'loss', line_label='loss') + plot_metric( + accuracy_sample, + batch_sample, + 'accuracy', + line_style='g-', + line_label='accuracy') + + +if __name__ == '__main__': + main() diff --git a/cmake/external/warpctc.cmake b/cmake/external/warpctc.cmake index a8e1aca49c97df256b1269c286b0bce7732fa932..7cb4efa7bff7164464f1210a2b2188226c219ef6 100644 --- a/cmake/external/warpctc.cmake +++ b/cmake/external/warpctc.cmake @@ -63,7 +63,7 @@ ExternalProject_Add( MESSAGE(STATUS "warp-ctc library: ${WARPCTC_LIBRARIES}") INCLUDE_DIRECTORIES(${WARPCTC_INCLUDE_DIR}) -ADD_LIBRARY(warpctc STATIC IMPORTED GLOBAL) +ADD_LIBRARY(warpctc SHARED IMPORTED GLOBAL) SET_PROPERTY(TARGET warpctc PROPERTY IMPORTED_LOCATION ${WARPCTC_LIBRARIES}) ADD_DEPENDENCIES(warpctc extern_warpctc) diff --git a/paddle/framework/device_data_transform_test.cu b/paddle/framework/device_data_transform_test.cu index 9fb26f09c7ed6aff3bfc98cf3f829e50adbf48bf..5d89f5546fa87241dec6364d86a100ca51bce687 100644 --- a/paddle/framework/device_data_transform_test.cu +++ b/paddle/framework/device_data_transform_test.cu @@ -105,8 +105,7 @@ static void BuildVar(const std::string& param_name, TEST(Operator, CPUtoGPU) { using namespace paddle::framework; using namespace paddle::platform; - - ASSERT_EQ(InitDevices({"CPU", "GPU:0"}), true); + InitDevices(); paddle::framework::Scope scope; paddle::platform::CPUPlace cpu_place; diff --git a/paddle/framework/executor.cc b/paddle/framework/executor.cc index d8ef9a0fbaa7ba18d78060bd5b9605458cd9b1a2..c0418c9266e257bd7567861543e557f354451b17 100644 --- a/paddle/framework/executor.cc +++ b/paddle/framework/executor.cc @@ -35,7 +35,7 @@ const std::string kFetchOpType = "fetch"; Executor::Executor(const platform::Place& place) : place_(place) {} -void CreateTensor(Variable* var, proto::VarDesc::VarType var_type) { +static void CreateTensor(Variable* var, proto::VarDesc::VarType var_type) { if (var_type == proto::VarDesc::LOD_TENSOR) { var->GetMutable(); } else if (var_type == proto::VarDesc::SELECTED_ROWS) { diff --git a/paddle/framework/executor.h b/paddle/framework/executor.h index 0b2b5780fed1ef48ba78f44112fb0a88b477b796..d869e18901b82959a40cc296aa0844c20ea63ac1 100644 --- a/paddle/framework/executor.h +++ b/paddle/framework/executor.h @@ -45,7 +45,5 @@ class Executor { const platform::Place place_; }; -void CreateTensor(Variable* var, proto::VarDesc::VarType var_type); - } // namespace framework } // namespace paddle diff --git a/paddle/framework/grad_op_desc_maker.h b/paddle/framework/grad_op_desc_maker.h index 2de5242831835b47893a5825e5532500ad5ec3f9..2082f8bb76fb62bc36f033fecbd4eaa76d12d949 100644 --- a/paddle/framework/grad_op_desc_maker.h +++ b/paddle/framework/grad_op_desc_maker.h @@ -87,7 +87,11 @@ class GradOpDescMakerBase { auto onames = this->Output(name); ret_val.reserve(onames.size()); std::transform(onames.begin(), onames.end(), std::back_inserter(ret_val), - GradVarName); + [this](const std::string& fwd_var_name) -> std::string { + auto g_name = GradVarName(fwd_var_name); + (*this->grad_to_var_)[g_name] = fwd_var_name; + return g_name; + }); return ret_val; } diff --git a/paddle/framework/init.cc b/paddle/framework/init.cc index e7087e063cbe8839716e3648d55cd25cc778f06f..e12bac1d78e3f6bbc46849c06b53e3b93e147cfc 100644 --- a/paddle/framework/init.cc +++ b/paddle/framework/init.cc @@ -40,40 +40,23 @@ void InitGflags(std::vector &argv) { }); } -bool InitDevices(const std::vector &devices) { - // device format - // CPU - // GPU:1 - // TODO(dzhwinter) : add device format annotation for users. +void InitDevices() { + /*Init all avaiable devices by default */ + std::vector places; - for (auto &device : devices) { - auto p = string::Piece(device); - if (string::HasPrefix(p, "CPU")) { - places.emplace_back(platform::CPUPlace()); - } else if (string::HasPrefix(p, "GPU")) { + places.emplace_back(platform::CPUPlace()); + #ifdef PADDLE_WITH_CUDA - auto pos = string::RFind(p, ':', string::Piece::npos); - auto number = device.substr(pos + 1); - places.emplace_back(platform::CUDAPlace(std::stoi(number))); + int count = platform::GetCUDADeviceCount(); + for (int i = 0; i < count; ++i) { + places.emplace_back(platform::CUDAPlace(i)); + } #else - LOG(WARNING) - << "'GPU' is not supported, Please re-compile with WITH_GPU option"; + LOG(WARNING) + << "'GPU' is not supported, Please re-compile with WITH_GPU option"; #endif - } else { - return false; - } - } - if (std::find_if(places.begin(), places.end(), - [&](const platform::Place &place) { - return platform::is_cpu_place(place); - }) == places.end()) { - places.emplace_back(platform::CPUPlace()); - LOG(WARNING) << "Not specified CPU device, create CPU by Default."; - } platform::DeviceContextPool::Init(places); - // framework::UseALL(); - return true; } void InitGLOG(const std::string &prog_name) { diff --git a/paddle/framework/init.h b/paddle/framework/init.h index 9c84a03ded52632047841f95badbcf44bc9f48d1..c8fd964d006baf729888414ded2aec85ba5a024e 100644 --- a/paddle/framework/init.h +++ b/paddle/framework/init.h @@ -24,7 +24,7 @@ void InitGflags(std::vector &argv); void InitGLOG(const std::string &prog_name); -bool InitDevices(const std::vector &devices); +void InitDevices(); } // namespace framework } // namespace paddle diff --git a/paddle/framework/init_test.cc b/paddle/framework/init_test.cc index f0788051d4855a175d2d7ea1f1a0805c776c462b..f837a965d3be7d40c20803ae4462b3bfd91bffd0 100644 --- a/paddle/framework/init_test.cc +++ b/paddle/framework/init_test.cc @@ -14,18 +14,13 @@ limitations under the License. */ #include "gtest/gtest.h" #include "paddle/framework/init.h" +#include "paddle/platform/device_context.h" -TEST(Init, InitDevices) { +TEST(InitDevices, CPU) { using paddle::framework::InitDevices; - std::vector ds1 = {"CPU"}; - ASSERT_EQ(InitDevices(ds1), true); + using paddle::platform::DeviceContextPool; -#ifdef PADDLE_WITH_CUDA - std::vector ds2 = {"CPU", "GPU:0", "GPU:1"}; - ASSERT_EQ(InitDevices(ds2), true); - - // test re-init - std::vector ds3 = {"GPU:0", "GPU:1"}; - ASSERT_EQ(InitDevices(ds3), true); -#endif + InitDevices(); + DeviceContextPool& pool = DeviceContextPool::Instance(); + ASSERT_GE(pool.size(), 1U); } diff --git a/paddle/framework/lod_tensor.cc b/paddle/framework/lod_tensor.cc index 506fde440533e83f093f26484f925416b89c75a0..7ae94c646537e0d7c4687b949a1b06cd3a7f3404 100644 --- a/paddle/framework/lod_tensor.cc +++ b/paddle/framework/lod_tensor.cc @@ -44,9 +44,19 @@ std::ostream &operator<<(std::ostream &os, const LoD &lod) { } std::ostream &operator<<(std::ostream &os, const LoDTensor &t) { - PADDLE_ENFORCE(platform::is_cpu_place(t.place())); PADDLE_ENFORCE(t.type().hash_code() == typeid(float).hash_code()); + if (!platform::is_cpu_place(t.place())) { + LoDTensor tt; + framework::Copy(t, platform::CPUPlace(), &tt); + platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); + auto &dev_ctx = *pool.Get(t.place()); + dev_ctx.Wait(); + + os << tt; + return os; + } + os << "dim: " << t.dims() << "\n"; os << "lod: " << t.lod() << "\n"; @@ -211,38 +221,23 @@ void DeserializeFromStream(std::istream &is, LoDTensor *tensor, DeserializeFromStream(is, static_cast(tensor), dev_ctx); } +// TODO(tonyyang-svail): make this function support LoD std::vector LoDTensor::SplitLoDTensor( const std::vector places) const { check_memory_size(); - // PADDLE_ENFORCE(lod().empty() || (lod().size() == 1 && lod()[0].empty()) - // , "Disable parallel lod for now"); PADDLE_ENFORCE(lod().empty(), "Disable parallel lod for now"); PADDLE_ENFORCE(dims()[0] % places.size() == 0, "Batch size should be divided by places size"); std::vector lods; for (size_t place_idx = 0; place_idx < places.size(); ++place_idx) { - size_t begin = place_idx * dims()[0] / places.size(); - size_t end = (place_idx + 1) * dims()[0] / places.size(); - auto src = Slice(static_cast(begin), static_cast(end)); + int begin = place_idx * dims()[0] / places.size(); + int end = (place_idx + 1) * dims()[0] / places.size(); - LoDTensor dst; - dst.Resize(src.dims()); + auto src = Slice(begin, end); auto &dst_place = places[place_idx]; - auto dst_ptr = dst.mutable_data(dst_place, src.type()); - - // TODO(tonyyang-svail): - // change the following to framework::Copy - auto src_place = src.place(); - auto src_ptr = src.data(); - auto size = src.numel() * SizeOfType(src.type()); - if (platform::is_cpu_place(src_place) && - platform::is_cpu_place(dst_place)) { - memory::Copy(boost::get(dst_place), dst_ptr, - boost::get(src_place), src_ptr, size); - } else { - PADDLE_THROW("Not Implemented"); - } + LoDTensor dst; + framework::Copy(src, dst_place, &dst); lods.emplace_back(dst); } @@ -250,28 +245,30 @@ std::vector LoDTensor::SplitLoDTensor( return lods; } +// TODO(tonyyang-svail): make this function support LoD void LoDTensor::MergeLoDTensor( - const std::vector &lod_tensors, platform::Place place) { - PADDLE_ENFORCE(platform::is_cpu_place(place)); + const std::vector &lod_tensors, + platform::Place dst_place) { PADDLE_ENFORCE(!lod_tensors.empty()); - framework::DDim new_dim = lod_tensors[0]->dims(); std::type_index new_type = lod_tensors[0]->type(); + auto new_layout = lod_tensors[0]->layout(); for (auto *lod : lod_tensors) { PADDLE_ENFORCE(new_dim == lod->dims()); PADDLE_ENFORCE(new_type == lod->type()); - PADDLE_ENFORCE(platform::is_cpu_place(lod->place())); + PADDLE_ENFORCE(new_layout == lod->layout()); } new_dim[0] *= lod_tensors.size(); Resize(new_dim); + set_layout(new_layout); - auto *dst_ptr = reinterpret_cast(mutable_data(place, new_type)); + mutable_data(dst_place, new_type); + int begin = 0; for (auto *src : lod_tensors) { - auto size = src->numel() * SizeOfType(src->type()); - memory::Copy(boost::get(place), dst_ptr, - boost::get(src->place()), - src->data(), size); - dst_ptr += size; + int end = begin + src->dims()[0]; + auto dst = Slice(begin, end); + framework::Copy(*src, dst_place, &dst); + begin = end; } } diff --git a/paddle/framework/lod_tensor_test.cc b/paddle/framework/lod_tensor_test.cc index 52b87f48e5340ce4e265e2e77577f58daae039d2..baad9c6f98ac135c3650fe3113522850328c1298 100644 --- a/paddle/framework/lod_tensor_test.cc +++ b/paddle/framework/lod_tensor_test.cc @@ -115,5 +115,21 @@ TEST(LoD, AppendLoD) { EXPECT_EQ(origin, expected); } +TEST(LoD, ToAbsOffset) { + LoD relative_lod; + relative_lod.push_back(std::vector({0, 2})); + relative_lod.push_back(std::vector({0, 1, 3})); + relative_lod.push_back(std::vector({0, 2, 4, 5})); + + LoD abs_lod = paddle::framework::ToAbsOffset(relative_lod); + + LoD expected; + expected.push_back(std::vector({0, 5})); + expected.push_back(std::vector({0, 2, 5})); + expected.push_back(std::vector({0, 2, 4, 5})); + + EXPECT_EQ(abs_lod, expected); +} + } // namespace framework } // namespace paddle diff --git a/paddle/framework/op_desc.h b/paddle/framework/op_desc.h index 4cf784a0d0d319d09caa27b4e2b589bd7ac4f324..a5ffb162928bfd355d35d3f9b63aab59a88dd061 100644 --- a/paddle/framework/op_desc.h +++ b/paddle/framework/op_desc.h @@ -129,7 +129,7 @@ class OpDesc { } proto::OpDesc desc_; - // input arg name => output variable names + // input arg name => input variable names VariableNameMap inputs_; // output arg name => output variable names VariableNameMap outputs_; diff --git a/paddle/framework/operator_test.cc b/paddle/framework/operator_test.cc index d002f3f238862a53ad7286570e2d0bbd2334c584..b69d7c7a7406eb3e18d385c568cb9c21b9b4107b 100644 --- a/paddle/framework/operator_test.cc +++ b/paddle/framework/operator_test.cc @@ -69,7 +69,7 @@ REGISTER_OP_WITHOUT_GRADIENT(test_operator, paddle::framework::OpWithoutKernelCheckerMaker); TEST(OperatorBase, all) { - paddle::framework::InitDevices({"CPU"}); + paddle::framework::InitDevices(); paddle::framework::proto::OpDesc op_desc; op_desc.set_type("test_operator"); BuildVar("input", {"IN1"}, op_desc.add_inputs()); @@ -195,7 +195,7 @@ REGISTER_OP_CPU_KERNEL(op_with_kernel, // test with single input TEST(OpKernel, all) { - paddle::framework::InitDevices({"CPU"}); + paddle::framework::InitDevices(); paddle::framework::proto::OpDesc op_desc; op_desc.set_type("op_with_kernel"); BuildVar("x", {"IN1"}, op_desc.add_inputs()); @@ -225,7 +225,7 @@ REGISTER_OP_CPU_KERNEL(op_multi_inputs_with_kernel, TEST(OpKernel, multi_inputs) { using namespace paddle::framework; - paddle::framework::InitDevices({"CPU"}); + paddle::framework::InitDevices(); proto::OpDesc op_desc; op_desc.set_type("op_multi_inputs_with_kernel"); @@ -264,7 +264,7 @@ class OperatorClone : public paddle::framework::OperatorBase { }; TEST(Operator, Clone) { - paddle::framework::InitDevices({"CPU"}); + paddle::framework::InitDevices(); OperatorClone a("ABC", paddle::framework::VariableNameMap{}, paddle::framework::VariableNameMap{}, paddle::framework::AttributeMap{}); diff --git a/paddle/framework/tensor_util.h b/paddle/framework/tensor_util.h index 7c56ccf17f94e29d06f529629c47f61b93d2bd22..f541d2ba693a169d074c070dd794a2dd4e52aabf 100644 --- a/paddle/framework/tensor_util.h +++ b/paddle/framework/tensor_util.h @@ -31,9 +31,10 @@ namespace framework { * * @note Copy supports CPU <-> GPU, GPU <-> GPU. */ - inline void Copy(const Tensor& src, const platform::Place& dst_place, const platform::DeviceContext& ctx, Tensor* dst) { + VLOG(3) << "Copy " << src.dims() << " from " << src.place() << " to " + << dst_place; src.check_memory_size(); dst->Resize(src.dims()); @@ -88,26 +89,25 @@ inline void Copy(const Tensor& src, const platform::Place& dst_place, } /** - * @brief Copy supports CPU <-> CPU + * @brief Wrapper on + * Copy(const Tensor& src, const platform::Place& dst_place, + * const platform::DeviceContext& ctx, Tensor* dst); + * + * @param[in] src The external tensor. + * @param[in] dst_place The dst place. + * + * @note Copy supports CPU <-> GPU, GPU <-> GPU. */ inline void Copy(const Tensor& src, const platform::Place& dst_place, Tensor* dst) { - src.check_memory_size(); - dst->Resize(src.dims()); - dst->set_layout(src.layout()); - - auto src_place = src.place(); - auto src_ptr = src.data(); - - auto dst_ptr = dst->mutable_data(dst_place, src.type()); - - auto size = src.numel() * SizeOfType(src.type()); - - PADDLE_ENFORCE(platform::is_cpu_place(src_place) && - platform::is_cpu_place(dst_place)); - - memory::Copy(boost::get(dst_place), dst_ptr, - boost::get(src_place), src_ptr, size); + platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); + const platform::DeviceContext* dev_ctx; + if (platform::is_gpu_place(src.place())) { + dev_ctx = pool.Get(src.place()); + } else { + dev_ctx = pool.Get(dst_place); + } + Copy(src, dst_place, *dev_ctx, dst); } /** diff --git a/paddle/framework/var_desc.cc b/paddle/framework/var_desc.cc index aeab18d7214f8d9dd79bc3d2e0322490445b3b49..62ab6593ef23c195e3caa2336574796ecaf35bc8 100644 --- a/paddle/framework/var_desc.cc +++ b/paddle/framework/var_desc.cc @@ -74,7 +74,7 @@ const proto::TensorDesc &VarDesc::tensor_desc() const { case proto::VarDesc::LOD_TENSOR_ARRAY: return desc_.tensor_array().tensor(); default: - PADDLE_THROW("The type of var '", this->Name(), "' is unsupported."); + PADDLE_THROW("The type of var %s is unsupported.", this->Name()); } } diff --git a/paddle/inference/inference.cc b/paddle/inference/inference.cc index 49e39358e81bbee64a618be88ee0fca6aa438b93..37b8b20ddfcf2566b8410f950308309e5b2b2a7c 100644 --- a/paddle/inference/inference.cc +++ b/paddle/inference/inference.cc @@ -169,7 +169,7 @@ void InferenceEngine::Execute(const std::vector& feeds, } auto* place = new platform::CPUPlace(); - framework::InitDevices({"CPU"}); + framework::InitDevices(); framework::Executor* executor = new framework::Executor(*place); framework::Scope* scope = new framework::Scope(); diff --git a/paddle/memory/memory.cc b/paddle/memory/memory.cc index c4bb6baee7ebf2941cee5915ca2723c298689261..1a73a94567e45b81a0b148965a834f03c7407ffe 100644 --- a/paddle/memory/memory.cc +++ b/paddle/memory/memory.cc @@ -114,5 +114,21 @@ void Free(platform::CUDAPlace place, void* p) { #endif +size_t Usage::operator()(const platform::CPUPlace& cpu) const { + return Used(cpu); +} + +size_t Usage::operator()(const platform::CUDAPlace& gpu) const { +#ifdef PADDLE_WITH_CUDA + return Used(gpu); +#else + PADDLE_THROW("'CUDAPlace' is not supported in CPU only device."); +#endif +} + +size_t memory_usage(const platform::Place& p) { + return boost::apply_visitor(Usage(), p); +} + } // namespace memory } // namespace paddle diff --git a/paddle/memory/memory.h b/paddle/memory/memory.h index 11bbb881874ec50e1132547336fc6fb6b42bcc4f..7012b6d331d0c4631a3d120fbaf3db7c97298ac7 100644 --- a/paddle/memory/memory.h +++ b/paddle/memory/memory.h @@ -54,6 +54,13 @@ void Free(Place place, void* ptr); template size_t Used(Place place); +struct Usage : public boost::static_visitor { + size_t operator()(const platform::CPUPlace& cpu) const; + size_t operator()(const platform::CUDAPlace& gpu) const; +}; + +size_t memory_usage(const platform::Place& p); + /** * \brief Free memory block in one place. * diff --git a/paddle/memory/memory_test.cc b/paddle/memory/memory_test.cc index f476bf71264da59a5c546968f4689145e1d8801b..b3f699f9b7eff54c06ff69023db082380c83467a 100644 --- a/paddle/memory/memory_test.cc +++ b/paddle/memory/memory_test.cc @@ -44,6 +44,9 @@ TEST(BuddyAllocator, CPUAllocation) { EXPECT_NE(p, nullptr); + paddle::platform::Place place = cpu; + EXPECT_EQ(paddle::memory::Used(cpu), paddle::memory::memory_usage(place)); + paddle::memory::Free(cpu, p); } @@ -99,6 +102,9 @@ TEST(BuddyAllocator, GPUAllocation) { EXPECT_NE(p, nullptr); + paddle::platform::Place place = gpu; + EXPECT_EQ(paddle::memory::Used(gpu), paddle::memory::memory_usage(place)); + paddle::memory::Free(gpu, p); } diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt index f1ce52332327ed2a9f290ccf412199fd5a6bbb67..5889a50db09534b30f0f57b4e659df440901f3b1 100644 --- a/paddle/operators/CMakeLists.txt +++ b/paddle/operators/CMakeLists.txt @@ -151,6 +151,7 @@ op_library(lstm_op DEPS sequence2batch lstm_compute) op_library(conv_transpose_op DEPS vol2col) op_library(gru_op DEPS sequence2batch gru_compute) op_library(recurrent_op DEPS executor) +op_library(warpctc_op DEPS dynload_warpctc sequence_padding math_function) op_library(cos_sim_op DEPS cos_sim_functor) op_library(parallel_do_op DEPS executor) # FIXME(typhoonzero): save/load depends lodtensor serialization functions diff --git a/paddle/operators/beam_search_op.cc b/paddle/operators/beam_search_op.cc index 2e0513b37a24b9737532b3a71f8f0724fbdd2c13..ed2e7b738acd81140467ff22ad077155bba2fde1 100644 --- a/paddle/operators/beam_search_op.cc +++ b/paddle/operators/beam_search_op.cc @@ -39,7 +39,7 @@ void BeamSearch::operator()(const framework::LoDTensor &pre_ids, std::map> hash; framework::LoD new_lod; - auto *ids_data = selected_ids->mutable_data(platform::CPUPlace()); + auto *ids_data = selected_ids->mutable_data(platform::CPUPlace()); auto *scores_data = selected_scores->mutable_data(platform::CPUPlace()); @@ -66,7 +66,7 @@ void BeamSearch::operator()(const framework::LoDTensor &pre_ids, void BeamSearch::PruneEndidCandidates(const framework::LoDTensor &pre_ids, std::vector> *items) { - auto *pre_ids_data = pre_ids.data(); + auto *pre_ids_data = pre_ids.data(); for (size_t offset = 0; offset < items->size(); offset++) { auto prefix_id = pre_ids_data[offset]; @@ -127,7 +127,7 @@ bool BeamSearch::NextItemSet(std::vector *items) { auto abs_lod = framework::ToAbsOffset(ids.lod()); PADDLE_ENFORCE_GE(source_abs_two_level_lod.size(), 2UL); - auto *ids_data = ids.data(); + auto *ids_data = ids.data(); auto *scores_data = scores.data(); size_t instance_dim = 1; diff --git a/paddle/operators/conv_op.cc b/paddle/operators/conv_op.cc index ad84524e1785e3b6b4586c83001852b6dba7afe8..1468e3eb960a2b7c2e7af83ff701338596606922 100644 --- a/paddle/operators/conv_op.cc +++ b/paddle/operators/conv_op.cc @@ -230,7 +230,6 @@ void ConvOpGrad::InferShape(framework::InferShapeContext* ctx) const { namespace ops = paddle::operators; REGISTER_OP(conv2d, ops::ConvOp, ops::Conv2DOpMaker, conv2d_grad, ops::ConvOpGrad); -namespace ops = paddle::operators; REGISTER_OP(conv3d, ops::ConvOp, ops::Conv3DOpMaker, conv3d_grad, ops::ConvOpGrad); diff --git a/paddle/operators/math/CMakeLists.txt b/paddle/operators/math/CMakeLists.txt index 7ebcfb9ab9f30e3b0f13d3646a59d008335b232d..fd59eef7d650b48feae68c89be54ec4e48cbcc7e 100644 --- a/paddle/operators/math/CMakeLists.txt +++ b/paddle/operators/math/CMakeLists.txt @@ -12,6 +12,7 @@ if(WITH_GPU) nv_library(vol2col SRCS vol2col.cc vol2col.cu DEPS device_context tensor) nv_library(context_project SRCS context_project.cc context_project.cu DEPS device_context math_function) nv_library(sequence2batch SRCS sequence2batch.cc sequence2batch.cu DEPS device_context tensor) + nv_library(sequence_padding SRCS sequence_padding.cc sequence_padding.cu DEPS lod_tensor device_context) nv_library(lstm_compute SRCS lstm_compute.cc lstm_compute.cu DEPS device_context activation_functions) nv_library(maxouting SRCS maxouting.cc maxouting.cu DEPS device_context) nv_library(unpooling SRCS unpooling.cc unpooling.cu DEPS device_context) @@ -27,6 +28,7 @@ else() cc_library(vol2col SRCS vol2col.cc DEPS device_context tensor) cc_library(context_project SRCS context_project.cc DEPS device_context math_function) cc_library(sequence2batch SRCS sequence2batch.cc DEPS device_context tensor) + cc_library(sequence_padding SRCS sequence_padding.cc DEPS lod_tensor device_context) cc_library(lstm_compute SRCS lstm_compute.cc DEPS device_context activation_functions) cc_library(maxouting SRCS maxouting.cc DEPS device_context) cc_library(unpooling SRCS unpooling.cc DEPS device_context) @@ -38,3 +40,4 @@ cc_test(math_function_test SRCS math_function_test.cc DEPS math_function tensor) cc_test(selected_rows_functor_test SRCS selected_rows_functor_test.cc DEPS selected_rows_functor) cc_test(im2col_test SRCS im2col_test.cc DEPS math_function tensor) cc_test(vol2col_test SRCS vol2col_test.cc DEPS vol2col tensor) +cc_test(sequence_padding_test SRCS sequence_padding_test.cc DEPS sequence_padding) diff --git a/paddle/operators/math/im2col_test.cc b/paddle/operators/math/im2col_test.cc index a0c02817fd8b03262f539288ec7e760fdb35bb05..1ba24325ffe3922081f59abfc1c67c95b514bcfa 100644 --- a/paddle/operators/math/im2col_test.cc +++ b/paddle/operators/math/im2col_test.cc @@ -14,7 +14,6 @@ limitations under the License. */ #include "paddle/operators/math/im2col.h" #include -#include template void testIm2col() { @@ -102,6 +101,7 @@ void testIm2col() { Copy(output_ocf, paddle::platform::CPUPlace(), *context, &output_tmp); out_ocf_ptr = output_tmp.data(); } + for (int i = 0; i < 6; ++i) { EXPECT_EQ(out_ocf_ptr[i], out_ocf_data[i]); } @@ -154,6 +154,9 @@ void testIm2col() { for (int i = 0; i < 6; ++i) { EXPECT_EQ(in_ptr[i], col2im_data[i]); } + + delete place; + delete context; } TEST(math, im2col) { diff --git a/paddle/operators/math/sequence_padding.cc b/paddle/operators/math/sequence_padding.cc new file mode 100644 index 0000000000000000000000000000000000000000..fd66455eaef60209b9ca334480951a9f7687729b --- /dev/null +++ b/paddle/operators/math/sequence_padding.cc @@ -0,0 +1,144 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/math/sequence_padding.h" + +namespace paddle { +namespace operators { +namespace math { + +template +class PaddingLoDTensorFunctor { + public: + void operator()(const platform::CPUDeviceContext& context, + const framework::LoDTensor& seq, framework::Tensor& padding, + bool norm_by_times) { + auto lod = seq.lod(); + PADDLE_ENFORCE_GT(lod.size(), 0UL, + "The LoD of LoDTensor seq should not be null."); + + const size_t level = 0; + framework::LoD abs_offset_lod = framework::ToAbsOffset(lod); + + auto seq_dims = seq.dims(); + PADDLE_ENFORCE_EQ(seq_dims[0], abs_offset_lod[level].back(), + "The first dimension of LoDTensor seq should be " + "equal to the sum of all sequences's length."); + + auto padding_dims = padding.dims(); + PADDLE_ENFORCE_EQ(padding_dims.size(), 3UL, + "The input padding should be a 3-D Tensor of shape " + "[max_sequence_length, num_sequences, sequence_width]."); + + const size_t max_sequence_length = MaximumSequenceLength(lod, level); + PADDLE_ENFORCE_EQ(padding_dims[0], max_sequence_length, + "The first dimension of Tensor padding should be the " + "maximum length of all sequences in LoDTensor seq."); + + const size_t num_sequences = abs_offset_lod[level].size() - 1; + PADDLE_ENFORCE_EQ(padding_dims[1], num_sequences, + "The second dimension of Tensor padding should be the " + "number of sequences in LoDTensor seq."); + + const size_t sequence_width = seq.numel() / seq_dims[0]; + PADDLE_ENFORCE_EQ(padding_dims[2], sequence_width, + "The third dimension of Tensor padding should be the " + "width of sequence in LoDTensor seq."); + + const T* seq_data = seq.data(); + T* padding_data = padding.data(); + for (size_t i = 0; i < max_sequence_length; ++i) { + for (size_t j = 0; j < num_sequences; ++j) { + size_t start_pos = abs_offset_lod[level][j]; + size_t sequence_length = abs_offset_lod[level][j + 1] - start_pos; + if (i < sequence_length) { + // i > 0 => sequence_length > 0 + T scale = + norm_by_times ? (1.0f / static_cast(sequence_length)) : 1.0f; + for (size_t k = 0; k < sequence_width; ++k) { + padding_data[(i * num_sequences + j) * sequence_width + k] = + seq_data[(start_pos + i) * sequence_width + k] * scale; + } + } else { + memset(padding_data + (i * num_sequences + j) * sequence_width, 0, + sequence_width * sizeof(T)); + } + } + } + } +}; + +template +class UnpaddingLoDTensorFunctor { + public: + void operator()(const platform::CPUDeviceContext& context, + framework::LoDTensor& seq, const framework::Tensor& padding, + bool norm_by_times) { + auto lod = seq.lod(); + PADDLE_ENFORCE_GT(lod.size(), 0UL, + "The LoD of LoDTensor seq should not be null."); + + const size_t level = 0; + framework::LoD abs_offset_lod = framework::ToAbsOffset(lod); + + auto seq_dims = seq.dims(); + PADDLE_ENFORCE_EQ(seq_dims[0], abs_offset_lod[level].back(), + "The first dimension of LoDTensor seq should be " + "equal to the sum of all sequences's length."); + + auto padding_dims = padding.dims(); + PADDLE_ENFORCE_EQ(padding_dims.size(), 3UL, + "The input padding should be a 3-D Tensor of shape " + "[max_sequnece_length, num_sequences, sequence_width]."); + + const size_t max_sequence_length = MaximumSequenceLength(lod, level); + PADDLE_ENFORCE_EQ(padding_dims[0], max_sequence_length, + "The first dimension of Tensor padding should be " + "the maximum length of all sequences in LoDTensor seq."); + + const size_t num_sequences = abs_offset_lod[level].size() - 1; + PADDLE_ENFORCE_EQ(padding_dims[1], num_sequences, + "The second dimension of Tensor padding should be " + "the number of sequences in LoDTensor seq."); + + const size_t sequence_width = seq.numel() / seq_dims[0]; + PADDLE_ENFORCE_EQ(padding_dims[2], sequence_width, + "The third dimension of Tensor padding should be the " + "width of sequence in LoDTensor seq."); + + const T* padding_data = padding.data(); + T* seq_data = seq.data(); + for (size_t i = 0; i < num_sequences; ++i) { + size_t start_pos = abs_offset_lod[level][i]; + size_t sequence_length = abs_offset_lod[level][i + 1] - start_pos; + for (size_t j = 0; j < sequence_length; ++j) { + // sequence_width > j > 0 + T scale = + norm_by_times ? (1.0f / static_cast(sequence_length)) : 1.0f; + for (size_t k = 0; k < sequence_width; ++k) { + seq_data[(start_pos + j) * sequence_width + k] = + padding_data[(j * num_sequences + i) * sequence_width + k] * + scale; + } + } + } + } +}; + +template class PaddingLoDTensorFunctor; +template class UnpaddingLoDTensorFunctor; + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/math/sequence_padding.cu b/paddle/operators/math/sequence_padding.cu new file mode 100644 index 0000000000000000000000000000000000000000..e4be178f81581dea2e84cf488b01d5f7f4cc0030 --- /dev/null +++ b/paddle/operators/math/sequence_padding.cu @@ -0,0 +1,209 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/math/sequence_padding.h" + +namespace paddle { +namespace operators { +namespace math { + +template +__global__ void SequencePaddingKernel(T* padding, T* sequence, + const size_t* sequence_start_positions, + const size_t sequence_width, + const size_t max_sequence_length, + const size_t num_sequences) { + size_t padding_idx = blockIdx.y; + size_t start_pos = sequence_start_positions[padding_idx]; + size_t sequence_length = + sequence_start_positions[padding_idx + 1] - start_pos; + + size_t sequence_idx = blockIdx.x * blockDim.y + threadIdx.y; + size_t padding_base_idx = + (sequence_idx * num_sequences + padding_idx) * sequence_width; + size_t sequence_base_idx = (start_pos + sequence_idx) * sequence_width; + + if (sequence_idx < sequence_length) { + T scale = NormByTimes ? (1.0f / static_cast(sequence_length)) : 1.0f; + if (Padding) { + /* sequence -> padding */ + for (size_t i = threadIdx.x; i < sequence_width; i += blockDim.x) { + padding[padding_base_idx + i] = scale * sequence[sequence_base_idx + i]; + } + } else { + /* padding -> sequence */ + for (size_t i = threadIdx.x; i < sequence_width; i += blockDim.x) { + sequence[sequence_base_idx + i] = scale * padding[padding_base_idx + i]; + } + } + } else if (sequence_idx < max_sequence_length) { + if (Padding) { + /* sequence -> padding */ + for (size_t i = threadIdx.x; i < sequence_width; i += blockDim.x) { + padding[padding_base_idx + i] = 0; + } + } + } +} + +template +class PaddingLoDTensorFunctor { + public: + void operator()(const platform::CUDADeviceContext& context, + const framework::LoDTensor& seq, framework::Tensor& padding, + bool norm_by_times) { + auto lod = seq.lod(); + PADDLE_ENFORCE_GT(lod.size(), 0UL, + "The lod of LoDTensor seq should not be null."); + + const size_t level = 0; + framework::LoD abs_offset_lod = framework::ToAbsOffset(lod); + + auto seq_dims = seq.dims(); + PADDLE_ENFORCE_EQ(seq_dims[0], abs_offset_lod[level].back(), + "The first dimension of LoDTensor seq should be " + "equal to the sum of all sequences's length."); + + auto padding_dims = padding.dims(); + PADDLE_ENFORCE_EQ(padding_dims.size(), 3UL, + "The input padding should be a 3-D Tensor of shape " + "[max_sequence_length, num_sequences, sequence_width]."); + + size_t max_sequence_length = MaximumSequenceLength(lod, level); + PADDLE_ENFORCE_EQ(padding_dims[0], max_sequence_length, + "The first dimension of Tensor padding should be the " + "maximum length of all sequences in LoDTensor seq."); + + const size_t num_sequences = abs_offset_lod[level].size() - 1; + PADDLE_ENFORCE_EQ(padding_dims[1], num_sequences, + "The second dimension of Tensor padding should be the " + "number of sequences in LoDTensor seq."); + + const size_t sequence_width = seq.numel() / seq_dims[0]; + PADDLE_ENFORCE_EQ(padding_dims[2], sequence_width, + "The third dimension of Tensor padding should be the " + "width of sequence in LoDTensor seq."); + + if (!norm_by_times && num_sequences == 1UL) { + Copy(seq, context.GetPlace(), context, &padding); + padding.Resize(padding_dims); + return; + } + + const size_t kBlockSize = 512; + + /* At least use 32 threads to copy sequence_width elements, + * and at least 8 elements for each thread. + */ + size_t block_dim_x = + std::min(((((sequence_width + 7) >> 3) + 31) >> 5) << 5, kBlockSize); + size_t block_dim_y = kBlockSize / block_dim_x; + dim3 threads(block_dim_x, block_dim_y); + + size_t grid_dim_x = (max_sequence_length + block_dim_y - 1) / block_dim_y; + size_t grid_dim_y = num_sequences; + dim3 grid(grid_dim_x, grid_dim_y); + + const T* seq_data = seq.data(); + T* padding_data = padding.data(); + if (norm_by_times) { + SequencePaddingKernel<<>>( + padding_data, const_cast(seq_data), abs_offset_lod[level].data(), + sequence_width, max_sequence_length, num_sequences); + } else { + SequencePaddingKernel<<>>( + padding_data, const_cast(seq_data), abs_offset_lod[level].data(), + sequence_width, max_sequence_length, num_sequences); + } + } +}; + +template +class UnpaddingLoDTensorFunctor { + public: + void operator()(const platform::CUDADeviceContext& context, + framework::LoDTensor& seq, const framework::Tensor& padding, + bool norm_by_times) { + auto lod = seq.lod(); + PADDLE_ENFORCE_GT(lod.size(), 0UL, + "The lod of LoDTensor seq should not be null."); + + const size_t level = 0; + framework::LoD abs_offset_lod = framework::ToAbsOffset(lod); + + auto seq_dims = seq.dims(); + PADDLE_ENFORCE_EQ(seq_dims[0], abs_offset_lod[level].back(), + "The first dimension of LoDTensor seq should be " + "equal to the sum of all sequences's length."); + + auto padding_dims = padding.dims(); + PADDLE_ENFORCE_EQ(padding_dims.size(), 3UL, + "The input padding should be a 3-D Tensor of shape " + "[max_sequnece_length, num_sequences, sequence_width]."); + + size_t max_sequence_length = MaximumSequenceLength(lod, level); + PADDLE_ENFORCE_EQ(padding_dims[0], max_sequence_length, + "The first dimension of Tensor padding should be " + "the maximum length of all sequences in LoDTensor seq."); + + const size_t num_sequences = abs_offset_lod[level].size() - 1; + PADDLE_ENFORCE_EQ(padding_dims[1], num_sequences, + "The second dimension of Tensor padding should be " + "the number of sequences in LoDTensor seq."); + + const size_t sequence_width = seq.numel() / seq_dims[0]; + PADDLE_ENFORCE_EQ(padding_dims[2], sequence_width, + "The third dimension of Tensor padding should be the " + "width of sequence in LoDTensor seq."); + + if (!norm_by_times && num_sequences == 1UL) { + Copy(padding, context.GetPlace(), context, &seq); + seq.Resize(seq_dims); + return; + } + + const size_t kBlockSize = 512; + + /* At least use 32 threads to copy sequence_width elements, + * and at least 8 elements for each thread. + */ + size_t block_dim_x = + std::min(((((sequence_width + 7) >> 3) + 31) >> 5) << 5, kBlockSize); + size_t block_dim_y = kBlockSize / block_dim_x; + dim3 threads(block_dim_x, block_dim_y); + + size_t grid_dim_x = (max_sequence_length + block_dim_y - 1) / block_dim_y; + size_t grid_dim_y = num_sequences; + dim3 grid(grid_dim_x, grid_dim_y); + + const T* padding_data = padding.data(); + T* seq_data = seq.data(); + if (norm_by_times) { + SequencePaddingKernel<<>>( + const_cast(padding_data), seq_data, abs_offset_lod[level].data(), + sequence_width, max_sequence_length, num_sequences); + } else { + SequencePaddingKernel<<>>( + const_cast(padding_data), seq_data, abs_offset_lod[level].data(), + sequence_width, max_sequence_length, num_sequences); + } + } +}; + +template class PaddingLoDTensorFunctor; +template class UnpaddingLoDTensorFunctor; + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/math/sequence_padding.h b/paddle/operators/math/sequence_padding.h new file mode 100644 index 0000000000000000000000000000000000000000..8f586c5eb469ea260dbdf9cc4e7f9b4b4a46a8cc --- /dev/null +++ b/paddle/operators/math/sequence_padding.h @@ -0,0 +1,79 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/framework/lod_tensor.h" +#include "paddle/platform/device_context.h" + +namespace paddle { +namespace operators { +namespace math { + +inline static size_t MaximumSequenceLength(const framework::LoD& lod, + const size_t level) { + const size_t num_sequences = lod[level].size() - 1; + size_t max_sequence_length = 0; + framework::LoD abs_offset_lod = framework::ToAbsOffset(lod); + for (size_t i = 0; i < num_sequences; ++i) { + max_sequence_length = + std::max(max_sequence_length, + abs_offset_lod[level][i + 1] - abs_offset_lod[level][i]); + } + return max_sequence_length; +} + +/* + * \brief Padding/Unpadding LoDTensor to/from normal Tensor of the shape + * [max_sequence_length, num_sequences, sequence_width]. + * + * Padding sequence: + * padding[i] = seq[lod[level][i]] + * Unpadding sequence: + * seq[lod[level][i]] = padding[i] + * + * All sequences will be padded to the same length and stored in a transposed + * shape. + * Example: + * seq (s0, s0, s0, s0; s1, s1; s2, s2, s2; s3) + * padding (s0, s1, s2, s3; s0, s1, s2, 0; s0, 0, s2, 0; s0, 0, 0, 0) + * + * \param context device context of this functor. + * \param seq LoDTensor which is stored in sequence format, the shape + * is [total_sequence_length, sequence_width] where + * total_sequence_length is the sum of all sequences' + * length. + * \param padding Tensor which is padded to the same length, the shape is + * [max_sequence_length, num_sequences, sequence_width]. + * \param norm_by_times whether dividing sequence's length. + * + * \note transposition is also done in this functor. + */ +template +class PaddingLoDTensorFunctor { + public: + void operator()(const DeviceContext& context, const framework::LoDTensor& seq, + framework::Tensor& padding, bool norm_by_times); +}; + +template +class UnpaddingLoDTensorFunctor { + public: + void operator()(const DeviceContext& context, framework::LoDTensor& seq, + const framework::Tensor& padding, bool norm_by_times); +}; + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/math/sequence_padding_test.cc b/paddle/operators/math/sequence_padding_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..9799bcd65dc65d5741813374c68a2640eaf4556c --- /dev/null +++ b/paddle/operators/math/sequence_padding_test.cc @@ -0,0 +1,104 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/math/sequence_padding.h" +#include + +template +void TestSequencePadding(const paddle::framework::LoD& lod, + const size_t sequence_width) { + paddle::framework::LoDTensor cpu_seq; + paddle::framework::LoDTensor cpu_seq_back; + paddle::framework::LoDTensor seq; + paddle::framework::LoDTensor seq_back; + paddle::framework::Tensor padding; + + const size_t level = lod.size() - 1; + auto seq_dims = + paddle::framework::make_ddim({static_cast(lod[level].back()), + static_cast(sequence_width)}); + + cpu_seq.set_lod(lod); + cpu_seq.mutable_data(seq_dims, paddle::platform::CPUPlace()); + for (size_t i = 0; i < cpu_seq.numel(); ++i) { + cpu_seq.data()[i] = static_cast(i); + } + + auto* place = new Place(); + DeviceContext* context = new DeviceContext(*place); + if (paddle::platform::is_cpu_place(*place)) { + seq = cpu_seq; + } else { + Copy(cpu_seq, *place, *context, &seq); + seq.set_lod(lod); + } + + const size_t max_sequence_length = + paddle::operators::math::MaximumSequenceLength(lod, level); + const size_t num_sequences = lod[level].size() - 1; + auto padding_dims = + paddle::framework::make_ddim({static_cast(max_sequence_length), + static_cast(num_sequences), + static_cast(sequence_width)}); + padding.mutable_data(padding_dims, *place); + paddle::operators::math::PaddingLoDTensorFunctor()( + *context, seq, padding, false); + + seq_back.set_lod(lod); + seq_back.mutable_data(seq_dims, *place); + paddle::operators::math::UnpaddingLoDTensorFunctor()( + *context, seq_back, padding, false); + + if (paddle::platform::is_cpu_place(*place)) { + cpu_seq_back = seq_back; + } else { + Copy(seq_back, paddle::platform::CPUPlace(), *context, &cpu_seq_back); + cpu_seq_back.set_lod(lod); + } + + EXPECT_EQ(cpu_seq.numel(), cpu_seq_back.numel()); + EXPECT_EQ(cpu_seq.dims(), cpu_seq_back.dims()); + for (size_t i = 0; i < cpu_seq.numel(); ++i) { + EXPECT_EQ(cpu_seq.data()[i], cpu_seq_back.data()[i]); + } + + delete place; + delete context; +}; + +TEST(Seq2BatchPadding, CPU) { + paddle::framework::LoD lod1; + lod1.push_back(std::vector{0, 10}); + TestSequencePadding(lod1, 16); + + paddle::framework::LoD lod2; + lod2.push_back(std::vector{0, 2, 7, 10}); + TestSequencePadding(lod2, 128); +} + +#ifdef PADDLE_WITH_CUDA +TEST(SequencePadding, CUDA) { + paddle::framework::LoD lod1; + lod1.push_back(std::vector{0, 10}); + TestSequencePadding(lod1, 16); + + paddle::framework::LoD lod2; + lod2.push_back(std::vector{0, 2, 7, 10}); + TestSequencePadding(lod2, 128); +} +#endif diff --git a/paddle/operators/norm_op.cc b/paddle/operators/norm_op.cc index b198b76cd49ca7c05b047d42df149d2b1e461b8e..0eeafcaae0a01ed7090800ecb06708773ede67aa 100644 --- a/paddle/operators/norm_op.cc +++ b/paddle/operators/norm_op.cc @@ -39,7 +39,7 @@ class NormOpMaker : public framework::OpProtoAndCheckerMaker { "M = C * H * W"); AddComment(R"DOC( "Input shape: $(N, C, H, W)$ - Sclae shape: $(C, 1)$ + Scale shape: $(C, 1)$ Output shape: $(N, C, H, W)$ Where forward diff --git a/paddle/operators/norm_op.h b/paddle/operators/norm_op.h index 7bee48919e9fdf85595bbc7ad540ca45a1dbfe5c..5759d6f1f07d50d4c65c93578ba7d7916a2d9b4e 100644 --- a/paddle/operators/norm_op.h +++ b/paddle/operators/norm_op.h @@ -66,7 +66,7 @@ class NormKernel : public framework::OpKernel { context.GetPlace()); auto tmp = framework::EigenVector::Flatten(tmp_tensor); - // get colsum and sqrt , inverse + // get colsum and sqrt , inverse auto dim = Eigen::array({{0}}); tmp.device(*place) = x_square_batch_eigen.sum(dim); tmp.device(*place) = (tmp + epsilon).sqrt().inverse(); diff --git a/paddle/operators/parallel_do_op.cc b/paddle/operators/parallel_do_op.cc index a6bc70f4c89fb24cef5aefcb69b97fbaa9dc9d9c..e1bec0421e76143bef669a4f6fa373cdf01226b2 100644 --- a/paddle/operators/parallel_do_op.cc +++ b/paddle/operators/parallel_do_op.cc @@ -39,6 +39,7 @@ void SplitTensorAndMoveTensorToScopes( const std::vector &sub_scopes, const std::vector &places, const std::vector &names) { + PADDLE_ENFORCE_EQ(sub_scopes.size(), places.size()); for (auto &argu : names) { auto *var = scope.FindVar(argu); const auto &tensor = var->Get(); @@ -54,6 +55,15 @@ void SplitTensorAndMoveTensorToScopes( } } +void WaitOnPlaces(const std::vector places) { + platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); + + for (auto &place : places) { + auto &dev_ctx = *pool.Get(place); + dev_ctx.Wait(); + } +} + class ParallelDoOp : public framework::OperatorBase { public: ParallelDoOp(const std::string &type, @@ -71,10 +81,7 @@ class ParallelDoOp : public framework::OperatorBase { auto *block = Attr(kParallelBlock); auto *program = block->Program(); - // TODO(tonyyang-svail): get places from input - std::vector places; - places.emplace_back(platform::CPUPlace()); - places.emplace_back(platform::CPUPlace()); + auto &places = scope.FindVar(Input(kPlaces))->Get(); auto &sub_scopes = *scope.FindVar(Output(kParallelScopes)) ->GetMutable>(); @@ -82,8 +89,22 @@ class ParallelDoOp : public framework::OperatorBase { sub_scopes.push_back(&scope.NewScope()); } + // split input SplitTensorAndMoveTensorToScopes(scope, sub_scopes, places, Inputs(kInputs)); + // copy parameter + for (auto ¶m : Inputs(kParameters)) { + PADDLE_ENFORCE(scope.FindVar(param)->IsType(), + "Only support parameter type as LoDTensor"); + auto &src = scope.FindVar(param)->Get(); + for (size_t i = 0; i < places.size(); ++i) { + auto &place = places[i]; + auto *sub_scope = sub_scopes[i]; + auto *dst = sub_scope->Var(param)->GetMutable(); + framework::Copy(src, place, dst); + } + } + WaitOnPlaces(places); std::vector> workers; workers.reserve(places.size()); @@ -93,12 +114,6 @@ class ParallelDoOp : public framework::OperatorBase { auto &place = places[place_idx]; auto *cur_scope = sub_scopes[place_idx]; - // copy parameter - // some version of boost lacks != for boost::variant - if (!(dev_ctx.GetPlace() == place)) { - PADDLE_THROW("Not Implemented"); - } - workers.emplace_back(framework::Async([program, cur_scope, place, block] { framework::Executor executor(place); executor.Run(*program, cur_scope, block->ID(), @@ -108,6 +123,7 @@ class ParallelDoOp : public framework::OperatorBase { for (auto &worker : workers) { worker.wait(); } + WaitOnPlaces(places); // merge output for (auto &o_name : Outputs(kOutputs)) { @@ -121,6 +137,7 @@ class ParallelDoOp : public framework::OperatorBase { scope.FindVar(o_name)->GetMutable(); lod_tensor_to_be_merged->MergeLoDTensor(lod_tensors, dev_ctx.GetPlace()); } + WaitOnPlaces(places); } }; @@ -161,15 +178,14 @@ class ParallelDoGradOp : public OperatorBase { auto &sub_scopes = scope.FindVar(Input(kParallelScopes)) ->Get>(); - // TODO(tonyyang-svail): get places from input - std::vector places; - places.emplace_back(platform::CPUPlace()); - places.emplace_back(platform::CPUPlace()); + auto &places = scope.FindVar(Input(kPlaces))->Get(); // feed output@grad SplitTensorAndMoveTensorToScopes(scope, sub_scopes, places, Inputs(framework::GradVarName(kOutputs))); + WaitOnPlaces(places); + // for debugging for (auto &s : Inputs(framework::GradVarName(kOutputs))) { VLOG(3) << s; VLOG(3) << scope.FindVar(s)->Get(); @@ -196,10 +212,11 @@ class ParallelDoGradOp : public OperatorBase { for (auto &worker : workers) { worker.wait(); } + WaitOnPlaces(places); // merge grad for (auto &s : Outputs(framework::GradVarName(kParameters))) { - VLOG(3) << s; + VLOG(3) << "merge grad " << s; auto &t = sub_scopes[0]->FindVar(s)->Get(); VLOG(3) << t; @@ -216,7 +233,8 @@ class ParallelDoGradOp : public OperatorBase { auto sum_op = framework::OpRegistry::CreateOp( "sum", {{"X", {s, s_buf}}}, {{"Out", {s}}}, framework::AttributeMap{}); - sum_op->Run(*sub_scopes[0], place); + sum_op->Run(*sub_scopes[0], places[0]); + WaitOnPlaces(places); } VLOG(3) << t; @@ -236,8 +254,10 @@ class ParallelDoGradOpDescMaker : public framework::SingleGradOpDescMaker { for (auto &input_param : this->InputNames()) { VLOG(3) << input_param; grad->SetInput(input_param, this->Input(input_param)); - grad->SetOutput(framework::GradVarName(input_param), - this->InputGrad(input_param, false)); + if (input_param != kPlaces) { + grad->SetOutput(framework::GradVarName(input_param), + this->InputGrad(input_param, false)); + } } for (auto &output_param : this->OutputNames()) { diff --git a/paddle/operators/recv_op.cc b/paddle/operators/recv_op.cc index 6f65b87d3b06c1d8d453f42194a277accdbc1164..9331c7b563491902b2824898766cacb9bfdee2d9 100644 --- a/paddle/operators/recv_op.cc +++ b/paddle/operators/recv_op.cc @@ -32,6 +32,20 @@ limitations under the License. */ namespace paddle { namespace operators { +static void CreateTensorFromMessageType(framework::Variable *var, + sendrecv::VarType var_type) { + if (var_type == sendrecv::VarType::LOD_TENSOR) { + var->GetMutable(); + } else if (var_type == sendrecv::VarType::SELECTED_ROWS) { + var->GetMutable(); + } else { + PADDLE_THROW( + "VraibleMessage type %d is not in " + "[LoDTensor, SelectedRows]", + var_type); + } +} + void RunServer(Server **rpc_server, std::shared_ptr service, const std::string &server_address) { @@ -111,10 +125,10 @@ class RecvOp : public framework::OperatorBase { auto *merged_grad = recv_scope.FindVar(grad_var_name); if (merged_grad == nullptr) { auto *ptr = recv_scope.Var(grad_var_name); - framework::CreateTensor(ptr, - framework::ToVarType(merged_grad->Type())); + CreateTensorFromMessageType(ptr, v.second.type()); VLOG(3) << "Create Variable " << grad_var_name - << " on recv scope, which pointer is " << ptr; + << " on recv scope, which pointer is " << ptr << " type is " + << v.second.type(); } if (trainer_count > 1) { diff --git a/paddle/operators/sequence_pool_op.cc b/paddle/operators/sequence_pool_op.cc index 34e1a12591515e2363651a2722f52963f7ae43b5..549d9620ef20e2be7d030f41f1e7567cb1922f3b 100644 --- a/paddle/operators/sequence_pool_op.cc +++ b/paddle/operators/sequence_pool_op.cc @@ -115,12 +115,32 @@ class SequencePoolGradOp : public framework::OperatorWithKernel { } }; +class SequencePoolGradOpMaker : public framework::SingleGradOpDescMaker { + public: + using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; + + protected: + std::unique_ptr Apply() const override { + auto* op_desc_ptr = new framework::OpDesc(); + op_desc_ptr->SetType("sequence_pool_grad"); + op_desc_ptr->SetInput("X", Input("X")); + if (boost::get(GetAttr("pooltype")) == "MAX") { + op_desc_ptr->SetInput("MaxIndex", Output("MaxIndex")); + } + op_desc_ptr->SetInput(framework::GradVarName("Out"), OutputGrad("Out")); + op_desc_ptr->SetOutput(framework::GradVarName("X"), InputGrad("X")); + op_desc_ptr->SetAttrMap(Attrs()); + return std::unique_ptr(op_desc_ptr); + } +}; + } // namespace operators } // namespace paddle namespace ops = paddle::operators; -REGISTER_OP(sequence_pool, ops::SequencePoolOp, ops::SequencePoolOpMaker, - sequence_pool_grad, ops::SequencePoolGradOp); +REGISTER_OPERATOR(sequence_pool, ops::SequencePoolOp, ops::SequencePoolOpMaker, + ops::SequencePoolGradOpMaker); +REGISTER_OPERATOR(sequence_pool_grad, ops::SequencePoolGradOp); REGISTER_OP_CPU_KERNEL( sequence_pool, ops::SequencePoolKernel); diff --git a/paddle/operators/softmax_op.cc b/paddle/operators/softmax_op.cc index e7306bc5f13377813e0bd49846bc834d501602eb..cef1f1fc99d6a34db068257f94f935bf4435119d 100644 --- a/paddle/operators/softmax_op.cc +++ b/paddle/operators/softmax_op.cc @@ -31,6 +31,7 @@ class SoftmaxOp : public framework::OperatorWithKernel { PADDLE_ENFORCE(x_dims.size() == 2UL, "The input of softmax op must be a matrix."); ctx->SetOutputDim("Out", x_dims); + ctx->ShareLoD("X", /*->*/ "Out"); } }; diff --git a/paddle/operators/sum_op.h b/paddle/operators/sum_op.h index 2c43097d71751f3b5ac3b6366de095a22bac00ee..48201b344de0d3bd2b121a12389876dad095f10d 100644 --- a/paddle/operators/sum_op.h +++ b/paddle/operators/sum_op.h @@ -70,6 +70,7 @@ class SumKernel : public framework::OpKernel { } else if (out_var->IsType()) { PADDLE_ENFORCE(!in_place, "SelectedRows not support inplace sum now"); auto *out = context.Output("Out"); + out->mutable_rows()->clear(); auto *out_value = out->mutable_value(); // Runtime InferShape diff --git a/paddle/operators/top_k_op.cc b/paddle/operators/top_k_op.cc index bb72210bb67f925af3e450961069f0737dbde35e..a8ddd729732eb41618ea2788a6eb70a16a5ac70f 100644 --- a/paddle/operators/top_k_op.cc +++ b/paddle/operators/top_k_op.cc @@ -41,6 +41,8 @@ class TopkOp : public framework::OperatorWithKernel { dims[dims.size() - 1] = k; ctx->SetOutputDim("Out", dims); ctx->SetOutputDim("Indices", dims); + ctx->ShareLoD("X", "Out"); + ctx->ShareLoD("X", "Indices"); } }; diff --git a/paddle/operators/warpctc_op.cc b/paddle/operators/warpctc_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..bd0c5f9957663fb05670269aa6c39a976de7bc9c --- /dev/null +++ b/paddle/operators/warpctc_op.cc @@ -0,0 +1,141 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/warpctc_op.h" + +namespace paddle { +namespace operators { + +class WarpCTCOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("Logits"), + "Input(Logits) of WarpCTCOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Label"), + "Input(Label) of WarpCTCOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("WarpCTCGrad"), + "Output(WarpCTCGrad) of WarpCTCOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Loss"), + "Output(Loss) of WarpCTCOp should not be null."); + + auto logits_dims = ctx->GetInputDim("Logits"); + int sequence_width = + static_cast(framework::product(logits_dims) / logits_dims[0]); + int blank = ctx->Attrs().Get("blank"); + PADDLE_ENFORCE((blank >= 0) && (blank < sequence_width), + "The value of Attr(blank) should be in interval [0, %d).", + sequence_width); + // TODO(liuyiqun): it is tricky to set the wrong dimension here. + ctx->SetOutputDim("Loss", {logits_dims[0], 1}); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + framework::ToDataType(ctx.Input("Logits")->type()), + ctx.device_context()); + } +}; + +class WarpCTCOpMaker : public framework::OpProtoAndCheckerMaker { + public: + WarpCTCOpMaker(OpProto* proto, OpAttrChecker* op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("Logits", + "(LodTensor, default: LoDTensor), the unscaled " + "probabilities of variable-length sequences, which is a 2-D " + "Tensor with LoD information. It's shape is " + "[Lp, num_classes + 1], where Lp is the sum of all input " + "sequences' length and num_classes is the true number of classes " + "(not including the blank label)."); + AddInput("Label", + "(LodTensor, default: LoDTensor), the ground truth " + "of variable-length sequence, which is a 2-D Tensor with LoD " + "information. It is of the shape [Lg, 1], where Lg is th sum of " + "all labels' length."); + AddOutput("WarpCTCGrad", + "(Tensor, default: Tensor), a temporary " + "output Tensor to store the gradients of warp-ctc, which is " + "computed with loss together in one call. It is a 3-D Tensor of " + "the shape [max_sequence_length, batch_size, num_classes + 1].") + .AsIntermediate(); + AddOutput("Loss", + "(Tensor, default: Tensor), the Connectionist " + "Temporal Classification (CTC) loss, which is a 2-D Tensor of " + "the shape [batch_size, 1]"); + AddAttr("blank", + "(int, default: 0), the blank label of Connectionist " + "Temporal Classification (CTC) loss, which is in the " + "half-opened interval [0, num_classes + 1).") + .SetDefault(0); + AddAttr("norm_by_times", + "(bool, default: false), whether to " + "normalize the gradients by the number of time-step, " + "which is also the sequence's length.") + .SetDefault(false); + AddComment(R"DOC( +An operator integrating the open-source +[warp-ctc](https://github.com/baidu-research/warp-ctc) library, which is used in +[Deep Speech 2: End-toEnd Speech Recognition in English and Mandarin]( +https://arxiv.org/pdf/1512.02595v1.pdf), +to compute Connectionist Temporal Classification (CTC) loss. +It can be aliased as softmax with ctc, since a native softmax activation is +interated to the warp-ctc library, to to normlize values for each row of the +input tensor. + +More detail of CTC loss can be found by refering to +[Connectionist Temporal Classification: Labelling Unsegmented Sequence Data with +Recurrent Neural Networks]( +http://machinelearning.wustl.edu/mlpapers/paper_files/icml2006_GravesFGS06.pdf). +)DOC"); + } +}; + +class WarpCTCGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("WarpCTCGrad"), + "Input(WarpCTCGrad) of WarpCTCGradOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("Logits")), + "Output(Logits@GRAD) of WarpCTCGradOp should not be null."); + ctx->SetOutputDim(framework::GradVarName("Logits"), + ctx->GetInputDim("Logits")); + ctx->ShareLoD("Logits", /*->*/ framework::GradVarName("Logits")); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + framework::ToDataType(ctx.Input("Logits")->type()), + ctx.device_context()); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP(warpctc, ops::WarpCTCOp, ops::WarpCTCOpMaker, warpctc_grad, + ops::WarpCTCGradOp); +REGISTER_OP_CPU_KERNEL( + warpctc, ops::WarpCTCKernel); +REGISTER_OP_CPU_KERNEL( + warpctc_grad, + ops::WarpCTCGradKernel); diff --git a/paddle/operators/warpctc_op.cu.cc b/paddle/operators/warpctc_op.cu.cc new file mode 100644 index 0000000000000000000000000000000000000000..7d8527ac75fd4964f95198417a213ff6155aac2a --- /dev/null +++ b/paddle/operators/warpctc_op.cu.cc @@ -0,0 +1,22 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/warpctc_op.h" + +namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL( + warpctc, ops::WarpCTCKernel); +REGISTER_OP_CUDA_KERNEL( + warpctc_grad, + ops::WarpCTCGradKernel); diff --git a/paddle/operators/warpctc_op.h b/paddle/operators/warpctc_op.h new file mode 100644 index 0000000000000000000000000000000000000000..41899c7fe0c3089c4fc7c160c8896dec0e3cd6dd --- /dev/null +++ b/paddle/operators/warpctc_op.h @@ -0,0 +1,218 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/framework/op_registry.h" +#include "paddle/operators/math/math_function.h" +#include "paddle/operators/math/sequence_padding.h" +#include "paddle/platform/dynload/warpctc.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +using LoDTensor = framework::LoDTensor; + +template +class WarpCTCFunctor { + public: + /* + * \brief Compute the connectionist temporal classification loss, + * and optionally compute the gradient with respect to the inputs. + * + * If gradient is nullptr, it only computes the ctc loss, + * or computes both ctc loss and gradient. + * + * \param ctx execution context of this functor + * \param input batch matrix of input probabilities, in + * max_sequence_length x num_sequences x + * sequence_width, (row-major) format + * \param gradient batch matrix of gradient, with the same shape as + * input. + * \param cpu_labels labels always in CPU memory. + * \param cpu_label_lengths length of all labels in CPU memory. + * \param cpu_input_lengths length of all sequences in CPU memory. + * \param sequence_width number of possible output symbols. + * \param num_sequences number of sequence. + * \param blank blank label used in ctc loss function. + * \param cpu_losss cost of each sequence in CPU memory. + */ + void operator()(const framework::ExecutionContext& ctx, const float* input, + float* gradient, const int* cpu_labels, + const int* cpu_label_lengths, const int* cpu_input_lengths, + const size_t sequence_width, const size_t num_sequences, + const size_t blank, float* cpu_loss) { + // Init warp-ctc options + init(ctx, blank); + + // Compute the required workspace size. + // There is no memory allocated operations within warp-ctc. + size_t workspace_bytes = 0; + ctcStatus_t status = platform::dynload::get_workspace_size( + cpu_label_lengths, cpu_input_lengths, static_cast(sequence_width), + static_cast(num_sequences), options_, &workspace_bytes); + PADDLE_ENFORCE_EQ(CTC_STATUS_SUCCESS, status, + "warp-ctc [version %d] Error in get_workspace_size: ", + warpctc_version_, + platform::dynload::ctcGetStatusString(status)); + PADDLE_ENFORCE_GT(workspace_bytes, 0UL, + "Bytes of workspace got by warp-ctc function, " + "get_workspace_size(), should be larger than 0."); + + Tensor workspace; + size_t workspace_elements = workspace_bytes / sizeof(float) + 1UL; + float* workspace_data = workspace.mutable_data( + framework::make_ddim({static_cast(workspace_elements)}), + ctx.GetPlace()); + math::SetConstant()( + ctx.template device_context(), &workspace, + static_cast(0)); + + // compute loss and gradient + status = platform::dynload::compute_ctc_loss( + input, gradient, cpu_labels, cpu_label_lengths, cpu_input_lengths, + static_cast(sequence_width), static_cast(num_sequences), + cpu_loss, workspace_data, options_); + PADDLE_ENFORCE_EQ(CTC_STATUS_SUCCESS, status, + "warp-ctc [version %d] Error in compute_ctc_loss: ", + warpctc_version_, + platform::dynload::ctcGetStatusString(status)); + } + + protected: + void init(const framework::ExecutionContext& ctx, const size_t blank) { + warpctc_version_ = platform::dynload::get_warpctc_version(); + + if (platform::is_gpu_place(ctx.GetPlace())) { +#ifdef PADDLE_WITH_CUDA + options_.loc = CTC_GPU; + options_.stream = reinterpret_cast( + ctx.device_context()) + .stream(); +#else + PADDLE_THROW("[warpctc init] GPU is not enabled."); +#endif + } else { + options_.loc = CTC_CPU; + options_.num_threads = 1; + } + + options_.blank_label = blank; + } + + private: + int warpctc_version_; + ctcOptions options_; +}; + +template +class WarpCTCKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* logits = ctx.Input("Logits"); + auto* label = ctx.Input("Label"); + auto* warpctc_grad = ctx.Output("WarpCTCGrad"); + auto* loss = ctx.Output("Loss"); + + const size_t level = 0; + + auto logits_lod = framework::ToAbsOffset(logits->lod()); + auto logits_dims = logits->dims(); + PADDLE_ENFORCE_EQ(logits_dims[0], + static_cast(logits_lod[level].back()), + "The first dimension of Input(Logits) should be equal to " + "the sum of all sequences' lengths."); + + auto label_lod = framework::ToAbsOffset(label->lod()); + auto label_dims = label->dims(); + PADDLE_ENFORCE_EQ( + label_dims[0], label->numel(), + "The width of each timestep in Input(Label) should be 1."); + + const size_t num_sequences = logits_lod[level].size() - 1; + PADDLE_ENFORCE_EQ(num_sequences, label_lod[level].size() - 1, + "The number of sequences of Input(Logits) should be " + "equal to that of Input(Label)."); + + const size_t sequence_width = logits->numel() / logits_dims[0]; + auto loss_dims = + framework::make_ddim({static_cast(num_sequences), 1}); + + // warpctc needs sequences data stored in transposed padding format + Tensor warpctc_logits; + const size_t max_sequence_length = + math::MaximumSequenceLength(logits_lod, level); + auto warpctc_logits_dims = + framework::make_ddim({static_cast(max_sequence_length), + static_cast(num_sequences), + static_cast(sequence_width)}); + warpctc_logits.mutable_data(warpctc_logits_dims, ctx.GetPlace()); + math::PaddingLoDTensorFunctor()( + ctx.template device_context(), *logits, warpctc_logits, + false); + const T* warpctc_logits_data = warpctc_logits.data(); + + std::vector warpctc_label_lengths(num_sequences); + std::vector warpctc_logits_lengths(num_sequences); + + for (size_t i = 0; i < num_sequences; ++i) { + warpctc_label_lengths[i] = label_lod[level][i + 1] - label_lod[level][i]; + warpctc_logits_lengths[i] = + logits_lod[level][i + 1] - logits_lod[level][i]; + } + + // warpctc computes loss and gradient in one call, gradient data also stored + // in batch format + T* warpctc_grad_data = + warpctc_grad->mutable_data(warpctc_logits.dims(), ctx.GetPlace()); + + // warpctc accesses labels in CPU memory + Tensor warpctc_label; + Copy(*label, platform::CPUPlace(), ctx.device_context(), &warpctc_label); + const int* warpctc_label_data = warpctc_label.data(); + + // warpctc stores loss in CPU memory + Tensor warpctc_loss; + T* warpctc_loss_data = + warpctc_loss.mutable_data(loss_dims, platform::CPUPlace()); + + const size_t blank = static_cast(ctx.Attr("blank")); + + WarpCTCFunctor()( + ctx, warpctc_logits_data, warpctc_grad_data, warpctc_label_data, + warpctc_label_lengths.data(), warpctc_logits_lengths.data(), + sequence_width, num_sequences, blank, warpctc_loss_data); + + // Copy the loss back + Copy(warpctc_loss, ctx.GetPlace(), ctx.device_context(), loss); + } +}; + +template +class WarpCTCGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* warpctc_grad = ctx.Input("WarpCTCGrad"); + auto* logits_grad = ctx.Output(framework::GradVarName("Logits")); + + bool norm_by_times = ctx.Attr("norm_by_times"); + math::UnpaddingLoDTensorFunctor()( + ctx.template device_context(), *logits_grad, + *warpctc_grad, norm_by_times); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/while_op.cc b/paddle/operators/while_op.cc index 65d827e0e0c5cfc3897c1fd0b971b766201cc1e2..7a3400919efe6f3bed40e45a245b556beab6fce4 100644 --- a/paddle/operators/while_op.cc +++ b/paddle/operators/while_op.cc @@ -211,59 +211,54 @@ class WhileGradOpDescMaker : public framework::SingleGradOpDescMaker { protected: std::unique_ptr Apply() const override { - auto *grad = new framework::OpDesc(); - grad->SetType("while_grad"); - grad->SetInput(kX, Input(kX)); + auto *while_grad = new framework::OpDesc(); + while_grad->SetType("while_grad"); + while_grad->SetInput(kX, Input(kX)); + while_grad->SetInput(kOutputs, Output(kOutputs)); + while_grad->SetInput(kStepScopes, Output(kStepScopes)); + + auto *grad_block = this->grad_block_[0]; + auto *fwd_block = grad_block->ParentBlock(); // Not all of IGs will be generated by inner gradient operators of while op. // Ignore IGs that is not generated by the inside block. - auto igs = InputGrad(kX, /*do not drop empty gradient*/ false); - std::unordered_set all_outs; - for (size_t i = 0; i < grad_block_[0]->OpSize(); ++i) { - for (auto &oname : grad_block_[0]->Op(i)->OutputArgumentNames()) { - all_outs.insert(oname); + std::unordered_set inner_op_outputs; + for (const auto *op : grad_block->AllOps()) { + for (auto &oname : op->OutputArgumentNames()) { + inner_op_outputs.insert(oname); } } + auto igs = InputGrad(kX, /*do not drop empty gradient*/ false); for (auto &each_ig : igs) { - if (all_outs.find(each_ig) == all_outs.end()) { + if (inner_op_outputs.find(each_ig) == inner_op_outputs.end()) { VLOG(10) << "Ignore " << each_ig; each_ig = framework::kEmptyVarName; } } - - grad->SetOutput(framework::GradVarName(kX), igs); - - grad->SetInput(kOutputs, Output(kOutputs)); + while_grad->SetOutput(framework::GradVarName(kX), igs); // OG should be re-calculated by step blocks, since many outputs of while op // do not need to calculate gradients. std::unordered_set block_ins; - auto *fwd_block = this->grad_block_[0]->ParentBlock(); - { - for (auto &p : Input(kX)) { - block_ins.insert(p); - } - for (auto &o : Output(kOutputs)) { - block_ins.insert(o); - } + block_ins.reserve(Input(kX).size() + Output(kOutputs).size()); + for (auto &p : Input(kX)) { + block_ins.insert(p); + } + for (auto &o : Output(kOutputs)) { + block_ins.insert(o); } std::unordered_set extra_inputs; - for (size_t i = 0; i < grad_block_[0]->OpSize(); ++i) { - for (auto &input_name : grad_block_[0]->Op(i)->InputArgumentNames()) { - if (block_ins.find(input_name) != block_ins.end()) { - continue; - } - - // If the input of Op is generated by the forward block, do not make it - // as input again. - if (fwd_block->FindVar(input_name) != nullptr) { + for (const auto *op : grad_block->AllOps()) { + for (auto &input_name : op->InputArgumentNames()) { + // If the input of Op has been recorded or is generated by the forward + // block, do not make it as input again. + if (block_ins.find(input_name) != block_ins.end() || + fwd_block->FindVar(input_name) != nullptr) { continue; } - extra_inputs.insert(input_name); } - - for (auto &output_name : grad_block_[0]->Op(i)->OutputArgumentNames()) { + for (auto &output_name : op->OutputArgumentNames()) { block_ins.insert(output_name); } } @@ -272,15 +267,15 @@ class WhileGradOpDescMaker : public framework::SingleGradOpDescMaker { extra_inputs_list.resize(extra_inputs.size()); std::copy(extra_inputs.begin(), extra_inputs.end(), extra_inputs_list.begin()); - grad->SetInput(framework::GradVarName(kOutputs), extra_inputs_list); - grad->SetInput(kStepScopes, Output(kStepScopes)); - grad->SetAttrMap(this->Attrs()); - grad->SetBlockAttr(kStepBlock, *grad_block_[0]); + while_grad->SetInput(framework::GradVarName(kOutputs), extra_inputs_list); + + while_grad->SetAttrMap(this->Attrs()); + while_grad->SetBlockAttr(kStepBlock, *grad_block); // record the original output gradient names, since the gradient name of // while operator could be renamed. - grad->SetAttr("original_output_grad", extra_inputs_list); + while_grad->SetAttr("original_output_grad", extra_inputs_list); - return std::unique_ptr(grad); + return std::unique_ptr(while_grad); } }; diff --git a/paddle/platform/device_context.h b/paddle/platform/device_context.h index 7a0040c9c229af79ea8be1049dfd6c0d1b4d19cf..9826a642768a3346110a7872a726aba15eac1fb0 100644 --- a/paddle/platform/device_context.h +++ b/paddle/platform/device_context.h @@ -185,6 +185,8 @@ class DeviceContextPool { const typename DefaultDeviceContextType::TYPE*>(Get(place)); } + size_t size() const { return device_contexts_.size(); } + private: static DeviceContextPool* pool; constexpr static int LEFT_SHIFT = 8; diff --git a/paddle/platform/dynload/CMakeLists.txt b/paddle/platform/dynload/CMakeLists.txt index f4fda65907dc26e9edb91ee46f3b8bd2de7b3f3a..cf2081b434961c17c1b65509909699788d2b9ad9 100644 --- a/paddle/platform/dynload/CMakeLists.txt +++ b/paddle/platform/dynload/CMakeLists.txt @@ -1,3 +1,4 @@ cc_library(dynamic_loader SRCS dynamic_loader.cc DEPS glog gflags enforce) nv_library(dynload_cuda SRCS cublas.cc cudnn.cc curand.cc nccl.cc DEPS dynamic_loader nccl) +cc_library(dynload_warpctc SRCS warpctc.cc DEPS dynamic_loader warpctc) diff --git a/paddle/platform/dynload/cublas.cc b/paddle/platform/dynload/cublas.cc index 9cd2a1f565526f8dc45932ba6168f4e25c6ad238..6aca716657c5f629ce40e88c04dc25fbbc9b4f36 100644 --- a/paddle/platform/dynload/cublas.cc +++ b/paddle/platform/dynload/cublas.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include +#include "paddle/platform/dynload/cublas.h" namespace paddle { namespace platform { diff --git a/paddle/platform/dynload/warpctc.cc b/paddle/platform/dynload/warpctc.cc new file mode 100644 index 0000000000000000000000000000000000000000..9b7d01a6e8f6cccc0082f65f9511085d2a3111b7 --- /dev/null +++ b/paddle/platform/dynload/warpctc.cc @@ -0,0 +1,30 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/platform/dynload/warpctc.h" + +namespace paddle { +namespace platform { +namespace dynload { + +std::once_flag warpctc_dso_flag; +void* warpctc_dso_handle = nullptr; + +#define DEFINE_WRAP(__name) DynLoad__##__name __name + +WARPCTC_ROUTINE_EACH(DEFINE_WRAP); + +} // namespace dynload +} // namespace platform +} // namespace paddle diff --git a/paddle/platform/dynload/warpctc.h b/paddle/platform/dynload/warpctc.h new file mode 100644 index 0000000000000000000000000000000000000000..acafcaff2ccc33c315f216e31b866bf4c8ae1fec --- /dev/null +++ b/paddle/platform/dynload/warpctc.h @@ -0,0 +1,63 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include "ctc.h" +#include "paddle/platform/dynload/dynamic_loader.h" + +namespace paddle { +namespace platform { +namespace dynload { + +extern std::once_flag warpctc_dso_flag; +extern void* warpctc_dso_handle; + +/** + * The following macro definition can generate structs + * (for each function) to dynamic load warpctc routine + * via operator overloading. + */ +#define DYNAMIC_LOAD_WARPCTC_WRAP(__name) \ + struct DynLoad__##__name { \ + template \ + auto operator()(Args... args) -> decltype(__name(args...)) { \ + using warpctcFunc = decltype(__name(args...)) (*)(Args...); \ + std::call_once(warpctc_dso_flag, \ + paddle::platform::dynload::GetWarpCTCDsoHandle, \ + &warpctc_dso_handle); \ + void* p_##_name = dlsym(warpctc_dso_handle, #__name); \ + return reinterpret_cast(p_##_name)(args...); \ + } \ + }; \ + extern DynLoad__##__name __name + +#define DECLARE_DYNAMIC_LOAD_WARPCTC_WRAP(__name) \ + DYNAMIC_LOAD_WARPCTC_WRAP(__name) + +#define WARPCTC_ROUTINE_EACH(__macro) \ + __macro(get_warpctc_version); \ + __macro(ctcGetStatusString); \ + __macro(compute_ctc_loss); \ + __macro(get_workspace_size) + +WARPCTC_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_WARPCTC_WRAP); + +#undef DYNAMIC_LOAD_WARPCTC_WRAP + +} // namespace dynload +} // namespace platform +} // namespace paddle diff --git a/paddle/testing/paddle_gtest_main.cc b/paddle/testing/paddle_gtest_main.cc index 108ff335bf6b920c648d4bfebbd6a40ffb6fd939..a7fb50ee4149a3c36077f83383f45f3106e7e0f1 100644 --- a/paddle/testing/paddle_gtest_main.cc +++ b/paddle/testing/paddle_gtest_main.cc @@ -34,11 +34,11 @@ int main(int argc, char** argv) { google::ParseCommandLineFlags(&new_argc, &new_argv_address, false); testing::InitGoogleTest(&argc, argv); paddle::memory::Used(paddle::platform::CPUPlace()); - std::vector devs = {"CPU"}; + #ifdef PADDLE_WITH_CUDA paddle::memory::Used(paddle::platform::CUDAPlace(0)); - devs.push_back("GPU:0"); #endif - paddle::framework::InitDevices(devs); + + paddle::framework::InitDevices(); return RUN_ALL_TESTS(); } diff --git a/python/paddle/v2/fluid/__init__.py b/python/paddle/v2/fluid/__init__.py index 5e01b8719806f4bb0c0d985373a5f4b076e05bd5..ccd5998e3592a1f5dc795ee24875c1aed230587e 100644 --- a/python/paddle/v2/fluid/__init__.py +++ b/python/paddle/v2/fluid/__init__.py @@ -19,12 +19,13 @@ from data_feeder import DataFeeder from core import LoDTensor, CPUPlace, CUDAPlace from distribute_transpiler import DistributeTranspiler import clip +from memory_optimization_transpiler import memory_optimize Tensor = LoDTensor __all__ = framework.__all__ + executor.__all__ + [ 'io', 'initializer', 'layers', 'nets', 'optimizer', 'backward', 'regularizer', 'LoDTensor', 'CPUPlace', 'CUDAPlace', 'Tensor', 'ParamAttr' - 'DataFeeder', 'clip', 'DistributeTranspiler' + 'DataFeeder', 'clip', 'DistributeTranspiler', 'memory_optimize' ] @@ -61,11 +62,7 @@ def __bootstrap__(): core.init_gflags([sys.argv[0]] + ["--tryfromenv=" + ",".join(read_env_flags)]) core.init_glog(sys.argv[0]) - - if core.is_compile_gpu(): - core.init_devices(["CPU", "GPU:0"]) - else: - core.init_devices(["CPU"]) + core.init_devices() __bootstrap__() diff --git a/python/paddle/v2/fluid/backward.py b/python/paddle/v2/fluid/backward.py index 8d0eb53b8e1ff6b0a05218c6d0c4e017a21b3fbb..cea2d1e09068da20f4d2fdbfbd9a3e3a511ba267 100644 --- a/python/paddle/v2/fluid/backward.py +++ b/python/paddle/v2/fluid/backward.py @@ -1,8 +1,9 @@ from paddle.v2.fluid import framework as framework from . import core import collections +import copy -__all__ = ['append_backward'] +__all__ = ['append_backward', 'calc_gradient'] def _rename_arg_(op_descs, old_name, new_name, begin_idx=None, end_idx=None): @@ -65,6 +66,18 @@ def _all_in_set_(cands, s): return True +def _some_in_set_(cands, s): + """ + Test if some elements of 'cands' are in set 's' + """ + if len(cands) == 0: + return False + for c in cands: + if c in s: + return True + return False + + def _strip_grad_suffix_(name): """ Strip the grad suffix from the given varibale name @@ -169,8 +182,8 @@ def _remove_no_grad_branch_(op_descs, no_grad_set): return op_descs -def _append_backward_ops_(target, - block, +def _append_backward_ops_(block, + ops, target_block, no_grad_dict, grad_to_var, @@ -179,8 +192,8 @@ def _append_backward_ops_(target, Create all grad ops, and insert them into given block Args: - target(Variable): the target variable of forward pass block(Block): the block where forward ops are + ops(Op): the forward operators whose backward ops need to be added target_block(Block): the block which is going to hold new generated grad ops no_grad_dict(dict): key(int) block index @@ -202,14 +215,14 @@ def _append_backward_ops_(target, # grad_op_descs holds created grad_op, and will be appended to target_block grad_op_descs = [] program = block.program - for op in reversed(block.ops): + for op in reversed(ops): grad_sub_block_list = [] # If the op has its own sub-block, deal with the sub-block first if op.has_attr("sub_block"): sub_block = program.block(op.block_attr("sub_block")) grad_sub_block = program.create_block(parent_idx=sub_block.idx) - _append_backward_ops_(target, sub_block, grad_sub_block, - no_grad_dict, grad_to_var, callback) + _append_backward_ops_(sub_block, sub_block.ops, grad_sub_block, + no_grad_dict, grad_to_var) grad_sub_block_list.append(grad_sub_block.desc) # Getting op's corresponding grad_op @@ -224,14 +237,6 @@ def _append_backward_ops_(target, grad_op_descs = _remove_no_grad_branch_(grad_op_descs, no_grad_dict[block.idx]) - if target_block.idx == 0: - grad_op_descs.insert( - 0, - _create_op_desc_("fill_constant", {}, { - "Out": [_append_grad_suffix_(target.name)] - }, {"shape": [1], - "value": 1.0, - "dtype": target.dtype})) # append op_desc in grad_op_descs to target_block for op_desc in grad_op_descs: new_op_desc = target_block.desc.append_op() @@ -252,7 +257,7 @@ def _append_backward_vars_(block, start_op_idx, grad_to_var, grad_info_map): In most cases, this dict is generated by _append_backward_ops_() grad_info_map(dict)(output argument): key(str): forward variable name - val(tuple): a tuple of (str, int), str is the corresponding grad name, int is the block index + val(tuple): a tuple of (str, Block), str is the corresponding grad name, Block is the block containing grad variable """ for op_idx in range(start_op_idx, block.desc.op_size()): op_desc = block.desc.op(op_idx) @@ -279,41 +284,63 @@ def _append_backward_vars_(block, start_op_idx, grad_to_var, grad_info_map): _infer_var_data_type_(arg, block) +def _rename_grad_(block, start_op_idx, grad_to_var, target_grad_map): + var_map = copy.copy(target_grad_map) + for op_idx in range(start_op_idx, block.desc.op_size()): + op_desc = block.desc.op(op_idx) + for name in op_desc.input_arg_names(): + if name in var_map: + op_desc.rename_input(name, var_map[name]) + + for name in op_desc.output_arg_names(): + if block.desc.find_var(name.encode("ascii")): + new_name = "%s_%s" % (name, core.unique_integer(name)) + op_desc.rename_output(name, new_name) + var_map[name] = new_name + + for g, ng in var_map.iteritems(): + if g in grad_to_var: + grad_to_var[ng] = grad_to_var[g] + grad_to_var.pop(g) + + +def _get_stop_gradients_(program): + no_grad_dict = dict() + assert isinstance(program, framework.Program) + for block in program.blocks: + assert isinstance(block, framework.Block) + block_no_grad_set = set() + for var in block.vars.itervalues(): + assert isinstance(var, framework.Variable) + if var.stop_gradient: + block_no_grad_set.add(_append_grad_suffix_(var.name)) + no_grad_dict[block.idx] = block_no_grad_set + return no_grad_dict + + def append_backward(loss, parameter_list=None, no_grad_set=None, callback=None): """ Append backward part to main_program Args: loss(Variable): The variable generated by cost function. - parameter_list(list): Parameters that need to be updated by optimizer. - If None, it means all parameters need to be updated. + parameter_list(list[string]): Parameters that need to be updated by + optimizer. If None, it means all parameters need to be updated. no_grad_set(set): Variables that have no gradients in Block 0. - If None, the set will be generated inside the function and - contains all variables with `step_gradient=True` from all blocks. + All variables with `step_gradient=True` from all blocks will be + automatically added. Return: - (list[Variable]): list of (parameters, gradients) pair. + (list[(Variable,Variable)]): list of (parameter, gradient) pair. """ assert isinstance(loss, framework.Variable) program = loss.block.program - no_grad_dict = dict() if no_grad_set is None: - assert isinstance(program, framework.Program) - for block in program.blocks: - assert isinstance(block, framework.Block) - block_no_grad_set = set() - for var in block.vars.itervalues(): - assert isinstance(var, framework.Variable) - if var.stop_gradient: - block_no_grad_set.add(_append_grad_suffix_(var.name)) - no_grad_dict[block.idx] = block_no_grad_set - elif isinstance(no_grad_set, set): - no_grad_dict = { - 0: set([_append_grad_suffix_(name) for name in no_grad_set]) - } - else: - raise ValueError("'no_grad_set' should be a set or None.") + no_grad_set = set() + no_grad_set = copy.copy(no_grad_set) + no_grad_dict = _get_stop_gradients_(program) + no_grad_dict[0].update(map(_append_grad_suffix_, no_grad_set)) grad_info_map = dict() root_block = program.block(0) @@ -322,8 +349,25 @@ def append_backward(loss, parameter_list=None, no_grad_set=None, callback=None): current_block_idx = program.current_block_idx grad_to_var = dict() - _append_backward_ops_(loss, root_block, root_block, no_grad_dict, + op_desc = _create_op_desc_("fill_constant", {}, { + "Out": [_append_grad_suffix_(loss.name)] + }, {"shape": [1], + "value": 1.0, + "dtype": loss.dtype}) + root_block.desc.append_op().copy_from(op_desc) + + block_no_grad_set = set(map(_strip_grad_suffix_, no_grad_dict[0])) + op_path = _find_op_path_(root_block, [loss], [], block_no_grad_set) + no_grad_dict[0].update(map(_append_grad_suffix_, block_no_grad_set)) + + _append_backward_ops_(root_block, op_path, root_block, no_grad_dict, grad_to_var, callback) + + # Because calc_gradient may be called multiple times, + # we need rename the internal gradient variables so that they have + # different names. + _rename_grad_(root_block, fwd_op_num, grad_to_var, {}) + _append_backward_vars_(root_block, fwd_op_num, grad_to_var, grad_info_map) program.current_block_idx = current_block_idx @@ -334,6 +378,7 @@ def append_backward(loss, parameter_list=None, no_grad_set=None, callback=None): else: params = program.global_block().all_parameters() parameters = [param.name for param in params] + params_and_grads = [] for param in parameters: if param not in grad_info_map: @@ -351,3 +396,147 @@ def append_backward(loss, parameter_list=None, no_grad_set=None, callback=None): else: params_and_grads.append((param_var, None)) return params_and_grads + + +def _as_list(x): + if x is None: + return [] + return list(x) if isinstance(x, collections.Sequence) else [x] + + +def _find_op_path_(block, outputs, inputs, no_grad_set): + """ + no_grad_set will also be changed + """ + input_names = set([inp.name for inp in inputs]) + output_names = set([out.name for out in outputs]) + + relevant_op_flags = [True] * len(block.ops) + + # All the inputs of the block are used if inputs is empty, + if inputs: + for i, op in enumerate(block.ops): + if _some_in_set_(op.desc.input_arg_names(), input_names): + for name in op.desc.output_arg_names(): + if name not in no_grad_set: + input_names.add(name) + else: + relevant_op_flags[i] = False + + for i, op in reversed(list(enumerate(block.ops))): + if _some_in_set_(op.desc.output_arg_names(), output_names): + for name in op.desc.input_arg_names(): + if name not in no_grad_set: + output_names.add(name) + else: + relevant_op_flags[i] = False + + op_path = [ + block.ops[i] for i in range(len(block.ops)) if relevant_op_flags[i] + ] + + if inputs: + for op in op_path: + for name in op.desc.input_arg_names(): + if name not in input_names: + no_grad_set.add(name) + + return op_path + + +def calc_gradient(targets, inputs, target_gradients=None, no_grad_set=None): + """ + Backpropagate the graidents of targets to inputs. + + Args: + targets(Variable|list[Variable]): The target variables + inputs(Variable|list[Variable]): The input variables + no_grad_set(set[string]): The names of variables that have no gradients + in Block 0. All variables with `stop_gradient=True` from all blocks + will be automatically added. + + Return: + (list[Variable]): list of gradients for inputs + If an input does not affect targets, the corresponding gradient variable + will be None + """ + targets = _as_list(targets) + inputs = _as_list(inputs) + target_gradients = _as_list(target_gradients) + + block = targets[0].block + prog = block.program + block_idx = block.idx + + if not target_gradients: + target_gradients = [None] * len(targets) + + if len(targets) != len(target_gradients): + raise ValueError( + "Should have the same number of target_gradients as targets") + + if no_grad_set is None: + no_grad_set = set() + no_grad_set = copy.copy(no_grad_set) + no_grad_dict = _get_stop_gradients_(prog) + no_grad_dict[0].update(map(_append_grad_suffix_, no_grad_set)) + + fwd_op_num = block.desc.op_size() + + target_grad_map = {} + for i, grad in enumerate(target_gradients): + target = targets[i] + if grad is None: + grad_name = _append_grad_suffix_(target.name) + op_desc = _create_op_desc_("fill_constant_batch_size_like", + {"Input": [target.name]}, + {"Out": [grad_name]}, { + "shape": target.shape, + "value": 1.0, + "dtype": target.dtype, + 'input_dim_idx': 0, + 'output_dim_idx': 0 + }) + block.desc.append_op().copy_from(op_desc) + else: + if target.block.idx != block_idx or target.block.program != prog: + raise ValueError("all targets must be in the same block") + if target.shape != grad.shape: + raise ValueError( + "The shapes of target and grad are different: %s %s" % ( + target.name, grad.name)) + target_grad_map[_append_grad_suffix_(target.name)] = grad.name + + for input in inputs: + if input.block.program != prog: + raise "input must be in the same program as targets" + + block_no_grad_set = set(map(_strip_grad_suffix_, no_grad_dict[0])) + op_path = _find_op_path_(block, targets, inputs, block_no_grad_set) + no_grad_dict[0].update(map(_append_grad_suffix_, block_no_grad_set)) + grad_to_var = dict() + grad_info_map = dict() + _append_backward_ops_(block, op_path, block, no_grad_dict, grad_to_var) + + # Because calc_gradient may be called multiple times, + # we need rename the internal gradient variables so that they have + # different names. + _rename_grad_(block, fwd_op_num, grad_to_var, target_grad_map) + + _append_backward_vars_(block, fwd_op_num, grad_to_var, grad_info_map) + prog.sync_with_cpp() + + grad_vars = [] + for input_var in inputs: + if input_var.name not in grad_info_map: + grad_vars.append(None) + else: + grad_info = grad_info_map[input_var.name] + grad_block = grad_info[1] + grad_var = grad_block.var(grad_info[0]) + grad_vars.append(grad_var) + + if len(grad_vars) == 1: + return grad_vars[0] + else: + return grad_vars diff --git a/python/paddle/v2/fluid/framework.py b/python/paddle/v2/fluid/framework.py index 85c1e6eb7ba37f71d79f1dd6c34539d0b1dcbf11..2fb388acfc0a9f19b26c92de95de6a0dc0d9c018 100644 --- a/python/paddle/v2/fluid/framework.py +++ b/python/paddle/v2/fluid/framework.py @@ -773,6 +773,9 @@ class Program(object): proto = framework_pb2.ProgramDesc.FromString(str(protostr)) return _debug_string_(proto, throw_on_error) + def get_desc(self): + return self.desc + def clone(self): p = Program() p.desc = core.ProgramDesc(self.desc) diff --git a/python/paddle/v2/fluid/layers/ops.py b/python/paddle/v2/fluid/layers/ops.py index 23fe13f9bbf3e81802ac86415472e6aa603711b1..544623c4bce0cb75ea727906c4879e986c8d1ce8 100644 --- a/python/paddle/v2/fluid/layers/ops.py +++ b/python/paddle/v2/fluid/layers/ops.py @@ -1,7 +1,17 @@ from ..registry import register_layer __activations__ = [ - 'abs', 'tanh', 'sigmoid', 'relu', 'sqrt', 'ceil', 'floor', 'log', 'round' + 'abs', + 'ceil', + 'exp', + 'floor', + 'log', + 'relu', + 'round', + 'sigmoid', + 'sqrt', + 'square', + 'tanh', ] __all__ = [ diff --git a/python/paddle/v2/fluid/layers/tensor.py b/python/paddle/v2/fluid/layers/tensor.py index 9ce25a9e0831a49ef3bbc5026181856e6c4cdfcc..5f12ecfc14f7521948acdf27f1d6249e8052abc5 100644 --- a/python/paddle/v2/fluid/layers/tensor.py +++ b/python/paddle/v2/fluid/layers/tensor.py @@ -1,7 +1,8 @@ from ..layer_helper import LayerHelper +from ..param_attr import ParamAttr __all__ = [ - 'create_tensor', 'cast', 'concat', 'sums', 'assign', + 'create_tensor', 'create_parameter', 'cast', 'concat', 'sums', 'assign', 'fill_constant_batch_size_like', 'fill_constant', 'ones', 'zeros' ] @@ -11,6 +12,33 @@ def create_tensor(dtype, name=None): return helper.create_variable(name=helper.name, dtype=dtype) +def create_parameter(shape, + dtype, + attr=None, + is_bias=False, + default_initializer=None): + """ + Create a parameter + Args: + shape(list[int]): shape of the parameter + dtype(string): element type of the parameter + attr(ParamAttr): attributes of the parameter + is_bias(bool): This can affect which default initializer is chosen + when default_initializer is None. If is_bias, + initializer.Constant(0.0) will be used. Otherwise, + Xavier() will be used. + default_initializer(Initializer): initializer for the parameter + + Returns: + Parameter: the created parameter + """ + helper = LayerHelper("create_parameter") + if attr is None: + attr = ParamAttr() + return helper.create_parameter(attr, shape, dtype, is_bias, + default_initializer) + + def cast(x, dtype): """ This function takes in the input with input_dtype @@ -180,7 +208,8 @@ def fill_constant_batch_size_like(input, Examples: .. code-block:: python - data = fluid.layers.fill_constant(shape=[1], value=0, dtype='int64') + data = fluid.layers.fill_constant_batch_size_like( + input=like, shape=[1], value=0, dtype='int64') """ helper = LayerHelper("fill_constant_batch_size_like", **locals()) out = helper.create_tmp_variable(dtype=dtype) diff --git a/python/paddle/v2/fluid/memory_optimization_transpiler.py b/python/paddle/v2/fluid/memory_optimization_transpiler.py new file mode 100644 index 0000000000000000000000000000000000000000..571fce7fac616356ae0368b407e90537caa42977 --- /dev/null +++ b/python/paddle/v2/fluid/memory_optimization_transpiler.py @@ -0,0 +1,115 @@ +from collections import defaultdict +import framework +from framework import Program, default_main_program, Parameter, Variable +import backward +from backward import _rename_arg_ + + +class ControlFlowGraph(object): + def __init__(self, Program): + self._program = Program + self._succesors = defaultdict(set) + self._presucessors = defaultdict(set) + self._uses = defaultdict(set) + self._defs = defaultdict(set) + self._live_in = defaultdict(set) + self._live_out = defaultdict(set) + + def _add_connections(self, connections): + for node1, node2 in connections: + self._add(node1, node2) + + def _add(self, node1, node2): + self._succesors[node1].add(node2) + self._presucessors[node2].add(node1) + + def _build_graph(self): + program_desc = self._program.get_desc() + block_size = program_desc.num_blocks() + + # TODO(qijun) handle Program with if/while operators + self.global_block = program_desc.block(0) + self.op_size = self.global_block.op_size() + + op_node_connections = [(i, i + 1) for i in range(self.op_size - 1)] + self._add_connections(op_node_connections) + + self.ops = [self.global_block.op(i) for i in range(self.op_size)] + + for i in range(self.op_size): + self._uses[i].update(self.ops[i].input_arg_names()) + self._defs[i].update(self.ops[i].output_arg_names()) + + def _reach_fixed_point(self, live_in, live_out): + if len(live_in) != len(self._live_in): + return False + if len(live_out) != len(self._live_out): + return False + for i in range(self.op_size): + if live_in[i] != self._live_in[i]: + return False + for i in range(self.op_size): + if live_out[i] != self._live_out[i]: + return False + return True + + def _dataflow_analyze(self): + self._build_graph() + live_in = defaultdict(set) + live_out = defaultdict(set) + while True: + for i in range(self.op_size): + live_in[i] = set(self._live_in[i]) + live_out[i] = set(self._live_out[i]) + self._live_in[i] = self._uses[i] | ( + self._live_out[i] - self._defs[i]) + for s in self._succesors[i]: + self._live_out[i] |= self._live_in[s] + + if self._reach_fixed_point(live_in, live_out): + break + + def _get_diff(self, a, b): + u = a & b + return a - u, b - u + + def memory_optimize(self): + self._build_graph() + self._dataflow_analyze() + self.pool = [] + for i in range(self.op_size): + if self.pool: + out_pair = [(x, self.global_block.var(str(x)).shape()) + for x in self._defs[i]] + for x, x_shape in out_pair: + for index, cache_pair in enumerate(self.pool): + cache_var = cache_pair[0] + cache_shape = cache_pair[1] + if x_shape == cache_shape: + print( + "Hit Cache !!!! cache pool index is %d, var name is %s, cached var name is %s, var shape is %s " + % (index, x, cache_var, str(cache_shape))) + self.pool.pop(index) + _rename_arg_(self.ops, x, cache_var, begin_idx=i) + self._dataflow_analyze() + break + + in_diff, out_diff = self._get_diff(self._live_in[i], + self._live_out[i]) + can_optimize = filter( + lambda x: not self.global_block.var(str(x)).persistable(), + in_diff) + if can_optimize: + for var_name in can_optimize: + self.pool.append(( + var_name, self.global_block.var(str(var_name)).shape())) + + def get_program(self): + return self._program + + +def memory_optimize(input_program): + graph = ControlFlowGraph(input_program) + graph.memory_optimize() + result_program = graph.get_program() + return result_program diff --git a/python/paddle/v2/fluid/tests/book_distribute/test_dist_word2vec.py b/python/paddle/v2/fluid/tests/book_distribute/test_dist_word2vec.py new file mode 100644 index 0000000000000000000000000000000000000000..b41853784d607c566fc596ab93f2282520778a4b --- /dev/null +++ b/python/paddle/v2/fluid/tests/book_distribute/test_dist_word2vec.py @@ -0,0 +1,96 @@ +from __future__ import print_function +import numpy as np +import paddle.v2 as paddle +import paddle.v2.fluid as fluid +import os + +PASS_NUM = 100 +EMBED_SIZE = 32 +HIDDEN_SIZE = 256 +N = 5 +BATCH_SIZE = 32 +IS_SPARSE = True +TRAINERS = 2 + +word_dict = paddle.dataset.imikolov.build_dict() +dict_size = len(word_dict) + +first_word = fluid.layers.data(name='firstw', shape=[1], dtype='int64') +second_word = fluid.layers.data(name='secondw', shape=[1], dtype='int64') +third_word = fluid.layers.data(name='thirdw', shape=[1], dtype='int64') +forth_word = fluid.layers.data(name='forthw', shape=[1], dtype='int64') +next_word = fluid.layers.data(name='nextw', shape=[1], dtype='int64') + +embed_first = fluid.layers.embedding( + input=first_word, + size=[dict_size, EMBED_SIZE], + dtype='float32', + is_sparse=IS_SPARSE, + param_attr='shared_w') +embed_second = fluid.layers.embedding( + input=second_word, + size=[dict_size, EMBED_SIZE], + dtype='float32', + is_sparse=IS_SPARSE, + param_attr='shared_w') +embed_third = fluid.layers.embedding( + input=third_word, + size=[dict_size, EMBED_SIZE], + dtype='float32', + is_sparse=IS_SPARSE, + param_attr='shared_w') +embed_forth = fluid.layers.embedding( + input=forth_word, + size=[dict_size, EMBED_SIZE], + dtype='float32', + is_sparse=IS_SPARSE, + param_attr='shared_w') + +concat_embed = fluid.layers.concat( + input=[embed_first, embed_second, embed_third, embed_forth], axis=1) +hidden1 = fluid.layers.fc(input=concat_embed, size=HIDDEN_SIZE, act='sigmoid') +predict_word = fluid.layers.fc(input=hidden1, size=dict_size, act='softmax') +cost = fluid.layers.cross_entropy(input=predict_word, label=next_word) +avg_cost = fluid.layers.mean(x=cost) +sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.001) +optimize_ops, params_grads = sgd_optimizer.minimize(avg_cost) +train_reader = paddle.batch( + paddle.dataset.imikolov.train(word_dict, N), BATCH_SIZE) + +place = fluid.CPUPlace() +exe = fluid.Executor(place) + +t = fluid.DistributeTranspiler() +# all parameter server endpoints list for spliting parameters +pserver_endpoints = os.getenv("PSERVERS") +# server endpoint for current node +current_endpoint = os.getenv("SERVER_ENDPOINT") +# run as trainer or parameter server +training_role = os.getenv("TRAINING_ROLE", + "TRAINER") # get the training role: trainer/pserver +t.transpile( + optimize_ops, params_grads, pservers=pserver_endpoints, trainers=TRAINERS) +if training_role == "PSERVER": + if not current_endpoint: + print("need env SERVER_ENDPOINT") + exit(1) + pserver_prog = t.get_pserver_program(current_endpoint, optimize_ops) + exe.run(fluid.default_startup_program()) + exe.run(pserver_prog) +elif training_role == "TRAINER": + feeder = fluid.DataFeeder( + feed_list=[first_word, second_word, third_word, forth_word, next_word], + place=place) + exe.run(fluid.default_startup_program()) + for pass_id in range(PASS_NUM): + for data in train_reader(): + avg_cost_np = exe.run(fluid.default_main_program(), + feed=feeder.feed(data), + fetch_list=[avg_cost]) + print("avg_cost_np", avg_cost_np) + if avg_cost_np[0] < 5.0: + exit( + 0) # if avg cost less than 10.0, we think our code is good. +else: + print("environment var TRAINER_ROLE should be TRAINER os PSERVER") +exit(1) diff --git a/python/paddle/v2/fluid/tests/test_batch_norm_op.py b/python/paddle/v2/fluid/tests/test_batch_norm_op.py index abbd48d2b843cedb77caffc13413d2f9695defa6..ac9418549f45f818257d74045cabb9c581816968 100644 --- a/python/paddle/v2/fluid/tests/test_batch_norm_op.py +++ b/python/paddle/v2/fluid/tests/test_batch_norm_op.py @@ -341,9 +341,6 @@ class TestBatchNormOp(OpTest): if core.is_compile_gpu() and core.op_support_gpu("batch_norm"): places.append(core.CUDAPlace(0)) - core.init_devices(["CPU", "GPU:0"]) - else: - core.init_devices(["CPU"]) for place in places: for data_format in ["NCHW", "NHWC"]: test_with_place(place, data_format, [2, 3, 4, 5]) diff --git a/python/paddle/v2/fluid/tests/test_beam_search_op.py b/python/paddle/v2/fluid/tests/test_beam_search_op.py index 595f132fa85f0a65f15d9ac31ad320e567c96358..319a7e49e35b0515e69703b2d03080cd9ffcae9d 100644 --- a/python/paddle/v2/fluid/tests/test_beam_search_op.py +++ b/python/paddle/v2/fluid/tests/test_beam_search_op.py @@ -37,13 +37,13 @@ class BeamSearchOpTester(unittest.TestCase): print 'lod', selected_ids.lod() def _create_pre_ids(self): - np_data = np.array([[1, 2, 3, 4]], dtype='int32') + np_data = np.array([[1, 2, 3, 4]], dtype='int64') tensor = create_tensor(self.scope, "pre_ids", np_data) def _create_ids(self): self.lod = [[0, 1, 4], [0, 1, 2, 3, 4]] np_data = np.array( - [[4, 2, 5], [2, 1, 3], [3, 5, 2], [8, 2, 1]], dtype='int32') + [[4, 2, 5], [2, 1, 3], [3, 5, 2], [8, 2, 1]], dtype='int64') tensor = create_tensor(self.scope, "ids", np_data) tensor.set_lod(self.lod) diff --git a/python/paddle/v2/fluid/tests/test_calc_gradient.py b/python/paddle/v2/fluid/tests/test_calc_gradient.py new file mode 100644 index 0000000000000000000000000000000000000000..c34c8ff6d143ff2c8ae0def935d2b44982c0764e --- /dev/null +++ b/python/paddle/v2/fluid/tests/test_calc_gradient.py @@ -0,0 +1,25 @@ +import unittest + +import paddle.v2.fluid as fluid +import paddle.v2.fluid.layers as layers +import paddle.v2.fluid.framework as framework +import paddle.v2.fluid.optimizer as optimizer +from paddle.v2.fluid.backward import calc_gradient + + +class TestCalcGradient(unittest.TestCase): + def test_calc_gradient(self): + x = layers.create_parameter(dtype="float32", shape=[5, 10]) + y = layers.create_parameter(dtype="float32", shape=[10, 8]) + mul_out = layers.mul(x=x, y=y) + mean_out = layers.mean(x=mul_out) + a = calc_gradient(mean_out, mul_out) + b = calc_gradient(mean_out, x) + place = fluid.CPUPlace() + exe = fluid.Executor(place) + exe.run(fluid.default_startup_program()) + exe.run(fluid.default_main_program(), feed={}, fetch_list=[a, b]) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/v2/fluid/tests/test_memory_optimization_transpiler.py b/python/paddle/v2/fluid/tests/test_memory_optimization_transpiler.py new file mode 100644 index 0000000000000000000000000000000000000000..5cce75ddb8df50a35156fc2b8b411823711989c0 --- /dev/null +++ b/python/paddle/v2/fluid/tests/test_memory_optimization_transpiler.py @@ -0,0 +1,33 @@ +from __future__ import print_function +import unittest + +import paddle.v2.fluid.layers as layers +import paddle.v2.fluid.optimizer as optimizer +from paddle.v2.fluid.framework import Program, program_guard +from paddle.v2.fluid.memory_optimization_transpiler import memory_optimize + + +class TestControlFlowGraph(unittest.TestCase): + def setUp(self): + program = Program() + with program_guard(program, startup_program=Program()): + x = layers.data(name='x', shape=[13], dtype='float32') + y_predict = layers.fc(input=x, size=1, act=None) + y = layers.data(name='y', shape=[1], dtype='float32') + cost = layers.square_error_cost(input=y_predict, label=y) + avg_cost = layers.mean(x=cost) + opt = optimizer.SGD(learning_rate=0.001) + opt = opt.minimize(avg_cost) + + self.program = program + + def test_control_flow_graph(self): + print("before optimization") + print(str(self.program)) + result_program = memory_optimize(self.program) + print("after optimization") + print(str(result_program)) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/v2/fluid/tests/test_parallel_op.py b/python/paddle/v2/fluid/tests/test_parallel_op.py index 2788f4e519b31b45250fbb923b2309e8bb1f6fa1..59ed041e7fa1dd68c0f8d610f2575886442d1b4d 100644 --- a/python/paddle/v2/fluid/tests/test_parallel_op.py +++ b/python/paddle/v2/fluid/tests/test_parallel_op.py @@ -18,7 +18,7 @@ class ParallelOpTest(unittest.TestCase): append_batch_size=False, stop_gradient=False) - places = fluid.default_main_program().global_block().create_var() + places = layers.get_places(device_count=4) pd = layers.ParallelDo(places=places) with pd.do(): diff --git a/python/paddle/v2/fluid/tests/test_sequence_softmax_op.py b/python/paddle/v2/fluid/tests/test_sequence_softmax_op.py index b54a56aa6d3f76baa4d1fc6ba8f963332deba002..8bffdd585699bfae2262bcfcd0387d22fa1e62db 100644 --- a/python/paddle/v2/fluid/tests/test_sequence_softmax_op.py +++ b/python/paddle/v2/fluid/tests/test_sequence_softmax_op.py @@ -1,13 +1,7 @@ import unittest import numpy as np from op_test import OpTest - - -def stable_softmax(x): - """Compute the softmax of vector x in a numerically stable way.""" - shiftx = x - np.max(x).clip(-64.) - exps = np.exp(shiftx) - return exps / np.sum(exps) +from test_softmax_op import stable_softmax class TestSequenceSoftmaxOp(OpTest): diff --git a/python/paddle/v2/fluid/tests/test_warpctc_op.py b/python/paddle/v2/fluid/tests/test_warpctc_op.py new file mode 100644 index 0000000000000000000000000000000000000000..59390d5303b9642ede0d421e908a1b129c68a072 --- /dev/null +++ b/python/paddle/v2/fluid/tests/test_warpctc_op.py @@ -0,0 +1,200 @@ +import sys +import unittest +import numpy as np +from op_test import OpTest +from test_softmax_op import stable_softmax + + +class CTCForward(object): + def __init__(self, softmax, softmax_lod, labels, labels_lod, blank, + norm_by_times): + self.softmax = softmax + self.softmax_lod = softmax_lod + assert labels.shape[1] == 1 + self.labels = labels + self.labels_lod = labels_lod + self.blank = blank + self.norm_by_times = norm_by_times + + self.level = 0 + self.num_classes = softmax.shape[1] + self.batch_size = len(softmax_lod[self.level]) - 1 + assert self.batch_size == len(labels_lod[self.level]) - 1 + + self.loss = np.zeros([self.batch_size, 1], dtype="float32") + self.gradient = np.zeros(self.softmax.shape, dtype="float32") + + # float64 + self.EXP_MAX = sys.float_info.max + self.EXP_MIN = sys.float_info.min + self.LOG_ZERO = np.log(self.EXP_MIN) + self.LOG_INFINITY = np.log(self.EXP_MAX) + + def safe_exp(self, x): + if x <= self.LOG_ZERO: + return 0.0 + if x >= self.LOG_INFINITY: + return self.EXP_MAX + return np.exp(x) + + def safe_log(self, x): + if x <= self.EXP_MIN: + return self.LOG_ZERO + return np.log(x) + + # x = lna and y = lnb are in log scale, ln(a / b) = lna - lnb + def log_div(self, x, y): + res = x - y + if res <= self.LOG_ZERO: + return self.LOG_ZERO + if res >= self.LOG_INFINITY: + return self.LOG_INFINITY + return res + + # x = lna and y = lnb are in log scale, ln(a * b) = lna + lnb + def log_mul(self, x, y): + res = x + y + if res <= self.LOG_ZERO: + return self.LOG_ZERO + if res >= self.LOG_INFINITY: + return self.LOG_INFINITY + return res + + # x = lna and y = lnb are in log scale, + # ln(a + b) = lna + ln(1 + exp(lnb - lna)), where b > a + def log_add(self, x, y): + if x < y: + t = y + y = x + x = t + return x + self.safe_log(1 + self.safe_exp(y - x)) + + def segment_range(self, time, total_times, total_segments): + start = max(0, total_segments - (2 * (total_times - time))) + end = min(total_segments, 2 * (time + 1)) + return start, end + + def forward_a_sequence(self, softmax_a_sequence, labels_a_sequence): + total_times = softmax_a_sequence.shape[0] + total_segments = labels_a_sequence.shape[0] * 2 + 1 + + required_times = labels_a_sequence.shape[0] + old_label = -1 + for i in range(labels_a_sequence.shape[0]): + # two contingous labels with the same value + if labels_a_sequence[i, 0] == old_label: + required_times = required_times + 1 + old_label = labels_a_sequence[i, 0] + + if total_times < required_times: + return 0 + + # calculate the forward and backward variables, + # reference Chapter 7.3 of "Alex Grave, Supervised Sequence + # Labelling with Recurrent Neural Networks" + log_acts = np.zeros([total_times, self.num_classes], dtype="float32") + for i in range(total_times): + for j in range(self.num_classes): + log_acts[i, j] = self.safe_log(softmax_a_sequence[i, j]) + + # calculate the forward variables + forward_vars = np.zeros([total_times, total_segments], dtype="float32") + for i in range(total_times): + for j in range(total_segments): + forward_vars[i, j] = self.LOG_ZERO + + for i in range(total_times): + # dp initialization at t0 + if i == 0: + forward_vars[i, 0] = log_acts[0, self.blank] + if total_segments > 1: + forward_vars[i, 1] = log_acts[0, labels_a_sequence[i, 0]] + continue + + # dp from t1 + start, end = self.segment_range(i, total_times, total_segments) + for k in range(end - start): + j = k + start + if j & 1 == 1: + label_idx = j / 2 + label_val = labels_a_sequence[label_idx, 0] + fv = self.log_add(forward_vars[i - 1, j], + forward_vars[i - 1, j - 1]) + if j > 1 and label_val != labels_a_sequence[label_idx - 1, + 0]: + fv = self.log_add(fv, forward_vars[i - 1, j - 2]) + fv = self.log_mul(fv, log_acts[i, label_val]) + else: + fv = forward_vars[i - 1, j] + if j > 0: + fv = self.log_add(fv, forward_vars[i - 1, j - 1]) + fv = self.log_mul(fv, log_acts[i, self.blank]) + forward_vars[i, j] = fv + + # sum the last two value as log_prob + log_prob = forward_vars[total_times - 1, total_segments - 1] + if total_segments > 1: + log_prob = self.log_add( + log_prob, forward_vars[total_times - 1, total_segments - 2]) + + return -log_prob + + def forward(self): + for i in range(self.batch_size): + softmax_start_i = self.softmax_lod[self.level][i] + softmax_end_i = self.softmax_lod[self.level][i + 1] + labels_start_i = self.labels_lod[self.level][i] + labels_end_i = self.labels_lod[self.level][i + 1] + + softmax_a_sequence = self.softmax[softmax_start_i:softmax_end_i, :] + labels_a_sequence = self.labels[labels_start_i:labels_end_i, :] + self.loss[i] = self.forward_a_sequence(softmax_a_sequence, + labels_a_sequence) + return self.loss + + +class TestWarpCTCOp(OpTest): + def setUp(self): + self.op_type = "warpctc" + + batch_size = 4 + num_classes = 8 + logits_lod = [[0, 4, 5, 8, 11]] + logits = np.random.uniform(0.1, 1.0, + [11, num_classes]).astype("float32") + softmax = np.apply_along_axis(stable_softmax, 1, logits) + labels_lod = [[0, 3, 4, 8, 12]] + # labels should not be blank + labels = np.random.randint(0, num_classes - 1, [12, 1], dtype="int32") + + blank = num_classes - 1 + norm_by_times = False + + ctc = CTCForward(softmax, logits_lod, labels, labels_lod, blank, + norm_by_times) + loss = ctc.forward() + + max_sequence_length = 0 + for i in range(batch_size): + max_sequence_length = max(max_sequence_length, + logits_lod[0][i + 1] - logits_lod[0][i]) + gradient = np.zeros( + [max_sequence_length, batch_size, num_classes], dtype="float32") + + self.inputs = { + "Logits": (logits, logits_lod), + "Label": (labels, labels_lod) + } + self.outputs = {"Loss": loss} + self.attrs = {"blank": blank, "norm_by_times": norm_by_times} + + def test_check_output(self): + self.check_output() + + +# def test_check_grad(self): +# self.outputs["WarpCTCGrad"] = None +# self.check_grad(["Logits"], "Loss", max_relative_error=0.01) + +if __name__ == "__main__": + unittest.main()