未验证 提交 13a22a37 编写于 作者: L Leo Chen 提交者: GitHub

fix shape of tile_grad op (#29289)

上级 be3777a5
...@@ -167,6 +167,7 @@ class TileGradOp : public framework::OperatorWithKernel { ...@@ -167,6 +167,7 @@ class TileGradOp : public framework::OperatorWithKernel {
framework::GradVarName("Out"), "TileGrad"); framework::GradVarName("Out"), "TileGrad");
auto x_dims = ctx->GetInputDim("X"); auto x_dims = ctx->GetInputDim("X");
std::vector<int> repeat_times = std::vector<int> repeat_times =
ctx->Attrs().Get<std::vector<int>>("repeat_times"); ctx->Attrs().Get<std::vector<int>>("repeat_times");
if (repeat_times.size() == 0) { if (repeat_times.size() == 0) {
......
...@@ -186,9 +186,9 @@ template <typename DeviceContext, typename T> ...@@ -186,9 +186,9 @@ template <typename DeviceContext, typename T>
class TileGradKernel : public framework::OpKernel<T> { class TileGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto* in0 = context.Input<Tensor>("X"); auto* x = context.Input<Tensor>("X");
auto repeat_times = get_repeat_times(context); auto repeat_times = get_repeat_times(context);
auto x_dims = in0->dims(); auto x_dims = x->dims();
auto vec_in_dims = framework::vectorize<int>(x_dims); auto vec_in_dims = framework::vectorize<int>(x_dims);
if (repeat_times.size() < vec_in_dims.size()) { if (repeat_times.size() < vec_in_dims.size()) {
int diff = vec_in_dims.size() - repeat_times.size(); int diff = vec_in_dims.size() - repeat_times.size();
...@@ -220,11 +220,13 @@ class TileGradKernel : public framework::OpKernel<T> { ...@@ -220,11 +220,13 @@ class TileGradKernel : public framework::OpKernel<T> {
} }
// no need reduce, just copy // no need reduce, just copy
if (just_copy) { if (just_copy) {
auto* in0 = context.Input<Tensor>(framework::GradVarName("Out")); auto* dout = context.Input<Tensor>(framework::GradVarName("Out"));
auto* out0 = context.Output<Tensor>(framework::GradVarName("X")); auto* dx = context.Output<Tensor>(framework::GradVarName("X"));
out0->mutable_data<T>(context.GetPlace()); dx->mutable_data<T>(context.GetPlace());
framework::TensorCopy(*in0, context.GetPlace(), context.device_context(), framework::TensorCopy(*dout, context.GetPlace(), context.device_context(),
out0); dx);
// TensorCopy may change the dims of dx
dx->Resize(x_dims);
} else { } else {
PADDLE_ENFORCE_GE(dims, 1, PADDLE_ENFORCE_GE(dims, 1,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
...@@ -261,6 +263,7 @@ class TileGradKernel : public framework::OpKernel<T> { ...@@ -261,6 +263,7 @@ class TileGradKernel : public framework::OpKernel<T> {
for (size_t i = 0; i < reduce_size; ++i) { for (size_t i = 0; i < reduce_size; ++i) {
reduce_dims[i] = reduce_dims_vec[i]; reduce_dims[i] = reduce_dims_vec[i];
} }
auto out_grad = EigenVector<T>::Flatten(*in0); auto out_grad = EigenVector<T>::Flatten(*in0);
x_grad.device( x_grad.device(
*context.template device_context<DeviceContext>().eigen_device()) = *context.template device_context<DeviceContext>().eigen_device()) =
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册