diff --git a/paddle/fluid/framework/pten_utils.cc b/paddle/fluid/framework/pten_utils.cc index cbd58592ef5617f732c1a8db6a3b206f6c56a9b7..b8aedcce3e3fa5931331747bd3faa2d0e758fd84 100644 --- a/paddle/fluid/framework/pten_utils.cc +++ b/paddle/fluid/framework/pten_utils.cc @@ -98,7 +98,8 @@ KernelSignatureMap& KernelSignatureMap::Instance() { for (const auto& pair : OpInfoMap::Instance().map()) { const auto& op_type = pair.first; const auto* op_proto = pair.second.proto_; - if (pten::KernelFactory::Instance().HasCompatiblePtenKernel(op_type)) { + if (pten::KernelFactory::Instance().HasCompatiblePtenKernel(op_type) && + op_proto != nullptr) { KernelArgsNameMakerByOpProto maker(op_proto); VLOG(10) << "Register kernel signature for " << op_type; auto success = kernel_signature_map_->map_ diff --git a/paddle/fluid/operators/flatten_op.cc b/paddle/fluid/operators/flatten_op.cc index a1b8dd6bae4945a7f1ea934792e8886278512b26..6b1ee00b55d62aa0ab3a5093aec329f7fcb10fd1 100644 --- a/paddle/fluid/operators/flatten_op.cc +++ b/paddle/fluid/operators/flatten_op.cc @@ -431,6 +431,12 @@ class FlattenContiguousRangeGradOp : public framework::OperatorWithKernel { ctx, framework::GradVarName("Out")), ctx.device_context()); } + framework::KernelSignature GetExpectedPtenKernelArgs( + const framework::ExecutionContext &ctx) const override { + return framework::KernelSignature("flatten_grad", + {framework::GradVarName("Out"), "XShape"}, + {}, {framework::GradVarName("X")}); + } }; DECLARE_INPLACE_OP_INFERER(FlattenOpInplaceInferer, {"X", "Out"}); DECLARE_INPLACE_OP_INFERER(FlattenGradInplaceInferer, diff --git a/paddle/fluid/operators/flatten_op.h b/paddle/fluid/operators/flatten_op.h index fa116d9516ecdb6d41c9a0097af7b47ee910ed8c..ef42619bfe4ff651d13c18c8fbfc203929c356e1 100644 --- a/paddle/fluid/operators/flatten_op.h +++ b/paddle/fluid/operators/flatten_op.h @@ -21,6 +21,8 @@ limitations under the License. */ #include "paddle/fluid/operators/math/pooling.h" #include "paddle/fluid/platform/device_context.h" #include "paddle/pten/include/core.h" +#include "paddle/pten/kernels/empty_kernel.h" +#include "paddle/pten/kernels/flatten_grad_kernel.h" #include "paddle/pten/kernels/flatten_kernel.h" namespace paddle { @@ -146,15 +148,25 @@ class FlattenContiguousRangeGradKernel : public framework::OpKernel { auto *d_x = ctx.Output(framework::GradVarName("X")); auto *d_out = ctx.Input(framework::GradVarName("Out")); - - auto xshape_dims = ctx.Input("XShape")->dims(); - auto x_dims = framework::slice_ddim(xshape_dims, 1, xshape_dims.size()); + auto *xshape = ctx.Input("XShape"); d_x->mutable_data(ctx.GetPlace(), d_out->type()); - framework::TensorCopy( - *d_out, ctx.GetPlace(), - ctx.template device_context(), d_x); - d_x->Resize(x_dims); + auto &dev_ctx = ctx.device_context(); + + auto pt_d_x = paddle::experimental::MakePtenDenseTensor(*d_x); + auto pt_d_out = paddle::experimental::MakePtenDenseTensor(*d_out); + + // Because the holder of xshape may be nullptr, we can't use + // MakePtenDenseTensor. + // So, we create a new DenseTensor to save the dims of xshape. + pten::DenseTensorMeta xshape_meta{pten::TransToPtenDataType(d_x->type()), + xshape->dims(), d_x->layout()}; + auto pt_xshape = + pten::Empty(dev_ctx, std::move(xshape_meta)); + + // call new kernel + pten::FlattenGradKernel(dev_ctx, *pt_d_out.get(), + pt_xshape, pt_d_x.get()); } }; diff --git a/paddle/pten/core/kernel_alias_name.h b/paddle/pten/core/kernel_alias_name.h index 3b8347dec772e93cfe533f4263c3937979025878..56f7eea7ea802dd94d4c5aecf82732dae27d3b8b 100644 --- a/paddle/pten/core/kernel_alias_name.h +++ b/paddle/pten/core/kernel_alias_name.h @@ -27,12 +27,14 @@ const std::unordered_map kernel_alias_name_map = { {"fill_any_like", "full_like"}, {"fill_constant", "full"}, {"flatten_contiguous_range", "flatten"}, + {"flatten_contiguous_range_grad", "flatten_grad"}, {"matmul_v2", "matmul"}, {"reduce_mean", "mean"}, {"reduce_sum", "sum"}, {"reshape2", "reshape"}, // fluid kernel "mean/reshape/matmul/flatten/sum" should be deprecated {"flatten", "deprecated"}, + {"flatten_grad", "deprecated"}, {"matmul", "deprecated"}, {"mean", "deprecated"}, {"reshape", "deprecated"}, diff --git a/paddle/pten/kernels/flatten_grad_kernel.cc b/paddle/pten/kernels/flatten_grad_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..d6aea31748d6cf7bbcea9e5c839fcadbe67c9b05 --- /dev/null +++ b/paddle/pten/kernels/flatten_grad_kernel.cc @@ -0,0 +1,73 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/pten/kernels/flatten_grad_kernel.h" +#include "paddle/pten/backends/all_context.h" +#include "paddle/pten/core/kernel_registry.h" +#include "paddle/pten/kernels/copy_kernel.h" + +namespace pten { + +template +void FlattenGradKernel(const Context& dev_ctx, + const DenseTensor& out_grad, + const DenseTensor& xshape, + DenseTensor* x_grad) { + auto xshape_dims = xshape.dims(); + auto x_dims = + paddle::framework::slice_ddim(xshape_dims, 1, xshape_dims.size()); + pten::Copy(dev_ctx, out_grad, false, x_grad); + x_grad->Resize(x_dims); +} + +} // namespace pten + +PT_REGISTER_CTX_KERNEL(flatten_grad, + CPU, + ALL_LAYOUT, + pten::FlattenGradKernel, + float, + double, + uint8_t, + int8_t, + int, + int64_t) {} + +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) +PT_REGISTER_CTX_KERNEL(flatten_grad, + GPU, + ALL_LAYOUT, + pten::FlattenGradKernel, + float, + paddle::platform::float16, + double, + uint8_t, + int8_t, + int, + int64_t) {} + +#endif + +#ifdef PADDLE_WITH_XPU +PT_REGISTER_CTX_KERNEL(flatten_grad, + XPU, + ALL_LAYOUT, + pten::FlattenGradKernel, + float, + paddle::platform::float16, + int8_t, + int, + int64_t) {} + +#endif diff --git a/paddle/pten/kernels/flatten_grad_kernel.h b/paddle/pten/kernels/flatten_grad_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..91d9aa7c3060971c9b943b96a560f0bdda6d90ae --- /dev/null +++ b/paddle/pten/kernels/flatten_grad_kernel.h @@ -0,0 +1,27 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/pten/core/dense_tensor.h" + +namespace pten { + +template +void FlattenGradKernel(const Context& dev_ctx, + const DenseTensor& out_grad, + const DenseTensor& xshape, + DenseTensor* x_grad); + +} // namespace pten diff --git a/paddle/pten/kernels/flatten_kernel.cc b/paddle/pten/kernels/flatten_kernel.cc index 37d4d88ccb40eeddec795c0f96393fbe752c71b6..b284d3690830f74b12cf52deb9a46457382004c8 100644 --- a/paddle/pten/kernels/flatten_kernel.cc +++ b/paddle/pten/kernels/flatten_kernel.cc @@ -103,8 +103,6 @@ PT_REGISTER_CTX_KERNEL(flatten, pten::FlattenKernel, float, paddle::platform::float16, - double, - uint8_t, int8_t, int, int64_t) {} @@ -115,8 +113,6 @@ PT_REGISTER_CTX_KERNEL(flatten_with_xshape, pten::FlattenWithXShape, float, paddle::platform::float16, - double, - uint8_t, int8_t, int, int64_t) {}