未验证 提交 ba70cc49 编写于 作者: L liuwei1031 提交者: GitHub
上级 ff7f911b
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
#endif #endif
#include <algorithm> #include <algorithm>
#include <chrono> // NOLINT #include <chrono> // NOLINT
#include <functional>
#include <iterator> #include <iterator>
#include <numeric> #include <numeric>
#include <sstream> #include <sstream>
...@@ -79,26 +80,63 @@ static void split(const std::string &str, char sep, ...@@ -79,26 +80,63 @@ static void split(const std::string &str, char sep,
pieces->push_back(str.substr(pos)); pieces->push_back(str.substr(pos));
} }
} }
template <typename T>
static T convert(const std::string &item,
std::function<T(const std::string &item)> func) {
T res;
try {
res = func(item);
} catch (std::invalid_argument &e) {
std::string message =
"invalid_argument exception when try to convert : " + item;
LOG(ERROR) << message;
PADDLE_THROW(message);
} catch (std::out_of_range &e) {
std::string message =
"out_of_range exception when try to convert : " + item;
LOG(ERROR) << message;
PADDLE_THROW(message);
} catch (...) {
std::string message = "unexpected exception when try to convert " + item;
LOG(ERROR) << message;
PADDLE_THROW(message);
}
return res;
}
static void split_to_float(const std::string &str, char sep, static void split_to_float(const std::string &str, char sep,
std::vector<float> *fs) { std::vector<float> *fs) {
std::vector<std::string> pieces; std::vector<std::string> pieces;
split(str, sep, &pieces); split(str, sep, &pieces);
std::transform(pieces.begin(), pieces.end(), std::back_inserter(*fs), std::transform(pieces.begin(), pieces.end(), std::back_inserter(*fs),
[](const std::string &v) { return std::stof(v); }); [](const std::string &v) {
return convert<float>(v, [](const std::string &item) {
return std::stof(item);
});
});
} }
static void split_to_int64(const std::string &str, char sep, static void split_to_int64(const std::string &str, char sep,
std::vector<int64_t> *is) { std::vector<int64_t> *is) {
std::vector<std::string> pieces; std::vector<std::string> pieces;
split(str, sep, &pieces); split(str, sep, &pieces);
std::transform(pieces.begin(), pieces.end(), std::back_inserter(*is), std::transform(pieces.begin(), pieces.end(), std::back_inserter(*is),
[](const std::string &v) { return std::stoi(v); }); [](const std::string &v) {
return convert<int64_t>(v, [](const std::string &item) {
return std::stoll(item);
});
});
} }
static void split_to_int(const std::string &str, char sep, static void split_to_int(const std::string &str, char sep,
std::vector<int> *is) { std::vector<int> *is) {
std::vector<std::string> pieces; std::vector<std::string> pieces;
split(str, sep, &pieces); split(str, sep, &pieces);
std::transform(pieces.begin(), pieces.end(), std::back_inserter(*is), std::transform(pieces.begin(), pieces.end(), std::back_inserter(*is),
[](const std::string &v) { return std::stoi(v); }); [](const std::string &v) {
return convert<int>(v, [](const std::string &item) {
return std::stoi(item);
});
});
} }
template <typename T> template <typename T>
std::string to_string(const std::vector<T> &vec) { std::string to_string(const std::vector<T> &vec) {
......
...@@ -34,7 +34,7 @@ class Im2SequenceOp : public framework::OperatorWithKernel { ...@@ -34,7 +34,7 @@ class Im2SequenceOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_EQ(in_dim.size(), 4, PADDLE_ENFORCE_EQ(in_dim.size(), 4,
"Input(X) format must be 4D tensor, eg., NCHW."); "Input(X) format must be 4D tensor, eg., NCHW.");
int img_channels = in_dim[1]; auto img_channels = in_dim[1];
auto kernels = ctx->Attrs().Get<std::vector<int>>("kernels"); auto kernels = ctx->Attrs().Get<std::vector<int>>("kernels");
auto strides = ctx->Attrs().Get<std::vector<int>>("strides"); auto strides = ctx->Attrs().Get<std::vector<int>>("strides");
......
...@@ -113,8 +113,9 @@ class Im2SequenceKernel : public framework::OpKernel<T> { ...@@ -113,8 +113,9 @@ class Im2SequenceKernel : public framework::OpKernel<T> {
paddings[2], strides[0]); paddings[2], strides[0]);
int output_width = Im2SeqOutputSize(img_width, kernels[1], paddings[1], int output_width = Im2SeqOutputSize(img_width, kernels[1], paddings[1],
paddings[3], strides[1]); paddings[3], strides[1]);
out->mutable_data<T>({batch_size * output_height * output_width, out->mutable_data<T>(
img_channels * kernels[0] * kernels[1]}, {static_cast<int64_t>(batch_size) * output_height * output_width,
static_cast<int64_t>(img_channels) * kernels[0] * kernels[1]},
ctx.GetPlace()); ctx.GetPlace());
const std::vector<int> dilations({1, 1}); const std::vector<int> dilations({1, 1});
auto out_dims = out->dims(); auto out_dims = out->dims();
......
...@@ -144,7 +144,8 @@ class ContextProjectFunctor { ...@@ -144,7 +144,8 @@ class ContextProjectFunctor {
sequence_height = static_cast<int>(out_t.dims()[0]); sequence_height = static_cast<int>(out_t.dims()[0]);
// add up trainable data // add up trainable data
out_t.Resize({sequence_height * context_length, sequence_width}); out_t.Resize({static_cast<int64_t>(sequence_height) * context_length,
sequence_width});
if (up_pad > 0) { // add up pad if (up_pad > 0) { // add up pad
int padding_rows = std::min( int padding_rows = std::min(
...@@ -191,7 +192,8 @@ class ContextProjectFunctor { ...@@ -191,7 +192,8 @@ class ContextProjectFunctor {
&out_t_sub); &out_t_sub);
} }
} }
out_t.Resize({sequence_height, context_length * sequence_width}); out_t.Resize({sequence_height,
static_cast<int64_t>(context_length) * sequence_width});
} }
} }
} }
...@@ -260,7 +262,8 @@ class ContextProjectGradFunctor { ...@@ -260,7 +262,8 @@ class ContextProjectGradFunctor {
static_cast<int>(lod_level_0[i + 1])); static_cast<int>(lod_level_0[i + 1]));
sequence_height = static_cast<int>(out_t.dims()[0]); sequence_height = static_cast<int>(out_t.dims()[0]);
out_t.Resize({sequence_height * context_length, sequence_width}); out_t.Resize({static_cast<int64_t>(sequence_height) * context_length,
sequence_width});
if (up_pad > 0) { if (up_pad > 0) {
int padding_rows = std::min( int padding_rows = std::min(
...@@ -308,7 +311,8 @@ class ContextProjectGradFunctor { ...@@ -308,7 +311,8 @@ class ContextProjectGradFunctor {
w_sub.data<T>()); w_sub.data<T>());
} }
} }
out_t.Resize({sequence_height, context_length * sequence_width}); out_t.Resize({sequence_height,
static_cast<int64_t>(context_length) * sequence_width});
} }
} }
} }
......
...@@ -32,17 +32,17 @@ namespace reader { ...@@ -32,17 +32,17 @@ namespace reader {
static inline void string_split(const std::string& s, const char delimiter, static inline void string_split(const std::string& s, const char delimiter,
std::vector<std::string>* output) { std::vector<std::string>* output) {
size_t start = 0; if (s.empty()) return;
size_t end = s.find_first_of(delimiter);
while (end <= std::string::npos) { size_t start = 0;
output->emplace_back(s.substr(start, end - start)); size_t end = s.find(delimiter);
if (end == std::string::npos) { while (end != std::string::npos) {
break; if (end > start) output->emplace_back(s.substr(start, end - start));
}
start = end + 1; start = end + 1;
end = s.find_first_of(delimiter, start); end = s.find(delimiter, start);
} }
auto term = s.substr(start);
if (!term.empty()) output->emplace_back(term);
} }
static inline void parse_line( static inline void parse_line(
......
...@@ -61,10 +61,10 @@ class UnpoolGradKernel : public framework::OpKernel<T> { ...@@ -61,10 +61,10 @@ class UnpoolGradKernel : public framework::OpKernel<T> {
auto& device_ctx = context.template device_context<DeviceContext>(); auto& device_ctx = context.template device_context<DeviceContext>();
math::SetConstant<DeviceContext, T> zero; math::SetConstant<DeviceContext, T> zero;
if (in_x_grad) {
in_x_grad->mutable_data<T>(context.GetPlace()); in_x_grad->mutable_data<T>(context.GetPlace());
zero(device_ctx, in_x_grad, static_cast<T>(0)); zero(device_ctx, in_x_grad, static_cast<T>(0));
}
math::Unpool2dMaxGradFunctor<DeviceContext, T> unpool2d_max_backward; math::Unpool2dMaxGradFunctor<DeviceContext, T> unpool2d_max_backward;
unpool2d_max_backward(device_ctx, *in_x, *in_y, *out, *out_grad, in_x_grad); unpool2d_max_backward(device_ctx, *in_x, *in_y, *out, *out_grad, in_x_grad);
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册