未验证 提交 292f3f77 编写于 作者: X xiaoguoguo626807 提交者: GitHub

【prim】vjp for reduce sum (#49736)

上级 e70af91d
......@@ -17,6 +17,7 @@
#include <string>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/prim/api/manual/backward/composite_backward_api.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/unary.h"
......@@ -63,6 +64,35 @@ class ReduceSumOpGradMaker : public framework::SingleGradOpMaker<T> {
}
};
class ReduceSumCompositeGradOpMaker : public prim::GradCompositeOpMakerBase {
public:
using prim::GradCompositeOpMakerBase::GradCompositeOpMakerBase;
void Apply() override {
// get inputs
paddle::experimental::Tensor x = this->GetSingleForwardInput("X");
paddle::experimental::Tensor out_grad = this->GetSingleOutputGrad("Out");
// get attr
std::vector<int> axis = this->Attr<std::vector<int>>("dim");
bool keep_dim = this->Attr<bool>("keep_dim");
bool reduce_all = this->Attr<bool>("reduce_all");
// get output
paddle::experimental::Tensor x_grad_t = this->GetSingleInputGrad("X");
// get output ptr
paddle::experimental::Tensor* x_grad = this->GetOutputPtr(&x_grad_t);
// get output orginal name
std::string x_grad_name = this->GetOutputName(x_grad_t);
// call composite backward func
prim::sum_grad<prim::DescTensor>(
x, out_grad, axis, keep_dim, reduce_all, x_grad);
// recover output name
this->RecoverOutputName(x_grad_t, x_grad_name);
}
};
template <typename T>
class ReduceSumDoubleOpGradMaker : public framework::SingleGradOpMaker<T> {
public:
......@@ -114,6 +144,7 @@ REGISTER_OPERATOR(reduce_sum,
ops::ReduceSumVarTypeInference,
ops::ReduceSumOpGradMaker<paddle::framework::OpDesc>,
ops::ReduceSumOpGradMaker<paddle::imperative::OpBase>,
ops::ReduceSumCompositeGradOpMaker,
ReduceSumInferShapeFunctor);
REGISTER_OPERATOR(reduce_sum_grad,
ops::ReduceGradOp,
......
......@@ -15,11 +15,17 @@
#pragma once
#include "paddle/fluid/prim/api/manual/prim_api/prim_api.h"
#include "paddle/fluid/prim/api/manual/utils/utils.h"
#include "paddle/phi/common/int_array.h"
#include "paddle/phi/core/ddim.h"
namespace paddle {
namespace prim {
// This function should have as same signature as phi, which defined in
// paddle/phi/api/backward/backward_api.h
using Tensor = paddle::experimental::Tensor;
using IntArray =
paddle::experimental::IntArrayBase<paddle::experimental::Tensor>;
// using IntArray = paddle::experimental::IntArray;
// This function should have as same signature as phi, which defined in
// paddle/phi/api/backward/backward_api.h
template <typename T>
void tanh_grad(const Tensor& out, const Tensor& grad_out, Tensor* grad_x) {
auto tmp = pow<T>(out, 2.0);
......@@ -94,6 +100,44 @@ void add_grad(const Tensor& x,
}
}
template <typename T>
void sum_grad(const Tensor& x,
const Tensor& out_grad,
const IntArray& axis,
bool keepdim,
bool reduce_all,
Tensor* x_grad) {
if (!x_grad) {
return;
}
std::vector<int> x_dim = phi::vectorize<int>(x.dims());
int64_t axis_size = axis.size();
int64_t x_dim_size = x_dim.size();
reduce_all = false;
if (reduce_all || axis_size == 0 || axis_size == x_dim_size) {
reduce_all = true;
} else {
reduce_all = false;
}
auto x_grad_tmp = Tensor();
if (!keepdim) {
auto axis_ = std::vector<int64_t>();
if (reduce_all) {
for (int64_t i = 1; i < x_dim_size; i++) {
axis_.push_back(i);
}
} else {
axis_ = axis.GetData();
}
auto out_grad_ = unsqueeze<T>(out_grad, axis_);
x_grad_tmp = expand<T>(out_grad_, x_dim);
} else {
x_grad_tmp = expand<T>(out_grad, x_dim);
}
x_grad->set_impl(x_grad_tmp.impl());
}
template <typename T>
void divide_grad(const Tensor& x,
const Tensor& y,
......
......@@ -36,6 +36,16 @@ Tensor multiply<Tensor>(const Tensor& x, const Tensor& y) {
return ::multiply_ad_func(x, y);
}
template <>
Tensor expand<Tensor>(const Tensor& x, const IntArray& shape) {
return ::expand_ad_func(x, shape);
}
template <>
Tensor unsqueeze<Tensor>(const Tensor& x, const IntArray& axis) {
return ::unsqueeze_ad_func(x, axis);
}
template <>
Tensor divide<Tensor>(const Tensor& x, const Tensor& y) {
return ::divide_ad_func(x, y);
......
......@@ -21,18 +21,25 @@ namespace prim {
using Tensor = paddle::experimental::Tensor;
using IntArray = paddle::experimental::IntArray;
using Scalar = paddle::experimental::Scalar;
template <typename T>
Tensor pow(const Tensor& x, const paddle::experimental::Scalar& y);
Tensor pow(const Tensor& x, const Scalar& y);
template <typename T>
Tensor scale(const Tensor& X,
const paddle::experimental::Scalar& scale,
const Scalar& scale,
float bias,
bool bias_after_scale);
template <typename T>
Tensor multiply(const Tensor& x, const Tensor& y);
template <typename T>
Tensor expand(const Tensor& x, const IntArray& shape);
template <typename T>
Tensor unsqueeze(const Tensor& x, const IntArray& axis);
template <typename T>
Tensor divide(const Tensor& x, const Tensor& y);
......
......@@ -94,6 +94,23 @@ Tensor multiply<DescTensor>(const Tensor& x, const Tensor& y) {
return out;
}
template <>
Tensor expand<DescTensor>(const Tensor& x, const IntArray& shape) {
Tensor out = empty<DescTensor>({}, phi::DataType::FLOAT32, paddle::Place());
framework::BlockDesc* block = StaticCompositeContext::Instance().GetBlock();
framework::OpDesc* op = block->AppendOp();
op->SetType("expand_v2");
op->SetInput("X",
{std::static_pointer_cast<prim::DescTensor>(x.impl())->Name()});
op->SetOutput(
"Out", {std::static_pointer_cast<prim::DescTensor>(out.impl())->Name()});
std::vector<int> new_shape(shape.GetData().begin(), shape.GetData().end());
op->SetAttr("shape", new_shape);
op->CheckAttrs();
op->InferVarType(block);
return out;
}
template <>
Tensor divide<DescTensor>(const Tensor& x, const Tensor& y) {
// Grad infershape
......@@ -113,6 +130,23 @@ Tensor divide<DescTensor>(const Tensor& x, const Tensor& y) {
return out;
}
template <>
Tensor unsqueeze<DescTensor>(const Tensor& x, const IntArray& axis) {
Tensor out = empty<DescTensor>({}, phi::DataType::FLOAT32, paddle::Place());
framework::BlockDesc* block = StaticCompositeContext::Instance().GetBlock();
framework::OpDesc* op = block->AppendOp();
op->SetType("unsqueeze2");
op->SetInput("X",
{std::static_pointer_cast<prim::DescTensor>(x.impl())->Name()});
op->SetOutput(
"Out", {std::static_pointer_cast<prim::DescTensor>(out.impl())->Name()});
std::vector<int> new_shape(axis.GetData().begin(), axis.GetData().end());
op->SetAttr("axes", new_shape);
op->CheckAttrs();
op->InferVarType(block);
return out;
}
template <>
Tensor full<DescTensor>(paddle::experimental::IntArray shape,
paddle::experimental::Scalar value,
......@@ -141,6 +175,7 @@ Tensor full<DescTensor>(paddle::experimental::IntArray shape,
op->InferShape(*block);
return out;
}
template <>
Tensor sum<DescTensor>(Tensor x,
paddle::experimental::IntArray axis,
......
......@@ -1356,6 +1356,7 @@
param : [x]
kernel :
func : sum_grad
composite : sum_grad(x, out_grad, axis, keepdim, reduce_all, x_grad)
no_need_buffer : x
backward : sum_double_grad
......
......@@ -8,5 +8,4 @@ foreach(TEST_OP ${TEST_OPS})
py_test_modules(${TEST_OP} MODULES ${TEST_OP} ENVS ${GC_ENVS})
endforeach()
add_subdirectory(comp)
add_subdirectory(prim)
file(
GLOB TEST_OPS
RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}"
"test_*.py")
string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}")
foreach(TEST_OP ${TEST_OPS})
py_test_modules(${TEST_OP} MODULES ${TEST_OP} ENVS ${GC_ENVS})
endforeach()
......@@ -11,5 +11,6 @@ endforeach()
set_tests_properties(test_comp_eager_tanh_grad PROPERTIES TIMEOUT 60)
set_tests_properties(test_comp_eager_div_grad PROPERTIES TIMEOUT 60)
set_tests_properties(test_comp_eager_sum_grad PROPERTIES TIMEOUT 60)
set_tests_properties(test_comp_eager_add_grad PROPERTIES TIMEOUT 60)
set_tests_properties(test_comp_eager_sub_grad PROPERTIES TIMEOUT 60)
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle
from paddle.fluid import core
def actual(primal, cotangent, axis, keep_dim):
core.set_prim_enabled(False)
x = paddle.to_tensor(primal, dtype='float32', stop_gradient=False)
v = paddle.to_tensor(cotangent, dtype='float32', stop_gradient=False)
y = paddle.sum(x, axis=axis, keepdim=keep_dim)
x_cotangent = paddle.grad(y, x, v, create_graph=True, retain_graph=True)
return x_cotangent[0]
def desired(primal, cotangent, axis, keep_dim):
core.set_prim_enabled(True)
x = paddle.to_tensor(primal, dtype='float32', stop_gradient=False)
v = paddle.to_tensor(cotangent, dtype='float32', stop_gradient=False)
y = paddle.sum(x, axis=axis, keepdim=keep_dim)
x_cotangent = paddle.grad(y, x, v, create_graph=True, retain_graph=True)
return x_cotangent[0]
class TestSumGradComp(unittest.TestCase):
def test_sum_grad_comp_1(self):
self.primal = np.random.rand(10, 10)
self.cotangent = np.random.rand(1)
paddle.disable_static()
np.testing.assert_allclose(
actual=actual(self.primal, self.cotangent, [], False),
desired=desired(self.primal, self.cotangent, [], False),
rtol=1e-6,
atol=0,
)
def test_sum_grad_comp_2(self):
self.primal = np.random.rand(4, 3, 2)
self.cotangent = np.random.rand(4, 2)
paddle.disable_static()
np.testing.assert_allclose(
actual=actual(self.primal, self.cotangent, 1, False),
desired=desired(self.primal, self.cotangent, 1, False),
rtol=1e-6,
atol=0,
)
def test_sum_grad_comp_3(self):
self.primal = np.random.rand(4, 3, 2)
self.cotangent = np.random.rand(4, 1, 2)
paddle.disable_static()
np.testing.assert_allclose(
actual=actual(self.primal, self.cotangent, 1, True),
desired=desired(self.primal, self.cotangent, 1, True),
rtol=1e-6,
atol=0,
)
def test_sum_grad_comp_4(self):
self.primal = np.random.rand(4, 3, 2, 5)
self.cotangent = np.random.rand(4, 1, 2, 1)
paddle.disable_static()
np.testing.assert_allclose(
actual=actual(self.primal, self.cotangent, [1, 3], True),
desired=desired(self.primal, self.cotangent, [1, 3], True),
rtol=1e-6,
atol=0,
)
def test_sum_grad_comp_5(self):
self.primal = np.random.rand(4, 3, 2, 5)
self.cotangent = np.random.rand(4, 2)
paddle.disable_static()
np.testing.assert_allclose(
actual=actual(self.primal, self.cotangent, [1, 3], False),
desired=desired(self.primal, self.cotangent, [1, 3], False),
rtol=1e-6,
atol=0,
)
if __name__ == '__main__':
unittest.main()
......@@ -11,6 +11,7 @@ endforeach()
set_tests_properties(test_comp_tanh_grad PROPERTIES TIMEOUT 60)
set_tests_properties(test_comp_div_grad PROPERTIES TIMEOUT 60)
set_tests_properties(test_comp_sum_grad PROPERTIES TIMEOUT 60)
set_tests_properties(test_comp_add_grad PROPERTIES TIMEOUT 60)
set_tests_properties(test_comp_sub_grad PROPERTIES TIMEOUT 60)
set_tests_properties(test_comp_add_tanh_grad PROPERTIES TIMEOUT 60)
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle
from paddle.fluid import core
def actual(primal, cotangent, axis, keep_dim):
core.set_prim_enabled(False)
mp, sp = paddle.static.Program(), paddle.static.Program()
with paddle.static.program_guard(mp, sp):
x = paddle.static.data('primal', primal.shape, primal.dtype)
x.stop_gradient = False
v = paddle.static.data('cotangent', cotangent.shape, cotangent.dtype)
y = paddle.sum(x, axis=axis, keepdim=keep_dim)
x_cotangent = paddle.static.gradients(y, x, None)
exe = paddle.static.Executor()
exe.run(sp)
result = exe.run(
program=mp,
feed={'primal': primal, 'cotangent': cotangent},
fetch_list=[x_cotangent],
)[0]
return result
def desired(primal, cotangent, axis, keep_dim):
core.set_prim_enabled(True)
mp, sp = paddle.static.Program(), paddle.static.Program()
with paddle.static.program_guard(mp, sp):
x = paddle.static.data('primal', primal.shape, primal.dtype)
x.stop_gradient = False
v = paddle.static.data('cotangent', cotangent.shape, cotangent.dtype)
y = paddle.sum(x, axis=axis, keepdim=keep_dim)
x_cotangent = paddle.static.gradients(y, x, None)
exe = paddle.static.Executor()
exe.run(sp)
result = exe.run(
program=mp,
feed={'primal': primal, 'cotangent': cotangent},
fetch_list=[x_cotangent],
)[0]
return result
class TestSumGradComp(unittest.TestCase):
def test_sum_grad_comp_1(self):
self.primal = np.random.rand(10, 10)
self.cotangent = np.random.rand(1, 1)
paddle.enable_static()
np.testing.assert_allclose(
actual=actual(self.primal, self.cotangent, [], True),
desired=desired(self.primal, self.cotangent, [], True),
rtol=1e-6,
atol=0,
)
def test_sum_grad_comp_2(self):
self.primal = np.random.rand(4, 3, 2)
self.cotangent = np.random.rand(4, 2)
paddle.enable_static()
np.testing.assert_allclose(
actual=actual(self.primal, self.cotangent, 1, False),
desired=desired(self.primal, self.cotangent, 1, False),
rtol=1e-6,
atol=0,
)
def test_sum_grad_comp_3(self):
self.primal = np.random.rand(4, 3, 2)
self.cotangent = np.random.rand(4, 1, 2)
paddle.enable_static()
np.testing.assert_allclose(
actual=actual(self.primal, self.cotangent, 1, True),
desired=desired(self.primal, self.cotangent, 1, True),
rtol=1e-6,
atol=0,
)
def test_sum_grad_comp_4(self):
self.primal = np.random.rand(4, 3, 2, 5)
self.cotangent = np.random.rand(4, 1, 2, 1)
paddle.enable_static()
np.testing.assert_allclose(
actual=actual(self.primal, self.cotangent, [1, 3], True),
desired=desired(self.primal, self.cotangent, [1, 3], True),
rtol=1e-6,
atol=0,
)
def test_sum_grad_comp_5(self):
self.primal = np.random.rand(4, 3, 2, 5)
self.cotangent = np.random.rand(4, 2)
paddle.enable_static()
np.testing.assert_allclose(
actual=actual(self.primal, self.cotangent, [1, 3], False),
desired=desired(self.primal, self.cotangent, [1, 3], False),
rtol=1e-6,
atol=0,
)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册