未验证 提交 778008d7 编写于 作者: Y YuanRisheng 提交者: GitHub

[Phi]Remove InferShape and Kernel of flatten_contiguous_range op (#40638)

* remove flatten infermeta

* fix bugs when run inference ci

* fix bugs when run inference ci

* fix bugs when run ci

* support infrt

* inplace infershape code'
上级 f4075db8
...@@ -27,7 +27,6 @@ limitations under the License. */ ...@@ -27,7 +27,6 @@ limitations under the License. */
#include "paddle/phi/core/compat/op_utils.h" #include "paddle/phi/core/compat/op_utils.h"
#include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/infermeta_utils.h" #include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/core/meta_tensor.h"
#include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/core/tensor_utils.h"
namespace paddle { namespace paddle {
...@@ -101,235 +100,197 @@ class InferShapeArgumentMappingContext : public phi::ArgumentMappingContext { ...@@ -101,235 +100,197 @@ class InferShapeArgumentMappingContext : public phi::ArgumentMappingContext {
const InferShapeContext& ctx_; const InferShapeContext& ctx_;
}; };
// TODO(chenweihang): Support TensorArray later int64_t CompatMetaTensor::numel() const {
class CompatMetaTensor : public phi::MetaTensor { if (is_runtime_) {
public: auto* var = BOOST_GET_CONST(Variable*, var_);
CompatMetaTensor(InferShapeVarPtr var, bool is_runtime) return var->Get<Tensor>().numel();
: var_(std::move(var)), is_runtime_(is_runtime) {} } else {
auto* var = BOOST_GET_CONST(VarDesc*, var_);
CompatMetaTensor() = default; return var->ElementSize();
CompatMetaTensor(const CompatMetaTensor&) = default;
CompatMetaTensor(CompatMetaTensor&&) = default;
CompatMetaTensor& operator=(const CompatMetaTensor&) = delete;
CompatMetaTensor& operator=(CompatMetaTensor&&) = delete;
int64_t numel() const override {
if (is_runtime_) {
auto* var = BOOST_GET_CONST(Variable*, var_);
return var->Get<Tensor>().numel();
} else {
auto* var = BOOST_GET_CONST(VarDesc*, var_);
return var->ElementSize();
}
} }
}
DDim dims() const override { DDim CompatMetaTensor::dims() const {
if (is_runtime_) { if (is_runtime_) {
auto* var = BOOST_GET_CONST(Variable*, var_); auto* var = BOOST_GET_CONST(Variable*, var_);
if (var->IsType<phi::DenseTensor>()) { if (var->IsType<phi::DenseTensor>()) {
return var->Get<phi::DenseTensor>().dims(); return var->Get<phi::DenseTensor>().dims();
} else if (var->IsType<phi::SelectedRows>()) { } else if (var->IsType<phi::SelectedRows>()) {
return var->Get<phi::SelectedRows>().dims(); return var->Get<phi::SelectedRows>().dims();
} else if (var->IsType<framework::LoDTensorArray>()) { } else if (var->IsType<framework::LoDTensorArray>()) {
// use tensor array size as dims // use tensor array size as dims
auto& tensor_array = var->Get<framework::LoDTensorArray>(); auto& tensor_array = var->Get<framework::LoDTensorArray>();
return phi::make_ddim({static_cast<int64_t>(tensor_array.size())}); return phi::make_ddim({static_cast<int64_t>(tensor_array.size())});
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Currently, only can get dims from DenseTensor or SelectedRows or "
"DenseTensorArray."));
}
} else { } else {
auto* var = BOOST_GET_CONST(VarDesc*, var_); PADDLE_THROW(platform::errors::Unimplemented(
"Currently, only can get dims from DenseTensor or SelectedRows or "
return var->GetShape().empty() ? phi::make_ddim({0UL}) "DenseTensorArray."));
: phi::make_ddim(var->GetShape());
} }
} else {
auto* var = BOOST_GET_CONST(VarDesc*, var_);
return var->GetShape().empty() ? phi::make_ddim({0UL})
: phi::make_ddim(var->GetShape());
} }
}
phi::DataType dtype() const override { phi::DataType CompatMetaTensor::dtype() const {
if (is_runtime_) { if (is_runtime_) {
auto* var = BOOST_GET_CONST(Variable*, var_); auto* var = BOOST_GET_CONST(Variable*, var_);
if (var->IsType<phi::DenseTensor>()) { if (var->IsType<phi::DenseTensor>()) {
return var->Get<phi::DenseTensor>().dtype(); return var->Get<phi::DenseTensor>().dtype();
} else if (var->IsType<phi::SelectedRows>()) { } else if (var->IsType<phi::SelectedRows>()) {
return var->Get<phi::SelectedRows>().dtype(); return var->Get<phi::SelectedRows>().dtype();
} else if (var->IsType<framework::LoDTensorArray>()) { } else if (var->IsType<framework::LoDTensorArray>()) {
// NOTE(chenweihang): do nothing // NOTE(chenweihang): do nothing
// Unsupported get dtype from LoDTensorArray now // Unsupported get dtype from LoDTensorArray now
return phi::DataType::UNDEFINED; return phi::DataType::UNDEFINED;
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Currently, only can get dtype from DenseTensor or SelectedRows."));
}
} else { } else {
auto* var = BOOST_GET_CONST(VarDesc*, var_); PADDLE_THROW(platform::errors::Unimplemented(
return paddle::framework::TransToPhiDataType(var->GetDataType()); "Currently, only can get dtype from DenseTensor or SelectedRows."));
} }
} else {
auto* var = BOOST_GET_CONST(VarDesc*, var_);
return paddle::framework::TransToPhiDataType(var->GetDataType());
} }
}
DataLayout layout() const override { DataLayout CompatMetaTensor::layout() const {
if (is_runtime_) { if (is_runtime_) {
auto* var = BOOST_GET_CONST(Variable*, var_); auto* var = BOOST_GET_CONST(Variable*, var_);
if (var->IsType<phi::DenseTensor>()) { if (var->IsType<phi::DenseTensor>()) {
return var->Get<phi::DenseTensor>().layout(); return var->Get<phi::DenseTensor>().layout();
} else if (var->IsType<phi::SelectedRows>()) { } else if (var->IsType<phi::SelectedRows>()) {
return var->Get<phi::SelectedRows>().layout(); return var->Get<phi::SelectedRows>().layout();
} else if (var->IsType<framework::LoDTensorArray>()) { } else if (var->IsType<framework::LoDTensorArray>()) {
// NOTE(chenweihang): do nothing
// Unsupported get layout from LoDTensorArray now
return phi::DataLayout::UNDEFINED;
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Currently, only can get layout from DenseTensor or "
"SelectedRows."));
}
} else {
// NOTE(chenweihang): do nothing // NOTE(chenweihang): do nothing
// Unsupported get layout for VarDesc now // Unsupported get layout from LoDTensorArray now
return DataLayout::UNDEFINED; return phi::DataLayout::UNDEFINED;
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Currently, only can get layout from DenseTensor or "
"SelectedRows."));
} }
} else {
// NOTE(chenweihang): do nothing
// Unsupported get layout for VarDesc now
return DataLayout::UNDEFINED;
} }
}
void set_dims(const DDim& dims) override { void CompatMetaTensor::set_dims(const DDim& dims) {
if (is_runtime_) { if (is_runtime_) {
auto* var = BOOST_GET(Variable*, var_); auto* var = BOOST_GET(Variable*, var_);
if (var->IsType<phi::DenseTensor>()) { if (var->IsType<phi::DenseTensor>()) {
auto* tensor = var->GetMutable<phi::DenseTensor>(); auto* tensor = var->GetMutable<phi::DenseTensor>();
phi::DenseTensorUtils::GetMutableMeta(tensor)->dims = dims; phi::DenseTensorUtils::GetMutableMeta(tensor)->dims = dims;
} else if (var->IsType<phi::SelectedRows>()) { } else if (var->IsType<phi::SelectedRows>()) {
auto* tensor = var->GetMutable<phi::SelectedRows>()->mutable_value(); auto* tensor = var->GetMutable<phi::SelectedRows>()->mutable_value();
phi::DenseTensorUtils::GetMutableMeta(tensor)->dims = dims; phi::DenseTensorUtils::GetMutableMeta(tensor)->dims = dims;
} else if (var->IsType<framework::LoDTensorArray>()) { } else if (var->IsType<framework::LoDTensorArray>()) {
auto* tensor_array = var->GetMutable<framework::LoDTensorArray>(); auto* tensor_array = var->GetMutable<framework::LoDTensorArray>();
// Note: Here I want enforce `tensor_array->size() == 0UL`, because // Note: Here I want enforce `tensor_array->size() == 0UL`, because
// inplace using on LoDTensorArray is dangerous, but the unittest // inplace using on LoDTensorArray is dangerous, but the unittest
// `test_list` contains this behavior // `test_list` contains this behavior
PADDLE_ENFORCE_EQ(dims.size(), 1UL, PADDLE_ENFORCE_EQ(dims.size(), 1UL,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"LoDTensorArray can only have one dimension.")); "LoDTensorArray can only have one dimension."));
// only set the array size for LoDTensorArray input // only set the array size for LoDTensorArray input
tensor_array->resize(dims[0]); tensor_array->resize(dims[0]);
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Currently, only can set dims from DenseTensor or SelectedRows."));
}
} else { } else {
auto* var = BOOST_GET(VarDesc*, var_); PADDLE_THROW(platform::errors::Unimplemented(
var->SetShape(vectorize(dims)); "Currently, only can set dims from DenseTensor or SelectedRows."));
} }
} else {
auto* var = BOOST_GET(VarDesc*, var_);
var->SetShape(vectorize(dims));
} }
}
void set_dtype(phi::DataType dtype) override { void CompatMetaTensor::set_dtype(phi::DataType dtype) {
if (is_runtime_) { if (is_runtime_) {
auto* var = BOOST_GET(Variable*, var_); auto* var = BOOST_GET(Variable*, var_);
if (var->IsType<phi::DenseTensor>()) { if (var->IsType<phi::DenseTensor>()) {
auto* tensor = var->GetMutable<phi::DenseTensor>(); auto* tensor = var->GetMutable<phi::DenseTensor>();
phi::DenseTensorUtils::GetMutableMeta(tensor)->dtype = dtype; phi::DenseTensorUtils::GetMutableMeta(tensor)->dtype = dtype;
} else if (var->IsType<phi::SelectedRows>()) { } else if (var->IsType<phi::SelectedRows>()) {
auto* tensor = var->GetMutable<phi::SelectedRows>()->mutable_value(); auto* tensor = var->GetMutable<phi::SelectedRows>()->mutable_value();
phi::DenseTensorUtils::GetMutableMeta(tensor)->dtype = dtype; phi::DenseTensorUtils::GetMutableMeta(tensor)->dtype = dtype;
} else if (var->IsType<framework::LoDTensorArray>()) { } else if (var->IsType<framework::LoDTensorArray>()) {
// NOTE(chenweihang): do nothing // NOTE(chenweihang): do nothing
// Unsupported set dtype for LoDTensorArray now // Unsupported set dtype for LoDTensorArray now
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Currently, only can set dtype from DenseTensor or SelectedRows."));
}
} else { } else {
auto* var = BOOST_GET(VarDesc*, var_); PADDLE_THROW(platform::errors::Unimplemented(
var->SetDataType(paddle::framework::TransToProtoVarType(dtype)); "Currently, only can set dtype from DenseTensor or SelectedRows."));
} }
} else {
auto* var = BOOST_GET(VarDesc*, var_);
var->SetDataType(paddle::framework::TransToProtoVarType(dtype));
} }
}
void set_layout(DataLayout layout) override { void CompatMetaTensor::set_layout(DataLayout layout) {
if (is_runtime_) { if (is_runtime_) {
auto* var = BOOST_GET(Variable*, var_); auto* var = BOOST_GET(Variable*, var_);
if (var->IsType<phi::DenseTensor>()) { if (var->IsType<phi::DenseTensor>()) {
auto* tensor = var->GetMutable<phi::DenseTensor>(); auto* tensor = var->GetMutable<phi::DenseTensor>();
phi::DenseTensorUtils::GetMutableMeta(tensor)->layout = layout; phi::DenseTensorUtils::GetMutableMeta(tensor)->layout = layout;
} else if (var->IsType<phi::SelectedRows>()) { } else if (var->IsType<phi::SelectedRows>()) {
auto* tensor = var->GetMutable<phi::SelectedRows>()->mutable_value(); auto* tensor = var->GetMutable<phi::SelectedRows>()->mutable_value();
phi::DenseTensorUtils::GetMutableMeta(tensor)->layout = layout; phi::DenseTensorUtils::GetMutableMeta(tensor)->layout = layout;
} else if (var->IsType<framework::LoDTensorArray>()) { } else if (var->IsType<framework::LoDTensorArray>()) {
// NOTE(chenweihang): do nothing
// Unsupported set dtype for LoDTensorArray now
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Currently, only can set layout from DenseTensor or "
"SelectedRows."));
}
} else {
// NOTE(chenweihang): do nothing // NOTE(chenweihang): do nothing
// Unsupported set layout for VarDesc now // Unsupported set dtype for LoDTensorArray now
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Currently, only can set layout from DenseTensor or "
"SelectedRows."));
} }
} else {
// NOTE(chenweihang): do nothing
// Unsupported set layout for VarDesc now
} }
}
void share_lod(const MetaTensor& meta_tensor) override { void CompatMetaTensor::share_lod(const MetaTensor& meta_tensor) {
if (is_runtime_) { if (is_runtime_) {
auto* var = BOOST_GET(Variable*, var_); auto* var = BOOST_GET(Variable*, var_);
if (var->IsType<phi::DenseTensor>()) { if (var->IsType<phi::DenseTensor>()) {
auto* tensor = var->GetMutable<phi::DenseTensor>(); auto* tensor = var->GetMutable<phi::DenseTensor>();
phi::DenseTensorUtils::GetMutableMeta(tensor)->lod = phi::DenseTensorUtils::GetMutableMeta(tensor)->lod =
static_cast<const CompatMetaTensor&>(meta_tensor).GetRuntimeLoD(); static_cast<const CompatMetaTensor&>(meta_tensor).GetRuntimeLoD();
} else {
// NOTE(chenweihang): do nothing
// only LoDTensor need to share lod
}
} else { } else {
auto* var = BOOST_GET(VarDesc*, var_); // NOTE(chenweihang): do nothing
var->SetLoDLevel(static_cast<const CompatMetaTensor&>(meta_tensor) // only LoDTensor need to share lod
.GetCompileTimeLoD());
} }
} else {
auto* var = BOOST_GET(VarDesc*, var_);
var->SetLoDLevel(
static_cast<const CompatMetaTensor&>(meta_tensor).GetCompileTimeLoD());
} }
}
void share_dims(const MetaTensor& meta_tensor) override { void CompatMetaTensor::share_dims(const MetaTensor& meta_tensor) {
set_dims(meta_tensor.dims()); set_dims(meta_tensor.dims());
if (is_runtime_) { if (is_runtime_) {
auto* var = BOOST_GET(Variable*, var_); auto* var = BOOST_GET(Variable*, var_);
if (var->IsType<phi::SelectedRows>()) { if (var->IsType<phi::SelectedRows>()) {
auto* selected_rows = var->GetMutable<phi::SelectedRows>(); auto* selected_rows = var->GetMutable<phi::SelectedRows>();
auto& input_selected_rows = auto& input_selected_rows =
static_cast<const CompatMetaTensor&>(meta_tensor).GetSelectedRows(); static_cast<const CompatMetaTensor&>(meta_tensor).GetSelectedRows();
selected_rows->set_rows(input_selected_rows.rows()); selected_rows->set_rows(input_selected_rows.rows());
selected_rows->set_height(input_selected_rows.height()); selected_rows->set_height(input_selected_rows.height());
}
} }
} }
}
void share_meta(const MetaTensor& meta_tensor) override { void CompatMetaTensor::share_meta(const MetaTensor& meta_tensor) {
share_dims(meta_tensor); share_dims(meta_tensor);
set_dtype(meta_tensor.dtype()); set_dtype(meta_tensor.dtype());
set_layout(meta_tensor.layout()); set_layout(meta_tensor.layout());
// special case: share lod of LoDTensor // special case: share lod of LoDTensor
share_lod(meta_tensor); share_lod(meta_tensor);
} }
private:
const LoD& GetRuntimeLoD() const {
auto* var = BOOST_GET_CONST(Variable*, var_);
return var->Get<LoDTensor>().lod();
}
int32_t GetCompileTimeLoD() const {
auto* var = BOOST_GET_CONST(VarDesc*, var_);
return var->GetLoDLevel();
}
const phi::SelectedRows& GetSelectedRows() const {
PADDLE_ENFORCE_EQ(is_runtime_, true,
platform::errors::Unavailable(
"Only can get Tensor from MetaTensor in rumtime."));
auto* var = BOOST_GET_CONST(Variable*, var_);
PADDLE_ENFORCE_EQ(var->IsType<phi::SelectedRows>(), true,
platform::errors::Unavailable(
"The Tensor in MetaTensor is not SelectedRows."));
return var->Get<phi::SelectedRows>();
}
InferShapeVarPtr var_;
bool is_runtime_;
};
phi::InferMetaContext BuildInferMetaContext(InferShapeContext* ctx, phi::InferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
const std::string& op_type) { const std::string& op_type) {
......
...@@ -18,7 +18,7 @@ limitations under the License. */ ...@@ -18,7 +18,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_info.h" #include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/shape_inference.h" #include "paddle/fluid/framework/shape_inference.h"
#include "paddle/phi/core/meta_tensor.h"
namespace phi { namespace phi {
class InferMetaContext; class InferMetaContext;
} // namespace phi } // namespace phi
...@@ -39,5 +39,63 @@ phi::InferMetaContext BuildInferMetaContext(InferShapeContext* ctx, ...@@ -39,5 +39,63 @@ phi::InferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
} \ } \
} }
// TODO(chenweihang): Support TensorArray later
class CompatMetaTensor : public phi::MetaTensor {
public:
CompatMetaTensor(InferShapeVarPtr var, bool is_runtime)
: var_(std::move(var)), is_runtime_(is_runtime) {}
CompatMetaTensor() = default;
CompatMetaTensor(const CompatMetaTensor&) = default;
CompatMetaTensor(CompatMetaTensor&&) = default;
CompatMetaTensor& operator=(const CompatMetaTensor&) = delete;
CompatMetaTensor& operator=(CompatMetaTensor&&) = delete;
int64_t numel() const override;
DDim dims() const override;
phi::DataType dtype() const override;
DataLayout layout() const override;
void set_dims(const DDim& dims) override;
void set_dtype(phi::DataType dtype) override;
void set_layout(DataLayout layout) override;
void share_lod(const MetaTensor& meta_tensor) override;
void share_dims(const MetaTensor& meta_tensor) override;
void share_meta(const MetaTensor& meta_tensor) override;
private:
const LoD& GetRuntimeLoD() const {
auto* var = BOOST_GET_CONST(Variable*, var_);
return var->Get<LoDTensor>().lod();
}
int32_t GetCompileTimeLoD() const {
auto* var = BOOST_GET_CONST(VarDesc*, var_);
return var->GetLoDLevel();
}
const phi::SelectedRows& GetSelectedRows() const {
PADDLE_ENFORCE_EQ(is_runtime_, true,
platform::errors::Unavailable(
"Only can get Tensor from MetaTensor in rumtime."));
auto* var = BOOST_GET_CONST(Variable*, var_);
PADDLE_ENFORCE_EQ(var->IsType<phi::SelectedRows>(), true,
platform::errors::Unavailable(
"The Tensor in MetaTensor is not SelectedRows."));
return var->Get<phi::SelectedRows>();
}
InferShapeVarPtr var_;
bool is_runtime_;
};
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -17,7 +17,10 @@ limitations under the License. */ ...@@ -17,7 +17,10 @@ limitations under the License. */
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/unary.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -270,70 +273,24 @@ class Flatten2GradOp : public framework::OperatorWithKernel { ...@@ -270,70 +273,24 @@ class Flatten2GradOp : public framework::OperatorWithKernel {
class FlattenContiguousRangeOp : public framework::OperatorWithKernel { class FlattenContiguousRangeOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override { void InferShape(framework::InferShapeContext *ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "FlattenContiguousRange"); OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "FlattenContiguousRange");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out",
"FlattenContiguousRange"); "FlattenContiguousRange");
const auto &start_axis = ctx->Attrs().Get<int>("start_axis"); const auto &start_axis = ctx->Attrs().Get<int>("start_axis");
const auto &stop_axis = ctx->Attrs().Get<int>("stop_axis"); const auto &stop_axis = ctx->Attrs().Get<int>("stop_axis");
const auto &in_dims = ctx->GetInputDim("X");
int in_dims_size = in_dims.size();
int real_start_axis = start_axis, real_stop_axis = stop_axis;
if (start_axis < 0) {
real_start_axis = start_axis + in_dims_size;
}
if (stop_axis < 0) {
real_stop_axis = stop_axis + in_dims_size;
}
PADDLE_ENFORCE_GE(
real_stop_axis, real_start_axis,
platform::errors::InvalidArgument("The stop_axis should be greater"
"than or equal to start_axis."));
const auto &out_dims = // Construct MetaTensor for InferMeta Func
GetOutputShape(real_start_axis, real_stop_axis, in_dims); using CompatMetaTensor = framework::CompatMetaTensor;
ctx->SetOutputDim("Out", phi::make_ddim(out_dims)); CompatMetaTensor x(ctx->GetInputVarPtrs("X")[0], ctx->IsRuntime());
if (in_dims[0] == out_dims[0]) { CompatMetaTensor out(ctx->GetOutputVarPtrs("Out")[0], ctx->IsRuntime());
// Only pass LoD when the first dimension of output and Input(X) std::unique_ptr<CompatMetaTensor> xshape(nullptr);
// are the same. if (ctx->HasOutput("XShape")) {
ctx->ShareLoD("X", "Out"); xshape = std::move(std::unique_ptr<CompatMetaTensor>(new CompatMetaTensor(
} ctx->GetOutputVarPtrs("XShape")[0], ctx->IsRuntime())));
if (!ctx->HasOutput("XShape")) return;
// OP_INOUT_CHECK(ctx->HasOutput("XShape"), "Output", "XShape", "Flatten2");
std::vector<int64_t> xshape_dims(in_dims.size() + 1);
xshape_dims[0] = 0;
for (int i = 0; i < in_dims.size(); ++i) {
xshape_dims[i + 1] = in_dims[i];
} }
ctx->SetOutputDim("XShape", phi::make_ddim(xshape_dims)); phi::FlattenWithXShapeInferMeta(x, start_axis, stop_axis, &out,
ctx->ShareLoD("X", "XShape"); xshape.get());
}
static std::vector<int32_t> GetOutputShape(const int start_axis,
const int stop_axis,
const framework::DDim &in_dims) {
int64_t outer = 1;
std::vector<int32_t> out_shape;
int in_dims_size = in_dims.size();
out_shape.reserve(in_dims_size - stop_axis + start_axis);
for (int i = 0; i < start_axis; ++i) {
out_shape.push_back(in_dims[i]);
}
for (int i = start_axis; i <= stop_axis; i++) {
if (in_dims[i] == -1 || outer == -1) {
outer = -1;
} else {
outer *= in_dims[i];
}
}
out_shape.push_back(outer);
for (int i = stop_axis + 1; i < in_dims_size; i++) {
out_shape.push_back(in_dims[i]);
}
return out_shape;
} }
}; };
...@@ -487,30 +444,3 @@ REGISTER_OP_CPU_KERNEL( ...@@ -487,30 +444,3 @@ REGISTER_OP_CPU_KERNEL(
ops::Flatten2GradKernel<paddle::platform::CPUDeviceContext, int>, ops::Flatten2GradKernel<paddle::platform::CPUDeviceContext, int>,
ops::Flatten2GradKernel<paddle::platform::CPUDeviceContext, int8_t>, ops::Flatten2GradKernel<paddle::platform::CPUDeviceContext, int8_t>,
ops::Flatten2GradKernel<paddle::platform::CPUDeviceContext, int64_t>); ops::Flatten2GradKernel<paddle::platform::CPUDeviceContext, int64_t>);
REGISTER_OP_CPU_KERNEL(
flatten_contiguous_range,
ops::FlattenContiguousRangeKernel<paddle::platform::CPUDeviceContext,
float>,
ops::FlattenContiguousRangeKernel<paddle::platform::CPUDeviceContext,
double>,
ops::FlattenContiguousRangeKernel<paddle::platform::CPUDeviceContext,
uint8_t>,
ops::FlattenContiguousRangeKernel<paddle::platform::CPUDeviceContext, int>,
ops::FlattenContiguousRangeKernel<paddle::platform::CPUDeviceContext,
int8_t>,
ops::FlattenContiguousRangeKernel<paddle::platform::CPUDeviceContext,
int64_t>);
REGISTER_OP_CPU_KERNEL(
flatten_contiguous_range_grad,
ops::FlattenContiguousRangeGradKernel<paddle::platform::CPUDeviceContext,
float>,
ops::FlattenContiguousRangeGradKernel<paddle::platform::CPUDeviceContext,
double>,
ops::FlattenContiguousRangeGradKernel<paddle::platform::CPUDeviceContext,
uint8_t>,
ops::FlattenContiguousRangeGradKernel<paddle::platform::CPUDeviceContext,
int>,
ops::FlattenContiguousRangeGradKernel<paddle::platform::CPUDeviceContext,
int8_t>,
ops::FlattenContiguousRangeGradKernel<paddle::platform::CPUDeviceContext,
int64_t>);
...@@ -47,34 +47,3 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -47,34 +47,3 @@ REGISTER_OP_CUDA_KERNEL(
ops::Flatten2GradKernel<paddle::platform::CUDADeviceContext, int>, ops::Flatten2GradKernel<paddle::platform::CUDADeviceContext, int>,
ops::Flatten2GradKernel<paddle::platform::CUDADeviceContext, int8_t>, ops::Flatten2GradKernel<paddle::platform::CUDADeviceContext, int8_t>,
ops::Flatten2GradKernel<paddle::platform::CUDADeviceContext, int64_t>); ops::Flatten2GradKernel<paddle::platform::CUDADeviceContext, int64_t>);
REGISTER_OP_CUDA_KERNEL(
flatten_contiguous_range,
ops::FlattenContiguousRangeKernel<paddle::platform::CUDADeviceContext,
float>,
ops::FlattenContiguousRangeKernel<paddle::platform::CUDADeviceContext,
plat::float16>,
ops::FlattenContiguousRangeKernel<paddle::platform::CUDADeviceContext,
double>,
ops::FlattenContiguousRangeKernel<paddle::platform::CUDADeviceContext,
uint8_t>,
ops::FlattenContiguousRangeKernel<paddle::platform::CUDADeviceContext, int>,
ops::FlattenContiguousRangeKernel<paddle::platform::CUDADeviceContext,
int8_t>,
ops::FlattenContiguousRangeKernel<paddle::platform::CUDADeviceContext,
int64_t>);
REGISTER_OP_CUDA_KERNEL(
flatten_contiguous_range_grad,
ops::FlattenContiguousRangeGradKernel<paddle::platform::CUDADeviceContext,
float>,
ops::FlattenContiguousRangeGradKernel<paddle::platform::CUDADeviceContext,
plat::float16>,
ops::FlattenContiguousRangeGradKernel<paddle::platform::CUDADeviceContext,
double>,
ops::FlattenContiguousRangeGradKernel<paddle::platform::CUDADeviceContext,
uint8_t>,
ops::FlattenContiguousRangeGradKernel<paddle::platform::CUDADeviceContext,
int>,
ops::FlattenContiguousRangeGradKernel<paddle::platform::CUDADeviceContext,
int8_t>,
ops::FlattenContiguousRangeGradKernel<paddle::platform::CUDADeviceContext,
int64_t>);
...@@ -119,46 +119,5 @@ class Flatten2GradKernel : public framework::OpKernel<T> { ...@@ -119,46 +119,5 @@ class Flatten2GradKernel : public framework::OpKernel<T> {
} }
}; };
template <typename DeviceContext, typename T>
class FlattenContiguousRangeKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &context) const override {
auto *in = context.Input<framework::LoDTensor>("X");
auto *out = context.Output<framework::LoDTensor>("Out");
out->mutable_data(context.GetPlace(), in->type());
auto &start_axis = context.Attr<int>("start_axis");
auto &stop_axis = context.Attr<int>("stop_axis");
auto &dev_ctx = context.device_context<DeviceContext>();
// call new kernel
phi::FlattenKernel<T, typename paddle::framework::ConvertToPhiContext<
DeviceContext>::TYPE>(
static_cast<const typename paddle::framework::ConvertToPhiContext<
DeviceContext>::TYPE &>(dev_ctx),
*in, start_axis, stop_axis, out);
}
};
template <typename DeviceContext, typename T>
class FlattenContiguousRangeGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
auto *d_x = ctx.Output<framework::LoDTensor>(framework::GradVarName("X"));
auto *d_out =
ctx.Input<framework::LoDTensor>(framework::GradVarName("Out"));
auto *xshape = ctx.Input<framework::LoDTensor>("XShape");
d_x->mutable_data(ctx.GetPlace(), d_out->type());
auto &dev_ctx = ctx.device_context<DeviceContext>();
// call new kernel
phi::FlattenGradKernel<T, typename paddle::framework::ConvertToPhiContext<
DeviceContext>::TYPE>(
static_cast<const typename paddle::framework::ConvertToPhiContext<
DeviceContext>::TYPE &>(dev_ctx),
*d_out, *xshape, d_x);
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -41,27 +41,4 @@ REGISTER_OP_XPU_KERNEL( ...@@ -41,27 +41,4 @@ REGISTER_OP_XPU_KERNEL(
ops::Flatten2GradKernel<paddle::platform::XPUDeviceContext, int>, ops::Flatten2GradKernel<paddle::platform::XPUDeviceContext, int>,
ops::Flatten2GradKernel<paddle::platform::XPUDeviceContext, int8_t>, ops::Flatten2GradKernel<paddle::platform::XPUDeviceContext, int8_t>,
ops::Flatten2GradKernel<paddle::platform::XPUDeviceContext, int64_t>); ops::Flatten2GradKernel<paddle::platform::XPUDeviceContext, int64_t>);
REGISTER_OP_XPU_KERNEL(
flatten_contiguous_range,
ops::FlattenContiguousRangeKernel<paddle::platform::XPUDeviceContext,
float>,
ops::FlattenContiguousRangeKernel<paddle::platform::XPUDeviceContext,
plat::float16>,
ops::FlattenContiguousRangeKernel<paddle::platform::XPUDeviceContext, int>,
ops::FlattenContiguousRangeKernel<paddle::platform::XPUDeviceContext,
int8_t>,
ops::FlattenContiguousRangeKernel<paddle::platform::XPUDeviceContext,
int64_t>);
REGISTER_OP_XPU_KERNEL(
flatten_contiguous_range_grad,
ops::FlattenContiguousRangeGradKernel<paddle::platform::XPUDeviceContext,
float>,
ops::FlattenContiguousRangeGradKernel<paddle::platform::XPUDeviceContext,
plat::float16>,
ops::FlattenContiguousRangeGradKernel<paddle::platform::XPUDeviceContext,
int>,
ops::FlattenContiguousRangeGradKernel<paddle::platform::XPUDeviceContext,
int8_t>,
ops::FlattenContiguousRangeGradKernel<paddle::platform::XPUDeviceContext,
int64_t>);
#endif #endif
...@@ -352,6 +352,14 @@ void FlattenInferMeta(const MetaTensor& x, ...@@ -352,6 +352,14 @@ void FlattenInferMeta(const MetaTensor& x,
int start_axis, int start_axis,
int stop_axis, int stop_axis,
MetaTensor* out) { MetaTensor* out) {
FlattenWithXShapeInferMeta(x, start_axis, stop_axis, out, nullptr);
}
void FlattenWithXShapeInferMeta(const MetaTensor& x,
int start_axis,
int stop_axis,
MetaTensor* out,
MetaTensor* xshape) {
auto x_dims = x.dims(); auto x_dims = x.dims();
int in_dims_size = x_dims.size(); int in_dims_size = x_dims.size();
if (start_axis < 0) { if (start_axis < 0) {
...@@ -394,6 +402,14 @@ void FlattenInferMeta(const MetaTensor& x, ...@@ -394,6 +402,14 @@ void FlattenInferMeta(const MetaTensor& x,
// are the same. // are the same.
out->share_lod(x); out->share_lod(x);
} }
if (xshape == nullptr) return;
std::vector<int64_t> xshape_dims(x_dims.size() + 1);
xshape_dims[0] = 0;
for (int i = 0; i < x_dims.size(); ++i) {
xshape_dims[i + 1] = x_dims[i];
}
xshape->set_dims(phi::make_ddim(xshape_dims));
xshape->share_lod(x);
} }
void GumbelSoftmaxInferMeta(const MetaTensor& x, void GumbelSoftmaxInferMeta(const MetaTensor& x,
......
...@@ -86,6 +86,12 @@ void FlattenInferMeta(const MetaTensor& x, ...@@ -86,6 +86,12 @@ void FlattenInferMeta(const MetaTensor& x,
int stop_axis, int stop_axis,
MetaTensor* out); MetaTensor* out);
void FlattenWithXShapeInferMeta(const MetaTensor& x,
int start_axis,
int stop_axis,
MetaTensor* out,
MetaTensor* xshape);
void GumbelSoftmaxInferMeta(const MetaTensor& x, void GumbelSoftmaxInferMeta(const MetaTensor& x,
float temperature, float temperature,
bool hard, bool hard,
......
...@@ -25,6 +25,7 @@ void FlattenGradKernel(const Context& dev_ctx, ...@@ -25,6 +25,7 @@ void FlattenGradKernel(const Context& dev_ctx,
const DenseTensor& xshape, const DenseTensor& xshape,
DenseTensor* x_grad) { DenseTensor* x_grad) {
auto xshape_dims = xshape.dims(); auto xshape_dims = xshape.dims();
dev_ctx.Alloc(x_grad, out_grad.dtype());
auto x_dims = phi::slice_ddim(xshape_dims, 1, xshape_dims.size()); auto x_dims = phi::slice_ddim(xshape_dims, 1, xshape_dims.size());
phi::Copy(dev_ctx, out_grad, dev_ctx.GetPlace(), false, x_grad); phi::Copy(dev_ctx, out_grad, dev_ctx.GetPlace(), false, x_grad);
x_grad->Resize(x_dims); x_grad->Resize(x_dims);
......
...@@ -27,6 +27,7 @@ void FlattenKernel(const Context& dev_ctx, ...@@ -27,6 +27,7 @@ void FlattenKernel(const Context& dev_ctx,
int start_axis, int start_axis,
int stop_axis, int stop_axis,
DenseTensor* out) { DenseTensor* out) {
dev_ctx.Alloc(out, x.dtype());
auto out_dims = out->dims(); auto out_dims = out->dims();
phi::Copy(dev_ctx, x, dev_ctx.GetPlace(), false, out); phi::Copy(dev_ctx, x, dev_ctx.GetPlace(), false, out);
out->Resize(out_dims); out->Resize(out_dims);
...@@ -43,7 +44,6 @@ void FlattenWithXShape(const Context& dev_ctx, ...@@ -43,7 +44,6 @@ void FlattenWithXShape(const Context& dev_ctx,
DenseTensor* out, DenseTensor* out,
DenseTensor* xshape) { DenseTensor* xshape) {
FlattenKernel<T, Context>(dev_ctx, x, start_axis, stop_axis, out); FlattenKernel<T, Context>(dev_ctx, x, start_axis, stop_axis, out);
funcs::SetXShape(x, xshape);
} }
} // namespace phi } // namespace phi
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册