未验证 提交 45dd4a5f 编写于 作者: Y YuanRisheng 提交者: GitHub

[PTen]Remove infershape of Reshape OP (#39631)

* remove infershape and Xshape

* add xshape

* fix bugs when run ci

* fix bugs when run ci

* fix bugs when run infrt test

* pass converage
上级 f858b645
...@@ -20,6 +20,7 @@ limitations under the License. */ ...@@ -20,6 +20,7 @@ limitations under the License. */
#include "paddle/fluid/framework/framework.pb.h" #include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/pten_utils.h" #include "paddle/fluid/framework/pten_utils.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
#include "paddle/phi/common/scalar_array.h"
#include "paddle/phi/core/compat/arg_map_context.h" #include "paddle/phi/core/compat/arg_map_context.h"
#include "paddle/phi/core/compat/convert_utils.h" #include "paddle/phi/core/compat/convert_utils.h"
#include "paddle/phi/core/compat/op_utils.h" #include "paddle/phi/core/compat/op_utils.h"
...@@ -54,7 +55,12 @@ class InferShapeArgumentMappingContext : public phi::ArgumentMappingContext { ...@@ -54,7 +55,12 @@ class InferShapeArgumentMappingContext : public phi::ArgumentMappingContext {
} }
size_t InputSize(const std::string& name) const override { size_t InputSize(const std::string& name) const override {
return ctx_.Inputs(name).size(); if (ctx_.HasInputs(name)) {
return ctx_.Inputs(name).size();
} else if (ctx_.HasInput(name)) {
return 1;
}
return 0;
} }
size_t OutputSize(const std::string& name) const override { size_t OutputSize(const std::string& name) const override {
...@@ -288,6 +294,16 @@ phi::InferMetaContext BuildInferMetaContext(InferShapeContext* ctx, ...@@ -288,6 +294,16 @@ phi::InferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
auto& attr_names = std::get<1>(signature.args); auto& attr_names = std::get<1>(signature.args);
auto& output_names = std::get<2>(signature.args); auto& output_names = std::get<2>(signature.args);
auto kernels_map =
phi::KernelFactory::Instance().SelectKernelMap(signature.name);
if (kernels_map.size() == 0) {
PADDLE_THROW(
platform::errors::Unimplemented("Not find `%s` kernels when construct "
"InferMetaContext.",
signature.name));
}
auto attr_defs = kernels_map.cbegin()->second.args_def().attribute_defs();
// TODO(chenweihang): support multiple inputs and outputs later // TODO(chenweihang): support multiple inputs and outputs later
phi::InferMetaContext infer_mete_context; phi::InferMetaContext infer_mete_context;
for (auto& in_name : input_names) { for (auto& in_name : input_names) {
...@@ -299,9 +315,70 @@ phi::InferMetaContext BuildInferMetaContext(InferShapeContext* ctx, ...@@ -299,9 +315,70 @@ phi::InferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
} }
} }
for (auto& out_name : output_names) {
if (ctx->HasOutput(out_name)) {
infer_meta_context.EmplaceBackOutput(std::make_shared<CompatMetaTensor>(
ctx->GetOutputVarPtrs(out_name)[0], ctx->IsRuntime()));
} else {
infer_meta_context.EmplaceBackOutput({nullptr});
}
}
auto attr_reader = ctx->Attrs(); auto attr_reader = ctx->Attrs();
for (auto& attr_name : attr_names) { for (size_t i = 0; i < attr_names.size(); ++i) {
if (ctx->HasAttr(attr_name)) { auto attr_name = attr_names[i];
if (attr_defs[i].type_index == std::type_index(typeid(phi::ScalarArray))) {
// When attr is a vector_tensor or tensor, transform it to ScalarArray
if (ctx->HasInputs(attr_name) || ctx->HasInput(attr_name)) {
const auto& infershape_inputs = ctx->GetInputVarPtrs(attr_name);
if (ctx->IsRuntime()) {
// If is in runtime, we will get tensor's value for ScalarArray
// and push it into attrs
std::vector<Variable*> vars;
vars.reserve(infershape_inputs.size());
for (size_t i = 0; i < infershape_inputs.size(); i++) {
vars.push_back(BOOST_GET_CONST(Variable*, infershape_inputs[i]));
}
if (infershape_inputs.size() != 1) {
infer_meta_context.EmplaceBackAttr(
std::move(experimental::MakePtenScalarArrayFromVarList(vars)));
} else {
infer_meta_context.EmplaceBackAttr(
std::move(experimental::MakePtenScalarArrayFromVar(*vars[0])));
}
} else {
// If is not in runtime, we will set default value(-1) for ScalarArray
int64_t num_ele = 1;
std::vector<VarDesc*> vars;
vars.reserve(infershape_inputs.size());
for (size_t i = 0; i < infershape_inputs.size(); i++) {
vars.push_back(BOOST_GET_CONST(VarDesc*, infershape_inputs[i]));
}
for (auto& var : vars) {
const auto& tensor_dims = var->GetShape();
for (size_t i = 0; i < tensor_dims.size(); ++i) {
num_ele *= tensor_dims[i];
}
}
phi::ScalarArray tensor_attr(std::vector<int32_t>(num_ele, -1));
tensor_attr.SetFromTensor(true);
infer_meta_context.EmplaceBackAttr(std::move(tensor_attr));
}
} else if (ctx->HasAttr(attr_name)) {
auto& attr = attr_reader.GetAttr(attr_name);
if (std::type_index(attr.type()) ==
std::type_index(typeid(std::vector<int32_t>))) {
infer_meta_context.EmplaceBackAttr(std::move(
phi::ScalarArray(BOOST_GET_CONST(std::vector<int32_t>, attr))));
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported cast op attribute `%s` to ScalarArray when "
"construct KernelContext.",
attr_name));
}
}
} else if (ctx->HasAttr(attr_name)) {
// Emplace Back Attr according to the type of attr.
auto& attr = attr_reader.GetAttr(attr_name); auto& attr = attr_reader.GetAttr(attr_name);
if (std::type_index(attr.type()) == std::type_index(typeid(bool))) { if (std::type_index(attr.type()) == std::type_index(typeid(bool))) {
infer_meta_context.EmplaceBackAttr(BOOST_GET_CONST(bool, attr)); infer_meta_context.EmplaceBackAttr(BOOST_GET_CONST(bool, attr));
...@@ -345,17 +422,6 @@ phi::InferMetaContext BuildInferMetaContext(InferShapeContext* ctx, ...@@ -345,17 +422,6 @@ phi::InferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
"Unsupported attribute type is received when call " "Unsupported attribute type is received when call "
"InferShapeFunctor.")); "InferShapeFunctor."));
} }
} else {
// do nothing
}
}
for (auto& out_name : output_names) {
if (ctx->HasOutput(out_name)) {
infer_meta_context.EmplaceBackOutput(std::make_shared<CompatMetaTensor>(
ctx->GetOutputVarPtrs(out_name)[0], ctx->IsRuntime()));
} else {
infer_meta_context.EmplaceBackOutput({nullptr});
} }
} }
......
...@@ -23,8 +23,11 @@ limitations under the License. */ ...@@ -23,8 +23,11 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/compat/op_utils.h" #include "paddle/phi/core/compat/op_utils.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/infermeta_utils.h" #include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/core/kernel_registry.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -93,6 +96,17 @@ phi::KernelSignature InferShapeUtilsTestOpArgumentMapping( ...@@ -93,6 +96,17 @@ phi::KernelSignature InferShapeUtilsTestOpArgumentMapping(
{}); {});
} }
template <typename T, typename Context>
void InferShapeUtilsTestKernel(
const Context& dev_ctx, const phi::DenseTensor& x, bool attr1, int attr2,
int64_t attr3, float attr4, const std::string& attr5,
const std::vector<bool>& attr6, const std::vector<int>& attr7,
const std::vector<int64_t>& attr8, const std::vector<float>& attr9,
const std::vector<double>& attr10, const std::vector<std::string>& attr11,
phi::DenseTensor* out) {
VLOG(6) << "Come into InferShapeUtilsTestKernel";
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -104,6 +118,9 @@ REGISTER_OPERATOR(infer_shape_utils_test, ...@@ -104,6 +118,9 @@ REGISTER_OPERATOR(infer_shape_utils_test,
paddle::framework::InferShapeUtilsTestOpMaker, paddle::framework::InferShapeUtilsTestOpMaker,
InferShapeUtilsTestInferShapeFunctor); InferShapeUtilsTestInferShapeFunctor);
PT_REGISTER_KERNEL(infer_shape_utils_test, CPU, ALL_LAYOUT,
paddle::framework::InferShapeUtilsTestKernel, int) {}
TEST(InferShapeUtilsTest, ALL) { TEST(InferShapeUtilsTest, ALL) {
paddle::framework::ProgramDesc prog; paddle::framework::ProgramDesc prog;
paddle::framework::proto::BlockDesc proto_block; paddle::framework::proto::BlockDesc proto_block;
......
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#include <string> #include <string>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/pten_utils.h" #include "paddle/fluid/framework/pten_utils.h"
...@@ -21,8 +22,11 @@ limitations under the License. */ ...@@ -21,8 +22,11 @@ limitations under the License. */
#include "paddle/phi/api/lib/utils/tensor_utils.h" #include "paddle/phi/api/lib/utils/tensor_utils.h"
#include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/common/scalar_array.h" #include "paddle/phi/common/scalar_array.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/unary.h"
#include "paddle/phi/kernels/reshape_grad_kernel.h" #include "paddle/phi/kernels/reshape_grad_kernel.h"
#include "paddle/phi/kernels/reshape_kernel.h" #include "paddle/phi/kernels/reshape_kernel.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
class InferShapeContext; class InferShapeContext;
...@@ -472,22 +476,6 @@ class Reshape2Op : public ReshapeOp { ...@@ -472,22 +476,6 @@ class Reshape2Op : public ReshapeOp {
const framework::VariableNameMap &outputs, const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs) const framework::AttributeMap &attrs)
: ReshapeOp(type, inputs, outputs, attrs) {} : ReshapeOp(type, inputs, outputs, attrs) {}
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasOutput("XShape"), true,
platform::errors::InvalidArgument(
"Output(XShape) of ReshapeOp should not be null."));
const auto &x_dims = ctx->GetInputDim("X");
std::vector<int64_t> xshape_dims(x_dims.size() + 1);
xshape_dims[0] = 0;
for (int i = 0; i < x_dims.size(); ++i) {
xshape_dims[i + 1] = x_dims[i];
}
ctx->SetOutputDim("XShape", phi::make_ddim(xshape_dims));
ctx->ShareLoD("X", /*->*/ "XShape");
ReshapeOp::InferShape(ctx);
}
}; };
class Reshape2OpMaker : public ReshapeOpMaker { class Reshape2OpMaker : public ReshapeOpMaker {
...@@ -647,10 +635,14 @@ REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape_grad, float, ops::ReshapeGradKernel, ...@@ -647,10 +635,14 @@ REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape_grad, float, ops::ReshapeGradKernel,
ops::ReshapeGradKernel, int, ops::ReshapeGradKernel, int,
ops::ReshapeGradKernel, int64_t, ops::ReshapeGradKernel, int64_t,
ops::ReshapeGradKernel); ops::ReshapeGradKernel);
DELCARE_INFER_SHAPE_FUNCTOR(reshape2, ReshapeInferShapeFunctor,
PT_INFER_META(phi::ReshapeWithXShapeInferMeta));
REGISTER_OPERATOR(reshape2, ops::Reshape2Op, ops::Reshape2OpMaker, REGISTER_OPERATOR(reshape2, ops::Reshape2Op, ops::Reshape2OpMaker,
ops::Reshape2GradMaker<paddle::framework::OpDesc>, ops::Reshape2GradMaker<paddle::framework::OpDesc>,
ops::Reshape2GradMaker<paddle::imperative::OpBase>, ops::Reshape2GradMaker<paddle::imperative::OpBase>,
ops::ReshapeOpInplaceInferer); ReshapeInferShapeFunctor, ops::ReshapeOpInplaceInferer);
REGISTER_OPERATOR(reshape2_grad, ops::Reshape2GradOp, REGISTER_OPERATOR(reshape2_grad, ops::Reshape2GradOp,
ops::Reshape2DoubleGradMaker<paddle::framework::OpDesc>, ops::Reshape2DoubleGradMaker<paddle::framework::OpDesc>,
ops::Reshape2DoubleGradMaker<paddle::imperative::OpBase>, ops::Reshape2DoubleGradMaker<paddle::imperative::OpBase>,
......
...@@ -131,7 +131,7 @@ phi::ScalarArray MakePtenScalarArrayFromVarList( ...@@ -131,7 +131,7 @@ phi::ScalarArray MakePtenScalarArrayFromVarList(
} }
phi::ScalarArray result{vector_data}; phi::ScalarArray result{vector_data};
result.setInitByTensor(true); result.SetFromTensor(true);
return result; return result;
} }
......
...@@ -25,7 +25,7 @@ namespace experimental { ...@@ -25,7 +25,7 @@ namespace experimental {
template <typename T> template <typename T>
class ScalarBase { class ScalarBase {
public: public:
bool IsInitByTensor() const { return is_init_by_tensor_; } bool FromTensor() const { return is_from_tensor_; }
// Constructor support implicit // Constructor support implicit
ScalarBase(double val) : dtype_(DataType::FLOAT64) { // NOLINT ScalarBase(double val) : dtype_(DataType::FLOAT64) { // NOLINT
data_.f64 = val; data_.f64 = val;
...@@ -104,7 +104,7 @@ class ScalarBase { ...@@ -104,7 +104,7 @@ class ScalarBase {
// The Tensor must have one dim // The Tensor must have one dim
ScalarBase(const T& tensor) : dtype_(tensor.dtype()) { // NOLINT ScalarBase(const T& tensor) : dtype_(tensor.dtype()) { // NOLINT
is_init_by_tensor_ = true; is_from_tensor_ = true;
PD_CHECK( PD_CHECK(
tensor.numel() == 1, tensor.numel() == 1,
"The Scalar only supports Tensor with 1 element, but now Tensor has `", "The Scalar only supports Tensor with 1 element, but now Tensor has `",
...@@ -196,7 +196,7 @@ class ScalarBase { ...@@ -196,7 +196,7 @@ class ScalarBase {
friend void CopyScalar(const ScalarBase<T1>& src, ScalarBase<T2>* dst); friend void CopyScalar(const ScalarBase<T1>& src, ScalarBase<T2>* dst);
private: private:
bool is_init_by_tensor_{false}; bool is_from_tensor_{false};
DataType dtype_; DataType dtype_;
union data { union data {
bool b; bool b;
......
...@@ -43,13 +43,13 @@ class ScalarArrayBase { ...@@ -43,13 +43,13 @@ class ScalarArrayBase {
AssignData(date_value, n); AssignData(date_value, n);
} }
bool IsInitByTensor() const { return is_init_by_tensor_; } bool FromTensor() const { return is_from_tensor_; }
void setInitByTensor(bool val) { is_init_by_tensor_ = val; } void SetFromTensor(bool val) { is_from_tensor_ = val; }
// The Tensor must have one dim // The Tensor must have one dim
ScalarArrayBase(const T& tensor) { // NOLINT ScalarArrayBase(const T& tensor) { // NOLINT
is_init_by_tensor_ = true; is_from_tensor_ = true;
size_t n = tensor.numel(); size_t n = tensor.numel();
array_.reserve(n); array_.reserve(n);
switch (tensor.dtype()) { switch (tensor.dtype()) {
...@@ -71,7 +71,7 @@ class ScalarArrayBase { ...@@ -71,7 +71,7 @@ class ScalarArrayBase {
// The Tensor in vec must have only one element // The Tensor in vec must have only one element
ScalarArrayBase(const std::vector<T>& tensor_list) { // NOLINT ScalarArrayBase(const std::vector<T>& tensor_list) { // NOLINT
is_init_by_tensor_ = true; is_from_tensor_ = true;
for (size_t i = 0; i < tensor_list.size(); ++i) { for (size_t i = 0; i < tensor_list.size(); ++i) {
DataType data_type = tensor_list[i].dtype(); DataType data_type = tensor_list[i].dtype();
...@@ -117,7 +117,7 @@ class ScalarArrayBase { ...@@ -117,7 +117,7 @@ class ScalarArrayBase {
// TODO(zhangyunfei) Replace std::vector with a more efficient container // TODO(zhangyunfei) Replace std::vector with a more efficient container
// structure. // structure.
std::vector<int64_t> array_; std::vector<int64_t> array_;
bool is_init_by_tensor_{false}; bool is_from_tensor_{false};
}; };
using ScalarArray = using ScalarArray =
......
...@@ -241,6 +241,10 @@ struct KernelImpl<Return (*)(DevCtx, Args...), kernel_fn> { ...@@ -241,6 +241,10 @@ struct KernelImpl<Return (*)(DevCtx, Args...), kernel_fn> {
PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(const ScalarArray&); PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(const ScalarArray&);
PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(const std::vector<int>&); PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(const std::vector<int>&);
PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(const std::string&); PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(const std::string&);
PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(const std::vector<bool>&);
PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(const std::vector<float>&);
PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(const std::vector<double>&);
PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(const std::vector<std::string>&);
/* Output Helpers */ /* Output Helpers */
......
...@@ -15,8 +15,8 @@ limitations under the License. */ ...@@ -15,8 +15,8 @@ limitations under the License. */
#include "paddle/phi/infermeta/unary.h" #include "paddle/phi/infermeta/unary.h"
#include <set> #include <set>
#include "paddle/phi/common/data_type.h" #include "paddle/phi/common/data_type.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/infermeta_utils.h" #include "paddle/phi/core/infermeta_utils.h"
namespace phi { namespace phi {
...@@ -217,7 +217,7 @@ void InferMetaFromVecValue(const MetaTensor& x, ...@@ -217,7 +217,7 @@ void InferMetaFromVecValue(const MetaTensor& x,
MetaTensor* out) { MetaTensor* out) {
PADDLE_ENFORCE_EQ(!shape.empty(), PADDLE_ENFORCE_EQ(!shape.empty(),
true, true,
paddle::platform::errors::InvalidArgument( phi::errors::InvalidArgument(
"The parameter 'shape' in ReshapeOp must be set. " "The parameter 'shape' in ReshapeOp must be set. "
"But received 'shape' is empty.")); "But received 'shape' is empty."));
auto x_dims = x.dims(); auto x_dims = x.dims();
...@@ -234,8 +234,42 @@ void InferMetaFromVecValue(const MetaTensor& x, ...@@ -234,8 +234,42 @@ void InferMetaFromVecValue(const MetaTensor& x,
void ReshapeInferMeta(const MetaTensor& x, void ReshapeInferMeta(const MetaTensor& x,
const ScalarArray& shape, const ScalarArray& shape,
MetaTensor* out) { MetaTensor* out,
InferMetaFromVecValue(x, shape.GetData(), out); MetaConfig config) {
auto& shape_data = shape.GetData();
PADDLE_ENFORCE_NOT_NULL(out,
phi::errors::InvalidArgument(
"Output(Out) of ReshapeOp should not be null."));
if (!config.is_runtime && shape.FromTensor()) {
out->set_dims(phi::make_ddim(shape_data));
out->share_lod(x);
return;
}
PADDLE_ENFORCE_GT(shape_data.size(),
0,
phi::errors::InvalidArgument(
"The shape's size in ReshapeOp can't be zero."));
InferMetaFromVecValue(x, shape_data, out);
}
void ReshapeWithXShapeInferMeta(const MetaTensor& x,
const ScalarArray& shape,
MetaTensor* xshape,
MetaTensor* out,
MetaConfig config) {
PADDLE_ENFORCE_NOT_NULL(
xshape,
phi::errors::InvalidArgument(
"Output(XShape) of ReshapeOp should not be null."));
const auto& x_dims = x.dims();
std::vector<int64_t> xshape_dims(x_dims.size() + 1);
xshape_dims[0] = 0;
for (int i = 0; i < x_dims.size(); ++i) {
xshape_dims[i + 1] = x_dims[i];
}
xshape->set_dims(phi::make_ddim(xshape_dims));
xshape->share_lod(x);
ReshapeInferMeta(x, shape, out, config);
} }
/* Why not use ReduceInferMeta directly? /* Why not use ReduceInferMeta directly?
......
...@@ -54,7 +54,14 @@ void InferMetaFromVecValue(const MetaTensor& x, ...@@ -54,7 +54,14 @@ void InferMetaFromVecValue(const MetaTensor& x,
void ReshapeInferMeta(const MetaTensor& x, void ReshapeInferMeta(const MetaTensor& x,
const ScalarArray& shape, const ScalarArray& shape,
MetaTensor* out); MetaTensor* out,
MetaConfig config = MetaConfig());
void ReshapeWithXShapeInferMeta(const MetaTensor& x,
const ScalarArray& shape,
MetaTensor* xshape,
MetaTensor* out,
MetaConfig config = MetaConfig());
void ReduceInferMetaBase(const MetaTensor& x, void ReduceInferMetaBase(const MetaTensor& x,
const std::vector<int64_t>& axis, const std::vector<int64_t>& axis,
......
...@@ -29,7 +29,7 @@ void SplitKernel(const Context& dev_ctx, ...@@ -29,7 +29,7 @@ void SplitKernel(const Context& dev_ctx,
const Scalar& axis_scalar, const Scalar& axis_scalar,
std::vector<DenseTensor*> outs) { std::vector<DenseTensor*> outs) {
// need to infershape output // need to infershape output
if (num_or_sections.IsInitByTensor() || axis_scalar.IsInitByTensor()) { if (num_or_sections.FromTensor() || axis_scalar.FromTensor()) {
std::vector<MetaTensor> out_metas; std::vector<MetaTensor> out_metas;
for (size_t i = 0; i < outs.size(); ++i) { for (size_t i = 0; i < outs.size(); ++i) {
out_metas.push_back(outs[i]); out_metas.push_back(outs[i]);
......
...@@ -28,7 +28,7 @@ void SplitKernel(const Context& dev_ctx, ...@@ -28,7 +28,7 @@ void SplitKernel(const Context& dev_ctx,
const Scalar& axis_scalar, const Scalar& axis_scalar,
std::vector<DenseTensor*> outs) { std::vector<DenseTensor*> outs) {
// need to infershape output // need to infershape output
if (num_or_sections.IsInitByTensor() || axis_scalar.IsInitByTensor()) { if (num_or_sections.FromTensor() || axis_scalar.FromTensor()) {
std::vector<MetaTensor> out_metas; std::vector<MetaTensor> out_metas;
for (size_t i = 0; i < outs.size(); ++i) { for (size_t i = 0; i < outs.size(); ++i) {
out_metas.push_back(outs[i]); out_metas.push_back(outs[i]);
......
...@@ -47,7 +47,6 @@ void ReshapeWithXShape(const Context& dev_ctx, ...@@ -47,7 +47,6 @@ void ReshapeWithXShape(const Context& dev_ctx,
const ScalarArray& shape, const ScalarArray& shape,
DenseTensor* xshape, DenseTensor* xshape,
DenseTensor* out) { DenseTensor* out) {
funcs::SetXShape(x, xshape);
ReshapeKernel(dev_ctx, x, shape, out); ReshapeKernel(dev_ctx, x, shape, out);
} }
......
...@@ -17,13 +17,19 @@ limitations under the License. */ ...@@ -17,13 +17,19 @@ limitations under the License. */
namespace phi { namespace phi {
KernelSignature ReshapeOpArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature ReshapeOpArgumentMapping(const ArgumentMappingContext& ctx) {
if (ctx.InputSize("ShapeTensor") > 0) { if (ctx.HasOutput("XShape")) {
return KernelSignature("reshape", {"X"}, {"ShapeTensor"}, {"Out"}); if (ctx.InputSize("ShapeTensor") > 0) {
} else if (ctx.HasInput("Shape")) { return KernelSignature(
return KernelSignature("reshape", {"X"}, {"Shape"}, {"Out"}); "reshape_with_xshape", {"X"}, {"ShapeTensor"}, {"XShape", "Out"});
} else { } else if (ctx.HasInput("Shape")) {
return KernelSignature("reshape", {"X"}, {"shape"}, {"Out"}); return KernelSignature(
"reshape_with_xshape", {"X"}, {"Shape"}, {"XShape", "Out"});
} else {
return KernelSignature(
"reshape_with_xshape", {"X"}, {"shape"}, {"XShape", "Out"});
}
} }
return KernelSignature("unregistered", {}, {}, {});
} }
KernelSignature ReshapeGradOpArgumentMapping( KernelSignature ReshapeGradOpArgumentMapping(
......
...@@ -91,14 +91,10 @@ class TRTReshapeTest2(TRTReshapeTest): ...@@ -91,14 +91,10 @@ class TRTReshapeTest2(TRTReshapeTest):
with fluid.program_guard(self.main_program, self.startup_program): with fluid.program_guard(self.main_program, self.startup_program):
data = fluid.data( data = fluid.data(
name='data', shape=self.data_shape, dtype='float32') name='data', shape=self.data_shape, dtype='float32')
actual_reshape = fluid.data( reshape_out = fluid.layers.reshape(x=data, shape=self.reshape)
name='actual_reshape', shape=[4], dtype='int32')
reshape_out = fluid.layers.reshape(
x=data, shape=self.reshape, actual_shape=actual_reshape)
out = fluid.layers.batch_norm(reshape_out, is_test=True) out = fluid.layers.batch_norm(reshape_out, is_test=True)
self.feeds = { self.feeds = {
'data': np.random.random(self.data_shape).astype('float32'), 'data': np.random.random(self.data_shape).astype('float32')
'actual_reshape': np.array([2, 0, -1, 6]).astype('int32')
} }
self.enable_trt = True self.enable_trt = True
self.trt_parameters = TRTReshapeTest.TensorRTParam( self.trt_parameters = TRTReshapeTest.TensorRTParam(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册