未验证 提交 7e29cea9 编写于 作者: C Chen Weihang 提交者: GitHub

[PTen] Change all InferMeta functions (#39222)

* change unary infermeta

* change other infermeta

* change all infermeta format

* resolve conflit

* fix test failed

* resolve reshape conflit

* fix compile failed

* adapt auto api gen

* fix reshape failed

* fix concat failed

* resolve conflict
上级 e7b442cc
...@@ -212,11 +212,13 @@ TEST(CustomKernel, custom_kernel_dot) { ...@@ -212,11 +212,13 @@ TEST(CustomKernel, custom_kernel_dot) {
kernel_context.EmplaceBackAttr(fake_attr_int64_vec); kernel_context.EmplaceBackAttr(fake_attr_int64_vec);
kernel_context.EmplaceBackAttr(fake_attr_int_vec); kernel_context.EmplaceBackAttr(fake_attr_int_vec);
auto out_meta = pten::DotInferMeta(dense_x->meta(), dense_y->meta());
auto dense_out = std::make_shared<pten::DenseTensor>( auto dense_out = std::make_shared<pten::DenseTensor>(
pten::make_intrusive<paddle::experimental::SharedStorage>( pten::make_intrusive<paddle::experimental::SharedStorage>(
pten::TransToFluidPlace(backend)), pten::TransToFluidPlace(backend)),
std::move(out_meta)); pten::DenseTensorMeta());
pten::MetaTensor meta_out(dense_out.get());
pten::DotInferMeta(*dense_x, *dense_y, &meta_out);
kernel_context.EmplaceBackOutput(dense_out.get()); // idx:0 index:[0,1) kernel_context.EmplaceBackOutput(dense_out.get()); // idx:0 index:[0,1)
// fake_input_vec: idx:1, index:[1,3) // fake_input_vec: idx:1, index:[1,3)
......
...@@ -186,6 +186,14 @@ class CompatMetaTensor : public pten::MetaTensor { ...@@ -186,6 +186,14 @@ class CompatMetaTensor : public pten::MetaTensor {
} }
} }
void share_meta(const MetaTensor& meta_tensor) override {
set_dims(meta_tensor.dims());
set_dtype(meta_tensor.dtype());
// VarDesc doesn't contains layout, so we cannot share layout
// set_layout(meta_tensor.layout());
share_lod(meta_tensor);
}
private: private:
const LoD& GetRuntimeLoD() const { const LoD& GetRuntimeLoD() const {
auto* var = BOOST_GET_CONST(Variable*, var_); auto* var = BOOST_GET_CONST(Variable*, var_);
......
...@@ -18,6 +18,7 @@ limitations under the License. */ ...@@ -18,6 +18,7 @@ limitations under the License. */
#include "paddle/pten/api/lib/utils/storage.h" #include "paddle/pten/api/lib/utils/storage.h"
#include "paddle/pten/core/compat/convert_utils.h" #include "paddle/pten/core/compat/convert_utils.h"
#include "paddle/pten/core/dense_tensor.h" #include "paddle/pten/core/dense_tensor.h"
#include "paddle/pten/core/meta_tensor.h"
namespace paddle { namespace paddle {
namespace experimental { namespace experimental {
...@@ -44,44 +45,38 @@ inline std::unique_ptr<std::vector<pten::DenseTensor>> TensorToDenseTensor( ...@@ -44,44 +45,38 @@ inline std::unique_ptr<std::vector<pten::DenseTensor>> TensorToDenseTensor(
/* ----------------- for infer_meta --------------------- */ /* ----------------- for infer_meta --------------------- */
inline const pten::DenseTensorMeta& GetDenseTensorMeta( inline pten::MetaTensor MakeMetaTensor(const pten::DenseTensor& tensor) {
const pten::DenseTensor& tensor) { return pten::MetaTensor(tensor);
return tensor.meta();
} }
inline std::vector<pten::DenseTensorMeta> GetDenseTensorMeta( inline std::vector<pten::MetaTensor> MakeMetaTensor(
const std::vector<pten::DenseTensor>& tensors) { const std::vector<pten::DenseTensor>& tensors) {
std::vector<pten::DenseTensorMeta> metas; std::vector<pten::MetaTensor> meta_tensors;
metas.reserve(tensors.size()); meta_tensors.reserve(tensors.size());
for (const auto& t : tensors) { for (const auto& t : tensors) {
metas.push_back(t.meta()); meta_tensors.emplace_back(t);
} }
return metas; return meta_tensors;
} }
/* ------------------ for output ----------------------- */ /* ------------------ for output ----------------------- */
inline pten::DenseTensor* SetKernelOutput(const pten::DenseTensorMeta& meta, inline pten::DenseTensor* SetKernelOutput(Backend backend, Tensor* out) {
Backend backend,
Tensor* out) {
auto dense_tensor = std::make_shared<pten::DenseTensor>( auto dense_tensor = std::make_shared<pten::DenseTensor>(
pten::make_intrusive<SharedStorage>(pten::TransToFluidPlace(backend)), pten::make_intrusive<SharedStorage>(pten::TransToFluidPlace(backend)),
meta); pten::DenseTensorMeta());
out->set_impl(dense_tensor); out->set_impl(dense_tensor);
return dense_tensor.get(); return dense_tensor.get();
} }
inline std::vector<pten::DenseTensor*> SetKernelOutput( inline std::vector<pten::DenseTensor*> SetKernelOutput(
const std::vector<pten::DenseTensorMeta>& metas, size_t out_size, Backend backend, std::vector<Tensor>* out) {
Backend backend, out->reserve(out_size);
std::vector<Tensor>* out) { std::vector<pten::DenseTensor*> results(out_size);
size_t n = metas.size(); for (size_t i = 0; i < out_size; ++i) {
out->reserve(n);
std::vector<pten::DenseTensor*> results(n);
for (size_t i = 0; i < n; ++i) {
auto tensor_ptr = std::make_shared<pten::DenseTensor>( auto tensor_ptr = std::make_shared<pten::DenseTensor>(
pten::make_intrusive<SharedStorage>(pten::TransToFluidPlace(backend)), pten::make_intrusive<SharedStorage>(pten::TransToFluidPlace(backend)),
metas[i]); pten::DenseTensorMeta());
results[i] = tensor_ptr.get(); results[i] = tensor_ptr.get();
out->emplace_back(); out->emplace_back();
out->back().set_impl(tensor_ptr); out->back().set_impl(tensor_ptr);
......
...@@ -57,20 +57,19 @@ PADDLE_API Tensor copy_to(const Tensor& x, Backend backend, bool blocking) { ...@@ -57,20 +57,19 @@ PADDLE_API Tensor copy_to(const Tensor& x, Backend backend, bool blocking) {
kernel_context.EmplaceBackInput(dense_x.get()); kernel_context.EmplaceBackInput(dense_x.get());
kernel_context.EmplaceBackAttr(blocking); kernel_context.EmplaceBackAttr(blocking);
// 4. InferMeta // 4. Prepare outputs & InferMeta
auto out_meta = UnchangedInferMeta(dense_x->meta());
// 5. Prepare outputs
auto dense_out = std::make_shared<pten::DenseTensor>( auto dense_out = std::make_shared<pten::DenseTensor>(
pten::make_intrusive<paddle::experimental::SharedStorage>( pten::make_intrusive<paddle::experimental::SharedStorage>(
pten::TransToFluidPlace(backend)), pten::TransToFluidPlace(backend)),
std::move(out_meta)); pten::DenseTensorMeta());
pten::MetaTensor meta_out(dense_out.get());
pten::UnchangedInferMeta(*dense_x, &meta_out);
dense_out->mutable_data(pten::TransToFluidPlace(backend)); dense_out->mutable_data(pten::TransToFluidPlace(backend));
kernel_context.EmplaceBackOutput(dense_out.get()); kernel_context.EmplaceBackOutput(dense_out.get());
Tensor out; Tensor out;
out.set_impl(dense_out); out.set_impl(dense_out);
// 6. Call kernel // 5. Call kernel
kernel(&kernel_context); kernel(&kernel_context);
return out; return out;
......
...@@ -26,16 +26,6 @@ limitations under the License. */ ...@@ -26,16 +26,6 @@ limitations under the License. */
namespace pten { namespace pten {
// TODO(chenweihang): add other flags if needed
struct MetaConfig {
bool is_runtime{true};
MetaConfig() = default;
// supporting implicit construction is easier to use
MetaConfig(bool is_runtime) : is_runtime(is_runtime) {} // NOLINT
};
class InferMetaContext { class InferMetaContext {
public: public:
InferMetaContext() = default; InferMetaContext() = default;
......
...@@ -33,7 +33,7 @@ void MetaTensor::set_dims(const DDim& dims) { ...@@ -33,7 +33,7 @@ void MetaTensor::set_dims(const DDim& dims) {
DenseTensorUtils::GetMutableMeta(static_cast<DenseTensor*>(tensor_))->dims = DenseTensorUtils::GetMutableMeta(static_cast<DenseTensor*>(tensor_))->dims =
dims; dims;
} else { } else {
PADDLE_THROW(paddle::platform::errors::Unimplemented( PADDLE_THROW(pten::errors::Unimplemented(
"Unsupported setting dims for `%s`.", tensor_->type_info().name())); "Unsupported setting dims for `%s`.", tensor_->type_info().name()));
} }
} }
...@@ -43,7 +43,7 @@ void MetaTensor::set_dtype(DataType dtype) { ...@@ -43,7 +43,7 @@ void MetaTensor::set_dtype(DataType dtype) {
DenseTensorUtils::GetMutableMeta(static_cast<DenseTensor*>(tensor_)) DenseTensorUtils::GetMutableMeta(static_cast<DenseTensor*>(tensor_))
->dtype = dtype; ->dtype = dtype;
} else { } else {
PADDLE_THROW(paddle::platform::errors::Unimplemented( PADDLE_THROW(pten::errors::Unimplemented(
"Unsupported settting dtype for `%s`.", tensor_->type_info().name())); "Unsupported settting dtype for `%s`.", tensor_->type_info().name()));
} }
} }
...@@ -53,7 +53,7 @@ void MetaTensor::set_layout(DataLayout layout) { ...@@ -53,7 +53,7 @@ void MetaTensor::set_layout(DataLayout layout) {
DenseTensorUtils::GetMutableMeta(static_cast<DenseTensor*>(tensor_)) DenseTensorUtils::GetMutableMeta(static_cast<DenseTensor*>(tensor_))
->layout = layout; ->layout = layout;
} else { } else {
PADDLE_THROW(paddle::platform::errors::Unimplemented( PADDLE_THROW(pten::errors::Unimplemented(
"Unsupported settting layout for `%s`.", tensor_->type_info().name())); "Unsupported settting layout for `%s`.", tensor_->type_info().name()));
} }
} }
...@@ -63,9 +63,9 @@ void MetaTensor::share_lod(const MetaTensor& meta_tensor) { ...@@ -63,9 +63,9 @@ void MetaTensor::share_lod(const MetaTensor& meta_tensor) {
DenseTensorUtils::GetMutableMeta(static_cast<DenseTensor*>(tensor_))->lod = DenseTensorUtils::GetMutableMeta(static_cast<DenseTensor*>(tensor_))->lod =
meta_tensor.lod(); meta_tensor.lod();
} else { } else {
PADDLE_THROW(paddle::platform::errors::Unimplemented( PADDLE_THROW(
"Unsupported share lod inplace for `%s`.", pten::errors::Unimplemented("Unsupported sharing lod inplace for `%s`.",
tensor_->type_info().name())); tensor_->type_info().name()));
} }
} }
...@@ -73,8 +73,20 @@ const LoD& MetaTensor::lod() const { ...@@ -73,8 +73,20 @@ const LoD& MetaTensor::lod() const {
if (pten::DenseTensor::classof(tensor_)) { if (pten::DenseTensor::classof(tensor_)) {
return static_cast<DenseTensor*>(tensor_)->lod(); return static_cast<DenseTensor*>(tensor_)->lod();
} else { } else {
PADDLE_THROW(paddle::platform::errors::Unimplemented( PADDLE_THROW(pten::errors::Unimplemented("Unsupported getting lod of `%s`.",
"Unsupported setting dims for `%s`.", tensor_->type_info().name())); tensor_->type_info().name()));
}
}
void MetaTensor::share_meta(const MetaTensor& meta_tensor) {
if (pten::DenseTensor::classof(tensor_)) {
set_dims(meta_tensor.dims());
set_dtype(meta_tensor.dtype());
set_layout(meta_tensor.layout());
share_lod(meta_tensor);
} else {
PADDLE_THROW(pten::errors::Unimplemented(
"Unsupported sharing meta for `%s`.", tensor_->type_info().name()));
} }
} }
......
...@@ -23,11 +23,26 @@ limitations under the License. */ ...@@ -23,11 +23,26 @@ limitations under the License. */
namespace pten { namespace pten {
// TODO(chenweihang): add other flags if needed
struct MetaConfig {
bool is_runtime{true};
MetaConfig() = default;
// supporting implicit construction is easier to use
MetaConfig(bool is_runtime) : is_runtime(is_runtime) {} // NOLINT
};
class MetaTensor { class MetaTensor {
public: public:
explicit MetaTensor(TensorBase* tensor) : tensor_(tensor) {}
MetaTensor() = default; MetaTensor() = default;
// supporting implicit construction is easier to use
MetaTensor(TensorBase* tensor) : tensor_(tensor) {} // NOLINT
MetaTensor(const TensorBase& tensor) // NOLINT
: tensor_(const_cast<TensorBase*>(&tensor)) {}
MetaTensor(TensorBase& tensor) : tensor_(&tensor) {} // NOLINT
MetaTensor(const MetaTensor&) = default; MetaTensor(const MetaTensor&) = default;
MetaTensor(MetaTensor&&) = default; MetaTensor(MetaTensor&&) = default;
MetaTensor& operator=(const MetaTensor&) = delete; MetaTensor& operator=(const MetaTensor&) = delete;
...@@ -42,7 +57,9 @@ class MetaTensor { ...@@ -42,7 +57,9 @@ class MetaTensor {
virtual void set_dims(const DDim& dims); virtual void set_dims(const DDim& dims);
virtual void set_dtype(DataType dtype); virtual void set_dtype(DataType dtype);
virtual void set_layout(DataLayout layout); virtual void set_layout(DataLayout layout);
virtual void share_lod(const MetaTensor& meta_tensor); virtual void share_lod(const MetaTensor& meta_tensor);
virtual void share_meta(const MetaTensor& meta_tensor);
private: private:
// Because the lod in compiletime and runtime is different, // Because the lod in compiletime and runtime is different,
......
cc_library(infermeta SRCS nullary.cc unary.cc binary.cc multiary.cc DEPS convert_utils infermeta_utils) cc_library(infermeta SRCS nullary.cc unary.cc binary.cc multiary.cc DEPS convert_utils meta_tensor infermeta_utils)
cc_library(backward_infermeta SRCS backward.cc DEPS convert_utils) cc_library(backward_infermeta SRCS backward.cc DEPS meta_tensor convert_utils)
...@@ -16,13 +16,15 @@ limitations under the License. */ ...@@ -16,13 +16,15 @@ limitations under the License. */
namespace pten { namespace pten {
std::tuple<DenseTensorMeta, DenseTensorMeta> MatmulGradInferMeta( void MatmulGradInferMeta(const MetaTensor& x,
const DenseTensorMeta& x_meta, const MetaTensor& y,
const DenseTensorMeta& y_meta, const MetaTensor& out_grad_meta,
const DenseTensorMeta& out_grad_meta, bool transpose_x,
bool transpose_x, bool transpose_y,
bool transpose_y) { MetaTensor* dx,
return std::make_tuple(x_meta, y_meta); MetaTensor* dy) {
dx->share_meta(x);
dy->share_meta(y);
} }
} // namespace pten } // namespace pten
...@@ -15,15 +15,17 @@ limitations under the License. */ ...@@ -15,15 +15,17 @@ limitations under the License. */
#pragma once #pragma once
#include <tuple> #include <tuple>
#include "paddle/pten/core/tensor_meta.h"
#include "paddle/pten/core/meta_tensor.h"
namespace pten { namespace pten {
std::tuple<DenseTensorMeta, DenseTensorMeta> MatmulGradInferMeta( void MatmulGradInferMeta(const MetaTensor& x,
const DenseTensorMeta& x_meta, const MetaTensor& y,
const DenseTensorMeta& y_meta, const MetaTensor& out_grad_meta,
const DenseTensorMeta& out_grad_meta, bool transpose_x,
bool transpose_x, bool transpose_y,
bool transpose_y); MetaTensor* dx,
MetaTensor* dy);
} // namespace pten } // namespace pten
...@@ -12,15 +12,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,15 +12,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
// See Note [ Why still include the fluid headers? ]
#include "paddle/pten/infermeta/binary.h" #include "paddle/pten/infermeta/binary.h"
#include "paddle/pten/kernels/funcs/common_shape.h" #include "paddle/pten/kernels/funcs/common_shape.h"
namespace pten { namespace pten {
DenseTensorMeta DotInferMeta(const DenseTensorMeta& x_meta, void DotInferMeta(const MetaTensor& x, const MetaTensor& y, MetaTensor* out) {
const DenseTensorMeta& y_meta) { auto x_dims = x.dims();
auto x_dims = x_meta.dims;
auto x_rank = static_cast<size_t>(x_dims.size()); auto x_rank = static_cast<size_t>(x_dims.size());
PADDLE_ENFORCE_EQ(true, PADDLE_ENFORCE_EQ(true,
1 == x_rank || 2 == x_rank, 1 == x_rank || 2 == x_rank,
...@@ -29,10 +27,10 @@ DenseTensorMeta DotInferMeta(const DenseTensorMeta& x_meta, ...@@ -29,10 +27,10 @@ DenseTensorMeta DotInferMeta(const DenseTensorMeta& x_meta,
"should be 1 or 2", "should be 1 or 2",
x_dims.to_str())); x_dims.to_str()));
auto y_dims = y_meta.dims; auto y_dims = y.dims();
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
true, true,
x_rank == (size_t)y_dims.size(), x_rank == static_cast<size_t>(y_dims.size()),
paddle::platform::errors::PreconditionNotMet( paddle::platform::errors::PreconditionNotMet(
"ShapeError: The shape of input tensor Y: %s should match with " "ShapeError: The shape of input tensor Y: %s should match with "
"input tenosr X: %s", "input tenosr X: %s",
...@@ -56,25 +54,27 @@ DenseTensorMeta DotInferMeta(const DenseTensorMeta& x_meta, ...@@ -56,25 +54,27 @@ DenseTensorMeta DotInferMeta(const DenseTensorMeta& x_meta,
y_dims.to_str())); y_dims.to_str()));
x_dims[x_dims.size() - 1] = 1; x_dims[x_dims.size() - 1] = 1;
DenseTensorMeta return_meta(x_meta.dtype, x_dims, x_meta.layout); out->set_dims(x_dims);
return return_meta; out->set_dtype(x.dtype());
out->set_layout(x.layout());
} }
DenseTensorMeta MatmulInferMeta(const DenseTensorMeta& x_meta, void MatmulInferMeta(const MetaTensor& x,
const DenseTensorMeta& y_meta, const MetaTensor& y,
bool trans_x, bool trans_x,
bool trans_y) { bool trans_y,
std::vector<int64_t> dims_x = pten::framework::vectorize(x_meta.dims); MetaTensor* out) {
std::vector<int64_t> dims_y = pten::framework::vectorize(y_meta.dims); std::vector<int64_t> dims_x = pten::framework::vectorize(x.dims());
std::vector<int64_t> dims_y = pten::framework::vectorize(y.dims());
auto ndims_x = dims_x.size(); auto ndims_x = dims_x.size();
auto ndims_y = dims_y.size(); auto ndims_y = dims_y.size();
PADDLE_ENFORCE_GT(ndims_x, PADDLE_ENFORCE_GT(ndims_x,
0, 0UL,
paddle::platform::errors::InvalidArgument( paddle::platform::errors::InvalidArgument(
"The Input(x) dims size must be greater than 0," "The Input(x) dims size must be greater than 0,"
" but reviced dims size is 0. ")); " but reviced dims size is 0. "));
PADDLE_ENFORCE_GT(ndims_y, PADDLE_ENFORCE_GT(ndims_y,
0, 0UL,
paddle::platform::errors::InvalidArgument( paddle::platform::errors::InvalidArgument(
"The Input(y) dims size must be greater than 0," "The Input(y) dims size must be greater than 0,"
" but reviced dims size is 0. ")); " but reviced dims size is 0. "));
...@@ -127,21 +127,24 @@ DenseTensorMeta MatmulInferMeta(const DenseTensorMeta& x_meta, ...@@ -127,21 +127,24 @@ DenseTensorMeta MatmulInferMeta(const DenseTensorMeta& x_meta,
auto ddim_out = pten::framework::make_ddim(new_dims); auto ddim_out = pten::framework::make_ddim(new_dims);
return {x_meta.dtype, ddim_out, x_meta.layout}; out->set_dims(ddim_out);
out->set_dtype(x.dtype());
out->set_layout(x.layout());
} }
DenseTensorMeta ElementwiseInferMeta(const DenseTensorMeta& x_meta, void ElementwiseInferMeta(const MetaTensor& x,
const DenseTensorMeta& y_meta) { const MetaTensor& y,
return ElementwiseRawInferMeta(x_meta, y_meta, -1); MetaTensor* out) {
return ElementwiseRawInferMeta(x, y, -1, std::move(out));
} }
DenseTensorMeta ElementwiseRawInferMeta(const DenseTensorMeta& x_meta, void ElementwiseRawInferMeta(const MetaTensor& x,
const DenseTensorMeta& y_meta, const MetaTensor& y,
int axis) { int axis,
DenseTensorMeta return_meta(x_meta.dtype, x_meta.dims, x_meta.layout); MetaTensor* out) {
if (x_meta.dims != y_meta.dims) { if (x.dims() != y.dims()) {
auto x_dims = x_meta.dims; auto x_dims = x.dims();
auto y_dims = y_meta.dims; auto y_dims = y.dims();
int max_dim = std::max(x_dims.size(), y_dims.size()); int max_dim = std::max(x_dims.size(), y_dims.size());
if (x_dims.size() == y_dims.size()) { if (x_dims.size() == y_dims.size()) {
PADDLE_ENFORCE_EQ((axis == -1) || (axis == 0), PADDLE_ENFORCE_EQ((axis == -1) || (axis == 0),
...@@ -174,10 +177,15 @@ DenseTensorMeta ElementwiseRawInferMeta(const DenseTensorMeta& x_meta, ...@@ -174,10 +177,15 @@ DenseTensorMeta ElementwiseRawInferMeta(const DenseTensorMeta& x_meta,
out_dims_array.data(), out_dims_array.data(),
max_dim, max_dim,
axis); axis);
return_meta.dims = pten::framework::make_ddim(out_dims_array); auto out_dims = pten::framework::make_ddim(out_dims_array);
out->set_dims(out_dims);
} else {
out->set_dims(x.dims());
} }
return_meta.lod = x_meta.lod;
return return_meta; out->set_dtype(x.dtype());
out->set_layout(x.layout());
out->share_lod(x);
} }
} // namespace pten } // namespace pten
...@@ -14,38 +14,35 @@ limitations under the License. */ ...@@ -14,38 +14,35 @@ limitations under the License. */
#pragma once #pragma once
// See Note [ Why still include the fluid headers? ] #include "paddle/pten/core/meta_tensor.h"
#include "paddle/pten/core/tensor_meta.h"
namespace pten { namespace pten {
// Common InferMeta Functions for binary operators, The format like: // Common InferMeta Functions for binary operators, The format like:
// //
// 1. DenseTensorMeta [OpName]InferMeta(const DenseTensorMeta& x_meta, ...) // 1. void [FunctionDesc|OpName]InferMeta(const MetaTensor& x,
// {} // const MetaTensor& y,
// 2. std::pair<DenseTensorMeta, DenseTensorMeta> [OpName]InferMeta(const // ...,
// DenseTensorMeta& // MetaTensor* out) {}
// x_meta, ...) {} //
// 3. std::tuple<DenseTensorMeta, DenseTensorMeta, DenseTensorMeta> // NOTE: The name "InferShape" may be not appropriate. "InferMeta" may be good.
// [OpName]InferMeta(const // Because functions in this file not only can infer shape, but also need
// DenseTensorMeta& x_meta, ...) // infer lod or other useful data.
// NOTE: The name "InferMeta" may be not appropriate. "InferMeta" may be good.
// Because functions in this file void DotInferMeta(const MetaTensor& x, const MetaTensor& y, MetaTensor* out);
// not only can infer shape, but alse need infer lod or other useful data.
void MatmulInferMeta(const MetaTensor& x,
DenseTensorMeta DotInferMeta(const DenseTensorMeta& x_meta, const MetaTensor& y,
const DenseTensorMeta& y_meta); bool trans_x,
bool trans_y,
DenseTensorMeta MatmulInferMeta(const DenseTensorMeta& x_meta, MetaTensor* out);
const DenseTensorMeta& y_meta,
bool trans_x, void ElementwiseInferMeta(const MetaTensor& x,
bool trans_y); const MetaTensor& y,
MetaTensor* out);
DenseTensorMeta ElementwiseInferMeta(const DenseTensorMeta& x_meta,
const DenseTensorMeta& y_meta); void ElementwiseRawInferMeta(const MetaTensor& x_meta,
const MetaTensor& y_meta,
DenseTensorMeta ElementwiseRawInferMeta(const DenseTensorMeta& x_meta, int axis,
const DenseTensorMeta& y_meta, MetaTensor* out);
int axis);
} // namespace pten } // namespace pten
...@@ -18,18 +18,19 @@ limitations under the License. */ ...@@ -18,18 +18,19 @@ limitations under the License. */
#include "paddle/pten/kernels/funcs/concat_funcs.h" #include "paddle/pten/kernels/funcs/concat_funcs.h"
namespace pten { namespace pten {
DenseTensorMeta ConcatInferMeta(const std::vector<DenseTensorMeta>& x_meta, void ConcatInferMeta(const std::vector<MetaTensor>& x,
const Scalar& axis_scalar, const Scalar& axis_scalar,
bool is_runtime) { MetaTensor* out,
PADDLE_ENFORCE_GE(x_meta.size(), MetaConfig config) {
0, PADDLE_ENFORCE_GE(x.size(),
0UL,
paddle::platform::errors::InvalidArgument( paddle::platform::errors::InvalidArgument(
"The size of input meta vector should be greater" "The size of input meta vector should be greater"
"than 0.")); "than 0."));
int axis = axis_scalar.to<int>(); int axis = axis_scalar.to<int>();
// 1. calculate axis // 1. calculate axis
int rank = x_meta[0].dims.size(); int rank = x.at(0).dims().size();
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
axis >= -rank && axis < rank, axis >= -rank && axis < rank,
true, true,
...@@ -44,13 +45,15 @@ DenseTensorMeta ConcatInferMeta(const std::vector<DenseTensorMeta>& x_meta, ...@@ -44,13 +45,15 @@ DenseTensorMeta ConcatInferMeta(const std::vector<DenseTensorMeta>& x_meta,
// 2. calculate out dims // 2. calculate out dims
std::vector<pten::DDim> x_dims; std::vector<pten::DDim> x_dims;
for (auto meta : x_meta) { for (auto& x_t : x) {
x_dims.push_back(meta.dims); x_dims.push_back(x_t.dims());
} }
pten::DDim out_dim = pten::DDim out_dim =
pten::funcs::ComputeAndCheckShape(is_runtime, x_dims, axis); pten::funcs::ComputeAndCheckShape(config.is_runtime, x_dims, axis);
return {x_meta[0].dtype, out_dim, x_meta[0].layout}; out->set_dims(out_dim);
out->set_dtype(x.at(0).dtype());
out->set_layout(x.at(0).layout());
} }
} // namespace pten } // namespace pten
...@@ -15,12 +15,12 @@ limitations under the License. */ ...@@ -15,12 +15,12 @@ limitations under the License. */
#pragma once #pragma once
#include "paddle/pten/common/scalar.h" #include "paddle/pten/common/scalar.h"
#include "paddle/pten/core/tensor_meta.h" #include "paddle/pten/core/meta_tensor.h"
namespace pten { namespace pten {
// TODO(chentianyu03) use std::vector<DenseTensor> as InferMeta inputs void ConcatInferMeta(const std::vector<MetaTensor>& x,
DenseTensorMeta ConcatInferMeta(const std::vector<DenseTensorMeta>& x_meta, const Scalar& axis_scalar,
const Scalar& axis_scalar, MetaTensor* out,
bool is_runtime); MetaConfig config = MetaConfig());
} // namespace pten } // namespace pten
...@@ -12,23 +12,25 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,23 +12,25 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
// See Note [ Why still include the fluid headers? ]
#include "paddle/pten/infermeta/nullary.h" #include "paddle/pten/infermeta/nullary.h"
namespace pten { namespace pten {
DenseTensorMeta CreateInferMeta(const std::vector<int64_t>& shape, void CreateInferMeta(const std::vector<int64_t>& shape,
DataType dtype, DataType dtype,
DataLayout layout) { DataLayout layout,
const auto& out_dims = pten::framework::make_ddim(shape); MetaTensor* out) {
return {dtype, out_dims, layout}; auto out_dims = pten::framework::make_ddim(shape);
out->set_dims(out_dims);
out->set_dtype(dtype);
out->set_layout(layout);
} }
DenseTensorMeta CreateInferMeta(const ScalarArray& shape, void CreateInferMeta(const ScalarArray& shape,
DataType dtype, DataType dtype,
DataLayout layout) { DataLayout layout,
const auto& out_dims = pten::framework::make_ddim(shape.GetData()); MetaTensor* out) {
return {dtype, out_dims, layout}; CreateInferMeta(shape.GetData(), dtype, layout, out);
} }
} // namespace pten } // namespace pten
...@@ -15,24 +15,27 @@ limitations under the License. */ ...@@ -15,24 +15,27 @@ limitations under the License. */
#pragma once #pragma once
#include "paddle/pten/common/scalar_array.h" #include "paddle/pten/common/scalar_array.h"
#include "paddle/pten/core/tensor_meta.h" #include "paddle/pten/core/meta_tensor.h"
namespace pten { namespace pten {
// Common InferMeta Functions for 0-nary operators(no input tensor), The format // Common InferMeta Functions for 0-nary operators(no input tensor), The format
// like: // like:
// //
// 1. DenseTensorMeta [OpName]InferMeta( ...) // 1. void [FunctionDesc|OpName]InferMeta(..., MetaTensor* out)
// NOTE: The name "InferMeta" may be not appropriate. "InferMeta" may be good. //
// Because functions in this file // NOTE: The name "InferShape" may be not appropriate. "InferMeta" may be good.
// not only can infer shape, but alse need infer lod or other useful data. // Because functions in this file not only can infer shape, but also need
// infer lod or other useful data.
DenseTensorMeta CreateInferMeta(const std::vector<int64_t>& shape,
DataType dtype, void CreateInferMeta(const std::vector<int64_t>& shape,
DataLayout layout); DataType dtype,
DataLayout layout,
DenseTensorMeta CreateInferMeta(const ScalarArray& shape, MetaTensor* out);
DataType dtype,
DataLayout layout); void CreateInferMeta(const ScalarArray& shape,
DataType dtype,
DataLayout layout,
MetaTensor* out);
} // namespace pten } // namespace pten
...@@ -16,6 +16,7 @@ limitations under the License. */ ...@@ -16,6 +16,7 @@ limitations under the License. */
#include <set> #include <set>
#include "paddle/pten/common/data_type.h"
#include "paddle/pten/core/infermeta_utils.h" #include "paddle/pten/core/infermeta_utils.h"
namespace pten { namespace pten {
...@@ -23,26 +24,22 @@ namespace pten { ...@@ -23,26 +24,22 @@ namespace pten {
void UnchangedInferMetaNew(MetaConfig config, void UnchangedInferMetaNew(MetaConfig config,
const MetaTensor& x, const MetaTensor& x,
MetaTensor* out) { MetaTensor* out) {
out->set_dims(x.dims()); out->share_meta(x);
out->set_dtype(x.dtype());
out->set_layout(x.layout());
out->share_lod(x);
} }
DenseTensorMeta UnchangedInferMeta(const DenseTensorMeta& x_meta) { DenseTensorMeta UnchangedInferMeta(const DenseTensorMeta& x_meta) {
return x_meta; return x_meta;
} }
DenseTensorMeta ReductionInferMeta(const DenseTensorMeta& x_meta) { void UnchangedInferMeta(const MetaTensor& x, MetaTensor* out) {
const auto& out_dims = pten::framework::make_ddim({1}); out->share_meta(x);
DenseTensorMeta return_meta(x_meta.dtype, out_dims, x_meta.layout);
return return_meta;
} }
DenseTensorMeta FlattenInferMeta(const DenseTensorMeta& x_meta, void FlattenInferMeta(const MetaTensor& x,
int start_axis, int start_axis,
int stop_axis) { int stop_axis,
auto& x_dims = x_meta.dims; MetaTensor* out) {
auto x_dims = x.dims();
int in_dims_size = x_dims.size(); int in_dims_size = x_dims.size();
if (start_axis < 0) { if (start_axis < 0) {
start_axis = start_axis + in_dims_size; start_axis = start_axis + in_dims_size;
...@@ -75,29 +72,30 @@ DenseTensorMeta FlattenInferMeta(const DenseTensorMeta& x_meta, ...@@ -75,29 +72,30 @@ DenseTensorMeta FlattenInferMeta(const DenseTensorMeta& x_meta,
out_shape.push_back(x_dims[i]); out_shape.push_back(x_dims[i]);
} }
const auto& out_dims = pten::framework::make_ddim(out_shape); const auto& out_dims = pten::framework::make_ddim(out_shape);
DenseTensorMeta return_meta(x_meta.dtype, out_dims, x_meta.layout); out->set_dims(out_dims);
out->set_dtype(x.dtype());
out->set_layout(x.layout());
if (x_dims[0] == return_meta.dims[0]) { if (x_dims[0] == out_dims[0]) {
// Only pass LoD when the first dimension of output and Input(X) // Only pass LoD when the first dimension of output and Input(X)
// are the same. // are the same.
return_meta.lod = x_meta.lod; out->share_lod(x);
} }
return return_meta;
} }
DenseTensorMeta CastInferMeta(const DenseTensorMeta& x_meta, void CastInferMeta(const MetaTensor& x, DataType out_dtype, MetaTensor* out) {
const DataType out_dtype) { out->set_dims(x.dims());
DenseTensorMeta out_meta(out_dtype, x_meta.dims, x_meta.layout); out->set_dtype(out_dtype);
return out_meta; out->set_layout(x.layout());
} }
DenseTensorMeta CreateLikeInferMeta(const DenseTensorMeta& x_meta, void CreateLikeInferMeta(const MetaTensor& x,
DataType dtype, DataType dtype,
DataLayout layout) { DataLayout layout,
return {dtype == DataType::UNDEFINED ? x_meta.dtype : dtype, MetaTensor* out) {
x_meta.dims, out->set_dims(x.dims());
layout == DataLayout::UNDEFINED ? x_meta.layout : layout}; out->set_dtype(dtype == DataType::UNDEFINED ? x.dtype() : dtype);
out->set_layout(layout == DataLayout::UNDEFINED ? x.layout() : layout);
} }
static pten::framework::DDim ValidateShape( static pten::framework::DDim ValidateShape(
...@@ -220,46 +218,51 @@ static pten::framework::DDim ValidateShape( ...@@ -220,46 +218,51 @@ static pten::framework::DDim ValidateShape(
return pten::framework::make_ddim(output_shape); return pten::framework::make_ddim(output_shape);
} }
DenseTensorMeta InferMetaFromVecValue(const DenseTensorMeta& x_meta, void InferMetaFromVecValue(const MetaTensor& x,
const std::vector<int64_t>& shape) { const std::vector<int64_t>& shape,
MetaTensor* out) {
PADDLE_ENFORCE_EQ(!shape.empty(), PADDLE_ENFORCE_EQ(!shape.empty(),
true, true,
paddle::platform::errors::InvalidArgument( paddle::platform::errors::InvalidArgument(
"The parameter 'shape' in ReshapeOp must be set. " "The parameter 'shape' in ReshapeOp must be set. "
"But received 'shape' is empty.")); "But received 'shape' is empty."));
auto x_dims = x_meta.dims; auto x_dims = x.dims();
auto out_dims = ValidateShape(shape, x_dims); auto out_dims = ValidateShape(shape, x_dims);
DenseTensorMeta return_meta(x_meta.dtype, out_dims, x_meta.layout); out->set_dims(out_dims);
if (x_dims[0] == return_meta.dims[0]) { out->set_dtype(x.dtype());
out->set_layout(x.layout());
if (x_dims[0] == out_dims[0]) {
// Only pass LoD when the first dimension of output and Input(X) // Only pass LoD when the first dimension of output and Input(X)
// are the same. // are the same.
return_meta.lod = x_meta.lod; out->share_lod(x);
} }
return return_meta;
} }
DenseTensorMeta ReshapeInferMeta(const DenseTensorMeta& x_meta, void ReshapeInferMeta(const MetaTensor& x,
const ScalarArray& shape) { const ScalarArray& shape,
return InferMetaFromVecValue(x_meta, shape.GetData()); MetaTensor* out) {
InferMetaFromVecValue(x, shape.GetData(), out);
} }
/* Why not use ReduceInferMeta directly? /* Why not use ReduceInferMeta directly?
Because we need make InferMetaFunction's args follow the design of api.yaml Because we need make InferMetaFunction's args follow the design of api.yaml
*/ */
DenseTensorMeta SumInferMeta(const DenseTensorMeta& x_meta, void SumInferMeta(const MetaTensor& x,
const std::vector<int64_t>& axis, const std::vector<int64_t>& axis,
DataType dtype, DataType dtype,
bool keep_dim) { bool keep_dim,
return ReduceInferMeta(x_meta, axis, keep_dim, dtype); MetaTensor* out) {
ReduceInferMeta(x, axis, keep_dim, dtype, std::move(out));
} }
DenseTensorMeta ReduceInferMeta(const DenseTensorMeta& x_meta, void ReduceInferMeta(const MetaTensor& x,
const std::vector<int64_t>& axis, const std::vector<int64_t>& axis,
bool keep_dim, bool keep_dim,
DataType dtype) { DataType dtype,
MetaTensor* out) {
bool reduce_all = true; bool reduce_all = true;
std::set<int64_t> dims_set(axis.begin(), axis.end()); std::set<int64_t> dims_set(axis.begin(), axis.end());
for (int64_t i = 0; i < x_meta.dims.size(); ++i) { for (int64_t i = 0; i < x.dims().size(); ++i) {
if (dims_set.find(i) == dims_set.end()) { if (dims_set.find(i) == dims_set.end()) {
reduce_all = false; reduce_all = false;
break; break;
...@@ -268,19 +271,19 @@ DenseTensorMeta ReduceInferMeta(const DenseTensorMeta& x_meta, ...@@ -268,19 +271,19 @@ DenseTensorMeta ReduceInferMeta(const DenseTensorMeta& x_meta,
std::vector<int64_t> out_dim_vector; std::vector<int64_t> out_dim_vector;
if (keep_dim) { if (keep_dim) {
for (int64_t i = 0; i < x_meta.dims.size(); ++i) { for (int64_t i = 0; i < x.dims().size(); ++i) {
if (reduce_all || dims_set.find(i) != dims_set.end()) { if (reduce_all || dims_set.find(i) != dims_set.end()) {
out_dim_vector.push_back(1); out_dim_vector.push_back(1);
} else { } else {
out_dim_vector.push_back(x_meta.dims.at(i)); out_dim_vector.push_back(x.dims().at(i));
} }
} }
} else { } else {
for (int64_t i = 0; i < x_meta.dims.size(); ++i) { for (int64_t i = 0; i < x.dims().size(); ++i) {
if (reduce_all || dims_set.find(i) != dims_set.end()) { if (reduce_all || dims_set.find(i) != dims_set.end()) {
continue; continue;
} else { } else {
out_dim_vector.push_back(x_meta.dims.at(i)); out_dim_vector.push_back(x.dims().at(i));
} }
} }
...@@ -294,16 +297,24 @@ DenseTensorMeta ReduceInferMeta(const DenseTensorMeta& x_meta, ...@@ -294,16 +297,24 @@ DenseTensorMeta ReduceInferMeta(const DenseTensorMeta& x_meta,
if (dtype != DataType::UNDEFINED) { if (dtype != DataType::UNDEFINED) {
out_dtype = dtype; out_dtype = dtype;
} else { } else {
if (x_meta.dtype == DataType::BOOL || x_meta.dtype == DataType::INT32 || if (x.dtype() == DataType::BOOL || x.dtype() == DataType::INT32 ||
x_meta.dtype == DataType::INT64) { x.dtype() == DataType::INT64) {
out_dtype = DataType::INT64; out_dtype = DataType::INT64;
} else { } else {
out_dtype = x_meta.dtype; out_dtype = x.dtype();
} }
} }
DenseTensorMeta return_meta(out_dtype, out_dim, x_meta.layout); out->set_dims(out_dim);
return return_meta; out->set_dtype(out_dtype);
out->set_layout(x.layout());
}
void ReduceInferMeta(const MetaTensor& x,
const std::vector<int64_t>& axis,
bool keep_dim,
MetaTensor* out) {
ReduceInferMeta(x, axis, keep_dim, DataType::UNDEFINED, out);
} }
} // namespace pten } // namespace pten
......
...@@ -16,9 +16,7 @@ limitations under the License. */ ...@@ -16,9 +16,7 @@ limitations under the License. */
// See Note [ Why still include the fluid headers? ] // See Note [ Why still include the fluid headers? ]
#include "paddle/pten/common/scalar_array.h" #include "paddle/pten/common/scalar_array.h"
#include "paddle/pten/core/infermeta_utils.h"
#include "paddle/pten/core/meta_tensor.h" #include "paddle/pten/core/meta_tensor.h"
#include "paddle/pten/core/tensor_meta.h"
namespace pten { namespace pten {
...@@ -26,45 +24,54 @@ class MetaConfig; ...@@ -26,45 +24,54 @@ class MetaConfig;
// Common InferMeta Functions for unary operators, The format like: // Common InferMeta Functions for unary operators, The format like:
// //
// void [OpName]InferMeta(const MetaTensor& x, ..., MetaTensor* out) {} // void [FunctionDesc|OpName]InferMeta(const MetaTensor& x, ..., MetaTensor*
// out) {}
// //
// NOTE: The name "InferShape" may be not appropriate. "InferMeta" may be good. // NOTE: The name "InferShape" may be not appropriate. "InferMeta" may be good.
// Because functions in this file not only can infer shape, but also need // Because functions in this file not only can infer shape, but also need
// infer lod or other useful data. // infer lod or other useful data.
// TODO(chenweihang): update all InferMeta function format in next pr, // TODO(chenweihang): to avoid conflit, remove this function in next PR
// now add UnchangedInferMetaNew for test new format
void UnchangedInferMetaNew(MetaConfig config, void UnchangedInferMetaNew(MetaConfig config,
const MetaTensor& x, const MetaTensor& x,
MetaTensor* out); MetaTensor* out);
DenseTensorMeta UnchangedInferMeta(const DenseTensorMeta& x_meta); void UnchangedInferMeta(const MetaTensor& x, MetaTensor* out);
DenseTensorMeta ReductionInferMeta(const DenseTensorMeta& x_meta); void FlattenInferMeta(const MetaTensor& x,
int start_axis,
int stop_axis,
MetaTensor* out);
DenseTensorMeta FlattenInferMeta(const DenseTensorMeta& x_meta, void CastInferMeta(const MetaTensor& x, DataType out_dtype, MetaTensor* out);
int start_axis,
int stop_axis);
DenseTensorMeta CastInferMeta(const DenseTensorMeta& x_meta,
const DataType out_dtype);
DenseTensorMeta CreateLikeInferMeta(const DenseTensorMeta& x_meta, void CreateLikeInferMeta(const MetaTensor& x,
DataType dtype, DataType dtype,
DataLayout layout); DataLayout layout,
MetaTensor* out);
DenseTensorMeta InferMetaFromVecValue(const DenseTensorMeta& x_meta, void InferMetaFromVecValue(const MetaTensor& x,
const std::vector<int64_t>& shape); const std::vector<int64_t>& shape,
MetaTensor* out);
DenseTensorMeta ReshapeInferMeta(const DenseTensorMeta& x_meta,
const ScalarArray& shape);
DenseTensorMeta ReduceInferMeta(const DenseTensorMeta& x_meta,
const std::vector<int64_t>& axis,
bool keep_dim,
DataType dtype = DataType::UNDEFINED);
DenseTensorMeta SumInferMeta(const DenseTensorMeta& x_meta, void ReshapeInferMeta(const MetaTensor& x,
const std::vector<int64_t>& axis, const ScalarArray& shape,
DataType dtype, MetaTensor* out);
bool keep_dim);
void ReduceInferMeta(const MetaTensor& x,
const std::vector<int64_t>& axis,
bool keep_dim,
DataType dtype,
MetaTensor* out);
void ReduceInferMeta(const MetaTensor& x,
const std::vector<int64_t>& axis,
bool keep_dim,
MetaTensor* out);
void SumInferMeta(const MetaTensor& x,
const std::vector<int64_t>& axis,
DataType dtype,
bool keep_dim,
MetaTensor* out);
} // namespace pten } // namespace pten
...@@ -29,8 +29,9 @@ template <typename T, typename Context> ...@@ -29,8 +29,9 @@ template <typename T, typename Context>
DenseTensor Cast(const Context& dev_ctx, DenseTensor Cast(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
DataType out_dtype) { DataType out_dtype) {
auto out_meta = CastInferMeta(x.meta(), out_dtype); auto dense_out = pten::Empty<T, Context>(dev_ctx);
auto dense_out = pten::Empty<T, Context>(dev_ctx, std::move(out_meta)); MetaTensor meta_out(&dense_out);
CastInferMeta(x, out_dtype, &meta_out);
CastKernel<T, Context>(dev_ctx, x, out_dtype, &dense_out); CastKernel<T, Context>(dev_ctx, x, out_dtype, &dense_out);
return dense_out; return dense_out;
} }
......
...@@ -32,8 +32,9 @@ template <typename T, ...@@ -32,8 +32,9 @@ template <typename T,
std::is_same<T, paddle::platform::complex<double>>::value, std::is_same<T, paddle::platform::complex<double>>::value,
bool> = true> bool> = true>
DenseTensor Conj(const Context& dev_ctx, const DenseTensor& x) { DenseTensor Conj(const Context& dev_ctx, const DenseTensor& x) {
auto out_meta = UnchangedInferMeta(x.meta()); auto dense_out = pten::Empty<T, Context>(dev_ctx);
auto dense_out = pten::Empty<T, Context>(dev_ctx, std::move(out_meta)); MetaTensor meta_out(&dense_out);
UnchangedInferMeta(x, &meta_out);
ConjKernel<T>(dev_ctx, x, &dense_out); ConjKernel<T>(dev_ctx, x, &dense_out);
return dense_out; return dense_out;
} }
......
...@@ -30,14 +30,16 @@ template <typename T, typename Context> ...@@ -30,14 +30,16 @@ template <typename T, typename Context>
DenseTensor Concat(const Context& dev_ctx, DenseTensor Concat(const Context& dev_ctx,
const std::vector<DenseTensor>& x, const std::vector<DenseTensor>& x,
const Scalar& axis) { const Scalar& axis) {
std::vector<DenseTensorMeta> x_meta; std::vector<MetaTensor> meta_x;
for (auto t : x) { for (const auto& t : x) {
x_meta.push_back(t.meta()); meta_x.emplace_back(t);
} }
auto out_meta = ConcatInferMeta(x_meta, axis.to<int>(), true); auto dense_out = pten::Empty<T, Context>(dev_ctx);
auto dense_out = pten::Empty<T, Context>(dev_ctx, std::move(out_meta)); MetaTensor meta_out(&dense_out);
ConcatInferMeta(meta_x, axis.to<int>(), &meta_out, /*is_runtime=*/true);
ConcatKernel<T, Context>(dev_ctx, x, axis, &dense_out); ConcatKernel<T, Context>(dev_ctx, x, axis, &dense_out);
return dense_out; return dense_out;
} }
} // namespace pten } // namespace pten
...@@ -29,8 +29,9 @@ template <typename T, typename Context> ...@@ -29,8 +29,9 @@ template <typename T, typename Context>
DenseTensor Dot(const Context& dev_ctx, DenseTensor Dot(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const DenseTensor& y) { const DenseTensor& y) {
auto out_meta = DotInferMeta(x.meta(), y.meta()); auto dense_out = pten::Empty<T, Context>(dev_ctx);
auto dense_out = pten::Empty<T, Context>(dev_ctx, std::move(out_meta)); MetaTensor meta_out(&dense_out);
DotInferMeta(x, y, &meta_out);
DotKernel<T, Context>(dev_ctx, x, y, &dense_out); DotKernel<T, Context>(dev_ctx, x, y, &dense_out);
return dense_out; return dense_out;
} }
......
...@@ -55,8 +55,9 @@ DenseTensor Empty(const Context& dev_ctx, ...@@ -55,8 +55,9 @@ DenseTensor Empty(const Context& dev_ctx,
DataType dtype = DataType::FLOAT32, DataType dtype = DataType::FLOAT32,
Backend backend = Backend::CPU, // Is backend needed here? Backend backend = Backend::CPU, // Is backend needed here?
DataLayout layout = DataLayout::NCHW) { DataLayout layout = DataLayout::NCHW) {
auto out_meta = CreateInferMeta(shape, dtype, layout); auto dense_out = Empty<T, Context>(dev_ctx);
auto dense_out = Empty<T, Context>(dev_ctx, std::move(out_meta)); MetaTensor meta_out(&dense_out);
CreateInferMeta(shape, dtype, layout, &meta_out);
EmptyKernel<T, Context>(dev_ctx, shape, &dense_out); EmptyKernel<T, Context>(dev_ctx, shape, &dense_out);
return dense_out; return dense_out;
} }
...@@ -68,8 +69,9 @@ DenseTensor EmptyLike( ...@@ -68,8 +69,9 @@ DenseTensor EmptyLike(
DataType dtype = DataType::UNDEFINED, DataType dtype = DataType::UNDEFINED,
Backend backend = Backend::UNDEFINED, // Is backend needed here? Backend backend = Backend::UNDEFINED, // Is backend needed here?
DataLayout layout = DataLayout::UNDEFINED) { DataLayout layout = DataLayout::UNDEFINED) {
auto out_meta = CreateLikeInferMeta(x.meta(), dtype, layout); auto dense_out = Empty<T, Context>(dev_ctx);
auto dense_out = Empty<T, Context>(dev_ctx, std::move(out_meta)); MetaTensor meta_out(&dense_out);
CreateLikeInferMeta(x, dtype, layout, &meta_out);
EmptyLikeKernel<T, Context>(dev_ctx, &dense_out); EmptyLikeKernel<T, Context>(dev_ctx, &dense_out);
return dense_out; return dense_out;
} }
......
...@@ -40,8 +40,9 @@ DenseTensor Flatten(const Context& dev_ctx, ...@@ -40,8 +40,9 @@ DenseTensor Flatten(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
int start_axis, int start_axis,
int stop_axis) { int stop_axis) {
auto out_meta = FlattenInferMeta(x.meta(), start_axis, stop_axis); auto dense_out = Empty<T, Context>(dev_ctx);
auto dense_out = Empty<T, Context>(dev_ctx, std::move(out_meta)); MetaTensor meta_out(&dense_out);
FlattenInferMeta(x, start_axis, stop_axis, &meta_out);
FlattenKernel<T, Context>(dev_ctx, x, start_axis, stop_axis, &dense_out); FlattenKernel<T, Context>(dev_ctx, x, start_axis, stop_axis, &dense_out);
return dense_out; return dense_out;
} }
......
...@@ -41,8 +41,9 @@ DenseTensor Full(const Context& dev_ctx, ...@@ -41,8 +41,9 @@ DenseTensor Full(const Context& dev_ctx,
DataType dtype = DataType::FLOAT32, DataType dtype = DataType::FLOAT32,
Backend backend = Backend::CPU, // Is backend needed here? Backend backend = Backend::CPU, // Is backend needed here?
DataLayout layout = DataLayout::NCHW) { DataLayout layout = DataLayout::NCHW) {
auto out_meta = CreateInferMeta(shape, dtype, layout); auto dense_out = Empty<T, Context>(dev_ctx);
auto dense_out = Empty<T, Context>(dev_ctx, std::move(out_meta)); MetaTensor meta_out(&dense_out);
CreateInferMeta(shape, dtype, layout, &meta_out);
FullKernel<T, Context>(dev_ctx, shape, val, &dense_out); FullKernel<T, Context>(dev_ctx, shape, val, &dense_out);
return dense_out; return dense_out;
} }
...@@ -55,8 +56,9 @@ DenseTensor FullLike( ...@@ -55,8 +56,9 @@ DenseTensor FullLike(
DataType dtype = DataType::UNDEFINED, DataType dtype = DataType::UNDEFINED,
Backend backend = Backend::UNDEFINED, // Is backend needed here? Backend backend = Backend::UNDEFINED, // Is backend needed here?
DataLayout layout = DataLayout::UNDEFINED) { DataLayout layout = DataLayout::UNDEFINED) {
auto out_meta = CreateLikeInferMeta(x.meta(), dtype, layout); auto dense_out = Empty<T, Context>(dev_ctx);
auto dense_out = Empty<T, Context>(dev_ctx, std::move(out_meta)); MetaTensor meta_out(&dense_out);
CreateLikeInferMeta(x, dtype, layout, &meta_out);
FullLikeKernel<T, Context>(dev_ctx, val, &dense_out); FullLikeKernel<T, Context>(dev_ctx, val, &dense_out);
return dense_out; return dense_out;
} }
......
...@@ -35,7 +35,7 @@ static inline int64_t ComputeAxis(int64_t axis, int64_t rank) { ...@@ -35,7 +35,7 @@ static inline int64_t ComputeAxis(int64_t axis, int64_t rank) {
} }
static inline pten::DDim ComputeAndCheckShape( static inline pten::DDim ComputeAndCheckShape(
const bool is_runtime, bool is_runtime,
const std::vector<pten::DDim>& inputs_dims, const std::vector<pten::DDim>& inputs_dims,
const size_t axis) { const size_t axis) {
const size_t n = inputs_dims.size(); const size_t n = inputs_dims.size();
......
...@@ -109,8 +109,9 @@ template <typename T, typename Context> ...@@ -109,8 +109,9 @@ template <typename T, typename Context>
DenseTensor Add(const Context& dev_ctx, DenseTensor Add(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const DenseTensor& y) { const DenseTensor& y) {
auto out_meta = ElementwiseRawInferMeta(x.meta(), y.meta(), -1); auto dense_out = pten::Empty<T, Context>(dev_ctx);
auto dense_out = pten::Empty<T, Context>(dev_ctx, std::move(out_meta)); MetaTensor meta_out(&dense_out);
ElementwiseInferMeta(x, y, &meta_out);
AddKernel<T, Context>(dev_ctx, x, y, &dense_out); AddKernel<T, Context>(dev_ctx, x, y, &dense_out);
return dense_out; return dense_out;
} }
...@@ -119,8 +120,9 @@ template <typename T, typename Context> ...@@ -119,8 +120,9 @@ template <typename T, typename Context>
DenseTensor Subtract(const Context& dev_ctx, DenseTensor Subtract(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const DenseTensor& y) { const DenseTensor& y) {
auto out_meta = ElementwiseRawInferMeta(x.meta(), y.meta(), -1); auto dense_out = pten::Empty<T, Context>(dev_ctx);
auto dense_out = pten::Empty<T, Context>(dev_ctx, std::move(out_meta)); MetaTensor meta_out(&dense_out);
ElementwiseInferMeta(x, y, &meta_out);
SubtractKernel<T, Context>(dev_ctx, x, y, &dense_out); SubtractKernel<T, Context>(dev_ctx, x, y, &dense_out);
return dense_out; return dense_out;
} }
...@@ -129,8 +131,9 @@ template <typename T, typename Context> ...@@ -129,8 +131,9 @@ template <typename T, typename Context>
DenseTensor Divide(const Context& dev_ctx, DenseTensor Divide(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const DenseTensor& y) { const DenseTensor& y) {
auto out_meta = ElementwiseRawInferMeta(x.meta(), y.meta(), -1); auto dense_out = pten::Empty<T, Context>(dev_ctx);
auto dense_out = pten::Empty<T, Context>(dev_ctx, std::move(out_meta)); MetaTensor meta_out(&dense_out);
ElementwiseInferMeta(x, y, &meta_out);
DivideKernel<T, Context>(dev_ctx, x, y, &dense_out); DivideKernel<T, Context>(dev_ctx, x, y, &dense_out);
return dense_out; return dense_out;
} }
...@@ -139,8 +142,9 @@ template <typename T, typename Context> ...@@ -139,8 +142,9 @@ template <typename T, typename Context>
DenseTensor Multiply(const Context& dev_ctx, DenseTensor Multiply(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const DenseTensor& y) { const DenseTensor& y) {
auto out_meta = ElementwiseRawInferMeta(x.meta(), y.meta(), -1); auto dense_out = pten::Empty<T, Context>(dev_ctx);
auto dense_out = pten::Empty<T, Context>(dev_ctx, std::move(out_meta)); MetaTensor meta_out(&dense_out);
ElementwiseInferMeta(x, y, &meta_out);
MultiplyKernel<T, Context>(dev_ctx, x, y, &dense_out); MultiplyKernel<T, Context>(dev_ctx, x, y, &dense_out);
return dense_out; return dense_out;
} }
...@@ -150,8 +154,9 @@ DenseTensor Mean(const Context& dev_ctx, ...@@ -150,8 +154,9 @@ DenseTensor Mean(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const std::vector<int64_t>& axis, const std::vector<int64_t>& axis,
bool keep_dim) { bool keep_dim) {
auto out_meta = ReduceInferMeta(x.meta(), axis, keep_dim); auto dense_out = pten::Empty<T, Context>(dev_ctx);
auto dense_out = pten::Empty<T, Context>(dev_ctx, std::move(out_meta)); MetaTensor meta_out(&dense_out);
ReduceInferMeta(x, axis, keep_dim, x.dtype(), &meta_out);
MeanKernel<T, Context>(dev_ctx, x, axis, keep_dim, &dense_out); MeanKernel<T, Context>(dev_ctx, x, axis, keep_dim, &dense_out);
return dense_out; return dense_out;
} }
...@@ -162,9 +167,9 @@ DenseTensor Sum(const Context& dev_ctx, ...@@ -162,9 +167,9 @@ DenseTensor Sum(const Context& dev_ctx,
const std::vector<int64_t>& axis, const std::vector<int64_t>& axis,
DataType dtype, DataType dtype,
bool keep_dim) { bool keep_dim) {
auto out_meta = SumInferMeta(x.meta(), axis, dtype, keep_dim); auto dense_out = pten::Empty<T, Context>(dev_ctx);
auto dense_out = pten::Empty<T, Context>(dev_ctx, std::move(out_meta)); MetaTensor meta_out(&dense_out);
SumInferMeta(x, axis, dtype, keep_dim, &meta_out);
SumKernel<T, Context>(dev_ctx, x, axis, dtype, keep_dim, &dense_out); SumKernel<T, Context>(dev_ctx, x, axis, dtype, keep_dim, &dense_out);
return dense_out; return dense_out;
} }
......
...@@ -35,8 +35,9 @@ DenseTensor Matmul(const Context& dev_ctx, ...@@ -35,8 +35,9 @@ DenseTensor Matmul(const Context& dev_ctx,
const DenseTensor& y, const DenseTensor& y,
bool transpose_x, bool transpose_x,
bool transpose_y) { bool transpose_y) {
auto out_meta = MatmulInferMeta(x.meta(), y.meta(), transpose_x, transpose_y); auto dense_out = Empty<T, Context>(dev_ctx);
auto dense_out = Empty<T, Context>(dev_ctx, std::move(out_meta)); MetaTensor meta_out(&dense_out);
MatmulInferMeta(x, y, transpose_x, transpose_y, &meta_out);
MatmulKernel<T, Context>(dev_ctx, x, y, transpose_x, transpose_y, &dense_out); MatmulKernel<T, Context>(dev_ctx, x, y, transpose_x, transpose_y, &dense_out);
return dense_out; return dense_out;
} }
......
...@@ -26,15 +26,18 @@ void ReshapeKernel(const Context& dev_ctx, ...@@ -26,15 +26,18 @@ void ReshapeKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const ScalarArray& shape, const ScalarArray& shape,
DenseTensor* out) { DenseTensor* out) {
auto out_meta = InferMetaFromVecValue(x.meta(), shape.GetData()); MetaTensor meta_out(out);
InferMetaFromVecValue(x, shape.GetData(), &meta_out);
if (x.initialized() && x.Holder() == out->Holder()) { if (x.initialized() && x.Holder() == out->Holder()) {
out->ResizeAndAllocate(out_meta.dims); dev_ctx.Alloc(out);
return; return;
} }
out->set_meta(out_meta);
dev_ctx.Alloc(out); dev_ctx.Alloc(out);
// TODO(chenweihang): the output dims are overwrite after copying,
// here we need to use copy method that only copy data
auto dims = out->dims();
pten::Copy(dev_ctx, x, false, out); pten::Copy(dev_ctx, x, false, out);
out->Resize(out_meta.dims); out->Resize(dims);
out->ResetLoD(x.lod()); out->ResetLoD(x.lod());
} }
......
...@@ -38,8 +38,9 @@ template <typename T, typename Context> ...@@ -38,8 +38,9 @@ template <typename T, typename Context>
DenseTensor Reshape(const Context& dev_ctx, DenseTensor Reshape(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const std::vector<int64_t>& shape) { const std::vector<int64_t>& shape) {
auto out_meta = InferMetaFromVecValue(x.meta(), shape); auto dense_out = Empty<T, Context>(dev_ctx);
auto dense_out = Empty<T, Context>(dev_ctx, std::move(out_meta)); MetaTensor meta_out(&dense_out);
InferMetaFromVecValue(x, shape, &meta_out);
ReshapeKernel<Context>(dev_ctx, x, ScalarArray(shape), &dense_out); ReshapeKernel<Context>(dev_ctx, x, ScalarArray(shape), &dense_out);
return dense_out; return dense_out;
} }
......
...@@ -43,8 +43,9 @@ DenseTensor Scale(const Context& dev_ctx, ...@@ -43,8 +43,9 @@ DenseTensor Scale(const Context& dev_ctx,
const Scalar& scale, const Scalar& scale,
float bias, float bias,
bool bias_after_scale) { bool bias_after_scale) {
auto out_meta = UnchangedInferMeta(x.meta()); auto dense_out = pten::Empty<T, Context>(dev_ctx);
auto dense_out = pten::Empty<T, Context>(dev_ctx, std::move(out_meta)); MetaTensor meta_out(&dense_out);
UnchangedInferMeta(x, &meta_out);
ScaleKernel<T, Context>( ScaleKernel<T, Context>(
dev_ctx, x, scale, bias, bias_after_scale, &dense_out); dev_ctx, x, scale, bias, bias_after_scale, &dense_out);
return dense_out; return dense_out;
......
...@@ -25,8 +25,9 @@ void SignKernel(const Context& dev_ctx, const DenseTensor& x, DenseTensor* out); ...@@ -25,8 +25,9 @@ void SignKernel(const Context& dev_ctx, const DenseTensor& x, DenseTensor* out);
template <typename T, typename Context> template <typename T, typename Context>
DenseTensor Sign(const Context& dev_ctx, const DenseTensor& x) { DenseTensor Sign(const Context& dev_ctx, const DenseTensor& x) {
auto out_meta = UnchangedInferMeta(x.meta()); auto dense_out = pten::Empty<T, Context>(dev_ctx);
auto dense_out = pten::Empty<T, Context>(dev_ctx, std::move(out_meta)); MetaTensor meta_out(&dense_out);
UnchangedInferMeta(x, &meta_out);
SignKernel<T, Context>(dev_ctx, x, &dense_out); SignKernel<T, Context>(dev_ctx, x, &dense_out);
return dense_out; return dense_out;
} }
......
...@@ -23,6 +23,7 @@ ...@@ -23,6 +23,7 @@
#include "paddle/pten/common/scalar.h" #include "paddle/pten/common/scalar.h"
#include "paddle/pten/common/scalar_array.h" #include "paddle/pten/common/scalar_array.h"
#include "paddle/pten/core/kernel_registry.h" #include "paddle/pten/core/kernel_registry.h"
#include "paddle/pten/core/meta_tensor.h"
#include "paddle/pten/infermeta/unary.h" #include "paddle/pten/infermeta/unary.h"
#include "paddle/pten/kernels/scale_kernel.h" #include "paddle/pten/kernels/scale_kernel.h"
...@@ -68,11 +69,12 @@ PADDLE_API Tensor scale_kernel_context(const Tensor& x, ...@@ -68,11 +69,12 @@ PADDLE_API Tensor scale_kernel_context(const Tensor& x,
kernel_context.EmplaceBackAttr(bias); kernel_context.EmplaceBackAttr(bias);
kernel_context.EmplaceBackAttr(bias_after_scale); kernel_context.EmplaceBackAttr(bias_after_scale);
auto out_meta = pten::UnchangedInferMeta(dense_x->meta());
auto dense_out = std::make_shared<pten::DenseTensor>( auto dense_out = std::make_shared<pten::DenseTensor>(
pten::make_intrusive<paddle::experimental::SharedStorage>( pten::make_intrusive<paddle::experimental::SharedStorage>(
pten::TransToFluidPlace(kernel_backend)), pten::TransToFluidPlace(kernel_backend)),
std::move(out_meta)); pten::DenseTensorMeta());
pten::MetaTensor meta_out(dense_out.get());
pten::UnchangedInferMeta(*dense_x, &meta_out);
kernel_context.EmplaceBackOutput(dense_out.get()); kernel_context.EmplaceBackOutput(dense_out.get());
Tensor out; Tensor out;
...@@ -234,11 +236,12 @@ Tensor scale_switch_case(const Tensor& x, ...@@ -234,11 +236,12 @@ Tensor scale_switch_case(const Tensor& x,
auto dense_x = std::dynamic_pointer_cast<pten::DenseTensor>(x.impl()); auto dense_x = std::dynamic_pointer_cast<pten::DenseTensor>(x.impl());
auto out_meta = pten::UnchangedInferMeta(dense_x->meta());
auto dense_out = std::make_shared<pten::DenseTensor>( auto dense_out = std::make_shared<pten::DenseTensor>(
pten::make_intrusive<paddle::experimental::SharedStorage>( pten::make_intrusive<paddle::experimental::SharedStorage>(
pten::TransToFluidPlace(kernel_backend)), pten::TransToFluidPlace(kernel_backend)),
std::move(out_meta)); pten::DenseTensorMeta());
pten::MetaTensor meta_out(dense_out.get());
pten::UnchangedInferMeta(*dense_x, &meta_out);
Tensor out; Tensor out;
out.set_impl(dense_out); out.set_impl(dense_out);
......
...@@ -23,7 +23,7 @@ ...@@ -23,7 +23,7 @@
output : Tensor output : Tensor
infer_meta : infer_meta :
func : ConcatInferMeta func : ConcatInferMeta
param : [x, axis, true] param : [x, axis]
kernel : kernel :
func : concat func : concat
......
...@@ -65,13 +65,15 @@ PADDLE_API {self.return_type} {self.api}({self.args['args_declare']}); ...@@ -65,13 +65,15 @@ PADDLE_API {self.return_type} {self.api}({self.args['args_declare']});
def gene_output(self, output_type_list): def gene_output(self, output_type_list):
kernel_output = "" kernel_output = ""
output_names = []
output_create = "" output_create = ""
if len(output_type_list) == 1: if len(output_type_list) == 1:
kernel_output = 'dense_out' kernel_output = 'dense_out'
output_names.append('dense_out')
output_create = f""" output_create = f"""
{self.return_type} out; {self.return_type} out;
auto dense_out = SetKernelOutput(out_meta, kernel_backend, &out);""" auto dense_out = SetKernelOutput(kernel_backend, &out);"""
elif len(output_type_list) > 1: elif len(output_type_list) > 1:
output_create = f""" output_create = f"""
...@@ -79,8 +81,9 @@ PADDLE_API {self.return_type} {self.api}({self.args['args_declare']}); ...@@ -79,8 +81,9 @@ PADDLE_API {self.return_type} {self.api}({self.args['args_declare']});
for i in range(len(output_type_list)): for i in range(len(output_type_list)):
kernel_output = kernel_output + f'dense_out_{i}, ' kernel_output = kernel_output + f'dense_out_{i}, '
output_names.append(f'dense_out_{i}')
output_create = output_create + f""" output_create = output_create + f"""
auto dense_out_{i} = SetKernelOutput(std::get<{i}>(out_meta), kernel_backend, &std::get<{i}>(out));""" auto dense_out_{i} = SetKernelOutput(kernel_backend, &std::get<{i}>(out));"""
kernel_output = kernel_output[:-2] kernel_output = kernel_output[:-2]
else: else:
...@@ -88,22 +91,23 @@ PADDLE_API {self.return_type} {self.api}({self.args['args_declare']}); ...@@ -88,22 +91,23 @@ PADDLE_API {self.return_type} {self.api}({self.args['args_declare']});
"{} : Output error: the output should not be empty.".format( "{} : Output error: the output should not be empty.".format(
self.api)) self.api))
return kernel_output, output_create return kernel_output, output_names, output_create
def gene_api_code(self): def gene_api_code(self):
if self.is_base_api: if self.is_base_api:
input_tensors, kernel_args, kernel_signature = gen_utils.get_kernel_args( input_tensors, kernel_args, kernel_signature = gen_utils.get_kernel_args(
self.args['inputs'], self.args['attrs'], self.out_type_list, self.args['inputs'], self.args['attrs'], self.out_type_list,
self.kernel['param']) self.kernel['param'])
outputs_args, output_create = self.gene_output(self.out_type_list) outputs_args, output_names, output_create = self.gene_output(
self.out_type_list)
return f""" return f"""
PADDLE_API {self.return_type} {self.api}({self.args["args_define"]}) {{ PADDLE_API {self.return_type} {self.api}({self.args["args_define"]}) {{
{gen_utils.gene_kernel_select(self.api, self.args['inputs']['names'], self.args['attrs'], self.kernel)} {gen_utils.gene_kernel_select(self.api, self.args['inputs']['names'], self.args['attrs'], self.kernel)}
auto* dev_ctx = GetDeviceContextByBackend(kernel_backend); auto* dev_ctx = GetDeviceContextByBackend(kernel_backend);
{input_tensors} {input_tensors}
{gen_utils.gene_infer_meta(self.args['inputs']['names'], self.args['attrs']['names'], self.infer_meta)}
{output_create} {output_create}
{gen_utils.gene_infer_meta(self.args['inputs']['names'], self.args['attrs']['names'], output_names, self.infer_meta)}
using kernel_signature = {kernel_signature}; using kernel_signature = {kernel_signature};
auto* kernel_fn = kernel.GetVariadicKernelFn<kernel_signature>(); auto* kernel_fn = kernel.GetVariadicKernelFn<kernel_signature>();
(*kernel_fn)({kernel_args}, {outputs_args}); (*kernel_fn)({kernel_args}, {outputs_args});
......
...@@ -105,13 +105,15 @@ class BackwardAPI: ...@@ -105,13 +105,15 @@ class BackwardAPI:
def gene_output(self, output_type_list): def gene_output(self, output_type_list):
kernel_output = "" kernel_output = ""
output_names = []
output_create = "" output_create = ""
if len(output_type_list) == 1: if len(output_type_list) == 1:
kernel_output = 'dense_out' kernel_output = 'dense_out'
output_names.append('dense_out')
output_create = f""" output_create = f"""
{self.return_type} out; {self.return_type} out;
auto dense_out = SetKernelOutput(out_meta, kernel_backend, &out);""" auto dense_out = SetKernelOutput(kernel_backend, &out);"""
elif len(output_type_list) > 1: elif len(output_type_list) > 1:
output_create = f""" output_create = f"""
...@@ -119,6 +121,7 @@ class BackwardAPI: ...@@ -119,6 +121,7 @@ class BackwardAPI:
for i, out_type_item in enumerate(output_type_list): for i, out_type_item in enumerate(output_type_list):
kernel_output = kernel_output + f'dense_out_{i}, ' kernel_output = kernel_output + f'dense_out_{i}, '
output_names.append(f'dense_out_{i}')
if out_type_item == 'Tensor': if out_type_item == 'Tensor':
get_out_code = f'&out[{i}][0]' get_out_code = f'&out[{i}][0]'
output_create = output_create + f""" output_create = output_create + f"""
...@@ -127,7 +130,7 @@ class BackwardAPI: ...@@ -127,7 +130,7 @@ class BackwardAPI:
else: else:
get_out_code = f'&out[{i}]' get_out_code = f'&out[{i}]'
output_create = output_create + f""" output_create = output_create + f"""
auto dense_out_{i} = SetKernelOutput(std::get<{i}>(out_meta), kernel_backend, {get_out_code});""" auto dense_out_{i} = SetKernelOutput(kernel_backend, {get_out_code});"""
kernel_output = kernel_output[:-2] kernel_output = kernel_output[:-2]
else: else:
...@@ -135,14 +138,14 @@ class BackwardAPI: ...@@ -135,14 +138,14 @@ class BackwardAPI:
"{} : Output error: the output should not be empty.".format( "{} : Output error: the output should not be empty.".format(
self.backward_api)) self.backward_api))
return kernel_output, output_create return kernel_output, output_names, output_create
def gene_api_code(self): def gene_api_code(self):
if self.is_base_api: if self.is_base_api:
input_tensors, kernel_args, kernel_signature = gen_utils.get_kernel_args( input_tensors, kernel_args, kernel_signature = gen_utils.get_kernel_args(
self.args['inputs'], self.args['attrs'], self.output_type_list, self.args['inputs'], self.args['attrs'], self.output_type_list,
self.kernel['param']) self.kernel['param'])
outputs_args, output_create = self.gene_output( outputs_args, output_names, output_create = self.gene_output(
self.output_type_list) self.output_type_list)
return f""" return f"""
// {self.return_comment} // {self.return_comment}
...@@ -151,8 +154,8 @@ class BackwardAPI: ...@@ -151,8 +154,8 @@ class BackwardAPI:
auto* dev_ctx = GetDeviceContextByBackend(kernel_backend); auto* dev_ctx = GetDeviceContextByBackend(kernel_backend);
{input_tensors} {input_tensors}
{gen_utils.gene_infer_meta(self.args['inputs']['names'], self.args['attrs']['names'], self.infer_meta)}
{output_create} {output_create}
{gen_utils.gene_infer_meta(self.args['inputs']['names'], self.args['attrs']['names'], output_names, self.infer_meta)}
using kernel_signature = {kernel_signature}; using kernel_signature = {kernel_signature};
auto* kernel_fn = kernel.GetVariadicKernelFn<kernel_signature>(); auto* kernel_fn = kernel.GetVariadicKernelFn<kernel_signature>();
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
import re import re
PREFIX_TENSOR_NAME = 'dense_' PREFIX_TENSOR_NAME = 'dense_'
PREFIX_META_TENSOR_NAME = 'meta_'
def parse_args(api_name, args_str): def parse_args(api_name, args_str):
...@@ -265,13 +266,21 @@ def gene_kernel_select(api, input_names, attrs, kernel) -> str: ...@@ -265,13 +266,21 @@ def gene_kernel_select(api, input_names, attrs, kernel) -> str:
return kernel_select_code return kernel_select_code
def gene_infer_meta(input_names, attr_names, infer_meta) -> str: def gene_infer_meta(input_names, attr_names, output_names, infer_meta) -> str:
infer_meta_params = infer_meta['param'] if infer_meta[ infer_meta_params = infer_meta['param'] + output_names if infer_meta[
'param'] is not None else input_names + attr_names 'param'] is not None else input_names + attr_names + output_names
# generate meta tensors
meta_tensor_code = ""
param_code = "" param_code = ""
for param in infer_meta_params: for param in infer_meta_params:
if param in input_names: if param in input_names:
param_code = param_code + "GetDenseTensorMeta(*" + PREFIX_TENSOR_NAME + param + "), " param_code = param_code + "MakeMetaTensor(*" + PREFIX_TENSOR_NAME + param + "), "
elif param in output_names:
meta_tensor_code = meta_tensor_code + " pten::MetaTensor " + param.replace(
PREFIX_TENSOR_NAME,
PREFIX_META_TENSOR_NAME) + "(" + param + ");\n"
param_code = param_code + "&" + param.replace(
PREFIX_TENSOR_NAME, PREFIX_META_TENSOR_NAME) + ", "
elif param in attr_names: elif param in attr_names:
param_code = param_code + param + ", " param_code = param_code + param + ", "
elif isinstance(param, str): elif isinstance(param, str):
...@@ -282,8 +291,8 @@ def gene_infer_meta(input_names, attr_names, infer_meta) -> str: ...@@ -282,8 +291,8 @@ def gene_infer_meta(input_names, attr_names, infer_meta) -> str:
param_code = param_code + str(param) + ", " param_code = param_code + str(param) + ", "
param_code = param_code[:-2] param_code = param_code[:-2]
return f""" return f"""{meta_tensor_code}
auto out_meta = pten::{infer_meta['func']}({param_code}); pten::{infer_meta['func']}({param_code});
""" """
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册