未验证 提交 f124c86f 编写于 作者: J Jiabin Yang 提交者: GitHub

【Prim】Custom softmax grad (#51474)

* [CINN]Enhance CacheKey hash logic by considering input dtypes (#50557)

* [CINN]Enhance CacheKey hash logic by considering input dtypes

* add unittest

* fix typo

* fix typo

* fix map.at

* fix find

* fix test

* fix cinn cache key structure realize

* using ordered map for attributes

* add test by review advice

---------
Co-authored-by: Njiangcheng <thisjiang@qq.com>

* [prim] enable dygraph_to_static to support custom_vjp

* Pr 50885 (#7)

* [CINN]Enhance CacheKey hash logic by considering input dtypes (#50557)

* [CINN]Enhance CacheKey hash logic by considering input dtypes

* add unittest

* fix typo

* fix typo

* fix map.at

* fix find

* fix test

* fix cinn cache key structure realize

* using ordered map for attributes

* add test by review advice

---------
Co-authored-by: Njiangcheng <thisjiang@qq.com>

* [prim] enable dygraph_to_static to support custom_vjp

* fix code in a dy2static-friendly way.

* [dystatic] add hooker for prim

---------
Co-authored-by: NAurelius84 <zhangliujie@baidu.com>
Co-authored-by: Njiangcheng <thisjiang@qq.com>
Co-authored-by: Ncxxly <chenxx_id@163.com>

* [prim] enable dygraph_to_static to support custom_vjp

* fix cast prim and vjp dtype mapping error bug

* Cxx prim custom vjp (#8)

* [CINN]Enhance CacheKey hash logic by considering input dtypes (#50557)

---------
Co-authored-by: Njiangcheng <thisjiang@qq.com>

* [prim] enable dygraph_to_static to support custom_vjp

* Pr 50885 (#7)

* [CINN]Enhance CacheKey hash logic by considering input dtypes (#50557)

* [CINN]Enhance CacheKey hash logic by considering input dtypes

---------
Co-authored-by: Njiangcheng <thisjiang@qq.com>

* [prim] enable dygraph_to_static to support custom_vjp

* fix code in a dy2static-friendly way.

* [dystatic] add hooker for prim

---------
Co-authored-by: NAurelius84 <zhangliujie@baidu.com>
Co-authored-by: Njiangcheng <thisjiang@qq.com>
Co-authored-by: Ncxxly <chenxx_id@163.com>

* [prim] enable dygraph_to_static to support custom_vjp

* fix cast prim and vjp dtype mapping error bug

* [dy2static-ci] fix dy2static ci errors.

---------
Co-authored-by: NAurelius84 <zhangliujie@baidu.com>
Co-authored-by: Njiangcheng <thisjiang@qq.com>
Co-authored-by: Ncxxly <chenxx_id@163.com>

* [Prim] enable whitelist and blacklist for custom_vjp

* support softmax grad

* remove additional code

* add test back

---------
Co-authored-by: NAurelius84 <zhangliujie@baidu.com>
Co-authored-by: Njiangcheng <thisjiang@qq.com>
Co-authored-by: Ncxxly <chenxx_id@163.com>
Co-authored-by: Nxiongkun <807377414@qq.com>
上级 50df0170
......@@ -19,6 +19,9 @@ limitations under the License. */
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/device/gpu/gpu_dnn.h"
#include "paddle/fluid/prim/api/composite_backward/composite_backward_api.h"
#include "paddle/fluid/prim/utils/static/composite_grad_desc_maker.h"
#include "paddle/fluid/prim/utils/static/desc_tensor.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/backward.h"
#include "paddle/phi/infermeta/unary.h"
......@@ -156,6 +159,23 @@ class SoftmaxOpGradMaker : public framework::SingleGradOpMaker<T> {
}
};
class SoftmaxCompositeGradOpMaker : public prim::CompositeGradOpMakerBase {
using prim::CompositeGradOpMakerBase::CompositeGradOpMakerBase;
public:
void Apply() override {
paddle::Tensor out = this->GetSingleForwardOutput("Out");
paddle::Tensor out_grad = this->GetSingleOutputGrad("Out");
paddle::Tensor dx = this->GetSingleInputGrad("X");
auto* dx_ptr = this->GetOutputPtr(&dx);
std::string dx_name = this->GetOutputName(dx);
int axis = static_cast<int>(this->Attr<int>("axis"));
VLOG(6) << "Runing softmax_grad composite func";
prim::softmax_grad<prim::DescTensor>(out, out_grad, axis, dx_ptr);
this->RecoverOutputName(dx, dx_name);
}
};
DECLARE_INPLACE_OP_INFERER(SoftmaxInplaceInferer, {"X", "Out"});
} // namespace operators
......@@ -172,6 +192,7 @@ REGISTER_OPERATOR(softmax,
ops::SoftmaxOpInferVarType,
ops::SoftmaxOpGradMaker<paddle::framework::OpDesc>,
ops::SoftmaxOpGradMaker<paddle::imperative::OpBase>,
ops::SoftmaxCompositeGradOpMaker,
ops::SoftmaxInplaceInferer,
SoftmaxInferShapeFunctor);
DECLARE_INFER_SHAPE_FUNCTOR(softmax_grad,
......
......@@ -30,6 +30,35 @@ using Tensor = paddle::Tensor;
using IntArray = paddle::experimental::IntArrayBase<paddle::Tensor>;
// This function should have as same signature as phi, which defined in
// paddle/phi/api/backward/backward_api.h
template <typename T>
void softmax_grad(const Tensor& out,
const Tensor& out_grad,
int axis,
Tensor* x_grad) {
if (x_grad) {
if (out_grad.dims().size() > 0) {
if (axis >= 0) {
auto new_out_grad = out_grad * out;
auto tmp_x_grad = new_out_grad -
out * sum<T>(new_out_grad, {axis}, out.dtype(), true);
set_output<T>(tmp_x_grad, x_grad);
} else {
auto new_out_grad = out_grad * out;
auto tmp_x_grad =
new_out_grad - out * sum<T>(new_out_grad,
{out.dims().size() + axis},
out.dtype(),
true);
set_output<T>(tmp_x_grad, x_grad);
}
} else {
set_output<T>(
full<T>(phi::vectorize(out_grad.dims()), 0.0, out_grad.dtype()),
x_grad);
}
}
}
template <typename T>
void cast_grad(const Tensor& out_grad, DataType dtype, Tensor* x_grad) {
if (x_grad) {
......
......@@ -1144,6 +1144,7 @@
param : [out]
kernel :
func : softmax_grad
composite : softmax_grad(out, out_grad, axis, x_grad)
- backward_op : spectral_norm_grad
forward : spectral_norm (Tensor weight, Tensor u, Tensor v, int dim, int power_iters, float eps) -> Tensor(out)
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
from utils import TOLERANCE
import paddle
import paddle.nn.functional as F
from paddle.fluid import core
def generate_data(shape, dtype="float32"):
np_data = np.random.random(shape).astype(dtype)
return np_data
class Attr:
def __init__(self) -> None:
self.dtype = None
self.axis = -1
self.shape = None
def set_dtype(self, dtype) -> None:
self.dtype = dtype
return
def set_axis(self, axis) -> None:
self.axis = axis
return
def set_shape(self, shape) -> None:
self.shape = shape
return
def get_rtol(self, flag):
rtol = TOLERANCE[self.dtype][flag].get("rtol")
return rtol
def get_atol(self, flag):
atol = TOLERANCE[self.dtype][flag].get("atol")
return atol
attrs = Attr()
def fn(x):
return F.softmax(x, axis=attrs.axis, dtype=attrs.dtype)
def expect_grad(inputs):
paddle.disable_static()
inputs.stop_gradient = False
res = fn(inputs)
gradients = paddle.grad(res, inputs)
return gradients
class TestCompositeSoftmax(unittest.TestCase):
def setUp(self):
self.dtypes = ["float32", "float64"]
self.shapes = [[2, 3, 4], [2, 3]]
self.axes = [-1, 0, 1]
def cal_composite_grad(self, inputs):
paddle.enable_static()
core._set_prim_forward_enabled(True)
startup_program = paddle.static.Program()
main_program = paddle.static.Program()
with paddle.static.program_guard(main_program, startup_program):
x = paddle.static.data(
'x', shape=inputs.shape, dtype=str(inputs.dtype)
)
x.stop_gradient = False
y = fn(x)
blocks = main_program.blocks
fwd_ops = [op.type for op in blocks[0].ops]
# Ensure that softmax in original block
self.assertTrue('softmax' in fwd_ops)
paddle.incubate.autograd.primapi.to_prim(blocks)
fwd_ops_new = [op.type for op in blocks[0].ops]
# Ensure that softmax is splitted into small ops
self.assertTrue('softmax' not in fwd_ops_new)
z = paddle.static.gradients([y], x)
fwd_ops_grad = [op.type for op in blocks[0].ops]
# Ensure that softmax_grad not in grad block
self.assertTrue('softmax_grad' not in fwd_ops_grad)
exe = paddle.static.Executor()
exe.run(startup_program)
res = exe.run(main_program, feed={'x': inputs}, fetch_list=[z])
paddle.disable_static()
core._set_prim_forward_enabled(False)
return res
def compare_backward(self):
np_data = generate_data(attrs.shape)
tensor_data = paddle.to_tensor(np_data)
expect = expect_grad(tensor_data)[0].numpy()
actual = self.cal_composite_grad(np_data)[0]
assert expect.dtype == actual.dtype
np.testing.assert_allclose(
expect,
actual,
rtol=attrs.get_rtol("backward"),
atol=attrs.get_atol("backward"),
)
def test_backward(self):
for i in self.axes:
for j in self.dtypes:
for t in self.shapes:
attrs.set_axis(i)
attrs.set_dtype(j)
attrs.set_shape(t)
self.compare_backward()
class TestCompositeSoftmaxPrimBackward(unittest.TestCase):
"test composite softmax and prim backward"
def setUp(self):
core._set_prim_backward_enabled(True)
self.dtypes = ["float32", "float64"]
self.shapes = [[], [2, 3, 4], [2, 3]]
self.axes = [-1, 0, 1]
def cal_composite_grad(self, inputs):
paddle.enable_static()
core._set_prim_all_enabled(True)
startup_program = paddle.static.Program()
main_program = paddle.static.Program()
with paddle.static.program_guard(main_program, startup_program):
x = paddle.static.data(
'x', shape=inputs.shape, dtype=str(inputs.dtype)
)
x.stop_gradient = False
y = fn(x)
blocks = main_program.blocks
z = paddle.static.gradients([y], x)
paddle.incubate.autograd.primapi.to_prim(blocks)
exe = paddle.static.Executor()
exe.run(startup_program)
res = exe.run(main_program, feed={'x': inputs}, fetch_list=[z])
paddle.disable_static()
core._set_prim_all_enabled(False)
return res
def compare_backward(self):
if not attrs.shape and attrs.axis not in [-1, 0]:
# op softmax does not support both case
return
np_data = generate_data(attrs.shape)
tensor_data = paddle.to_tensor(np_data)
expect = expect_grad(tensor_data)[0].numpy()
actual = self.cal_composite_grad(np_data)[0]
assert expect.dtype == actual.dtype
np.testing.assert_allclose(
expect,
actual,
rtol=attrs.get_rtol("prim_backward"),
atol=attrs.get_rtol("prim_backward"),
)
def test_prim_backward(self):
for i in self.axes:
for j in self.dtypes:
for t in self.shapes:
attrs.set_axis(i)
attrs.set_dtype(j)
attrs.set_shape(t)
self.compare_backward()
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册