From 778008d7fc78fd2d57ac24d9c654ea594c8f511a Mon Sep 17 00:00:00 2001 From: YuanRisheng Date: Wed, 23 Mar 2022 14:49:43 +0800 Subject: [PATCH] [Phi]Remove InferShape and Kernel of flatten_contiguous_range op (#40638) * remove flatten infermeta * fix bugs when run inference ci * fix bugs when run inference ci * fix bugs when run ci * support infrt * inplace infershape code' --- paddle/fluid/framework/infershape_utils.cc | 355 +++++++++------------ paddle/fluid/framework/infershape_utils.h | 60 +++- paddle/fluid/operators/flatten_op.cc | 96 +----- paddle/fluid/operators/flatten_op.cu.cc | 31 -- paddle/fluid/operators/flatten_op.h | 41 --- paddle/fluid/operators/flatten_op_xpu.cc | 23 -- paddle/phi/infermeta/unary.cc | 16 + paddle/phi/infermeta/unary.h | 6 + paddle/phi/kernels/flatten_grad_kernel.cc | 1 + paddle/phi/kernels/flatten_kernel.cc | 2 +- 10 files changed, 254 insertions(+), 377 deletions(-) diff --git a/paddle/fluid/framework/infershape_utils.cc b/paddle/fluid/framework/infershape_utils.cc index 2babecc6dd..504fadedba 100644 --- a/paddle/fluid/framework/infershape_utils.cc +++ b/paddle/fluid/framework/infershape_utils.cc @@ -27,7 +27,6 @@ limitations under the License. */ #include "paddle/phi/core/compat/op_utils.h" #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/infermeta_utils.h" -#include "paddle/phi/core/meta_tensor.h" #include "paddle/phi/core/tensor_utils.h" namespace paddle { @@ -101,235 +100,197 @@ class InferShapeArgumentMappingContext : public phi::ArgumentMappingContext { const InferShapeContext& ctx_; }; -// TODO(chenweihang): Support TensorArray later -class CompatMetaTensor : public phi::MetaTensor { - public: - CompatMetaTensor(InferShapeVarPtr var, bool is_runtime) - : var_(std::move(var)), is_runtime_(is_runtime) {} - - CompatMetaTensor() = default; - CompatMetaTensor(const CompatMetaTensor&) = default; - CompatMetaTensor(CompatMetaTensor&&) = default; - CompatMetaTensor& operator=(const CompatMetaTensor&) = delete; - CompatMetaTensor& operator=(CompatMetaTensor&&) = delete; - - int64_t numel() const override { - if (is_runtime_) { - auto* var = BOOST_GET_CONST(Variable*, var_); - return var->Get().numel(); - } else { - auto* var = BOOST_GET_CONST(VarDesc*, var_); - return var->ElementSize(); - } +int64_t CompatMetaTensor::numel() const { + if (is_runtime_) { + auto* var = BOOST_GET_CONST(Variable*, var_); + return var->Get().numel(); + } else { + auto* var = BOOST_GET_CONST(VarDesc*, var_); + return var->ElementSize(); } +} - DDim dims() const override { - if (is_runtime_) { - auto* var = BOOST_GET_CONST(Variable*, var_); - if (var->IsType()) { - return var->Get().dims(); - } else if (var->IsType()) { - return var->Get().dims(); - } else if (var->IsType()) { - // use tensor array size as dims - auto& tensor_array = var->Get(); - return phi::make_ddim({static_cast(tensor_array.size())}); - } else { - PADDLE_THROW(platform::errors::Unimplemented( - "Currently, only can get dims from DenseTensor or SelectedRows or " - "DenseTensorArray.")); - } +DDim CompatMetaTensor::dims() const { + if (is_runtime_) { + auto* var = BOOST_GET_CONST(Variable*, var_); + if (var->IsType()) { + return var->Get().dims(); + } else if (var->IsType()) { + return var->Get().dims(); + } else if (var->IsType()) { + // use tensor array size as dims + auto& tensor_array = var->Get(); + return phi::make_ddim({static_cast(tensor_array.size())}); } else { - auto* var = BOOST_GET_CONST(VarDesc*, var_); - - return var->GetShape().empty() ? phi::make_ddim({0UL}) - : phi::make_ddim(var->GetShape()); + PADDLE_THROW(platform::errors::Unimplemented( + "Currently, only can get dims from DenseTensor or SelectedRows or " + "DenseTensorArray.")); } + } else { + auto* var = BOOST_GET_CONST(VarDesc*, var_); + + return var->GetShape().empty() ? phi::make_ddim({0UL}) + : phi::make_ddim(var->GetShape()); } +} - phi::DataType dtype() const override { - if (is_runtime_) { - auto* var = BOOST_GET_CONST(Variable*, var_); - if (var->IsType()) { - return var->Get().dtype(); - } else if (var->IsType()) { - return var->Get().dtype(); - } else if (var->IsType()) { - // NOTE(chenweihang): do nothing - // Unsupported get dtype from LoDTensorArray now - return phi::DataType::UNDEFINED; - } else { - PADDLE_THROW(platform::errors::Unimplemented( - "Currently, only can get dtype from DenseTensor or SelectedRows.")); - } +phi::DataType CompatMetaTensor::dtype() const { + if (is_runtime_) { + auto* var = BOOST_GET_CONST(Variable*, var_); + if (var->IsType()) { + return var->Get().dtype(); + } else if (var->IsType()) { + return var->Get().dtype(); + } else if (var->IsType()) { + // NOTE(chenweihang): do nothing + // Unsupported get dtype from LoDTensorArray now + return phi::DataType::UNDEFINED; } else { - auto* var = BOOST_GET_CONST(VarDesc*, var_); - return paddle::framework::TransToPhiDataType(var->GetDataType()); + PADDLE_THROW(platform::errors::Unimplemented( + "Currently, only can get dtype from DenseTensor or SelectedRows.")); } + } else { + auto* var = BOOST_GET_CONST(VarDesc*, var_); + return paddle::framework::TransToPhiDataType(var->GetDataType()); } +} - DataLayout layout() const override { - if (is_runtime_) { - auto* var = BOOST_GET_CONST(Variable*, var_); - if (var->IsType()) { - return var->Get().layout(); - } else if (var->IsType()) { - return var->Get().layout(); - } else if (var->IsType()) { - // NOTE(chenweihang): do nothing - // Unsupported get layout from LoDTensorArray now - return phi::DataLayout::UNDEFINED; - } else { - PADDLE_THROW(platform::errors::Unimplemented( - "Currently, only can get layout from DenseTensor or " - "SelectedRows.")); - } - } else { +DataLayout CompatMetaTensor::layout() const { + if (is_runtime_) { + auto* var = BOOST_GET_CONST(Variable*, var_); + if (var->IsType()) { + return var->Get().layout(); + } else if (var->IsType()) { + return var->Get().layout(); + } else if (var->IsType()) { // NOTE(chenweihang): do nothing - // Unsupported get layout for VarDesc now - return DataLayout::UNDEFINED; + // Unsupported get layout from LoDTensorArray now + return phi::DataLayout::UNDEFINED; + } else { + PADDLE_THROW(platform::errors::Unimplemented( + "Currently, only can get layout from DenseTensor or " + "SelectedRows.")); } + } else { + // NOTE(chenweihang): do nothing + // Unsupported get layout for VarDesc now + return DataLayout::UNDEFINED; } +} - void set_dims(const DDim& dims) override { - if (is_runtime_) { - auto* var = BOOST_GET(Variable*, var_); - if (var->IsType()) { - auto* tensor = var->GetMutable(); - phi::DenseTensorUtils::GetMutableMeta(tensor)->dims = dims; - } else if (var->IsType()) { - auto* tensor = var->GetMutable()->mutable_value(); - phi::DenseTensorUtils::GetMutableMeta(tensor)->dims = dims; - } else if (var->IsType()) { - auto* tensor_array = var->GetMutable(); - // Note: Here I want enforce `tensor_array->size() == 0UL`, because - // inplace using on LoDTensorArray is dangerous, but the unittest - // `test_list` contains this behavior - PADDLE_ENFORCE_EQ(dims.size(), 1UL, - platform::errors::InvalidArgument( - "LoDTensorArray can only have one dimension.")); - // only set the array size for LoDTensorArray input - tensor_array->resize(dims[0]); - } else { - PADDLE_THROW(platform::errors::Unimplemented( - "Currently, only can set dims from DenseTensor or SelectedRows.")); - } +void CompatMetaTensor::set_dims(const DDim& dims) { + if (is_runtime_) { + auto* var = BOOST_GET(Variable*, var_); + if (var->IsType()) { + auto* tensor = var->GetMutable(); + phi::DenseTensorUtils::GetMutableMeta(tensor)->dims = dims; + } else if (var->IsType()) { + auto* tensor = var->GetMutable()->mutable_value(); + phi::DenseTensorUtils::GetMutableMeta(tensor)->dims = dims; + } else if (var->IsType()) { + auto* tensor_array = var->GetMutable(); + // Note: Here I want enforce `tensor_array->size() == 0UL`, because + // inplace using on LoDTensorArray is dangerous, but the unittest + // `test_list` contains this behavior + PADDLE_ENFORCE_EQ(dims.size(), 1UL, + platform::errors::InvalidArgument( + "LoDTensorArray can only have one dimension.")); + // only set the array size for LoDTensorArray input + tensor_array->resize(dims[0]); } else { - auto* var = BOOST_GET(VarDesc*, var_); - var->SetShape(vectorize(dims)); + PADDLE_THROW(platform::errors::Unimplemented( + "Currently, only can set dims from DenseTensor or SelectedRows.")); } + } else { + auto* var = BOOST_GET(VarDesc*, var_); + var->SetShape(vectorize(dims)); } +} - void set_dtype(phi::DataType dtype) override { - if (is_runtime_) { - auto* var = BOOST_GET(Variable*, var_); - if (var->IsType()) { - auto* tensor = var->GetMutable(); - phi::DenseTensorUtils::GetMutableMeta(tensor)->dtype = dtype; - } else if (var->IsType()) { - auto* tensor = var->GetMutable()->mutable_value(); - phi::DenseTensorUtils::GetMutableMeta(tensor)->dtype = dtype; - } else if (var->IsType()) { - // NOTE(chenweihang): do nothing - // Unsupported set dtype for LoDTensorArray now - } else { - PADDLE_THROW(platform::errors::Unimplemented( - "Currently, only can set dtype from DenseTensor or SelectedRows.")); - } +void CompatMetaTensor::set_dtype(phi::DataType dtype) { + if (is_runtime_) { + auto* var = BOOST_GET(Variable*, var_); + if (var->IsType()) { + auto* tensor = var->GetMutable(); + phi::DenseTensorUtils::GetMutableMeta(tensor)->dtype = dtype; + } else if (var->IsType()) { + auto* tensor = var->GetMutable()->mutable_value(); + phi::DenseTensorUtils::GetMutableMeta(tensor)->dtype = dtype; + } else if (var->IsType()) { + // NOTE(chenweihang): do nothing + // Unsupported set dtype for LoDTensorArray now } else { - auto* var = BOOST_GET(VarDesc*, var_); - var->SetDataType(paddle::framework::TransToProtoVarType(dtype)); + PADDLE_THROW(platform::errors::Unimplemented( + "Currently, only can set dtype from DenseTensor or SelectedRows.")); } + } else { + auto* var = BOOST_GET(VarDesc*, var_); + var->SetDataType(paddle::framework::TransToProtoVarType(dtype)); } +} - void set_layout(DataLayout layout) override { - if (is_runtime_) { - auto* var = BOOST_GET(Variable*, var_); - if (var->IsType()) { - auto* tensor = var->GetMutable(); - phi::DenseTensorUtils::GetMutableMeta(tensor)->layout = layout; - } else if (var->IsType()) { - auto* tensor = var->GetMutable()->mutable_value(); - phi::DenseTensorUtils::GetMutableMeta(tensor)->layout = layout; - } else if (var->IsType()) { - // NOTE(chenweihang): do nothing - // Unsupported set dtype for LoDTensorArray now - } else { - PADDLE_THROW(platform::errors::Unimplemented( - "Currently, only can set layout from DenseTensor or " - "SelectedRows.")); - } - } else { +void CompatMetaTensor::set_layout(DataLayout layout) { + if (is_runtime_) { + auto* var = BOOST_GET(Variable*, var_); + if (var->IsType()) { + auto* tensor = var->GetMutable(); + phi::DenseTensorUtils::GetMutableMeta(tensor)->layout = layout; + } else if (var->IsType()) { + auto* tensor = var->GetMutable()->mutable_value(); + phi::DenseTensorUtils::GetMutableMeta(tensor)->layout = layout; + } else if (var->IsType()) { // NOTE(chenweihang): do nothing - // Unsupported set layout for VarDesc now + // Unsupported set dtype for LoDTensorArray now + } else { + PADDLE_THROW(platform::errors::Unimplemented( + "Currently, only can set layout from DenseTensor or " + "SelectedRows.")); } + } else { + // NOTE(chenweihang): do nothing + // Unsupported set layout for VarDesc now } +} - void share_lod(const MetaTensor& meta_tensor) override { - if (is_runtime_) { - auto* var = BOOST_GET(Variable*, var_); - if (var->IsType()) { - auto* tensor = var->GetMutable(); - phi::DenseTensorUtils::GetMutableMeta(tensor)->lod = - static_cast(meta_tensor).GetRuntimeLoD(); - } else { - // NOTE(chenweihang): do nothing - // only LoDTensor need to share lod - } +void CompatMetaTensor::share_lod(const MetaTensor& meta_tensor) { + if (is_runtime_) { + auto* var = BOOST_GET(Variable*, var_); + if (var->IsType()) { + auto* tensor = var->GetMutable(); + phi::DenseTensorUtils::GetMutableMeta(tensor)->lod = + static_cast(meta_tensor).GetRuntimeLoD(); } else { - auto* var = BOOST_GET(VarDesc*, var_); - var->SetLoDLevel(static_cast(meta_tensor) - .GetCompileTimeLoD()); + // NOTE(chenweihang): do nothing + // only LoDTensor need to share lod } + } else { + auto* var = BOOST_GET(VarDesc*, var_); + var->SetLoDLevel( + static_cast(meta_tensor).GetCompileTimeLoD()); } +} - void share_dims(const MetaTensor& meta_tensor) override { - set_dims(meta_tensor.dims()); - if (is_runtime_) { - auto* var = BOOST_GET(Variable*, var_); - if (var->IsType()) { - auto* selected_rows = var->GetMutable(); - auto& input_selected_rows = - static_cast(meta_tensor).GetSelectedRows(); - selected_rows->set_rows(input_selected_rows.rows()); - selected_rows->set_height(input_selected_rows.height()); - } +void CompatMetaTensor::share_dims(const MetaTensor& meta_tensor) { + set_dims(meta_tensor.dims()); + if (is_runtime_) { + auto* var = BOOST_GET(Variable*, var_); + if (var->IsType()) { + auto* selected_rows = var->GetMutable(); + auto& input_selected_rows = + static_cast(meta_tensor).GetSelectedRows(); + selected_rows->set_rows(input_selected_rows.rows()); + selected_rows->set_height(input_selected_rows.height()); } } +} - void share_meta(const MetaTensor& meta_tensor) override { - share_dims(meta_tensor); - set_dtype(meta_tensor.dtype()); - set_layout(meta_tensor.layout()); - // special case: share lod of LoDTensor - share_lod(meta_tensor); - } - - private: - const LoD& GetRuntimeLoD() const { - auto* var = BOOST_GET_CONST(Variable*, var_); - return var->Get().lod(); - } - - int32_t GetCompileTimeLoD() const { - auto* var = BOOST_GET_CONST(VarDesc*, var_); - return var->GetLoDLevel(); - } - - const phi::SelectedRows& GetSelectedRows() const { - PADDLE_ENFORCE_EQ(is_runtime_, true, - platform::errors::Unavailable( - "Only can get Tensor from MetaTensor in rumtime.")); - auto* var = BOOST_GET_CONST(Variable*, var_); - PADDLE_ENFORCE_EQ(var->IsType(), true, - platform::errors::Unavailable( - "The Tensor in MetaTensor is not SelectedRows.")); - return var->Get(); - } - - InferShapeVarPtr var_; - bool is_runtime_; -}; +void CompatMetaTensor::share_meta(const MetaTensor& meta_tensor) { + share_dims(meta_tensor); + set_dtype(meta_tensor.dtype()); + set_layout(meta_tensor.layout()); + // special case: share lod of LoDTensor + share_lod(meta_tensor); +} phi::InferMetaContext BuildInferMetaContext(InferShapeContext* ctx, const std::string& op_type) { diff --git a/paddle/fluid/framework/infershape_utils.h b/paddle/fluid/framework/infershape_utils.h index b692b6ffab..022f194b66 100644 --- a/paddle/fluid/framework/infershape_utils.h +++ b/paddle/fluid/framework/infershape_utils.h @@ -18,7 +18,7 @@ limitations under the License. */ #include "paddle/fluid/framework/op_info.h" #include "paddle/fluid/framework/shape_inference.h" - +#include "paddle/phi/core/meta_tensor.h" namespace phi { class InferMetaContext; } // namespace phi @@ -39,5 +39,63 @@ phi::InferMetaContext BuildInferMetaContext(InferShapeContext* ctx, } \ } +// TODO(chenweihang): Support TensorArray later +class CompatMetaTensor : public phi::MetaTensor { + public: + CompatMetaTensor(InferShapeVarPtr var, bool is_runtime) + : var_(std::move(var)), is_runtime_(is_runtime) {} + + CompatMetaTensor() = default; + CompatMetaTensor(const CompatMetaTensor&) = default; + CompatMetaTensor(CompatMetaTensor&&) = default; + CompatMetaTensor& operator=(const CompatMetaTensor&) = delete; + CompatMetaTensor& operator=(CompatMetaTensor&&) = delete; + + int64_t numel() const override; + + DDim dims() const override; + + phi::DataType dtype() const override; + + DataLayout layout() const override; + + void set_dims(const DDim& dims) override; + + void set_dtype(phi::DataType dtype) override; + + void set_layout(DataLayout layout) override; + + void share_lod(const MetaTensor& meta_tensor) override; + + void share_dims(const MetaTensor& meta_tensor) override; + + void share_meta(const MetaTensor& meta_tensor) override; + + private: + const LoD& GetRuntimeLoD() const { + auto* var = BOOST_GET_CONST(Variable*, var_); + return var->Get().lod(); + } + + int32_t GetCompileTimeLoD() const { + auto* var = BOOST_GET_CONST(VarDesc*, var_); + return var->GetLoDLevel(); + } + + const phi::SelectedRows& GetSelectedRows() const { + PADDLE_ENFORCE_EQ(is_runtime_, true, + platform::errors::Unavailable( + "Only can get Tensor from MetaTensor in rumtime.")); + auto* var = BOOST_GET_CONST(Variable*, var_); + PADDLE_ENFORCE_EQ(var->IsType(), true, + platform::errors::Unavailable( + "The Tensor in MetaTensor is not SelectedRows.")); + return var->Get(); + } + + InferShapeVarPtr var_; + bool is_runtime_; +}; + } // namespace framework } // namespace paddle diff --git a/paddle/fluid/operators/flatten_op.cc b/paddle/fluid/operators/flatten_op.cc index dd172d53ef..b0a7007755 100644 --- a/paddle/fluid/operators/flatten_op.cc +++ b/paddle/fluid/operators/flatten_op.cc @@ -17,7 +17,10 @@ limitations under the License. */ #include #include #include +#include "paddle/fluid/framework/infershape_utils.h" #include "paddle/fluid/framework/op_registry.h" +#include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/infermeta/unary.h" namespace paddle { namespace operators { @@ -270,70 +273,24 @@ class Flatten2GradOp : public framework::OperatorWithKernel { class FlattenContiguousRangeOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext *ctx) const override { OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "FlattenContiguousRange"); OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "FlattenContiguousRange"); const auto &start_axis = ctx->Attrs().Get("start_axis"); const auto &stop_axis = ctx->Attrs().Get("stop_axis"); - const auto &in_dims = ctx->GetInputDim("X"); - int in_dims_size = in_dims.size(); - int real_start_axis = start_axis, real_stop_axis = stop_axis; - if (start_axis < 0) { - real_start_axis = start_axis + in_dims_size; - } - if (stop_axis < 0) { - real_stop_axis = stop_axis + in_dims_size; - } - PADDLE_ENFORCE_GE( - real_stop_axis, real_start_axis, - platform::errors::InvalidArgument("The stop_axis should be greater" - "than or equal to start_axis.")); - const auto &out_dims = - GetOutputShape(real_start_axis, real_stop_axis, in_dims); - ctx->SetOutputDim("Out", phi::make_ddim(out_dims)); - if (in_dims[0] == out_dims[0]) { - // Only pass LoD when the first dimension of output and Input(X) - // are the same. - ctx->ShareLoD("X", "Out"); - } - if (!ctx->HasOutput("XShape")) return; - // OP_INOUT_CHECK(ctx->HasOutput("XShape"), "Output", "XShape", "Flatten2"); - std::vector xshape_dims(in_dims.size() + 1); - xshape_dims[0] = 0; - for (int i = 0; i < in_dims.size(); ++i) { - xshape_dims[i + 1] = in_dims[i]; + // Construct MetaTensor for InferMeta Func + using CompatMetaTensor = framework::CompatMetaTensor; + CompatMetaTensor x(ctx->GetInputVarPtrs("X")[0], ctx->IsRuntime()); + CompatMetaTensor out(ctx->GetOutputVarPtrs("Out")[0], ctx->IsRuntime()); + std::unique_ptr xshape(nullptr); + if (ctx->HasOutput("XShape")) { + xshape = std::move(std::unique_ptr(new CompatMetaTensor( + ctx->GetOutputVarPtrs("XShape")[0], ctx->IsRuntime()))); } - ctx->SetOutputDim("XShape", phi::make_ddim(xshape_dims)); - ctx->ShareLoD("X", "XShape"); - } - - static std::vector GetOutputShape(const int start_axis, - const int stop_axis, - const framework::DDim &in_dims) { - int64_t outer = 1; - std::vector out_shape; - int in_dims_size = in_dims.size(); - out_shape.reserve(in_dims_size - stop_axis + start_axis); - - for (int i = 0; i < start_axis; ++i) { - out_shape.push_back(in_dims[i]); - } - for (int i = start_axis; i <= stop_axis; i++) { - if (in_dims[i] == -1 || outer == -1) { - outer = -1; - } else { - outer *= in_dims[i]; - } - } - out_shape.push_back(outer); - for (int i = stop_axis + 1; i < in_dims_size; i++) { - out_shape.push_back(in_dims[i]); - } - - return out_shape; + phi::FlattenWithXShapeInferMeta(x, start_axis, stop_axis, &out, + xshape.get()); } }; @@ -487,30 +444,3 @@ REGISTER_OP_CPU_KERNEL( ops::Flatten2GradKernel, ops::Flatten2GradKernel, ops::Flatten2GradKernel); -REGISTER_OP_CPU_KERNEL( - flatten_contiguous_range, - ops::FlattenContiguousRangeKernel, - ops::FlattenContiguousRangeKernel, - ops::FlattenContiguousRangeKernel, - ops::FlattenContiguousRangeKernel, - ops::FlattenContiguousRangeKernel, - ops::FlattenContiguousRangeKernel); -REGISTER_OP_CPU_KERNEL( - flatten_contiguous_range_grad, - ops::FlattenContiguousRangeGradKernel, - ops::FlattenContiguousRangeGradKernel, - ops::FlattenContiguousRangeGradKernel, - ops::FlattenContiguousRangeGradKernel, - ops::FlattenContiguousRangeGradKernel, - ops::FlattenContiguousRangeGradKernel); diff --git a/paddle/fluid/operators/flatten_op.cu.cc b/paddle/fluid/operators/flatten_op.cu.cc index e0987288ab..4796bff5e2 100644 --- a/paddle/fluid/operators/flatten_op.cu.cc +++ b/paddle/fluid/operators/flatten_op.cu.cc @@ -47,34 +47,3 @@ REGISTER_OP_CUDA_KERNEL( ops::Flatten2GradKernel, ops::Flatten2GradKernel, ops::Flatten2GradKernel); -REGISTER_OP_CUDA_KERNEL( - flatten_contiguous_range, - ops::FlattenContiguousRangeKernel, - ops::FlattenContiguousRangeKernel, - ops::FlattenContiguousRangeKernel, - ops::FlattenContiguousRangeKernel, - ops::FlattenContiguousRangeKernel, - ops::FlattenContiguousRangeKernel, - ops::FlattenContiguousRangeKernel); -REGISTER_OP_CUDA_KERNEL( - flatten_contiguous_range_grad, - ops::FlattenContiguousRangeGradKernel, - ops::FlattenContiguousRangeGradKernel, - ops::FlattenContiguousRangeGradKernel, - ops::FlattenContiguousRangeGradKernel, - ops::FlattenContiguousRangeGradKernel, - ops::FlattenContiguousRangeGradKernel, - ops::FlattenContiguousRangeGradKernel); diff --git a/paddle/fluid/operators/flatten_op.h b/paddle/fluid/operators/flatten_op.h index feae954e35..cacd30cad8 100644 --- a/paddle/fluid/operators/flatten_op.h +++ b/paddle/fluid/operators/flatten_op.h @@ -119,46 +119,5 @@ class Flatten2GradKernel : public framework::OpKernel { } }; -template -class FlattenContiguousRangeKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext &context) const override { - auto *in = context.Input("X"); - auto *out = context.Output("Out"); - out->mutable_data(context.GetPlace(), in->type()); - auto &start_axis = context.Attr("start_axis"); - auto &stop_axis = context.Attr("stop_axis"); - auto &dev_ctx = context.device_context(); - - // call new kernel - phi::FlattenKernel::TYPE>( - static_cast::TYPE &>(dev_ctx), - *in, start_axis, stop_axis, out); - } -}; - -template -class FlattenContiguousRangeGradKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext &ctx) const override { - auto *d_x = ctx.Output(framework::GradVarName("X")); - auto *d_out = - ctx.Input(framework::GradVarName("Out")); - auto *xshape = ctx.Input("XShape"); - - d_x->mutable_data(ctx.GetPlace(), d_out->type()); - auto &dev_ctx = ctx.device_context(); - - // call new kernel - phi::FlattenGradKernel::TYPE>( - static_cast::TYPE &>(dev_ctx), - *d_out, *xshape, d_x); - } -}; - } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/flatten_op_xpu.cc b/paddle/fluid/operators/flatten_op_xpu.cc index 53c0c688fd..cc2f65bba6 100644 --- a/paddle/fluid/operators/flatten_op_xpu.cc +++ b/paddle/fluid/operators/flatten_op_xpu.cc @@ -41,27 +41,4 @@ REGISTER_OP_XPU_KERNEL( ops::Flatten2GradKernel, ops::Flatten2GradKernel, ops::Flatten2GradKernel); -REGISTER_OP_XPU_KERNEL( - flatten_contiguous_range, - ops::FlattenContiguousRangeKernel, - ops::FlattenContiguousRangeKernel, - ops::FlattenContiguousRangeKernel, - ops::FlattenContiguousRangeKernel, - ops::FlattenContiguousRangeKernel); -REGISTER_OP_XPU_KERNEL( - flatten_contiguous_range_grad, - ops::FlattenContiguousRangeGradKernel, - ops::FlattenContiguousRangeGradKernel, - ops::FlattenContiguousRangeGradKernel, - ops::FlattenContiguousRangeGradKernel, - ops::FlattenContiguousRangeGradKernel); #endif diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index e44032285a..160e8ef56f 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -352,6 +352,14 @@ void FlattenInferMeta(const MetaTensor& x, int start_axis, int stop_axis, MetaTensor* out) { + FlattenWithXShapeInferMeta(x, start_axis, stop_axis, out, nullptr); +} + +void FlattenWithXShapeInferMeta(const MetaTensor& x, + int start_axis, + int stop_axis, + MetaTensor* out, + MetaTensor* xshape) { auto x_dims = x.dims(); int in_dims_size = x_dims.size(); if (start_axis < 0) { @@ -394,6 +402,14 @@ void FlattenInferMeta(const MetaTensor& x, // are the same. out->share_lod(x); } + if (xshape == nullptr) return; + std::vector xshape_dims(x_dims.size() + 1); + xshape_dims[0] = 0; + for (int i = 0; i < x_dims.size(); ++i) { + xshape_dims[i + 1] = x_dims[i]; + } + xshape->set_dims(phi::make_ddim(xshape_dims)); + xshape->share_lod(x); } void GumbelSoftmaxInferMeta(const MetaTensor& x, diff --git a/paddle/phi/infermeta/unary.h b/paddle/phi/infermeta/unary.h index f623f14a70..6187c49de1 100644 --- a/paddle/phi/infermeta/unary.h +++ b/paddle/phi/infermeta/unary.h @@ -86,6 +86,12 @@ void FlattenInferMeta(const MetaTensor& x, int stop_axis, MetaTensor* out); +void FlattenWithXShapeInferMeta(const MetaTensor& x, + int start_axis, + int stop_axis, + MetaTensor* out, + MetaTensor* xshape); + void GumbelSoftmaxInferMeta(const MetaTensor& x, float temperature, bool hard, diff --git a/paddle/phi/kernels/flatten_grad_kernel.cc b/paddle/phi/kernels/flatten_grad_kernel.cc index f6ba272500..b7b45e46cf 100644 --- a/paddle/phi/kernels/flatten_grad_kernel.cc +++ b/paddle/phi/kernels/flatten_grad_kernel.cc @@ -25,6 +25,7 @@ void FlattenGradKernel(const Context& dev_ctx, const DenseTensor& xshape, DenseTensor* x_grad) { auto xshape_dims = xshape.dims(); + dev_ctx.Alloc(x_grad, out_grad.dtype()); auto x_dims = phi::slice_ddim(xshape_dims, 1, xshape_dims.size()); phi::Copy(dev_ctx, out_grad, dev_ctx.GetPlace(), false, x_grad); x_grad->Resize(x_dims); diff --git a/paddle/phi/kernels/flatten_kernel.cc b/paddle/phi/kernels/flatten_kernel.cc index 78ac9eaa78..f304e7706a 100644 --- a/paddle/phi/kernels/flatten_kernel.cc +++ b/paddle/phi/kernels/flatten_kernel.cc @@ -27,6 +27,7 @@ void FlattenKernel(const Context& dev_ctx, int start_axis, int stop_axis, DenseTensor* out) { + dev_ctx.Alloc(out, x.dtype()); auto out_dims = out->dims(); phi::Copy(dev_ctx, x, dev_ctx.GetPlace(), false, out); out->Resize(out_dims); @@ -43,7 +44,6 @@ void FlattenWithXShape(const Context& dev_ctx, DenseTensor* out, DenseTensor* xshape) { FlattenKernel(dev_ctx, x, start_axis, stop_axis, out); - funcs::SetXShape(x, xshape); } } // namespace phi -- GitLab