未验证 提交 ab24b9c0 编写于 作者: C Chen Weihang 提交者: GitHub

[Cherry-pick] Optimize dygraph performance part2 (#42224)

* Add paddle::variant and replace paddle::any (#42139)

* add variant and replace any

* split attribute

* Optimize dygraph GetExpectedKernelType perf (#42154)

* opt dygraph scheduling

* revert part impl

* fix variant compile error (#42203)

* replace any by variant in infermeta (#42181)
上级 e4da34fd
...@@ -39,6 +39,7 @@ limitations under the License. */ ...@@ -39,6 +39,7 @@ limitations under the License. */
#include "paddle/phi/api/all.h" #include "paddle/phi/api/all.h"
#include "paddle/phi/api/lib/utils/tensor_utils.h" #include "paddle/phi/api/lib/utils/tensor_utils.h"
#include "paddle/phi/core/compat/convert_utils.h" #include "paddle/phi/core/compat/convert_utils.h"
#include "paddle/phi/core/tensor_utils.h"
#include "paddle/utils/any.h" #include "paddle/utils/any.h"
namespace paddle { namespace paddle {
......
...@@ -22,6 +22,7 @@ ...@@ -22,6 +22,7 @@
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/os_info.h" #include "paddle/fluid/platform/os_info.h"
#include "paddle/fluid/platform/profiler/event_tracing.h" #include "paddle/fluid/platform/profiler/event_tracing.h"
#include "paddle/phi/core/kernel_context.h"
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h" #include "paddle/fluid/platform/mkldnn_helper.h"
#endif #endif
......
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
#include "paddle/fluid/operators/controlflow/conditional_block_op_helper.h" #include "paddle/fluid/operators/controlflow/conditional_block_op_helper.h"
#include "paddle/fluid/operators/controlflow/recurrent_op_helper.h" #include "paddle/fluid/operators/controlflow/recurrent_op_helper.h"
#include "paddle/fluid/operators/controlflow/while_op_helper.h" #include "paddle/fluid/operators/controlflow/while_op_helper.h"
#include "paddle/phi/core/kernel_context.h"
#include "paddle/phi/core/kernel_factory.h" #include "paddle/phi/core/kernel_factory.h"
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
......
...@@ -35,6 +35,7 @@ limitations under the License. */ ...@@ -35,6 +35,7 @@ limitations under the License. */
#include "paddle/fluid/platform/profiler/event_tracing.h" #include "paddle/fluid/platform/profiler/event_tracing.h"
#include "paddle/phi/common/int_array.h" #include "paddle/phi/common/int_array.h"
#include "paddle/phi/common/scalar.h" #include "paddle/phi/common/scalar.h"
#include "paddle/phi/core/kernel_context.h"
#include "paddle/phi/core/kernel_factory.h" #include "paddle/phi/core/kernel_factory.h"
#include "paddle/phi/ops/compat/signatures.h" #include "paddle/phi/ops/compat/signatures.h"
...@@ -941,7 +942,7 @@ class RuntimeInferShapeContext : public InferShapeContext { ...@@ -941,7 +942,7 @@ class RuntimeInferShapeContext : public InferShapeContext {
return ((op_with_kernel.kernel_type()) && return ((op_with_kernel.kernel_type()) &&
(op_with_kernel.kernel_type()->data_layout_ == (op_with_kernel.kernel_type()->data_layout_ ==
framework::DataLayout::kMKLDNN)); framework::DataLayout::kMKLDNN));
} catch (std::bad_cast exp) { } catch (const std::bad_cast& exp) {
return false; return false;
} }
} }
...@@ -1966,6 +1967,36 @@ Scope* OperatorWithKernel::PrepareData( ...@@ -1966,6 +1967,36 @@ Scope* OperatorWithKernel::PrepareData(
} }
void OperatorWithKernel::ParseInputDataType( void OperatorWithKernel::ParseInputDataType(
const Variable* var, const std::string& name,
proto::VarType::Type* data_type) const {
if (var != nullptr) {
const Tensor* t = nullptr;
if (var->IsType<Tensor>()) {
t = &var->Get<Tensor>();
} else if (var->IsType<LoDTensor>()) {
t = &var->Get<LoDTensor>();
} else if (var->IsType<phi::SelectedRows>()) {
t = &(var->Get<phi::SelectedRows>().value());
} else if (var->IsType<LoDTensorArray>()) {
auto t_arr = &var->Get<LoDTensorArray>();
for (size_t j = 0; j < t_arr->size(); j++) {
if (t_arr->at(j).IsInitialized()) {
t = &(t_arr->at(j));
}
}
}
if (t != nullptr) {
PADDLE_ENFORCE_EQ(
t->IsInitialized(), true,
platform::errors::InvalidArgument("The %s Op's Input Variable `%s` "
"contains uninitialized Tensor.",
Type(), name));
*data_type = paddle::framework::TransToProtoVarType(t->dtype());
}
}
}
void OperatorWithKernel::ParseMultiInputDataType(
const std::vector<Variable*>& vars, const std::string& name, const std::vector<Variable*>& vars, const std::string& name,
proto::VarType::Type* data_type) const { proto::VarType::Type* data_type) const {
proto::VarType::Type default_data_type = proto::VarType::Type default_data_type =
...@@ -2016,9 +2047,12 @@ proto::VarType::Type OperatorWithKernel::IndicateDataType( ...@@ -2016,9 +2047,12 @@ proto::VarType::Type OperatorWithKernel::IndicateDataType(
proto::VarType::Type dafault_data_type = proto::VarType::Type dafault_data_type =
static_cast<proto::VarType::Type>(-1); static_cast<proto::VarType::Type>(-1);
proto::VarType::Type data_type = dafault_data_type; proto::VarType::Type data_type = dafault_data_type;
for (auto& input : ctx.InNameList()) { for (auto* name : ctx.InNameList()) {
const std::vector<Variable*> vars = ctx.MultiInputVar(input); if (ctx.InputSize(*name) == 1UL) {
ParseInputDataType(vars, input, &data_type); ParseInputDataType(ctx.InputVar(*name), *name, &data_type);
} else {
ParseMultiInputDataType(ctx.MultiInputVar(*name), *name, &data_type);
}
} }
PADDLE_ENFORCE_NE( PADDLE_ENFORCE_NE(
data_type, dafault_data_type, data_type, dafault_data_type,
...@@ -2032,7 +2066,11 @@ proto::VarType::Type OperatorWithKernel::IndicateVarDataType( ...@@ -2032,7 +2066,11 @@ proto::VarType::Type OperatorWithKernel::IndicateVarDataType(
proto::VarType::Type dafault_data_type = proto::VarType::Type dafault_data_type =
static_cast<proto::VarType::Type>(-1); static_cast<proto::VarType::Type>(-1);
proto::VarType::Type data_type = dafault_data_type; proto::VarType::Type data_type = dafault_data_type;
ParseInputDataType(ctx.MultiInputVar(name), name, &data_type); if (ctx.InputSize(name) == 1UL) {
ParseInputDataType(ctx.InputVar(name), name, &data_type);
} else {
ParseMultiInputDataType(ctx.MultiInputVar(name), name, &data_type);
}
PADDLE_ENFORCE_NE( PADDLE_ENFORCE_NE(
data_type, dafault_data_type, data_type, dafault_data_type,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
......
...@@ -43,7 +43,6 @@ limitations under the License. */ ...@@ -43,7 +43,6 @@ limitations under the License. */
#include "paddle/fluid/framework/convert_utils.h" #include "paddle/fluid/framework/convert_utils.h"
#include "paddle/phi/core/compat/arg_map_context.h" #include "paddle/phi/core/compat/arg_map_context.h"
#include "paddle/phi/core/compat/op_utils.h" #include "paddle/phi/core/compat/op_utils.h"
#include "paddle/phi/core/kernel_context.h"
#include "paddle/phi/core/kernel_factory.h" #include "paddle/phi/core/kernel_factory.h"
namespace paddle { namespace paddle {
...@@ -55,6 +54,10 @@ class Variable; ...@@ -55,6 +54,10 @@ class Variable;
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
namespace phi {
class KernelContext;
}
DECLARE_int32(inner_op_parallelism); DECLARE_int32(inner_op_parallelism);
namespace paddle { namespace paddle {
...@@ -330,12 +333,12 @@ class ExecutionContext { ...@@ -330,12 +333,12 @@ class ExecutionContext {
return it->second; return it->second;
} }
virtual std::vector<std::string> InNameList() const { virtual paddle::SmallVector<const std::string*> InNameList() const {
std::vector<std::string> vec_temp; paddle::SmallVector<const std::string*> vec_temp;
vec_temp.reserve(ctx_.inputs.size()); vec_temp.reserve(ctx_.inputs.size());
for (auto& input : ctx_.inputs) { for (auto& input : ctx_.inputs) {
vec_temp.push_back(input.first); vec_temp.push_back(&input.first);
} }
return vec_temp; return vec_temp;
...@@ -677,9 +680,11 @@ class OperatorWithKernel : public OperatorBase { ...@@ -677,9 +680,11 @@ class OperatorWithKernel : public OperatorBase {
// By default all input data must be same. // By default all input data must be same.
proto::VarType::Type IndicateDataType(const ExecutionContext& ctx) const; proto::VarType::Type IndicateDataType(const ExecutionContext& ctx) const;
// used for IndicateDataType // used for IndicateDataType
void ParseInputDataType(const std::vector<Variable*>& vars, void ParseInputDataType(const Variable* vars, const std::string& name,
const std::string& name,
proto::VarType::Type* data_type) const; proto::VarType::Type* data_type) const;
void ParseMultiInputDataType(const std::vector<Variable*>& vars,
const std::string& name,
proto::VarType::Type* data_type) const;
// used for IndicateOrPromoteVarDataTypes // used for IndicateOrPromoteVarDataTypes
Tensor* GetTensorFormInputSafely(const ExecutionContext& ctx, Tensor* GetTensorFormInputSafely(const ExecutionContext& ctx,
const std::string& name) const; const std::string& name) const;
......
...@@ -117,12 +117,12 @@ class DygraphExecutionContext : public framework::ExecutionContext { ...@@ -117,12 +117,12 @@ class DygraphExecutionContext : public framework::ExecutionContext {
return it->second; return it->second;
} }
std::vector<std::string> InNameList() const override { paddle::SmallVector<const std::string*> InNameList() const override {
std::vector<std::string> vec_temp; paddle::SmallVector<const std::string*> vec_temp;
vec_temp.reserve(var_map_in_.size()); vec_temp.reserve(var_map_in_.size());
for (auto& v : var_map_in_) { for (auto& v : var_map_in_) {
vec_temp.push_back(v.first); vec_temp.push_back(&v.first);
} }
return vec_temp; return vec_temp;
...@@ -144,11 +144,19 @@ class DygraphExecutionContext : public framework::ExecutionContext { ...@@ -144,11 +144,19 @@ class DygraphExecutionContext : public framework::ExecutionContext {
} }
size_t InputSize(const std::string& name) const override { size_t InputSize(const std::string& name) const override {
return InputNames(name).size(); auto it = var_map_in_.find(name);
PADDLE_ENFORCE_NE(
it, var_map_in_.end(),
platform::errors::NotFound("Can not find [%s] in Input", name));
return it->second.size();
} }
size_t OutputSize(const std::string& name) const override { size_t OutputSize(const std::string& name) const override {
return OutputNames(name).size(); auto it = var_map_out_.find(name);
PADDLE_ENFORCE_NE(
it, var_map_out_.end(),
platform::errors::NotFound("Can not find [%s] in Output", name));
return it->second.size();
} }
const Variable* InputVar(const std::string& name) const override { const Variable* InputVar(const std::string& name) const override {
......
...@@ -31,6 +31,7 @@ ...@@ -31,6 +31,7 @@
#include "paddle/fluid/framework/convert_utils.h" #include "paddle/fluid/framework/convert_utils.h"
#include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/kernel_context.h"
#include "paddle/phi/core/selected_rows.h" #include "paddle/phi/core/selected_rows.h"
DECLARE_bool(use_mkldnn); DECLARE_bool(use_mkldnn);
......
...@@ -90,7 +90,7 @@ class TransposeOp : public framework::OperatorWithKernel { ...@@ -90,7 +90,7 @@ class TransposeOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
framework::LibraryType library_{framework::LibraryType::kPlain}; framework::LibraryType library_{framework::LibraryType::kPlain};
std::string data_format = ctx.Attr<std::string>("data_format"); auto &data_format = ctx.Attr<std::string>("data_format");
framework::DataLayout layout_ = framework::StringToDataLayout(data_format); framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/int_array.h"
#include "paddle/phi/common/layout.h"
#include "paddle/phi/common/scalar.h"
#include "paddle/utils/variant.h"
namespace phi {
class Place;
// NOTE: Add needed type in the future
using Attribute = paddle::variant<bool,
int,
int64_t,
float,
double,
std::string,
std::vector<bool>,
std::vector<int>,
std::vector<int64_t>,
std::vector<float>,
std::vector<double>,
std::vector<std::string>,
Scalar,
std::vector<Scalar>,
IntArray,
DataType,
DataLayout,
Place>;
} // namespace phi
...@@ -30,7 +30,7 @@ void InferMetaContext::EmplaceBackOutput(MetaTensor output) { ...@@ -30,7 +30,7 @@ void InferMetaContext::EmplaceBackOutput(MetaTensor output) {
outputs_.emplace_back(std::move(output)); outputs_.emplace_back(std::move(output));
output_range_.emplace_back(std::pair<int, int>(index, index + 1)); output_range_.emplace_back(std::pair<int, int>(index, index + 1));
} }
void InferMetaContext::EmplaceBackAttr(paddle::any attr) { void InferMetaContext::EmplaceBackAttr(Attribute attr) {
attrs_.emplace_back(std::move(attr)); attrs_.emplace_back(std::move(attr));
} }
...@@ -120,6 +120,38 @@ std::vector<MetaTensor*> InferMetaContext::MutableOutputBetween(size_t start, ...@@ -120,6 +120,38 @@ std::vector<MetaTensor*> InferMetaContext::MutableOutputBetween(size_t start,
return result; return result;
} }
template <typename AttrType>
const AttrType& InferMetaContext::AttrAt(size_t idx) const {
try {
return paddle::get<AttrType>(attrs_.at(idx));
} catch (paddle::bad_variant_access const& e) {
PADDLE_THROW(phi::errors::InvalidArgument(
"Attribute cast error in InferMeta Context, the expected attribute "
"type is `%s`.",
std::type_index(typeid(AttrType)).name()));
}
}
template const bool& InferMetaContext::AttrAt(size_t idx) const;
template const int& InferMetaContext::AttrAt(size_t idx) const;
template const int64_t& InferMetaContext::AttrAt(size_t idx) const;
template const float& InferMetaContext::AttrAt(size_t idx) const;
template const double& InferMetaContext::AttrAt(size_t idx) const;
template const std::string& InferMetaContext::AttrAt(size_t idx) const;
template const std::vector<bool>& InferMetaContext::AttrAt(size_t idx) const;
template const std::vector<int>& InferMetaContext::AttrAt(size_t idx) const;
template const std::vector<int64_t>& InferMetaContext::AttrAt(size_t idx) const;
template const std::vector<float>& InferMetaContext::AttrAt(size_t idx) const;
template const std::vector<double>& InferMetaContext::AttrAt(size_t idx) const;
template const std::vector<std::string>& InferMetaContext::AttrAt(
size_t idx) const;
template const Scalar& InferMetaContext::AttrAt(size_t idx) const;
template const std::vector<Scalar>& InferMetaContext::AttrAt(size_t idx) const;
template const IntArray& InferMetaContext::AttrAt(size_t idx) const;
template const DataType& InferMetaContext::AttrAt(size_t idx) const;
template const DataLayout& InferMetaContext::AttrAt(size_t idx) const;
template const Place& InferMetaContext::AttrAt(size_t idx) const;
MetaFnFactory& MetaFnFactory::Instance() { MetaFnFactory& MetaFnFactory::Instance() {
static MetaFnFactory g_meta_fn_map; static MetaFnFactory g_meta_fn_map;
return g_meta_fn_map; return g_meta_fn_map;
......
...@@ -21,6 +21,7 @@ limitations under the License. */ ...@@ -21,6 +21,7 @@ limitations under the License. */
#include "paddle/phi/common/int_array.h" #include "paddle/phi/common/int_array.h"
#include "paddle/phi/common/scalar.h" #include "paddle/phi/common/scalar.h"
#include "paddle/phi/core/attribute.h"
#include "paddle/phi/core/enforce.h" #include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/macros.h" #include "paddle/phi/core/macros.h"
#include "paddle/phi/core/meta_tensor.h" #include "paddle/phi/core/meta_tensor.h"
...@@ -41,7 +42,7 @@ class InferMetaContext { ...@@ -41,7 +42,7 @@ class InferMetaContext {
void EmplaceBackInput(MetaTensor input); void EmplaceBackInput(MetaTensor input);
void EmplaceBackOutput(MetaTensor output); void EmplaceBackOutput(MetaTensor output);
void EmplaceBackAttr(paddle::any attr); void EmplaceBackAttr(Attribute attr);
void EmplaceBackInputs( void EmplaceBackInputs(
paddle::SmallVector<MetaTensor, phi::kInputSmallVectorSize> inputs); paddle::SmallVector<MetaTensor, phi::kInputSmallVectorSize> inputs);
...@@ -61,17 +62,7 @@ class InferMetaContext { ...@@ -61,17 +62,7 @@ class InferMetaContext {
size_t end); size_t end);
template <typename AttrType> template <typename AttrType>
AttrType AttrAt(size_t idx) { const AttrType& AttrAt(size_t idx) const;
try {
return paddle::any_cast<AttrType>(attrs_.at(idx));
} catch (paddle::bad_any_cast& e) {
PADDLE_THROW(phi::errors::InvalidArgument(
"Attribute cast error in InferMeta Context, the expected attribute "
"type is `%s`, but actual attribute type is `%s`.",
std::type_index(typeid(AttrType)).name(),
std::type_index(attrs_.at(idx).type()).name()));
}
}
const std::pair<int, int>& InputRangeAt(size_t idx) const; const std::pair<int, int>& InputRangeAt(size_t idx) const;
const std::pair<int, int>& OutputRangeAt(size_t idx) const; const std::pair<int, int>& OutputRangeAt(size_t idx) const;
...@@ -81,7 +72,7 @@ class InferMetaContext { ...@@ -81,7 +72,7 @@ class InferMetaContext {
protected: protected:
MetaConfig config_; MetaConfig config_;
paddle::SmallVector<paddle::any, kAttrSmallVectorSize> attrs_; paddle::SmallVector<Attribute, kAttrSmallVectorSize> attrs_;
paddle::SmallVector<std::pair<int, int>, phi::kInputSmallVectorSize> paddle::SmallVector<std::pair<int, int>, phi::kInputSmallVectorSize>
input_range_; input_range_;
...@@ -111,6 +102,21 @@ class InferMetaContext { ...@@ -111,6 +102,21 @@ class InferMetaContext {
} \ } \
} }
#define PD_SPECIALIZE_InferMetaFnCallHelper_FOR_CONST_ATTRIBUTE_REF(attr_type) \
template <typename... Tail> \
struct InferMetaFnCallHelper<const attr_type&, Tail...> { \
template <int in_idx, int attr_idx, int out_idx, typename... PreviousArgs> \
static void Call(InferMetaContext* ctx, PreviousArgs&... pargs) { \
static_assert(out_idx == 0, \
"InferMeta's Attributes should appear before Outputs."); \
const attr_type& arg = ctx->AttrAt<attr_type>(attr_idx); \
InferMetaFnCallHelper< \
Tail...>::template Call<in_idx, attr_idx + 1, out_idx>(ctx, \
pargs..., \
arg); \
} \
}
template <typename T> template <typename T>
struct InferMetaTypeTag {}; struct InferMetaTypeTag {};
...@@ -201,27 +207,27 @@ struct InferMetaFnImpl<Return (*)(Args...), infer_meta_fn> { ...@@ -201,27 +207,27 @@ struct InferMetaFnImpl<Return (*)(Args...), infer_meta_fn> {
} }
}; };
// TODO(chenweihang): support other attr type later
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(bool); PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(bool);
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(int); PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(int);
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(int64_t); PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(int64_t);
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(float); PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(float);
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(const std::string&);
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(const std::vector<bool>&);
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(const std::vector<int>&);
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(
const std::vector<int64_t>&);
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(const std::vector<float>&);
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(const std::vector<double>&);
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(
const std::vector<std::string>&);
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(DataType); PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(DataType);
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(Backend); PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(Backend);
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(DataLayout); PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(DataLayout);
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(const Scalar&); PD_SPECIALIZE_InferMetaFnCallHelper_FOR_CONST_ATTRIBUTE_REF(std::string);
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(const IntArray&); PD_SPECIALIZE_InferMetaFnCallHelper_FOR_CONST_ATTRIBUTE_REF(Scalar);
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_CONST_ATTRIBUTE_REF(IntArray);
// TODO(chenweihang): support vector<MetaTensor> input later PD_SPECIALIZE_InferMetaFnCallHelper_FOR_CONST_ATTRIBUTE_REF(
std::vector<bool>);
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_CONST_ATTRIBUTE_REF(std::vector<int>);
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_CONST_ATTRIBUTE_REF(
std::vector<int64_t>);
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_CONST_ATTRIBUTE_REF(
std::vector<float>);
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_CONST_ATTRIBUTE_REF(
std::vector<double>);
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_CONST_ATTRIBUTE_REF(
std::vector<std::string>);
template <typename... Tail> template <typename... Tail>
struct InferMetaFnCallHelper<MetaTensor*, Tail...> { struct InferMetaFnCallHelper<MetaTensor*, Tail...> {
......
...@@ -73,7 +73,7 @@ void KernelContext::EmplaceBackOutputsWithoutSetRange( ...@@ -73,7 +73,7 @@ void KernelContext::EmplaceBackOutputsWithoutSetRange(
std::make_move_iterator(outputs.end())); std::make_move_iterator(outputs.end()));
} }
void KernelContext::EmplaceBackAttr(paddle::any attr) { void KernelContext::EmplaceBackAttr(Attribute attr) {
attrs_.emplace_back(std::move(attr)); attrs_.emplace_back(std::move(attr));
} }
...@@ -113,4 +113,34 @@ const std::pair<int, int>& KernelContext::OutputRangeAt(size_t idx) const { ...@@ -113,4 +113,34 @@ const std::pair<int, int>& KernelContext::OutputRangeAt(size_t idx) const {
return output_range_.at(idx); return output_range_.at(idx);
} }
template <typename AttrType>
const AttrType& KernelContext::AttrAt(size_t idx) const {
try {
return paddle::get<AttrType>(attrs_.at(idx));
} catch (paddle::bad_variant_access const& ex) {
PADDLE_THROW(phi::errors::InvalidArgument(
"Attribute cast error in Op Kernel Context."));
}
}
template const bool& KernelContext::AttrAt(size_t idx) const;
template const int& KernelContext::AttrAt(size_t idx) const;
template const int64_t& KernelContext::AttrAt(size_t idx) const;
template const float& KernelContext::AttrAt(size_t idx) const;
template const double& KernelContext::AttrAt(size_t idx) const;
template const std::string& KernelContext::AttrAt(size_t idx) const;
template const std::vector<bool>& KernelContext::AttrAt(size_t idx) const;
template const std::vector<int>& KernelContext::AttrAt(size_t idx) const;
template const std::vector<int64_t>& KernelContext::AttrAt(size_t idx) const;
template const std::vector<float>& KernelContext::AttrAt(size_t idx) const;
template const std::vector<double>& KernelContext::AttrAt(size_t idx) const;
template const std::vector<std::string>& KernelContext::AttrAt(
size_t idx) const;
template const Scalar& KernelContext::AttrAt(size_t idx) const;
template const std::vector<Scalar>& KernelContext::AttrAt(size_t idx) const;
template const IntArray& KernelContext::AttrAt(size_t idx) const;
template const DataType& KernelContext::AttrAt(size_t idx) const;
template const DataLayout& KernelContext::AttrAt(size_t idx) const;
template const Place& KernelContext::AttrAt(size_t idx) const;
} // namespace phi } // namespace phi
...@@ -17,11 +17,12 @@ ...@@ -17,11 +17,12 @@
#include <iterator> #include <iterator>
#include <utility> #include <utility>
#include "paddle/phi/core/attribute.h"
#include "paddle/phi/core/device_context.h" #include "paddle/phi/core/device_context.h"
#include "paddle/phi/core/enforce.h" #include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/tensor_base.h" #include "paddle/phi/core/tensor_base.h"
#include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/core/tensor_utils.h"
#include "paddle/utils/any.h" #include "paddle/phi/core/type_defs.h"
#include "paddle/utils/optional.h" #include "paddle/utils/optional.h"
#include "paddle/utils/small_vector.h" #include "paddle/utils/small_vector.h"
...@@ -64,7 +65,7 @@ class KernelContext { ...@@ -64,7 +65,7 @@ class KernelContext {
void EmplaceBackOutputsWithoutSetRange( void EmplaceBackOutputsWithoutSetRange(
paddle::SmallVector<TensorBase*> outputs); paddle::SmallVector<TensorBase*> outputs);
void EmplaceBackAttr(paddle::any attr); void EmplaceBackAttr(Attribute attr);
const std::pair<int, int>& InputRangeAt(size_t idx) const; const std::pair<int, int>& InputRangeAt(size_t idx) const;
...@@ -128,14 +129,7 @@ class KernelContext { ...@@ -128,14 +129,7 @@ class KernelContext {
} }
template <typename AttrType> template <typename AttrType>
AttrType AttrAt(size_t idx) const { const AttrType& AttrAt(size_t idx) const;
try {
return paddle::any_cast<AttrType>(attrs_.at(idx));
} catch (paddle::bad_any_cast&) {
PADDLE_THROW(phi::errors::InvalidArgument(
"Attribute cast error in Op Kernel Context."));
}
}
size_t InputsSize() const { return inputs_.size(); } size_t InputsSize() const { return inputs_.size(); }
size_t OutputsSize() const { return outputs_.size(); } size_t OutputsSize() const { return outputs_.size(); }
...@@ -146,10 +140,11 @@ class KernelContext { ...@@ -146,10 +140,11 @@ class KernelContext {
paddle::SmallVector<const TensorBase*> inputs_; paddle::SmallVector<const TensorBase*> inputs_;
paddle::SmallVector<TensorBase*> outputs_; paddle::SmallVector<TensorBase*> outputs_;
paddle::SmallVector<paddle::any> attrs_; paddle::SmallVector<Attribute, kAttrSmallVectorSize> attrs_;
paddle::SmallVector<std::pair<int, int>> input_range_; paddle::SmallVector<std::pair<int, int>, kInputSmallVectorSize> input_range_;
paddle::SmallVector<std::pair<int, int>> output_range_; paddle::SmallVector<std::pair<int, int>, kOutputSmallVectorSize>
output_range_;
}; };
} // namespace phi } // namespace phi
...@@ -105,6 +105,11 @@ struct KernelArgsParseFunctor<Return_ (*)(Args_...)> { ...@@ -105,6 +105,11 @@ struct KernelArgsParseFunctor<Return_ (*)(Args_...)> {
default_tensor_layout, default_tensor_layout,
default_key.dtype(), default_key.dtype(),
arg_type); arg_type);
} else if (arg_type == std::type_index(typeid(const StringTensor&))) {
args_def->AppendInput(default_key.backend(),
default_tensor_layout,
default_key.dtype(),
arg_type);
} else if (arg_type == std::type_index(typeid(const SparseCooTensor&))) { } else if (arg_type == std::type_index(typeid(const SparseCooTensor&))) {
args_def->AppendInput(default_key.backend(), args_def->AppendInput(default_key.backend(),
default_tensor_layout, default_tensor_layout,
...@@ -153,6 +158,11 @@ struct KernelArgsParseFunctor<Return_ (*)(Args_...)> { ...@@ -153,6 +158,11 @@ struct KernelArgsParseFunctor<Return_ (*)(Args_...)> {
default_tensor_layout, default_tensor_layout,
default_key.dtype(), default_key.dtype(),
arg_type); arg_type);
} else if (arg_type == std::type_index(typeid(StringTensor*))) {
args_def->AppendOutput(default_key.backend(),
default_tensor_layout,
default_key.dtype(),
arg_type);
} else { } else {
// Attribute deal with // Attribute deal with
// TODO(chenweihang): now here allow any types of attribute, maybe // TODO(chenweihang): now here allow any types of attribute, maybe
......
...@@ -168,6 +168,24 @@ namespace phi { ...@@ -168,6 +168,24 @@ namespace phi {
} \ } \
} }
#define PD_SPECIALIZE_KernelCallHelper_FOR_CONST_ATTRIBUTE_REF(attr_type) \
template <typename... Tail> \
struct KernelCallHelper<const attr_type&, Tail...> { \
template <int dev_ctx_idx, \
int in_idx, \
int attr_idx, \
int out_idx, \
typename... PreviousArgs> \
static void Compute(KernelContext* ctx, PreviousArgs&... pargs) { \
static_assert(out_idx == 0, \
"Kernel's Attributes should appear before Outputs."); \
const attr_type& arg = ctx->AttrAt<attr_type>(attr_idx); \
KernelCallHelper<Tail...>:: \
template Compute<dev_ctx_idx, in_idx, attr_idx + 1, out_idx>( \
ctx, pargs..., arg); \
} \
}
#define PD_SPECIALIZE_KernelCallHelper_FOR_OUTPUT(tensor_type) \ #define PD_SPECIALIZE_KernelCallHelper_FOR_OUTPUT(tensor_type) \
template <typename... Tail> \ template <typename... Tail> \
struct KernelCallHelper<tensor_type*, Tail...> { \ struct KernelCallHelper<tensor_type*, Tail...> { \
...@@ -270,19 +288,20 @@ struct KernelImpl<Return (*)(DevCtx, Args...), kernel_fn> { ...@@ -270,19 +288,20 @@ struct KernelImpl<Return (*)(DevCtx, Args...), kernel_fn> {
PD_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(int); PD_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(int);
PD_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(int64_t); PD_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(int64_t);
PD_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(phi::dtype::float16); PD_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(phi::dtype::float16);
PD_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(const Scalar&);
PD_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(DataType); PD_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(DataType);
PD_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(DataLayout); PD_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(DataLayout);
PD_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(Place); PD_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(Place);
PD_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(const std::vector<int64_t>&); PD_SPECIALIZE_KernelCallHelper_FOR_CONST_ATTRIBUTE_REF(std::string);
PD_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(const IntArray&); PD_SPECIALIZE_KernelCallHelper_FOR_CONST_ATTRIBUTE_REF(Scalar);
PD_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(const std::vector<int>&); PD_SPECIALIZE_KernelCallHelper_FOR_CONST_ATTRIBUTE_REF(IntArray);
PD_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(const std::string&); PD_SPECIALIZE_KernelCallHelper_FOR_CONST_ATTRIBUTE_REF(std::vector<bool>);
PD_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(const std::vector<bool>&); PD_SPECIALIZE_KernelCallHelper_FOR_CONST_ATTRIBUTE_REF(std::vector<int>);
PD_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(const std::vector<float>&); PD_SPECIALIZE_KernelCallHelper_FOR_CONST_ATTRIBUTE_REF(std::vector<int64_t>);
PD_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(const std::vector<double>&); PD_SPECIALIZE_KernelCallHelper_FOR_CONST_ATTRIBUTE_REF(std::vector<float>);
PD_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(const std::vector<std::string>&); PD_SPECIALIZE_KernelCallHelper_FOR_CONST_ATTRIBUTE_REF(std::vector<double>);
PD_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(const std::vector<Scalar>&); PD_SPECIALIZE_KernelCallHelper_FOR_CONST_ATTRIBUTE_REF(
std::vector<std::string>);
PD_SPECIALIZE_KernelCallHelper_FOR_CONST_ATTRIBUTE_REF(std::vector<Scalar>);
/* Output Helpers */ /* Output Helpers */
......
...@@ -15,6 +15,8 @@ ...@@ -15,6 +15,8 @@
#pragma once #pragma once
#include <functional> #include <functional>
#include <string>
#include <vector>
namespace phi { namespace phi {
......
...@@ -228,13 +228,6 @@ void CholeskyInferMeta(const MetaTensor& x, bool upper, MetaTensor* out) { ...@@ -228,13 +228,6 @@ void CholeskyInferMeta(const MetaTensor& x, bool upper, MetaTensor* out) {
out->set_dtype(x.dtype()); out->set_dtype(x.dtype());
} }
void CopyToInferMeta(const MetaTensor& x,
Backend backend,
bool blocking,
MetaTensor* out) {
UnchangedInferMeta(x, out);
}
void CreateLikeInferMeta(const MetaTensor& x, DataType dtype, MetaTensor* out) { void CreateLikeInferMeta(const MetaTensor& x, DataType dtype, MetaTensor* out) {
out->set_dims(x.dims()); out->set_dims(x.dims());
out->set_dtype(dtype == DataType::UNDEFINED ? x.dtype() : dtype); out->set_dtype(dtype == DataType::UNDEFINED ? x.dtype() : dtype);
...@@ -3002,5 +2995,5 @@ void WhereIndexInferMeta(const MetaTensor& condition, MetaTensor* out) { ...@@ -3002,5 +2995,5 @@ void WhereIndexInferMeta(const MetaTensor& condition, MetaTensor* out) {
} // namespace phi } // namespace phi
PD_REGISTER_INFER_META_FN(copy_to, phi::CopyToInferMeta); PD_REGISTER_INFER_META_FN(flatten, phi::FlattenInferMeta);
PD_REGISTER_INFER_META_FN(split, phi::SplitInferMeta); PD_REGISTER_INFER_META_FN(split, phi::SplitInferMeta);
...@@ -58,11 +58,6 @@ void CastInferMeta(const MetaTensor& x, DataType out_dtype, MetaTensor* out); ...@@ -58,11 +58,6 @@ void CastInferMeta(const MetaTensor& x, DataType out_dtype, MetaTensor* out);
void CholeskyInferMeta(const MetaTensor& x, bool upper, MetaTensor* out); void CholeskyInferMeta(const MetaTensor& x, bool upper, MetaTensor* out);
void CopyToInferMeta(const MetaTensor& x,
Backend backend,
bool blocking,
MetaTensor* out);
void CreateLikeInferMeta(const MetaTensor& x, DataType dtype, MetaTensor* out); void CreateLikeInferMeta(const MetaTensor& x, DataType dtype, MetaTensor* out);
void CumsumInferMeta(const MetaTensor& x, void CumsumInferMeta(const MetaTensor& x,
......
...@@ -14,6 +14,8 @@ ...@@ -14,6 +14,8 @@
#include "paddle/phi/kernels/where_grad_kernel.h" #include "paddle/phi/kernels/where_grad_kernel.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi { namespace phi {
template <typename T, typename Context> template <typename T, typename Context>
......
...@@ -14,6 +14,8 @@ ...@@ -14,6 +14,8 @@
#include "paddle/phi/kernels/where_kernel.h" #include "paddle/phi/kernels/where_kernel.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi { namespace phi {
template <typename T, typename Context> template <typename T, typename Context>
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#pragma once #pragma once
#include <glog/logging.h> #include <glog/logging.h>
#include <algorithm> #include <algorithm>
#include <memory> #include <memory>
...@@ -33,7 +34,6 @@ ...@@ -33,7 +34,6 @@
#include "paddle/phi/common/float16.h" #include "paddle/phi/common/float16.h"
#include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/enforce.h" #include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/eigen/common.h" #include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/eigen/extensions.h" #include "paddle/phi/kernels/funcs/eigen/extensions.h"
......
...@@ -14,6 +14,9 @@ ...@@ -14,6 +14,9 @@
#include "paddle/phi/kernels/where_grad_kernel.h" #include "paddle/phi/kernels/where_grad_kernel.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi { namespace phi {
template <typename T> template <typename T>
......
...@@ -14,6 +14,8 @@ ...@@ -14,6 +14,8 @@
#include "paddle/phi/kernels/where_kernel.h" #include "paddle/phi/kernels/where_kernel.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/broadcast_function.h" #include "paddle/phi/kernels/funcs/broadcast_function.h"
#include "paddle/phi/kernels/funcs/elementwise_functor.h" #include "paddle/phi/kernels/funcs/elementwise_functor.h"
......
...@@ -14,10 +14,7 @@ ...@@ -14,10 +14,7 @@
#pragma once #pragma once
#include "paddle/phi/backends/all_context.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi { namespace phi {
......
...@@ -14,10 +14,7 @@ ...@@ -14,10 +14,7 @@
#pragma once #pragma once
#include "paddle/phi/backends/all_context.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi { namespace phi {
......
...@@ -49,7 +49,6 @@ void FakeDot(const Context& dev_ctx, ...@@ -49,7 +49,6 @@ void FakeDot(const Context& dev_ctx,
float fake_attr_float, float fake_attr_float,
double fake_attr_double, double fake_attr_double,
int64_t fake_attr_int64, int64_t fake_attr_int64,
phi::dtype::float16 fake_attr_f16,
phi::DataType fake_attr_dtype, phi::DataType fake_attr_dtype,
const phi::Scalar& fake_attr_scalar, const phi::Scalar& fake_attr_scalar,
const phi::IntArray& fake_attr_int_array, const phi::IntArray& fake_attr_int_array,
...@@ -64,7 +63,6 @@ void FakeDot(const Context& dev_ctx, ...@@ -64,7 +63,6 @@ void FakeDot(const Context& dev_ctx,
std::cout << "fake_attr_float: " << fake_attr_float << std::endl; std::cout << "fake_attr_float: " << fake_attr_float << std::endl;
std::cout << "fake_attr_double: " << fake_attr_double << std::endl; std::cout << "fake_attr_double: " << fake_attr_double << std::endl;
std::cout << "fake_attr_int64: " << fake_attr_int64 << std::endl; std::cout << "fake_attr_int64: " << fake_attr_int64 << std::endl;
std::cout << "fake_attr_f16: " << fake_attr_f16 << std::endl;
std::cout << "fake_attr_dtype: " << fake_attr_dtype << std::endl; std::cout << "fake_attr_dtype: " << fake_attr_dtype << std::endl;
std::cout << "fake_attr_int64_vec: " << fake_attr_int64_vec.size() std::cout << "fake_attr_int64_vec: " << fake_attr_int64_vec.size()
<< std::endl; << std::endl;
...@@ -78,7 +76,6 @@ void FakeDot(const Context& dev_ctx, ...@@ -78,7 +76,6 @@ void FakeDot(const Context& dev_ctx,
assert(fake_attr_float == 2); assert(fake_attr_float == 2);
assert(fake_attr_double == 3); assert(fake_attr_double == 3);
assert(fake_attr_int64 == 4); assert(fake_attr_int64 == 4);
assert(fake_attr_f16 == phi::dtype::float16(5));
assert(fake_attr_dtype == phi::DataType::UINT32); assert(fake_attr_dtype == phi::DataType::UINT32);
assert(fake_attr_int64_vec.size() == 0); assert(fake_attr_int64_vec.size() == 0);
assert(fake_attr_int_vec.size() == 0); assert(fake_attr_int_vec.size() == 0);
...@@ -248,7 +245,6 @@ TEST(CustomKernel, custom_kernel_dot) { ...@@ -248,7 +245,6 @@ TEST(CustomKernel, custom_kernel_dot) {
float fake_attr_float = 2.0; float fake_attr_float = 2.0;
double fake_attr_double = 3.0; double fake_attr_double = 3.0;
int64_t fake_attr_int64 = 4; int64_t fake_attr_int64 = 4;
phi::dtype::float16 fake_attr_f16 = phi::dtype::float16(5);
phi::DataType fake_attr_dtype = phi::DataType::UINT32; phi::DataType fake_attr_dtype = phi::DataType::UINT32;
paddle::framework::LoDTensor tmp_tensor; paddle::framework::LoDTensor tmp_tensor;
tmp_tensor.mutable_data<uint8_t>({1}, phi::TransToPhiPlace(backend)); tmp_tensor.mutable_data<uint8_t>({1}, phi::TransToPhiPlace(backend));
...@@ -262,7 +258,6 @@ TEST(CustomKernel, custom_kernel_dot) { ...@@ -262,7 +258,6 @@ TEST(CustomKernel, custom_kernel_dot) {
kernel_context.EmplaceBackAttr(fake_attr_float); kernel_context.EmplaceBackAttr(fake_attr_float);
kernel_context.EmplaceBackAttr(fake_attr_double); kernel_context.EmplaceBackAttr(fake_attr_double);
kernel_context.EmplaceBackAttr(fake_attr_int64); kernel_context.EmplaceBackAttr(fake_attr_int64);
kernel_context.EmplaceBackAttr(fake_attr_f16);
kernel_context.EmplaceBackAttr(fake_attr_dtype); kernel_context.EmplaceBackAttr(fake_attr_dtype);
kernel_context.EmplaceBackAttr(fake_attr_scalar); kernel_context.EmplaceBackAttr(fake_attr_scalar);
kernel_context.EmplaceBackAttr(fake_attr_int_array); kernel_context.EmplaceBackAttr(fake_attr_int_array);
......
...@@ -60,32 +60,6 @@ TEST(MetaFnFactory, InferMetaFnExists) { ...@@ -60,32 +60,6 @@ TEST(MetaFnFactory, InferMetaFnExists) {
EXPECT_EQ(dense_out1.dims()[1], dense_out2.dims()[1]); EXPECT_EQ(dense_out1.dims()[1], dense_out2.dims()[1]);
} }
TEST(MetaFnFactory, CopyInferMetaFn) {
phi::DenseTensor dense_x;
dense_x.Resize({3, 4});
phi::MetaTensor meta_x(&dense_x);
phi::DenseTensor dense_out1;
phi::MetaTensor meta_out(&dense_out1);
phi::UnchangedInferMeta(meta_x, &meta_out);
auto shared_meat_x = phi::MetaTensor(&dense_x);
phi::DenseTensor dense_out2;
auto shared_meta_out = phi::MetaTensor(&dense_out2);
phi::InferMetaContext ctx;
ctx.EmplaceBackInput(shared_meat_x);
ctx.EmplaceBackAttr(Backend::CPU);
ctx.EmplaceBackAttr(false);
ctx.EmplaceBackOutput(shared_meta_out);
ctx.SetMetaConfig({/*is_runtime =*/true, /*is_run_mkldnn_kernel=*/false});
phi::MetaFnFactory::Instance().Get("copy_to")(&ctx);
EXPECT_EQ(dense_out1.dims().size(), dense_out2.dims().size());
EXPECT_EQ(dense_out1.dims()[0], dense_out2.dims()[0]);
EXPECT_EQ(dense_out1.dims()[1], dense_out2.dims()[1]);
}
TEST(MetaFnFactory, SplitInferMetaFn) { TEST(MetaFnFactory, SplitInferMetaFn) {
phi::DenseTensor dense_x; phi::DenseTensor dense_x;
dense_x.Resize({4, 10}); dense_x.Resize({4, 10});
......
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册