未验证 提交 384062fa 编写于 作者: Z zyfncg 提交者: GitHub

Add intermediate config for some api in yaml (#42824)

* add intermediate for some api

* fix bug

* fix fluid.layer
上级 efaaf239
......@@ -2163,8 +2163,8 @@ void SplitInferMeta(const MetaTensor& x,
void SqueezeInferMeta(const MetaTensor& x,
const std::vector<int>& axes,
MetaTensor* xshape,
MetaTensor* out) {
MetaTensor* out,
MetaTensor* xshape) {
const auto& x_dims = x.dims();
// Check input tensor dims (<6) Eigen limit.
PADDLE_ENFORCE_LE(x_dims.size(),
......@@ -2964,8 +2964,8 @@ void UniqueRawInferMeta(const MetaTensor& x,
void UnsqueezeInferMeta(const MetaTensor& x,
const IntArray& axes,
MetaTensor* xshape,
MetaTensor* out,
MetaTensor* xshape,
MetaConfig config) {
const auto& x_dims = x.dims();
// Validity Check: input tensor dims (<6).
......
......@@ -306,8 +306,8 @@ void SplitInferMeta(const MetaTensor& x_meta,
void SqueezeInferMeta(const MetaTensor& x,
const std::vector<int>& axes,
MetaTensor* xshape,
MetaTensor* out);
MetaTensor* out,
MetaTensor* xshape);
void StridedSliceRawInferMeta(const MetaTensor& x,
const std::vector<int>& axes,
......@@ -425,8 +425,8 @@ void UniqueRawInferMeta(const MetaTensor& x,
void UnsqueezeInferMeta(const MetaTensor& x,
const IntArray& axes,
MetaTensor* xshape,
MetaTensor* out,
MetaTensor* xshape,
MetaConfig config = MetaConfig());
void UnStackInferMeta(const MetaTensor& x,
......
......@@ -22,8 +22,8 @@ template <typename T, typename Context>
void SqueezeKernel(const Context& dev_ctx,
const DenseTensor& x,
const std::vector<int>& axes,
DenseTensor* xshape,
DenseTensor* out) {
DenseTensor* out,
DenseTensor* xshape) {
auto x_dims = x.dims();
auto out_dims = funcs::GetOutputSqueezeShape(axes, x_dims, true);
......
......@@ -22,8 +22,8 @@ template <typename T, typename Context>
void UnsqueezeKernel(const Context& dev_ctx,
const DenseTensor& x,
const IntArray& axes,
DenseTensor* xshape,
DenseTensor* out) {
DenseTensor* out,
DenseTensor* xshape) {
auto x_dims = x.dims();
auto out_dims = out->dims();
if (axes.FromTensor()) {
......
......@@ -23,6 +23,6 @@ template <typename T, typename Context>
void SqueezeKernel(const Context& dev_ctx,
const DenseTensor& x,
const std::vector<int>& axes,
DenseTensor* xshape,
DenseTensor* out);
DenseTensor* out,
DenseTensor* xshape);
} // namespace phi
......@@ -24,6 +24,6 @@ template <typename T, typename Context>
void UnsqueezeKernel(const Context& dev_ctx,
const DenseTensor& x,
const IntArray& axes,
DenseTensor* xshape,
DenseTensor* out);
DenseTensor* out,
DenseTensor* xshape);
} // namespace phi
......@@ -18,7 +18,7 @@
namespace phi {
KernelSignature SqueezeOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("squeeze", {"X"}, {"axes"}, {"XShape", "Out"});
return KernelSignature("squeeze", {"X"}, {"axes"}, {"Out", "XShape"});
}
KernelSignature SqueezeGradOpArgumentMapping(
......
......@@ -21,14 +21,14 @@ KernelSignature UnsqueezeOpArgumentMapping(const ArgumentMappingContext& ctx) {
if (ctx.InputSize("AxesTensorList") > 0) {
VLOG(2) << "unsqueeze2 in AxesTensorList";
return KernelSignature(
"unsqueeze", {"X"}, {"AxesTensorList"}, {"XShape", "Out"});
"unsqueeze", {"X"}, {"AxesTensorList"}, {"Out", "XShape"});
} else if (ctx.InputSize("AxesTensor") > 0) {
VLOG(2) << "unsqueeze2 in AxesTensor";
return KernelSignature(
"unsqueeze", {"X"}, {"AxesTensor"}, {"XShape", "Out"});
"unsqueeze", {"X"}, {"AxesTensor"}, {"Out", "XShape"});
} else {
VLOG(2) << "unsqueeze2 in axes";
return KernelSignature("unsqueeze", {"X"}, {"axes"}, {"XShape", "Out"});
return KernelSignature("unsqueeze", {"X"}, {"axes"}, {"Out", "XShape"});
}
}
......
......@@ -6533,7 +6533,7 @@ def squeeze(input, axes, name=None):
"""
if in_dygraph_mode():
return _C_ops.final_state_squeeze(input, axes)[1]
return _C_ops.final_state_squeeze(input, axes)
if _in_legacy_dygraph():
out, _ = _C_ops.squeeze2(input, 'axes', axes)
return out
......@@ -6598,7 +6598,7 @@ def unsqueeze(input, axes, name=None):
if _in_legacy_dygraph():
out, _ = _C_ops.unsqueeze2(input, 'axes', axes)
return out
return _C_ops.final_state_unsqueeze(input, axes)[1]
return _C_ops.final_state_unsqueeze(input, axes)
check_type(axes, 'axis/axes', (int, list, tuple, Variable), 'unsqueeze')
check_variable_and_dtype(input, 'input', [
......
......@@ -1427,8 +1427,7 @@ def flatten(x, start_axis=0, stop_axis=-1, name=None):
raise ValueError("The stop_axis should be larger than stat_axis")
if in_dygraph_mode():
dy_out, _ = _C_ops.final_state_flatten(x, start_axis, stop_axis)
return dy_out
return _C_ops.final_state_flatten(x, start_axis, stop_axis)
if _in_legacy_dygraph():
dy_out, _ = _C_ops.flatten_contiguous_range(x, 'start_axis', start_axis,
......@@ -1936,7 +1935,7 @@ def squeeze(x, axis=None, name=None):
input = x
axes = axis
if in_dygraph_mode():
return _C_ops.final_state_squeeze(input, axes)[1]
return _C_ops.final_state_squeeze(input, axes)
if _in_legacy_dygraph():
out, _ = _C_ops.squeeze2(input, 'axes', axes)
return out
......@@ -2271,7 +2270,7 @@ def unsqueeze(x, axis, name=None):
if _in_legacy_dygraph():
out, _ = _C_ops.unsqueeze2(input, 'axes', axes)
return out
return _C_ops.final_state_unsqueeze(input, axes)[1]
return _C_ops.final_state_unsqueeze(input, axes)
check_type(axes, 'axis/axes', (int, list, tuple, Variable), 'unsqueeze')
check_variable_and_dtype(input, 'input', [
......
......@@ -714,7 +714,7 @@
backend : x
inplace : (x -> out)
view : (x -> out)
# intermediate : xshape
intermediate : xshape
backward : flatten_grad
# flip
......@@ -1984,12 +1984,13 @@
- api : squeeze
args : (Tensor x, int[] axes)
output : Tensor(xshape), Tensor(out)
output : Tensor(out), Tensor(xshape)
infer_meta :
func : SqueezeInferMeta
kernel :
func : squeeze
view: (x -> out)
intermediate : xshape
backward : squeeze_grad
- api : stack
......@@ -2213,12 +2214,13 @@
- api : unsqueeze
args : (Tensor x, IntArray axis)
output : Tensor(xshape), Tensor(out)
output : Tensor(out), Tensor(xshape)
infer_meta :
func : UnsqueezeInferMeta
kernel :
func : unsqueeze
view: (x -> out)
intermediate : xshape
backward : unsqueeze_grad
# viterbi_decode
......
......@@ -1750,7 +1750,7 @@
func : square_grad
- backward_api : squeeze_grad
forward : squeeze(Tensor x, int[] axes) -> Tensor(xshape), Tensor(out)
forward : squeeze(Tensor x, int[] axes) -> Tensor(out), Tensor(xshape)
args : (Tensor xshape, Tensor out_grad, int[] axes)
output : Tensor(x_grad)
infer_meta :
......@@ -2021,7 +2021,7 @@
no_need_buffer : x
- backward_api : unsqueeze_grad
forward : unsqueeze(Tensor x, IntArray axes) -> Tensor(xshape), Tensor(out)
forward : unsqueeze(Tensor x, IntArray axes) -> Tensor(out), Tensor(xshape)
args : (Tensor xshape, Tensor out_grad)
output : Tensor(x_grad)
infer_meta :
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册