未验证 提交 dda6b9d5 编写于 作者: C ccrrong 提交者: GitHub

update tile_grad composite rule (#53261)

上级 be1b3fc3
......@@ -171,14 +171,25 @@ class TileCompositeGradOpMaker : public prim::CompositeGradOpMakerBase {
paddle::Tensor x = this->GetSingleForwardInput("X");
paddle::Tensor out_grad = this->GetSingleOutputGrad("Out");
paddle::Tensor x_grad = this->GetSingleInputGrad("X");
paddle::optional<paddle::Tensor> tensor_repeat_times =
this->GetOptionalSingleForwardInput("RepeatTimes");
paddle::optional<paddle::Tensor> tensor_repeat_times_attr =
this->GetOptionalSingleForwardInput("repeat_times_tensor");
auto dx_ptr = this->GetOutputPtr(&x_grad);
std::string dx_name = this->GetOutputName(x_grad);
auto repeat_times = this->Attr<std::vector<int>>("repeat_times");
VLOG(6) << "Runing tile_grad composite func";
prim::tile_grad<prim::DescTensor>(
x, out_grad, paddle::experimental::IntArray(repeat_times), dx_ptr);
this->RecoverOutputName(x_grad, dx_name);
if (tensor_repeat_times.is_initialized() ||
tensor_repeat_times_attr.is_initialized()) {
PADDLE_THROW(platform::errors::Unimplemented(
"We don't support RepeatTimes from tensor or repeat_times_tensor for "
"tile composite grad for now. "));
} else {
VLOG(6) << "Runing tile_grad composite func";
prim::tile_grad<prim::DescTensor>(
x, out_grad, paddle::experimental::IntArray(repeat_times), dx_ptr);
this->RecoverOutputName(x_grad, dx_name);
}
}
};
......
......@@ -1774,22 +1774,7 @@ void tile_grad(const Tensor& x,
if (x_grad) {
auto repeat_times_data = repeat_times.GetData();
auto out_grad_shape = phi::vectorize<int>(out_grad.dims());
auto x_shape = phi::vectorize<int>(x.dims());
if (repeat_times_data.size() < x_shape.size()) {
int diff = x_shape.size() - repeat_times_data.size();
repeat_times_data.insert(repeat_times_data.begin(), diff, 1);
} else {
int diff = repeat_times_data.size() - x_shape.size();
x_shape.insert(x_shape.begin(), diff, 1);
}
for (int i = 0; i < static_cast<int>(out_grad_shape.size()); i++) {
if (out_grad_shape[i] == -1) {
out_grad_shape[i] = x_shape[i] * repeat_times_data[i];
}
}
auto result = reshape<T>(out_grad, out_grad_shape);
auto result = out_grad;
for (int i = 0; i < static_cast<int>(repeat_times_data.size()); i++) {
int size = out_grad_shape[i] / repeat_times_data[i];
std::vector<int> sections(repeat_times_data[i], size);
......
......@@ -4220,7 +4220,7 @@ void TileInferMeta(const MetaTensor& x,
auto repeat_times_data = repeat_times.GetData();
auto x_dims = x.dims();
if (repeat_times_data.size() == 0) {
repeat_times_data = std::vector<int64_t>(x_dims.size(), -1);
repeat_times_data = std::vector<int64_t>(x_dims.size(), 1);
}
PADDLE_ENFORCE_LE(
......@@ -4253,10 +4253,10 @@ void TileInferMeta(const MetaTensor& x,
auto x_dim_vec = phi::vectorize<int>(x_dims);
if (x_dim_vec.size() > repeat_times_data.size()) {
auto diff = x_dim_vec.size() - repeat_times_data.size();
repeat_times_data.insert(repeat_times_data.begin(), diff, -1);
repeat_times_data.insert(repeat_times_data.begin(), diff, 1);
} else {
auto diff = repeat_times_data.size() - x_dim_vec.size();
x_dim_vec.insert(x_dim_vec.begin(), diff, -1);
x_dim_vec.insert(x_dim_vec.begin(), diff, 1);
}
for (size_t i = 0; i < repeat_times_data.size(); ++i) {
if (x_dim_vec[i] == -1 || repeat_times_data[i] == -1) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册