未验证 提交 789aac8a 编写于 作者: X xiaoguoguo626807 提交者: GitHub

【prim】change layernorm_grad rules (#51879)

* support layer_norm prim and cinn test

* enable cinn test

* fix merge conflict

* polish input for check_output_with_place

* fix merge conflict

* add more test case

* fix merge conflict

* polish test case

* polish op_test

* change ln_g rules

* modify scale is none case

* modify scale is none case

* add public_python_api for check prim

* modify setoutputgrad and fp64bug

* add todo & delete log

* delete Single***varname

* delete get varname

* modify FP64 bug

* delete op test

* recover

* fix conflict

---------
Co-authored-by: NWeilong Wu <veyron_wu@163.com>
上级 b3efc923
......@@ -44,7 +44,6 @@
- sin
- cos
- where
- reshape
- split
- erf
- tanh
......@@ -32,5 +32,12 @@ template <>
Tensor cast<Tensor>(const Tensor& x, DataType dtype) {
return ::cast_ad_func(x, dtype);
}
template <>
Tensor reshape<Tensor>(const Tensor& x, const IntArray& shape) {
VLOG(4) << "Eager Prim API reshape_ad_func call";
return ::reshape_ad_func(x, shape);
}
} // namespace prim
} // namespace paddle
......@@ -38,5 +38,8 @@ Tensor full(const IntArray& shape,
template <typename T>
Tensor cast(const Tensor& x, DataType dtype);
template <typename T>
Tensor reshape(const Tensor& x, const IntArray& shape);
} // namespace prim
} // namespace paddle
......@@ -126,5 +126,23 @@ Tensor cast<DescTensor>(const Tensor& x, DataType dtype) {
op->InferShape(*block);
return out;
}
template <>
Tensor reshape<DescTensor>(const Tensor& x, const IntArray& shape) {
framework::BlockDesc* block = StaticCompositeContext::Instance().GetBlock();
framework::OpDesc* op = block->AppendOp();
op->SetType("reshape2");
op->SetInput("X",
{std::static_pointer_cast<prim::DescTensor>(x.impl())->Name()});
auto out = empty<DescTensor>({}, x.dtype(), paddle::Place());
op->SetOutput(
"Out", {std::static_pointer_cast<prim::DescTensor>(out.impl())->Name()});
op->SetAttr("shape", unsafe_vector_cast<int64_t, int>(shape.GetData()));
op->CheckAttrs();
op->InferVarType(block);
op->InferShape(*block);
return out;
}
} // namespace prim
} // namespace paddle
......@@ -103,6 +103,8 @@ class CompositeGradOpMakerBase {
return output_grad;
}
// TODO(Ruting): modify name to GetNullableSingleInputGrad after Large-scale
// development
paddle::Tensor GetSingleInputGrad(const std::string& name) {
framework::VarDesc* input_grad_desc = this->SingleInputGrad(name);
if (!input_grad_desc) return paddle::Tensor();
......@@ -320,7 +322,11 @@ class CompositeGradOpMakerBase {
framework::VarDesc* SingleInputGrad(const std::string& name,
bool drop_empty_grad = true) const {
auto var_name = this->SingleForwardInputVarName(name);
auto* var = this->SingleForwardInput(name);
if (!var) {
return nullptr;
}
auto var_name = var->Name();
auto grad_var_name = framework::GradVarName(var_name);
if (no_grad_set_.empty() || !no_grad_set_.count(grad_var_name)) {
(*this->grad_to_var_)[grad_var_name] = var_name;
......@@ -342,7 +348,14 @@ class CompositeGradOpMakerBase {
}
framework::VarDesc* SingleOutputGrad(const std::string& name) const {
auto var_name = this->SingleForwardOutputVarName(name);
auto* var = this->SingleForwardOutput(name);
if (!var) {
PADDLE_THROW(platform::errors::InvalidArgument(
"GetSingleOutputGrad for %s_grad faild, if it is Optional input,"
"please use GetOptionalSingleOutputGrad replaced. ",
name));
}
auto var_name = var->Name();
auto grad_var_name = framework::GradVarName(var_name);
(*this->grad_to_var_)[grad_var_name] = var_name;
VLOG(8) << "Valid gradients: " << grad_var_name;
......@@ -553,14 +566,6 @@ class CompositeGradOpMakerBase {
}
}
std::string SingleForwardInputVarName(const std::string& name) const {
return fwd_op_.Input(name).at(0);
}
std::string SingleForwardOutputVarName(const std::string& name) const {
return fwd_op_.Output(name).at(0);
}
std::vector<std::string> MultiForwardOutputVarName(
const std::string& name) const {
return fwd_op_.Output(name);
......
......@@ -30,7 +30,8 @@ TOLERANCE_NUMPY = {
}
TOLERANCE_COMP_GRAD = {
"float32": {"rtol": 1e-3, "atol": 1e-3},
"float64": {"rtol": 1e-13, "atol": 1e-13},
"float32": {"rtol": 1e-5, "atol": 1e-5},
"float16": {"rtol": 1e-3, "atol": 1e-3}, # amp
}
......@@ -348,6 +349,49 @@ class TestCompositelayer_norm(unittest.TestCase):
core._set_prim_all_enabled(False)
return res
def static_comp_forward_and_backward_withNone(
self, inputs, norm_shape, weight, bias, y_g
):
paddle.enable_static()
core._set_prim_all_enabled(True)
startup_program = paddle.static.Program()
main_program = paddle.static.Program()
with paddle.static.program_guard(main_program, startup_program):
x = paddle.static.data(
'x', shape=inputs.shape, dtype=str(inputs.dtype)
)
x.stop_gradient = False
y_grad = paddle.static.data(
'y_grad', shape=y_g.shape, dtype=str(y_g.dtype)
)
y = fn(x, norm_shape, weight, bias)
blocks = main_program.blocks
fwd_ops = [op.type for op in blocks[0].ops]
# Ensure that layer_norm in original block
self.assertTrue('layer_norm' in fwd_ops)
z = paddle.static.gradients([y], [x], y_grad)
primapi.to_prim(blocks)
exe = paddle.static.Executor()
exe.run(startup_program)
res = exe.run(
main_program,
feed={
'x': inputs,
'y_grad': y_g,
},
fetch_list=z,
)
paddle.disable_static()
core._set_prim_all_enabled(False)
return res
def compare_comp_forward(self):
x, w, b, y_g = generate_data(
attrs.shape1, attrs.shape2, attrs.shape3, attrs.dtype
......@@ -379,12 +423,26 @@ class TestCompositelayer_norm(unittest.TestCase):
atol=TOLERANCE_COMP_GRAD[attrs.dtype]['atol'],
)
def compare_comp_forward_withNone(self):
x, w, b, y_g = generate_data(
attrs.shape1, attrs.shape2, attrs.shape3, attrs.dtype
)
n_shape = attrs.n_shape
x_p = paddle.to_tensor(x)
w_p = paddle.to_tensor(w)
b_p = paddle.to_tensor(b)
y_g_p = paddle.to_tensor(y_g)
expect_2 = dygraph_fused_backward_withNone(
x_p, n_shape, None, None, y_g_p
)[0].numpy()
actual_2 = self.static_comp_forward_withNone(
x, n_shape, None, None, y_g
)[0]
actual_all_2 = self.static_comp_forward_and_backward_withNone(
x, n_shape, None, None, y_g
)[0]
assert expect_2.dtype == actual_2.dtype
np.testing.assert_allclose(
expect_2,
......@@ -393,6 +451,13 @@ class TestCompositelayer_norm(unittest.TestCase):
atol=attrs.get_atol("backward"),
)
np.testing.assert_allclose(
expect_2,
actual_all_2,
rtol=TOLERANCE_COMP_GRAD[attrs.dtype]['rtol'],
atol=TOLERANCE_COMP_GRAD[attrs.dtype]['atol'],
)
def test_backward(self):
for j in self.dtypes:
if paddle.device.get_device() == "cpu":
......@@ -408,11 +473,26 @@ class TestCompositelayer_norm(unittest.TestCase):
)
self.compare_comp_forward()
def test_backward_withNone(self):
for t in range(0, len(self.shape1s)):
if paddle.device.get_device() == "cpu":
print("need pass this case")
continue
attrs.set_dtype("float32")
attrs.set_shape(
self.n_shape[t],
self.shape1s[t],
self.shape2s[t],
self.shape3s[t],
)
self.compare_comp_forward_withNone()
class TestCompositelayer_normPrimBackward(unittest.TestCase):
def setUp(self):
core._set_prim_backward_enabled(True)
self.dtypes = ["float32"]
self.dtypes = ["float16", "float32"]
self.n_shape = [[4], [64, 128], [64]]
self.shape1s = [[3, 4], [64, 64, 128], [128, 64, 64]]
self.shape2s = [[4], [64 * 128], [64]]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册