未验证 提交 4f834cb2 编写于 作者: Z zhangyikun02 提交者: GitHub

change d2d copy to api copy in xpu kernel, test=kunlun (#48505)

上级 bc01d56e
......@@ -45,7 +45,11 @@ void AddGradKernel(const Context& dev_ctx,
T* dx_data = dev_ctx.template Alloc<T>(dx);
if (dx->dims() == dz_dims) {
if (dx_data != dz_data) {
Copy(dev_ctx, *dz, dev_ctx.GetPlace(), false, dx);
int ret = xpu::copy(dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(dz_data),
reinterpret_cast<XPUType*>(dx->data<T>()),
dx->numel());
PADDLE_ENFORCE_XDNN_SUCCESS(ret, "copy");
}
} else {
// For inplace strategy, dx will be stored in addr of dz, which makes
......@@ -73,7 +77,11 @@ void AddGradKernel(const Context& dev_ctx,
T* dy_data = dy->mutable_data<T>(dev_ctx.GetPlace());
if (dy->dims() == dz_dims) {
if (dy_data != dz_data) {
Copy(dev_ctx, *dz, dev_ctx.GetPlace(), false, dy);
int ret = xpu::copy(dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(dz_data),
reinterpret_cast<XPUType*>(dy->data<T>()),
dy->numel());
PADDLE_ENFORCE_XDNN_SUCCESS(ret, "copy");
}
} else {
std::vector<int> reduce_dims =
......
......@@ -68,6 +68,15 @@ void WarpctcKernel(const Context& dev_ctx,
"but received %d. ",
sequence_width));
int lm_workspace = (max_sequence_length + 1) *
(2 * max_target_seq_length + sequence_width + 1) *
sizeof(T) +
(7 * max_target_seq_length + 3) * sizeof(int);
PADDLE_ENFORCE_LE(lm_workspace,
6144,
phi::errors::InvalidArgument(
"Input size is too large for xpu in warpctc kernel"));
loss->Resize(phi::make_ddim({num_sequences, 1}));
dev_ctx.template Alloc<T>(loss);
......
......@@ -239,7 +239,6 @@ class XPUTestWarpCTCOp(XPUOpTestWrapper):
logits = np.random.uniform(
0.1, 1.0, [sum(self.logits_length), self.num_classes]
).astype(self.dtype)
print("logits.shape = ", logits.shape)
softmax = np.apply_along_axis(stable_softmax, 1, logits)
# labels should not be blank
labels = np.random.randint(
......@@ -416,7 +415,11 @@ class XPUTestWarpCTCOp(XPUOpTestWrapper):
labels = paddle.to_tensor(labels)
paddle.nn.functional.ctc_loss(
log_probs=softmax, labels=labels, reduction='none'
log_probs=softmax,
labels=labels,
input_lengths=None,
label_lengths=None,
reduction='none',
)
paddle.disable_static()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册