未验证 提交 604b7a53 编写于 作者: J Jiabin Yang 提交者: GitHub

support relue custom vjp (#51742)

上级 702fc894
...@@ -39,3 +39,4 @@ ...@@ -39,3 +39,4 @@
- put_along_axis - put_along_axis
- greater_than - greater_than
- less_equal - less_equal
- where
...@@ -30,6 +30,18 @@ using Tensor = paddle::Tensor; ...@@ -30,6 +30,18 @@ using Tensor = paddle::Tensor;
using IntArray = paddle::experimental::IntArrayBase<paddle::Tensor>; using IntArray = paddle::experimental::IntArrayBase<paddle::Tensor>;
// This function should have as same signature as phi, which defined in // This function should have as same signature as phi, which defined in
// paddle/phi/api/backward/backward_api.h // paddle/phi/api/backward/backward_api.h
template <typename T>
void relu_grad(const Tensor& out, const Tensor& out_grad, Tensor* x_grad) {
if (x_grad) {
auto condition = greater_than<T>(
out, full<T>(phi::vectorize(out.dims()), 0.0, out.dtype()));
auto res = where<T>(condition,
out_grad,
full<T>(phi::vectorize(out.dims()), 0.0, out.dtype()));
set_output<T>(res, x_grad);
}
}
template <typename T> template <typename T>
void softmax_grad(const Tensor& out, void softmax_grad(const Tensor& out,
const Tensor& out_grad, const Tensor& out_grad,
......
...@@ -1142,6 +1142,7 @@ ...@@ -1142,6 +1142,7 @@
kernel : kernel :
func : relu_grad func : relu_grad
backward: relu_double_grad backward: relu_double_grad
composite: relu_grad(out, out_grad, x_grad)
inplace : (out_grad -> x_grad) inplace : (out_grad -> x_grad)
- backward_op : renorm_grad - backward_op : renorm_grad
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
from utils import TOLERANCE
import paddle
import paddle.nn.functional as F
from paddle.fluid import core
def generate_data(shape, dtype="float32"):
np_data = np.random.random(shape).astype(dtype)
return np_data
class Attr:
def __init__(self) -> None:
self.dtype = None
self.shape = None
def set_dtype(self, dtype) -> None:
self.dtype = dtype
return
def set_shape(self, shape) -> None:
self.shape = shape
return
def get_rtol(self, flag):
rtol = TOLERANCE[self.dtype][flag].get("rtol")
return rtol
def get_atol(self, flag):
atol = TOLERANCE[self.dtype][flag].get("atol")
return atol
attrs = Attr()
def fn(x):
return F.relu(x)
def expect_grad(inputs):
paddle.disable_static()
inputs.stop_gradient = False
res = fn(inputs)
gradients = paddle.grad(res, inputs)
return gradients
class TestCompositeSoftmaxPrimBackward(unittest.TestCase):
"test composite softmax and prim backward"
def setUp(self):
core._set_prim_backward_enabled(True)
self.dtypes = ["float16", "float32", "float64"]
self.shapes = [[2, 3, 4], [2, 3]]
def cal_composite_grad(self, inputs):
paddle.enable_static()
core._set_prim_all_enabled(True)
startup_program = paddle.static.Program()
main_program = paddle.static.Program()
with paddle.static.program_guard(main_program, startup_program):
x = paddle.static.data(
'x', shape=inputs.shape, dtype=str(inputs.dtype)
)
x.stop_gradient = False
y = fn(x)
blocks = main_program.blocks
z = paddle.static.gradients([y], x)
paddle.incubate.autograd.primapi.to_prim(blocks)
exe = paddle.static.Executor()
exe.run(startup_program)
res = exe.run(main_program, feed={'x': inputs}, fetch_list=[z])
paddle.disable_static()
core._set_prim_all_enabled(False)
return res
def compare_backward(self):
np_data = generate_data(attrs.shape)
tensor_data = paddle.to_tensor(np_data)
expect = expect_grad(tensor_data)[0].numpy()
actual = self.cal_composite_grad(np_data)[0]
assert expect.dtype == actual.dtype
np.testing.assert_allclose(
expect,
actual,
rtol=attrs.get_rtol("prim_backward"),
atol=attrs.get_rtol("prim_backward"),
)
def test_prim_backward(self):
for j in self.dtypes:
for t in self.shapes:
attrs.set_dtype(j)
attrs.set_shape(t)
self.compare_backward()
if __name__ == '__main__':
unittest.main()
...@@ -97,6 +97,7 @@ def composite_batchnorm( ...@@ -97,6 +97,7 @@ def composite_batchnorm(
batch_mean = zeros(run_mean.shape, run_mean.dtype) batch_mean = zeros(run_mean.shape, run_mean.dtype)
batch_var = zeros(run_var.shape, run_var.dtype) batch_var = zeros(run_var.shape, run_var.dtype)
if not use_run_stat: if not use_run_stat:
batch_mean = mean(x, reduce_axes, keepdim=True) batch_mean = mean(x, reduce_axes, keepdim=True)
temp = mean(x * x, reduce_axes, keepdim=True) temp = mean(x * x, reduce_axes, keepdim=True)
batch_var = temp - batch_mean * batch_mean batch_var = temp - batch_mean * batch_mean
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册