未验证 提交 778008d7 编写于 作者: Y YuanRisheng 提交者: GitHub

[Phi]Remove InferShape and Kernel of flatten_contiguous_range op (#40638)

* remove flatten infermeta

* fix bugs when run inference ci

* fix bugs when run inference ci

* fix bugs when run ci

* support infrt

* inplace infershape code'
上级 f4075db8
......@@ -27,7 +27,6 @@ limitations under the License. */
#include "paddle/phi/core/compat/op_utils.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/core/meta_tensor.h"
#include "paddle/phi/core/tensor_utils.h"
namespace paddle {
......@@ -101,235 +100,197 @@ class InferShapeArgumentMappingContext : public phi::ArgumentMappingContext {
const InferShapeContext& ctx_;
};
// TODO(chenweihang): Support TensorArray later
class CompatMetaTensor : public phi::MetaTensor {
public:
CompatMetaTensor(InferShapeVarPtr var, bool is_runtime)
: var_(std::move(var)), is_runtime_(is_runtime) {}
CompatMetaTensor() = default;
CompatMetaTensor(const CompatMetaTensor&) = default;
CompatMetaTensor(CompatMetaTensor&&) = default;
CompatMetaTensor& operator=(const CompatMetaTensor&) = delete;
CompatMetaTensor& operator=(CompatMetaTensor&&) = delete;
int64_t numel() const override {
if (is_runtime_) {
auto* var = BOOST_GET_CONST(Variable*, var_);
return var->Get<Tensor>().numel();
} else {
auto* var = BOOST_GET_CONST(VarDesc*, var_);
return var->ElementSize();
}
int64_t CompatMetaTensor::numel() const {
if (is_runtime_) {
auto* var = BOOST_GET_CONST(Variable*, var_);
return var->Get<Tensor>().numel();
} else {
auto* var = BOOST_GET_CONST(VarDesc*, var_);
return var->ElementSize();
}
}
DDim dims() const override {
if (is_runtime_) {
auto* var = BOOST_GET_CONST(Variable*, var_);
if (var->IsType<phi::DenseTensor>()) {
return var->Get<phi::DenseTensor>().dims();
} else if (var->IsType<phi::SelectedRows>()) {
return var->Get<phi::SelectedRows>().dims();
} else if (var->IsType<framework::LoDTensorArray>()) {
// use tensor array size as dims
auto& tensor_array = var->Get<framework::LoDTensorArray>();
return phi::make_ddim({static_cast<int64_t>(tensor_array.size())});
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Currently, only can get dims from DenseTensor or SelectedRows or "
"DenseTensorArray."));
}
DDim CompatMetaTensor::dims() const {
if (is_runtime_) {
auto* var = BOOST_GET_CONST(Variable*, var_);
if (var->IsType<phi::DenseTensor>()) {
return var->Get<phi::DenseTensor>().dims();
} else if (var->IsType<phi::SelectedRows>()) {
return var->Get<phi::SelectedRows>().dims();
} else if (var->IsType<framework::LoDTensorArray>()) {
// use tensor array size as dims
auto& tensor_array = var->Get<framework::LoDTensorArray>();
return phi::make_ddim({static_cast<int64_t>(tensor_array.size())});
} else {
auto* var = BOOST_GET_CONST(VarDesc*, var_);
return var->GetShape().empty() ? phi::make_ddim({0UL})
: phi::make_ddim(var->GetShape());
PADDLE_THROW(platform::errors::Unimplemented(
"Currently, only can get dims from DenseTensor or SelectedRows or "
"DenseTensorArray."));
}
} else {
auto* var = BOOST_GET_CONST(VarDesc*, var_);
return var->GetShape().empty() ? phi::make_ddim({0UL})
: phi::make_ddim(var->GetShape());
}
}
phi::DataType dtype() const override {
if (is_runtime_) {
auto* var = BOOST_GET_CONST(Variable*, var_);
if (var->IsType<phi::DenseTensor>()) {
return var->Get<phi::DenseTensor>().dtype();
} else if (var->IsType<phi::SelectedRows>()) {
return var->Get<phi::SelectedRows>().dtype();
} else if (var->IsType<framework::LoDTensorArray>()) {
// NOTE(chenweihang): do nothing
// Unsupported get dtype from LoDTensorArray now
return phi::DataType::UNDEFINED;
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Currently, only can get dtype from DenseTensor or SelectedRows."));
}
phi::DataType CompatMetaTensor::dtype() const {
if (is_runtime_) {
auto* var = BOOST_GET_CONST(Variable*, var_);
if (var->IsType<phi::DenseTensor>()) {
return var->Get<phi::DenseTensor>().dtype();
} else if (var->IsType<phi::SelectedRows>()) {
return var->Get<phi::SelectedRows>().dtype();
} else if (var->IsType<framework::LoDTensorArray>()) {
// NOTE(chenweihang): do nothing
// Unsupported get dtype from LoDTensorArray now
return phi::DataType::UNDEFINED;
} else {
auto* var = BOOST_GET_CONST(VarDesc*, var_);
return paddle::framework::TransToPhiDataType(var->GetDataType());
PADDLE_THROW(platform::errors::Unimplemented(
"Currently, only can get dtype from DenseTensor or SelectedRows."));
}
} else {
auto* var = BOOST_GET_CONST(VarDesc*, var_);
return paddle::framework::TransToPhiDataType(var->GetDataType());
}
}
DataLayout layout() const override {
if (is_runtime_) {
auto* var = BOOST_GET_CONST(Variable*, var_);
if (var->IsType<phi::DenseTensor>()) {
return var->Get<phi::DenseTensor>().layout();
} else if (var->IsType<phi::SelectedRows>()) {
return var->Get<phi::SelectedRows>().layout();
} else if (var->IsType<framework::LoDTensorArray>()) {
// NOTE(chenweihang): do nothing
// Unsupported get layout from LoDTensorArray now
return phi::DataLayout::UNDEFINED;
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Currently, only can get layout from DenseTensor or "
"SelectedRows."));
}
} else {
DataLayout CompatMetaTensor::layout() const {
if (is_runtime_) {
auto* var = BOOST_GET_CONST(Variable*, var_);
if (var->IsType<phi::DenseTensor>()) {
return var->Get<phi::DenseTensor>().layout();
} else if (var->IsType<phi::SelectedRows>()) {
return var->Get<phi::SelectedRows>().layout();
} else if (var->IsType<framework::LoDTensorArray>()) {
// NOTE(chenweihang): do nothing
// Unsupported get layout for VarDesc now
return DataLayout::UNDEFINED;
// Unsupported get layout from LoDTensorArray now
return phi::DataLayout::UNDEFINED;
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Currently, only can get layout from DenseTensor or "
"SelectedRows."));
}
} else {
// NOTE(chenweihang): do nothing
// Unsupported get layout for VarDesc now
return DataLayout::UNDEFINED;
}
}
void set_dims(const DDim& dims) override {
if (is_runtime_) {
auto* var = BOOST_GET(Variable*, var_);
if (var->IsType<phi::DenseTensor>()) {
auto* tensor = var->GetMutable<phi::DenseTensor>();
phi::DenseTensorUtils::GetMutableMeta(tensor)->dims = dims;
} else if (var->IsType<phi::SelectedRows>()) {
auto* tensor = var->GetMutable<phi::SelectedRows>()->mutable_value();
phi::DenseTensorUtils::GetMutableMeta(tensor)->dims = dims;
} else if (var->IsType<framework::LoDTensorArray>()) {
auto* tensor_array = var->GetMutable<framework::LoDTensorArray>();
// Note: Here I want enforce `tensor_array->size() == 0UL`, because
// inplace using on LoDTensorArray is dangerous, but the unittest
// `test_list` contains this behavior
PADDLE_ENFORCE_EQ(dims.size(), 1UL,
platform::errors::InvalidArgument(
"LoDTensorArray can only have one dimension."));
// only set the array size for LoDTensorArray input
tensor_array->resize(dims[0]);
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Currently, only can set dims from DenseTensor or SelectedRows."));
}
void CompatMetaTensor::set_dims(const DDim& dims) {
if (is_runtime_) {
auto* var = BOOST_GET(Variable*, var_);
if (var->IsType<phi::DenseTensor>()) {
auto* tensor = var->GetMutable<phi::DenseTensor>();
phi::DenseTensorUtils::GetMutableMeta(tensor)->dims = dims;
} else if (var->IsType<phi::SelectedRows>()) {
auto* tensor = var->GetMutable<phi::SelectedRows>()->mutable_value();
phi::DenseTensorUtils::GetMutableMeta(tensor)->dims = dims;
} else if (var->IsType<framework::LoDTensorArray>()) {
auto* tensor_array = var->GetMutable<framework::LoDTensorArray>();
// Note: Here I want enforce `tensor_array->size() == 0UL`, because
// inplace using on LoDTensorArray is dangerous, but the unittest
// `test_list` contains this behavior
PADDLE_ENFORCE_EQ(dims.size(), 1UL,
platform::errors::InvalidArgument(
"LoDTensorArray can only have one dimension."));
// only set the array size for LoDTensorArray input
tensor_array->resize(dims[0]);
} else {
auto* var = BOOST_GET(VarDesc*, var_);
var->SetShape(vectorize(dims));
PADDLE_THROW(platform::errors::Unimplemented(
"Currently, only can set dims from DenseTensor or SelectedRows."));
}
} else {
auto* var = BOOST_GET(VarDesc*, var_);
var->SetShape(vectorize(dims));
}
}
void set_dtype(phi::DataType dtype) override {
if (is_runtime_) {
auto* var = BOOST_GET(Variable*, var_);
if (var->IsType<phi::DenseTensor>()) {
auto* tensor = var->GetMutable<phi::DenseTensor>();
phi::DenseTensorUtils::GetMutableMeta(tensor)->dtype = dtype;
} else if (var->IsType<phi::SelectedRows>()) {
auto* tensor = var->GetMutable<phi::SelectedRows>()->mutable_value();
phi::DenseTensorUtils::GetMutableMeta(tensor)->dtype = dtype;
} else if (var->IsType<framework::LoDTensorArray>()) {
// NOTE(chenweihang): do nothing
// Unsupported set dtype for LoDTensorArray now
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Currently, only can set dtype from DenseTensor or SelectedRows."));
}
void CompatMetaTensor::set_dtype(phi::DataType dtype) {
if (is_runtime_) {
auto* var = BOOST_GET(Variable*, var_);
if (var->IsType<phi::DenseTensor>()) {
auto* tensor = var->GetMutable<phi::DenseTensor>();
phi::DenseTensorUtils::GetMutableMeta(tensor)->dtype = dtype;
} else if (var->IsType<phi::SelectedRows>()) {
auto* tensor = var->GetMutable<phi::SelectedRows>()->mutable_value();
phi::DenseTensorUtils::GetMutableMeta(tensor)->dtype = dtype;
} else if (var->IsType<framework::LoDTensorArray>()) {
// NOTE(chenweihang): do nothing
// Unsupported set dtype for LoDTensorArray now
} else {
auto* var = BOOST_GET(VarDesc*, var_);
var->SetDataType(paddle::framework::TransToProtoVarType(dtype));
PADDLE_THROW(platform::errors::Unimplemented(
"Currently, only can set dtype from DenseTensor or SelectedRows."));
}
} else {
auto* var = BOOST_GET(VarDesc*, var_);
var->SetDataType(paddle::framework::TransToProtoVarType(dtype));
}
}
void set_layout(DataLayout layout) override {
if (is_runtime_) {
auto* var = BOOST_GET(Variable*, var_);
if (var->IsType<phi::DenseTensor>()) {
auto* tensor = var->GetMutable<phi::DenseTensor>();
phi::DenseTensorUtils::GetMutableMeta(tensor)->layout = layout;
} else if (var->IsType<phi::SelectedRows>()) {
auto* tensor = var->GetMutable<phi::SelectedRows>()->mutable_value();
phi::DenseTensorUtils::GetMutableMeta(tensor)->layout = layout;
} else if (var->IsType<framework::LoDTensorArray>()) {
// NOTE(chenweihang): do nothing
// Unsupported set dtype for LoDTensorArray now
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Currently, only can set layout from DenseTensor or "
"SelectedRows."));
}
} else {
void CompatMetaTensor::set_layout(DataLayout layout) {
if (is_runtime_) {
auto* var = BOOST_GET(Variable*, var_);
if (var->IsType<phi::DenseTensor>()) {
auto* tensor = var->GetMutable<phi::DenseTensor>();
phi::DenseTensorUtils::GetMutableMeta(tensor)->layout = layout;
} else if (var->IsType<phi::SelectedRows>()) {
auto* tensor = var->GetMutable<phi::SelectedRows>()->mutable_value();
phi::DenseTensorUtils::GetMutableMeta(tensor)->layout = layout;
} else if (var->IsType<framework::LoDTensorArray>()) {
// NOTE(chenweihang): do nothing
// Unsupported set layout for VarDesc now
// Unsupported set dtype for LoDTensorArray now
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Currently, only can set layout from DenseTensor or "
"SelectedRows."));
}
} else {
// NOTE(chenweihang): do nothing
// Unsupported set layout for VarDesc now
}
}
void share_lod(const MetaTensor& meta_tensor) override {
if (is_runtime_) {
auto* var = BOOST_GET(Variable*, var_);
if (var->IsType<phi::DenseTensor>()) {
auto* tensor = var->GetMutable<phi::DenseTensor>();
phi::DenseTensorUtils::GetMutableMeta(tensor)->lod =
static_cast<const CompatMetaTensor&>(meta_tensor).GetRuntimeLoD();
} else {
// NOTE(chenweihang): do nothing
// only LoDTensor need to share lod
}
void CompatMetaTensor::share_lod(const MetaTensor& meta_tensor) {
if (is_runtime_) {
auto* var = BOOST_GET(Variable*, var_);
if (var->IsType<phi::DenseTensor>()) {
auto* tensor = var->GetMutable<phi::DenseTensor>();
phi::DenseTensorUtils::GetMutableMeta(tensor)->lod =
static_cast<const CompatMetaTensor&>(meta_tensor).GetRuntimeLoD();
} else {
auto* var = BOOST_GET(VarDesc*, var_);
var->SetLoDLevel(static_cast<const CompatMetaTensor&>(meta_tensor)
.GetCompileTimeLoD());
// NOTE(chenweihang): do nothing
// only LoDTensor need to share lod
}
} else {
auto* var = BOOST_GET(VarDesc*, var_);
var->SetLoDLevel(
static_cast<const CompatMetaTensor&>(meta_tensor).GetCompileTimeLoD());
}
}
void share_dims(const MetaTensor& meta_tensor) override {
set_dims(meta_tensor.dims());
if (is_runtime_) {
auto* var = BOOST_GET(Variable*, var_);
if (var->IsType<phi::SelectedRows>()) {
auto* selected_rows = var->GetMutable<phi::SelectedRows>();
auto& input_selected_rows =
static_cast<const CompatMetaTensor&>(meta_tensor).GetSelectedRows();
selected_rows->set_rows(input_selected_rows.rows());
selected_rows->set_height(input_selected_rows.height());
}
void CompatMetaTensor::share_dims(const MetaTensor& meta_tensor) {
set_dims(meta_tensor.dims());
if (is_runtime_) {
auto* var = BOOST_GET(Variable*, var_);
if (var->IsType<phi::SelectedRows>()) {
auto* selected_rows = var->GetMutable<phi::SelectedRows>();
auto& input_selected_rows =
static_cast<const CompatMetaTensor&>(meta_tensor).GetSelectedRows();
selected_rows->set_rows(input_selected_rows.rows());
selected_rows->set_height(input_selected_rows.height());
}
}
}
void share_meta(const MetaTensor& meta_tensor) override {
share_dims(meta_tensor);
set_dtype(meta_tensor.dtype());
set_layout(meta_tensor.layout());
// special case: share lod of LoDTensor
share_lod(meta_tensor);
}
private:
const LoD& GetRuntimeLoD() const {
auto* var = BOOST_GET_CONST(Variable*, var_);
return var->Get<LoDTensor>().lod();
}
int32_t GetCompileTimeLoD() const {
auto* var = BOOST_GET_CONST(VarDesc*, var_);
return var->GetLoDLevel();
}
const phi::SelectedRows& GetSelectedRows() const {
PADDLE_ENFORCE_EQ(is_runtime_, true,
platform::errors::Unavailable(
"Only can get Tensor from MetaTensor in rumtime."));
auto* var = BOOST_GET_CONST(Variable*, var_);
PADDLE_ENFORCE_EQ(var->IsType<phi::SelectedRows>(), true,
platform::errors::Unavailable(
"The Tensor in MetaTensor is not SelectedRows."));
return var->Get<phi::SelectedRows>();
}
InferShapeVarPtr var_;
bool is_runtime_;
};
void CompatMetaTensor::share_meta(const MetaTensor& meta_tensor) {
share_dims(meta_tensor);
set_dtype(meta_tensor.dtype());
set_layout(meta_tensor.layout());
// special case: share lod of LoDTensor
share_lod(meta_tensor);
}
phi::InferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
const std::string& op_type) {
......
......@@ -18,7 +18,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/shape_inference.h"
#include "paddle/phi/core/meta_tensor.h"
namespace phi {
class InferMetaContext;
} // namespace phi
......@@ -39,5 +39,63 @@ phi::InferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
} \
}
// TODO(chenweihang): Support TensorArray later
class CompatMetaTensor : public phi::MetaTensor {
public:
CompatMetaTensor(InferShapeVarPtr var, bool is_runtime)
: var_(std::move(var)), is_runtime_(is_runtime) {}
CompatMetaTensor() = default;
CompatMetaTensor(const CompatMetaTensor&) = default;
CompatMetaTensor(CompatMetaTensor&&) = default;
CompatMetaTensor& operator=(const CompatMetaTensor&) = delete;
CompatMetaTensor& operator=(CompatMetaTensor&&) = delete;
int64_t numel() const override;
DDim dims() const override;
phi::DataType dtype() const override;
DataLayout layout() const override;
void set_dims(const DDim& dims) override;
void set_dtype(phi::DataType dtype) override;
void set_layout(DataLayout layout) override;
void share_lod(const MetaTensor& meta_tensor) override;
void share_dims(const MetaTensor& meta_tensor) override;
void share_meta(const MetaTensor& meta_tensor) override;
private:
const LoD& GetRuntimeLoD() const {
auto* var = BOOST_GET_CONST(Variable*, var_);
return var->Get<LoDTensor>().lod();
}
int32_t GetCompileTimeLoD() const {
auto* var = BOOST_GET_CONST(VarDesc*, var_);
return var->GetLoDLevel();
}
const phi::SelectedRows& GetSelectedRows() const {
PADDLE_ENFORCE_EQ(is_runtime_, true,
platform::errors::Unavailable(
"Only can get Tensor from MetaTensor in rumtime."));
auto* var = BOOST_GET_CONST(Variable*, var_);
PADDLE_ENFORCE_EQ(var->IsType<phi::SelectedRows>(), true,
platform::errors::Unavailable(
"The Tensor in MetaTensor is not SelectedRows."));
return var->Get<phi::SelectedRows>();
}
InferShapeVarPtr var_;
bool is_runtime_;
};
} // namespace framework
} // namespace paddle
......@@ -17,7 +17,10 @@ limitations under the License. */
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/unary.h"
namespace paddle {
namespace operators {
......@@ -270,70 +273,24 @@ class Flatten2GradOp : public framework::OperatorWithKernel {
class FlattenContiguousRangeOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "FlattenContiguousRange");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out",
"FlattenContiguousRange");
const auto &start_axis = ctx->Attrs().Get<int>("start_axis");
const auto &stop_axis = ctx->Attrs().Get<int>("stop_axis");
const auto &in_dims = ctx->GetInputDim("X");
int in_dims_size = in_dims.size();
int real_start_axis = start_axis, real_stop_axis = stop_axis;
if (start_axis < 0) {
real_start_axis = start_axis + in_dims_size;
}
if (stop_axis < 0) {
real_stop_axis = stop_axis + in_dims_size;
}
PADDLE_ENFORCE_GE(
real_stop_axis, real_start_axis,
platform::errors::InvalidArgument("The stop_axis should be greater"
"than or equal to start_axis."));
const auto &out_dims =
GetOutputShape(real_start_axis, real_stop_axis, in_dims);
ctx->SetOutputDim("Out", phi::make_ddim(out_dims));
if (in_dims[0] == out_dims[0]) {
// Only pass LoD when the first dimension of output and Input(X)
// are the same.
ctx->ShareLoD("X", "Out");
}
if (!ctx->HasOutput("XShape")) return;
// OP_INOUT_CHECK(ctx->HasOutput("XShape"), "Output", "XShape", "Flatten2");
std::vector<int64_t> xshape_dims(in_dims.size() + 1);
xshape_dims[0] = 0;
for (int i = 0; i < in_dims.size(); ++i) {
xshape_dims[i + 1] = in_dims[i];
// Construct MetaTensor for InferMeta Func
using CompatMetaTensor = framework::CompatMetaTensor;
CompatMetaTensor x(ctx->GetInputVarPtrs("X")[0], ctx->IsRuntime());
CompatMetaTensor out(ctx->GetOutputVarPtrs("Out")[0], ctx->IsRuntime());
std::unique_ptr<CompatMetaTensor> xshape(nullptr);
if (ctx->HasOutput("XShape")) {
xshape = std::move(std::unique_ptr<CompatMetaTensor>(new CompatMetaTensor(
ctx->GetOutputVarPtrs("XShape")[0], ctx->IsRuntime())));
}
ctx->SetOutputDim("XShape", phi::make_ddim(xshape_dims));
ctx->ShareLoD("X", "XShape");
}
static std::vector<int32_t> GetOutputShape(const int start_axis,
const int stop_axis,
const framework::DDim &in_dims) {
int64_t outer = 1;
std::vector<int32_t> out_shape;
int in_dims_size = in_dims.size();
out_shape.reserve(in_dims_size - stop_axis + start_axis);
for (int i = 0; i < start_axis; ++i) {
out_shape.push_back(in_dims[i]);
}
for (int i = start_axis; i <= stop_axis; i++) {
if (in_dims[i] == -1 || outer == -1) {
outer = -1;
} else {
outer *= in_dims[i];
}
}
out_shape.push_back(outer);
for (int i = stop_axis + 1; i < in_dims_size; i++) {
out_shape.push_back(in_dims[i]);
}
return out_shape;
phi::FlattenWithXShapeInferMeta(x, start_axis, stop_axis, &out,
xshape.get());
}
};
......@@ -487,30 +444,3 @@ REGISTER_OP_CPU_KERNEL(
ops::Flatten2GradKernel<paddle::platform::CPUDeviceContext, int>,
ops::Flatten2GradKernel<paddle::platform::CPUDeviceContext, int8_t>,
ops::Flatten2GradKernel<paddle::platform::CPUDeviceContext, int64_t>);
REGISTER_OP_CPU_KERNEL(
flatten_contiguous_range,
ops::FlattenContiguousRangeKernel<paddle::platform::CPUDeviceContext,
float>,
ops::FlattenContiguousRangeKernel<paddle::platform::CPUDeviceContext,
double>,
ops::FlattenContiguousRangeKernel<paddle::platform::CPUDeviceContext,
uint8_t>,
ops::FlattenContiguousRangeKernel<paddle::platform::CPUDeviceContext, int>,
ops::FlattenContiguousRangeKernel<paddle::platform::CPUDeviceContext,
int8_t>,
ops::FlattenContiguousRangeKernel<paddle::platform::CPUDeviceContext,
int64_t>);
REGISTER_OP_CPU_KERNEL(
flatten_contiguous_range_grad,
ops::FlattenContiguousRangeGradKernel<paddle::platform::CPUDeviceContext,
float>,
ops::FlattenContiguousRangeGradKernel<paddle::platform::CPUDeviceContext,
double>,
ops::FlattenContiguousRangeGradKernel<paddle::platform::CPUDeviceContext,
uint8_t>,
ops::FlattenContiguousRangeGradKernel<paddle::platform::CPUDeviceContext,
int>,
ops::FlattenContiguousRangeGradKernel<paddle::platform::CPUDeviceContext,
int8_t>,
ops::FlattenContiguousRangeGradKernel<paddle::platform::CPUDeviceContext,
int64_t>);
......@@ -47,34 +47,3 @@ REGISTER_OP_CUDA_KERNEL(
ops::Flatten2GradKernel<paddle::platform::CUDADeviceContext, int>,
ops::Flatten2GradKernel<paddle::platform::CUDADeviceContext, int8_t>,
ops::Flatten2GradKernel<paddle::platform::CUDADeviceContext, int64_t>);
REGISTER_OP_CUDA_KERNEL(
flatten_contiguous_range,
ops::FlattenContiguousRangeKernel<paddle::platform::CUDADeviceContext,
float>,
ops::FlattenContiguousRangeKernel<paddle::platform::CUDADeviceContext,
plat::float16>,
ops::FlattenContiguousRangeKernel<paddle::platform::CUDADeviceContext,
double>,
ops::FlattenContiguousRangeKernel<paddle::platform::CUDADeviceContext,
uint8_t>,
ops::FlattenContiguousRangeKernel<paddle::platform::CUDADeviceContext, int>,
ops::FlattenContiguousRangeKernel<paddle::platform::CUDADeviceContext,
int8_t>,
ops::FlattenContiguousRangeKernel<paddle::platform::CUDADeviceContext,
int64_t>);
REGISTER_OP_CUDA_KERNEL(
flatten_contiguous_range_grad,
ops::FlattenContiguousRangeGradKernel<paddle::platform::CUDADeviceContext,
float>,
ops::FlattenContiguousRangeGradKernel<paddle::platform::CUDADeviceContext,
plat::float16>,
ops::FlattenContiguousRangeGradKernel<paddle::platform::CUDADeviceContext,
double>,
ops::FlattenContiguousRangeGradKernel<paddle::platform::CUDADeviceContext,
uint8_t>,
ops::FlattenContiguousRangeGradKernel<paddle::platform::CUDADeviceContext,
int>,
ops::FlattenContiguousRangeGradKernel<paddle::platform::CUDADeviceContext,
int8_t>,
ops::FlattenContiguousRangeGradKernel<paddle::platform::CUDADeviceContext,
int64_t>);
......@@ -119,46 +119,5 @@ class Flatten2GradKernel : public framework::OpKernel<T> {
}
};
template <typename DeviceContext, typename T>
class FlattenContiguousRangeKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &context) const override {
auto *in = context.Input<framework::LoDTensor>("X");
auto *out = context.Output<framework::LoDTensor>("Out");
out->mutable_data(context.GetPlace(), in->type());
auto &start_axis = context.Attr<int>("start_axis");
auto &stop_axis = context.Attr<int>("stop_axis");
auto &dev_ctx = context.device_context<DeviceContext>();
// call new kernel
phi::FlattenKernel<T, typename paddle::framework::ConvertToPhiContext<
DeviceContext>::TYPE>(
static_cast<const typename paddle::framework::ConvertToPhiContext<
DeviceContext>::TYPE &>(dev_ctx),
*in, start_axis, stop_axis, out);
}
};
template <typename DeviceContext, typename T>
class FlattenContiguousRangeGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
auto *d_x = ctx.Output<framework::LoDTensor>(framework::GradVarName("X"));
auto *d_out =
ctx.Input<framework::LoDTensor>(framework::GradVarName("Out"));
auto *xshape = ctx.Input<framework::LoDTensor>("XShape");
d_x->mutable_data(ctx.GetPlace(), d_out->type());
auto &dev_ctx = ctx.device_context<DeviceContext>();
// call new kernel
phi::FlattenGradKernel<T, typename paddle::framework::ConvertToPhiContext<
DeviceContext>::TYPE>(
static_cast<const typename paddle::framework::ConvertToPhiContext<
DeviceContext>::TYPE &>(dev_ctx),
*d_out, *xshape, d_x);
}
};
} // namespace operators
} // namespace paddle
......@@ -41,27 +41,4 @@ REGISTER_OP_XPU_KERNEL(
ops::Flatten2GradKernel<paddle::platform::XPUDeviceContext, int>,
ops::Flatten2GradKernel<paddle::platform::XPUDeviceContext, int8_t>,
ops::Flatten2GradKernel<paddle::platform::XPUDeviceContext, int64_t>);
REGISTER_OP_XPU_KERNEL(
flatten_contiguous_range,
ops::FlattenContiguousRangeKernel<paddle::platform::XPUDeviceContext,
float>,
ops::FlattenContiguousRangeKernel<paddle::platform::XPUDeviceContext,
plat::float16>,
ops::FlattenContiguousRangeKernel<paddle::platform::XPUDeviceContext, int>,
ops::FlattenContiguousRangeKernel<paddle::platform::XPUDeviceContext,
int8_t>,
ops::FlattenContiguousRangeKernel<paddle::platform::XPUDeviceContext,
int64_t>);
REGISTER_OP_XPU_KERNEL(
flatten_contiguous_range_grad,
ops::FlattenContiguousRangeGradKernel<paddle::platform::XPUDeviceContext,
float>,
ops::FlattenContiguousRangeGradKernel<paddle::platform::XPUDeviceContext,
plat::float16>,
ops::FlattenContiguousRangeGradKernel<paddle::platform::XPUDeviceContext,
int>,
ops::FlattenContiguousRangeGradKernel<paddle::platform::XPUDeviceContext,
int8_t>,
ops::FlattenContiguousRangeGradKernel<paddle::platform::XPUDeviceContext,
int64_t>);
#endif
......@@ -352,6 +352,14 @@ void FlattenInferMeta(const MetaTensor& x,
int start_axis,
int stop_axis,
MetaTensor* out) {
FlattenWithXShapeInferMeta(x, start_axis, stop_axis, out, nullptr);
}
void FlattenWithXShapeInferMeta(const MetaTensor& x,
int start_axis,
int stop_axis,
MetaTensor* out,
MetaTensor* xshape) {
auto x_dims = x.dims();
int in_dims_size = x_dims.size();
if (start_axis < 0) {
......@@ -394,6 +402,14 @@ void FlattenInferMeta(const MetaTensor& x,
// are the same.
out->share_lod(x);
}
if (xshape == nullptr) return;
std::vector<int64_t> xshape_dims(x_dims.size() + 1);
xshape_dims[0] = 0;
for (int i = 0; i < x_dims.size(); ++i) {
xshape_dims[i + 1] = x_dims[i];
}
xshape->set_dims(phi::make_ddim(xshape_dims));
xshape->share_lod(x);
}
void GumbelSoftmaxInferMeta(const MetaTensor& x,
......
......@@ -86,6 +86,12 @@ void FlattenInferMeta(const MetaTensor& x,
int stop_axis,
MetaTensor* out);
void FlattenWithXShapeInferMeta(const MetaTensor& x,
int start_axis,
int stop_axis,
MetaTensor* out,
MetaTensor* xshape);
void GumbelSoftmaxInferMeta(const MetaTensor& x,
float temperature,
bool hard,
......
......@@ -25,6 +25,7 @@ void FlattenGradKernel(const Context& dev_ctx,
const DenseTensor& xshape,
DenseTensor* x_grad) {
auto xshape_dims = xshape.dims();
dev_ctx.Alloc(x_grad, out_grad.dtype());
auto x_dims = phi::slice_ddim(xshape_dims, 1, xshape_dims.size());
phi::Copy(dev_ctx, out_grad, dev_ctx.GetPlace(), false, x_grad);
x_grad->Resize(x_dims);
......
......@@ -27,6 +27,7 @@ void FlattenKernel(const Context& dev_ctx,
int start_axis,
int stop_axis,
DenseTensor* out) {
dev_ctx.Alloc(out, x.dtype());
auto out_dims = out->dims();
phi::Copy(dev_ctx, x, dev_ctx.GetPlace(), false, out);
out->Resize(out_dims);
......@@ -43,7 +44,6 @@ void FlattenWithXShape(const Context& dev_ctx,
DenseTensor* out,
DenseTensor* xshape) {
FlattenKernel<T, Context>(dev_ctx, x, start_axis, stop_axis, out);
funcs::SetXShape(x, xshape);
}
} // namespace phi
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册