未验证 提交 02346930 编写于 作者: L Leo Chen 提交者: GitHub

fix gather_grad bug (#31607)

上级 5e851bff
......@@ -19,6 +19,7 @@ limitations under the License. */
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/operators/kron_op.h"
#include "paddle/fluid/operators/npu_op_runner.h"
#include "paddle/fluid/platform/npu_info.h"
namespace paddle {
namespace operators {
......@@ -65,20 +66,17 @@ class GatherGradOpNPUKernel : public framework::OpKernel<T> {
.stream();
// step2: ZerosLike x in device
Tensor *tmp_zerox = const_cast<Tensor *>(x);
Tensor zeroslike_xout(x->type());
zeroslike_xout.Resize(x->dims());
zeroslike_xout.mutable_data<T>(ctx.GetPlace());
auto p = zeroslike_xout.mutable_data<T>(ctx.GetPlace());
auto runner_zeroslike =
NpuOpRunner("ZerosLike", {*x}, {zeroslike_xout}, {});
runner_zeroslike.Run(stream);
tmp_zerox = &zeroslike_xout;
platform::NPUMemsetAsync(static_cast<void *>(p), 0,
zeroslike_xout.numel() * sizeof(T), stream);
// step3: scatter(x_grad)
dx->mutable_data<T>(ctx.GetPlace());
auto runner_scatter = NpuOpRunner("TensorScatterUpdate",
{*tmp_zerox, *index, *dout}, {*dx}, {});
auto runner_scatter = NpuOpRunner(
"TensorScatterUpdate", {zeroslike_xout, *index, *dout}, {*dx}, {});
runner_scatter.Run(stream);
}
};
......
......@@ -23,6 +23,7 @@ import paddle
import paddle.fluid as fluid
paddle.enable_static()
SEED = 2021
@unittest.skipIf(not paddle.is_compiled_with_npu(),
......@@ -105,5 +106,58 @@ class TestGatherAPI(unittest.TestCase):
pass
@unittest.skipIf(not paddle.is_compiled_with_npu(),
"core is not compiled with NPU")
class TestPowNet(unittest.TestCase):
def _test(self, run_npu=True):
main_prog = paddle.static.Program()
startup_prog = paddle.static.Program()
main_prog.random_seed = SEED
startup_prog.random_seed = SEED
np.random.seed(SEED)
a_np = np.random.random(size=(8192, 768)).astype('float32')
index_np = np.random.randint(0, 8192, size=(1232, 1)).astype('int32')
with paddle.static.program_guard(main_prog, startup_prog):
a = paddle.static.data(name="a", shape=[8192, 768], dtype='float32')
index = paddle.static.data(
name="index", shape=[1232, 1], dtype='int32')
a.stop_gradient = False
b = paddle.gather(a, index)
loss = fluid.layers.reduce_mean(b)
sgd = fluid.optimizer.SGD(learning_rate=0.01)
sgd.minimize(loss)
if run_npu:
place = paddle.NPUPlace(0)
else:
place = paddle.CPUPlace()
exe = paddle.static.Executor(place)
exe.run(startup_prog)
print("Start run on {}".format(place))
for epoch in range(100):
pred_res, loss_res = exe.run(main_prog,
feed={"a": a_np,
"index": index_np},
fetch_list=[b, loss])
if epoch % 10 == 0:
print("Epoch {} | Prediction[0]: {}, Loss: {}".format(
epoch, pred_res[0], loss_res[0]))
return pred_res, loss_res
def test_npu(self):
npu_pred, npu_loss = self._test(True)
cpu_pred, cpu_loss = self._test(False)
self.assertTrue(np.allclose(npu_pred, cpu_pred))
self.assertTrue(np.allclose(npu_loss, cpu_loss))
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册