未验证 提交 39eb77a6 编写于 作者: H HongyuJia 提交者: GitHub

[Cherry-pick] Fix custom operator backward=None (#48656) (#48715)

* [Release2.4] Revert python link prs (#48573)

* Revert "Fix mac link python (#48017)"

This reverts commit 3fa7a736.

* Revert "[Cherry-pick] Fix python link error (#47811)"

This reverts commit ff642c68.

* Update config.go

* fix custom operator backward=None (#48656)

* [Custom Extension] Fix custom double_grad backward=None (#49224)

* fix custom double_grad backward=None

* fix custom_relu.cu bug && polish testcase of double_grad

* remove old dynamic graph test

* add import fluid

* add import fluid
Co-authored-by: NChen Weihang <chenweihang@baidu.com>
上级 5d29a5bf
......@@ -217,18 +217,20 @@ RunCustomOpNode::operator()(
VLOG(6) << "Prepare Grad outputs for size: " << grad_outputs_names.size();
for (size_t i = 0; i < OutputMeta().size(); i++) {
if (map[0][0].find(i) != map[0][0].end()) {
int grad_output_idx = map[0][0][i];
VLOG(7) << "Insert grad outputs: " << i
<< " with size: " << OutputMeta()[i].size()
<< " to tmp_outputs: " << map[0][0][i];
for (size_t j = 0; j < OutputMeta()[i].size(); j++) {
outs[i].emplace_back(/* init it incase of copy nullptr of shared_ptr */
std::make_shared<phi::DenseTensor>(
phi::DataType::UNDEFINED),
egr::Controller::Instance().GenerateUniqueName(
"custom_tmp_grad"));
egr::EagerUtils::autograd_meta(&(outs[i][j]));
<< " with size: " << OutputMeta()[grad_output_idx].size()
<< " to tmp_outputs: " << grad_output_idx;
for (size_t j = 0; j < OutputMeta()[grad_output_idx].size(); j++) {
outs[grad_output_idx]
.emplace_back(/* init it incase of copy nullptr of shared_ptr */
std::make_shared<phi::DenseTensor>(
phi::DataType::UNDEFINED),
egr::Controller::Instance().GenerateUniqueName(
"custom_tmp_grad"));
egr::EagerUtils::autograd_meta(&(outs[grad_output_idx][j]));
}
tmp_outs[map[0][0][i]] = outs[i];
tmp_outs[grad_output_idx] = outs[grad_output_idx];
}
}
for (size_t i = 0; i < tmp_outs.size(); i++) {
......@@ -408,17 +410,19 @@ RunCustomOpDoubleGradNode::operator()(
for (size_t i = 0; i < OutputMeta().size(); i++) {
if (map[1][0].find(i) != map[1][0].end()) {
int grad_output_idx = map[1][0][i];
VLOG(7) << "Insert grad outputs: " << i
<< " with size: " << OutputMeta()[i].size()
<< " to tmp_outputs: " << map[1][0][i];
for (size_t j = 0; j < OutputMeta()[i].size(); j++) {
outs[i].emplace_back(/* init it incase of copy nullptr of shared_ptr */
std::make_shared<phi::DenseTensor>(
phi::DataType::UNDEFINED),
egr::Controller::Instance().GenerateUniqueName(
"custom_tmp_grad"));
<< " with size: " << OutputMeta()[grad_output_idx].size()
<< " to tmp_outputs: " << grad_output_idx;
for (size_t j = 0; j < OutputMeta()[grad_output_idx].size(); j++) {
outs[grad_output_idx]
.emplace_back(/* init it incase of copy nullptr of shared_ptr */
std::make_shared<phi::DenseTensor>(
phi::DataType::UNDEFINED),
egr::Controller::Instance().GenerateUniqueName(
"custom_tmp_grad"));
}
tmp_outs[map[1][0][i]] = outs[i];
tmp_outs[grad_output_idx] = outs[grad_output_idx];
}
}
for (size_t i = 0; i < tmp_outs.size(); i++) {
......
......@@ -44,7 +44,7 @@ __global__ void relu_cuda_double_backward_kernel(const data_t* out_data,
data_t* ddout_data,
int64_t num) {
int64_t gid = blockIdx.x * blockDim.x + threadIdx.x;
for (int64_t i = num; i < num; i += blockDim.x * gridDim.x) {
for (int64_t i = gid; i < num; i += blockDim.x * gridDim.x) {
ddout_data[i] = ddx_data[i] * (out_data[i] > static_cast<data_t>(0.)
? static_cast<data_t>(1.)
: static_cast<data_t>(0.));
......
......@@ -21,6 +21,7 @@ import paddle.static as static
import tempfile
import subprocess
import numpy as np
from paddle import fluid
from paddle.vision.transforms import Compose, Normalize
from paddle.utils.cpp_extension.extension_utils import run_cmd
from paddle.fluid.framework import _test_eager_guard
......@@ -43,12 +44,9 @@ def custom_relu_dynamic(func, device, dtype, np_x, use_func=True):
return out.numpy(), t.grad.numpy()
def custom_relu_static(func,
device,
dtype,
np_x,
use_func=True,
test_infer=False):
def custom_relu_static(
func, device, dtype, np_x, use_func=True, test_infer=False
):
paddle.enable_static()
paddle.set_device(device)
......@@ -62,9 +60,11 @@ def custom_relu_static(func,
exe = static.Executor()
exe.run(static.default_startup_program())
# in static mode, x data has been covered by out
out_v = exe.run(static.default_main_program(),
feed={'X': np_x},
fetch_list=[out.name])
out_v = exe.run(
static.default_main_program(),
feed={'X': np_x},
fetch_list=[out.name],
)
paddle.disable_static()
return out_v
......@@ -87,11 +87,11 @@ def custom_relu_static_pe(func, device, dtype, np_x, use_func=True):
# in static mode, x data has been covered by out
compiled_prog = static.CompiledProgram(
static.default_main_program()).with_data_parallel(
loss_name=out.name, places=places)
out_v = exe.run(compiled_prog,
feed={'X': np_x},
fetch_list=[out.name])
static.default_main_program()
).with_data_parallel(loss_name=out.name, places=places)
out_v = exe.run(
compiled_prog, feed={'X': np_x}, fetch_list=[out.name]
)
paddle.disable_static()
return out_v
......@@ -103,9 +103,9 @@ def custom_relu_static_inference(func, device, np_data, np_label, path_prefix):
with static.scope_guard(static.Scope()):
with static.program_guard(static.Program()):
# simple module
data = static.data(name='data',
shape=[None, 1, 28, 28],
dtype='float32')
data = static.data(
name='data', shape=[None, 1, 28, 28], dtype='float32'
)
label = static.data(name='label', shape=[None, 1], dtype='int64')
hidden = static.nn.fc(data, size=128)
......@@ -124,23 +124,21 @@ def custom_relu_static_inference(func, device, np_data, np_label, path_prefix):
# train
for i in range(4):
avg_loss_v = exe.run(static.default_main_program(),
feed={
'data': np_data,
'label': np_label
},
fetch_list=[avg_loss])
avg_loss_v = exe.run(
static.default_main_program(),
feed={'data': np_data, 'label': np_label},
fetch_list=[avg_loss],
)
# save inference model
static.save_inference_model(path_prefix, [data], [predict], exe)
# get train predict value
predict_v = exe.run(static.default_main_program(),
feed={
'data': np_data,
'label': np_label
},
fetch_list=[predict])
predict_v = exe.run(
static.default_main_program(),
feed={'data': np_data, 'label': np_label},
fetch_list=[predict],
)
return predict_v
......@@ -151,30 +149,37 @@ def custom_relu_double_grad_dynamic(func, device, dtype, np_x, use_func=True):
t = paddle.to_tensor(np_x, dtype=dtype, stop_gradient=False)
out = func(t) if use_func else paddle.nn.functional.relu(t)
out.stop_gradient = False
dx = paddle.grad(outputs=[out],
inputs=[t],
create_graph=True,
retain_graph=True)
dx = paddle.grad(
outputs=out,
inputs=t,
grad_outputs=paddle.ones_like(t),
create_graph=True,
retain_graph=True,
)
dx[0].backward()
ddout = paddle.grad(
outputs=dx[0],
inputs=out.grad,
grad_outputs=paddle.ones_like(t),
create_graph=False,
)
assert dx[0].grad is not None
return dx[0].numpy(), dx[0].grad.numpy()
assert ddout[0].numpy() is not None
return dx[0].numpy(), ddout[0].numpy()
class TestNewCustomOpSetUpInstall(unittest.TestCase):
def setUp(self):
cur_dir = os.path.dirname(os.path.abspath(__file__))
# compile, install the custom op egg into site-packages under background
if os.name == 'nt':
cmd = 'cd /d {} && python custom_relu_setup.py install'.format(
cur_dir)
cur_dir
)
else:
cmd = 'cd {} && {} custom_relu_setup.py install'.format(
cur_dir, sys.executable)
cur_dir, sys.executable
)
run_cmd(cmd)
# NOTE(Aurelius84): Normally, it's no need to add following codes for users.
......@@ -190,16 +195,18 @@ class TestNewCustomOpSetUpInstall(unittest.TestCase):
custom_egg_path = [
x for x in os.listdir(site_dir) if 'custom_relu_module_setup' in x
]
assert len(custom_egg_path
) == 1, "Matched egg number is %d." % len(custom_egg_path)
assert len(custom_egg_path) == 1, "Matched egg number is %d." % len(
custom_egg_path
)
sys.path.append(os.path.join(site_dir, custom_egg_path[0]))
# usage: import the package directly
import custom_relu_module_setup
# `custom_relu_dup` is same as `custom_relu_dup`
self.custom_ops = [
custom_relu_module_setup.custom_relu,
custom_relu_module_setup.custom_relu_dup
custom_relu_module_setup.custom_relu_dup,
]
self.dtypes = ['float32', 'float64']
......@@ -222,13 +229,16 @@ class TestNewCustomOpSetUpInstall(unittest.TestCase):
x = np.random.uniform(-1, 1, [4, 8]).astype(dtype)
for custom_op in self.custom_ops:
out = custom_relu_static(custom_op, device, dtype, x)
pd_out = custom_relu_static(custom_op, device, dtype, x,
False)
pd_out = custom_relu_static(
custom_op, device, dtype, x, False
)
np.testing.assert_array_equal(
out,
pd_out,
err_msg='custom op out: {},\n paddle api out: {}'.
format(out, pd_out))
err_msg='custom op out: {},\n paddle api out: {}'.format(
out, pd_out
),
)
def test_static_pe(self):
for device in self.devices:
......@@ -238,13 +248,16 @@ class TestNewCustomOpSetUpInstall(unittest.TestCase):
x = np.random.uniform(-1, 1, [4, 8]).astype(dtype)
for custom_op in self.custom_ops:
out = custom_relu_static_pe(custom_op, device, dtype, x)
pd_out = custom_relu_static_pe(custom_op, device, dtype, x,
False)
pd_out = custom_relu_static_pe(
custom_op, device, dtype, x, False
)
np.testing.assert_array_equal(
out,
pd_out,
err_msg='custom op out: {},\n paddle api out: {}'.
format(out, pd_out))
err_msg='custom op out: {},\n paddle api out: {}'.format(
out, pd_out
),
)
def func_dynamic(self):
for device in self.devices:
......@@ -253,20 +266,26 @@ class TestNewCustomOpSetUpInstall(unittest.TestCase):
continue
x = np.random.uniform(-1, 1, [4, 8]).astype(dtype)
for custom_op in self.custom_ops:
out, x_grad = custom_relu_dynamic(custom_op, device, dtype,
x)
out, x_grad = custom_relu_dynamic(
custom_op, device, dtype, x
)
pd_out, pd_x_grad = custom_relu_dynamic(
custom_op, device, dtype, x, False)
custom_op, device, dtype, x, False
)
np.testing.assert_array_equal(
out,
pd_out,
err_msg='custom op out: {},\n paddle api out: {}'.
format(out, pd_out))
err_msg='custom op out: {},\n paddle api out: {}'.format(
out, pd_out
),
)
np.testing.assert_array_equal(
x_grad,
pd_x_grad,
err_msg='custom op x grad: {},\n paddle api x grad: {}'.
format(x_grad, pd_x_grad))
err_msg='custom op x grad: {},\n paddle api x grad: {}'.format(
x_grad, pd_x_grad
),
)
def test_dynamic(self):
with _test_eager_guard():
......@@ -279,22 +298,29 @@ class TestNewCustomOpSetUpInstall(unittest.TestCase):
np_label = np.random.random((1, 1)).astype("int64")
path_prefix = "custom_op_inference/custom_relu"
for device in self.devices:
predict = custom_relu_static_inference(self.custom_ops[0], device,
np_data, np_label,
path_prefix)
predict = custom_relu_static_inference(
self.custom_ops[0], device, np_data, np_label, path_prefix
)
# load inference model
with static.scope_guard(static.Scope()):
exe = static.Executor()
[inference_program, feed_target_names,
fetch_targets] = static.load_inference_model(path_prefix, exe)
predict_infer = exe.run(inference_program,
feed={feed_target_names[0]: np_data},
fetch_list=fetch_targets)
[
inference_program,
feed_target_names,
fetch_targets,
] = static.load_inference_model(path_prefix, exe)
predict_infer = exe.run(
inference_program,
feed={feed_target_names[0]: np_data},
fetch_list=fetch_targets,
)
np.testing.assert_array_equal(
predict,
predict_infer,
err_msg='custom op predict: {},\n custom op infer predict: {}'
.format(predict, predict_infer))
err_msg='custom op predict: {},\n custom op infer predict: {}'.format(
predict, predict_infer
),
)
paddle.disable_static()
def test_static_save_and_run_inference_predictor(self):
......@@ -304,62 +330,80 @@ class TestNewCustomOpSetUpInstall(unittest.TestCase):
path_prefix = "custom_op_inference/custom_relu"
from paddle.inference import Config
from paddle.inference import create_predictor
for device in self.devices:
predict = custom_relu_static_inference(self.custom_ops[0], device,
np_data, np_label,
path_prefix)
predict = custom_relu_static_inference(
self.custom_ops[0], device, np_data, np_label, path_prefix
)
# load inference model
config = Config(path_prefix + ".pdmodel",
path_prefix + ".pdiparams")
config = Config(
path_prefix + ".pdmodel", path_prefix + ".pdiparams"
)
predictor = create_predictor(config)
input_tensor = predictor.get_input_handle(
predictor.get_input_names()[0])
predictor.get_input_names()[0]
)
input_tensor.reshape(np_data.shape)
input_tensor.copy_from_cpu(np_data.copy())
predictor.run()
output_tensor = predictor.get_output_handle(
predictor.get_output_names()[0])
predictor.get_output_names()[0]
)
predict_infer = output_tensor.copy_to_cpu()
self.assertTrue(
np.isclose(predict, predict_infer, rtol=5e-5).any(),
"custom op predict: {},\n custom op infer predict: {}".format(
predict, predict_infer))
predict, predict_infer
),
)
paddle.disable_static()
def test_func_double_grad_dynamic(self):
def test_double_grad_dynamic(self):
fluid.set_flags({"FLAGS_retain_grad_for_all_tensor": True})
for device in self.devices:
for dtype in self.dtypes:
if device == 'cpu' and dtype == 'float16':
continue
x = np.random.uniform(-1, 1, [4, 8]).astype(dtype)
out, dx_grad = custom_relu_double_grad_dynamic(
self.custom_ops[0], device, dtype, x)
self.custom_ops[0], device, dtype, x
)
pd_out, pd_dx_grad = custom_relu_double_grad_dynamic(
self.custom_ops[0], device, dtype, x, False)
self.custom_ops[0], device, dtype, x, False
)
np.testing.assert_array_equal(
out,
pd_out,
err_msg='custom op out: {},\n paddle api out: {}'.format(
out, pd_out))
out, pd_out
),
)
np.testing.assert_array_equal(
dx_grad,
pd_dx_grad,
err_msg='custom op dx grad: {},\n paddle api dx grad: {}'.
format(dx_grad, pd_dx_grad))
err_msg='custom op dx grad: {},\n paddle api dx grad: {}'.format(
dx_grad, pd_dx_grad
),
)
fluid.set_flags({"FLAGS_retain_grad_for_all_tensor": False})
def test_with_dataloader(self):
for device in self.devices:
paddle.set_device(device)
# data loader
transform = Compose(
[Normalize(mean=[127.5], std=[127.5], data_format='CHW')])
train_dataset = paddle.vision.datasets.MNIST(mode='train',
transform=transform)
train_loader = paddle.io.DataLoader(train_dataset,
batch_size=64,
shuffle=True,
drop_last=True,
num_workers=0)
[Normalize(mean=[127.5], std=[127.5], data_format='CHW')]
)
train_dataset = paddle.vision.datasets.MNIST(
mode='train', transform=transform
)
train_loader = paddle.io.DataLoader(
train_dataset,
batch_size=64,
shuffle=True,
drop_last=True,
num_workers=0,
)
for batch_id, (image, _) in enumerate(train_loader()):
out = self.custom_ops[0](image)
......@@ -368,7 +412,9 @@ class TestNewCustomOpSetUpInstall(unittest.TestCase):
out,
pd_out,
err_msg='custom op out: {},\n paddle api out: {}'.format(
out, pd_out))
out, pd_out
),
)
if batch_id == 5:
break
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册