未验证 提交 ba70cc49 编写于 作者: L liuwei1031 提交者: GitHub
上级 ff7f911b
......@@ -21,6 +21,7 @@
#endif
#include <algorithm>
#include <chrono> // NOLINT
#include <functional>
#include <iterator>
#include <numeric>
#include <sstream>
......@@ -79,26 +80,63 @@ static void split(const std::string &str, char sep,
pieces->push_back(str.substr(pos));
}
}
template <typename T>
static T convert(const std::string &item,
std::function<T(const std::string &item)> func) {
T res;
try {
res = func(item);
} catch (std::invalid_argument &e) {
std::string message =
"invalid_argument exception when try to convert : " + item;
LOG(ERROR) << message;
PADDLE_THROW(message);
} catch (std::out_of_range &e) {
std::string message =
"out_of_range exception when try to convert : " + item;
LOG(ERROR) << message;
PADDLE_THROW(message);
} catch (...) {
std::string message = "unexpected exception when try to convert " + item;
LOG(ERROR) << message;
PADDLE_THROW(message);
}
return res;
}
static void split_to_float(const std::string &str, char sep,
std::vector<float> *fs) {
std::vector<std::string> pieces;
split(str, sep, &pieces);
std::transform(pieces.begin(), pieces.end(), std::back_inserter(*fs),
[](const std::string &v) { return std::stof(v); });
[](const std::string &v) {
return convert<float>(v, [](const std::string &item) {
return std::stof(item);
});
});
}
static void split_to_int64(const std::string &str, char sep,
std::vector<int64_t> *is) {
std::vector<std::string> pieces;
split(str, sep, &pieces);
std::transform(pieces.begin(), pieces.end(), std::back_inserter(*is),
[](const std::string &v) { return std::stoi(v); });
[](const std::string &v) {
return convert<int64_t>(v, [](const std::string &item) {
return std::stoll(item);
});
});
}
static void split_to_int(const std::string &str, char sep,
std::vector<int> *is) {
std::vector<std::string> pieces;
split(str, sep, &pieces);
std::transform(pieces.begin(), pieces.end(), std::back_inserter(*is),
[](const std::string &v) { return std::stoi(v); });
[](const std::string &v) {
return convert<int>(v, [](const std::string &item) {
return std::stoi(item);
});
});
}
template <typename T>
std::string to_string(const std::vector<T> &vec) {
......
......@@ -34,7 +34,7 @@ class Im2SequenceOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_EQ(in_dim.size(), 4,
"Input(X) format must be 4D tensor, eg., NCHW.");
int img_channels = in_dim[1];
auto img_channels = in_dim[1];
auto kernels = ctx->Attrs().Get<std::vector<int>>("kernels");
auto strides = ctx->Attrs().Get<std::vector<int>>("strides");
......
......@@ -113,8 +113,9 @@ class Im2SequenceKernel : public framework::OpKernel<T> {
paddings[2], strides[0]);
int output_width = Im2SeqOutputSize(img_width, kernels[1], paddings[1],
paddings[3], strides[1]);
out->mutable_data<T>({batch_size * output_height * output_width,
img_channels * kernels[0] * kernels[1]},
out->mutable_data<T>(
{static_cast<int64_t>(batch_size) * output_height * output_width,
static_cast<int64_t>(img_channels) * kernels[0] * kernels[1]},
ctx.GetPlace());
const std::vector<int> dilations({1, 1});
auto out_dims = out->dims();
......
......@@ -144,7 +144,8 @@ class ContextProjectFunctor {
sequence_height = static_cast<int>(out_t.dims()[0]);
// add up trainable data
out_t.Resize({sequence_height * context_length, sequence_width});
out_t.Resize({static_cast<int64_t>(sequence_height) * context_length,
sequence_width});
if (up_pad > 0) { // add up pad
int padding_rows = std::min(
......@@ -191,7 +192,8 @@ class ContextProjectFunctor {
&out_t_sub);
}
}
out_t.Resize({sequence_height, context_length * sequence_width});
out_t.Resize({sequence_height,
static_cast<int64_t>(context_length) * sequence_width});
}
}
}
......@@ -260,7 +262,8 @@ class ContextProjectGradFunctor {
static_cast<int>(lod_level_0[i + 1]));
sequence_height = static_cast<int>(out_t.dims()[0]);
out_t.Resize({sequence_height * context_length, sequence_width});
out_t.Resize({static_cast<int64_t>(sequence_height) * context_length,
sequence_width});
if (up_pad > 0) {
int padding_rows = std::min(
......@@ -308,7 +311,8 @@ class ContextProjectGradFunctor {
w_sub.data<T>());
}
}
out_t.Resize({sequence_height, context_length * sequence_width});
out_t.Resize({sequence_height,
static_cast<int64_t>(context_length) * sequence_width});
}
}
}
......
......@@ -32,17 +32,17 @@ namespace reader {
static inline void string_split(const std::string& s, const char delimiter,
std::vector<std::string>* output) {
size_t start = 0;
size_t end = s.find_first_of(delimiter);
if (s.empty()) return;
while (end <= std::string::npos) {
output->emplace_back(s.substr(start, end - start));
if (end == std::string::npos) {
break;
}
size_t start = 0;
size_t end = s.find(delimiter);
while (end != std::string::npos) {
if (end > start) output->emplace_back(s.substr(start, end - start));
start = end + 1;
end = s.find_first_of(delimiter, start);
end = s.find(delimiter, start);
}
auto term = s.substr(start);
if (!term.empty()) output->emplace_back(term);
}
static inline void parse_line(
......
......@@ -61,10 +61,10 @@ class UnpoolGradKernel : public framework::OpKernel<T> {
auto& device_ctx = context.template device_context<DeviceContext>();
math::SetConstant<DeviceContext, T> zero;
if (in_x_grad) {
in_x_grad->mutable_data<T>(context.GetPlace());
zero(device_ctx, in_x_grad, static_cast<T>(0));
}
math::Unpool2dMaxGradFunctor<DeviceContext, T> unpool2d_max_backward;
unpool2d_max_backward(device_ctx, *in_x, *in_y, *out, *out_grad, in_x_grad);
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册