未验证 提交 6d3db9c7 编写于 作者: C Chen Weihang 提交者: GitHub

[Phi] Move batch size like infershape into phi (#40847)

* move batch size like infershape

* revert other op change

* call infermeta in infershape

* adjust batchsize like pos
上级 92afe146
......@@ -15,61 +15,32 @@ limitations under the License. */
#pragma once
#include <algorithm>
#include <vector>
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/infermeta/unary.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace paddle {
namespace operators {
using MetaTensor = framework::CompatMetaTensor;
class BatchSizeLikeOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", Type());
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", Type());
auto &shape = ctx->Attrs().Get<std::vector<int>>("shape");
PADDLE_ENFORCE_GT(shape.size(), 0,
platform::errors::InvalidArgument(
"Shape size must be larger than 0, but received: %s.",
shape.size()));
std::vector<int64_t> shape_int64(shape.size(), 0);
std::transform(shape.begin(), shape.end(), shape_int64.begin(),
[](int a) { return static_cast<int64_t>(a); });
auto output_dim = phi::make_ddim(shape_int64);
int input_dim_idx = ctx->Attrs().Get<int>("input_dim_idx");
int input_dim_size = static_cast<int>(ctx->GetInputDim("Input").size());
PADDLE_ENFORCE_GE(input_dim_idx, 0,
platform::errors::InvalidArgument(
"Input dimension index must be larger "
"equal than 0, but received: %s.",
input_dim_idx));
PADDLE_ENFORCE_GT(input_dim_size, input_dim_idx,
platform::errors::InvalidArgument(
"Input dimension size must be larger than "
"input dimension index, but received input "
"dimension size: %s, input dimension index: %s.",
input_dim_size, input_dim_idx));
int output_dim_idx = ctx->Attrs().Get<int>("output_dim_idx");
int output_dim_size = static_cast<int>(shape.size());
PADDLE_ENFORCE_GE(output_dim_idx, 0,
platform::errors::InvalidArgument(
"Output dimension index must be larger "
"equal than 0, but received: %s.",
output_dim_idx));
PADDLE_ENFORCE_GT(
output_dim_size, output_dim_idx,
platform::errors::InvalidArgument(
"Output dimension size must be larger than output dimension index, "
"but received output dimension size: %s, output dimension index: "
"%s.",
output_dim_size, output_dim_idx));
output_dim[output_dim_idx] = ctx->GetInputDim("Input")[input_dim_idx];
ctx->SetOutputDim("Out", output_dim);
MetaTensor x(ctx->GetInputVarPtrs("Input")[0], ctx->IsRuntime());
MetaTensor out(ctx->GetOutputVarPtrs("Out")[0], ctx->IsRuntime());
auto& shape = ctx->Attrs().Get<std::vector<int>>("shape");
int x_batch_size_dim = ctx->Attrs().Get<int>("input_dim_idx");
int out_batch_size_dim = ctx->Attrs().Get<int>("output_dim_idx");
phi::BatchSizeLikeInferMeta(x, shape, x_batch_size_dim, out_batch_size_dim,
&out);
}
};
......
......@@ -12,7 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/operators/batch_size_like.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/unary.h"
namespace paddle {
namespace operators {
......@@ -61,9 +64,13 @@ obtained from the `input` tensor.
} // namespace paddle
namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(fill_constant_batch_size_like,
FillConstantBatchSizeLikeInferShapeFunctor,
PD_INFER_META(phi::FullBatchSizeLikeInferMeta));
REGISTER_OPERATOR(
fill_constant_batch_size_like, ops::FillConstantBatchSizeLikeOp,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
ops::FillConstantBatchSizeLikeOpMaker,
ops::BatchSizeLikeNoNeedBufferVarsInferer);
ops::BatchSizeLikeNoNeedBufferVarsInferer,
FillConstantBatchSizeLikeInferShapeFunctor);
......@@ -133,6 +133,59 @@ void ArgsortInferMeta(const MetaTensor& input,
indices->share_lod(input);
}
void BatchSizeLikeInferMeta(const MetaTensor& x,
const std::vector<int>& shape,
int x_batch_size_dim,
int out_batch_size_dim,
MetaTensor* out) {
PADDLE_ENFORCE_GT(
shape.size(),
0UL,
phi::errors::InvalidArgument(
"Shape size must be larger than 0, but received: %s.", shape.size()));
std::vector<int64_t> shape_int64(shape.size(), 0);
std::transform(shape.begin(), shape.end(), shape_int64.begin(), [](int a) {
return static_cast<int64_t>(a);
});
auto output_dim = phi::make_ddim(shape_int64);
int input_dim_size = static_cast<int>(x.dims().size());
PADDLE_ENFORCE_GE(
x_batch_size_dim,
0,
phi::errors::InvalidArgument("Input dimension index must be larger "
"equal than 0, but received: %s.",
x_batch_size_dim));
PADDLE_ENFORCE_GT(input_dim_size,
x_batch_size_dim,
phi::errors::InvalidArgument(
"Input dimension size must be larger than "
"input dimension index, but received input "
"dimension size: %s, input dimension index: %s.",
input_dim_size,
x_batch_size_dim));
int output_dim_size = static_cast<int>(shape.size());
PADDLE_ENFORCE_GE(
out_batch_size_dim,
0,
phi::errors::InvalidArgument("Output dimension index must be larger "
"equal than 0, but received: %s.",
out_batch_size_dim));
PADDLE_ENFORCE_GT(
output_dim_size,
out_batch_size_dim,
phi::errors::InvalidArgument(
"Output dimension size must be larger than output dimension index, "
"but received output dimension size: %s, output dimension index: "
"%s.",
output_dim_size,
out_batch_size_dim));
output_dim[out_batch_size_dim] = x.dims()[x_batch_size_dim];
out->set_dims(output_dim);
}
void CastInferMeta(const MetaTensor& x, DataType out_dtype, MetaTensor* out) {
out->set_dims(x.dims());
out->set_dtype(out_dtype);
......@@ -413,6 +466,17 @@ void FlattenWithXShapeInferMeta(const MetaTensor& x,
xshape->share_lod(x);
}
void FullBatchSizeLikeInferMeta(const MetaTensor& x,
const std::vector<int>& shape,
const Scalar& val,
DataType dtype,
int x_batch_size_dim,
int out_batch_size_dim,
MetaTensor* out) {
BatchSizeLikeInferMeta(x, shape, x_batch_size_dim, out_batch_size_dim, out);
out->set_dtype(dtype);
}
void GumbelSoftmaxInferMeta(const MetaTensor& x,
float temperature,
bool hard,
......
......@@ -48,6 +48,12 @@ void ArgsortInferMeta(const MetaTensor& input,
MetaTensor* output,
MetaTensor* indices);
void BatchSizeLikeInferMeta(const MetaTensor& x,
const std::vector<int>& shape,
int x_batch_size_dim,
int out_batch_size_dim,
MetaTensor* out);
void CastInferMeta(const MetaTensor& x, DataType out_dtype, MetaTensor* out);
void CholeskyInferMeta(const MetaTensor& x, bool upper, MetaTensor* out);
......@@ -92,6 +98,14 @@ void FlattenWithXShapeInferMeta(const MetaTensor& x,
MetaTensor* out,
MetaTensor* xshape);
void FullBatchSizeLikeInferMeta(const MetaTensor& x,
const std::vector<int>& shape,
const Scalar& val,
DataType dtype,
int x_batch_size_dim,
int out_batch_size_dim,
MetaTensor* out);
void GumbelSoftmaxInferMeta(const MetaTensor& x,
float temperature,
bool hard,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册