未验证 提交 aec493c0 编写于 作者: T Thomas Young 提交者: GitHub

fix expand_v2 and expand_as_v2 bug (#38677)

上级 c48bd3ff
...@@ -12,6 +12,7 @@ limitations under the License. */ ...@@ -12,6 +12,7 @@ limitations under the License. */
#include "paddle/fluid/operators/expand_as_v2_op.h" #include "paddle/fluid/operators/expand_as_v2_op.h"
#include <memory> #include <memory>
#include <vector> #include <vector>
#include "paddle/fluid/framework/op_version_registry.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -50,6 +51,10 @@ class ExpandAsV2OpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -50,6 +51,10 @@ class ExpandAsV2OpMaker : public framework::OpProtoAndCheckerMaker {
AddInput("X", AddInput("X",
"(Tensor, default Tensor<float>). A tensor with rank in [1, 6]." "(Tensor, default Tensor<float>). A tensor with rank in [1, 6]."
"X is the input to be expanded."); "X is the input to be expanded.");
AddInput("Y",
"(Tensor, default Tensor<float>). A tensor with rank in [1, 6]."
"Expand X according to the shape of Y.")
.AsDispensable();
AddOutput("Out", AddOutput("Out",
"(Tensor, default Tensor<float>). A tensor with rank in [1, 6]." "(Tensor, default Tensor<float>). A tensor with rank in [1, 6]."
"The rank of Output(Out) have the same with Input(X). " "The rank of Output(Out) have the same with Input(X). "
...@@ -144,3 +149,9 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -144,3 +149,9 @@ REGISTER_OP_CUDA_KERNEL(
ops::ExpandAsV2GradKernel<paddle::platform::CUDADeviceContext, float>, ops::ExpandAsV2GradKernel<paddle::platform::CUDADeviceContext, float>,
ops::ExpandAsV2GradKernel<paddle::platform::CUDADeviceContext, double>); ops::ExpandAsV2GradKernel<paddle::platform::CUDADeviceContext, double>);
#endif #endif
REGISTER_OP_VERSION(expand_as_v2)
.AddCheckpoint(
R"ROC(fix expand_as_v2 and add new input [Y])ROC",
paddle::framework::compatible::OpVersionDesc().NewInput(
"Y", "Expand X according to the shape of Y"));
\ No newline at end of file
...@@ -91,17 +91,34 @@ class ExpandAsV2Kernel : public framework::OpKernel<T> { ...@@ -91,17 +91,34 @@ class ExpandAsV2Kernel : public framework::OpKernel<T> {
PADDLE_ENFORCE_NE(target_shape[i], 0, PADDLE_ENFORCE_NE(target_shape[i], 0,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"The value of target shape cannot be zero.")); "The value of target shape cannot be zero."));
if (vec_in_dims[i] != 1) { if (i < diff) {
PADDLE_ENFORCE_GT(
target_shape[i], 0,
platform::errors::InvalidArgument(
"The expanded size (%d) for non-existing dimensions must be "
"positive for expand_as_v2 op.",
target_shape[i]));
repeat_times[i] = target_shape[i];
} else if (target_shape[i] > 0) {
if (vec_in_dims[i] != 1) {
PADDLE_ENFORCE_EQ(
vec_in_dims[i], target_shape[i],
platform::errors::InvalidArgument(
"The value (%d) of the non-singleton dimension does not match"
" the corresponding value (%d) in shape for expand_as_v2 op.",
vec_in_dims[i], target_shape[i]));
repeat_times[i] = 1;
} else {
repeat_times[i] = target_shape[i];
}
} else {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
vec_in_dims[i], target_shape[i], target_shape[i], -1,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"The value (%d) of the non-singleton dimension does not match" "When the value in shape is negative for expand_as_v2 op, "
" the corresponding value (%d) in " "only -1 is supported, but the value received is %d.",
"target tensor for expand_as_v2 op.", target_shape[i]));
vec_in_dims[i], target_shape[i]));
repeat_times[i] = 1; repeat_times[i] = 1;
} else {
repeat_times[i] = target_shape[i];
} }
} }
auto* out0 = context.Output<Tensor>("Out"); auto* out0 = context.Output<Tensor>("Out");
......
...@@ -65,7 +65,11 @@ class ExpandV2Op : public framework::OperatorWithKernel { ...@@ -65,7 +65,11 @@ class ExpandV2Op : public framework::OperatorWithKernel {
if (x_dims[i] == -1) { if (x_dims[i] == -1) {
out_shape[i] = -1; out_shape[i] = -1;
} else if (expand_shape[i] == -1) { } else if (expand_shape[i] == -1) {
out_shape[i] = x_dims[i]; if (static_cast<size_t>(x_dims.size()) > i) {
out_shape[i] = x_dims[i];
} else {
out_shape[i] = -1;
}
} else if (expand_shape[i] == -2) { } else if (expand_shape[i] == -2) {
// We use -2 to represent the element in expand_shape is a var. // We use -2 to represent the element in expand_shape is a var.
out_shape[i] = -1; out_shape[i] = -1;
......
...@@ -1838,7 +1838,7 @@ def expand_as(x, y, name=None): ...@@ -1838,7 +1838,7 @@ def expand_as(x, y, name=None):
"you must set its stop_gradient to be False by " "you must set its stop_gradient to be False by "
"some_var.stop_gradient = True, supporting " "some_var.stop_gradient = True, supporting "
"some_var as the input 'x'.") "some_var as the input 'x'.")
inputs = {"X": [x]} inputs = {"X": [x], "Y": [y]}
helper = LayerHelper('expand_as', **locals()) helper = LayerHelper('expand_as', **locals())
dtype = helper.input_dtype(input_param_name='x') dtype = helper.input_dtype(input_param_name='x')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册