提交 7333df85 编写于 作者: Q QI JUN 提交者: dzhwinter

fix pool_op bug (#7879)

上级 69a8438e
...@@ -139,10 +139,8 @@ class PoolGradKernel : public framework::OpKernel<T> { ...@@ -139,10 +139,8 @@ class PoolGradKernel : public framework::OpKernel<T> {
auto& dev_ctx = context.template device_context<DeviceContext>(); auto& dev_ctx = context.template device_context<DeviceContext>();
if (in_x_grad) { if (in_x_grad) {
in_x_grad->mutable_data<T>(context.GetPlace()); in_x_grad->mutable_data<T>(context.GetPlace());
auto temp = framework::EigenVector<T>::Flatten(*in_x_grad); paddle::operators::math::SetConstant<DeviceContext, T> set_constant;
temp.device( set_constant(dev_ctx, in_x_grad, 0.0);
*context.template device_context<DeviceContext>().eigen_device()) =
temp.constant(static_cast<T>(0));
switch (ksize.size()) { switch (ksize.size()) {
case 2: { case 2: {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册