未验证 提交 384062fa 编写于 作者: Z zyfncg 提交者: GitHub

Add intermediate config for some api in yaml (#42824)

* add intermediate for some api

* fix bug

* fix fluid.layer
上级 efaaf239
...@@ -2163,8 +2163,8 @@ void SplitInferMeta(const MetaTensor& x, ...@@ -2163,8 +2163,8 @@ void SplitInferMeta(const MetaTensor& x,
void SqueezeInferMeta(const MetaTensor& x, void SqueezeInferMeta(const MetaTensor& x,
const std::vector<int>& axes, const std::vector<int>& axes,
MetaTensor* xshape, MetaTensor* out,
MetaTensor* out) { MetaTensor* xshape) {
const auto& x_dims = x.dims(); const auto& x_dims = x.dims();
// Check input tensor dims (<6) Eigen limit. // Check input tensor dims (<6) Eigen limit.
PADDLE_ENFORCE_LE(x_dims.size(), PADDLE_ENFORCE_LE(x_dims.size(),
...@@ -2964,8 +2964,8 @@ void UniqueRawInferMeta(const MetaTensor& x, ...@@ -2964,8 +2964,8 @@ void UniqueRawInferMeta(const MetaTensor& x,
void UnsqueezeInferMeta(const MetaTensor& x, void UnsqueezeInferMeta(const MetaTensor& x,
const IntArray& axes, const IntArray& axes,
MetaTensor* xshape,
MetaTensor* out, MetaTensor* out,
MetaTensor* xshape,
MetaConfig config) { MetaConfig config) {
const auto& x_dims = x.dims(); const auto& x_dims = x.dims();
// Validity Check: input tensor dims (<6). // Validity Check: input tensor dims (<6).
......
...@@ -306,8 +306,8 @@ void SplitInferMeta(const MetaTensor& x_meta, ...@@ -306,8 +306,8 @@ void SplitInferMeta(const MetaTensor& x_meta,
void SqueezeInferMeta(const MetaTensor& x, void SqueezeInferMeta(const MetaTensor& x,
const std::vector<int>& axes, const std::vector<int>& axes,
MetaTensor* xshape, MetaTensor* out,
MetaTensor* out); MetaTensor* xshape);
void StridedSliceRawInferMeta(const MetaTensor& x, void StridedSliceRawInferMeta(const MetaTensor& x,
const std::vector<int>& axes, const std::vector<int>& axes,
...@@ -425,8 +425,8 @@ void UniqueRawInferMeta(const MetaTensor& x, ...@@ -425,8 +425,8 @@ void UniqueRawInferMeta(const MetaTensor& x,
void UnsqueezeInferMeta(const MetaTensor& x, void UnsqueezeInferMeta(const MetaTensor& x,
const IntArray& axes, const IntArray& axes,
MetaTensor* xshape,
MetaTensor* out, MetaTensor* out,
MetaTensor* xshape,
MetaConfig config = MetaConfig()); MetaConfig config = MetaConfig());
void UnStackInferMeta(const MetaTensor& x, void UnStackInferMeta(const MetaTensor& x,
......
...@@ -22,8 +22,8 @@ template <typename T, typename Context> ...@@ -22,8 +22,8 @@ template <typename T, typename Context>
void SqueezeKernel(const Context& dev_ctx, void SqueezeKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const std::vector<int>& axes, const std::vector<int>& axes,
DenseTensor* xshape, DenseTensor* out,
DenseTensor* out) { DenseTensor* xshape) {
auto x_dims = x.dims(); auto x_dims = x.dims();
auto out_dims = funcs::GetOutputSqueezeShape(axes, x_dims, true); auto out_dims = funcs::GetOutputSqueezeShape(axes, x_dims, true);
......
...@@ -22,8 +22,8 @@ template <typename T, typename Context> ...@@ -22,8 +22,8 @@ template <typename T, typename Context>
void UnsqueezeKernel(const Context& dev_ctx, void UnsqueezeKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const IntArray& axes, const IntArray& axes,
DenseTensor* xshape, DenseTensor* out,
DenseTensor* out) { DenseTensor* xshape) {
auto x_dims = x.dims(); auto x_dims = x.dims();
auto out_dims = out->dims(); auto out_dims = out->dims();
if (axes.FromTensor()) { if (axes.FromTensor()) {
......
...@@ -23,6 +23,6 @@ template <typename T, typename Context> ...@@ -23,6 +23,6 @@ template <typename T, typename Context>
void SqueezeKernel(const Context& dev_ctx, void SqueezeKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const std::vector<int>& axes, const std::vector<int>& axes,
DenseTensor* xshape, DenseTensor* out,
DenseTensor* out); DenseTensor* xshape);
} // namespace phi } // namespace phi
...@@ -24,6 +24,6 @@ template <typename T, typename Context> ...@@ -24,6 +24,6 @@ template <typename T, typename Context>
void UnsqueezeKernel(const Context& dev_ctx, void UnsqueezeKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const IntArray& axes, const IntArray& axes,
DenseTensor* xshape, DenseTensor* out,
DenseTensor* out); DenseTensor* xshape);
} // namespace phi } // namespace phi
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
namespace phi { namespace phi {
KernelSignature SqueezeOpArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature SqueezeOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("squeeze", {"X"}, {"axes"}, {"XShape", "Out"}); return KernelSignature("squeeze", {"X"}, {"axes"}, {"Out", "XShape"});
} }
KernelSignature SqueezeGradOpArgumentMapping( KernelSignature SqueezeGradOpArgumentMapping(
......
...@@ -21,14 +21,14 @@ KernelSignature UnsqueezeOpArgumentMapping(const ArgumentMappingContext& ctx) { ...@@ -21,14 +21,14 @@ KernelSignature UnsqueezeOpArgumentMapping(const ArgumentMappingContext& ctx) {
if (ctx.InputSize("AxesTensorList") > 0) { if (ctx.InputSize("AxesTensorList") > 0) {
VLOG(2) << "unsqueeze2 in AxesTensorList"; VLOG(2) << "unsqueeze2 in AxesTensorList";
return KernelSignature( return KernelSignature(
"unsqueeze", {"X"}, {"AxesTensorList"}, {"XShape", "Out"}); "unsqueeze", {"X"}, {"AxesTensorList"}, {"Out", "XShape"});
} else if (ctx.InputSize("AxesTensor") > 0) { } else if (ctx.InputSize("AxesTensor") > 0) {
VLOG(2) << "unsqueeze2 in AxesTensor"; VLOG(2) << "unsqueeze2 in AxesTensor";
return KernelSignature( return KernelSignature(
"unsqueeze", {"X"}, {"AxesTensor"}, {"XShape", "Out"}); "unsqueeze", {"X"}, {"AxesTensor"}, {"Out", "XShape"});
} else { } else {
VLOG(2) << "unsqueeze2 in axes"; VLOG(2) << "unsqueeze2 in axes";
return KernelSignature("unsqueeze", {"X"}, {"axes"}, {"XShape", "Out"}); return KernelSignature("unsqueeze", {"X"}, {"axes"}, {"Out", "XShape"});
} }
} }
......
...@@ -6533,7 +6533,7 @@ def squeeze(input, axes, name=None): ...@@ -6533,7 +6533,7 @@ def squeeze(input, axes, name=None):
""" """
if in_dygraph_mode(): if in_dygraph_mode():
return _C_ops.final_state_squeeze(input, axes)[1] return _C_ops.final_state_squeeze(input, axes)
if _in_legacy_dygraph(): if _in_legacy_dygraph():
out, _ = _C_ops.squeeze2(input, 'axes', axes) out, _ = _C_ops.squeeze2(input, 'axes', axes)
return out return out
...@@ -6598,7 +6598,7 @@ def unsqueeze(input, axes, name=None): ...@@ -6598,7 +6598,7 @@ def unsqueeze(input, axes, name=None):
if _in_legacy_dygraph(): if _in_legacy_dygraph():
out, _ = _C_ops.unsqueeze2(input, 'axes', axes) out, _ = _C_ops.unsqueeze2(input, 'axes', axes)
return out return out
return _C_ops.final_state_unsqueeze(input, axes)[1] return _C_ops.final_state_unsqueeze(input, axes)
check_type(axes, 'axis/axes', (int, list, tuple, Variable), 'unsqueeze') check_type(axes, 'axis/axes', (int, list, tuple, Variable), 'unsqueeze')
check_variable_and_dtype(input, 'input', [ check_variable_and_dtype(input, 'input', [
......
...@@ -1427,8 +1427,7 @@ def flatten(x, start_axis=0, stop_axis=-1, name=None): ...@@ -1427,8 +1427,7 @@ def flatten(x, start_axis=0, stop_axis=-1, name=None):
raise ValueError("The stop_axis should be larger than stat_axis") raise ValueError("The stop_axis should be larger than stat_axis")
if in_dygraph_mode(): if in_dygraph_mode():
dy_out, _ = _C_ops.final_state_flatten(x, start_axis, stop_axis) return _C_ops.final_state_flatten(x, start_axis, stop_axis)
return dy_out
if _in_legacy_dygraph(): if _in_legacy_dygraph():
dy_out, _ = _C_ops.flatten_contiguous_range(x, 'start_axis', start_axis, dy_out, _ = _C_ops.flatten_contiguous_range(x, 'start_axis', start_axis,
...@@ -1936,7 +1935,7 @@ def squeeze(x, axis=None, name=None): ...@@ -1936,7 +1935,7 @@ def squeeze(x, axis=None, name=None):
input = x input = x
axes = axis axes = axis
if in_dygraph_mode(): if in_dygraph_mode():
return _C_ops.final_state_squeeze(input, axes)[1] return _C_ops.final_state_squeeze(input, axes)
if _in_legacy_dygraph(): if _in_legacy_dygraph():
out, _ = _C_ops.squeeze2(input, 'axes', axes) out, _ = _C_ops.squeeze2(input, 'axes', axes)
return out return out
...@@ -2271,7 +2270,7 @@ def unsqueeze(x, axis, name=None): ...@@ -2271,7 +2270,7 @@ def unsqueeze(x, axis, name=None):
if _in_legacy_dygraph(): if _in_legacy_dygraph():
out, _ = _C_ops.unsqueeze2(input, 'axes', axes) out, _ = _C_ops.unsqueeze2(input, 'axes', axes)
return out return out
return _C_ops.final_state_unsqueeze(input, axes)[1] return _C_ops.final_state_unsqueeze(input, axes)
check_type(axes, 'axis/axes', (int, list, tuple, Variable), 'unsqueeze') check_type(axes, 'axis/axes', (int, list, tuple, Variable), 'unsqueeze')
check_variable_and_dtype(input, 'input', [ check_variable_and_dtype(input, 'input', [
......
...@@ -714,7 +714,7 @@ ...@@ -714,7 +714,7 @@
backend : x backend : x
inplace : (x -> out) inplace : (x -> out)
view : (x -> out) view : (x -> out)
# intermediate : xshape intermediate : xshape
backward : flatten_grad backward : flatten_grad
# flip # flip
...@@ -1984,12 +1984,13 @@ ...@@ -1984,12 +1984,13 @@
- api : squeeze - api : squeeze
args : (Tensor x, int[] axes) args : (Tensor x, int[] axes)
output : Tensor(xshape), Tensor(out) output : Tensor(out), Tensor(xshape)
infer_meta : infer_meta :
func : SqueezeInferMeta func : SqueezeInferMeta
kernel : kernel :
func : squeeze func : squeeze
view: (x -> out) view: (x -> out)
intermediate : xshape
backward : squeeze_grad backward : squeeze_grad
- api : stack - api : stack
...@@ -2213,12 +2214,13 @@ ...@@ -2213,12 +2214,13 @@
- api : unsqueeze - api : unsqueeze
args : (Tensor x, IntArray axis) args : (Tensor x, IntArray axis)
output : Tensor(xshape), Tensor(out) output : Tensor(out), Tensor(xshape)
infer_meta : infer_meta :
func : UnsqueezeInferMeta func : UnsqueezeInferMeta
kernel : kernel :
func : unsqueeze func : unsqueeze
view: (x -> out) view: (x -> out)
intermediate : xshape
backward : unsqueeze_grad backward : unsqueeze_grad
# viterbi_decode # viterbi_decode
......
...@@ -1750,7 +1750,7 @@ ...@@ -1750,7 +1750,7 @@
func : square_grad func : square_grad
- backward_api : squeeze_grad - backward_api : squeeze_grad
forward : squeeze(Tensor x, int[] axes) -> Tensor(xshape), Tensor(out) forward : squeeze(Tensor x, int[] axes) -> Tensor(out), Tensor(xshape)
args : (Tensor xshape, Tensor out_grad, int[] axes) args : (Tensor xshape, Tensor out_grad, int[] axes)
output : Tensor(x_grad) output : Tensor(x_grad)
infer_meta : infer_meta :
...@@ -2021,7 +2021,7 @@ ...@@ -2021,7 +2021,7 @@
no_need_buffer : x no_need_buffer : x
- backward_api : unsqueeze_grad - backward_api : unsqueeze_grad
forward : unsqueeze(Tensor x, IntArray axes) -> Tensor(xshape), Tensor(out) forward : unsqueeze(Tensor x, IntArray axes) -> Tensor(out), Tensor(xshape)
args : (Tensor xshape, Tensor out_grad) args : (Tensor xshape, Tensor out_grad)
output : Tensor(x_grad) output : Tensor(x_grad)
infer_meta : infer_meta :
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册