提交 e5e206e2 编写于 作者: Y Yang Yu

Merge branch 'develop' of github.com:baidu/Paddle into feature/refine_get_places_op

# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import argparse
import matplotlib.pyplot as plt
def parse_args():
parser = argparse.ArgumentParser('Parse Log')
parser.add_argument(
'--file_path', '-f', type=str, help='the path of the log file')
parser.add_argument(
'--sample_rate',
'-s',
type=float,
default=1.0,
help='the rate to take samples from log')
parser.add_argument(
'--log_period', '-p', type=int, default=1, help='the period of log')
args = parser.parse_args()
return args
def parse_file(file_name):
loss = []
error = []
with open(file_name) as f:
for i, line in enumerate(f):
line = line.strip()
if not line.startswith('pass'):
continue
line_split = line.split(' ')
if len(line_split) != 5:
continue
loss_str = line_split[2][:-1]
cur_loss = float(loss_str.split('=')[-1])
loss.append(cur_loss)
err_str = line_split[3][:-1]
cur_err = float(err_str.split('=')[-1])
error.append(cur_err)
accuracy = [1.0 - err for err in error]
return loss, accuracy
def sample(metric, sample_rate):
interval = int(1.0 / sample_rate)
if interval > len(metric):
return metric[:1]
num = len(metric) / interval
idx = [interval * i for i in range(num)]
metric_sample = [metric[id] for id in idx]
return metric_sample
def plot_metric(metric,
batch_id,
graph_title,
line_style='b-',
line_label='y',
line_num=1):
plt.figure()
plt.title(graph_title)
if line_num == 1:
plt.plot(batch_id, metric, line_style, label=line_label)
else:
for i in range(line_num):
plt.plot(batch_id, metric[i], line_style[i], label=line_label[i])
plt.xlabel('batch')
plt.ylabel(graph_title)
plt.legend()
plt.savefig(graph_title + '.jpg')
plt.close()
def main():
args = parse_args()
assert args.sample_rate > 0. and args.sample_rate <= 1.0, "The sample rate should in the range (0, 1]."
loss, accuracy = parse_file(args.file_path)
batch = [args.log_period * i for i in range(len(loss))]
batch_sample = sample(batch, args.sample_rate)
loss_sample = sample(loss, args.sample_rate)
accuracy_sample = sample(accuracy, args.sample_rate)
plot_metric(loss_sample, batch_sample, 'loss', line_label='loss')
plot_metric(
accuracy_sample,
batch_sample,
'accuracy',
line_style='g-',
line_label='accuracy')
if __name__ == '__main__':
main()
......@@ -63,7 +63,7 @@ ExternalProject_Add(
MESSAGE(STATUS "warp-ctc library: ${WARPCTC_LIBRARIES}")
INCLUDE_DIRECTORIES(${WARPCTC_INCLUDE_DIR})
ADD_LIBRARY(warpctc STATIC IMPORTED GLOBAL)
ADD_LIBRARY(warpctc SHARED IMPORTED GLOBAL)
SET_PROPERTY(TARGET warpctc PROPERTY IMPORTED_LOCATION ${WARPCTC_LIBRARIES})
ADD_DEPENDENCIES(warpctc extern_warpctc)
......
......@@ -105,8 +105,7 @@ static void BuildVar(const std::string& param_name,
TEST(Operator, CPUtoGPU) {
using namespace paddle::framework;
using namespace paddle::platform;
ASSERT_EQ(InitDevices({"CPU", "GPU:0"}), true);
InitDevices();
paddle::framework::Scope scope;
paddle::platform::CPUPlace cpu_place;
......
......@@ -35,7 +35,7 @@ const std::string kFetchOpType = "fetch";
Executor::Executor(const platform::Place& place) : place_(place) {}
void CreateTensor(Variable* var, proto::VarDesc::VarType var_type) {
static void CreateTensor(Variable* var, proto::VarDesc::VarType var_type) {
if (var_type == proto::VarDesc::LOD_TENSOR) {
var->GetMutable<LoDTensor>();
} else if (var_type == proto::VarDesc::SELECTED_ROWS) {
......
......@@ -45,7 +45,5 @@ class Executor {
const platform::Place place_;
};
void CreateTensor(Variable* var, proto::VarDesc::VarType var_type);
} // namespace framework
} // namespace paddle
......@@ -87,7 +87,11 @@ class GradOpDescMakerBase {
auto onames = this->Output(name);
ret_val.reserve(onames.size());
std::transform(onames.begin(), onames.end(), std::back_inserter(ret_val),
GradVarName);
[this](const std::string& fwd_var_name) -> std::string {
auto g_name = GradVarName(fwd_var_name);
(*this->grad_to_var_)[g_name] = fwd_var_name;
return g_name;
});
return ret_val;
}
......
......@@ -40,40 +40,23 @@ void InitGflags(std::vector<std::string> &argv) {
});
}
bool InitDevices(const std::vector<std::string> &devices) {
// device format
// CPU
// GPU:1
// TODO(dzhwinter) : add device format annotation for users.
void InitDevices() {
/*Init all avaiable devices by default */
std::vector<platform::Place> places;
for (auto &device : devices) {
auto p = string::Piece(device);
if (string::HasPrefix(p, "CPU")) {
places.emplace_back(platform::CPUPlace());
} else if (string::HasPrefix(p, "GPU")) {
places.emplace_back(platform::CPUPlace());
#ifdef PADDLE_WITH_CUDA
auto pos = string::RFind(p, ':', string::Piece::npos);
auto number = device.substr(pos + 1);
places.emplace_back(platform::CUDAPlace(std::stoi(number)));
int count = platform::GetCUDADeviceCount();
for (int i = 0; i < count; ++i) {
places.emplace_back(platform::CUDAPlace(i));
}
#else
LOG(WARNING)
<< "'GPU' is not supported, Please re-compile with WITH_GPU option";
LOG(WARNING)
<< "'GPU' is not supported, Please re-compile with WITH_GPU option";
#endif
} else {
return false;
}
}
if (std::find_if(places.begin(), places.end(),
[&](const platform::Place &place) {
return platform::is_cpu_place(place);
}) == places.end()) {
places.emplace_back(platform::CPUPlace());
LOG(WARNING) << "Not specified CPU device, create CPU by Default.";
}
platform::DeviceContextPool::Init(places);
// framework::UseALL();
return true;
}
void InitGLOG(const std::string &prog_name) {
......
......@@ -24,7 +24,7 @@ void InitGflags(std::vector<std::string> &argv);
void InitGLOG(const std::string &prog_name);
bool InitDevices(const std::vector<std::string> &devices);
void InitDevices();
} // namespace framework
} // namespace paddle
......@@ -14,18 +14,13 @@ limitations under the License. */
#include "gtest/gtest.h"
#include "paddle/framework/init.h"
#include "paddle/platform/device_context.h"
TEST(Init, InitDevices) {
TEST(InitDevices, CPU) {
using paddle::framework::InitDevices;
std::vector<std::string> ds1 = {"CPU"};
ASSERT_EQ(InitDevices(ds1), true);
using paddle::platform::DeviceContextPool;
#ifdef PADDLE_WITH_CUDA
std::vector<std::string> ds2 = {"CPU", "GPU:0", "GPU:1"};
ASSERT_EQ(InitDevices(ds2), true);
// test re-init
std::vector<std::string> ds3 = {"GPU:0", "GPU:1"};
ASSERT_EQ(InitDevices(ds3), true);
#endif
InitDevices();
DeviceContextPool& pool = DeviceContextPool::Instance();
ASSERT_GE(pool.size(), 1U);
}
......@@ -44,9 +44,19 @@ std::ostream &operator<<(std::ostream &os, const LoD &lod) {
}
std::ostream &operator<<(std::ostream &os, const LoDTensor &t) {
PADDLE_ENFORCE(platform::is_cpu_place(t.place()));
PADDLE_ENFORCE(t.type().hash_code() == typeid(float).hash_code());
if (!platform::is_cpu_place(t.place())) {
LoDTensor tt;
framework::Copy(t, platform::CPUPlace(), &tt);
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(t.place());
dev_ctx.Wait();
os << tt;
return os;
}
os << "dim: " << t.dims() << "\n";
os << "lod: " << t.lod() << "\n";
......@@ -211,38 +221,23 @@ void DeserializeFromStream(std::istream &is, LoDTensor *tensor,
DeserializeFromStream(is, static_cast<Tensor *>(tensor), dev_ctx);
}
// TODO(tonyyang-svail): make this function support LoD
std::vector<LoDTensor> LoDTensor::SplitLoDTensor(
const std::vector<platform::Place> places) const {
check_memory_size();
// PADDLE_ENFORCE(lod().empty() || (lod().size() == 1 && lod()[0].empty())
// , "Disable parallel lod for now");
PADDLE_ENFORCE(lod().empty(), "Disable parallel lod for now");
PADDLE_ENFORCE(dims()[0] % places.size() == 0,
"Batch size should be divided by places size");
std::vector<LoDTensor> lods;
for (size_t place_idx = 0; place_idx < places.size(); ++place_idx) {
size_t begin = place_idx * dims()[0] / places.size();
size_t end = (place_idx + 1) * dims()[0] / places.size();
auto src = Slice(static_cast<int>(begin), static_cast<int>(end));
int begin = place_idx * dims()[0] / places.size();
int end = (place_idx + 1) * dims()[0] / places.size();
LoDTensor dst;
dst.Resize(src.dims());
auto src = Slice(begin, end);
auto &dst_place = places[place_idx];
auto dst_ptr = dst.mutable_data(dst_place, src.type());
// TODO(tonyyang-svail):
// change the following to framework::Copy
auto src_place = src.place();
auto src_ptr = src.data<void>();
auto size = src.numel() * SizeOfType(src.type());
if (platform::is_cpu_place(src_place) &&
platform::is_cpu_place(dst_place)) {
memory::Copy(boost::get<platform::CPUPlace>(dst_place), dst_ptr,
boost::get<platform::CPUPlace>(src_place), src_ptr, size);
} else {
PADDLE_THROW("Not Implemented");
}
LoDTensor dst;
framework::Copy(src, dst_place, &dst);
lods.emplace_back(dst);
}
......@@ -250,28 +245,30 @@ std::vector<LoDTensor> LoDTensor::SplitLoDTensor(
return lods;
}
// TODO(tonyyang-svail): make this function support LoD
void LoDTensor::MergeLoDTensor(
const std::vector<const LoDTensor *> &lod_tensors, platform::Place place) {
PADDLE_ENFORCE(platform::is_cpu_place(place));
const std::vector<const LoDTensor *> &lod_tensors,
platform::Place dst_place) {
PADDLE_ENFORCE(!lod_tensors.empty());
framework::DDim new_dim = lod_tensors[0]->dims();
std::type_index new_type = lod_tensors[0]->type();
auto new_layout = lod_tensors[0]->layout();
for (auto *lod : lod_tensors) {
PADDLE_ENFORCE(new_dim == lod->dims());
PADDLE_ENFORCE(new_type == lod->type());
PADDLE_ENFORCE(platform::is_cpu_place(lod->place()));
PADDLE_ENFORCE(new_layout == lod->layout());
}
new_dim[0] *= lod_tensors.size();
Resize(new_dim);
set_layout(new_layout);
auto *dst_ptr = reinterpret_cast<uint8_t *>(mutable_data(place, new_type));
mutable_data(dst_place, new_type);
int begin = 0;
for (auto *src : lod_tensors) {
auto size = src->numel() * SizeOfType(src->type());
memory::Copy(boost::get<platform::CPUPlace>(place), dst_ptr,
boost::get<platform::CPUPlace>(src->place()),
src->data<void>(), size);
dst_ptr += size;
int end = begin + src->dims()[0];
auto dst = Slice(begin, end);
framework::Copy(*src, dst_place, &dst);
begin = end;
}
}
......
......@@ -115,5 +115,21 @@ TEST(LoD, AppendLoD) {
EXPECT_EQ(origin, expected);
}
TEST(LoD, ToAbsOffset) {
LoD relative_lod;
relative_lod.push_back(std::vector<size_t>({0, 2}));
relative_lod.push_back(std::vector<size_t>({0, 1, 3}));
relative_lod.push_back(std::vector<size_t>({0, 2, 4, 5}));
LoD abs_lod = paddle::framework::ToAbsOffset(relative_lod);
LoD expected;
expected.push_back(std::vector<size_t>({0, 5}));
expected.push_back(std::vector<size_t>({0, 2, 5}));
expected.push_back(std::vector<size_t>({0, 2, 4, 5}));
EXPECT_EQ(abs_lod, expected);
}
} // namespace framework
} // namespace paddle
......@@ -129,7 +129,7 @@ class OpDesc {
}
proto::OpDesc desc_;
// input arg name => output variable names
// input arg name => input variable names
VariableNameMap inputs_;
// output arg name => output variable names
VariableNameMap outputs_;
......
......@@ -69,7 +69,7 @@ REGISTER_OP_WITHOUT_GRADIENT(test_operator,
paddle::framework::OpWithoutKernelCheckerMaker);
TEST(OperatorBase, all) {
paddle::framework::InitDevices({"CPU"});
paddle::framework::InitDevices();
paddle::framework::proto::OpDesc op_desc;
op_desc.set_type("test_operator");
BuildVar("input", {"IN1"}, op_desc.add_inputs());
......@@ -195,7 +195,7 @@ REGISTER_OP_CPU_KERNEL(op_with_kernel,
// test with single input
TEST(OpKernel, all) {
paddle::framework::InitDevices({"CPU"});
paddle::framework::InitDevices();
paddle::framework::proto::OpDesc op_desc;
op_desc.set_type("op_with_kernel");
BuildVar("x", {"IN1"}, op_desc.add_inputs());
......@@ -225,7 +225,7 @@ REGISTER_OP_CPU_KERNEL(op_multi_inputs_with_kernel,
TEST(OpKernel, multi_inputs) {
using namespace paddle::framework;
paddle::framework::InitDevices({"CPU"});
paddle::framework::InitDevices();
proto::OpDesc op_desc;
op_desc.set_type("op_multi_inputs_with_kernel");
......@@ -264,7 +264,7 @@ class OperatorClone : public paddle::framework::OperatorBase {
};
TEST(Operator, Clone) {
paddle::framework::InitDevices({"CPU"});
paddle::framework::InitDevices();
OperatorClone a("ABC", paddle::framework::VariableNameMap{},
paddle::framework::VariableNameMap{},
paddle::framework::AttributeMap{});
......
......@@ -31,9 +31,10 @@ namespace framework {
*
* @note Copy supports CPU <-> GPU, GPU <-> GPU.
*/
inline void Copy(const Tensor& src, const platform::Place& dst_place,
const platform::DeviceContext& ctx, Tensor* dst) {
VLOG(3) << "Copy " << src.dims() << " from " << src.place() << " to "
<< dst_place;
src.check_memory_size();
dst->Resize(src.dims());
......@@ -88,26 +89,25 @@ inline void Copy(const Tensor& src, const platform::Place& dst_place,
}
/**
* @brief Copy supports CPU <-> CPU
* @brief Wrapper on
* Copy(const Tensor& src, const platform::Place& dst_place,
* const platform::DeviceContext& ctx, Tensor* dst);
*
* @param[in] src The external tensor.
* @param[in] dst_place The dst place.
*
* @note Copy supports CPU <-> GPU, GPU <-> GPU.
*/
inline void Copy(const Tensor& src, const platform::Place& dst_place,
Tensor* dst) {
src.check_memory_size();
dst->Resize(src.dims());
dst->set_layout(src.layout());
auto src_place = src.place();
auto src_ptr = src.data<void>();
auto dst_ptr = dst->mutable_data(dst_place, src.type());
auto size = src.numel() * SizeOfType(src.type());
PADDLE_ENFORCE(platform::is_cpu_place(src_place) &&
platform::is_cpu_place(dst_place));
memory::Copy(boost::get<platform::CPUPlace>(dst_place), dst_ptr,
boost::get<platform::CPUPlace>(src_place), src_ptr, size);
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
const platform::DeviceContext* dev_ctx;
if (platform::is_gpu_place(src.place())) {
dev_ctx = pool.Get(src.place());
} else {
dev_ctx = pool.Get(dst_place);
}
Copy(src, dst_place, *dev_ctx, dst);
}
/**
......
......@@ -74,7 +74,7 @@ const proto::TensorDesc &VarDesc::tensor_desc() const {
case proto::VarDesc::LOD_TENSOR_ARRAY:
return desc_.tensor_array().tensor();
default:
PADDLE_THROW("The type of var '", this->Name(), "' is unsupported.");
PADDLE_THROW("The type of var %s is unsupported.", this->Name());
}
}
......
......@@ -169,7 +169,7 @@ void InferenceEngine::Execute(const std::vector<framework::LoDTensor>& feeds,
}
auto* place = new platform::CPUPlace();
framework::InitDevices({"CPU"});
framework::InitDevices();
framework::Executor* executor = new framework::Executor(*place);
framework::Scope* scope = new framework::Scope();
......
......@@ -114,5 +114,21 @@ void Free<platform::CUDAPlace>(platform::CUDAPlace place, void* p) {
#endif
size_t Usage::operator()(const platform::CPUPlace& cpu) const {
return Used(cpu);
}
size_t Usage::operator()(const platform::CUDAPlace& gpu) const {
#ifdef PADDLE_WITH_CUDA
return Used(gpu);
#else
PADDLE_THROW("'CUDAPlace' is not supported in CPU only device.");
#endif
}
size_t memory_usage(const platform::Place& p) {
return boost::apply_visitor(Usage(), p);
}
} // namespace memory
} // namespace paddle
......@@ -54,6 +54,13 @@ void Free(Place place, void* ptr);
template <typename Place>
size_t Used(Place place);
struct Usage : public boost::static_visitor<size_t> {
size_t operator()(const platform::CPUPlace& cpu) const;
size_t operator()(const platform::CUDAPlace& gpu) const;
};
size_t memory_usage(const platform::Place& p);
/**
* \brief Free memory block in one place.
*
......
......@@ -44,6 +44,9 @@ TEST(BuddyAllocator, CPUAllocation) {
EXPECT_NE(p, nullptr);
paddle::platform::Place place = cpu;
EXPECT_EQ(paddle::memory::Used(cpu), paddle::memory::memory_usage(place));
paddle::memory::Free(cpu, p);
}
......@@ -99,6 +102,9 @@ TEST(BuddyAllocator, GPUAllocation) {
EXPECT_NE(p, nullptr);
paddle::platform::Place place = gpu;
EXPECT_EQ(paddle::memory::Used(gpu), paddle::memory::memory_usage(place));
paddle::memory::Free(gpu, p);
}
......
......@@ -151,6 +151,7 @@ op_library(lstm_op DEPS sequence2batch lstm_compute)
op_library(conv_transpose_op DEPS vol2col)
op_library(gru_op DEPS sequence2batch gru_compute)
op_library(recurrent_op DEPS executor)
op_library(warpctc_op DEPS dynload_warpctc sequence_padding math_function)
op_library(cos_sim_op DEPS cos_sim_functor)
op_library(parallel_do_op DEPS executor)
# FIXME(typhoonzero): save/load depends lodtensor serialization functions
......
......@@ -39,7 +39,7 @@ void BeamSearch::operator()(const framework::LoDTensor &pre_ids,
std::map<size_t /*offset*/, std::vector<Item>> hash;
framework::LoD new_lod;
auto *ids_data = selected_ids->mutable_data<int>(platform::CPUPlace());
auto *ids_data = selected_ids->mutable_data<int64_t>(platform::CPUPlace());
auto *scores_data =
selected_scores->mutable_data<float>(platform::CPUPlace());
......@@ -66,7 +66,7 @@ void BeamSearch::operator()(const framework::LoDTensor &pre_ids,
void BeamSearch::PruneEndidCandidates(const framework::LoDTensor &pre_ids,
std::vector<std::vector<Item>> *items) {
auto *pre_ids_data = pre_ids.data<int>();
auto *pre_ids_data = pre_ids.data<int64_t>();
for (size_t offset = 0; offset < items->size(); offset++) {
auto prefix_id = pre_ids_data[offset];
......@@ -127,7 +127,7 @@ bool BeamSearch::NextItemSet(std::vector<BeamSearch::Item> *items) {
auto abs_lod = framework::ToAbsOffset(ids.lod());
PADDLE_ENFORCE_GE(source_abs_two_level_lod.size(), 2UL);
auto *ids_data = ids.data<int>();
auto *ids_data = ids.data<int64_t>();
auto *scores_data = scores.data<float>();
size_t instance_dim = 1;
......
......@@ -230,7 +230,6 @@ void ConvOpGrad::InferShape(framework::InferShapeContext* ctx) const {
namespace ops = paddle::operators;
REGISTER_OP(conv2d, ops::ConvOp, ops::Conv2DOpMaker, conv2d_grad,
ops::ConvOpGrad);
namespace ops = paddle::operators;
REGISTER_OP(conv3d, ops::ConvOp, ops::Conv3DOpMaker, conv3d_grad,
ops::ConvOpGrad);
......
......@@ -12,6 +12,7 @@ if(WITH_GPU)
nv_library(vol2col SRCS vol2col.cc vol2col.cu DEPS device_context tensor)
nv_library(context_project SRCS context_project.cc context_project.cu DEPS device_context math_function)
nv_library(sequence2batch SRCS sequence2batch.cc sequence2batch.cu DEPS device_context tensor)
nv_library(sequence_padding SRCS sequence_padding.cc sequence_padding.cu DEPS lod_tensor device_context)
nv_library(lstm_compute SRCS lstm_compute.cc lstm_compute.cu DEPS device_context activation_functions)
nv_library(maxouting SRCS maxouting.cc maxouting.cu DEPS device_context)
nv_library(unpooling SRCS unpooling.cc unpooling.cu DEPS device_context)
......@@ -27,6 +28,7 @@ else()
cc_library(vol2col SRCS vol2col.cc DEPS device_context tensor)
cc_library(context_project SRCS context_project.cc DEPS device_context math_function)
cc_library(sequence2batch SRCS sequence2batch.cc DEPS device_context tensor)
cc_library(sequence_padding SRCS sequence_padding.cc DEPS lod_tensor device_context)
cc_library(lstm_compute SRCS lstm_compute.cc DEPS device_context activation_functions)
cc_library(maxouting SRCS maxouting.cc DEPS device_context)
cc_library(unpooling SRCS unpooling.cc DEPS device_context)
......@@ -38,3 +40,4 @@ cc_test(math_function_test SRCS math_function_test.cc DEPS math_function tensor)
cc_test(selected_rows_functor_test SRCS selected_rows_functor_test.cc DEPS selected_rows_functor)
cc_test(im2col_test SRCS im2col_test.cc DEPS math_function tensor)
cc_test(vol2col_test SRCS vol2col_test.cc DEPS vol2col tensor)
cc_test(sequence_padding_test SRCS sequence_padding_test.cc DEPS sequence_padding)
......@@ -14,7 +14,6 @@ limitations under the License. */
#include "paddle/operators/math/im2col.h"
#include <gtest/gtest.h>
#include <iostream>
template <typename DeviceContext, typename Place>
void testIm2col() {
......@@ -102,6 +101,7 @@ void testIm2col() {
Copy(output_ocf, paddle::platform::CPUPlace(), *context, &output_tmp);
out_ocf_ptr = output_tmp.data<float>();
}
for (int i = 0; i < 6; ++i) {
EXPECT_EQ(out_ocf_ptr[i], out_ocf_data[i]);
}
......@@ -154,6 +154,9 @@ void testIm2col() {
for (int i = 0; i < 6; ++i) {
EXPECT_EQ(in_ptr[i], col2im_data[i]);
}
delete place;
delete context;
}
TEST(math, im2col) {
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/math/sequence_padding.h"
namespace paddle {
namespace operators {
namespace math {
template <typename T>
class PaddingLoDTensorFunctor<platform::CPUDeviceContext, T> {
public:
void operator()(const platform::CPUDeviceContext& context,
const framework::LoDTensor& seq, framework::Tensor& padding,
bool norm_by_times) {
auto lod = seq.lod();
PADDLE_ENFORCE_GT(lod.size(), 0UL,
"The LoD of LoDTensor seq should not be null.");
const size_t level = 0;
framework::LoD abs_offset_lod = framework::ToAbsOffset(lod);
auto seq_dims = seq.dims();
PADDLE_ENFORCE_EQ(seq_dims[0], abs_offset_lod[level].back(),
"The first dimension of LoDTensor seq should be "
"equal to the sum of all sequences's length.");
auto padding_dims = padding.dims();
PADDLE_ENFORCE_EQ(padding_dims.size(), 3UL,
"The input padding should be a 3-D Tensor of shape "
"[max_sequence_length, num_sequences, sequence_width].");
const size_t max_sequence_length = MaximumSequenceLength(lod, level);
PADDLE_ENFORCE_EQ(padding_dims[0], max_sequence_length,
"The first dimension of Tensor padding should be the "
"maximum length of all sequences in LoDTensor seq.");
const size_t num_sequences = abs_offset_lod[level].size() - 1;
PADDLE_ENFORCE_EQ(padding_dims[1], num_sequences,
"The second dimension of Tensor padding should be the "
"number of sequences in LoDTensor seq.");
const size_t sequence_width = seq.numel() / seq_dims[0];
PADDLE_ENFORCE_EQ(padding_dims[2], sequence_width,
"The third dimension of Tensor padding should be the "
"width of sequence in LoDTensor seq.");
const T* seq_data = seq.data<T>();
T* padding_data = padding.data<T>();
for (size_t i = 0; i < max_sequence_length; ++i) {
for (size_t j = 0; j < num_sequences; ++j) {
size_t start_pos = abs_offset_lod[level][j];
size_t sequence_length = abs_offset_lod[level][j + 1] - start_pos;
if (i < sequence_length) {
// i > 0 => sequence_length > 0
T scale =
norm_by_times ? (1.0f / static_cast<T>(sequence_length)) : 1.0f;
for (size_t k = 0; k < sequence_width; ++k) {
padding_data[(i * num_sequences + j) * sequence_width + k] =
seq_data[(start_pos + i) * sequence_width + k] * scale;
}
} else {
memset(padding_data + (i * num_sequences + j) * sequence_width, 0,
sequence_width * sizeof(T));
}
}
}
}
};
template <typename T>
class UnpaddingLoDTensorFunctor<platform::CPUDeviceContext, T> {
public:
void operator()(const platform::CPUDeviceContext& context,
framework::LoDTensor& seq, const framework::Tensor& padding,
bool norm_by_times) {
auto lod = seq.lod();
PADDLE_ENFORCE_GT(lod.size(), 0UL,
"The LoD of LoDTensor seq should not be null.");
const size_t level = 0;
framework::LoD abs_offset_lod = framework::ToAbsOffset(lod);
auto seq_dims = seq.dims();
PADDLE_ENFORCE_EQ(seq_dims[0], abs_offset_lod[level].back(),
"The first dimension of LoDTensor seq should be "
"equal to the sum of all sequences's length.");
auto padding_dims = padding.dims();
PADDLE_ENFORCE_EQ(padding_dims.size(), 3UL,
"The input padding should be a 3-D Tensor of shape "
"[max_sequnece_length, num_sequences, sequence_width].");
const size_t max_sequence_length = MaximumSequenceLength(lod, level);
PADDLE_ENFORCE_EQ(padding_dims[0], max_sequence_length,
"The first dimension of Tensor padding should be "
"the maximum length of all sequences in LoDTensor seq.");
const size_t num_sequences = abs_offset_lod[level].size() - 1;
PADDLE_ENFORCE_EQ(padding_dims[1], num_sequences,
"The second dimension of Tensor padding should be "
"the number of sequences in LoDTensor seq.");
const size_t sequence_width = seq.numel() / seq_dims[0];
PADDLE_ENFORCE_EQ(padding_dims[2], sequence_width,
"The third dimension of Tensor padding should be the "
"width of sequence in LoDTensor seq.");
const T* padding_data = padding.data<T>();
T* seq_data = seq.data<T>();
for (size_t i = 0; i < num_sequences; ++i) {
size_t start_pos = abs_offset_lod[level][i];
size_t sequence_length = abs_offset_lod[level][i + 1] - start_pos;
for (size_t j = 0; j < sequence_length; ++j) {
// sequence_width > j > 0
T scale =
norm_by_times ? (1.0f / static_cast<T>(sequence_length)) : 1.0f;
for (size_t k = 0; k < sequence_width; ++k) {
seq_data[(start_pos + j) * sequence_width + k] =
padding_data[(j * num_sequences + i) * sequence_width + k] *
scale;
}
}
}
}
};
template class PaddingLoDTensorFunctor<platform::CPUDeviceContext, float>;
template class UnpaddingLoDTensorFunctor<platform::CPUDeviceContext, float>;
} // namespace math
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/math/sequence_padding.h"
namespace paddle {
namespace operators {
namespace math {
template <typename T, bool NormByTimes, bool Padding>
__global__ void SequencePaddingKernel(T* padding, T* sequence,
const size_t* sequence_start_positions,
const size_t sequence_width,
const size_t max_sequence_length,
const size_t num_sequences) {
size_t padding_idx = blockIdx.y;
size_t start_pos = sequence_start_positions[padding_idx];
size_t sequence_length =
sequence_start_positions[padding_idx + 1] - start_pos;
size_t sequence_idx = blockIdx.x * blockDim.y + threadIdx.y;
size_t padding_base_idx =
(sequence_idx * num_sequences + padding_idx) * sequence_width;
size_t sequence_base_idx = (start_pos + sequence_idx) * sequence_width;
if (sequence_idx < sequence_length) {
T scale = NormByTimes ? (1.0f / static_cast<T>(sequence_length)) : 1.0f;
if (Padding) {
/* sequence -> padding */
for (size_t i = threadIdx.x; i < sequence_width; i += blockDim.x) {
padding[padding_base_idx + i] = scale * sequence[sequence_base_idx + i];
}
} else {
/* padding -> sequence */
for (size_t i = threadIdx.x; i < sequence_width; i += blockDim.x) {
sequence[sequence_base_idx + i] = scale * padding[padding_base_idx + i];
}
}
} else if (sequence_idx < max_sequence_length) {
if (Padding) {
/* sequence -> padding */
for (size_t i = threadIdx.x; i < sequence_width; i += blockDim.x) {
padding[padding_base_idx + i] = 0;
}
}
}
}
template <typename T>
class PaddingLoDTensorFunctor<platform::CUDADeviceContext, T> {
public:
void operator()(const platform::CUDADeviceContext& context,
const framework::LoDTensor& seq, framework::Tensor& padding,
bool norm_by_times) {
auto lod = seq.lod();
PADDLE_ENFORCE_GT(lod.size(), 0UL,
"The lod of LoDTensor seq should not be null.");
const size_t level = 0;
framework::LoD abs_offset_lod = framework::ToAbsOffset(lod);
auto seq_dims = seq.dims();
PADDLE_ENFORCE_EQ(seq_dims[0], abs_offset_lod[level].back(),
"The first dimension of LoDTensor seq should be "
"equal to the sum of all sequences's length.");
auto padding_dims = padding.dims();
PADDLE_ENFORCE_EQ(padding_dims.size(), 3UL,
"The input padding should be a 3-D Tensor of shape "
"[max_sequence_length, num_sequences, sequence_width].");
size_t max_sequence_length = MaximumSequenceLength(lod, level);
PADDLE_ENFORCE_EQ(padding_dims[0], max_sequence_length,
"The first dimension of Tensor padding should be the "
"maximum length of all sequences in LoDTensor seq.");
const size_t num_sequences = abs_offset_lod[level].size() - 1;
PADDLE_ENFORCE_EQ(padding_dims[1], num_sequences,
"The second dimension of Tensor padding should be the "
"number of sequences in LoDTensor seq.");
const size_t sequence_width = seq.numel() / seq_dims[0];
PADDLE_ENFORCE_EQ(padding_dims[2], sequence_width,
"The third dimension of Tensor padding should be the "
"width of sequence in LoDTensor seq.");
if (!norm_by_times && num_sequences == 1UL) {
Copy(seq, context.GetPlace(), context, &padding);
padding.Resize(padding_dims);
return;
}
const size_t kBlockSize = 512;
/* At least use 32 threads to copy sequence_width elements,
* and at least 8 elements for each thread.
*/
size_t block_dim_x =
std::min(((((sequence_width + 7) >> 3) + 31) >> 5) << 5, kBlockSize);
size_t block_dim_y = kBlockSize / block_dim_x;
dim3 threads(block_dim_x, block_dim_y);
size_t grid_dim_x = (max_sequence_length + block_dim_y - 1) / block_dim_y;
size_t grid_dim_y = num_sequences;
dim3 grid(grid_dim_x, grid_dim_y);
const T* seq_data = seq.data<T>();
T* padding_data = padding.data<T>();
if (norm_by_times) {
SequencePaddingKernel<T, 1, 1><<<grid, threads, 0, context.stream()>>>(
padding_data, const_cast<T*>(seq_data), abs_offset_lod[level].data(),
sequence_width, max_sequence_length, num_sequences);
} else {
SequencePaddingKernel<T, 0, 1><<<grid, threads, 0, context.stream()>>>(
padding_data, const_cast<T*>(seq_data), abs_offset_lod[level].data(),
sequence_width, max_sequence_length, num_sequences);
}
}
};
template <typename T>
class UnpaddingLoDTensorFunctor<platform::CUDADeviceContext, T> {
public:
void operator()(const platform::CUDADeviceContext& context,
framework::LoDTensor& seq, const framework::Tensor& padding,
bool norm_by_times) {
auto lod = seq.lod();
PADDLE_ENFORCE_GT(lod.size(), 0UL,
"The lod of LoDTensor seq should not be null.");
const size_t level = 0;
framework::LoD abs_offset_lod = framework::ToAbsOffset(lod);
auto seq_dims = seq.dims();
PADDLE_ENFORCE_EQ(seq_dims[0], abs_offset_lod[level].back(),
"The first dimension of LoDTensor seq should be "
"equal to the sum of all sequences's length.");
auto padding_dims = padding.dims();
PADDLE_ENFORCE_EQ(padding_dims.size(), 3UL,
"The input padding should be a 3-D Tensor of shape "
"[max_sequnece_length, num_sequences, sequence_width].");
size_t max_sequence_length = MaximumSequenceLength(lod, level);
PADDLE_ENFORCE_EQ(padding_dims[0], max_sequence_length,
"The first dimension of Tensor padding should be "
"the maximum length of all sequences in LoDTensor seq.");
const size_t num_sequences = abs_offset_lod[level].size() - 1;
PADDLE_ENFORCE_EQ(padding_dims[1], num_sequences,
"The second dimension of Tensor padding should be "
"the number of sequences in LoDTensor seq.");
const size_t sequence_width = seq.numel() / seq_dims[0];
PADDLE_ENFORCE_EQ(padding_dims[2], sequence_width,
"The third dimension of Tensor padding should be the "
"width of sequence in LoDTensor seq.");
if (!norm_by_times && num_sequences == 1UL) {
Copy(padding, context.GetPlace(), context, &seq);
seq.Resize(seq_dims);
return;
}
const size_t kBlockSize = 512;
/* At least use 32 threads to copy sequence_width elements,
* and at least 8 elements for each thread.
*/
size_t block_dim_x =
std::min(((((sequence_width + 7) >> 3) + 31) >> 5) << 5, kBlockSize);
size_t block_dim_y = kBlockSize / block_dim_x;
dim3 threads(block_dim_x, block_dim_y);
size_t grid_dim_x = (max_sequence_length + block_dim_y - 1) / block_dim_y;
size_t grid_dim_y = num_sequences;
dim3 grid(grid_dim_x, grid_dim_y);
const T* padding_data = padding.data<T>();
T* seq_data = seq.data<T>();
if (norm_by_times) {
SequencePaddingKernel<T, 1, 0><<<grid, threads, 0, context.stream()>>>(
const_cast<T*>(padding_data), seq_data, abs_offset_lod[level].data(),
sequence_width, max_sequence_length, num_sequences);
} else {
SequencePaddingKernel<T, 0, 0><<<grid, threads, 0, context.stream()>>>(
const_cast<T*>(padding_data), seq_data, abs_offset_lod[level].data(),
sequence_width, max_sequence_length, num_sequences);
}
}
};
template class PaddingLoDTensorFunctor<platform::CUDADeviceContext, float>;
template class UnpaddingLoDTensorFunctor<platform::CUDADeviceContext, float>;
} // namespace math
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/lod_tensor.h"
#include "paddle/platform/device_context.h"
namespace paddle {
namespace operators {
namespace math {
inline static size_t MaximumSequenceLength(const framework::LoD& lod,
const size_t level) {
const size_t num_sequences = lod[level].size() - 1;
size_t max_sequence_length = 0;
framework::LoD abs_offset_lod = framework::ToAbsOffset(lod);
for (size_t i = 0; i < num_sequences; ++i) {
max_sequence_length =
std::max(max_sequence_length,
abs_offset_lod[level][i + 1] - abs_offset_lod[level][i]);
}
return max_sequence_length;
}
/*
* \brief Padding/Unpadding LoDTensor to/from normal Tensor of the shape
* [max_sequence_length, num_sequences, sequence_width].
*
* Padding sequence:
* padding[i] = seq[lod[level][i]]
* Unpadding sequence:
* seq[lod[level][i]] = padding[i]
*
* All sequences will be padded to the same length and stored in a transposed
* shape.
* Example:
* seq (s0, s0, s0, s0; s1, s1; s2, s2, s2; s3)
* padding (s0, s1, s2, s3; s0, s1, s2, 0; s0, 0, s2, 0; s0, 0, 0, 0)
*
* \param context device context of this functor.
* \param seq LoDTensor which is stored in sequence format, the shape
* is [total_sequence_length, sequence_width] where
* total_sequence_length is the sum of all sequences'
* length.
* \param padding Tensor which is padded to the same length, the shape is
* [max_sequence_length, num_sequences, sequence_width].
* \param norm_by_times whether dividing sequence's length.
*
* \note transposition is also done in this functor.
*/
template <typename DeviceContext, typename T>
class PaddingLoDTensorFunctor {
public:
void operator()(const DeviceContext& context, const framework::LoDTensor& seq,
framework::Tensor& padding, bool norm_by_times);
};
template <typename DeviceContext, typename T>
class UnpaddingLoDTensorFunctor {
public:
void operator()(const DeviceContext& context, framework::LoDTensor& seq,
const framework::Tensor& padding, bool norm_by_times);
};
} // namespace math
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/math/sequence_padding.h"
#include <gtest/gtest.h>
template <typename DeviceContext, typename Place, typename T>
void TestSequencePadding(const paddle::framework::LoD& lod,
const size_t sequence_width) {
paddle::framework::LoDTensor cpu_seq;
paddle::framework::LoDTensor cpu_seq_back;
paddle::framework::LoDTensor seq;
paddle::framework::LoDTensor seq_back;
paddle::framework::Tensor padding;
const size_t level = lod.size() - 1;
auto seq_dims =
paddle::framework::make_ddim({static_cast<int64_t>(lod[level].back()),
static_cast<int64_t>(sequence_width)});
cpu_seq.set_lod(lod);
cpu_seq.mutable_data<T>(seq_dims, paddle::platform::CPUPlace());
for (size_t i = 0; i < cpu_seq.numel(); ++i) {
cpu_seq.data<T>()[i] = static_cast<T>(i);
}
auto* place = new Place();
DeviceContext* context = new DeviceContext(*place);
if (paddle::platform::is_cpu_place(*place)) {
seq = cpu_seq;
} else {
Copy(cpu_seq, *place, *context, &seq);
seq.set_lod(lod);
}
const size_t max_sequence_length =
paddle::operators::math::MaximumSequenceLength(lod, level);
const size_t num_sequences = lod[level].size() - 1;
auto padding_dims =
paddle::framework::make_ddim({static_cast<int64_t>(max_sequence_length),
static_cast<int64_t>(num_sequences),
static_cast<int64_t>(sequence_width)});
padding.mutable_data<T>(padding_dims, *place);
paddle::operators::math::PaddingLoDTensorFunctor<DeviceContext, T>()(
*context, seq, padding, false);
seq_back.set_lod(lod);
seq_back.mutable_data<T>(seq_dims, *place);
paddle::operators::math::UnpaddingLoDTensorFunctor<DeviceContext, T>()(
*context, seq_back, padding, false);
if (paddle::platform::is_cpu_place(*place)) {
cpu_seq_back = seq_back;
} else {
Copy(seq_back, paddle::platform::CPUPlace(), *context, &cpu_seq_back);
cpu_seq_back.set_lod(lod);
}
EXPECT_EQ(cpu_seq.numel(), cpu_seq_back.numel());
EXPECT_EQ(cpu_seq.dims(), cpu_seq_back.dims());
for (size_t i = 0; i < cpu_seq.numel(); ++i) {
EXPECT_EQ(cpu_seq.data<T>()[i], cpu_seq_back.data<T>()[i]);
}
delete place;
delete context;
};
TEST(Seq2BatchPadding, CPU) {
paddle::framework::LoD lod1;
lod1.push_back(std::vector<size_t>{0, 10});
TestSequencePadding<paddle::platform::CPUDeviceContext,
paddle::platform::CPUPlace, float>(lod1, 16);
paddle::framework::LoD lod2;
lod2.push_back(std::vector<size_t>{0, 2, 7, 10});
TestSequencePadding<paddle::platform::CPUDeviceContext,
paddle::platform::CPUPlace, float>(lod2, 128);
}
#ifdef PADDLE_WITH_CUDA
TEST(SequencePadding, CUDA) {
paddle::framework::LoD lod1;
lod1.push_back(std::vector<size_t>{0, 10});
TestSequencePadding<paddle::platform::CUDADeviceContext,
paddle::platform::CUDAPlace, float>(lod1, 16);
paddle::framework::LoD lod2;
lod2.push_back(std::vector<size_t>{0, 2, 7, 10});
TestSequencePadding<paddle::platform::CUDADeviceContext,
paddle::platform::CUDAPlace, float>(lod2, 128);
}
#endif
......@@ -39,7 +39,7 @@ class NormOpMaker : public framework::OpProtoAndCheckerMaker {
"M = C * H * W");
AddComment(R"DOC(
"Input shape: $(N, C, H, W)$
Sclae shape: $(C, 1)$
Scale shape: $(C, 1)$
Output shape: $(N, C, H, W)$
Where
forward
......
......@@ -66,7 +66,7 @@ class NormKernel : public framework::OpKernel<T> {
context.GetPlace());
auto tmp = framework::EigenVector<T, Eigen::RowMajor,
Eigen::DenseIndex>::Flatten(tmp_tensor);
// get colsum and sqrt , inverse
// get colsum and sqrt , inverse
auto dim = Eigen::array<int, 1>({{0}});
tmp.device(*place) = x_square_batch_eigen.sum(dim);
tmp.device(*place) = (tmp + epsilon).sqrt().inverse();
......
......@@ -39,6 +39,7 @@ void SplitTensorAndMoveTensorToScopes(
const std::vector<framework::Scope *> &sub_scopes,
const std::vector<platform::Place> &places,
const std::vector<std::string> &names) {
PADDLE_ENFORCE_EQ(sub_scopes.size(), places.size());
for (auto &argu : names) {
auto *var = scope.FindVar(argu);
const auto &tensor = var->Get<LoDTensor>();
......@@ -54,6 +55,15 @@ void SplitTensorAndMoveTensorToScopes(
}
}
void WaitOnPlaces(const std::vector<platform::Place> places) {
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
for (auto &place : places) {
auto &dev_ctx = *pool.Get(place);
dev_ctx.Wait();
}
}
class ParallelDoOp : public framework::OperatorBase {
public:
ParallelDoOp(const std::string &type,
......@@ -71,10 +81,7 @@ class ParallelDoOp : public framework::OperatorBase {
auto *block = Attr<framework::BlockDesc *>(kParallelBlock);
auto *program = block->Program();
// TODO(tonyyang-svail): get places from input
std::vector<platform::Place> places;
places.emplace_back(platform::CPUPlace());
places.emplace_back(platform::CPUPlace());
auto &places = scope.FindVar(Input(kPlaces))->Get<platform::PlaceList>();
auto &sub_scopes = *scope.FindVar(Output(kParallelScopes))
->GetMutable<std::vector<framework::Scope *>>();
......@@ -82,8 +89,22 @@ class ParallelDoOp : public framework::OperatorBase {
sub_scopes.push_back(&scope.NewScope());
}
// split input
SplitTensorAndMoveTensorToScopes(scope, sub_scopes, places,
Inputs(kInputs));
// copy parameter
for (auto &param : Inputs(kParameters)) {
PADDLE_ENFORCE(scope.FindVar(param)->IsType<LoDTensor>(),
"Only support parameter type as LoDTensor");
auto &src = scope.FindVar(param)->Get<LoDTensor>();
for (size_t i = 0; i < places.size(); ++i) {
auto &place = places[i];
auto *sub_scope = sub_scopes[i];
auto *dst = sub_scope->Var(param)->GetMutable<LoDTensor>();
framework::Copy(src, place, dst);
}
}
WaitOnPlaces(places);
std::vector<std::future<void>> workers;
workers.reserve(places.size());
......@@ -93,12 +114,6 @@ class ParallelDoOp : public framework::OperatorBase {
auto &place = places[place_idx];
auto *cur_scope = sub_scopes[place_idx];
// copy parameter
// some version of boost lacks != for boost::variant
if (!(dev_ctx.GetPlace() == place)) {
PADDLE_THROW("Not Implemented");
}
workers.emplace_back(framework::Async([program, cur_scope, place, block] {
framework::Executor executor(place);
executor.Run(*program, cur_scope, block->ID(),
......@@ -108,6 +123,7 @@ class ParallelDoOp : public framework::OperatorBase {
for (auto &worker : workers) {
worker.wait();
}
WaitOnPlaces(places);
// merge output
for (auto &o_name : Outputs(kOutputs)) {
......@@ -121,6 +137,7 @@ class ParallelDoOp : public framework::OperatorBase {
scope.FindVar(o_name)->GetMutable<LoDTensor>();
lod_tensor_to_be_merged->MergeLoDTensor(lod_tensors, dev_ctx.GetPlace());
}
WaitOnPlaces(places);
}
};
......@@ -161,15 +178,14 @@ class ParallelDoGradOp : public OperatorBase {
auto &sub_scopes = scope.FindVar(Input(kParallelScopes))
->Get<std::vector<framework::Scope *>>();
// TODO(tonyyang-svail): get places from input
std::vector<platform::Place> places;
places.emplace_back(platform::CPUPlace());
places.emplace_back(platform::CPUPlace());
auto &places = scope.FindVar(Input(kPlaces))->Get<platform::PlaceList>();
// feed output@grad
SplitTensorAndMoveTensorToScopes(scope, sub_scopes, places,
Inputs(framework::GradVarName(kOutputs)));
WaitOnPlaces(places);
// for debugging
for (auto &s : Inputs(framework::GradVarName(kOutputs))) {
VLOG(3) << s;
VLOG(3) << scope.FindVar(s)->Get<LoDTensor>();
......@@ -196,10 +212,11 @@ class ParallelDoGradOp : public OperatorBase {
for (auto &worker : workers) {
worker.wait();
}
WaitOnPlaces(places);
// merge grad
for (auto &s : Outputs(framework::GradVarName(kParameters))) {
VLOG(3) << s;
VLOG(3) << "merge grad " << s;
auto &t = sub_scopes[0]->FindVar(s)->Get<LoDTensor>();
VLOG(3) << t;
......@@ -216,7 +233,8 @@ class ParallelDoGradOp : public OperatorBase {
auto sum_op = framework::OpRegistry::CreateOp(
"sum", {{"X", {s, s_buf}}}, {{"Out", {s}}},
framework::AttributeMap{});
sum_op->Run(*sub_scopes[0], place);
sum_op->Run(*sub_scopes[0], places[0]);
WaitOnPlaces(places);
}
VLOG(3) << t;
......@@ -236,8 +254,10 @@ class ParallelDoGradOpDescMaker : public framework::SingleGradOpDescMaker {
for (auto &input_param : this->InputNames()) {
VLOG(3) << input_param;
grad->SetInput(input_param, this->Input(input_param));
grad->SetOutput(framework::GradVarName(input_param),
this->InputGrad(input_param, false));
if (input_param != kPlaces) {
grad->SetOutput(framework::GradVarName(input_param),
this->InputGrad(input_param, false));
}
}
for (auto &output_param : this->OutputNames()) {
......
......@@ -32,6 +32,20 @@ limitations under the License. */
namespace paddle {
namespace operators {
static void CreateTensorFromMessageType(framework::Variable *var,
sendrecv::VarType var_type) {
if (var_type == sendrecv::VarType::LOD_TENSOR) {
var->GetMutable<framework::LoDTensor>();
} else if (var_type == sendrecv::VarType::SELECTED_ROWS) {
var->GetMutable<framework::SelectedRows>();
} else {
PADDLE_THROW(
"VraibleMessage type %d is not in "
"[LoDTensor, SelectedRows]",
var_type);
}
}
void RunServer(Server **rpc_server,
std::shared_ptr<detail::SendRecvServerImpl> service,
const std::string &server_address) {
......@@ -111,10 +125,10 @@ class RecvOp : public framework::OperatorBase {
auto *merged_grad = recv_scope.FindVar(grad_var_name);
if (merged_grad == nullptr) {
auto *ptr = recv_scope.Var(grad_var_name);
framework::CreateTensor(ptr,
framework::ToVarType(merged_grad->Type()));
CreateTensorFromMessageType(ptr, v.second.type());
VLOG(3) << "Create Variable " << grad_var_name
<< " on recv scope, which pointer is " << ptr;
<< " on recv scope, which pointer is " << ptr << " type is "
<< v.second.type();
}
if (trainer_count > 1) {
......
......@@ -115,12 +115,32 @@ class SequencePoolGradOp : public framework::OperatorWithKernel {
}
};
class SequencePoolGradOpMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDesc> Apply() const override {
auto* op_desc_ptr = new framework::OpDesc();
op_desc_ptr->SetType("sequence_pool_grad");
op_desc_ptr->SetInput("X", Input("X"));
if (boost::get<std::string>(GetAttr("pooltype")) == "MAX") {
op_desc_ptr->SetInput("MaxIndex", Output("MaxIndex"));
}
op_desc_ptr->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
op_desc_ptr->SetOutput(framework::GradVarName("X"), InputGrad("X"));
op_desc_ptr->SetAttrMap(Attrs());
return std::unique_ptr<framework::OpDesc>(op_desc_ptr);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP(sequence_pool, ops::SequencePoolOp, ops::SequencePoolOpMaker,
sequence_pool_grad, ops::SequencePoolGradOp);
REGISTER_OPERATOR(sequence_pool, ops::SequencePoolOp, ops::SequencePoolOpMaker,
ops::SequencePoolGradOpMaker);
REGISTER_OPERATOR(sequence_pool_grad, ops::SequencePoolGradOp);
REGISTER_OP_CPU_KERNEL(
sequence_pool,
ops::SequencePoolKernel<paddle::platform::CPUDeviceContext, float>);
......
......@@ -31,6 +31,7 @@ class SoftmaxOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE(x_dims.size() == 2UL,
"The input of softmax op must be a matrix.");
ctx->SetOutputDim("Out", x_dims);
ctx->ShareLoD("X", /*->*/ "Out");
}
};
......
......@@ -70,6 +70,7 @@ class SumKernel : public framework::OpKernel<T> {
} else if (out_var->IsType<framework::SelectedRows>()) {
PADDLE_ENFORCE(!in_place, "SelectedRows not support inplace sum now");
auto *out = context.Output<SelectedRows>("Out");
out->mutable_rows()->clear();
auto *out_value = out->mutable_value();
// Runtime InferShape
......
......@@ -41,6 +41,8 @@ class TopkOp : public framework::OperatorWithKernel {
dims[dims.size() - 1] = k;
ctx->SetOutputDim("Out", dims);
ctx->SetOutputDim("Indices", dims);
ctx->ShareLoD("X", "Out");
ctx->ShareLoD("X", "Indices");
}
};
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/warpctc_op.h"
namespace paddle {
namespace operators {
class WarpCTCOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Logits"),
"Input(Logits) of WarpCTCOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Label"),
"Input(Label) of WarpCTCOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("WarpCTCGrad"),
"Output(WarpCTCGrad) of WarpCTCOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Loss"),
"Output(Loss) of WarpCTCOp should not be null.");
auto logits_dims = ctx->GetInputDim("Logits");
int sequence_width =
static_cast<int>(framework::product(logits_dims) / logits_dims[0]);
int blank = ctx->Attrs().Get<int>("blank");
PADDLE_ENFORCE((blank >= 0) && (blank < sequence_width),
"The value of Attr(blank) should be in interval [0, %d).",
sequence_width);
// TODO(liuyiqun): it is tricky to set the wrong dimension here.
ctx->SetOutputDim("Loss", {logits_dims[0], 1});
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("Logits")->type()),
ctx.device_context());
}
};
class WarpCTCOpMaker : public framework::OpProtoAndCheckerMaker {
public:
WarpCTCOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Logits",
"(LodTensor, default: LoDTensor<float>), the unscaled "
"probabilities of variable-length sequences, which is a 2-D "
"Tensor with LoD information. It's shape is "
"[Lp, num_classes + 1], where Lp is the sum of all input "
"sequences' length and num_classes is the true number of classes "
"(not including the blank label).");
AddInput("Label",
"(LodTensor, default: LoDTensor<int>), the ground truth "
"of variable-length sequence, which is a 2-D Tensor with LoD "
"information. It is of the shape [Lg, 1], where Lg is th sum of "
"all labels' length.");
AddOutput("WarpCTCGrad",
"(Tensor, default: Tensor<float>), a temporary "
"output Tensor to store the gradients of warp-ctc, which is "
"computed with loss together in one call. It is a 3-D Tensor of "
"the shape [max_sequence_length, batch_size, num_classes + 1].")
.AsIntermediate();
AddOutput("Loss",
"(Tensor, default: Tensor<float>), the Connectionist "
"Temporal Classification (CTC) loss, which is a 2-D Tensor of "
"the shape [batch_size, 1]");
AddAttr<int>("blank",
"(int, default: 0), the blank label of Connectionist "
"Temporal Classification (CTC) loss, which is in the "
"half-opened interval [0, num_classes + 1).")
.SetDefault(0);
AddAttr<bool>("norm_by_times",
"(bool, default: false), whether to "
"normalize the gradients by the number of time-step, "
"which is also the sequence's length.")
.SetDefault(false);
AddComment(R"DOC(
An operator integrating the open-source
[warp-ctc](https://github.com/baidu-research/warp-ctc) library, which is used in
[Deep Speech 2: End-toEnd Speech Recognition in English and Mandarin](
https://arxiv.org/pdf/1512.02595v1.pdf),
to compute Connectionist Temporal Classification (CTC) loss.
It can be aliased as softmax with ctc, since a native softmax activation is
interated to the warp-ctc library, to to normlize values for each row of the
input tensor.
More detail of CTC loss can be found by refering to
[Connectionist Temporal Classification: Labelling Unsegmented Sequence Data with
Recurrent Neural Networks](
http://machinelearning.wustl.edu/mlpapers/paper_files/icml2006_GravesFGS06.pdf).
)DOC");
}
};
class WarpCTCGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("WarpCTCGrad"),
"Input(WarpCTCGrad) of WarpCTCGradOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("Logits")),
"Output(Logits@GRAD) of WarpCTCGradOp should not be null.");
ctx->SetOutputDim(framework::GradVarName("Logits"),
ctx->GetInputDim("Logits"));
ctx->ShareLoD("Logits", /*->*/ framework::GradVarName("Logits"));
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("Logits")->type()),
ctx.device_context());
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP(warpctc, ops::WarpCTCOp, ops::WarpCTCOpMaker, warpctc_grad,
ops::WarpCTCGradOp);
REGISTER_OP_CPU_KERNEL(
warpctc, ops::WarpCTCKernel<paddle::platform::CPUDeviceContext, float>);
REGISTER_OP_CPU_KERNEL(
warpctc_grad,
ops::WarpCTCGradKernel<paddle::platform::CPUDeviceContext, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/warpctc_op.h"
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
warpctc, ops::WarpCTCKernel<paddle::platform::CUDADeviceContext, float>);
REGISTER_OP_CUDA_KERNEL(
warpctc_grad,
ops::WarpCTCGradKernel<paddle::platform::CUDADeviceContext, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/op_registry.h"
#include "paddle/operators/math/math_function.h"
#include "paddle/operators/math/sequence_padding.h"
#include "paddle/platform/dynload/warpctc.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
template <typename DeviceContext>
class WarpCTCFunctor {
public:
/*
* \brief Compute the connectionist temporal classification loss,
* and optionally compute the gradient with respect to the inputs.
*
* If gradient is nullptr, it only computes the ctc loss,
* or computes both ctc loss and gradient.
*
* \param ctx execution context of this functor
* \param input batch matrix of input probabilities, in
* max_sequence_length x num_sequences x
* sequence_width, (row-major) format
* \param gradient batch matrix of gradient, with the same shape as
* input.
* \param cpu_labels labels always in CPU memory.
* \param cpu_label_lengths length of all labels in CPU memory.
* \param cpu_input_lengths length of all sequences in CPU memory.
* \param sequence_width number of possible output symbols.
* \param num_sequences number of sequence.
* \param blank blank label used in ctc loss function.
* \param cpu_losss cost of each sequence in CPU memory.
*/
void operator()(const framework::ExecutionContext& ctx, const float* input,
float* gradient, const int* cpu_labels,
const int* cpu_label_lengths, const int* cpu_input_lengths,
const size_t sequence_width, const size_t num_sequences,
const size_t blank, float* cpu_loss) {
// Init warp-ctc options
init(ctx, blank);
// Compute the required workspace size.
// There is no memory allocated operations within warp-ctc.
size_t workspace_bytes = 0;
ctcStatus_t status = platform::dynload::get_workspace_size(
cpu_label_lengths, cpu_input_lengths, static_cast<int>(sequence_width),
static_cast<int>(num_sequences), options_, &workspace_bytes);
PADDLE_ENFORCE_EQ(CTC_STATUS_SUCCESS, status,
"warp-ctc [version %d] Error in get_workspace_size: ",
warpctc_version_,
platform::dynload::ctcGetStatusString(status));
PADDLE_ENFORCE_GT(workspace_bytes, 0UL,
"Bytes of workspace got by warp-ctc function, "
"get_workspace_size(), should be larger than 0.");
Tensor workspace;
size_t workspace_elements = workspace_bytes / sizeof(float) + 1UL;
float* workspace_data = workspace.mutable_data<float>(
framework::make_ddim({static_cast<int64_t>(workspace_elements)}),
ctx.GetPlace());
math::SetConstant<DeviceContext, float>()(
ctx.template device_context<DeviceContext>(), &workspace,
static_cast<float>(0));
// compute loss and gradient
status = platform::dynload::compute_ctc_loss(
input, gradient, cpu_labels, cpu_label_lengths, cpu_input_lengths,
static_cast<int>(sequence_width), static_cast<int>(num_sequences),
cpu_loss, workspace_data, options_);
PADDLE_ENFORCE_EQ(CTC_STATUS_SUCCESS, status,
"warp-ctc [version %d] Error in compute_ctc_loss: ",
warpctc_version_,
platform::dynload::ctcGetStatusString(status));
}
protected:
void init(const framework::ExecutionContext& ctx, const size_t blank) {
warpctc_version_ = platform::dynload::get_warpctc_version();
if (platform::is_gpu_place(ctx.GetPlace())) {
#ifdef PADDLE_WITH_CUDA
options_.loc = CTC_GPU;
options_.stream = reinterpret_cast<const platform::CUDADeviceContext&>(
ctx.device_context())
.stream();
#else
PADDLE_THROW("[warpctc init] GPU is not enabled.");
#endif
} else {
options_.loc = CTC_CPU;
options_.num_threads = 1;
}
options_.blank_label = blank;
}
private:
int warpctc_version_;
ctcOptions options_;
};
template <typename DeviceContext, typename T>
class WarpCTCKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* logits = ctx.Input<LoDTensor>("Logits");
auto* label = ctx.Input<LoDTensor>("Label");
auto* warpctc_grad = ctx.Output<Tensor>("WarpCTCGrad");
auto* loss = ctx.Output<Tensor>("Loss");
const size_t level = 0;
auto logits_lod = framework::ToAbsOffset(logits->lod());
auto logits_dims = logits->dims();
PADDLE_ENFORCE_EQ(logits_dims[0],
static_cast<int64_t>(logits_lod[level].back()),
"The first dimension of Input(Logits) should be equal to "
"the sum of all sequences' lengths.");
auto label_lod = framework::ToAbsOffset(label->lod());
auto label_dims = label->dims();
PADDLE_ENFORCE_EQ(
label_dims[0], label->numel(),
"The width of each timestep in Input(Label) should be 1.");
const size_t num_sequences = logits_lod[level].size() - 1;
PADDLE_ENFORCE_EQ(num_sequences, label_lod[level].size() - 1,
"The number of sequences of Input(Logits) should be "
"equal to that of Input(Label).");
const size_t sequence_width = logits->numel() / logits_dims[0];
auto loss_dims =
framework::make_ddim({static_cast<int64_t>(num_sequences), 1});
// warpctc needs sequences data stored in transposed padding format
Tensor warpctc_logits;
const size_t max_sequence_length =
math::MaximumSequenceLength(logits_lod, level);
auto warpctc_logits_dims =
framework::make_ddim({static_cast<int64_t>(max_sequence_length),
static_cast<int64_t>(num_sequences),
static_cast<int64_t>(sequence_width)});
warpctc_logits.mutable_data<T>(warpctc_logits_dims, ctx.GetPlace());
math::PaddingLoDTensorFunctor<DeviceContext, T>()(
ctx.template device_context<DeviceContext>(), *logits, warpctc_logits,
false);
const T* warpctc_logits_data = warpctc_logits.data<T>();
std::vector<int> warpctc_label_lengths(num_sequences);
std::vector<int> warpctc_logits_lengths(num_sequences);
for (size_t i = 0; i < num_sequences; ++i) {
warpctc_label_lengths[i] = label_lod[level][i + 1] - label_lod[level][i];
warpctc_logits_lengths[i] =
logits_lod[level][i + 1] - logits_lod[level][i];
}
// warpctc computes loss and gradient in one call, gradient data also stored
// in batch format
T* warpctc_grad_data =
warpctc_grad->mutable_data<T>(warpctc_logits.dims(), ctx.GetPlace());
// warpctc accesses labels in CPU memory
Tensor warpctc_label;
Copy(*label, platform::CPUPlace(), ctx.device_context(), &warpctc_label);
const int* warpctc_label_data = warpctc_label.data<int>();
// warpctc stores loss in CPU memory
Tensor warpctc_loss;
T* warpctc_loss_data =
warpctc_loss.mutable_data<T>(loss_dims, platform::CPUPlace());
const size_t blank = static_cast<size_t>(ctx.Attr<int>("blank"));
WarpCTCFunctor<DeviceContext>()(
ctx, warpctc_logits_data, warpctc_grad_data, warpctc_label_data,
warpctc_label_lengths.data(), warpctc_logits_lengths.data(),
sequence_width, num_sequences, blank, warpctc_loss_data);
// Copy the loss back
Copy(warpctc_loss, ctx.GetPlace(), ctx.device_context(), loss);
}
};
template <typename DeviceContext, typename T>
class WarpCTCGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* warpctc_grad = ctx.Input<Tensor>("WarpCTCGrad");
auto* logits_grad = ctx.Output<LoDTensor>(framework::GradVarName("Logits"));
bool norm_by_times = ctx.Attr<bool>("norm_by_times");
math::UnpaddingLoDTensorFunctor<DeviceContext, T>()(
ctx.template device_context<DeviceContext>(), *logits_grad,
*warpctc_grad, norm_by_times);
}
};
} // namespace operators
} // namespace paddle
......@@ -211,59 +211,54 @@ class WhileGradOpDescMaker : public framework::SingleGradOpDescMaker {
protected:
std::unique_ptr<framework::OpDesc> Apply() const override {
auto *grad = new framework::OpDesc();
grad->SetType("while_grad");
grad->SetInput(kX, Input(kX));
auto *while_grad = new framework::OpDesc();
while_grad->SetType("while_grad");
while_grad->SetInput(kX, Input(kX));
while_grad->SetInput(kOutputs, Output(kOutputs));
while_grad->SetInput(kStepScopes, Output(kStepScopes));
auto *grad_block = this->grad_block_[0];
auto *fwd_block = grad_block->ParentBlock();
// Not all of IGs will be generated by inner gradient operators of while op.
// Ignore IGs that is not generated by the inside block.
auto igs = InputGrad(kX, /*do not drop empty gradient*/ false);
std::unordered_set<std::string> all_outs;
for (size_t i = 0; i < grad_block_[0]->OpSize(); ++i) {
for (auto &oname : grad_block_[0]->Op(i)->OutputArgumentNames()) {
all_outs.insert(oname);
std::unordered_set<std::string> inner_op_outputs;
for (const auto *op : grad_block->AllOps()) {
for (auto &oname : op->OutputArgumentNames()) {
inner_op_outputs.insert(oname);
}
}
auto igs = InputGrad(kX, /*do not drop empty gradient*/ false);
for (auto &each_ig : igs) {
if (all_outs.find(each_ig) == all_outs.end()) {
if (inner_op_outputs.find(each_ig) == inner_op_outputs.end()) {
VLOG(10) << "Ignore " << each_ig;
each_ig = framework::kEmptyVarName;
}
}
grad->SetOutput(framework::GradVarName(kX), igs);
grad->SetInput(kOutputs, Output(kOutputs));
while_grad->SetOutput(framework::GradVarName(kX), igs);
// OG should be re-calculated by step blocks, since many outputs of while op
// do not need to calculate gradients.
std::unordered_set<std::string> block_ins;
auto *fwd_block = this->grad_block_[0]->ParentBlock();
{
for (auto &p : Input(kX)) {
block_ins.insert(p);
}
for (auto &o : Output(kOutputs)) {
block_ins.insert(o);
}
block_ins.reserve(Input(kX).size() + Output(kOutputs).size());
for (auto &p : Input(kX)) {
block_ins.insert(p);
}
for (auto &o : Output(kOutputs)) {
block_ins.insert(o);
}
std::unordered_set<std::string> extra_inputs;
for (size_t i = 0; i < grad_block_[0]->OpSize(); ++i) {
for (auto &input_name : grad_block_[0]->Op(i)->InputArgumentNames()) {
if (block_ins.find(input_name) != block_ins.end()) {
continue;
}
// If the input of Op is generated by the forward block, do not make it
// as input again.
if (fwd_block->FindVar(input_name) != nullptr) {
for (const auto *op : grad_block->AllOps()) {
for (auto &input_name : op->InputArgumentNames()) {
// If the input of Op has been recorded or is generated by the forward
// block, do not make it as input again.
if (block_ins.find(input_name) != block_ins.end() ||
fwd_block->FindVar(input_name) != nullptr) {
continue;
}
extra_inputs.insert(input_name);
}
for (auto &output_name : grad_block_[0]->Op(i)->OutputArgumentNames()) {
for (auto &output_name : op->OutputArgumentNames()) {
block_ins.insert(output_name);
}
}
......@@ -272,15 +267,15 @@ class WhileGradOpDescMaker : public framework::SingleGradOpDescMaker {
extra_inputs_list.resize(extra_inputs.size());
std::copy(extra_inputs.begin(), extra_inputs.end(),
extra_inputs_list.begin());
grad->SetInput(framework::GradVarName(kOutputs), extra_inputs_list);
grad->SetInput(kStepScopes, Output(kStepScopes));
grad->SetAttrMap(this->Attrs());
grad->SetBlockAttr(kStepBlock, *grad_block_[0]);
while_grad->SetInput(framework::GradVarName(kOutputs), extra_inputs_list);
while_grad->SetAttrMap(this->Attrs());
while_grad->SetBlockAttr(kStepBlock, *grad_block);
// record the original output gradient names, since the gradient name of
// while operator could be renamed.
grad->SetAttr("original_output_grad", extra_inputs_list);
while_grad->SetAttr("original_output_grad", extra_inputs_list);
return std::unique_ptr<framework::OpDesc>(grad);
return std::unique_ptr<framework::OpDesc>(while_grad);
}
};
......
......@@ -185,6 +185,8 @@ class DeviceContextPool {
const typename DefaultDeviceContextType<Place>::TYPE*>(Get(place));
}
size_t size() const { return device_contexts_.size(); }
private:
static DeviceContextPool* pool;
constexpr static int LEFT_SHIFT = 8;
......
cc_library(dynamic_loader SRCS dynamic_loader.cc DEPS glog gflags enforce)
nv_library(dynload_cuda SRCS cublas.cc cudnn.cc curand.cc nccl.cc
DEPS dynamic_loader nccl)
cc_library(dynload_warpctc SRCS warpctc.cc DEPS dynamic_loader warpctc)
......@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <paddle/platform/dynload/cublas.h>
#include "paddle/platform/dynload/cublas.h"
namespace paddle {
namespace platform {
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/platform/dynload/warpctc.h"
namespace paddle {
namespace platform {
namespace dynload {
std::once_flag warpctc_dso_flag;
void* warpctc_dso_handle = nullptr;
#define DEFINE_WRAP(__name) DynLoad__##__name __name
WARPCTC_ROUTINE_EACH(DEFINE_WRAP);
} // namespace dynload
} // namespace platform
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <dlfcn.h>
#include <mutex>
#include "ctc.h"
#include "paddle/platform/dynload/dynamic_loader.h"
namespace paddle {
namespace platform {
namespace dynload {
extern std::once_flag warpctc_dso_flag;
extern void* warpctc_dso_handle;
/**
* The following macro definition can generate structs
* (for each function) to dynamic load warpctc routine
* via operator overloading.
*/
#define DYNAMIC_LOAD_WARPCTC_WRAP(__name) \
struct DynLoad__##__name { \
template <typename... Args> \
auto operator()(Args... args) -> decltype(__name(args...)) { \
using warpctcFunc = decltype(__name(args...)) (*)(Args...); \
std::call_once(warpctc_dso_flag, \
paddle::platform::dynload::GetWarpCTCDsoHandle, \
&warpctc_dso_handle); \
void* p_##_name = dlsym(warpctc_dso_handle, #__name); \
return reinterpret_cast<warpctcFunc>(p_##_name)(args...); \
} \
}; \
extern DynLoad__##__name __name
#define DECLARE_DYNAMIC_LOAD_WARPCTC_WRAP(__name) \
DYNAMIC_LOAD_WARPCTC_WRAP(__name)
#define WARPCTC_ROUTINE_EACH(__macro) \
__macro(get_warpctc_version); \
__macro(ctcGetStatusString); \
__macro(compute_ctc_loss); \
__macro(get_workspace_size)
WARPCTC_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_WARPCTC_WRAP);
#undef DYNAMIC_LOAD_WARPCTC_WRAP
} // namespace dynload
} // namespace platform
} // namespace paddle
......@@ -34,11 +34,11 @@ int main(int argc, char** argv) {
google::ParseCommandLineFlags(&new_argc, &new_argv_address, false);
testing::InitGoogleTest(&argc, argv);
paddle::memory::Used(paddle::platform::CPUPlace());
std::vector<std::string> devs = {"CPU"};
#ifdef PADDLE_WITH_CUDA
paddle::memory::Used(paddle::platform::CUDAPlace(0));
devs.push_back("GPU:0");
#endif
paddle::framework::InitDevices(devs);
paddle::framework::InitDevices();
return RUN_ALL_TESTS();
}
......@@ -19,12 +19,13 @@ from data_feeder import DataFeeder
from core import LoDTensor, CPUPlace, CUDAPlace
from distribute_transpiler import DistributeTranspiler
import clip
from memory_optimization_transpiler import memory_optimize
Tensor = LoDTensor
__all__ = framework.__all__ + executor.__all__ + [
'io', 'initializer', 'layers', 'nets', 'optimizer', 'backward',
'regularizer', 'LoDTensor', 'CPUPlace', 'CUDAPlace', 'Tensor', 'ParamAttr'
'DataFeeder', 'clip', 'DistributeTranspiler'
'DataFeeder', 'clip', 'DistributeTranspiler', 'memory_optimize'
]
......@@ -61,11 +62,7 @@ def __bootstrap__():
core.init_gflags([sys.argv[0]] +
["--tryfromenv=" + ",".join(read_env_flags)])
core.init_glog(sys.argv[0])
if core.is_compile_gpu():
core.init_devices(["CPU", "GPU:0"])
else:
core.init_devices(["CPU"])
core.init_devices()
__bootstrap__()
from paddle.v2.fluid import framework as framework
from . import core
import collections
import copy
__all__ = ['append_backward']
__all__ = ['append_backward', 'calc_gradient']
def _rename_arg_(op_descs, old_name, new_name, begin_idx=None, end_idx=None):
......@@ -65,6 +66,18 @@ def _all_in_set_(cands, s):
return True
def _some_in_set_(cands, s):
"""
Test if some elements of 'cands' are in set 's'
"""
if len(cands) == 0:
return False
for c in cands:
if c in s:
return True
return False
def _strip_grad_suffix_(name):
"""
Strip the grad suffix from the given varibale name
......@@ -169,8 +182,8 @@ def _remove_no_grad_branch_(op_descs, no_grad_set):
return op_descs
def _append_backward_ops_(target,
block,
def _append_backward_ops_(block,
ops,
target_block,
no_grad_dict,
grad_to_var,
......@@ -179,8 +192,8 @@ def _append_backward_ops_(target,
Create all grad ops, and insert them into given block
Args:
target(Variable): the target variable of forward pass
block(Block): the block where forward ops are
ops(Op): the forward operators whose backward ops need to be added
target_block(Block): the block which is going to hold new generated grad ops
no_grad_dict(dict):
key(int) block index
......@@ -202,14 +215,14 @@ def _append_backward_ops_(target,
# grad_op_descs holds created grad_op, and will be appended to target_block
grad_op_descs = []
program = block.program
for op in reversed(block.ops):
for op in reversed(ops):
grad_sub_block_list = []
# If the op has its own sub-block, deal with the sub-block first
if op.has_attr("sub_block"):
sub_block = program.block(op.block_attr("sub_block"))
grad_sub_block = program.create_block(parent_idx=sub_block.idx)
_append_backward_ops_(target, sub_block, grad_sub_block,
no_grad_dict, grad_to_var, callback)
_append_backward_ops_(sub_block, sub_block.ops, grad_sub_block,
no_grad_dict, grad_to_var)
grad_sub_block_list.append(grad_sub_block.desc)
# Getting op's corresponding grad_op
......@@ -224,14 +237,6 @@ def _append_backward_ops_(target,
grad_op_descs = _remove_no_grad_branch_(grad_op_descs,
no_grad_dict[block.idx])
if target_block.idx == 0:
grad_op_descs.insert(
0,
_create_op_desc_("fill_constant", {}, {
"Out": [_append_grad_suffix_(target.name)]
}, {"shape": [1],
"value": 1.0,
"dtype": target.dtype}))
# append op_desc in grad_op_descs to target_block
for op_desc in grad_op_descs:
new_op_desc = target_block.desc.append_op()
......@@ -252,7 +257,7 @@ def _append_backward_vars_(block, start_op_idx, grad_to_var, grad_info_map):
In most cases, this dict is generated by _append_backward_ops_()
grad_info_map(dict)(output argument):
key(str): forward variable name
val(tuple): a tuple of (str, int), str is the corresponding grad name, int is the block index
val(tuple): a tuple of (str, Block), str is the corresponding grad name, Block is the block containing grad variable
"""
for op_idx in range(start_op_idx, block.desc.op_size()):
op_desc = block.desc.op(op_idx)
......@@ -279,41 +284,63 @@ def _append_backward_vars_(block, start_op_idx, grad_to_var, grad_info_map):
_infer_var_data_type_(arg, block)
def _rename_grad_(block, start_op_idx, grad_to_var, target_grad_map):
var_map = copy.copy(target_grad_map)
for op_idx in range(start_op_idx, block.desc.op_size()):
op_desc = block.desc.op(op_idx)
for name in op_desc.input_arg_names():
if name in var_map:
op_desc.rename_input(name, var_map[name])
for name in op_desc.output_arg_names():
if block.desc.find_var(name.encode("ascii")):
new_name = "%s_%s" % (name, core.unique_integer(name))
op_desc.rename_output(name, new_name)
var_map[name] = new_name
for g, ng in var_map.iteritems():
if g in grad_to_var:
grad_to_var[ng] = grad_to_var[g]
grad_to_var.pop(g)
def _get_stop_gradients_(program):
no_grad_dict = dict()
assert isinstance(program, framework.Program)
for block in program.blocks:
assert isinstance(block, framework.Block)
block_no_grad_set = set()
for var in block.vars.itervalues():
assert isinstance(var, framework.Variable)
if var.stop_gradient:
block_no_grad_set.add(_append_grad_suffix_(var.name))
no_grad_dict[block.idx] = block_no_grad_set
return no_grad_dict
def append_backward(loss, parameter_list=None, no_grad_set=None, callback=None):
"""
Append backward part to main_program
Args:
loss(Variable): The variable generated by cost function.
parameter_list(list): Parameters that need to be updated by optimizer.
If None, it means all parameters need to be updated.
parameter_list(list[string]): Parameters that need to be updated by
optimizer. If None, it means all parameters need to be updated.
no_grad_set(set): Variables that have no gradients in Block 0.
If None, the set will be generated inside the function and
contains all variables with `step_gradient=True` from all blocks.
All variables with `step_gradient=True` from all blocks will be
automatically added.
Return:
(list[Variable]): list of (parameters, gradients) pair.
(list[(Variable,Variable)]): list of (parameter, gradient) pair.
"""
assert isinstance(loss, framework.Variable)
program = loss.block.program
no_grad_dict = dict()
if no_grad_set is None:
assert isinstance(program, framework.Program)
for block in program.blocks:
assert isinstance(block, framework.Block)
block_no_grad_set = set()
for var in block.vars.itervalues():
assert isinstance(var, framework.Variable)
if var.stop_gradient:
block_no_grad_set.add(_append_grad_suffix_(var.name))
no_grad_dict[block.idx] = block_no_grad_set
elif isinstance(no_grad_set, set):
no_grad_dict = {
0: set([_append_grad_suffix_(name) for name in no_grad_set])
}
else:
raise ValueError("'no_grad_set' should be a set or None.")
no_grad_set = set()
no_grad_set = copy.copy(no_grad_set)
no_grad_dict = _get_stop_gradients_(program)
no_grad_dict[0].update(map(_append_grad_suffix_, no_grad_set))
grad_info_map = dict()
root_block = program.block(0)
......@@ -322,8 +349,25 @@ def append_backward(loss, parameter_list=None, no_grad_set=None, callback=None):
current_block_idx = program.current_block_idx
grad_to_var = dict()
_append_backward_ops_(loss, root_block, root_block, no_grad_dict,
op_desc = _create_op_desc_("fill_constant", {}, {
"Out": [_append_grad_suffix_(loss.name)]
}, {"shape": [1],
"value": 1.0,
"dtype": loss.dtype})
root_block.desc.append_op().copy_from(op_desc)
block_no_grad_set = set(map(_strip_grad_suffix_, no_grad_dict[0]))
op_path = _find_op_path_(root_block, [loss], [], block_no_grad_set)
no_grad_dict[0].update(map(_append_grad_suffix_, block_no_grad_set))
_append_backward_ops_(root_block, op_path, root_block, no_grad_dict,
grad_to_var, callback)
# Because calc_gradient may be called multiple times,
# we need rename the internal gradient variables so that they have
# different names.
_rename_grad_(root_block, fwd_op_num, grad_to_var, {})
_append_backward_vars_(root_block, fwd_op_num, grad_to_var, grad_info_map)
program.current_block_idx = current_block_idx
......@@ -334,6 +378,7 @@ def append_backward(loss, parameter_list=None, no_grad_set=None, callback=None):
else:
params = program.global_block().all_parameters()
parameters = [param.name for param in params]
params_and_grads = []
for param in parameters:
if param not in grad_info_map:
......@@ -351,3 +396,147 @@ def append_backward(loss, parameter_list=None, no_grad_set=None, callback=None):
else:
params_and_grads.append((param_var, None))
return params_and_grads
def _as_list(x):
if x is None:
return []
return list(x) if isinstance(x, collections.Sequence) else [x]
def _find_op_path_(block, outputs, inputs, no_grad_set):
"""
no_grad_set will also be changed
"""
input_names = set([inp.name for inp in inputs])
output_names = set([out.name for out in outputs])
relevant_op_flags = [True] * len(block.ops)
# All the inputs of the block are used if inputs is empty,
if inputs:
for i, op in enumerate(block.ops):
if _some_in_set_(op.desc.input_arg_names(), input_names):
for name in op.desc.output_arg_names():
if name not in no_grad_set:
input_names.add(name)
else:
relevant_op_flags[i] = False
for i, op in reversed(list(enumerate(block.ops))):
if _some_in_set_(op.desc.output_arg_names(), output_names):
for name in op.desc.input_arg_names():
if name not in no_grad_set:
output_names.add(name)
else:
relevant_op_flags[i] = False
op_path = [
block.ops[i] for i in range(len(block.ops)) if relevant_op_flags[i]
]
if inputs:
for op in op_path:
for name in op.desc.input_arg_names():
if name not in input_names:
no_grad_set.add(name)
return op_path
def calc_gradient(targets, inputs, target_gradients=None, no_grad_set=None):
"""
Backpropagate the graidents of targets to inputs.
Args:
targets(Variable|list[Variable]): The target variables
inputs(Variable|list[Variable]): The input variables
no_grad_set(set[string]): The names of variables that have no gradients
in Block 0. All variables with `stop_gradient=True` from all blocks
will be automatically added.
Return:
(list[Variable]): list of gradients for inputs
If an input does not affect targets, the corresponding gradient variable
will be None
"""
targets = _as_list(targets)
inputs = _as_list(inputs)
target_gradients = _as_list(target_gradients)
block = targets[0].block
prog = block.program
block_idx = block.idx
if not target_gradients:
target_gradients = [None] * len(targets)
if len(targets) != len(target_gradients):
raise ValueError(
"Should have the same number of target_gradients as targets")
if no_grad_set is None:
no_grad_set = set()
no_grad_set = copy.copy(no_grad_set)
no_grad_dict = _get_stop_gradients_(prog)
no_grad_dict[0].update(map(_append_grad_suffix_, no_grad_set))
fwd_op_num = block.desc.op_size()
target_grad_map = {}
for i, grad in enumerate(target_gradients):
target = targets[i]
if grad is None:
grad_name = _append_grad_suffix_(target.name)
op_desc = _create_op_desc_("fill_constant_batch_size_like",
{"Input": [target.name]},
{"Out": [grad_name]}, {
"shape": target.shape,
"value": 1.0,
"dtype": target.dtype,
'input_dim_idx': 0,
'output_dim_idx': 0
})
block.desc.append_op().copy_from(op_desc)
else:
if target.block.idx != block_idx or target.block.program != prog:
raise ValueError("all targets must be in the same block")
if target.shape != grad.shape:
raise ValueError(
"The shapes of target and grad are different: %s %s" % (
target.name, grad.name))
target_grad_map[_append_grad_suffix_(target.name)] = grad.name
for input in inputs:
if input.block.program != prog:
raise "input must be in the same program as targets"
block_no_grad_set = set(map(_strip_grad_suffix_, no_grad_dict[0]))
op_path = _find_op_path_(block, targets, inputs, block_no_grad_set)
no_grad_dict[0].update(map(_append_grad_suffix_, block_no_grad_set))
grad_to_var = dict()
grad_info_map = dict()
_append_backward_ops_(block, op_path, block, no_grad_dict, grad_to_var)
# Because calc_gradient may be called multiple times,
# we need rename the internal gradient variables so that they have
# different names.
_rename_grad_(block, fwd_op_num, grad_to_var, target_grad_map)
_append_backward_vars_(block, fwd_op_num, grad_to_var, grad_info_map)
prog.sync_with_cpp()
grad_vars = []
for input_var in inputs:
if input_var.name not in grad_info_map:
grad_vars.append(None)
else:
grad_info = grad_info_map[input_var.name]
grad_block = grad_info[1]
grad_var = grad_block.var(grad_info[0])
grad_vars.append(grad_var)
if len(grad_vars) == 1:
return grad_vars[0]
else:
return grad_vars
......@@ -773,6 +773,9 @@ class Program(object):
proto = framework_pb2.ProgramDesc.FromString(str(protostr))
return _debug_string_(proto, throw_on_error)
def get_desc(self):
return self.desc
def clone(self):
p = Program()
p.desc = core.ProgramDesc(self.desc)
......
from ..registry import register_layer
__activations__ = [
'abs', 'tanh', 'sigmoid', 'relu', 'sqrt', 'ceil', 'floor', 'log', 'round'
'abs',
'ceil',
'exp',
'floor',
'log',
'relu',
'round',
'sigmoid',
'sqrt',
'square',
'tanh',
]
__all__ = [
......
from ..layer_helper import LayerHelper
from ..param_attr import ParamAttr
__all__ = [
'create_tensor', 'cast', 'concat', 'sums', 'assign',
'create_tensor', 'create_parameter', 'cast', 'concat', 'sums', 'assign',
'fill_constant_batch_size_like', 'fill_constant', 'ones', 'zeros'
]
......@@ -11,6 +12,33 @@ def create_tensor(dtype, name=None):
return helper.create_variable(name=helper.name, dtype=dtype)
def create_parameter(shape,
dtype,
attr=None,
is_bias=False,
default_initializer=None):
"""
Create a parameter
Args:
shape(list[int]): shape of the parameter
dtype(string): element type of the parameter
attr(ParamAttr): attributes of the parameter
is_bias(bool): This can affect which default initializer is chosen
when default_initializer is None. If is_bias,
initializer.Constant(0.0) will be used. Otherwise,
Xavier() will be used.
default_initializer(Initializer): initializer for the parameter
Returns:
Parameter: the created parameter
"""
helper = LayerHelper("create_parameter")
if attr is None:
attr = ParamAttr()
return helper.create_parameter(attr, shape, dtype, is_bias,
default_initializer)
def cast(x, dtype):
"""
This function takes in the input with input_dtype
......@@ -180,7 +208,8 @@ def fill_constant_batch_size_like(input,
Examples:
.. code-block:: python
data = fluid.layers.fill_constant(shape=[1], value=0, dtype='int64')
data = fluid.layers.fill_constant_batch_size_like(
input=like, shape=[1], value=0, dtype='int64')
"""
helper = LayerHelper("fill_constant_batch_size_like", **locals())
out = helper.create_tmp_variable(dtype=dtype)
......
from collections import defaultdict
import framework
from framework import Program, default_main_program, Parameter, Variable
import backward
from backward import _rename_arg_
class ControlFlowGraph(object):
def __init__(self, Program):
self._program = Program
self._succesors = defaultdict(set)
self._presucessors = defaultdict(set)
self._uses = defaultdict(set)
self._defs = defaultdict(set)
self._live_in = defaultdict(set)
self._live_out = defaultdict(set)
def _add_connections(self, connections):
for node1, node2 in connections:
self._add(node1, node2)
def _add(self, node1, node2):
self._succesors[node1].add(node2)
self._presucessors[node2].add(node1)
def _build_graph(self):
program_desc = self._program.get_desc()
block_size = program_desc.num_blocks()
# TODO(qijun) handle Program with if/while operators
self.global_block = program_desc.block(0)
self.op_size = self.global_block.op_size()
op_node_connections = [(i, i + 1) for i in range(self.op_size - 1)]
self._add_connections(op_node_connections)
self.ops = [self.global_block.op(i) for i in range(self.op_size)]
for i in range(self.op_size):
self._uses[i].update(self.ops[i].input_arg_names())
self._defs[i].update(self.ops[i].output_arg_names())
def _reach_fixed_point(self, live_in, live_out):
if len(live_in) != len(self._live_in):
return False
if len(live_out) != len(self._live_out):
return False
for i in range(self.op_size):
if live_in[i] != self._live_in[i]:
return False
for i in range(self.op_size):
if live_out[i] != self._live_out[i]:
return False
return True
def _dataflow_analyze(self):
self._build_graph()
live_in = defaultdict(set)
live_out = defaultdict(set)
while True:
for i in range(self.op_size):
live_in[i] = set(self._live_in[i])
live_out[i] = set(self._live_out[i])
self._live_in[i] = self._uses[i] | (
self._live_out[i] - self._defs[i])
for s in self._succesors[i]:
self._live_out[i] |= self._live_in[s]
if self._reach_fixed_point(live_in, live_out):
break
def _get_diff(self, a, b):
u = a & b
return a - u, b - u
def memory_optimize(self):
self._build_graph()
self._dataflow_analyze()
self.pool = []
for i in range(self.op_size):
if self.pool:
out_pair = [(x, self.global_block.var(str(x)).shape())
for x in self._defs[i]]
for x, x_shape in out_pair:
for index, cache_pair in enumerate(self.pool):
cache_var = cache_pair[0]
cache_shape = cache_pair[1]
if x_shape == cache_shape:
print(
"Hit Cache !!!! cache pool index is %d, var name is %s, cached var name is %s, var shape is %s "
% (index, x, cache_var, str(cache_shape)))
self.pool.pop(index)
_rename_arg_(self.ops, x, cache_var, begin_idx=i)
self._dataflow_analyze()
break
in_diff, out_diff = self._get_diff(self._live_in[i],
self._live_out[i])
can_optimize = filter(
lambda x: not self.global_block.var(str(x)).persistable(),
in_diff)
if can_optimize:
for var_name in can_optimize:
self.pool.append((
var_name, self.global_block.var(str(var_name)).shape()))
def get_program(self):
return self._program
def memory_optimize(input_program):
graph = ControlFlowGraph(input_program)
graph.memory_optimize()
result_program = graph.get_program()
return result_program
from __future__ import print_function
import numpy as np
import paddle.v2 as paddle
import paddle.v2.fluid as fluid
import os
PASS_NUM = 100
EMBED_SIZE = 32
HIDDEN_SIZE = 256
N = 5
BATCH_SIZE = 32
IS_SPARSE = True
TRAINERS = 2
word_dict = paddle.dataset.imikolov.build_dict()
dict_size = len(word_dict)
first_word = fluid.layers.data(name='firstw', shape=[1], dtype='int64')
second_word = fluid.layers.data(name='secondw', shape=[1], dtype='int64')
third_word = fluid.layers.data(name='thirdw', shape=[1], dtype='int64')
forth_word = fluid.layers.data(name='forthw', shape=[1], dtype='int64')
next_word = fluid.layers.data(name='nextw', shape=[1], dtype='int64')
embed_first = fluid.layers.embedding(
input=first_word,
size=[dict_size, EMBED_SIZE],
dtype='float32',
is_sparse=IS_SPARSE,
param_attr='shared_w')
embed_second = fluid.layers.embedding(
input=second_word,
size=[dict_size, EMBED_SIZE],
dtype='float32',
is_sparse=IS_SPARSE,
param_attr='shared_w')
embed_third = fluid.layers.embedding(
input=third_word,
size=[dict_size, EMBED_SIZE],
dtype='float32',
is_sparse=IS_SPARSE,
param_attr='shared_w')
embed_forth = fluid.layers.embedding(
input=forth_word,
size=[dict_size, EMBED_SIZE],
dtype='float32',
is_sparse=IS_SPARSE,
param_attr='shared_w')
concat_embed = fluid.layers.concat(
input=[embed_first, embed_second, embed_third, embed_forth], axis=1)
hidden1 = fluid.layers.fc(input=concat_embed, size=HIDDEN_SIZE, act='sigmoid')
predict_word = fluid.layers.fc(input=hidden1, size=dict_size, act='softmax')
cost = fluid.layers.cross_entropy(input=predict_word, label=next_word)
avg_cost = fluid.layers.mean(x=cost)
sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.001)
optimize_ops, params_grads = sgd_optimizer.minimize(avg_cost)
train_reader = paddle.batch(
paddle.dataset.imikolov.train(word_dict, N), BATCH_SIZE)
place = fluid.CPUPlace()
exe = fluid.Executor(place)
t = fluid.DistributeTranspiler()
# all parameter server endpoints list for spliting parameters
pserver_endpoints = os.getenv("PSERVERS")
# server endpoint for current node
current_endpoint = os.getenv("SERVER_ENDPOINT")
# run as trainer or parameter server
training_role = os.getenv("TRAINING_ROLE",
"TRAINER") # get the training role: trainer/pserver
t.transpile(
optimize_ops, params_grads, pservers=pserver_endpoints, trainers=TRAINERS)
if training_role == "PSERVER":
if not current_endpoint:
print("need env SERVER_ENDPOINT")
exit(1)
pserver_prog = t.get_pserver_program(current_endpoint, optimize_ops)
exe.run(fluid.default_startup_program())
exe.run(pserver_prog)
elif training_role == "TRAINER":
feeder = fluid.DataFeeder(
feed_list=[first_word, second_word, third_word, forth_word, next_word],
place=place)
exe.run(fluid.default_startup_program())
for pass_id in range(PASS_NUM):
for data in train_reader():
avg_cost_np = exe.run(fluid.default_main_program(),
feed=feeder.feed(data),
fetch_list=[avg_cost])
print("avg_cost_np", avg_cost_np)
if avg_cost_np[0] < 5.0:
exit(
0) # if avg cost less than 10.0, we think our code is good.
else:
print("environment var TRAINER_ROLE should be TRAINER os PSERVER")
exit(1)
......@@ -341,9 +341,6 @@ class TestBatchNormOp(OpTest):
if core.is_compile_gpu() and core.op_support_gpu("batch_norm"):
places.append(core.CUDAPlace(0))
core.init_devices(["CPU", "GPU:0"])
else:
core.init_devices(["CPU"])
for place in places:
for data_format in ["NCHW", "NHWC"]:
test_with_place(place, data_format, [2, 3, 4, 5])
......
......@@ -37,13 +37,13 @@ class BeamSearchOpTester(unittest.TestCase):
print 'lod', selected_ids.lod()
def _create_pre_ids(self):
np_data = np.array([[1, 2, 3, 4]], dtype='int32')
np_data = np.array([[1, 2, 3, 4]], dtype='int64')
tensor = create_tensor(self.scope, "pre_ids", np_data)
def _create_ids(self):
self.lod = [[0, 1, 4], [0, 1, 2, 3, 4]]
np_data = np.array(
[[4, 2, 5], [2, 1, 3], [3, 5, 2], [8, 2, 1]], dtype='int32')
[[4, 2, 5], [2, 1, 3], [3, 5, 2], [8, 2, 1]], dtype='int64')
tensor = create_tensor(self.scope, "ids", np_data)
tensor.set_lod(self.lod)
......
import unittest
import paddle.v2.fluid as fluid
import paddle.v2.fluid.layers as layers
import paddle.v2.fluid.framework as framework
import paddle.v2.fluid.optimizer as optimizer
from paddle.v2.fluid.backward import calc_gradient
class TestCalcGradient(unittest.TestCase):
def test_calc_gradient(self):
x = layers.create_parameter(dtype="float32", shape=[5, 10])
y = layers.create_parameter(dtype="float32", shape=[10, 8])
mul_out = layers.mul(x=x, y=y)
mean_out = layers.mean(x=mul_out)
a = calc_gradient(mean_out, mul_out)
b = calc_gradient(mean_out, x)
place = fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
exe.run(fluid.default_main_program(), feed={}, fetch_list=[a, b])
if __name__ == "__main__":
unittest.main()
from __future__ import print_function
import unittest
import paddle.v2.fluid.layers as layers
import paddle.v2.fluid.optimizer as optimizer
from paddle.v2.fluid.framework import Program, program_guard
from paddle.v2.fluid.memory_optimization_transpiler import memory_optimize
class TestControlFlowGraph(unittest.TestCase):
def setUp(self):
program = Program()
with program_guard(program, startup_program=Program()):
x = layers.data(name='x', shape=[13], dtype='float32')
y_predict = layers.fc(input=x, size=1, act=None)
y = layers.data(name='y', shape=[1], dtype='float32')
cost = layers.square_error_cost(input=y_predict, label=y)
avg_cost = layers.mean(x=cost)
opt = optimizer.SGD(learning_rate=0.001)
opt = opt.minimize(avg_cost)
self.program = program
def test_control_flow_graph(self):
print("before optimization")
print(str(self.program))
result_program = memory_optimize(self.program)
print("after optimization")
print(str(result_program))
if __name__ == "__main__":
unittest.main()
......@@ -18,7 +18,7 @@ class ParallelOpTest(unittest.TestCase):
append_batch_size=False,
stop_gradient=False)
places = fluid.default_main_program().global_block().create_var()
places = layers.get_places(device_count=4)
pd = layers.ParallelDo(places=places)
with pd.do():
......
import unittest
import numpy as np
from op_test import OpTest
def stable_softmax(x):
"""Compute the softmax of vector x in a numerically stable way."""
shiftx = x - np.max(x).clip(-64.)
exps = np.exp(shiftx)
return exps / np.sum(exps)
from test_softmax_op import stable_softmax
class TestSequenceSoftmaxOp(OpTest):
......
import sys
import unittest
import numpy as np
from op_test import OpTest
from test_softmax_op import stable_softmax
class CTCForward(object):
def __init__(self, softmax, softmax_lod, labels, labels_lod, blank,
norm_by_times):
self.softmax = softmax
self.softmax_lod = softmax_lod
assert labels.shape[1] == 1
self.labels = labels
self.labels_lod = labels_lod
self.blank = blank
self.norm_by_times = norm_by_times
self.level = 0
self.num_classes = softmax.shape[1]
self.batch_size = len(softmax_lod[self.level]) - 1
assert self.batch_size == len(labels_lod[self.level]) - 1
self.loss = np.zeros([self.batch_size, 1], dtype="float32")
self.gradient = np.zeros(self.softmax.shape, dtype="float32")
# float64
self.EXP_MAX = sys.float_info.max
self.EXP_MIN = sys.float_info.min
self.LOG_ZERO = np.log(self.EXP_MIN)
self.LOG_INFINITY = np.log(self.EXP_MAX)
def safe_exp(self, x):
if x <= self.LOG_ZERO:
return 0.0
if x >= self.LOG_INFINITY:
return self.EXP_MAX
return np.exp(x)
def safe_log(self, x):
if x <= self.EXP_MIN:
return self.LOG_ZERO
return np.log(x)
# x = lna and y = lnb are in log scale, ln(a / b) = lna - lnb
def log_div(self, x, y):
res = x - y
if res <= self.LOG_ZERO:
return self.LOG_ZERO
if res >= self.LOG_INFINITY:
return self.LOG_INFINITY
return res
# x = lna and y = lnb are in log scale, ln(a * b) = lna + lnb
def log_mul(self, x, y):
res = x + y
if res <= self.LOG_ZERO:
return self.LOG_ZERO
if res >= self.LOG_INFINITY:
return self.LOG_INFINITY
return res
# x = lna and y = lnb are in log scale,
# ln(a + b) = lna + ln(1 + exp(lnb - lna)), where b > a
def log_add(self, x, y):
if x < y:
t = y
y = x
x = t
return x + self.safe_log(1 + self.safe_exp(y - x))
def segment_range(self, time, total_times, total_segments):
start = max(0, total_segments - (2 * (total_times - time)))
end = min(total_segments, 2 * (time + 1))
return start, end
def forward_a_sequence(self, softmax_a_sequence, labels_a_sequence):
total_times = softmax_a_sequence.shape[0]
total_segments = labels_a_sequence.shape[0] * 2 + 1
required_times = labels_a_sequence.shape[0]
old_label = -1
for i in range(labels_a_sequence.shape[0]):
# two contingous labels with the same value
if labels_a_sequence[i, 0] == old_label:
required_times = required_times + 1
old_label = labels_a_sequence[i, 0]
if total_times < required_times:
return 0
# calculate the forward and backward variables,
# reference Chapter 7.3 of "Alex Grave, Supervised Sequence
# Labelling with Recurrent Neural Networks"
log_acts = np.zeros([total_times, self.num_classes], dtype="float32")
for i in range(total_times):
for j in range(self.num_classes):
log_acts[i, j] = self.safe_log(softmax_a_sequence[i, j])
# calculate the forward variables
forward_vars = np.zeros([total_times, total_segments], dtype="float32")
for i in range(total_times):
for j in range(total_segments):
forward_vars[i, j] = self.LOG_ZERO
for i in range(total_times):
# dp initialization at t0
if i == 0:
forward_vars[i, 0] = log_acts[0, self.blank]
if total_segments > 1:
forward_vars[i, 1] = log_acts[0, labels_a_sequence[i, 0]]
continue
# dp from t1
start, end = self.segment_range(i, total_times, total_segments)
for k in range(end - start):
j = k + start
if j & 1 == 1:
label_idx = j / 2
label_val = labels_a_sequence[label_idx, 0]
fv = self.log_add(forward_vars[i - 1, j],
forward_vars[i - 1, j - 1])
if j > 1 and label_val != labels_a_sequence[label_idx - 1,
0]:
fv = self.log_add(fv, forward_vars[i - 1, j - 2])
fv = self.log_mul(fv, log_acts[i, label_val])
else:
fv = forward_vars[i - 1, j]
if j > 0:
fv = self.log_add(fv, forward_vars[i - 1, j - 1])
fv = self.log_mul(fv, log_acts[i, self.blank])
forward_vars[i, j] = fv
# sum the last two value as log_prob
log_prob = forward_vars[total_times - 1, total_segments - 1]
if total_segments > 1:
log_prob = self.log_add(
log_prob, forward_vars[total_times - 1, total_segments - 2])
return -log_prob
def forward(self):
for i in range(self.batch_size):
softmax_start_i = self.softmax_lod[self.level][i]
softmax_end_i = self.softmax_lod[self.level][i + 1]
labels_start_i = self.labels_lod[self.level][i]
labels_end_i = self.labels_lod[self.level][i + 1]
softmax_a_sequence = self.softmax[softmax_start_i:softmax_end_i, :]
labels_a_sequence = self.labels[labels_start_i:labels_end_i, :]
self.loss[i] = self.forward_a_sequence(softmax_a_sequence,
labels_a_sequence)
return self.loss
class TestWarpCTCOp(OpTest):
def setUp(self):
self.op_type = "warpctc"
batch_size = 4
num_classes = 8
logits_lod = [[0, 4, 5, 8, 11]]
logits = np.random.uniform(0.1, 1.0,
[11, num_classes]).astype("float32")
softmax = np.apply_along_axis(stable_softmax, 1, logits)
labels_lod = [[0, 3, 4, 8, 12]]
# labels should not be blank
labels = np.random.randint(0, num_classes - 1, [12, 1], dtype="int32")
blank = num_classes - 1
norm_by_times = False
ctc = CTCForward(softmax, logits_lod, labels, labels_lod, blank,
norm_by_times)
loss = ctc.forward()
max_sequence_length = 0
for i in range(batch_size):
max_sequence_length = max(max_sequence_length,
logits_lod[0][i + 1] - logits_lod[0][i])
gradient = np.zeros(
[max_sequence_length, batch_size, num_classes], dtype="float32")
self.inputs = {
"Logits": (logits, logits_lod),
"Label": (labels, labels_lod)
}
self.outputs = {"Loss": loss}
self.attrs = {"blank": blank, "norm_by_times": norm_by_times}
def test_check_output(self):
self.check_output()
# def test_check_grad(self):
# self.outputs["WarpCTCGrad"] = None
# self.check_grad(["Logits"], "Loss", max_relative_error=0.01)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册