提交 7fb1f7a2 编写于 作者: D dangqingqing

Fix lstm_op and gru_op in debug mode.

上级 2523410e
...@@ -297,7 +297,25 @@ void set_constant_with_place<platform::GPUPlace>( ...@@ -297,7 +297,25 @@ void set_constant_with_place<platform::GPUPlace>(
template struct RowwiseAdd<platform::GPUPlace, float>; template struct RowwiseAdd<platform::GPUPlace, float>;
template struct RowwiseAdd<platform::GPUPlace, double>; template struct RowwiseAdd<platform::GPUPlace, double>;
template struct ColwiseSum<platform::GPUPlace, float>; template struct ColwiseSum<platform::GPUPlace, float>;
template struct ColwiseSum<platform::GPUPlace, double>; // template struct ColwiseSum<platform::GPUPlace, double>;
// The ColwiseSum<platform::GPUPlace, double> failed in debug mode,
// and only failed for this case. So reimplemented it.
template <>
void ColwiseSum<platform::GPUPlace, double>::operator()(
const platform::DeviceContext& context, const framework::Tensor& input,
framework::Tensor* vector) {
auto in_dims = input.dims();
auto size = input.numel() / in_dims[0];
PADDLE_ENFORCE_EQ(vector->numel(), size);
framework::Tensor one;
one.mutable_data<double>({in_dims[0]}, context.GetPlace());
SetConstant<platform::GPUPlace, double> set;
set(context, &one, static_cast<double>(1.0));
gemv<platform::GPUPlace, double>(context, true, static_cast<int>(in_dims[0]),
static_cast<int>(in_dims[1]), 1.0,
input.data<double>(), one.data<double>(),
0.0, vector->data<double>());
}
} // namespace math } // namespace math
} // namespace operators } // namespace operators
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册