未验证 提交 f424162c 编写于 作者: C ccrrong 提交者: GitHub

add tile_grad composite rule (#53141)

* add tile_grad composite rule
上级 27159c42
...@@ -18,6 +18,9 @@ limitations under the License. */ ...@@ -18,6 +18,9 @@ limitations under the License. */
#include "paddle/fluid/framework/infershape_utils.h" #include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/prim/api/composite_backward/composite_backward_api.h"
#include "paddle/fluid/prim/utils/static/composite_grad_desc_maker.h"
#include "paddle/fluid/prim/utils/static/desc_tensor.h"
#include "paddle/phi/core/infermeta_utils.h" #include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/unary.h" #include "paddle/phi/infermeta/unary.h"
...@@ -160,6 +163,25 @@ class TileGradOpMaker : public framework::SingleGradOpMaker<T> { ...@@ -160,6 +163,25 @@ class TileGradOpMaker : public framework::SingleGradOpMaker<T> {
} }
}; };
class TileCompositeGradOpMaker : public prim::CompositeGradOpMakerBase {
using prim::CompositeGradOpMakerBase::CompositeGradOpMakerBase;
public:
void Apply() override {
paddle::Tensor x = this->GetSingleForwardInput("X");
paddle::Tensor out_grad = this->GetSingleOutputGrad("Out");
paddle::Tensor x_grad = this->GetSingleInputGrad("X");
auto dx_ptr = this->GetOutputPtr(&x_grad);
std::string dx_name = this->GetOutputName(x_grad);
auto repeat_times = this->Attr<std::vector<int>>("repeat_times");
VLOG(6) << "Runing tile_grad composite func";
prim::tile_grad<prim::DescTensor>(
x, out_grad, paddle::experimental::IntArray(repeat_times), dx_ptr);
this->RecoverOutputName(x_grad, dx_name);
}
};
template <typename T> template <typename T>
class TileDoubleGradOpMaker : public framework::SingleGradOpMaker<T> { class TileDoubleGradOpMaker : public framework::SingleGradOpMaker<T> {
public: public:
...@@ -196,6 +218,7 @@ REGISTER_OPERATOR(tile, ...@@ -196,6 +218,7 @@ REGISTER_OPERATOR(tile,
ops::TileOpMaker, ops::TileOpMaker,
ops::TileGradOpMaker<paddle::framework::OpDesc>, ops::TileGradOpMaker<paddle::framework::OpDesc>,
ops::TileGradOpMaker<paddle::imperative::OpBase>, ops::TileGradOpMaker<paddle::imperative::OpBase>,
ops::TileCompositeGradOpMaker,
TileInferMetaFunctor); TileInferMetaFunctor);
REGISTER_OPERATOR(tile_grad, REGISTER_OPERATOR(tile_grad,
ops::TileGradOp, ops::TileGradOp,
......
...@@ -1769,6 +1769,44 @@ void gelu_grad(const Tensor& x, ...@@ -1769,6 +1769,44 @@ void gelu_grad(const Tensor& x,
} }
} }
template <typename T>
void tile_grad(const Tensor& x,
const Tensor& out_grad,
const IntArray& repeat_times,
Tensor* x_grad) {
if (x_grad) {
auto repeat_times_data = repeat_times.GetData();
auto out_grad_shape = phi::vectorize<int>(out_grad.dims());
auto x_shape = phi::vectorize<int>(x.dims());
if (repeat_times_data.size() < x_shape.size()) {
int diff = x_shape.size() - repeat_times_data.size();
repeat_times_data.insert(repeat_times_data.begin(), diff, 1);
} else {
int diff = repeat_times_data.size() - x_shape.size();
x_shape.insert(x_shape.begin(), diff, 1);
}
for (int i = 0; i < static_cast<int>(out_grad_shape.size()); i++) {
if (out_grad_shape[i] == -1) {
out_grad_shape[i] = x_shape[i] * repeat_times_data[i];
}
}
auto result = reshape<T>(out_grad, out_grad_shape);
for (int i = 0; i < static_cast<int>(repeat_times_data.size()); i++) {
int size = out_grad_shape[i] / repeat_times_data[i];
std::vector<int> sections(repeat_times_data[i], size);
auto split_arr = split<T>(result, IntArray(sections), i);
result = full<T>(phi::vectorize(split_arr[0].dims()), 0.0, x.dtype());
for (int j = 0; j < static_cast<int>(split_arr.size()); j++) {
result = split_arr[j] + result;
}
}
result = reshape<T>(result, x.shape());
set_output<T>(result, x_grad);
}
}
template <typename T> template <typename T>
void roll_grad(const Tensor& x, void roll_grad(const Tensor& x,
const Tensor& out_grad, const Tensor& out_grad,
......
...@@ -1030,6 +1030,7 @@ ...@@ -1030,6 +1030,7 @@
kernel : kernel :
func : tile_grad func : tile_grad
no_need_buffer : x no_need_buffer : x
composite : tile_grad(x, outgrad, repeat_times, x_grad)
backward : tile_double_grad backward : tile_double_grad
- backward_op : transpose_double_grad - backward_op : transpose_double_grad
......
...@@ -1114,6 +1114,7 @@ set(TEST_CINN_OPS ...@@ -1114,6 +1114,7 @@ set(TEST_CINN_OPS
test_cast_op test_cast_op
test_dropout_op test_dropout_op
test_group_norm_op test_group_norm_op
test_tile_op
test_roll_op) test_roll_op)
foreach(TEST_CINN_OPS ${TEST_CINN_OPS}) foreach(TEST_CINN_OPS ${TEST_CINN_OPS})
......
...@@ -67,7 +67,8 @@ class TestMeshgridOp(OpTest): ...@@ -67,7 +67,8 @@ class TestMeshgridOp(OpTest):
return [100, 200] return [100, 200]
def if_enable_cinn(self): def if_enable_cinn(self):
self.enable_cinn = True # 拆解tile_grad导致cinn运行超时
self.enable_cinn = False
class TestMeshgridOp2(TestMeshgridOp): class TestMeshgridOp2(TestMeshgridOp):
......
...@@ -29,6 +29,9 @@ class TestTileOpRank1(OpTest): ...@@ -29,6 +29,9 @@ class TestTileOpRank1(OpTest):
def setUp(self): def setUp(self):
self.op_type = "tile" self.op_type = "tile"
self.python_api = paddle.tile self.python_api = paddle.tile
self.prim_op_type = "prim"
self.enable_cinn = True
self.public_python_api = paddle.tile
self.init_data() self.init_data()
self.inputs = {'X': np.random.random(self.ori_shape).astype("float64")} self.inputs = {'X': np.random.random(self.ori_shape).astype("float64")}
...@@ -44,23 +47,26 @@ class TestTileOpRank1(OpTest): ...@@ -44,23 +47,26 @@ class TestTileOpRank1(OpTest):
self.check_output() self.check_output()
def test_check_grad(self): def test_check_grad(self):
self.check_grad(['X'], 'Out') self.check_grad(['X'], 'Out', check_prim=True)
class TestTileOpRank_ZeroDim1(TestTileOpRank1): class TestTileOpRank_ZeroDim1(TestTileOpRank1):
def init_data(self): def init_data(self):
self.enable_cinn = False
self.ori_shape = [] self.ori_shape = []
self.repeat_times = [] self.repeat_times = []
class TestTileOpRank_ZeroDim2(TestTileOpRank1): class TestTileOpRank_ZeroDim2(TestTileOpRank1):
def init_data(self): def init_data(self):
self.enable_cinn = False
self.ori_shape = [] self.ori_shape = []
self.repeat_times = [2] self.repeat_times = [2]
class TestTileOpRank_ZeroDim3(TestTileOpRank1): class TestTileOpRank_ZeroDim3(TestTileOpRank1):
def init_data(self): def init_data(self):
self.enable_cinn = False
self.ori_shape = [] self.ori_shape = []
self.repeat_times = [2, 3] self.repeat_times = [2, 3]
...@@ -201,6 +207,9 @@ class TestTileFP16OP(OpTest): ...@@ -201,6 +207,9 @@ class TestTileFP16OP(OpTest):
self.op_type = "tile" self.op_type = "tile"
self.dtype = np.float16 self.dtype = np.float16
self.python_api = paddle.tile self.python_api = paddle.tile
self.prim_op_type = "prim"
self.enable_cinn = True
self.public_python_api = paddle.tile
self.init_data() self.init_data()
x = np.random.uniform(10, size=self.ori_shape).astype(self.dtype) x = np.random.uniform(10, size=self.ori_shape).astype(self.dtype)
output = np.tile(x, self.repeat_times) output = np.tile(x, self.repeat_times)
...@@ -217,7 +226,7 @@ class TestTileFP16OP(OpTest): ...@@ -217,7 +226,7 @@ class TestTileFP16OP(OpTest):
self.check_output() self.check_output()
def test_check_grad(self): def test_check_grad(self):
self.check_grad(['X'], 'Out') self.check_grad(['X'], 'Out', check_prim=True)
@unittest.skipIf( @unittest.skipIf(
...@@ -230,6 +239,9 @@ class TestTileBF16OP(OpTest): ...@@ -230,6 +239,9 @@ class TestTileBF16OP(OpTest):
self.op_type = 'tile' self.op_type = 'tile'
self.__class__.op_type = self.op_type self.__class__.op_type = self.op_type
self.python_api = paddle.tile self.python_api = paddle.tile
self.prim_op_type = "prim"
self.enable_cinn = False
self.public_python_api = paddle.tile
self.init_data() self.init_data()
x = np.random.uniform(10, size=self.ori_shape).astype(np.float32) x = np.random.uniform(10, size=self.ori_shape).astype(np.float32)
output = np.tile(x, self.repeat_times) output = np.tile(x, self.repeat_times)
...@@ -248,7 +260,7 @@ class TestTileBF16OP(OpTest): ...@@ -248,7 +260,7 @@ class TestTileBF16OP(OpTest):
def test_check_grad(self): def test_check_grad(self):
place = core.CUDAPlace(0) place = core.CUDAPlace(0)
self.check_grad_with_place(place, ['X'], 'Out') self.check_grad_with_place(place, ['X'], 'Out', check_prim=True)
# Situation 5: input x is Bool # Situation 5: input x is Bool
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册