未验证 提交 8e4e19ab 编写于 作者: H hong 提交者: GitHub

Add infer meta (#40544)

* add infer meta; test=develop

* add histogram infer meta; test=develop

* fix unitest bug; test=develop

* format; test=develop

* format; test=develop

* bn not use new infer meta; test=develop

* add infer meta; test=develop

* fixbug; test=develop

* fix bug;

* recover unitest; test=develop
上级 8e612903
...@@ -113,23 +113,5 @@ class BatchNormOpInferVarType ...@@ -113,23 +113,5 @@ class BatchNormOpInferVarType
} }
}; };
template <typename DeviceContext, typename T>
class BatchNormKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override;
};
template <typename DeviceContext, typename T>
class BatchNormGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override;
};
template <typename DeviceContext, typename T>
class BatchNormDoubleGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override;
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -27,6 +27,9 @@ limitations under the License. */ ...@@ -27,6 +27,9 @@ limitations under the License. */
#endif #endif
#include "paddle/fluid/platform/cudnn_workspace_helper.h" #include "paddle/fluid/platform/cudnn_workspace_helper.h"
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/phi/infermeta/binary.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -841,6 +844,8 @@ framework::OpKernelType ConvOpDoubleGrad::GetExpectedKernelType( ...@@ -841,6 +844,8 @@ framework::OpKernelType ConvOpDoubleGrad::GetExpectedKernelType(
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(conv2d, Conv2dInferShapeFunctor,
PD_INFER_META(phi::ConvInferMeta));
REGISTER_OPERATOR(conv2d, ops::ConvOp, ops::Conv2DOpMaker, REGISTER_OPERATOR(conv2d, ops::ConvOp, ops::Conv2DOpMaker,
ops::ConvOpInferVarType, ops::ConvOpInferVarType,
ops::Conv2DGradMaker<paddle::framework::OpDesc>, ops::Conv2DGradMaker<paddle::framework::OpDesc>,
...@@ -851,6 +856,8 @@ REGISTER_OPERATOR(conv2d_grad, ops::ConvOpGrad, ...@@ -851,6 +856,8 @@ REGISTER_OPERATOR(conv2d_grad, ops::ConvOpGrad,
REGISTER_OPERATOR(conv2d_grad_grad, ops::ConvOpDoubleGrad); REGISTER_OPERATOR(conv2d_grad_grad, ops::ConvOpDoubleGrad);
// depthwise convolution op // depthwise convolution op
DECLARE_INFER_SHAPE_FUNCTOR(depthwise_conv2d, DepthwiseConv2dInferShapeFunctor,
PD_INFER_META(phi::ConvInferMeta));
REGISTER_OPERATOR(depthwise_conv2d, ops::ConvOp, ops::Conv2DOpMaker, REGISTER_OPERATOR(depthwise_conv2d, ops::ConvOp, ops::Conv2DOpMaker,
ops::ConvOpInferVarType, ops::ConvOpInferVarType,
ops::Conv2DGradMaker<paddle::framework::OpDesc>, ops::Conv2DGradMaker<paddle::framework::OpDesc>,
...@@ -860,6 +867,8 @@ REGISTER_OPERATOR(depthwise_conv2d_grad, ops::ConvOpGrad, ...@@ -860,6 +867,8 @@ REGISTER_OPERATOR(depthwise_conv2d_grad, ops::ConvOpGrad,
ops::Conv2DDoubleGradMaker<paddle::imperative::OpBase>); ops::Conv2DDoubleGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(depthwise_conv2d_grad_grad, ops::ConvOpDoubleGrad); REGISTER_OPERATOR(depthwise_conv2d_grad_grad, ops::ConvOpDoubleGrad);
DECLARE_INFER_SHAPE_FUNCTOR(conv3d, Conv3dInferShapeFunctor,
PD_INFER_META(phi::ConvInferMeta));
REGISTER_OPERATOR(conv3d, ops::ConvOp, ops::Conv3DOpMaker, REGISTER_OPERATOR(conv3d, ops::ConvOp, ops::Conv3DOpMaker,
ops::ConvOpInferVarType, ops::ConvOpInferVarType,
ops::Conv3DGradMaker<paddle::framework::OpDesc>, ops::Conv3DGradMaker<paddle::framework::OpDesc>,
......
...@@ -9,8 +9,10 @@ ...@@ -9,8 +9,10 @@
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/phi/infermeta/binary.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -235,10 +237,13 @@ class YoloBoxOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -235,10 +237,13 @@ class YoloBoxOpMaker : public framework::OpProtoAndCheckerMaker {
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(yolo_box, YoloBoxInferShapeFunctor,
PD_INFER_META(phi::YoloBoxInferMeta));
REGISTER_OPERATOR( REGISTER_OPERATOR(
yolo_box, ops::YoloBoxOp, ops::YoloBoxOpMaker, yolo_box, ops::YoloBoxOp, ops::YoloBoxOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>, paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>); paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
YoloBoxInferShapeFunctor);
REGISTER_OP_VERSION(yolo_box) REGISTER_OP_VERSION(yolo_box)
.AddCheckpoint( .AddCheckpoint(
......
...@@ -14,7 +14,9 @@ limitations under the License. */ ...@@ -14,7 +14,9 @@ limitations under the License. */
#include <memory> #include <memory>
#include <string> #include <string>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/infermeta/unary.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -25,17 +27,6 @@ class DropoutOp : public framework::OperatorWithKernel { ...@@ -25,17 +27,6 @@ class DropoutOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "Dropout");
auto x_dims = ctx->GetInputDim("X");
ctx->SetOutputDim("Out", x_dims);
if (ctx->Attrs().Get<bool>("is_test") == false) {
ctx->SetOutputDim("Mask", x_dims);
}
ctx->ShareLoD("X", /*->*/ "Out");
}
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
...@@ -173,7 +164,11 @@ class DropoutGradOpMaker : public framework::SingleGradOpMaker<T> { ...@@ -173,7 +164,11 @@ class DropoutGradOpMaker : public framework::SingleGradOpMaker<T> {
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(dropout, DropoutInferShapeFunctor,
PD_INFER_META(phi::DropoutInferMeta));
REGISTER_OPERATOR(dropout, ops::DropoutOp, ops::DropoutOpMaker, REGISTER_OPERATOR(dropout, ops::DropoutOp, ops::DropoutOpMaker,
ops::DropoutGradOpMaker<paddle::framework::OpDesc>, ops::DropoutGradOpMaker<paddle::framework::OpDesc>,
ops::DropoutGradOpMaker<paddle::imperative::OpBase>); ops::DropoutGradOpMaker<paddle::imperative::OpBase>,
DropoutInferShapeFunctor);
REGISTER_OPERATOR(dropout_grad, ops::DropoutOpGrad); REGISTER_OPERATOR(dropout_grad, ops::DropoutOpGrad);
...@@ -120,6 +120,142 @@ class Conv2DFusionOp : public operators::ConvOp { ...@@ -120,6 +120,142 @@ class Conv2DFusionOp : public operators::ConvOp {
ctx->SetOutputsDim("Outputs", output_shapes); ctx->SetOutputsDim("Outputs", output_shapes);
} }
} }
std::vector<int64_t> ComputeOutputShape(
framework::InferShapeContext* ctx) const {
OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "Conv");
OP_INOUT_CHECK(ctx->HasInput("Filter"), "Input", "Filter", "Conv");
auto in_dims = ctx->GetInputDim("Input");
auto filter_dims = ctx->GetInputDim("Filter");
std::vector<int> strides = ctx->Attrs().Get<std::vector<int>>("strides");
std::vector<int> paddings = ctx->Attrs().Get<std::vector<int>>("paddings");
std::string padding_algorithm =
ctx->Attrs().Get<std::string>("padding_algorithm");
int groups = ctx->Attrs().Get<int>("groups");
std::vector<int> dilations =
ctx->Attrs().Get<std::vector<int>>("dilations");
int dilation_size = dilations.size();
for (int i = 0; i < dilation_size; ++i) {
PADDLE_ENFORCE_GT(
dilations[i], 0,
platform::errors::InvalidArgument(
"The dilation of Op(Conv) should be larget than 0, but received "
"dilation is %d.",
dilations[i]));
}
const std::string data_format =
ctx->Attrs().Get<std::string>("data_format");
// MKL-DNN Kernels are using NCHW order of dims description
// so we ignore data_format consideration for MKL-DNN kernel
const bool channel_last = (ctx->IsRunMKLDNNKernel() == false) &&
(data_format == "NHWC" || data_format == "NDHWC");
PADDLE_ENFORCE_EQ(
in_dims.size() == 4 || in_dims.size() == 5, true,
platform::errors::InvalidArgument(
"The input of Op(Conv) should be a 4-D or 5-D Tensor. But "
"received: input's dimension is %u, input's shape is [%s].",
in_dims.size(), in_dims));
PADDLE_ENFORCE_EQ(
in_dims.size(), filter_dims.size(),
platform::errors::InvalidArgument(
"The input's dimension and filter's dimension of "
"Op(Conv) should be equal. But received: the input's shape is "
"[%s], "
"the input's dimension is %d; the filter's shape is [%s], "
"the filter's dimension is %d.",
in_dims, in_dims.size(), filter_dims, filter_dims.size()));
int stride_size = strides.size();
for (int i = 0; i < stride_size; ++i) {
PADDLE_ENFORCE_GT(
strides[i], 0,
platform::errors::InvalidArgument(
"The stride of Op(Conv) should be larget than 0, but received "
"stride is %d.",
strides[i]));
}
int in_sub_stride_size = in_dims.size() - stride_size;
PADDLE_ENFORCE_EQ(
in_dims.size(), strides.size() + 2U,
platform::errors::InvalidArgument(
"The difference of input's dimension and Attr(strides)'s "
"length must be euqal to 2 for Op(Conv). "
"But received: input's dimension is %d, input's shape is [%s]; "
"Attr(stride)'s length is %d, Attr(stride) is [%s]; "
"difference of input's dimention and Attr(strides)'s length = %u.",
in_dims.size(), in_dims, strides.size(), phi::make_ddim(strides),
in_sub_stride_size));
const auto input_channels =
channel_last ? in_dims[in_dims.size() - 1] : in_dims[1];
PADDLE_ENFORCE_EQ(
input_channels, filter_dims[1] * groups,
platform::errors::InvalidArgument(
"The number of input's channels should be equal to filter's "
"channels "
"* groups for Op(Conv). But received: the input's channels is %d, "
"the input's shape is [%s]; the filter's channels is %d, the "
"filter's shape is [%s]; the groups is %d, the data_format is %s. "
"The error may come from wrong data_format setting.",
input_channels, in_dims, filter_dims[1], filter_dims, groups,
data_format));
PADDLE_ENFORCE_EQ(
filter_dims[0] % groups, 0,
platform::errors::InvalidArgument(
"The number of output's channels (filter's first dimension) of "
"Op(Conv) should be divided by groups. But received: "
"the output channels is %d, the filter's shape is [%s], "
"the groups is %d.",
filter_dims[0], filter_dims, groups));
if (ctx->IsRuntime()) {
PADDLE_ENFORCE_GT(
filter_dims[0], 0,
platform::errors::InvalidArgument(
"the size of filter at axis 0 should be greater than 0"));
}
framework::DDim in_data_dims;
if (channel_last) {
in_data_dims = phi::slice_ddim(in_dims, 1, in_dims.size() - 1);
} else {
in_data_dims = phi::slice_ddim(in_dims, 2, in_dims.size());
}
framework::DDim filter_data_dims =
phi::slice_ddim(filter_dims, 2, filter_dims.size());
std::vector<int> ksize = phi::vectorize<int>(filter_data_dims);
UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm,
in_data_dims, strides, ksize);
std::vector<int64_t> output_shape({in_dims[0]});
if (!channel_last) {
output_shape.push_back(filter_dims[0]);
}
for (int i = 0; i < in_data_dims.size(); ++i) {
if ((!ctx->IsRuntime()) &&
(in_data_dims[i] <= 0 || filter_dims[i + 2] <= 0)) {
output_shape.push_back(-1);
} else {
output_shape.push_back(
ConvOutputSize(in_data_dims[i], filter_data_dims[i], dilations[i],
paddings[2 * i], paddings[2 * i + 1], strides[i]));
}
}
if (channel_last) {
output_shape.push_back(filter_dims[0]);
}
return output_shape;
}
}; };
// TODO(qingqing): add gradient operator for conv2d_fusion // TODO(qingqing): add gradient operator for conv2d_fusion
......
...@@ -16,7 +16,9 @@ limitations under the License. */ ...@@ -16,7 +16,9 @@ limitations under the License. */
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/infermeta/unary.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -28,27 +30,6 @@ class HistogramOp : public framework::OperatorWithKernel { ...@@ -28,27 +30,6 @@ class HistogramOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "histogram");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "histogram");
const auto &nbins = ctx->Attrs().Get<int64_t>("bins");
const auto &minval = ctx->Attrs().Get<int>("min");
const auto &maxval = ctx->Attrs().Get<int>("max");
PADDLE_ENFORCE_GE(nbins, 1,
platform::errors::InvalidArgument(
"The bins should be greater than or equal to 1."
"But received nbins is %d",
nbins));
PADDLE_ENFORCE_GE(maxval, minval, platform::errors::InvalidArgument(
"max must be larger or equal to min."
"But received max is %d, min is %d",
maxval, minval));
ctx->SetOutputDim("Out", phi::make_ddim({nbins}));
ctx->ShareLoD("X", /*->*/ "Out");
}
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const { const framework::ExecutionContext &ctx) const {
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
...@@ -81,7 +62,12 @@ class HistogramOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -81,7 +62,12 @@ class HistogramOpMaker : public framework::OpProtoAndCheckerMaker {
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(histogram, HistogramInferShapeFunctor,
PD_INFER_META(phi::HistogramInferMeta));
REGISTER_OPERATOR( REGISTER_OPERATOR(
histogram, ops::HistogramOp, ops::HistogramOpMaker, histogram, ops::HistogramOp, ops::HistogramOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>, paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>); paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
HistogramInferShapeFunctor);
...@@ -323,6 +323,7 @@ class InplaceABNGradKernel : public framework::OpKernel<T> { ...@@ -323,6 +323,7 @@ class InplaceABNGradKernel : public framework::OpKernel<T> {
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(inplace_abn, ops::InplaceABNOp, ops::InplaceABNOpMaker, REGISTER_OPERATOR(inplace_abn, ops::InplaceABNOp, ops::InplaceABNOpMaker,
ops::BatchNormOpInferVarType, ops::BatchNormOpInferVarType,
ops::InplaceABNOpGradMaker<paddle::framework::OpDesc>, ops::InplaceABNOpGradMaker<paddle::framework::OpDesc>,
......
...@@ -12,7 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,7 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/infermeta/binary.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -21,16 +23,6 @@ class MaskedSelectOp : public framework::OperatorWithKernel { ...@@ -21,16 +23,6 @@ class MaskedSelectOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "Input", "MaskedSelect");
OP_INOUT_CHECK(ctx->HasInput("Mask"), "Input", "Mask", "MaskedSelect");
OP_INOUT_CHECK(ctx->HasOutput("Y"), "Output", "Out", "MaskedSelect");
// output will only be a 1-D Tensor
ctx->SetOutputDim("Y", phi::make_ddim({-1}));
ctx->ShareLoD("X", /*->*/ "Y");
}
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
...@@ -100,8 +92,13 @@ DECLARE_NO_NEED_BUFFER_VARS_INFERER(MaskedSelectedGradNoNeedBufferVarsInferer, ...@@ -100,8 +92,13 @@ DECLARE_NO_NEED_BUFFER_VARS_INFERER(MaskedSelectedGradNoNeedBufferVarsInferer,
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(masked_select, MaksedSelectInferShapeFunctor,
PD_INFER_META(phi::MaskedSelectInferMeta));
REGISTER_OPERATOR(masked_select, ops::MaskedSelectOp, ops::MaskedSelectOpMaker, REGISTER_OPERATOR(masked_select, ops::MaskedSelectOp, ops::MaskedSelectOpMaker,
ops::MaskedSelectGradOpMaker<paddle::framework::OpDesc>, ops::MaskedSelectGradOpMaker<paddle::framework::OpDesc>,
ops::MaskedSelectGradOpMaker<paddle::imperative::OpBase>); ops::MaskedSelectGradOpMaker<paddle::imperative::OpBase>,
MaksedSelectInferShapeFunctor);
REGISTER_OPERATOR(masked_select_grad, ops::MaskedSelectOpGrad, REGISTER_OPERATOR(masked_select_grad, ops::MaskedSelectOpGrad,
ops::MaskedSelectedGradNoNeedBufferVarsInferer); ops::MaskedSelectedGradNoNeedBufferVarsInferer);
...@@ -15,7 +15,9 @@ limitations under the License. */ ...@@ -15,7 +15,9 @@ limitations under the License. */
#include <memory> #include <memory>
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/infermeta/unary.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -57,21 +59,7 @@ where, $\sum {x^2}$ is calculated along the `axis` dimension. ...@@ -57,21 +59,7 @@ where, $\sum {x^2}$ is calculated along the `axis` dimension.
}; };
class NormOp : public framework::OperatorWithKernel { class NormOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "NormOp");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "NormOp");
auto xdim = ctx->GetInputDim("X");
ctx->SetOutputDim("Out", xdim);
if (ctx->Attrs().Get<bool>("is_test") == false) {
int axis = ctx->Attrs().Get<int>("axis");
if (axis < 0) axis = xdim.size() + axis;
xdim[axis] = 1;
ctx->SetOutputDim("Norm", xdim);
}
}
}; };
class NormOpGrad : public framework::OperatorWithKernel { class NormOpGrad : public framework::OperatorWithKernel {
...@@ -111,7 +99,11 @@ class NormOpGradOpMaker : public framework::SingleGradOpMaker<T> { ...@@ -111,7 +99,11 @@ class NormOpGradOpMaker : public framework::SingleGradOpMaker<T> {
namespace ops = paddle::operators; namespace ops = paddle::operators;
using CPU = paddle::platform::CPUDeviceContext; using CPU = paddle::platform::CPUDeviceContext;
DECLARE_INFER_SHAPE_FUNCTOR(norm, NormInferShapeFunctor,
PD_INFER_META(phi::NormInferMeta));
REGISTER_OPERATOR(norm, ops::NormOp, ops::NormOpMaker, REGISTER_OPERATOR(norm, ops::NormOp, ops::NormOpMaker,
ops::NormOpGradOpMaker<paddle::framework::OpDesc>, ops::NormOpGradOpMaker<paddle::framework::OpDesc>,
ops::NormOpGradOpMaker<paddle::imperative::OpBase>); ops::NormOpGradOpMaker<paddle::imperative::OpBase>,
NormInferShapeFunctor);
REGISTER_OPERATOR(norm_grad, ops::NormOpGrad); REGISTER_OPERATOR(norm_grad, ops::NormOpGrad);
...@@ -50,6 +50,7 @@ class SyncBatchNormGradMaker : public framework::SingleGradOpMaker<T> { ...@@ -50,6 +50,7 @@ class SyncBatchNormGradMaker : public framework::SingleGradOpMaker<T> {
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(sync_batch_norm, ops::BatchNormOp, ops::BatchNormOpMaker, REGISTER_OPERATOR(sync_batch_norm, ops::BatchNormOp, ops::BatchNormOpMaker,
ops::BatchNormOpInferVarType, ops::BatchNormOpInferVarType,
ops::SyncBatchNormGradMaker<paddle::framework::OpDesc>, ops::SyncBatchNormGradMaker<paddle::framework::OpDesc>,
......
...@@ -21,6 +21,8 @@ limitations under the License. */ ...@@ -21,6 +21,8 @@ limitations under the License. */
#include "paddle/phi/core/infermeta_utils.h" #include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/kernels/funcs/common_shape.h" #include "paddle/phi/kernels/funcs/common_shape.h"
#include "paddle/phi/kernels/cpu/conv_util.h"
namespace phi { namespace phi {
namespace detail { namespace detail {
...@@ -355,6 +357,161 @@ void CrossInferMeta(const MetaTensor& x, ...@@ -355,6 +357,161 @@ void CrossInferMeta(const MetaTensor& x,
out->share_lod(x); out->share_lod(x);
} }
void ConvInferMeta(const MetaTensor& input,
const MetaTensor& filter,
const std::vector<int>& strides,
const std::vector<int>& paddings_t,
const std::string& padding_algorithm,
int groups,
const std::vector<int>& dilations_t,
const std::string& data_format,
bool use_addto,
int workspace_size_MB,
bool exhaustive_search,
MetaTensor* out,
MetaConfig config) {
std::vector<int> paddings = paddings_t;
std::vector<int> dilations = dilations_t;
auto in_dims = input.dims();
auto filter_dims = filter.dims();
int dilation_size = dilations.size();
for (int i = 0; i < dilation_size; ++i) {
PADDLE_ENFORCE_GT(
dilations[i],
0,
phi::errors::InvalidArgument(
"The dilation of Op(Conv) should be larget than 0, but received "
"dilation is %d.",
dilations[i]));
}
const bool channel_last = (config.is_run_mkldnn_kernel == false) &&
(data_format == "NHWC" || data_format == "NDHWC");
PADDLE_ENFORCE_EQ(
in_dims.size() == 4 || in_dims.size() == 5,
true,
phi::errors::InvalidArgument(
"The input of Op(Conv) should be a 4-D or 5-D Tensor. But "
"received: input's dimension is %u, input's shape is [%s].",
in_dims.size(),
in_dims));
PADDLE_ENFORCE_EQ(
in_dims.size(),
filter_dims.size(),
phi::errors::InvalidArgument(
"The input's dimension and filter's dimension of "
"Op(Conv) should be equal. But received: the input's shape is [%s], "
"the input's dimension is %d; the filter's shape is [%s], "
"the filter's dimension is %d.",
in_dims,
in_dims.size(),
filter_dims,
filter_dims.size()));
int stride_size = strides.size();
for (int i = 0; i < stride_size; ++i) {
PADDLE_ENFORCE_GT(
strides[i],
0,
phi::errors::InvalidArgument(
"The stride of Op(Conv) should be larget than 0, but received "
"stride is %d.",
strides[i]));
}
int in_sub_stride_size = in_dims.size() - stride_size;
PADDLE_ENFORCE_EQ(
in_dims.size(),
strides.size() + 2U,
phi::errors::InvalidArgument(
"The difference of input's dimension and Attr(strides)'s "
"length must be euqal to 2 for Op(Conv). "
"But received: input's dimension is %d, input's shape is [%s]; "
"Attr(stride)'s length is %d, Attr(stride) is [%s]; "
"difference of input's dimention and Attr(strides)'s length = %u.",
in_dims.size(),
in_dims,
strides.size(),
phi::make_ddim(strides),
in_sub_stride_size));
const auto input_channels =
channel_last ? in_dims[in_dims.size() - 1] : in_dims[1];
PADDLE_ENFORCE_EQ(
input_channels,
filter_dims[1] * groups,
phi::errors::InvalidArgument(
"The number of input's channels should be equal to filter's channels "
"* groups for Op(Conv). But received: the input's channels is %d, "
"the input's shape is [%s]; the filter's channels is %d, the "
"filter's shape is [%s]; the groups is %d, the data_format is %s. "
"The error may come from wrong data_format setting.",
input_channels,
in_dims,
filter_dims[1],
filter_dims,
groups,
data_format));
PADDLE_ENFORCE_EQ(
filter_dims[0] % groups,
0,
phi::errors::InvalidArgument(
"The number of output's channels (filter's first dimension) of "
"Op(Conv) should be divided by groups. But received: "
"the output channels is %d, the filter's shape is [%s], "
"the groups is %d.",
filter_dims[0],
filter_dims,
groups));
if (config.is_runtime) {
PADDLE_ENFORCE_GT(
filter_dims[0],
0,
phi::errors::InvalidArgument(
"the size of filter at axis 0 should be greater than 0"));
}
DDim in_data_dims;
if (channel_last) {
in_data_dims = phi::slice_ddim(in_dims, 1, in_dims.size() - 1);
} else {
in_data_dims = phi::slice_ddim(in_dims, 2, in_dims.size());
}
DDim filter_data_dims = phi::slice_ddim(filter_dims, 2, filter_dims.size());
std::vector<int> ksize = phi::vectorize<int>(filter_data_dims);
phi::UpdatePaddingAndDilation(
&paddings, &dilations, padding_algorithm, in_data_dims, strides, ksize);
std::vector<int64_t> output_shape({in_dims[0]});
if (!channel_last) {
output_shape.push_back(filter_dims[0]);
}
for (int i = 0; i < in_data_dims.size(); ++i) {
if ((!config.is_runtime) &&
(in_data_dims[i] <= 0 || filter_dims[i + 2] <= 0)) {
output_shape.push_back(-1);
} else {
const int dkernel = dilations[i] * (filter_data_dims[i] - 1) + 1;
int output_size =
(in_data_dims[i] + paddings[2 * i] + paddings[2 * i + 1] - dkernel) /
strides[i] +
1;
output_shape.push_back(output_size);
}
}
if (channel_last) {
output_shape.push_back(filter_dims[0]);
}
out->set_dims(make_ddim(output_shape));
out->set_dtype(input.dtype());
}
void DistInferMeta(const MetaTensor& x, void DistInferMeta(const MetaTensor& x,
const MetaTensor& y, const MetaTensor& y,
float p, float p,
...@@ -815,6 +972,13 @@ void LogLossInferMeta(const MetaTensor& input, ...@@ -815,6 +972,13 @@ void LogLossInferMeta(const MetaTensor& input,
out->share_lod(input); out->share_lod(input);
} }
void MaskedSelectInferMeta(const MetaTensor& x,
const MetaTensor& mask,
MetaTensor* out) {
out->set_dims({-1}); // can not infer
out->set_dtype(x.dtype());
}
void MatmulInferMeta(const MetaTensor& x, void MatmulInferMeta(const MetaTensor& x,
const MetaTensor& y, const MetaTensor& y,
bool trans_x, bool trans_x,
...@@ -1188,6 +1352,118 @@ void TriangularSolveInferMeta(const MetaTensor& x, ...@@ -1188,6 +1352,118 @@ void TriangularSolveInferMeta(const MetaTensor& x,
out->share_lod(y); out->share_lod(y);
} }
void YoloBoxInferMeta(const MetaTensor& x,
const MetaTensor& img_size,
const std::vector<int>& anchors,
int class_num,
float conf_thresh,
int downsample_ratio,
bool clip_bbox,
float scale_x_y,
bool iou_aware,
float iou_aware_factor,
MetaTensor* boxes,
MetaTensor* scores,
MetaConfig config) {
auto dim_x = x.dims();
auto dim_imgsize = img_size.dims();
int anchor_num = anchors.size() / 2;
PADDLE_ENFORCE_EQ(
dim_x.size(),
4,
phi::errors::InvalidArgument("Input(X) should be a 4-D tensor."
"But received X dimension(%s)",
dim_x.size()));
if (iou_aware) {
PADDLE_ENFORCE_EQ(
dim_x[1],
anchor_num * (6 + class_num),
phi::errors::InvalidArgument(
"Input(X) dim[1] should be equal to (anchor_mask_number * (6 "
"+ class_num)) while iou_aware is true."
"But received dim[1](%s) != (anchor_mask_number * "
"(6+class_num)(%s).",
dim_x[1],
anchor_num * (6 + class_num)));
PADDLE_ENFORCE_GE(
iou_aware_factor,
0,
phi::errors::InvalidArgument(
"Attr(iou_aware_factor) should greater than or equal to 0."
"But received iou_aware_factor (%s)",
iou_aware_factor));
PADDLE_ENFORCE_LE(
iou_aware_factor,
1,
phi::errors::InvalidArgument(
"Attr(iou_aware_factor) should less than or equal to 1."
"But received iou_aware_factor (%s)",
iou_aware_factor));
} else {
PADDLE_ENFORCE_EQ(
dim_x[1],
anchor_num * (5 + class_num),
phi::errors::InvalidArgument(
"Input(X) dim[1] should be equal to (anchor_mask_number * (5 "
"+ class_num))."
"But received dim[1](%s) != (anchor_mask_number * "
"(5+class_num)(%s).",
dim_x[1],
anchor_num * (5 + class_num)));
}
PADDLE_ENFORCE_EQ(
dim_imgsize.size(),
2,
phi::errors::InvalidArgument("Input(ImgSize) should be a 2-D tensor."
"But received Imgsize size(%s)",
dim_imgsize.size()));
if ((dim_imgsize[0] > 0 && dim_x[0] > 0) || config.is_runtime) {
PADDLE_ENFORCE_EQ(
dim_imgsize[0],
dim_x[0],
phi::errors::InvalidArgument(
"Input(ImgSize) dim[0] and Input(X) dim[0] should be same."));
}
PADDLE_ENFORCE_EQ(
dim_imgsize[1],
2,
phi::errors::InvalidArgument("Input(ImgSize) dim[1] should be 2."
"But received imgsize dim[1](%s).",
dim_imgsize[1]));
PADDLE_ENFORCE_GT(anchors.size(),
0,
phi::errors::InvalidArgument(
"Attr(anchors) length should be greater than 0."
"But received anchors length(%s).",
anchors.size()));
PADDLE_ENFORCE_EQ(anchors.size() % 2,
0,
phi::errors::InvalidArgument(
"Attr(anchors) length should be even integer."
"But received anchors length (%s)",
anchors.size()));
PADDLE_ENFORCE_GT(class_num,
0,
phi::errors::InvalidArgument(
"Attr(class_num) should be an integer greater than 0."
"But received class_num (%s)",
class_num));
int box_num;
if ((dim_x[2] > 0 && dim_x[3] > 0) || config.is_runtime) {
box_num = dim_x[2] * dim_x[3] * anchor_num;
} else {
box_num = -1;
}
std::vector<int64_t> dim_boxes({dim_x[0], box_num, 4});
boxes->set_dims(phi::make_ddim(dim_boxes));
boxes->set_dtype(x.dtype());
std::vector<int64_t> dim_scores({dim_x[0], box_num, class_num});
scores->set_dims(phi::make_ddim(dim_scores));
}
void ValueCompareInferMeta(const MetaTensor& x, void ValueCompareInferMeta(const MetaTensor& x,
const MetaTensor& y, const MetaTensor& y,
MetaTensor* out, MetaTensor* out,
...@@ -1201,3 +1477,4 @@ void ValueCompareInferMeta(const MetaTensor& x, ...@@ -1201,3 +1477,4 @@ void ValueCompareInferMeta(const MetaTensor& x,
} // namespace phi } // namespace phi
PD_REGISTER_INFER_META_FN(add_raw, phi::ElementwiseRawInferMeta); PD_REGISTER_INFER_META_FN(add_raw, phi::ElementwiseRawInferMeta);
PD_REGISTER_INFER_META_FN(conv2d, phi::ConvInferMeta);
...@@ -69,6 +69,20 @@ void CompareInferMeta(const MetaTensor& x, ...@@ -69,6 +69,20 @@ void CompareInferMeta(const MetaTensor& x,
int axis, int axis,
MetaTensor* out); MetaTensor* out);
void ConvInferMeta(const MetaTensor& input,
const MetaTensor& filter,
const std::vector<int>& strides,
const std::vector<int>& paddings,
const std::string& paddding_algorithm,
int groups,
const std::vector<int>& dilations,
const std::string& data_format,
bool use_addto,
int workspace_size_MB,
bool exhaustive_search,
MetaTensor* out,
MetaConfig config = MetaConfig());
void CrossInferMeta(const MetaTensor& x, void CrossInferMeta(const MetaTensor& x,
const MetaTensor& y, const MetaTensor& y,
int axis, int axis,
...@@ -138,6 +152,10 @@ void LogLossInferMeta(const MetaTensor& input, ...@@ -138,6 +152,10 @@ void LogLossInferMeta(const MetaTensor& input,
MetaTensor* out, MetaTensor* out,
MetaConfig config = MetaConfig()); MetaConfig config = MetaConfig());
void MaskedSelectInferMeta(const MetaTensor& x,
const MetaTensor& mask,
MetaTensor* out);
void MatmulInferMeta(const MetaTensor& x, void MatmulInferMeta(const MetaTensor& x,
const MetaTensor& y, const MetaTensor& y,
bool trans_x, bool trans_x,
...@@ -180,6 +198,20 @@ void TriangularSolveInferMeta(const MetaTensor& x, ...@@ -180,6 +198,20 @@ void TriangularSolveInferMeta(const MetaTensor& x,
bool unitriangular, bool unitriangular,
MetaTensor* out); MetaTensor* out);
void YoloBoxInferMeta(const MetaTensor& x,
const MetaTensor& img_size,
const std::vector<int>& anchors,
int class_num,
float conf_thresh,
int downsample_ratio,
bool clip_bbox,
float scale_x_y,
bool iou_aware,
float iou_aware_factor,
MetaTensor* boxes,
MetaTensor* scores,
MetaConfig config = MetaConfig());
void ValueCompareInferMeta(const MetaTensor& x, void ValueCompareInferMeta(const MetaTensor& x,
const MetaTensor& y, const MetaTensor& y,
MetaTensor* out, MetaTensor* out,
......
...@@ -14,7 +14,9 @@ limitations under the License. */ ...@@ -14,7 +14,9 @@ limitations under the License. */
#include "paddle/phi/infermeta/multiary.h" #include "paddle/phi/infermeta/multiary.h"
#include <vector> #include <vector>
#include "paddle/phi/common/layout.h"
#include "paddle/phi/common/scalar.h" #include "paddle/phi/common/scalar.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/core/meta_tensor.h" #include "paddle/phi/core/meta_tensor.h"
#include "paddle/phi/kernels/funcs/concat_funcs.h" #include "paddle/phi/kernels/funcs/concat_funcs.h"
namespace phi { namespace phi {
...@@ -200,6 +202,114 @@ void AucInferMeta(const MetaTensor& input, ...@@ -200,6 +202,114 @@ void AucInferMeta(const MetaTensor& input,
} }
} }
void BatchNormInferMeta(const MetaTensor& x,
const MetaTensor& scale,
const MetaTensor& bias,
const MetaTensor& mean,
const MetaTensor& variance,
float momentum,
float epsilon,
const std::string& data_layout_str,
bool is_test,
bool use_global_stats,
bool trainable_statistics,
bool fuse_with_relu,
MetaTensor* y,
MetaTensor* mean_out,
MetaTensor* variance_out,
MetaTensor* saved_mean,
MetaTensor* saved_variance,
MetaTensor* reserve_space,
MetaConfig config) {
const auto x_dims = x.dims();
for (int i = 0; i < x_dims.size(); i++) {
PADDLE_ENFORCE_EQ(
(x_dims[i] == -1) || (x_dims[i] > 0),
true,
phi::errors::InvalidArgument(
"Each dimension of input tensor is expected to be -1 or a "
"positive number, but recieved %d. Input's shape is [%s].",
x_dims[i],
x_dims));
}
const DataLayout data_layout =
paddle::framework::StringToDataLayout(data_layout_str);
PADDLE_ENFORCE_GE(
x_dims.size(),
2,
phi::errors::InvalidArgument(
"ShapeError: the dimension of input "
"X must greater than or equal to 2. But received: the shape of input "
"X = [%s], the dimension of input X =[%d]",
x_dims,
x_dims.size()));
PADDLE_ENFORCE_LE(
x_dims.size(),
5,
phi::errors::InvalidArgument(
"ShapeError: the dimension of input X "
"must smaller than or equal to 5. But received: the shape of input X "
"= [%s], the dimension of input X = [%d]",
x_dims,
x_dims.size()));
const int64_t C = ((config.is_run_mkldnn_kernel == true) ||
(data_layout == DataLayout::kNCHW)
? x_dims[1]
: x_dims[x_dims.size() - 1]);
auto scale_dim = scale.dims();
auto bias_dim = bias.dims();
PADDLE_ENFORCE_EQ(
scale_dim.size(),
1UL,
phi::errors::InvalidArgument(
"ShapeError: the dimension of scale must equal to 1."
"But received: the shape of scale is [%s], the dimension "
"of scale is [%d]",
scale_dim,
scale_dim.size()));
PADDLE_ENFORCE_EQ(bias_dim.size(),
1UL,
phi::errors::InvalidArgument(
"ShapeError: the dimension of bias must equal to 1."
"But received: the shape of bias is [%s],the dimension "
"of bias is [%d]",
bias_dim,
bias_dim.size()));
bool check = true;
if ((!config.is_runtime) &&
(phi::product(scale_dim) <= 0 || phi::product(bias_dim) <= 0)) {
check = false;
}
if (check) {
PADDLE_ENFORCE_EQ(scale_dim[0],
C,
phi::errors::InvalidArgument(
"ShapeError: the shape of scale must equal to [%d]"
"But received: the shape of scale is [%d]",
C,
scale_dim[0]));
PADDLE_ENFORCE_EQ(bias_dim[0],
C,
phi::errors::InvalidArgument(
"ShapeError: the shape of bias must equal to [%d]"
"But received: the shape of bias is [%d]",
C,
bias_dim[0]));
}
y->set_dims(x_dims);
mean_out->set_dims({C});
variance_out->set_dims({C});
saved_mean->set_dims({C});
saved_variance->set_dims({C});
y->share_lod(x);
}
void BilinearTensorProductInferMeta(const MetaTensor& x, void BilinearTensorProductInferMeta(const MetaTensor& x,
const MetaTensor& y, const MetaTensor& y,
const MetaTensor& weight, const MetaTensor& weight,
...@@ -577,3 +687,5 @@ void WhereInferMeta(const MetaTensor& condition, ...@@ -577,3 +687,5 @@ void WhereInferMeta(const MetaTensor& condition,
} }
} // namespace phi } // namespace phi
PD_REGISTER_INFER_META_FN(batch_norm, phi::BatchNormInferMeta);
...@@ -72,6 +72,26 @@ void AucInferMeta(const MetaTensor& input, ...@@ -72,6 +72,26 @@ void AucInferMeta(const MetaTensor& input,
MetaTensor* stat_neg_out, MetaTensor* stat_neg_out,
MetaConfig config = MetaConfig()); MetaConfig config = MetaConfig());
void BatchNormInferMeta(const MetaTensor& x,
const MetaTensor& scale,
const MetaTensor& bias,
const MetaTensor& mean,
const MetaTensor& variance,
float momentum,
float epsilon,
const std::string& data_layout,
bool is_test,
bool use_global_stats,
bool trainable_statistics,
bool fuse_with_relu,
MetaTensor* y,
MetaTensor* mean_out,
MetaTensor* variance_out,
MetaTensor* saved_mean,
MetaTensor* saved_variance,
MetaTensor* reserve_space,
MetaConfig config = MetaConfig());
void BilinearTensorProductInferMeta(const MetaTensor& x, void BilinearTensorProductInferMeta(const MetaTensor& x,
const MetaTensor& y, const MetaTensor& y,
const MetaTensor& weight, const MetaTensor& weight,
......
...@@ -304,6 +304,17 @@ void DiagonalInferMeta(const MetaTensor& input, ...@@ -304,6 +304,17 @@ void DiagonalInferMeta(const MetaTensor& input,
out->set_dims(phi::make_ddim(out_dims)); out->set_dims(phi::make_ddim(out_dims));
} }
void DropoutInferMeta(const MetaTensor& x, MetaTensor* out, MetaTensor* mask) {
auto x_dims = x.dims();
out->set_dims(x_dims);
out->share_lod(x);
out->set_dtype(x.dtype());
if (mask != nullptr) {
mask->set_dims(x_dims);
}
}
void EighInferMeta(const MetaTensor& x, void EighInferMeta(const MetaTensor& x,
const std::string& uplo, const std::string& uplo,
MetaTensor* out_w, MetaTensor* out_w,
...@@ -392,6 +403,26 @@ void GumbelSoftmaxInferMeta(const MetaTensor& x, ...@@ -392,6 +403,26 @@ void GumbelSoftmaxInferMeta(const MetaTensor& x,
UnchangedInferMetaCheckAxis(x, axis, out); UnchangedInferMetaCheckAxis(x, axis, out);
} }
void HistogramInferMeta(
const MetaTensor& input, int64_t bins, int min, int max, MetaTensor* out) {
PADDLE_ENFORCE_GE(bins,
1,
phi::errors::InvalidArgument(
"The bins should be greater than or equal to 1."
"But received nbins is %d",
bins));
PADDLE_ENFORCE_GE(
max,
min,
phi::errors::InvalidArgument("max must be larger or equal to min."
"But received max is %d, min is %d",
max,
min));
out->set_dims({bins});
out->share_lod(input);
}
void IncrementInferMeta(const MetaTensor& x, float value, MetaTensor* out) { void IncrementInferMeta(const MetaTensor& x, float value, MetaTensor* out) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
product(x.dims()), product(x.dims()),
...@@ -787,6 +818,24 @@ void MultinomialInferMeta(const MetaTensor& x, ...@@ -787,6 +818,24 @@ void MultinomialInferMeta(const MetaTensor& x,
out->set_dtype(DataType::INT64); out->set_dtype(DataType::INT64);
} }
void NormInferMeta(const MetaTensor& x,
int axis,
float epsilon,
bool is_test,
MetaTensor* out,
MetaTensor* norm) {
auto xdim = x.dims();
out->set_dims(x.dims());
out->set_dtype(x.dtype());
if (is_test == false) {
if (axis < 0) axis = xdim.size() + axis;
xdim[axis] = 1;
norm->set_dims(xdim);
norm->set_dtype(x.dtype());
}
}
void PadInferMeta(const MetaTensor& input, void PadInferMeta(const MetaTensor& input,
const std::vector<int>& paddings, const std::vector<int>& paddings,
float pad_value, float pad_value,
......
...@@ -74,6 +74,8 @@ void DiagInferMeta(const MetaTensor& x, ...@@ -74,6 +74,8 @@ void DiagInferMeta(const MetaTensor& x,
void DiagonalInferMeta( void DiagonalInferMeta(
const MetaTensor& input, int offset, int axis1, int axis2, MetaTensor* out); const MetaTensor& input, int offset, int axis1, int axis2, MetaTensor* out);
void DropoutInferMeta(const MetaTensor& x, MetaTensor* out, MetaTensor* mask);
void EighInferMeta(const MetaTensor& x, void EighInferMeta(const MetaTensor& x,
const std::string& uplo, const std::string& uplo,
MetaTensor* out_w, MetaTensor* out_w,
...@@ -89,6 +91,8 @@ void GumbelSoftmaxInferMeta(const MetaTensor& x, ...@@ -89,6 +91,8 @@ void GumbelSoftmaxInferMeta(const MetaTensor& x,
bool hard, bool hard,
int axis, int axis,
MetaTensor* out); MetaTensor* out);
void HistogramInferMeta(
const MetaTensor& input, int64_t bins, int min, int max, MetaTensor* out);
void IncrementInferMeta(const MetaTensor& x, float value, MetaTensor* out); void IncrementInferMeta(const MetaTensor& x, float value, MetaTensor* out);
...@@ -130,6 +134,12 @@ void MultinomialInferMeta(const MetaTensor& x, ...@@ -130,6 +134,12 @@ void MultinomialInferMeta(const MetaTensor& x,
int num_samples, int num_samples,
bool replacement, bool replacement,
MetaTensor* out); MetaTensor* out);
void NormInferMeta(const MetaTensor& x,
int axis,
float epsilon,
bool is_test,
MetaTensor* out,
MetaTensor* norm);
void PadInferMeta(const MetaTensor& input, void PadInferMeta(const MetaTensor& input,
const std::vector<int>& paddings, const std::vector<int>& paddings,
......
...@@ -19,6 +19,7 @@ import paddle.inference as paddle_infer ...@@ -19,6 +19,7 @@ import paddle.inference as paddle_infer
from functools import partial from functools import partial
from typing import Optional, List, Callable, Dict, Any, Set from typing import Optional, List, Callable, Dict, Any, Set
import unittest import unittest
import paddle
import hypothesis import hypothesis
from hypothesis import given, settings, seed, example, assume from hypothesis import given, settings, seed, example, assume
...@@ -104,4 +105,5 @@ class TestConvGeluMkldnnFusePass(PassAutoScanTest): ...@@ -104,4 +105,5 @@ class TestConvGeluMkldnnFusePass(PassAutoScanTest):
if __name__ == "__main__": if __name__ == "__main__":
paddle.enable_static()
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册