未验证 提交 75923a32 编写于 作者: C Chen Weihang 提交者: GitHub

[PTen] Update all forward argument maping fns (#39252)

* update forward argument mapping

* fix compile failed

* fix test failed
上级 9a001c09
...@@ -64,6 +64,16 @@ class InferShapeArgumentMappingContext : public pten::ArgumentMappingContext { ...@@ -64,6 +64,16 @@ class InferShapeArgumentMappingContext : public pten::ArgumentMappingContext {
return var_types[0] == proto::VarType::SELECTED_ROWS; return var_types[0] == proto::VarType::SELECTED_ROWS;
} }
bool IsDenseTensorOutput(const std::string& name) const override {
auto var_types = ctx_.GetOutputsVarType(name);
return var_types[0] == proto::VarType::LOD_TENSOR;
}
bool IsSelectedRowsOutput(const std::string& name) const override {
auto var_types = ctx_.GetOutputsVarType(name);
return var_types[0] == proto::VarType::SELECTED_ROWS;
}
private: private:
const InferShapeContext& ctx_; const InferShapeContext& ctx_;
}; };
......
...@@ -461,11 +461,11 @@ class ExecutionArgumentMappingContext : public pten::ArgumentMappingContext { ...@@ -461,11 +461,11 @@ class ExecutionArgumentMappingContext : public pten::ArgumentMappingContext {
} }
size_t InputSize(const std::string& name) const override { size_t InputSize(const std::string& name) const override {
return ctx_.InputSize(name); return ctx_.MultiInputVar(name).size();
} }
size_t OutputSize(const std::string& name) const override { size_t OutputSize(const std::string& name) const override {
return ctx_.OutputSize(name); return ctx_.MultiOutputVar(name).size();
} }
bool IsDenseTensorInput(const std::string& name) const override { bool IsDenseTensorInput(const std::string& name) const override {
...@@ -476,6 +476,14 @@ class ExecutionArgumentMappingContext : public pten::ArgumentMappingContext { ...@@ -476,6 +476,14 @@ class ExecutionArgumentMappingContext : public pten::ArgumentMappingContext {
return ctx_.InputVar(name)->IsType<pten::SelectedRows>(); return ctx_.InputVar(name)->IsType<pten::SelectedRows>();
} }
bool IsDenseTensorOutput(const std::string& name) const override {
return ctx_.OutputVar(name)->IsType<framework::LoDTensor>();
}
bool IsSelectedRowsOutput(const std::string& name) const override {
return ctx_.OutputVar(name)->IsType<pten::SelectedRows>();
}
private: private:
const ExecutionContext& ctx_; const ExecutionContext& ctx_;
}; };
......
...@@ -121,11 +121,6 @@ class CastOp : public framework::OperatorWithKernel { ...@@ -121,11 +121,6 @@ class CastOp : public framework::OperatorWithKernel {
#endif #endif
return framework::OpKernelType(tensor->type(), tensor_place); return framework::OpKernelType(tensor->type(), tensor_place);
} }
framework::KernelSignature GetExpectedPtenKernelArgs(
const framework::ExecutionContext &ctx) const override {
return framework::KernelSignature("cast", {"X"}, {"out_dtype"}, {"Out"});
}
}; };
} // namespace operators } // namespace operators
......
...@@ -104,15 +104,6 @@ class ConcatOp : public framework::OperatorWithKernel { ...@@ -104,15 +104,6 @@ class ConcatOp : public framework::OperatorWithKernel {
return framework::OpKernelType(expected_kernel_type.data_type_, return framework::OpKernelType(expected_kernel_type.data_type_,
tensor.place(), tensor.layout()); tensor.place(), tensor.layout());
} }
framework::KernelSignature GetExpectedPtenKernelArgs(
const framework::ExecutionContext &ctx) const override {
if (ctx.HasInput("AxisTensor")) {
return framework::KernelSignature("concat", {"X"}, {"AxisTensor"},
{"Out"});
}
return framework::KernelSignature("concat", {"X"}, {"axis"}, {"Out"});
}
}; };
class ConcatOpMaker : public framework::OpProtoAndCheckerMaker { class ConcatOpMaker : public framework::OpProtoAndCheckerMaker {
......
...@@ -137,50 +137,6 @@ class ElementwiseOp : public framework::OperatorWithKernel { ...@@ -137,50 +137,6 @@ class ElementwiseOp : public framework::OperatorWithKernel {
tensor.place(), tensor.layout()); tensor.place(), tensor.layout());
} }
} }
framework::KernelSignature GetExpectedPtenKernelArgs(
const framework::ExecutionContext &ctx) const override {
int axis = ctx.Attr<int>("axis");
if (Type() == "elementwise_add") {
if (ctx.InputVar("X")->IsType<framework::LoDTensor>()) {
if (axis == -1) {
return framework::KernelSignature("add", {"X", "Y"}, {}, {"Out"});
}
return framework::KernelSignature("add_raw", {"X", "Y"}, {"axis"},
{"Out"});
}
}
if (Type() == "elementwise_sub") {
if (ctx.InputVar("X")->IsType<framework::LoDTensor>()) {
if (axis == -1) {
return framework::KernelSignature("subtract", {"X", "Y"}, {},
{"Out"});
}
return framework::KernelSignature("subtract_raw", {"X", "Y"}, {"axis"},
{"Out"});
}
}
if (Type() == "elementwise_div") {
if (ctx.InputVar("X")->IsType<framework::LoDTensor>()) {
if (axis == -1) {
return framework::KernelSignature("divide", {"X", "Y"}, {}, {"Out"});
}
return framework::KernelSignature("divide_raw", {"X", "Y"}, {"axis"},
{"Out"});
}
}
if (Type() == "elementwise_mul") {
if (ctx.InputVar("X")->IsType<framework::LoDTensor>()) {
if (axis == -1) {
return framework::KernelSignature("multiply", {"X", "Y"}, {},
{"Out"});
}
return framework::KernelSignature("multiply_raw", {"X", "Y"}, {"axis"},
{"Out"});
}
}
return framework::KernelSignature("None", {"X"}, {}, {"Out"});
}
}; };
class ElementwiseOpInferVarType class ElementwiseOpInferVarType
......
...@@ -109,20 +109,6 @@ class EmptyOp : public framework::OperatorWithKernel { ...@@ -109,20 +109,6 @@ class EmptyOp : public framework::OperatorWithKernel {
framework::proto::VarType::Type(context.Attr<int>("dtype")), framework::proto::VarType::Type(context.Attr<int>("dtype")),
context.GetPlace()); context.GetPlace());
} }
framework::KernelSignature GetExpectedPtenKernelArgs(
const framework::ExecutionContext& ctx) const override {
std::string shape;
if (ctx.HasInput("ShapeTensor")) {
shape = "ShapeTensor";
} else if (ctx.MultiInput<framework::Tensor>("ShapeTensorList").size()) {
shape = "ShapeTensorList";
} else {
shape = "shape";
}
return framework::KernelSignature("empty", {}, {shape}, {"Out"});
}
}; };
class EmptyOpVarTypeInference : public framework::VarTypeInference { class EmptyOpVarTypeInference : public framework::VarTypeInference {
......
...@@ -47,11 +47,6 @@ class FillAnyLikeOp : public framework::OperatorWithKernel { ...@@ -47,11 +47,6 @@ class FillAnyLikeOp : public framework::OperatorWithKernel {
expected_kernel_type.place_, expected_kernel_type.place_,
tensor.layout()); tensor.layout());
} }
framework::KernelSignature GetExpectedPtenKernelArgs(
const framework::ExecutionContext &ctx) const override {
return framework::KernelSignature("full_like", {}, {"value"}, {"Out"});
}
}; };
class FillAnyLikeOpMaker : public framework::OpProtoAndCheckerMaker { class FillAnyLikeOpMaker : public framework::OpProtoAndCheckerMaker {
......
...@@ -99,29 +99,6 @@ class FillConstantOp : public framework::OperatorWithKernel { ...@@ -99,29 +99,6 @@ class FillConstantOp : public framework::OperatorWithKernel {
return kt; return kt;
} }
framework::KernelSignature GetExpectedPtenKernelArgs(
const framework::ExecutionContext& ctx) const override {
std::string shape;
if (ctx.HasInput("ShapeTensor")) {
shape = "ShapeTensor";
} else if (ctx.MultiInput<framework::Tensor>("ShapeTensorList").size()) {
shape = "ShapeTensorList";
} else {
shape = "shape";
}
std::string value;
if (ctx.HasInput("ValueTensor")) {
value = "ValueTensor";
} else {
const auto& str_value = ctx.Attr<std::string>("str_value");
value = str_value.empty() ? "value" : "str_value";
}
if (!ctx.OutputVar("Out")->IsType<pten::SelectedRows>()) {
return framework::KernelSignature("full", {}, {shape, value}, {"Out"});
}
return framework::KernelSignature("fill_constant.unregistered", {}, {}, {});
}
}; };
class FillConstantOpVarTypeInference : public framework::VarTypeInference { class FillConstantOpVarTypeInference : public framework::VarTypeInference {
......
...@@ -333,18 +333,6 @@ class FlattenContiguousRangeOp : public framework::OperatorWithKernel { ...@@ -333,18 +333,6 @@ class FlattenContiguousRangeOp : public framework::OperatorWithKernel {
return out_shape; return out_shape;
} }
framework::KernelSignature GetExpectedPtenKernelArgs(
const framework::ExecutionContext &ctx) const override {
if (ctx.HasOutput("XShape")) {
return framework::KernelSignature("flatten_with_xshape", {"X"},
{"start_axis", "stop_axis"},
{"Out", "XShape"});
} else {
return framework::KernelSignature("flatten", {"X"},
{"start_axis", "stop_axis"}, {"Out"});
}
}
}; };
class FlattenContiguousRangeOpMaker : public FlattenOpMaker { class FlattenContiguousRangeOpMaker : public FlattenOpMaker {
......
...@@ -485,20 +485,6 @@ class Reshape2Op : public ReshapeOp { ...@@ -485,20 +485,6 @@ class Reshape2Op : public ReshapeOp {
ReshapeOp::InferShape(ctx); ReshapeOp::InferShape(ctx);
} }
framework::KernelSignature GetExpectedPtenKernelArgs(
const framework::ExecutionContext &ctx) const override {
std::string shape;
auto multi_inputs = ctx.MultiInput<framework::Tensor>("ShapeTensor");
if (multi_inputs.size() > 0) {
shape = "ShapeTensor";
} else if (ctx.HasInput("Shape")) {
shape = "Shape";
} else {
shape = "shape";
}
return framework::KernelSignature("reshape", {"X"}, {shape}, {"Out"});
}
}; };
class Reshape2OpMaker : public ReshapeOpMaker { class Reshape2OpMaker : public ReshapeOpMaker {
......
...@@ -75,6 +75,9 @@ class ArgumentMappingContext { ...@@ -75,6 +75,9 @@ class ArgumentMappingContext {
virtual bool IsDenseTensorInput(const std::string& name) const = 0; virtual bool IsDenseTensorInput(const std::string& name) const = 0;
virtual bool IsSelectedRowsInput(const std::string& name) const = 0; virtual bool IsSelectedRowsInput(const std::string& name) const = 0;
virtual bool IsDenseTensorOutput(const std::string& name) const = 0;
virtual bool IsSelectedRowsOutput(const std::string& name) const = 0;
}; };
} // namespace pten } // namespace pten
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/pten/core/compat/op_utils.h"
namespace pten {
KernelSignature CastOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("cast", {"X"}, {"out_dtype"}, {"Out"});
}
} // namespace pten
PT_REGISTER_ARG_MAPPING_FN(cast, pten::CastOpArgumentMapping);
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/pten/core/compat/op_utils.h"
namespace pten {
KernelSignature ConcatOpArgumentMapping(const ArgumentMappingContext& ctx) {
if (ctx.HasInput("AxisTensor")) {
return KernelSignature("concat", {"X"}, {"AxisTensor"}, {"Out"});
}
return KernelSignature("concat", {"X"}, {"axis"}, {"Out"});
}
} // namespace pten
PT_REGISTER_ARG_MAPPING_FN(concat, pten::ConcatOpArgumentMapping);
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/pten/core/compat/op_utils.h"
namespace pten {
KernelSignature ElementwiseAddOpArgumentMapping(
const ArgumentMappingContext& ctx) {
int axis = paddle::any_cast<int>(ctx.Attr("axis"));
if (ctx.IsDenseTensorInput("X")) {
if (axis == -1) {
return KernelSignature("add", {"X", "Y"}, {}, {"Out"});
}
return KernelSignature("add_raw", {"X", "Y"}, {"axis"}, {"Out"});
}
return KernelSignature("unregistered", {}, {}, {});
}
KernelSignature ElementwiseSubOpArgumentMapping(
const ArgumentMappingContext& ctx) {
int axis = paddle::any_cast<int>(ctx.Attr("axis"));
if (ctx.IsDenseTensorInput("X")) {
if (axis == -1) {
return KernelSignature("subtract", {"X", "Y"}, {}, {"Out"});
}
return KernelSignature("subtract_raw", {"X", "Y"}, {"axis"}, {"Out"});
}
return KernelSignature("unregistered", {}, {}, {});
}
KernelSignature ElementwiseMulOpArgumentMapping(
const ArgumentMappingContext& ctx) {
int axis = paddle::any_cast<int>(ctx.Attr("axis"));
if (ctx.IsDenseTensorInput("X")) {
if (axis == -1) {
return KernelSignature("multiply", {"X", "Y"}, {}, {"Out"});
}
return KernelSignature("multiply_raw", {"X", "Y"}, {"axis"}, {"Out"});
}
return KernelSignature("unregistered", {}, {}, {});
}
KernelSignature ElementwiseDivOpArgumentMapping(
const ArgumentMappingContext& ctx) {
int axis = paddle::any_cast<int>(ctx.Attr("axis"));
if (ctx.IsDenseTensorInput("X")) {
if (axis == -1) {
return KernelSignature("divide", {"X", "Y"}, {}, {"Out"});
}
return KernelSignature("divide_raw", {"X", "Y"}, {"axis"}, {"Out"});
}
return KernelSignature("unregistered", {}, {}, {});
}
} // namespace pten
PT_REGISTER_ARG_MAPPING_FN(elementwise_add,
pten::ElementwiseAddOpArgumentMapping);
PT_REGISTER_ARG_MAPPING_FN(elementwise_sub,
pten::ElementwiseSubOpArgumentMapping);
PT_REGISTER_ARG_MAPPING_FN(elementwise_mul,
pten::ElementwiseMulOpArgumentMapping);
PT_REGISTER_ARG_MAPPING_FN(elementwise_div,
pten::ElementwiseDivOpArgumentMapping);
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/pten/core/compat/op_utils.h"
namespace pten {
KernelSignature EmptyOpArgumentMapping(const ArgumentMappingContext& ctx) {
if (ctx.HasInput("ShapeTensor")) {
return KernelSignature("empty", {}, {"ShapeTensor"}, {"Out"});
} else if (ctx.InputSize("ShapeTensorList") > 0) {
return KernelSignature("empty", {}, {"ShapeTensorList"}, {"Out"});
} else {
return KernelSignature("empty", {}, {"shape"}, {"Out"});
}
}
} // namespace pten
PT_REGISTER_ARG_MAPPING_FN(empty, pten::EmptyOpArgumentMapping);
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/pten/core/compat/op_utils.h"
namespace pten {
KernelSignature FillAnyLikeOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("full_like", {}, {"value"}, {"Out"});
}
} // namespace pten
PT_REGISTER_ARG_MAPPING_FN(fill_any_like, pten::FillAnyLikeOpArgumentMapping);
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/pten/core/compat/op_utils.h"
namespace pten {
// we have to return every specific KernelSignature for infrt now
KernelSignature FillConstantOpArgumentMapping(
const ArgumentMappingContext& ctx) {
if (ctx.IsDenseTensorOutput("Out")) {
if (ctx.HasInput("ShapeTensor")) {
if (ctx.HasInput("ValueTensor")) {
return KernelSignature(
"full", {}, {"ShapeTensor", "ValueTensor"}, {"Out"});
} else {
const auto& str_value =
paddle::any_cast<std::string>(ctx.Attr("str_value"));
if (str_value.empty()) {
return KernelSignature("full", {}, {"ShapeTensor", "value"}, {"Out"});
} else {
return KernelSignature(
"full", {}, {"ShapeTensor", "str_value"}, {"Out"});
}
}
} else if (ctx.InputSize("ShapeTensorList") > 0) {
if (ctx.HasInput("ValueTensor")) {
return KernelSignature(
"full", {}, {"ShapeTensorList", "ValueTensor"}, {"Out"});
} else {
const auto& str_value =
paddle::any_cast<std::string>(ctx.Attr("str_value"));
if (str_value.empty()) {
return KernelSignature(
"full", {}, {"ShapeTensorList", "value"}, {"Out"});
} else {
return KernelSignature(
"full", {}, {"ShapeTensorList", "str_value"}, {"Out"});
}
}
} else {
if (ctx.HasInput("ValueTensor")) {
return KernelSignature("full", {}, {"shape", "ValueTensor"}, {"Out"});
} else {
const auto& str_value =
paddle::any_cast<std::string>(ctx.Attr("str_value"));
if (str_value.empty()) {
return KernelSignature("full", {}, {"shape", "value"}, {"Out"});
} else {
return KernelSignature("full", {}, {"shape", "str_value"}, {"Out"});
}
}
}
}
return KernelSignature("unregistered", {}, {}, {});
}
} // namespace pten
PT_REGISTER_ARG_MAPPING_FN(fill_constant, pten::FillConstantOpArgumentMapping);
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/pten/core/compat/op_utils.h"
namespace pten {
KernelSignature FlattenOpArgumentMapping(const ArgumentMappingContext& ctx) {
if (ctx.HasOutput("XShape")) {
return KernelSignature("flatten_with_xshape",
{"X"},
{"start_axis", "stop_axis"},
{"Out", "XShape"});
} else {
return KernelSignature(
"flatten", {"X"}, {"start_axis", "stop_axis"}, {"Out"});
}
}
} // namespace pten
PT_REGISTER_ARG_MAPPING_FN(flatten_contiguous_range,
pten::FlattenOpArgumentMapping);
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/pten/core/compat/op_utils.h"
namespace pten {
KernelSignature ReduceSumOpArgumentMapping(const ArgumentMappingContext& ctx) {
bool reduce_all = paddle::any_cast<bool>(ctx.Attr("reduce_all"));
if (ctx.IsDenseTensorInput("X")) {
if (!reduce_all) {
return KernelSignature(
"sum", {"X"}, {"dim", "keep_dim", "out_dtype"}, {"Out"});
}
return KernelSignature("sum_raw",
{"X"},
{"dim", "keep_dim", "reduce_all", "out_dtype"},
{"Out"});
}
return KernelSignature("unregistered", {}, {}, {});
}
KernelSignature ReduceMeanOpArgumentMapping(const ArgumentMappingContext& ctx) {
bool reduce_all = paddle::any_cast<bool>(ctx.Attr("reduce_all"));
if (ctx.IsDenseTensorInput("X")) {
if (!reduce_all) {
return KernelSignature("mean", {"X"}, {"dim", "keep_dim"}, {"Out"});
}
return KernelSignature(
"mean_raw", {"X"}, {"dim", "keep_dim", "reduce_all"}, {"Out"});
}
return KernelSignature("unregistered", {}, {}, {});
}
} // namespace pten
PT_REGISTER_ARG_MAPPING_FN(reduce_sum, pten::ReduceSumOpArgumentMapping);
PT_REGISTER_ARG_MAPPING_FN(reduce_mean, pten::ReduceMeanOpArgumentMapping);
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/pten/core/compat/op_utils.h"
namespace pten {
KernelSignature ReshapeOpArgumentMapping(const ArgumentMappingContext& ctx) {
if (ctx.InputSize("ShapeTensor") > 0) {
return KernelSignature("reshape", {"X"}, {"ShapeTensor"}, {"Out"});
} else if (ctx.HasInput("Shape")) {
return KernelSignature("reshape", {"X"}, {"Shape"}, {"Out"});
} else {
return KernelSignature("reshape", {"X"}, {"shape"}, {"Out"});
}
}
} // namespace pten
PT_REGISTER_ARG_MAPPING_FN(reshape2, pten::ReshapeOpArgumentMapping);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册