未验证 提交 667bd962 编写于 作者: C Chen Weihang 提交者: GitHub

[PTen] Move grad GetExpectedPtenKernelArgs into pten (#39418)

* move grad get expected pten kernel args

* fix reduce sum error

* fix element_sub_grad failed

* revert kernel judge change
上级 22c67d14
...@@ -1171,8 +1171,8 @@ void OperatorWithKernel::RunImpl(const Scope& scope, ...@@ -1171,8 +1171,8 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
std::string pt_kernel_name; std::string pt_kernel_name;
if (pten::KernelFactory::Instance().HasCompatiblePtenKernel(type_)) { if (pten::KernelFactory::Instance().HasCompatiblePtenKernel(type_)) {
if (pt_kernel_signature_ == nullptr || pt_kernel_ == nullptr) { if (pt_kernel_signature_ == nullptr || pt_kernel_ == nullptr) {
pt_kernel_signature_.reset(new KernelSignature( pt_kernel_signature_.reset(
std::move(this->GetExpectedPtenKernelArgs(exe_ctx)))); new KernelSignature(std::move(GetExpectedPtenKernelArgs(exe_ctx))));
VLOG(6) << *pt_kernel_signature_.get(); VLOG(6) << *pt_kernel_signature_.get();
kernel_type_.reset( kernel_type_.reset(
...@@ -1359,7 +1359,7 @@ OpKernelType OperatorWithKernel::InnerGetExpectedKernelType( ...@@ -1359,7 +1359,7 @@ OpKernelType OperatorWithKernel::InnerGetExpectedKernelType(
pten::KernelKey OperatorWithKernel::ChoosePtenKernel( pten::KernelKey OperatorWithKernel::ChoosePtenKernel(
const ExecutionContext& ctx) const { const ExecutionContext& ctx) const {
pt_kernel_signature_.reset( pt_kernel_signature_.reset(
new KernelSignature(std::move(this->GetExpectedPtenKernelArgs(ctx)))); new KernelSignature(std::move(GetExpectedPtenKernelArgs(ctx))));
VLOG(6) << *pt_kernel_signature_.get(); VLOG(6) << *pt_kernel_signature_.get();
kernel_type_.reset( kernel_type_.reset(
......
...@@ -606,7 +606,7 @@ class OperatorWithKernel : public OperatorBase { ...@@ -606,7 +606,7 @@ class OperatorWithKernel : public OperatorBase {
* When selecting Kernel during Op execution, select the arguments of the * When selecting Kernel during Op execution, select the arguments of the
* original Op according to the GetExpectedPtenKernelArgs returned arguments. * original Op according to the GetExpectedPtenKernelArgs returned arguments.
*/ */
virtual pten::KernelSignature GetExpectedPtenKernelArgs( pten::KernelSignature GetExpectedPtenKernelArgs(
const ExecutionContext& ctx) const; const ExecutionContext& ctx) const;
/* member functions for adapting to pten lib */ /* member functions for adapting to pten lib */
......
...@@ -64,13 +64,6 @@ class DigammaGradOp : public framework::OperatorWithKernel { ...@@ -64,13 +64,6 @@ class DigammaGradOp : public framework::OperatorWithKernel {
ctx->SetOutputDim(framework::GradVarName("X"), dout_dims); ctx->SetOutputDim(framework::GradVarName("X"), dout_dims);
ctx->ShareLoD(framework::GradVarName("Out"), framework::GradVarName("X")); ctx->ShareLoD(framework::GradVarName("Out"), framework::GradVarName("X"));
} }
framework::KernelSignature GetExpectedPtenKernelArgs(
const framework::ExecutionContext &ctx) const override {
return framework::KernelSignature("digamma_grad",
{framework::GradVarName("Out"), "X"}, {},
{framework::GradVarName("X")});
}
}; };
template <typename T> template <typename T>
......
...@@ -117,13 +117,6 @@ class DotGradOp : public framework::OperatorWithKernel { ...@@ -117,13 +117,6 @@ class DotGradOp : public framework::OperatorWithKernel {
ctx, framework::GradVarName("Out")), ctx, framework::GradVarName("Out")),
ctx.GetPlace()); ctx.GetPlace());
} }
framework::KernelSignature GetExpectedPtenKernelArgs(
const framework::ExecutionContext& ctx) const override {
return framework::KernelSignature(
"dot_grad", {"X", "Y", framework::GradVarName("Out")}, {},
{framework::GradVarName("X"), framework::GradVarName("Y")});
}
}; };
template <typename T> template <typename T>
......
...@@ -353,18 +353,6 @@ class ElementwiseOpGrad : public framework::OperatorWithKernel { ...@@ -353,18 +353,6 @@ class ElementwiseOpGrad : public framework::OperatorWithKernel {
tensor.place(), tensor.layout()); tensor.place(), tensor.layout());
} }
} }
framework::KernelSignature GetExpectedPtenKernelArgs(
const framework::ExecutionContext &ctx) const override {
if (Type() == "elementwise_add_grad") {
if (ctx.InputVar("X")->IsType<framework::LoDTensor>()) {
return framework::KernelSignature(
"add_grad", {"X", "Y", framework::GradVarName("Out")}, {"axis"},
{framework::GradVarName("X"), framework::GradVarName("Y")});
}
}
return framework::KernelSignature("None", {"X"}, {}, {"Out"});
}
}; };
class ElementwiseOpDoubleGrad : public framework::OperatorWithKernel { class ElementwiseOpDoubleGrad : public framework::OperatorWithKernel {
......
...@@ -421,12 +421,6 @@ class FlattenContiguousRangeGradOp : public framework::OperatorWithKernel { ...@@ -421,12 +421,6 @@ class FlattenContiguousRangeGradOp : public framework::OperatorWithKernel {
ctx, framework::GradVarName("Out")), ctx, framework::GradVarName("Out")),
ctx.device_context()); ctx.device_context());
} }
framework::KernelSignature GetExpectedPtenKernelArgs(
const framework::ExecutionContext &ctx) const override {
return framework::KernelSignature("flatten_grad",
{framework::GradVarName("Out"), "XShape"},
{}, {framework::GradVarName("X")});
}
}; };
DECLARE_INPLACE_OP_INFERER(FlattenOpInplaceInferer, {"X", "Out"}); DECLARE_INPLACE_OP_INFERER(FlattenOpInplaceInferer, {"X", "Out"});
DECLARE_INPLACE_OP_INFERER(FlattenGradInplaceInferer, DECLARE_INPLACE_OP_INFERER(FlattenGradInplaceInferer,
......
...@@ -389,14 +389,6 @@ class MatMulV2OpGrad : public framework::OperatorWithKernel { ...@@ -389,14 +389,6 @@ class MatMulV2OpGrad : public framework::OperatorWithKernel {
tensor.place(), tensor.layout()); tensor.place(), tensor.layout());
} }
} }
framework::KernelSignature GetExpectedPtenKernelArgs(
const framework::ExecutionContext& ctx) const override {
return framework::KernelSignature(
"matmul_grad", {"X", "Y", framework::GradVarName("Out")},
{"trans_x", "trans_y"},
{framework::GradVarName("X"), framework::GradVarName("Y")});
}
}; };
template <typename T> template <typename T>
...@@ -439,13 +431,6 @@ class MatMulV2OpDoubleGrad : public framework::OperatorWithKernel { ...@@ -439,13 +431,6 @@ class MatMulV2OpDoubleGrad : public framework::OperatorWithKernel {
context->ShareDim("DOut", "DDOut"); context->ShareDim("DOut", "DDOut");
} }
} }
framework::KernelSignature GetExpectedPtenKernelArgs(
const framework::ExecutionContext& ctx) const override {
return framework::KernelSignature(
"matmul_double_grad", {"X", "Y", "DOut", "DDX", "DDY"},
{"trans_x", "trans_y"}, {"DX", "DY", "DDOut"});
}
}; };
template <typename T> template <typename T>
...@@ -515,15 +500,6 @@ class MatMulV2OpTripleGrad : public framework::OperatorWithKernel { ...@@ -515,15 +500,6 @@ class MatMulV2OpTripleGrad : public framework::OperatorWithKernel {
context->ShareDim("Y", "D_DDY_out"); context->ShareDim("Y", "D_DDY_out");
} }
} }
framework::KernelSignature GetExpectedPtenKernelArgs(
const framework::ExecutionContext& ctx) const override {
return framework::KernelSignature(
"matmul_triple_grad",
{"X", "Y", "DOut", "DDX", "DDY", "D_DX", "D_DY", "D_DDOut"},
{"trans_x", "trans_y"},
{"D_X_out", "D_Y_out", "D_DOut_out", "D_DDX_out", "D_DDY_out"});
}
}; };
template <typename T> template <typename T>
......
...@@ -547,34 +547,6 @@ class ReduceOp : public framework::OperatorWithKernel { ...@@ -547,34 +547,6 @@ class ReduceOp : public framework::OperatorWithKernel {
} }
return framework::OpKernelType(input_data_type, ctx.GetPlace()); return framework::OpKernelType(input_data_type, ctx.GetPlace());
} }
framework::KernelSignature GetExpectedPtenKernelArgs(
const framework::ExecutionContext& ctx) const override {
bool reduce_all = ctx.Attr<bool>("reduce_all");
if (Type() == "reduce_sum") {
if (ctx.InputVar("X")->IsType<framework::LoDTensor>()) {
if (!reduce_all) {
return framework::KernelSignature(
"sum", {"X"}, {"dim", "out_dtype", "keep_dim"}, {"Out"});
}
return framework::KernelSignature(
"sum_raw", {"X"}, {"dim", "keep_dim", "reduce_all", "out_dtype"},
{"Out"});
}
}
if (Type() == "reduce_mean") {
if (ctx.InputVar("X")->IsType<framework::LoDTensor>()) {
if (!reduce_all) {
return framework::KernelSignature("mean", {"X"}, {"dim", "keep_dim"},
{"Out"});
}
return framework::KernelSignature(
"mean_raw", {"X"}, {"dim", "keep_dim", "reduce_all"}, {"Out"});
}
}
// TODO(chentianyu03): support other cases after selected rows added
return framework::KernelSignature("reduce.unregistered", {}, {}, {});
}
}; };
class ReduceOpUseInputPlace : public ReduceOp { class ReduceOpUseInputPlace : public ReduceOp {
......
...@@ -579,13 +579,6 @@ class Reshape2GradOp : public framework::OperatorWithKernel { ...@@ -579,13 +579,6 @@ class Reshape2GradOp : public framework::OperatorWithKernel {
return framework::OpKernelType(expected_kernel_type.data_type_, return framework::OpKernelType(expected_kernel_type.data_type_,
tensor.place(), tensor.layout()); tensor.place(), tensor.layout());
} }
framework::KernelSignature GetExpectedPtenKernelArgs(
const framework::ExecutionContext &ctx) const override {
return framework::KernelSignature("reshape_grad",
{framework::GradVarName("Out")}, {},
{framework::GradVarName("X")});
}
}; };
class Reshape2DoubleGradOp : public framework::OperatorWithKernel { class Reshape2DoubleGradOp : public framework::OperatorWithKernel {
...@@ -622,11 +615,6 @@ class Reshape2DoubleGradOp : public framework::OperatorWithKernel { ...@@ -622,11 +615,6 @@ class Reshape2DoubleGradOp : public framework::OperatorWithKernel {
return framework::OpKernelType(expected_kernel_type.data_type_, return framework::OpKernelType(expected_kernel_type.data_type_,
tensor.place(), tensor.layout()); tensor.place(), tensor.layout());
} }
framework::KernelSignature GetExpectedPtenKernelArgs(
const framework::ExecutionContext &ctx) const override {
return framework::KernelSignature("reshape_double_grad", {"DDX"}, {},
{"DDOut"});
}
}; };
DECLARE_INPLACE_OP_INFERER(ReshapeOpInplaceInferer, {"X", "Out"}); DECLARE_INPLACE_OP_INFERER(ReshapeOpInplaceInferer, {"X", "Out"});
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/pten/core/compat/op_utils.h"
namespace pten {
KernelSignature DigammaGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature(
"digamma_grad", {GradVarName("Out"), "X"}, {}, {GradVarName("X")});
}
} // namespace pten
PT_REGISTER_ARG_MAPPING_FN(digamma_grad, pten::DigammaGradOpArgumentMapping);
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/pten/core/compat/op_utils.h"
namespace pten {
KernelSignature DotGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("dot_grad",
{"X", "Y", GradVarName("Out")},
{},
{GradVarName("X"), GradVarName("Y")});
}
} // namespace pten
PT_REGISTER_ARG_MAPPING_FN(dot_grad, pten::DotGradOpArgumentMapping);
...@@ -64,6 +64,17 @@ KernelSignature ElementwiseDivOpArgumentMapping( ...@@ -64,6 +64,17 @@ KernelSignature ElementwiseDivOpArgumentMapping(
return KernelSignature("unregistered", {}, {}, {}); return KernelSignature("unregistered", {}, {}, {});
} }
KernelSignature ElementwiseAddGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
if (ctx.IsDenseTensorInput("X")) {
return KernelSignature("add_grad",
{"X", "Y", GradVarName("Out")},
{"axis"},
{GradVarName("X"), GradVarName("Y")});
}
return KernelSignature("unregistered", {}, {}, {});
}
} // namespace pten } // namespace pten
PT_REGISTER_BASE_KERNEL_NAME(elementwise_add, add); PT_REGISTER_BASE_KERNEL_NAME(elementwise_add, add);
...@@ -71,7 +82,6 @@ PT_REGISTER_BASE_KERNEL_NAME(elementwise_sub, subtract); ...@@ -71,7 +82,6 @@ PT_REGISTER_BASE_KERNEL_NAME(elementwise_sub, subtract);
PT_REGISTER_BASE_KERNEL_NAME(elementwise_mul, multiply); PT_REGISTER_BASE_KERNEL_NAME(elementwise_mul, multiply);
PT_REGISTER_BASE_KERNEL_NAME(elementwise_div, divide); PT_REGISTER_BASE_KERNEL_NAME(elementwise_div, divide);
PT_REGISTER_BASE_KERNEL_NAME(elementwise_add_grad, add_grad); PT_REGISTER_BASE_KERNEL_NAME(elementwise_add_grad, add_grad);
PT_REGISTER_BASE_KERNEL_NAME(elementwise_sub_grad, subtract_grad);
PT_REGISTER_ARG_MAPPING_FN(elementwise_add, PT_REGISTER_ARG_MAPPING_FN(elementwise_add,
pten::ElementwiseAddOpArgumentMapping); pten::ElementwiseAddOpArgumentMapping);
...@@ -81,3 +91,5 @@ PT_REGISTER_ARG_MAPPING_FN(elementwise_mul, ...@@ -81,3 +91,5 @@ PT_REGISTER_ARG_MAPPING_FN(elementwise_mul,
pten::ElementwiseMulOpArgumentMapping); pten::ElementwiseMulOpArgumentMapping);
PT_REGISTER_ARG_MAPPING_FN(elementwise_div, PT_REGISTER_ARG_MAPPING_FN(elementwise_div,
pten::ElementwiseDivOpArgumentMapping); pten::ElementwiseDivOpArgumentMapping);
PT_REGISTER_ARG_MAPPING_FN(elementwise_add_grad,
pten::ElementwiseAddGradOpArgumentMapping);
...@@ -28,6 +28,12 @@ KernelSignature FlattenOpArgumentMapping(const ArgumentMappingContext& ctx) { ...@@ -28,6 +28,12 @@ KernelSignature FlattenOpArgumentMapping(const ArgumentMappingContext& ctx) {
} }
} }
KernelSignature FlattenGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature(
"flatten_grad", {GradVarName("Out"), "XShape"}, {}, {GradVarName("X")});
}
} // namespace pten } // namespace pten
PT_REGISTER_BASE_KERNEL_NAME(flatten_contiguous_range, flatten); PT_REGISTER_BASE_KERNEL_NAME(flatten_contiguous_range, flatten);
...@@ -35,3 +41,5 @@ PT_REGISTER_BASE_KERNEL_NAME(flatten_contiguous_range_grad, flatten_grad); ...@@ -35,3 +41,5 @@ PT_REGISTER_BASE_KERNEL_NAME(flatten_contiguous_range_grad, flatten_grad);
PT_REGISTER_ARG_MAPPING_FN(flatten_contiguous_range, PT_REGISTER_ARG_MAPPING_FN(flatten_contiguous_range,
pten::FlattenOpArgumentMapping); pten::FlattenOpArgumentMapping);
PT_REGISTER_ARG_MAPPING_FN(flatten_contiguous_range_grad,
pten::FlattenGradOpArgumentMapping);
...@@ -14,9 +14,41 @@ limitations under the License. */ ...@@ -14,9 +14,41 @@ limitations under the License. */
#include "paddle/pten/core/compat/op_utils.h" #include "paddle/pten/core/compat/op_utils.h"
namespace pten {} // namespace pten namespace pten {
KernelSignature MatmulGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("matmul_grad",
{"X", "Y", GradVarName("Out")},
{"trans_x", "trans_y"},
{GradVarName("X"), GradVarName("Y")});
}
KernelSignature MatmulDoubleGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("matmul_double_grad",
{"X", "Y", "DOut", "DDX", "DDY"},
{"trans_x", "trans_y"},
{"DX", "DY", "DDOut"});
}
KernelSignature MatmulTripleGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature(
"matmul_triple_grad",
{"X", "Y", "DOut", "DDX", "DDY", "D_DX", "D_DY", "D_DDOut"},
{"trans_x", "trans_y"},
{"D_X_out", "D_Y_out", "D_DOut_out", "D_DDX_out", "D_DDY_out"});
}
} // namespace pten
PT_REGISTER_BASE_KERNEL_NAME(matmul_v2, matmul); PT_REGISTER_BASE_KERNEL_NAME(matmul_v2, matmul);
PT_REGISTER_BASE_KERNEL_NAME(matmul_v2_grad, matmul_grad); PT_REGISTER_BASE_KERNEL_NAME(matmul_v2_grad, matmul_grad);
PT_REGISTER_BASE_KERNEL_NAME(matmul_v2_grad_grad, matmul_double_grad); PT_REGISTER_BASE_KERNEL_NAME(matmul_v2_grad_grad, matmul_double_grad);
PT_REGISTER_BASE_KERNEL_NAME(matmul_v2_triple_grad, matmul_triple_grad); PT_REGISTER_BASE_KERNEL_NAME(matmul_v2_triple_grad, matmul_triple_grad);
PT_REGISTER_ARG_MAPPING_FN(matmul_v2_grad, pten::MatmulGradOpArgumentMapping);
PT_REGISTER_ARG_MAPPING_FN(matmul_v2_grad_grad,
pten::MatmulDoubleGradOpArgumentMapping);
PT_REGISTER_ARG_MAPPING_FN(matmul_v2_triple_grad,
pten::MatmulTripleGradOpArgumentMapping);
...@@ -21,7 +21,7 @@ KernelSignature ReduceSumOpArgumentMapping(const ArgumentMappingContext& ctx) { ...@@ -21,7 +21,7 @@ KernelSignature ReduceSumOpArgumentMapping(const ArgumentMappingContext& ctx) {
if (ctx.IsDenseTensorInput("X")) { if (ctx.IsDenseTensorInput("X")) {
if (!reduce_all) { if (!reduce_all) {
return KernelSignature( return KernelSignature(
"sum", {"X"}, {"dim", "keep_dim", "out_dtype"}, {"Out"}); "sum", {"X"}, {"dim", "out_dtype", "keep_dim"}, {"Out"});
} }
return KernelSignature("sum_raw", return KernelSignature("sum_raw",
{"X"}, {"X"},
......
...@@ -26,6 +26,17 @@ KernelSignature ReshapeOpArgumentMapping(const ArgumentMappingContext& ctx) { ...@@ -26,6 +26,17 @@ KernelSignature ReshapeOpArgumentMapping(const ArgumentMappingContext& ctx) {
} }
} }
KernelSignature ReshapeGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature(
"reshape_grad", {GradVarName("Out")}, {}, {GradVarName("X")});
}
KernelSignature ReshapeDoubleGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("reshape_double_grad", {"DDX"}, {}, {"DDOut"});
}
} // namespace pten } // namespace pten
PT_REGISTER_BASE_KERNEL_NAME(reshape2, reshape); PT_REGISTER_BASE_KERNEL_NAME(reshape2, reshape);
...@@ -33,3 +44,6 @@ PT_REGISTER_BASE_KERNEL_NAME(reshape2_grad, reshape_grad); ...@@ -33,3 +44,6 @@ PT_REGISTER_BASE_KERNEL_NAME(reshape2_grad, reshape_grad);
PT_REGISTER_BASE_KERNEL_NAME(reshape2_grad_grad, reshape_double_grad); PT_REGISTER_BASE_KERNEL_NAME(reshape2_grad_grad, reshape_double_grad);
PT_REGISTER_ARG_MAPPING_FN(reshape2, pten::ReshapeOpArgumentMapping); PT_REGISTER_ARG_MAPPING_FN(reshape2, pten::ReshapeOpArgumentMapping);
PT_REGISTER_ARG_MAPPING_FN(reshape2_grad, pten::ReshapeGradOpArgumentMapping);
PT_REGISTER_ARG_MAPPING_FN(reshape2_grad_grad,
pten::ReshapeDoubleGradOpArgumentMapping);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册