提交 c474e7dd 编写于 作者: K Kevin 提交者: liuwei1031

fix overflow by int32 mul test=develop (#16794)

* fix overflow by int32 mul test=develop

* fix reference nullptr

* fix codestyle test=develop

* modify to point in ContextProjectFunctor test=develop

* modify to point in ContextProjectFunctor test=develop

* modify . to -> test=develop
上级 b7f20ed6
......@@ -87,7 +87,7 @@ template <typename DeviceContext, typename T>
class ContextProjectFunctor {
public:
void operator()(const DeviceContext& context, const LoDTensor& in,
const Tensor& padding_data, bool padding_trainable,
const Tensor* padding_data, bool padding_trainable,
const int context_start, const int context_length,
const int context_stride, const int up_pad,
const int down_pad, Tensor* col) {
......@@ -132,6 +132,7 @@ class ContextProjectFunctor {
}
}
if (padding_trainable) {
PADDLE_ENFORCE_NOT_NULL(padding_data);
for (int i = 0; i < static_cast<int>(lod_level_0.size()) - 1; ++i) {
Tensor out_t = col->Slice(static_cast<int>(lod_level_0[i]),
static_cast<int>(lod_level_0[i + 1]));
......@@ -150,7 +151,7 @@ class ContextProjectFunctor {
k + context_length < up_pad ? context_length : up_pad - k;
Tensor out_t_sub = out_t.Slice(k * context_length,
k * context_length + padding_size);
Tensor w_sub = padding_data.Slice(k, k + padding_size);
Tensor w_sub = padding_data->Slice(k, k + padding_size);
framework::TensorCopy(w_sub, context.GetPlace(), context,
&out_t_sub);
}
......@@ -180,7 +181,7 @@ class ContextProjectFunctor {
Tensor out_t_sub = out_t.Slice(
(down_pad_begin_row + t) * context_length - padding_size,
(down_pad_begin_row + t) * context_length);
Tensor w_sub = padding_data.Slice(
Tensor w_sub = padding_data->Slice(
up_pad + padding_idx, up_pad + padding_idx + padding_size);
framework::TensorCopy(w_sub, context.GetPlace(), context,
&out_t_sub);
......
......@@ -49,7 +49,7 @@ class SequenceConvKernel : public framework::OpKernel<T> {
int up_pad = std::max(0, -context_start);
int down_pad = std::max(0, context_start + context_length - 1);
int sequence_width = static_cast<int>(in->dims()[1]);
auto sequence_width = static_cast<int64_t>(in->dims()[1]);
framework::DDim col_shape = {in->dims()[0],
context_length * sequence_width};
......@@ -62,7 +62,7 @@ class SequenceConvKernel : public framework::OpKernel<T> {
set_zero(dev_ctx, &col, static_cast<T>(0));
math::ContextProjectFunctor<DeviceContext, T> seq_project_functor;
seq_project_functor(dev_ctx, *in, *padding_data, padding_trainable,
seq_project_functor(dev_ctx, *in, padding_data, padding_trainable,
context_start, context_length, context_stride, up_pad,
down_pad, &col);
......@@ -93,7 +93,7 @@ class SequenceConvGradKernel : public framework::OpKernel<T> {
int up_pad = std::max(0, -context_start);
int down_pad = std::max(0, context_start + context_length - 1);
int sequence_width = static_cast<int>(in->dims()[1]);
auto sequence_width = static_cast<int64_t>(in->dims()[1]);
math::SetConstant<DeviceContext, T> set_zero;
auto& dev_ctx = context.template device_context<DeviceContext>();
......@@ -144,7 +144,7 @@ class SequenceConvGradKernel : public framework::OpKernel<T> {
padding_data = context.Input<Tensor>("PaddingData");
}
seq_project_functor(dev_ctx, *in, *padding_data, padding_trainable,
seq_project_functor(dev_ctx, *in, padding_data, padding_trainable,
context_start, context_length, context_stride, up_pad,
down_pad, &col);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册