未验证 提交 f1711f24 编写于 作者: C Chen Weihang 提交者: GitHub

[CustomOp] Fix custom op pinned input error (#41972)

* fix custom op pinned input error

* fix compile error
上级 d70104e5
......@@ -33,6 +33,7 @@ limitations under the License. */
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/phi_utils.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/platform/device/gpu/gpu_info.h"
#include "paddle/fluid/platform/dynload/dynamic_loader.h"
#include "paddle/fluid/string/string_helper.h"
#include "paddle/phi/api/all.h"
......@@ -160,7 +161,18 @@ static void RunKernelFunc(const framework::ExecutionContext& ctx,
"Input tensor (%s) is not initialized.", in_name));
paddle::experimental::Tensor custom_in;
custom_in.set_impl(std::make_shared<phi::DenseTensor>(*x));
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (custom_in.is_gpu_pinned()) {
VLOG(3) << "Custom Operator: custom input is gpu pinned tensor";
auto gpu_place = phi::GPUPlace(platform::GetCurrentDeviceId());
auto custom_gpu_in = custom_in.copy_to(gpu_place, true);
kernel_ctx.EmplaceBackInput(std::move(custom_gpu_in));
} else {
kernel_ctx.EmplaceBackInput(std::move(custom_in));
}
#else
kernel_ctx.EmplaceBackInput(std::move(custom_in));
#endif
}
}
......
......@@ -20,6 +20,7 @@ import paddle
import paddle.static as static
import subprocess
import numpy as np
from paddle.vision.transforms import Compose, Normalize
from paddle.utils.cpp_extension.extension_utils import run_cmd
from paddle.fluid.framework import _test_eager_guard
......@@ -329,6 +330,33 @@ class TestNewCustomOpSetUpInstall(unittest.TestCase):
"custom op dx grad: {},\n paddle api dx grad: {}".format(
dx_grad, pd_dx_grad))
def test_with_dataloader(self):
for device in self.devices:
paddle.set_device(device)
# data loader
transform = Compose(
[Normalize(
mean=[127.5], std=[127.5], data_format='CHW')])
train_dataset = paddle.vision.datasets.MNIST(
mode='train', transform=transform)
train_loader = paddle.io.DataLoader(
train_dataset,
batch_size=64,
shuffle=True,
drop_last=True,
num_workers=0)
for batch_id, (image, _) in enumerate(train_loader()):
out = self.custom_ops[0](image)
pd_out = paddle.nn.functional.relu(image)
self.assertTrue(
np.array_equal(out, pd_out),
"custom op out: {},\n paddle api out: {}".format(out,
pd_out))
if batch_id == 5:
break
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册