未验证 提交 8630ba2e 编写于 作者: Q Qingsheng Li 提交者: GitHub

Fix sequence expand op (#11618)

* Set zero outside functor
上级 01fbcb0b
......@@ -151,9 +151,6 @@ struct SequenceExpandGradFunctor<platform::CPUDeviceContext, T> {
const framework::Vector<size_t>& x_lod, /*expand source lod*/
const framework::Vector<size_t>& ref_lod, /*expand referenced lod*/
LoDTensor* dx) {
math::SetConstant<platform::CPUDeviceContext, T> set_zero;
set_zero(context, dx, static_cast<T>(0));
int dout_offset = 0;
for (size_t i = 1; i < ref_lod.size(); ++i) {
int repeat_num = ref_lod[i] - ref_lod[i - 1];
......@@ -187,6 +184,10 @@ class SequenceExpandGradKernel : public framework::OpKernel<T> {
g_x->mutable_data<T>(context.GetPlace());
g_x->set_lod(x->lod());
auto& dev_ctx = context.template device_context<DeviceContext>();
math::SetConstant<DeviceContext, T> set_zero;
set_zero(dev_ctx, g_x, static_cast<T>(0));
auto& y_lod = y->lod();
if (ref_level == -1) ref_level = y_lod.size() - 1;
// just copy the gradient
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册