未验证 提交 39eb77a6 编写于 作者: H HongyuJia 提交者: GitHub

[Cherry-pick] Fix custom operator backward=None (#48656) (#48715)

* [Release2.4] Revert python link prs (#48573)

* Revert "Fix mac link python (#48017)"

This reverts commit 3fa7a736.

* Revert "[Cherry-pick] Fix python link error (#47811)"

This reverts commit ff642c68.

* Update config.go

* fix custom operator backward=None (#48656)

* [Custom Extension] Fix custom double_grad backward=None (#49224)

* fix custom double_grad backward=None

* fix custom_relu.cu bug && polish testcase of double_grad

* remove old dynamic graph test

* add import fluid

* add import fluid
Co-authored-by: NChen Weihang <chenweihang@baidu.com>
上级 5d29a5bf
...@@ -217,18 +217,20 @@ RunCustomOpNode::operator()( ...@@ -217,18 +217,20 @@ RunCustomOpNode::operator()(
VLOG(6) << "Prepare Grad outputs for size: " << grad_outputs_names.size(); VLOG(6) << "Prepare Grad outputs for size: " << grad_outputs_names.size();
for (size_t i = 0; i < OutputMeta().size(); i++) { for (size_t i = 0; i < OutputMeta().size(); i++) {
if (map[0][0].find(i) != map[0][0].end()) { if (map[0][0].find(i) != map[0][0].end()) {
int grad_output_idx = map[0][0][i];
VLOG(7) << "Insert grad outputs: " << i VLOG(7) << "Insert grad outputs: " << i
<< " with size: " << OutputMeta()[i].size() << " with size: " << OutputMeta()[grad_output_idx].size()
<< " to tmp_outputs: " << map[0][0][i]; << " to tmp_outputs: " << grad_output_idx;
for (size_t j = 0; j < OutputMeta()[i].size(); j++) { for (size_t j = 0; j < OutputMeta()[grad_output_idx].size(); j++) {
outs[i].emplace_back(/* init it incase of copy nullptr of shared_ptr */ outs[grad_output_idx]
std::make_shared<phi::DenseTensor>( .emplace_back(/* init it incase of copy nullptr of shared_ptr */
phi::DataType::UNDEFINED), std::make_shared<phi::DenseTensor>(
egr::Controller::Instance().GenerateUniqueName( phi::DataType::UNDEFINED),
"custom_tmp_grad")); egr::Controller::Instance().GenerateUniqueName(
egr::EagerUtils::autograd_meta(&(outs[i][j])); "custom_tmp_grad"));
egr::EagerUtils::autograd_meta(&(outs[grad_output_idx][j]));
} }
tmp_outs[map[0][0][i]] = outs[i]; tmp_outs[grad_output_idx] = outs[grad_output_idx];
} }
} }
for (size_t i = 0; i < tmp_outs.size(); i++) { for (size_t i = 0; i < tmp_outs.size(); i++) {
...@@ -408,17 +410,19 @@ RunCustomOpDoubleGradNode::operator()( ...@@ -408,17 +410,19 @@ RunCustomOpDoubleGradNode::operator()(
for (size_t i = 0; i < OutputMeta().size(); i++) { for (size_t i = 0; i < OutputMeta().size(); i++) {
if (map[1][0].find(i) != map[1][0].end()) { if (map[1][0].find(i) != map[1][0].end()) {
int grad_output_idx = map[1][0][i];
VLOG(7) << "Insert grad outputs: " << i VLOG(7) << "Insert grad outputs: " << i
<< " with size: " << OutputMeta()[i].size() << " with size: " << OutputMeta()[grad_output_idx].size()
<< " to tmp_outputs: " << map[1][0][i]; << " to tmp_outputs: " << grad_output_idx;
for (size_t j = 0; j < OutputMeta()[i].size(); j++) { for (size_t j = 0; j < OutputMeta()[grad_output_idx].size(); j++) {
outs[i].emplace_back(/* init it incase of copy nullptr of shared_ptr */ outs[grad_output_idx]
std::make_shared<phi::DenseTensor>( .emplace_back(/* init it incase of copy nullptr of shared_ptr */
phi::DataType::UNDEFINED), std::make_shared<phi::DenseTensor>(
egr::Controller::Instance().GenerateUniqueName( phi::DataType::UNDEFINED),
"custom_tmp_grad")); egr::Controller::Instance().GenerateUniqueName(
"custom_tmp_grad"));
} }
tmp_outs[map[1][0][i]] = outs[i]; tmp_outs[grad_output_idx] = outs[grad_output_idx];
} }
} }
for (size_t i = 0; i < tmp_outs.size(); i++) { for (size_t i = 0; i < tmp_outs.size(); i++) {
......
...@@ -44,7 +44,7 @@ __global__ void relu_cuda_double_backward_kernel(const data_t* out_data, ...@@ -44,7 +44,7 @@ __global__ void relu_cuda_double_backward_kernel(const data_t* out_data,
data_t* ddout_data, data_t* ddout_data,
int64_t num) { int64_t num) {
int64_t gid = blockIdx.x * blockDim.x + threadIdx.x; int64_t gid = blockIdx.x * blockDim.x + threadIdx.x;
for (int64_t i = num; i < num; i += blockDim.x * gridDim.x) { for (int64_t i = gid; i < num; i += blockDim.x * gridDim.x) {
ddout_data[i] = ddx_data[i] * (out_data[i] > static_cast<data_t>(0.) ddout_data[i] = ddx_data[i] * (out_data[i] > static_cast<data_t>(0.)
? static_cast<data_t>(1.) ? static_cast<data_t>(1.)
: static_cast<data_t>(0.)); : static_cast<data_t>(0.));
......
...@@ -21,6 +21,7 @@ import paddle.static as static ...@@ -21,6 +21,7 @@ import paddle.static as static
import tempfile import tempfile
import subprocess import subprocess
import numpy as np import numpy as np
from paddle import fluid
from paddle.vision.transforms import Compose, Normalize from paddle.vision.transforms import Compose, Normalize
from paddle.utils.cpp_extension.extension_utils import run_cmd from paddle.utils.cpp_extension.extension_utils import run_cmd
from paddle.fluid.framework import _test_eager_guard from paddle.fluid.framework import _test_eager_guard
...@@ -43,12 +44,9 @@ def custom_relu_dynamic(func, device, dtype, np_x, use_func=True): ...@@ -43,12 +44,9 @@ def custom_relu_dynamic(func, device, dtype, np_x, use_func=True):
return out.numpy(), t.grad.numpy() return out.numpy(), t.grad.numpy()
def custom_relu_static(func, def custom_relu_static(
device, func, device, dtype, np_x, use_func=True, test_infer=False
dtype, ):
np_x,
use_func=True,
test_infer=False):
paddle.enable_static() paddle.enable_static()
paddle.set_device(device) paddle.set_device(device)
...@@ -62,9 +60,11 @@ def custom_relu_static(func, ...@@ -62,9 +60,11 @@ def custom_relu_static(func,
exe = static.Executor() exe = static.Executor()
exe.run(static.default_startup_program()) exe.run(static.default_startup_program())
# in static mode, x data has been covered by out # in static mode, x data has been covered by out
out_v = exe.run(static.default_main_program(), out_v = exe.run(
feed={'X': np_x}, static.default_main_program(),
fetch_list=[out.name]) feed={'X': np_x},
fetch_list=[out.name],
)
paddle.disable_static() paddle.disable_static()
return out_v return out_v
...@@ -87,11 +87,11 @@ def custom_relu_static_pe(func, device, dtype, np_x, use_func=True): ...@@ -87,11 +87,11 @@ def custom_relu_static_pe(func, device, dtype, np_x, use_func=True):
# in static mode, x data has been covered by out # in static mode, x data has been covered by out
compiled_prog = static.CompiledProgram( compiled_prog = static.CompiledProgram(
static.default_main_program()).with_data_parallel( static.default_main_program()
loss_name=out.name, places=places) ).with_data_parallel(loss_name=out.name, places=places)
out_v = exe.run(compiled_prog, out_v = exe.run(
feed={'X': np_x}, compiled_prog, feed={'X': np_x}, fetch_list=[out.name]
fetch_list=[out.name]) )
paddle.disable_static() paddle.disable_static()
return out_v return out_v
...@@ -103,9 +103,9 @@ def custom_relu_static_inference(func, device, np_data, np_label, path_prefix): ...@@ -103,9 +103,9 @@ def custom_relu_static_inference(func, device, np_data, np_label, path_prefix):
with static.scope_guard(static.Scope()): with static.scope_guard(static.Scope()):
with static.program_guard(static.Program()): with static.program_guard(static.Program()):
# simple module # simple module
data = static.data(name='data', data = static.data(
shape=[None, 1, 28, 28], name='data', shape=[None, 1, 28, 28], dtype='float32'
dtype='float32') )
label = static.data(name='label', shape=[None, 1], dtype='int64') label = static.data(name='label', shape=[None, 1], dtype='int64')
hidden = static.nn.fc(data, size=128) hidden = static.nn.fc(data, size=128)
...@@ -124,23 +124,21 @@ def custom_relu_static_inference(func, device, np_data, np_label, path_prefix): ...@@ -124,23 +124,21 @@ def custom_relu_static_inference(func, device, np_data, np_label, path_prefix):
# train # train
for i in range(4): for i in range(4):
avg_loss_v = exe.run(static.default_main_program(), avg_loss_v = exe.run(
feed={ static.default_main_program(),
'data': np_data, feed={'data': np_data, 'label': np_label},
'label': np_label fetch_list=[avg_loss],
}, )
fetch_list=[avg_loss])
# save inference model # save inference model
static.save_inference_model(path_prefix, [data], [predict], exe) static.save_inference_model(path_prefix, [data], [predict], exe)
# get train predict value # get train predict value
predict_v = exe.run(static.default_main_program(), predict_v = exe.run(
feed={ static.default_main_program(),
'data': np_data, feed={'data': np_data, 'label': np_label},
'label': np_label fetch_list=[predict],
}, )
fetch_list=[predict])
return predict_v return predict_v
...@@ -151,30 +149,37 @@ def custom_relu_double_grad_dynamic(func, device, dtype, np_x, use_func=True): ...@@ -151,30 +149,37 @@ def custom_relu_double_grad_dynamic(func, device, dtype, np_x, use_func=True):
t = paddle.to_tensor(np_x, dtype=dtype, stop_gradient=False) t = paddle.to_tensor(np_x, dtype=dtype, stop_gradient=False)
out = func(t) if use_func else paddle.nn.functional.relu(t) out = func(t) if use_func else paddle.nn.functional.relu(t)
out.stop_gradient = False dx = paddle.grad(
outputs=out,
dx = paddle.grad(outputs=[out], inputs=t,
inputs=[t], grad_outputs=paddle.ones_like(t),
create_graph=True, create_graph=True,
retain_graph=True) retain_graph=True,
)
dx[0].backward() ddout = paddle.grad(
outputs=dx[0],
inputs=out.grad,
grad_outputs=paddle.ones_like(t),
create_graph=False,
)
assert dx[0].grad is not None assert ddout[0].numpy() is not None
return dx[0].numpy(), dx[0].grad.numpy() return dx[0].numpy(), ddout[0].numpy()
class TestNewCustomOpSetUpInstall(unittest.TestCase): class TestNewCustomOpSetUpInstall(unittest.TestCase):
def setUp(self): def setUp(self):
cur_dir = os.path.dirname(os.path.abspath(__file__)) cur_dir = os.path.dirname(os.path.abspath(__file__))
# compile, install the custom op egg into site-packages under background # compile, install the custom op egg into site-packages under background
if os.name == 'nt': if os.name == 'nt':
cmd = 'cd /d {} && python custom_relu_setup.py install'.format( cmd = 'cd /d {} && python custom_relu_setup.py install'.format(
cur_dir) cur_dir
)
else: else:
cmd = 'cd {} && {} custom_relu_setup.py install'.format( cmd = 'cd {} && {} custom_relu_setup.py install'.format(
cur_dir, sys.executable) cur_dir, sys.executable
)
run_cmd(cmd) run_cmd(cmd)
# NOTE(Aurelius84): Normally, it's no need to add following codes for users. # NOTE(Aurelius84): Normally, it's no need to add following codes for users.
...@@ -190,16 +195,18 @@ class TestNewCustomOpSetUpInstall(unittest.TestCase): ...@@ -190,16 +195,18 @@ class TestNewCustomOpSetUpInstall(unittest.TestCase):
custom_egg_path = [ custom_egg_path = [
x for x in os.listdir(site_dir) if 'custom_relu_module_setup' in x x for x in os.listdir(site_dir) if 'custom_relu_module_setup' in x
] ]
assert len(custom_egg_path assert len(custom_egg_path) == 1, "Matched egg number is %d." % len(
) == 1, "Matched egg number is %d." % len(custom_egg_path) custom_egg_path
)
sys.path.append(os.path.join(site_dir, custom_egg_path[0])) sys.path.append(os.path.join(site_dir, custom_egg_path[0]))
# usage: import the package directly # usage: import the package directly
import custom_relu_module_setup import custom_relu_module_setup
# `custom_relu_dup` is same as `custom_relu_dup` # `custom_relu_dup` is same as `custom_relu_dup`
self.custom_ops = [ self.custom_ops = [
custom_relu_module_setup.custom_relu, custom_relu_module_setup.custom_relu,
custom_relu_module_setup.custom_relu_dup custom_relu_module_setup.custom_relu_dup,
] ]
self.dtypes = ['float32', 'float64'] self.dtypes = ['float32', 'float64']
...@@ -222,13 +229,16 @@ class TestNewCustomOpSetUpInstall(unittest.TestCase): ...@@ -222,13 +229,16 @@ class TestNewCustomOpSetUpInstall(unittest.TestCase):
x = np.random.uniform(-1, 1, [4, 8]).astype(dtype) x = np.random.uniform(-1, 1, [4, 8]).astype(dtype)
for custom_op in self.custom_ops: for custom_op in self.custom_ops:
out = custom_relu_static(custom_op, device, dtype, x) out = custom_relu_static(custom_op, device, dtype, x)
pd_out = custom_relu_static(custom_op, device, dtype, x, pd_out = custom_relu_static(
False) custom_op, device, dtype, x, False
)
np.testing.assert_array_equal( np.testing.assert_array_equal(
out, out,
pd_out, pd_out,
err_msg='custom op out: {},\n paddle api out: {}'. err_msg='custom op out: {},\n paddle api out: {}'.format(
format(out, pd_out)) out, pd_out
),
)
def test_static_pe(self): def test_static_pe(self):
for device in self.devices: for device in self.devices:
...@@ -238,13 +248,16 @@ class TestNewCustomOpSetUpInstall(unittest.TestCase): ...@@ -238,13 +248,16 @@ class TestNewCustomOpSetUpInstall(unittest.TestCase):
x = np.random.uniform(-1, 1, [4, 8]).astype(dtype) x = np.random.uniform(-1, 1, [4, 8]).astype(dtype)
for custom_op in self.custom_ops: for custom_op in self.custom_ops:
out = custom_relu_static_pe(custom_op, device, dtype, x) out = custom_relu_static_pe(custom_op, device, dtype, x)
pd_out = custom_relu_static_pe(custom_op, device, dtype, x, pd_out = custom_relu_static_pe(
False) custom_op, device, dtype, x, False
)
np.testing.assert_array_equal( np.testing.assert_array_equal(
out, out,
pd_out, pd_out,
err_msg='custom op out: {},\n paddle api out: {}'. err_msg='custom op out: {},\n paddle api out: {}'.format(
format(out, pd_out)) out, pd_out
),
)
def func_dynamic(self): def func_dynamic(self):
for device in self.devices: for device in self.devices:
...@@ -253,20 +266,26 @@ class TestNewCustomOpSetUpInstall(unittest.TestCase): ...@@ -253,20 +266,26 @@ class TestNewCustomOpSetUpInstall(unittest.TestCase):
continue continue
x = np.random.uniform(-1, 1, [4, 8]).astype(dtype) x = np.random.uniform(-1, 1, [4, 8]).astype(dtype)
for custom_op in self.custom_ops: for custom_op in self.custom_ops:
out, x_grad = custom_relu_dynamic(custom_op, device, dtype, out, x_grad = custom_relu_dynamic(
x) custom_op, device, dtype, x
)
pd_out, pd_x_grad = custom_relu_dynamic( pd_out, pd_x_grad = custom_relu_dynamic(
custom_op, device, dtype, x, False) custom_op, device, dtype, x, False
)
np.testing.assert_array_equal( np.testing.assert_array_equal(
out, out,
pd_out, pd_out,
err_msg='custom op out: {},\n paddle api out: {}'. err_msg='custom op out: {},\n paddle api out: {}'.format(
format(out, pd_out)) out, pd_out
),
)
np.testing.assert_array_equal( np.testing.assert_array_equal(
x_grad, x_grad,
pd_x_grad, pd_x_grad,
err_msg='custom op x grad: {},\n paddle api x grad: {}'. err_msg='custom op x grad: {},\n paddle api x grad: {}'.format(
format(x_grad, pd_x_grad)) x_grad, pd_x_grad
),
)
def test_dynamic(self): def test_dynamic(self):
with _test_eager_guard(): with _test_eager_guard():
...@@ -279,22 +298,29 @@ class TestNewCustomOpSetUpInstall(unittest.TestCase): ...@@ -279,22 +298,29 @@ class TestNewCustomOpSetUpInstall(unittest.TestCase):
np_label = np.random.random((1, 1)).astype("int64") np_label = np.random.random((1, 1)).astype("int64")
path_prefix = "custom_op_inference/custom_relu" path_prefix = "custom_op_inference/custom_relu"
for device in self.devices: for device in self.devices:
predict = custom_relu_static_inference(self.custom_ops[0], device, predict = custom_relu_static_inference(
np_data, np_label, self.custom_ops[0], device, np_data, np_label, path_prefix
path_prefix) )
# load inference model # load inference model
with static.scope_guard(static.Scope()): with static.scope_guard(static.Scope()):
exe = static.Executor() exe = static.Executor()
[inference_program, feed_target_names, [
fetch_targets] = static.load_inference_model(path_prefix, exe) inference_program,
predict_infer = exe.run(inference_program, feed_target_names,
feed={feed_target_names[0]: np_data}, fetch_targets,
fetch_list=fetch_targets) ] = static.load_inference_model(path_prefix, exe)
predict_infer = exe.run(
inference_program,
feed={feed_target_names[0]: np_data},
fetch_list=fetch_targets,
)
np.testing.assert_array_equal( np.testing.assert_array_equal(
predict, predict,
predict_infer, predict_infer,
err_msg='custom op predict: {},\n custom op infer predict: {}' err_msg='custom op predict: {},\n custom op infer predict: {}'.format(
.format(predict, predict_infer)) predict, predict_infer
),
)
paddle.disable_static() paddle.disable_static()
def test_static_save_and_run_inference_predictor(self): def test_static_save_and_run_inference_predictor(self):
...@@ -304,62 +330,80 @@ class TestNewCustomOpSetUpInstall(unittest.TestCase): ...@@ -304,62 +330,80 @@ class TestNewCustomOpSetUpInstall(unittest.TestCase):
path_prefix = "custom_op_inference/custom_relu" path_prefix = "custom_op_inference/custom_relu"
from paddle.inference import Config from paddle.inference import Config
from paddle.inference import create_predictor from paddle.inference import create_predictor
for device in self.devices: for device in self.devices:
predict = custom_relu_static_inference(self.custom_ops[0], device, predict = custom_relu_static_inference(
np_data, np_label, self.custom_ops[0], device, np_data, np_label, path_prefix
path_prefix) )
# load inference model # load inference model
config = Config(path_prefix + ".pdmodel", config = Config(
path_prefix + ".pdiparams") path_prefix + ".pdmodel", path_prefix + ".pdiparams"
)
predictor = create_predictor(config) predictor = create_predictor(config)
input_tensor = predictor.get_input_handle( input_tensor = predictor.get_input_handle(
predictor.get_input_names()[0]) predictor.get_input_names()[0]
)
input_tensor.reshape(np_data.shape) input_tensor.reshape(np_data.shape)
input_tensor.copy_from_cpu(np_data.copy()) input_tensor.copy_from_cpu(np_data.copy())
predictor.run() predictor.run()
output_tensor = predictor.get_output_handle( output_tensor = predictor.get_output_handle(
predictor.get_output_names()[0]) predictor.get_output_names()[0]
)
predict_infer = output_tensor.copy_to_cpu() predict_infer = output_tensor.copy_to_cpu()
self.assertTrue( self.assertTrue(
np.isclose(predict, predict_infer, rtol=5e-5).any(), np.isclose(predict, predict_infer, rtol=5e-5).any(),
"custom op predict: {},\n custom op infer predict: {}".format( "custom op predict: {},\n custom op infer predict: {}".format(
predict, predict_infer)) predict, predict_infer
),
)
paddle.disable_static() paddle.disable_static()
def test_func_double_grad_dynamic(self): def test_double_grad_dynamic(self):
fluid.set_flags({"FLAGS_retain_grad_for_all_tensor": True})
for device in self.devices: for device in self.devices:
for dtype in self.dtypes: for dtype in self.dtypes:
if device == 'cpu' and dtype == 'float16': if device == 'cpu' and dtype == 'float16':
continue continue
x = np.random.uniform(-1, 1, [4, 8]).astype(dtype) x = np.random.uniform(-1, 1, [4, 8]).astype(dtype)
out, dx_grad = custom_relu_double_grad_dynamic( out, dx_grad = custom_relu_double_grad_dynamic(
self.custom_ops[0], device, dtype, x) self.custom_ops[0], device, dtype, x
)
pd_out, pd_dx_grad = custom_relu_double_grad_dynamic( pd_out, pd_dx_grad = custom_relu_double_grad_dynamic(
self.custom_ops[0], device, dtype, x, False) self.custom_ops[0], device, dtype, x, False
)
np.testing.assert_array_equal( np.testing.assert_array_equal(
out, out,
pd_out, pd_out,
err_msg='custom op out: {},\n paddle api out: {}'.format( err_msg='custom op out: {},\n paddle api out: {}'.format(
out, pd_out)) out, pd_out
),
)
np.testing.assert_array_equal( np.testing.assert_array_equal(
dx_grad, dx_grad,
pd_dx_grad, pd_dx_grad,
err_msg='custom op dx grad: {},\n paddle api dx grad: {}'. err_msg='custom op dx grad: {},\n paddle api dx grad: {}'.format(
format(dx_grad, pd_dx_grad)) dx_grad, pd_dx_grad
),
)
fluid.set_flags({"FLAGS_retain_grad_for_all_tensor": False})
def test_with_dataloader(self): def test_with_dataloader(self):
for device in self.devices: for device in self.devices:
paddle.set_device(device) paddle.set_device(device)
# data loader # data loader
transform = Compose( transform = Compose(
[Normalize(mean=[127.5], std=[127.5], data_format='CHW')]) [Normalize(mean=[127.5], std=[127.5], data_format='CHW')]
train_dataset = paddle.vision.datasets.MNIST(mode='train', )
transform=transform) train_dataset = paddle.vision.datasets.MNIST(
train_loader = paddle.io.DataLoader(train_dataset, mode='train', transform=transform
batch_size=64, )
shuffle=True, train_loader = paddle.io.DataLoader(
drop_last=True, train_dataset,
num_workers=0) batch_size=64,
shuffle=True,
drop_last=True,
num_workers=0,
)
for batch_id, (image, _) in enumerate(train_loader()): for batch_id, (image, _) in enumerate(train_loader()):
out = self.custom_ops[0](image) out = self.custom_ops[0](image)
...@@ -368,7 +412,9 @@ class TestNewCustomOpSetUpInstall(unittest.TestCase): ...@@ -368,7 +412,9 @@ class TestNewCustomOpSetUpInstall(unittest.TestCase):
out, out,
pd_out, pd_out,
err_msg='custom op out: {},\n paddle api out: {}'.format( err_msg='custom op out: {},\n paddle api out: {}'.format(
out, pd_out)) out, pd_out
),
)
if batch_id == 5: if batch_id == 5:
break break
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册