未验证 提交 2a905f6b 编写于 作者: C Chen Weihang 提交者: GitHub

infershape func to infermeta (#37524)

上级 171da2ce
...@@ -52,8 +52,8 @@ PD_DLL_DECL Tensor full(const ScalarArray& shape, ...@@ -52,8 +52,8 @@ PD_DLL_DECL Tensor full(const ScalarArray& shape,
kernel_context.EmplaceBackAttr(pten::ScalarArray(shape)); kernel_context.EmplaceBackAttr(pten::ScalarArray(shape));
kernel_context.EmplaceBackAttr(pten::Scalar(value)); kernel_context.EmplaceBackAttr(pten::Scalar(value));
// 4. InferShape // 4. InferMeta
auto out_meta = pten::FullInferShape(shape, dtype, layout); auto out_meta = pten::FullInferMeta(shape, dtype, layout);
// 5. Prepare outputs // 5. Prepare outputs
const auto allocator = const auto allocator =
...@@ -97,8 +97,8 @@ PD_DLL_DECL Tensor full_like(const Tensor& x, ...@@ -97,8 +97,8 @@ PD_DLL_DECL Tensor full_like(const Tensor& x,
auto dense_x = std::dynamic_pointer_cast<pten::DenseTensor>(x.impl()); auto dense_x = std::dynamic_pointer_cast<pten::DenseTensor>(x.impl());
kernel_context.EmplaceBackAttr(pten::Scalar(value)); kernel_context.EmplaceBackAttr(pten::Scalar(value));
// 4. InferShape // 4. InferMeta
auto out_meta = FullLikeInferShape(dense_x->meta(), dtype, layout); auto out_meta = FullLikeInferMeta(dense_x->meta(), dtype, layout);
// 5. Prepare outputs // 5. Prepare outputs
Tensor out; Tensor out;
......
...@@ -55,8 +55,8 @@ PD_DLL_DECL Tensor dot(const Tensor& x, const Tensor& y) { ...@@ -55,8 +55,8 @@ PD_DLL_DECL Tensor dot(const Tensor& x, const Tensor& y) {
kernel_context.EmplaceBackInput(dense_y); kernel_context.EmplaceBackInput(dense_y);
// TODO(chenweihang): add transform impl // TODO(chenweihang): add transform impl
// 4. InferShape // 4. InferMeta
auto out_meta = DotInferShape(dense_x->meta(), dense_y->meta()); auto out_meta = DotInferMeta(dense_x->meta(), dense_y->meta());
// 5. Prepare outputs // 5. Prepare outputs
Tensor out; Tensor out;
...@@ -95,8 +95,8 @@ PD_DLL_DECL Tensor matmul(const Tensor& x, ...@@ -95,8 +95,8 @@ PD_DLL_DECL Tensor matmul(const Tensor& x,
kernel_context.EmplaceBackAttr(transpose_y); kernel_context.EmplaceBackAttr(transpose_y);
// TODO(chenweihang): add transform impl // TODO(chenweihang): add transform impl
// 4. InferShape // 4. InferMeta
auto out_meta = MatmulInferShape( auto out_meta = MatmulInferMeta(
dense_x->meta(), dense_y->meta(), transpose_x, transpose_y); dense_x->meta(), dense_y->meta(), transpose_x, transpose_y);
// 5. Prepare outputs // 5. Prepare outputs
......
...@@ -50,8 +50,8 @@ PD_DLL_DECL Tensor flatten(const Tensor& x, int start_axis, int stop_axis) { ...@@ -50,8 +50,8 @@ PD_DLL_DECL Tensor flatten(const Tensor& x, int start_axis, int stop_axis) {
kernel_context.EmplaceBackAttr(start_axis); kernel_context.EmplaceBackAttr(start_axis);
kernel_context.EmplaceBackAttr(stop_axis); kernel_context.EmplaceBackAttr(stop_axis);
// 4. InferShape // 4. InferMeta
auto out_meta = FlattenInferShape(dense_x->meta(), start_axis, stop_axis); auto out_meta = FlattenInferMeta(dense_x->meta(), start_axis, stop_axis);
// 5. Prepare outputs // 5. Prepare outputs
Tensor out; Tensor out;
...@@ -84,7 +84,7 @@ PD_DLL_DECL Tensor cast(const Tensor& x, DataType out_dtype) { ...@@ -84,7 +84,7 @@ PD_DLL_DECL Tensor cast(const Tensor& x, DataType out_dtype) {
kernel_context.EmplaceBackAttr(out_dtype); kernel_context.EmplaceBackAttr(out_dtype);
kernel_context.EmplaceBackAttr(dense_x->meta().dtype); kernel_context.EmplaceBackAttr(dense_x->meta().dtype);
// 4. InferShape // 4. InferMeta
auto out_meta = CastInferMeta(dense_x->meta(), out_dtype); auto out_meta = CastInferMeta(dense_x->meta(), out_dtype);
// 5. Prepare outputs // 5. Prepare outputs
...@@ -117,8 +117,8 @@ PD_DLL_DECL Tensor reshape(const Tensor& x, const std::vector<int64_t>& shape) { ...@@ -117,8 +117,8 @@ PD_DLL_DECL Tensor reshape(const Tensor& x, const std::vector<int64_t>& shape) {
kernel_context.EmplaceBackInput(dense_x); kernel_context.EmplaceBackInput(dense_x);
kernel_context.EmplaceBackAttr(shape); kernel_context.EmplaceBackAttr(shape);
// 4. InferShape // 4. InferMeta
auto out_meta = InferShapeFromVecValue(dense_x->meta(), shape); auto out_meta = InferMetaFromVecValue(dense_x->meta(), shape);
// 5. Prepare outputs // 5. Prepare outputs
Tensor out; Tensor out;
......
...@@ -50,8 +50,8 @@ PD_DLL_DECL Tensor mean(const Tensor& x) { ...@@ -50,8 +50,8 @@ PD_DLL_DECL Tensor mean(const Tensor& x) {
auto dense_x = std::dynamic_pointer_cast<pten::DenseTensor>(x.impl()); auto dense_x = std::dynamic_pointer_cast<pten::DenseTensor>(x.impl());
kernel_context.EmplaceBackInput(dense_x); kernel_context.EmplaceBackInput(dense_x);
// 4. InferShape // 4. InferMeta
auto out_meta = ReductionInferShape(dense_x->meta()); auto out_meta = ReductionInferMeta(dense_x->meta());
// 5. Prepare outputs // 5. Prepare outputs
Tensor out; Tensor out;
...@@ -86,8 +86,8 @@ PD_DLL_DECL Tensor add(const Tensor& x, const Tensor& y) { ...@@ -86,8 +86,8 @@ PD_DLL_DECL Tensor add(const Tensor& x, const Tensor& y) {
kernel_context.EmplaceBackInput(dense_y); kernel_context.EmplaceBackInput(dense_y);
kernel_context.EmplaceBackAttr(-1); kernel_context.EmplaceBackAttr(-1);
// 4. InferShape // 4. InferMeta
auto out_meta = ElementwiseInferShape(dense_x->meta(), dense_y->meta(), -1); auto out_meta = ElementwiseInferMeta(dense_x->meta(), dense_y->meta(), -1);
// 5. Prepare outputs // 5. Prepare outputs
Tensor out; Tensor out;
...@@ -121,8 +121,8 @@ PD_DLL_DECL Tensor subtract(const Tensor& x, const Tensor& y) { ...@@ -121,8 +121,8 @@ PD_DLL_DECL Tensor subtract(const Tensor& x, const Tensor& y) {
kernel_context.EmplaceBackInput(dense_y); kernel_context.EmplaceBackInput(dense_y);
kernel_context.EmplaceBackAttr(-1); kernel_context.EmplaceBackAttr(-1);
// 4. InferShape // 4. InferMeta
auto out_meta = ElementwiseInferShape(dense_x->meta(), dense_y->meta(), -1); auto out_meta = ElementwiseInferMeta(dense_x->meta(), dense_y->meta(), -1);
// 5. Prepare outputs // 5. Prepare outputs
Tensor out; Tensor out;
...@@ -156,8 +156,8 @@ PD_DLL_DECL Tensor divide(const Tensor& x, const Tensor& y) { ...@@ -156,8 +156,8 @@ PD_DLL_DECL Tensor divide(const Tensor& x, const Tensor& y) {
kernel_context.EmplaceBackInput(dense_y); kernel_context.EmplaceBackInput(dense_y);
kernel_context.EmplaceBackAttr(-1); kernel_context.EmplaceBackAttr(-1);
// 4. InferShape // 4. InferMeta
auto out_meta = ElementwiseInferShape(dense_x->meta(), dense_y->meta(), -1); auto out_meta = ElementwiseInferMeta(dense_x->meta(), dense_y->meta(), -1);
// 5. Prepare outputs // 5. Prepare outputs
Tensor out; Tensor out;
...@@ -191,8 +191,8 @@ PD_DLL_DECL Tensor multiply(const Tensor& x, const Tensor& y) { ...@@ -191,8 +191,8 @@ PD_DLL_DECL Tensor multiply(const Tensor& x, const Tensor& y) {
kernel_context.EmplaceBackInput(dense_y); kernel_context.EmplaceBackInput(dense_y);
kernel_context.EmplaceBackAttr(-1); kernel_context.EmplaceBackAttr(-1);
// 4. InferShape // 4. InferMeta
auto out_meta = ElementwiseInferShape(dense_x->meta(), dense_y->meta(), -1); auto out_meta = ElementwiseInferMeta(dense_x->meta(), dense_y->meta(), -1);
// 5. Prepare outputs // 5. Prepare outputs
Tensor out; Tensor out;
......
...@@ -55,7 +55,7 @@ PD_DLL_DECL Tensor copy_to(const Tensor& x, Backend backend, bool blocking) { ...@@ -55,7 +55,7 @@ PD_DLL_DECL Tensor copy_to(const Tensor& x, Backend backend, bool blocking) {
kernel_context.EmplaceBackAttr(blocking); kernel_context.EmplaceBackAttr(blocking);
// 4. InferMeta // 4. InferMeta
auto out_meta = UnchangedInferShape(dense_x->meta()); auto out_meta = UnchangedInferMeta(dense_x->meta());
// 5. Prepare outputs // 5. Prepare outputs
const auto allocator = const auto allocator =
......
...@@ -31,7 +31,7 @@ DenseTensor FillAnyLike( ...@@ -31,7 +31,7 @@ DenseTensor FillAnyLike(
DataType dtype = DataType::UNDEFINED, DataType dtype = DataType::UNDEFINED,
Backend backend = Backend::UNDEFINED, // Is backend needed here? Backend backend = Backend::UNDEFINED, // Is backend needed here?
DataLayout layout = DataLayout::UNDEFINED) { DataLayout layout = DataLayout::UNDEFINED) {
auto out_meta = FullLikeInferShape(x.meta(), dtype, layout); auto out_meta = FullLikeInferMeta(x.meta(), dtype, layout);
const auto allocator = const auto allocator =
std::make_shared<paddle::experimental::DefaultAllocator>( std::make_shared<paddle::experimental::DefaultAllocator>(
dev_ctx.GetPlace()); dev_ctx.GetPlace());
......
...@@ -26,7 +26,7 @@ template <typename T, typename ContextT> ...@@ -26,7 +26,7 @@ template <typename T, typename ContextT>
DenseTensor Dot(const ContextT& dev_ctx, DenseTensor Dot(const ContextT& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const DenseTensor& y) { const DenseTensor& y) {
auto out_meta = DotInferShape(x.meta(), y.meta()); auto out_meta = DotInferMeta(x.meta(), y.meta());
const auto allocator = const auto allocator =
std::make_shared<paddle::experimental::DefaultAllocator>( std::make_shared<paddle::experimental::DefaultAllocator>(
dev_ctx.GetPlace()); dev_ctx.GetPlace());
......
...@@ -28,7 +28,7 @@ DenseTensor Flatten(const ContextT& dev_ctx, ...@@ -28,7 +28,7 @@ DenseTensor Flatten(const ContextT& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
int start_axis, int start_axis,
int stop_axis) { int stop_axis) {
auto out_meta = FlattenInferShape(x.meta(), start_axis, stop_axis); auto out_meta = FlattenInferMeta(x.meta(), start_axis, stop_axis);
const auto allocator = const auto allocator =
std::make_shared<paddle::experimental::DefaultAllocator>( std::make_shared<paddle::experimental::DefaultAllocator>(
dev_ctx.GetPlace()); dev_ctx.GetPlace());
...@@ -55,7 +55,7 @@ template <typename T, typename ContextT> ...@@ -55,7 +55,7 @@ template <typename T, typename ContextT>
DenseTensor Reshape(const ContextT& dev_ctx, DenseTensor Reshape(const ContextT& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const std::vector<int64_t>& shape) { const std::vector<int64_t>& shape) {
auto out_meta = InferShapeFromVecValue(x.meta(), shape); auto out_meta = InferMetaFromVecValue(x.meta(), shape);
const auto allocator = const auto allocator =
std::make_shared<paddle::experimental::DefaultAllocator>( std::make_shared<paddle::experimental::DefaultAllocator>(
dev_ctx.GetPlace()); dev_ctx.GetPlace());
......
...@@ -24,7 +24,7 @@ namespace pten { ...@@ -24,7 +24,7 @@ namespace pten {
template <typename T, typename ContextT> template <typename T, typename ContextT>
DenseTensor Sign(const ContextT& dev_ctx, const DenseTensor& x) { DenseTensor Sign(const ContextT& dev_ctx, const DenseTensor& x) {
auto out_meta = UnchangedInferShape(x.meta()); auto out_meta = UnchangedInferMeta(x.meta());
const auto allocator = const auto allocator =
std::make_shared<paddle::experimental::DefaultAllocator>( std::make_shared<paddle::experimental::DefaultAllocator>(
dev_ctx.GetPlace()); dev_ctx.GetPlace());
...@@ -35,7 +35,7 @@ DenseTensor Sign(const ContextT& dev_ctx, const DenseTensor& x) { ...@@ -35,7 +35,7 @@ DenseTensor Sign(const ContextT& dev_ctx, const DenseTensor& x) {
template <typename T, typename ContextT> template <typename T, typename ContextT>
DenseTensor Mean(const ContextT& dev_ctx, const DenseTensor& x) { DenseTensor Mean(const ContextT& dev_ctx, const DenseTensor& x) {
auto out_meta = ReductionInferShape(x.meta()); auto out_meta = ReductionInferMeta(x.meta());
const auto allocator = const auto allocator =
std::make_shared<paddle::experimental::DefaultAllocator>( std::make_shared<paddle::experimental::DefaultAllocator>(
dev_ctx.GetPlace()); dev_ctx.GetPlace());
...@@ -50,7 +50,7 @@ DenseTensor Scale(const ContextT& dev_ctx, ...@@ -50,7 +50,7 @@ DenseTensor Scale(const ContextT& dev_ctx,
float scale, float scale,
float bias, float bias,
bool bias_after_scale) { bool bias_after_scale) {
auto out_meta = UnchangedInferShape(x.meta()); auto out_meta = UnchangedInferMeta(x.meta());
const auto allocator = const auto allocator =
std::make_shared<paddle::experimental::DefaultAllocator>( std::make_shared<paddle::experimental::DefaultAllocator>(
dev_ctx.GetPlace()); dev_ctx.GetPlace());
...@@ -65,7 +65,7 @@ DenseTensor Scale(const ContextT& dev_ctx, ...@@ -65,7 +65,7 @@ DenseTensor Scale(const ContextT& dev_ctx,
const DenseTensor& scale, const DenseTensor& scale,
float bias, float bias,
bool bias_after_scale) { bool bias_after_scale) {
auto out_meta = UnchangedInferShape(x.meta()); auto out_meta = UnchangedInferMeta(x.meta());
const auto allocator = const auto allocator =
std::make_shared<paddle::experimental::DefaultAllocator>( std::make_shared<paddle::experimental::DefaultAllocator>(
dev_ctx.GetPlace()); dev_ctx.GetPlace());
...@@ -79,7 +79,7 @@ DenseTensor Add(const ContextT& dev_ctx, ...@@ -79,7 +79,7 @@ DenseTensor Add(const ContextT& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const DenseTensor& y, const DenseTensor& y,
int axis) { int axis) {
auto out_meta = ElementwiseInferShape(x.meta(), y.meta(), axis); auto out_meta = ElementwiseInferMeta(x.meta(), y.meta(), axis);
const auto allocator = const auto allocator =
std::make_shared<paddle::experimental::DefaultAllocator>( std::make_shared<paddle::experimental::DefaultAllocator>(
dev_ctx.GetPlace()); dev_ctx.GetPlace());
...@@ -93,7 +93,7 @@ DenseTensor Subtract(const ContextT& dev_ctx, ...@@ -93,7 +93,7 @@ DenseTensor Subtract(const ContextT& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const DenseTensor& y, const DenseTensor& y,
int axis) { int axis) {
auto out_meta = ElementwiseInferShape(x.meta(), y.meta(), axis); auto out_meta = ElementwiseInferMeta(x.meta(), y.meta(), axis);
const auto allocator = const auto allocator =
std::make_shared<paddle::experimental::DefaultAllocator>( std::make_shared<paddle::experimental::DefaultAllocator>(
dev_ctx.GetPlace()); dev_ctx.GetPlace());
...@@ -107,7 +107,7 @@ DenseTensor Divide(const ContextT& dev_ctx, ...@@ -107,7 +107,7 @@ DenseTensor Divide(const ContextT& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const DenseTensor& y, const DenseTensor& y,
int axis) { int axis) {
auto out_meta = ElementwiseInferShape(x.meta(), y.meta(), axis); auto out_meta = ElementwiseInferMeta(x.meta(), y.meta(), axis);
const auto allocator = const auto allocator =
std::make_shared<paddle::experimental::DefaultAllocator>( std::make_shared<paddle::experimental::DefaultAllocator>(
dev_ctx.GetPlace()); dev_ctx.GetPlace());
...@@ -121,7 +121,7 @@ DenseTensor Multiply(const ContextT& dev_ctx, ...@@ -121,7 +121,7 @@ DenseTensor Multiply(const ContextT& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const DenseTensor& y, const DenseTensor& y,
int axis) { int axis) {
auto out_meta = ElementwiseInferShape(x.meta(), y.meta(), axis); auto out_meta = ElementwiseInferMeta(x.meta(), y.meta(), axis);
const auto allocator = const auto allocator =
std::make_shared<paddle::experimental::DefaultAllocator>( std::make_shared<paddle::experimental::DefaultAllocator>(
dev_ctx.GetPlace()); dev_ctx.GetPlace());
......
...@@ -18,7 +18,7 @@ limitations under the License. */ ...@@ -18,7 +18,7 @@ limitations under the License. */
namespace pten { namespace pten {
DenseTensorMeta DotInferShape(const DenseTensorMeta& x_meta, DenseTensorMeta DotInferMeta(const DenseTensorMeta& x_meta,
const DenseTensorMeta& y_meta) { const DenseTensorMeta& y_meta) {
auto x_dims = x_meta.dims; auto x_dims = x_meta.dims;
auto x_rank = static_cast<size_t>(x_dims.size()); auto x_rank = static_cast<size_t>(x_dims.size());
...@@ -60,7 +60,7 @@ DenseTensorMeta DotInferShape(const DenseTensorMeta& x_meta, ...@@ -60,7 +60,7 @@ DenseTensorMeta DotInferShape(const DenseTensorMeta& x_meta,
return return_meta; return return_meta;
} }
DenseTensorMeta MatmulInferShape(const DenseTensorMeta& x_meta, DenseTensorMeta MatmulInferMeta(const DenseTensorMeta& x_meta,
const DenseTensorMeta& y_meta, const DenseTensorMeta& y_meta,
bool trans_x, bool trans_x,
bool trans_y) { bool trans_y) {
...@@ -130,7 +130,7 @@ DenseTensorMeta MatmulInferShape(const DenseTensorMeta& x_meta, ...@@ -130,7 +130,7 @@ DenseTensorMeta MatmulInferShape(const DenseTensorMeta& x_meta,
return {x_meta.dtype, ddim_out, x_meta.layout}; return {x_meta.dtype, ddim_out, x_meta.layout};
} }
DenseTensorMeta ElementwiseInferShape(const DenseTensorMeta& x_meta, DenseTensorMeta ElementwiseInferMeta(const DenseTensorMeta& x_meta,
const DenseTensorMeta& y_meta, const DenseTensorMeta& y_meta,
int axis) { int axis) {
DenseTensorMeta return_meta(x_meta.dtype, x_meta.dims, x_meta.layout); DenseTensorMeta return_meta(x_meta.dtype, x_meta.dims, x_meta.layout);
......
...@@ -19,29 +19,29 @@ limitations under the License. */ ...@@ -19,29 +19,29 @@ limitations under the License. */
namespace pten { namespace pten {
// Common InferShape Functions for binary operators, The format like: // Common InferMeta Functions for binary operators, The format like:
// //
// 1. DenseTensorMeta [OpName]InferShape(const DenseTensorMeta& x_meta, ...) // 1. DenseTensorMeta [OpName]InferMeta(const DenseTensorMeta& x_meta, ...)
// {} // {}
// 2. std::pair<DenseTensorMeta, DenseTensorMeta> [OpName]InferShape(const // 2. std::pair<DenseTensorMeta, DenseTensorMeta> [OpName]InferMeta(const
// DenseTensorMeta& // DenseTensorMeta&
// x_meta, ...) {} // x_meta, ...) {}
// 3. std::tuple<DenseTensorMeta, DenseTensorMeta, DenseTensorMeta> // 3. std::tuple<DenseTensorMeta, DenseTensorMeta, DenseTensorMeta>
// [OpName]InferShape(const // [OpName]InferMeta(const
// DenseTensorMeta& x_meta, ...) // DenseTensorMeta& x_meta, ...)
// NOTE: The name "InferShape" may be not appropriate. "InferMeta" may be good. // NOTE: The name "InferMeta" may be not appropriate. "InferMeta" may be good.
// Because functions in this file // Because functions in this file
// not only can infer shape, but alse need infer lod or other useful data. // not only can infer shape, but alse need infer lod or other useful data.
DenseTensorMeta DotInferShape(const DenseTensorMeta& x_meta, DenseTensorMeta DotInferMeta(const DenseTensorMeta& x_meta,
const DenseTensorMeta& y_meta); const DenseTensorMeta& y_meta);
DenseTensorMeta MatmulInferShape(const DenseTensorMeta& x_meta, DenseTensorMeta MatmulInferMeta(const DenseTensorMeta& x_meta,
const DenseTensorMeta& y_meta, const DenseTensorMeta& y_meta,
bool trans_x, bool trans_x,
bool trans_y); bool trans_y);
DenseTensorMeta ElementwiseInferShape(const DenseTensorMeta& x_meta, DenseTensorMeta ElementwiseInferMeta(const DenseTensorMeta& x_meta,
const DenseTensorMeta& y_meta, const DenseTensorMeta& y_meta,
int axis); int axis);
} // namespace pten } // namespace pten
...@@ -17,14 +17,14 @@ limitations under the License. */ ...@@ -17,14 +17,14 @@ limitations under the License. */
namespace pten { namespace pten {
DenseTensorMeta FullInferShape(const std::vector<int64_t>& shape, DenseTensorMeta FullInferMeta(const std::vector<int64_t>& shape,
DataType dtype, DataType dtype,
DataLayout layout) { DataLayout layout) {
const auto& out_dims = paddle::framework::make_ddim(shape); const auto& out_dims = paddle::framework::make_ddim(shape);
return {dtype, out_dims, layout}; return {dtype, out_dims, layout};
} }
DenseTensorMeta FullInferShape(const ScalarArray& shape, DenseTensorMeta FullInferMeta(const ScalarArray& shape,
DataType dtype, DataType dtype,
DataLayout layout) { DataLayout layout) {
const auto& out_dims = paddle::framework::make_ddim(shape.GetData()); const auto& out_dims = paddle::framework::make_ddim(shape.GetData());
......
...@@ -19,19 +19,19 @@ limitations under the License. */ ...@@ -19,19 +19,19 @@ limitations under the License. */
namespace pten { namespace pten {
// Common InferShape Functions for 0-nary operators(no input tensor), The format // Common InferMeta Functions for 0-nary operators(no input tensor), The format
// like: // like:
// //
// 1. DenseTensorMeta [OpName]InferShape( ...) // 1. DenseTensorMeta [OpName]InferMeta( ...)
// NOTE: The name "InferShape" may be not appropriate. "InferMeta" may be good. // NOTE: The name "InferMeta" may be not appropriate. "InferMeta" may be good.
// Because functions in this file // Because functions in this file
// not only can infer shape, but alse need infer lod or other useful data. // not only can infer shape, but alse need infer lod or other useful data.
DenseTensorMeta FullInferShape(const std::vector<int64_t>& shape, DenseTensorMeta FullInferMeta(const std::vector<int64_t>& shape,
DataType dtype, DataType dtype,
DataLayout layout); DataLayout layout);
DenseTensorMeta FullInferShape(const ScalarArray& shape, DenseTensorMeta FullInferMeta(const ScalarArray& shape,
DataType dtype, DataType dtype,
DataLayout layout); DataLayout layout);
......
...@@ -17,17 +17,17 @@ limitations under the License. */ ...@@ -17,17 +17,17 @@ limitations under the License. */
namespace pten { namespace pten {
DenseTensorMeta UnchangedInferShape(const DenseTensorMeta& x_meta) { DenseTensorMeta UnchangedInferMeta(const DenseTensorMeta& x_meta) {
return x_meta; return x_meta;
} }
DenseTensorMeta ReductionInferShape(const DenseTensorMeta& x_meta) { DenseTensorMeta ReductionInferMeta(const DenseTensorMeta& x_meta) {
const auto& out_dims = paddle::framework::make_ddim({1}); const auto& out_dims = paddle::framework::make_ddim({1});
DenseTensorMeta return_meta(x_meta.dtype, out_dims, x_meta.layout); DenseTensorMeta return_meta(x_meta.dtype, out_dims, x_meta.layout);
return return_meta; return return_meta;
} }
DenseTensorMeta FlattenInferShape(const DenseTensorMeta& x_meta, DenseTensorMeta FlattenInferMeta(const DenseTensorMeta& x_meta,
int start_axis, int start_axis,
int stop_axis) { int stop_axis) {
auto& x_dims = x_meta.dims; auto& x_dims = x_meta.dims;
...@@ -80,7 +80,7 @@ DenseTensorMeta CastInferMeta(const DenseTensorMeta& x_meta, ...@@ -80,7 +80,7 @@ DenseTensorMeta CastInferMeta(const DenseTensorMeta& x_meta,
return out_meta; return out_meta;
} }
DenseTensorMeta FullLikeInferShape(const DenseTensorMeta& x_meta, DenseTensorMeta FullLikeInferMeta(const DenseTensorMeta& x_meta,
DataType dtype, DataType dtype,
DataLayout layout) { DataLayout layout) {
return {dtype == DataType::UNDEFINED ? x_meta.dtype : dtype, return {dtype == DataType::UNDEFINED ? x_meta.dtype : dtype,
...@@ -208,7 +208,7 @@ static paddle::framework::DDim ValidateShape( ...@@ -208,7 +208,7 @@ static paddle::framework::DDim ValidateShape(
return paddle::framework::make_ddim(output_shape); return paddle::framework::make_ddim(output_shape);
} }
DenseTensorMeta InferShapeFromVecValue(const DenseTensorMeta& x_meta, DenseTensorMeta InferMetaFromVecValue(const DenseTensorMeta& x_meta,
const std::vector<int64_t>& shape) { const std::vector<int64_t>& shape) {
PADDLE_ENFORCE_EQ(!shape.empty(), PADDLE_ENFORCE_EQ(!shape.empty(),
true, true,
......
...@@ -19,34 +19,34 @@ limitations under the License. */ ...@@ -19,34 +19,34 @@ limitations under the License. */
namespace pten { namespace pten {
// Common InferShape Functions for unary operators, The format like: // Common InferMeta Functions for unary operators, The format like:
// //
// 1. DenseTensorMeta [OpName]InferShape(const DenseTensorMeta& x_meta, ...) // 1. DenseTensorMeta [OpName]InferMeta(const DenseTensorMeta& x_meta, ...)
// {} // {}
// 2. std::pair<DenseTensorMeta, DenseTensorMeta> [OpName]InferShape(const // 2. std::pair<DenseTensorMeta, DenseTensorMeta> [OpName]InferMeta(const
// DenseTensorMeta& // DenseTensorMeta&
// x_meta, ...) {} // x_meta, ...) {}
// 3. std::tuple<DenseTensorMeta, DenseTensorMeta, DenseTensorMeta> // 3. std::tuple<DenseTensorMeta, DenseTensorMeta, DenseTensorMeta>
// [OpName]InferShape(const // [OpName]InferMeta(const
// DenseTensorMeta& x_meta, ...) // DenseTensorMeta& x_meta, ...)
// NOTE: The name "InferShape" may be not appropriate. "InferMeta" may be good. // NOTE: The name "InferMeta" may be not appropriate. "InferMeta" may be good.
// Because functions in this file // Because functions in this file
// not only can infer shape, but alse need infer lod or other useful data. // not only can infer shape, but alse need infer lod or other useful data.
DenseTensorMeta UnchangedInferShape(const DenseTensorMeta& x_meta); DenseTensorMeta UnchangedInferMeta(const DenseTensorMeta& x_meta);
DenseTensorMeta ReductionInferShape(const DenseTensorMeta& x_meta); DenseTensorMeta ReductionInferMeta(const DenseTensorMeta& x_meta);
DenseTensorMeta FlattenInferShape(const DenseTensorMeta& x_meta, DenseTensorMeta FlattenInferMeta(const DenseTensorMeta& x_meta,
int start_axis, int start_axis,
int stop_axis); int stop_axis);
DenseTensorMeta CastInferMeta(const DenseTensorMeta& x_meta, DenseTensorMeta CastInferMeta(const DenseTensorMeta& x_meta,
const DataType out_dtype); const DataType out_dtype);
DenseTensorMeta FullLikeInferShape(const DenseTensorMeta& x_meta, DenseTensorMeta FullLikeInferMeta(const DenseTensorMeta& x_meta,
DataType dtype, DataType dtype,
DataLayout layout); DataLayout layout);
DenseTensorMeta InferShapeFromVecValue(const DenseTensorMeta& x_meta, DenseTensorMeta InferMetaFromVecValue(const DenseTensorMeta& x_meta,
const std::vector<int64_t>& shape); const std::vector<int64_t>& shape);
} // namespace pten } // namespace pten
...@@ -50,7 +50,7 @@ void ReshapeFromVectorVal(const CPUContext& dev_ctx, ...@@ -50,7 +50,7 @@ void ReshapeFromVectorVal(const CPUContext& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const std::vector<int64_t>& shape, const std::vector<int64_t>& shape,
DenseTensor* out) { DenseTensor* out) {
auto out_meta = InferShapeFromVecValue(x.meta(), shape); auto out_meta = InferMetaFromVecValue(x.meta(), shape);
if (&x == out) { if (&x == out) {
out->Resize(out_meta.dims); out->Resize(out_meta.dims);
return; return;
......
...@@ -50,7 +50,7 @@ void ReshapeFromVectorVal(const CUDAContext& dev_ctx, ...@@ -50,7 +50,7 @@ void ReshapeFromVectorVal(const CUDAContext& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const std::vector<int64_t>& shape, const std::vector<int64_t>& shape,
DenseTensor* out) { DenseTensor* out) {
auto out_meta = InferShapeFromVecValue(x.meta(), shape); auto out_meta = InferMetaFromVecValue(x.meta(), shape);
if (&x == out) { if (&x == out) {
out->Resize(out_meta.dims); out->Resize(out_meta.dims);
return; return;
......
...@@ -55,7 +55,7 @@ void ReshapeFromVectorVal(const XPUContext& dev_ctx, ...@@ -55,7 +55,7 @@ void ReshapeFromVectorVal(const XPUContext& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const std::vector<int64_t>& shape, const std::vector<int64_t>& shape,
DenseTensor* out) { DenseTensor* out) {
auto out_meta = InferShapeFromVecValue(x.meta(), shape); auto out_meta = InferMetaFromVecValue(x.meta(), shape);
if (&x == out) { if (&x == out) {
out->Resize(out_meta.dims); out->Resize(out_meta.dims);
return; return;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册