未验证 提交 6d3db9c7 编写于 作者: C Chen Weihang 提交者: GitHub

[Phi] Move batch size like infershape into phi (#40847)

* move batch size like infershape

* revert other op change

* call infermeta in infershape

* adjust batchsize like pos
上级 92afe146
...@@ -15,61 +15,32 @@ limitations under the License. */ ...@@ -15,61 +15,32 @@ limitations under the License. */
#pragma once #pragma once
#include <algorithm> #include <algorithm>
#include <vector> #include <vector>
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/infermeta/unary.h"
#include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/math_function.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using MetaTensor = framework::CompatMetaTensor;
class BatchSizeLikeOp : public framework::OperatorWithKernel { class BatchSizeLikeOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", Type()); OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", Type());
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", Type()); OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", Type());
auto &shape = ctx->Attrs().Get<std::vector<int>>("shape"); MetaTensor x(ctx->GetInputVarPtrs("Input")[0], ctx->IsRuntime());
PADDLE_ENFORCE_GT(shape.size(), 0, MetaTensor out(ctx->GetOutputVarPtrs("Out")[0], ctx->IsRuntime());
platform::errors::InvalidArgument( auto& shape = ctx->Attrs().Get<std::vector<int>>("shape");
"Shape size must be larger than 0, but received: %s.", int x_batch_size_dim = ctx->Attrs().Get<int>("input_dim_idx");
shape.size())); int out_batch_size_dim = ctx->Attrs().Get<int>("output_dim_idx");
std::vector<int64_t> shape_int64(shape.size(), 0); phi::BatchSizeLikeInferMeta(x, shape, x_batch_size_dim, out_batch_size_dim,
std::transform(shape.begin(), shape.end(), shape_int64.begin(), &out);
[](int a) { return static_cast<int64_t>(a); });
auto output_dim = phi::make_ddim(shape_int64);
int input_dim_idx = ctx->Attrs().Get<int>("input_dim_idx");
int input_dim_size = static_cast<int>(ctx->GetInputDim("Input").size());
PADDLE_ENFORCE_GE(input_dim_idx, 0,
platform::errors::InvalidArgument(
"Input dimension index must be larger "
"equal than 0, but received: %s.",
input_dim_idx));
PADDLE_ENFORCE_GT(input_dim_size, input_dim_idx,
platform::errors::InvalidArgument(
"Input dimension size must be larger than "
"input dimension index, but received input "
"dimension size: %s, input dimension index: %s.",
input_dim_size, input_dim_idx));
int output_dim_idx = ctx->Attrs().Get<int>("output_dim_idx");
int output_dim_size = static_cast<int>(shape.size());
PADDLE_ENFORCE_GE(output_dim_idx, 0,
platform::errors::InvalidArgument(
"Output dimension index must be larger "
"equal than 0, but received: %s.",
output_dim_idx));
PADDLE_ENFORCE_GT(
output_dim_size, output_dim_idx,
platform::errors::InvalidArgument(
"Output dimension size must be larger than output dimension index, "
"but received output dimension size: %s, output dimension index: "
"%s.",
output_dim_size, output_dim_idx));
output_dim[output_dim_idx] = ctx->GetInputDim("Input")[input_dim_idx];
ctx->SetOutputDim("Out", output_dim);
} }
}; };
......
...@@ -12,7 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,7 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/operators/batch_size_like.h" #include "paddle/fluid/operators/batch_size_like.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/unary.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -61,9 +64,13 @@ obtained from the `input` tensor. ...@@ -61,9 +64,13 @@ obtained from the `input` tensor.
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(fill_constant_batch_size_like,
FillConstantBatchSizeLikeInferShapeFunctor,
PD_INFER_META(phi::FullBatchSizeLikeInferMeta));
REGISTER_OPERATOR( REGISTER_OPERATOR(
fill_constant_batch_size_like, ops::FillConstantBatchSizeLikeOp, fill_constant_batch_size_like, ops::FillConstantBatchSizeLikeOp,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>, paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>, paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
ops::FillConstantBatchSizeLikeOpMaker, ops::FillConstantBatchSizeLikeOpMaker,
ops::BatchSizeLikeNoNeedBufferVarsInferer); ops::BatchSizeLikeNoNeedBufferVarsInferer,
FillConstantBatchSizeLikeInferShapeFunctor);
...@@ -133,6 +133,59 @@ void ArgsortInferMeta(const MetaTensor& input, ...@@ -133,6 +133,59 @@ void ArgsortInferMeta(const MetaTensor& input,
indices->share_lod(input); indices->share_lod(input);
} }
void BatchSizeLikeInferMeta(const MetaTensor& x,
const std::vector<int>& shape,
int x_batch_size_dim,
int out_batch_size_dim,
MetaTensor* out) {
PADDLE_ENFORCE_GT(
shape.size(),
0UL,
phi::errors::InvalidArgument(
"Shape size must be larger than 0, but received: %s.", shape.size()));
std::vector<int64_t> shape_int64(shape.size(), 0);
std::transform(shape.begin(), shape.end(), shape_int64.begin(), [](int a) {
return static_cast<int64_t>(a);
});
auto output_dim = phi::make_ddim(shape_int64);
int input_dim_size = static_cast<int>(x.dims().size());
PADDLE_ENFORCE_GE(
x_batch_size_dim,
0,
phi::errors::InvalidArgument("Input dimension index must be larger "
"equal than 0, but received: %s.",
x_batch_size_dim));
PADDLE_ENFORCE_GT(input_dim_size,
x_batch_size_dim,
phi::errors::InvalidArgument(
"Input dimension size must be larger than "
"input dimension index, but received input "
"dimension size: %s, input dimension index: %s.",
input_dim_size,
x_batch_size_dim));
int output_dim_size = static_cast<int>(shape.size());
PADDLE_ENFORCE_GE(
out_batch_size_dim,
0,
phi::errors::InvalidArgument("Output dimension index must be larger "
"equal than 0, but received: %s.",
out_batch_size_dim));
PADDLE_ENFORCE_GT(
output_dim_size,
out_batch_size_dim,
phi::errors::InvalidArgument(
"Output dimension size must be larger than output dimension index, "
"but received output dimension size: %s, output dimension index: "
"%s.",
output_dim_size,
out_batch_size_dim));
output_dim[out_batch_size_dim] = x.dims()[x_batch_size_dim];
out->set_dims(output_dim);
}
void CastInferMeta(const MetaTensor& x, DataType out_dtype, MetaTensor* out) { void CastInferMeta(const MetaTensor& x, DataType out_dtype, MetaTensor* out) {
out->set_dims(x.dims()); out->set_dims(x.dims());
out->set_dtype(out_dtype); out->set_dtype(out_dtype);
...@@ -413,6 +466,17 @@ void FlattenWithXShapeInferMeta(const MetaTensor& x, ...@@ -413,6 +466,17 @@ void FlattenWithXShapeInferMeta(const MetaTensor& x,
xshape->share_lod(x); xshape->share_lod(x);
} }
void FullBatchSizeLikeInferMeta(const MetaTensor& x,
const std::vector<int>& shape,
const Scalar& val,
DataType dtype,
int x_batch_size_dim,
int out_batch_size_dim,
MetaTensor* out) {
BatchSizeLikeInferMeta(x, shape, x_batch_size_dim, out_batch_size_dim, out);
out->set_dtype(dtype);
}
void GumbelSoftmaxInferMeta(const MetaTensor& x, void GumbelSoftmaxInferMeta(const MetaTensor& x,
float temperature, float temperature,
bool hard, bool hard,
......
...@@ -48,6 +48,12 @@ void ArgsortInferMeta(const MetaTensor& input, ...@@ -48,6 +48,12 @@ void ArgsortInferMeta(const MetaTensor& input,
MetaTensor* output, MetaTensor* output,
MetaTensor* indices); MetaTensor* indices);
void BatchSizeLikeInferMeta(const MetaTensor& x,
const std::vector<int>& shape,
int x_batch_size_dim,
int out_batch_size_dim,
MetaTensor* out);
void CastInferMeta(const MetaTensor& x, DataType out_dtype, MetaTensor* out); void CastInferMeta(const MetaTensor& x, DataType out_dtype, MetaTensor* out);
void CholeskyInferMeta(const MetaTensor& x, bool upper, MetaTensor* out); void CholeskyInferMeta(const MetaTensor& x, bool upper, MetaTensor* out);
...@@ -92,6 +98,14 @@ void FlattenWithXShapeInferMeta(const MetaTensor& x, ...@@ -92,6 +98,14 @@ void FlattenWithXShapeInferMeta(const MetaTensor& x,
MetaTensor* out, MetaTensor* out,
MetaTensor* xshape); MetaTensor* xshape);
void FullBatchSizeLikeInferMeta(const MetaTensor& x,
const std::vector<int>& shape,
const Scalar& val,
DataType dtype,
int x_batch_size_dim,
int out_batch_size_dim,
MetaTensor* out);
void GumbelSoftmaxInferMeta(const MetaTensor& x, void GumbelSoftmaxInferMeta(const MetaTensor& x,
float temperature, float temperature,
bool hard, bool hard,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册