未验证 提交 aec493c0 编写于 作者: T Thomas Young 提交者: GitHub

fix expand_v2 and expand_as_v2 bug (#38677)

上级 c48bd3ff
......@@ -12,6 +12,7 @@ limitations under the License. */
#include "paddle/fluid/operators/expand_as_v2_op.h"
#include <memory>
#include <vector>
#include "paddle/fluid/framework/op_version_registry.h"
namespace paddle {
namespace operators {
......@@ -50,6 +51,10 @@ class ExpandAsV2OpMaker : public framework::OpProtoAndCheckerMaker {
AddInput("X",
"(Tensor, default Tensor<float>). A tensor with rank in [1, 6]."
"X is the input to be expanded.");
AddInput("Y",
"(Tensor, default Tensor<float>). A tensor with rank in [1, 6]."
"Expand X according to the shape of Y.")
.AsDispensable();
AddOutput("Out",
"(Tensor, default Tensor<float>). A tensor with rank in [1, 6]."
"The rank of Output(Out) have the same with Input(X). "
......@@ -144,3 +149,9 @@ REGISTER_OP_CUDA_KERNEL(
ops::ExpandAsV2GradKernel<paddle::platform::CUDADeviceContext, float>,
ops::ExpandAsV2GradKernel<paddle::platform::CUDADeviceContext, double>);
#endif
REGISTER_OP_VERSION(expand_as_v2)
.AddCheckpoint(
R"ROC(fix expand_as_v2 and add new input [Y])ROC",
paddle::framework::compatible::OpVersionDesc().NewInput(
"Y", "Expand X according to the shape of Y"));
\ No newline at end of file
......@@ -91,18 +91,35 @@ class ExpandAsV2Kernel : public framework::OpKernel<T> {
PADDLE_ENFORCE_NE(target_shape[i], 0,
platform::errors::InvalidArgument(
"The value of target shape cannot be zero."));
if (i < diff) {
PADDLE_ENFORCE_GT(
target_shape[i], 0,
platform::errors::InvalidArgument(
"The expanded size (%d) for non-existing dimensions must be "
"positive for expand_as_v2 op.",
target_shape[i]));
repeat_times[i] = target_shape[i];
} else if (target_shape[i] > 0) {
if (vec_in_dims[i] != 1) {
PADDLE_ENFORCE_EQ(
vec_in_dims[i], target_shape[i],
platform::errors::InvalidArgument(
"The value (%d) of the non-singleton dimension does not match"
" the corresponding value (%d) in "
"target tensor for expand_as_v2 op.",
" the corresponding value (%d) in shape for expand_as_v2 op.",
vec_in_dims[i], target_shape[i]));
repeat_times[i] = 1;
} else {
repeat_times[i] = target_shape[i];
}
} else {
PADDLE_ENFORCE_EQ(
target_shape[i], -1,
platform::errors::InvalidArgument(
"When the value in shape is negative for expand_as_v2 op, "
"only -1 is supported, but the value received is %d.",
target_shape[i]));
repeat_times[i] = 1;
}
}
auto* out0 = context.Output<Tensor>("Out");
Eigen::DSizes<Eigen::DenseIndex, Rank> bcast_dims;
......
......@@ -65,7 +65,11 @@ class ExpandV2Op : public framework::OperatorWithKernel {
if (x_dims[i] == -1) {
out_shape[i] = -1;
} else if (expand_shape[i] == -1) {
if (static_cast<size_t>(x_dims.size()) > i) {
out_shape[i] = x_dims[i];
} else {
out_shape[i] = -1;
}
} else if (expand_shape[i] == -2) {
// We use -2 to represent the element in expand_shape is a var.
out_shape[i] = -1;
......
......@@ -1838,7 +1838,7 @@ def expand_as(x, y, name=None):
"you must set its stop_gradient to be False by "
"some_var.stop_gradient = True, supporting "
"some_var as the input 'x'.")
inputs = {"X": [x]}
inputs = {"X": [x], "Y": [y]}
helper = LayerHelper('expand_as', **locals())
dtype = helper.input_dtype(input_param_name='x')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册