未验证 提交 8cd8cd53 编写于 作者: L Leo Chen 提交者: GitHub

fix shape of tile_grad op (#29289) (#29324)

上级 ec57656e
......@@ -167,6 +167,7 @@ class TileGradOp : public framework::OperatorWithKernel {
framework::GradVarName("Out"), "TileGrad");
auto x_dims = ctx->GetInputDim("X");
std::vector<int> repeat_times =
ctx->Attrs().Get<std::vector<int>>("repeat_times");
if (repeat_times.size() == 0) {
......
......@@ -186,9 +186,9 @@ template <typename DeviceContext, typename T>
class TileGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* in0 = context.Input<Tensor>("X");
auto* x = context.Input<Tensor>("X");
auto repeat_times = get_repeat_times(context);
auto x_dims = in0->dims();
auto x_dims = x->dims();
auto vec_in_dims = framework::vectorize<int>(x_dims);
if (repeat_times.size() < vec_in_dims.size()) {
int diff = vec_in_dims.size() - repeat_times.size();
......@@ -220,11 +220,13 @@ class TileGradKernel : public framework::OpKernel<T> {
}
// no need reduce, just copy
if (just_copy) {
auto* in0 = context.Input<Tensor>(framework::GradVarName("Out"));
auto* out0 = context.Output<Tensor>(framework::GradVarName("X"));
out0->mutable_data<T>(context.GetPlace());
framework::TensorCopy(*in0, context.GetPlace(), context.device_context(),
out0);
auto* dout = context.Input<Tensor>(framework::GradVarName("Out"));
auto* dx = context.Output<Tensor>(framework::GradVarName("X"));
dx->mutable_data<T>(context.GetPlace());
framework::TensorCopy(*dout, context.GetPlace(), context.device_context(),
dx);
// TensorCopy may change the dims of dx
dx->Resize(x_dims);
} else {
PADDLE_ENFORCE_GE(dims, 1,
platform::errors::InvalidArgument(
......@@ -261,6 +263,7 @@ class TileGradKernel : public framework::OpKernel<T> {
for (size_t i = 0; i < reduce_size; ++i) {
reduce_dims[i] = reduce_dims_vec[i];
}
auto out_grad = EigenVector<T>::Flatten(*in0);
x_grad.device(
*context.template device_context<DeviceContext>().eigen_device()) =
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册