diff --git a/paddle/fluid/operators/expand_as_v2_op.cc b/paddle/fluid/operators/expand_as_v2_op.cc old mode 100644 new mode 100755 index 5296a144f6247db18fc866febac39779d4a317b3..cc293a5aaa0b278dbf857cedc8b9e074f641f0e4 --- a/paddle/fluid/operators/expand_as_v2_op.cc +++ b/paddle/fluid/operators/expand_as_v2_op.cc @@ -12,6 +12,7 @@ limitations under the License. */ #include "paddle/fluid/operators/expand_as_v2_op.h" #include #include +#include "paddle/fluid/framework/op_version_registry.h" namespace paddle { namespace operators { @@ -50,6 +51,10 @@ class ExpandAsV2OpMaker : public framework::OpProtoAndCheckerMaker { AddInput("X", "(Tensor, default Tensor). A tensor with rank in [1, 6]." "X is the input to be expanded."); + AddInput("Y", + "(Tensor, default Tensor). A tensor with rank in [1, 6]." + "Expand X according to the shape of Y.") + .AsDispensable(); AddOutput("Out", "(Tensor, default Tensor). A tensor with rank in [1, 6]." "The rank of Output(Out) have the same with Input(X). " @@ -144,3 +149,9 @@ REGISTER_OP_CUDA_KERNEL( ops::ExpandAsV2GradKernel, ops::ExpandAsV2GradKernel); #endif + +REGISTER_OP_VERSION(expand_as_v2) + .AddCheckpoint( + R"ROC(fix expand_as_v2 and add new input [Y])ROC", + paddle::framework::compatible::OpVersionDesc().NewInput( + "Y", "Expand X according to the shape of Y")); \ No newline at end of file diff --git a/paddle/fluid/operators/expand_as_v2_op.h b/paddle/fluid/operators/expand_as_v2_op.h old mode 100644 new mode 100755 index 3e8f7d15880bcd16ed040637f3e80c43b4d287b7..9e683a792c61f96de6174aea5e4be60f1ee87257 --- a/paddle/fluid/operators/expand_as_v2_op.h +++ b/paddle/fluid/operators/expand_as_v2_op.h @@ -91,17 +91,34 @@ class ExpandAsV2Kernel : public framework::OpKernel { PADDLE_ENFORCE_NE(target_shape[i], 0, platform::errors::InvalidArgument( "The value of target shape cannot be zero.")); - if (vec_in_dims[i] != 1) { + if (i < diff) { + PADDLE_ENFORCE_GT( + target_shape[i], 0, + platform::errors::InvalidArgument( + "The expanded size (%d) for non-existing dimensions must be " + "positive for expand_as_v2 op.", + target_shape[i])); + repeat_times[i] = target_shape[i]; + } else if (target_shape[i] > 0) { + if (vec_in_dims[i] != 1) { + PADDLE_ENFORCE_EQ( + vec_in_dims[i], target_shape[i], + platform::errors::InvalidArgument( + "The value (%d) of the non-singleton dimension does not match" + " the corresponding value (%d) in shape for expand_as_v2 op.", + vec_in_dims[i], target_shape[i])); + repeat_times[i] = 1; + } else { + repeat_times[i] = target_shape[i]; + } + } else { PADDLE_ENFORCE_EQ( - vec_in_dims[i], target_shape[i], + target_shape[i], -1, platform::errors::InvalidArgument( - "The value (%d) of the non-singleton dimension does not match" - " the corresponding value (%d) in " - "target tensor for expand_as_v2 op.", - vec_in_dims[i], target_shape[i])); + "When the value in shape is negative for expand_as_v2 op, " + "only -1 is supported, but the value received is %d.", + target_shape[i])); repeat_times[i] = 1; - } else { - repeat_times[i] = target_shape[i]; } } auto* out0 = context.Output("Out"); diff --git a/paddle/fluid/operators/expand_v2_op.cc b/paddle/fluid/operators/expand_v2_op.cc old mode 100644 new mode 100755 index dc6da979671e598605f6904b1b26602a5f44071a..6d803c500d90f9464746ce4879713a57a5855984 --- a/paddle/fluid/operators/expand_v2_op.cc +++ b/paddle/fluid/operators/expand_v2_op.cc @@ -65,7 +65,11 @@ class ExpandV2Op : public framework::OperatorWithKernel { if (x_dims[i] == -1) { out_shape[i] = -1; } else if (expand_shape[i] == -1) { - out_shape[i] = x_dims[i]; + if (static_cast(x_dims.size()) > i) { + out_shape[i] = x_dims[i]; + } else { + out_shape[i] = -1; + } } else if (expand_shape[i] == -2) { // We use -2 to represent the element in expand_shape is a var. out_shape[i] = -1; diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py old mode 100644 new mode 100755 index b54c3596a26a9827e325463cd99448ceeab5f4ec..a15c1af391f9f25cb95fd9d9465818f46e78fbbf --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -1838,7 +1838,7 @@ def expand_as(x, y, name=None): "you must set its stop_gradient to be False by " "some_var.stop_gradient = True, supporting " "some_var as the input 'x'.") - inputs = {"X": [x]} + inputs = {"X": [x], "Y": [y]} helper = LayerHelper('expand_as', **locals()) dtype = helper.input_dtype(input_param_name='x')