未验证 提交 5e722245 编写于 作者: L Leo Guo 提交者: GitHub

Add unitest for set_value, set_value_grad. test=kunlun (#49773)

上级 5fd115f3
...@@ -143,7 +143,13 @@ void SetValueGradImpl(const Context& dev_ctx, ...@@ -143,7 +143,13 @@ void SetValueGradImpl(const Context& dev_ctx,
if (x_grad) { if (x_grad) {
// Set gradient of `Input` // Set gradient of `Input`
Copy(dev_ctx, out_grad, dev_ctx.GetPlace(), false, x_grad); x_grad->Resize(out_grad.dims());
dev_ctx.template Alloc<T>(x_grad);
r = xpu::copy(dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(out_grad.data<T>()),
reinterpret_cast<XPUType*>(x_grad->data<T>()),
out_grad.numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "copy");
DenseTensor tmp = Full<T>(dev_ctx, out_dims_vector, static_cast<T>(0)); DenseTensor tmp = Full<T>(dev_ctx, out_dims_vector, static_cast<T>(0));
......
...@@ -119,8 +119,6 @@ void SetValueImpl(const Context& dev_ctx, ...@@ -119,8 +119,6 @@ void SetValueImpl(const Context& dev_ctx,
slice_dims_for_assign = phi::make_ddim(slice_dims_with_none); slice_dims_for_assign = phi::make_ddim(slice_dims_with_none);
} }
auto place = dev_ctx.GetPlace();
// Here copy data from input to avoid data loss at PE and Graph level. // Here copy data from input to avoid data loss at PE and Graph level.
// TODO(liym27): Speed up in the future version. // TODO(liym27): Speed up in the future version.
// - Q: Why don't call ShareDataWith to speed up? // - Q: Why don't call ShareDataWith to speed up?
...@@ -132,7 +130,14 @@ void SetValueImpl(const Context& dev_ctx, ...@@ -132,7 +130,14 @@ void SetValueImpl(const Context& dev_ctx,
// be two ops points to the output in graph: op1 -> output <- set_value. // be two ops points to the output in graph: op1 -> output <- set_value.
// In this case, we have to find a way to handle the running order of // In this case, we have to find a way to handle the running order of
// set_value is what we want. // set_value is what we want.
Copy(dev_ctx, in, place, false, out); int r = XPU_SUCCESS;
out->Resize(in.dims());
dev_ctx.template Alloc<T>(out);
r = xpu::copy(dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(in.data<T>()),
reinterpret_cast<XPUType*>(out->data<T>()),
in.numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "copy");
DenseTensor slice_tensor = DenseTensor slice_tensor =
Empty<T>(dev_ctx, IntArray{slice_dims.Get(), slice_dims.size()}); Empty<T>(dev_ctx, IntArray{slice_dims.Get(), slice_dims.size()});
...@@ -179,7 +184,6 @@ void SetValueImpl(const Context& dev_ctx, ...@@ -179,7 +184,6 @@ void SetValueImpl(const Context& dev_ctx,
auto out_shape = phi::vectorize<int>(out->dims()); auto out_shape = phi::vectorize<int>(out->dims());
auto slice_shape = phi::vectorize<int>(slice_dims); auto slice_shape = phi::vectorize<int>(slice_dims);
int r = XPU_SUCCESS;
r = xpu::strided_slice(dev_ctx.x_context(), r = xpu::strided_slice(dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(out->data<T>()), reinterpret_cast<const XPUType*>(out->data<T>()),
reinterpret_cast<XPUType*>(slice_tensor.data<T>()), reinterpret_cast<XPUType*>(slice_tensor.data<T>()),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册