未验证 提交 868a3203 编写于 作者: H hong 提交者: GitHub

Add infer meta (#41054)

* add some infer meta

* fix bug

* fix bugs;

* fix bug and add set data type

* revert infer shape of lookup table

* recover test
上级 2bc72a06
......@@ -19,6 +19,10 @@
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/multiary.h"
namespace paddle {
namespace operators {
......@@ -28,30 +32,6 @@ class MeshgridOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_GE(
ctx->Inputs("X").size(), 1UL,
platform::errors::InvalidArgument("Input(X) should not be empty."));
PADDLE_ENFORCE_GE(
ctx->Outputs("Out").size(), 1UL,
platform::errors::InvalidArgument("Output(Out) should not be empty."));
auto inputs_dims = ctx->GetInputsDim("X");
const size_t inputs_num = inputs_dims.size();
auto outs_names = ctx->Outputs("Out");
const size_t outputs_num = outs_names.size();
auto out_shape = std::vector<int>(inputs_num);
for (size_t i = 0; i < inputs_num; i++) {
out_shape[i] = inputs_dims[i][0];
}
auto out_dims = phi::make_ddim(std::vector<int>(out_shape));
std::vector<framework::DDim> outs_dims(outputs_num, out_dims);
ctx->SetOutputsDim("Out", outs_dims);
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
......@@ -142,7 +122,10 @@ class MeshgridGradOpMaker : public framework::SingleGradOpMaker<T> {
} // namespace paddle
namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(meshgrid, MeshgridInferShapeFunctor,
PD_INFER_META(phi::MeshgridInferMeta));
REGISTER_OPERATOR(meshgrid, ops::MeshgridOp, ops::MeshgridOpMaker,
ops::MeshgridGradOpMaker<paddle::framework::OpDesc>,
ops::MeshgridGradOpMaker<paddle::imperative::OpBase>);
ops::MeshgridGradOpMaker<paddle::imperative::OpBase>,
MeshgridInferShapeFunctor);
REGISTER_OPERATOR(meshgrid_grad, ops::MeshgridGradOp);
......@@ -19,6 +19,10 @@ limitations under the License. */
#include "paddle/fluid/operators/math/selected_rows_functor.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/multiary.h"
namespace paddle {
namespace operators {
......@@ -27,39 +31,6 @@ class AdagradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("Param"), "Input", "Param", "Adagrad");
OP_INOUT_CHECK(ctx->HasInput("Grad"), "Input", "Grad", "Adagrad");
OP_INOUT_CHECK(ctx->HasInput("Moment"), "Input", "Moment", "Adagrad");
OP_INOUT_CHECK(ctx->HasInput("LearningRate"), "Input", "LearningRate",
"Adagrad");
OP_INOUT_CHECK(ctx->HasOutput("ParamOut"), "Output", "ParamOut", "Adagrad");
OP_INOUT_CHECK(ctx->HasOutput("MomentOut"), "Output", "MomentOut",
"Adagrad");
auto lr_dims = ctx->GetInputDim("LearningRate");
PADDLE_ENFORCE_NE(phi::product(lr_dims), 0,
platform::errors::InvalidArgument(
"Maybe the Input variable LearningRate has not "
"been initialized. You may need to confirm "
"if you put exe.run(startup_program) "
"after optimizer.minimize function."));
PADDLE_ENFORCE_EQ(phi::product(lr_dims), 1,
platform::errors::InvalidArgument(
"LearningRate should have one element"));
auto param_dims = ctx->GetInputDim("Param");
PADDLE_ENFORCE_EQ(
param_dims, ctx->GetInputDim("Grad"),
platform::errors::InvalidArgument("Param and Grad input of AdagradOp "
"should have the same dimension."));
PADDLE_ENFORCE_EQ(
param_dims, ctx->GetInputDim("Moment"),
platform::errors::InvalidArgument("Param and Moment input of AdagradOp "
"should have the same dimension."));
ctx->SetOutputDim("ParamOut", param_dims);
ctx->SetOutputDim("MomentOut", param_dims);
}
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
......@@ -105,4 +76,7 @@ for numerical stability to avoid the division by zero error.
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(adagrad, ops::AdagradOp, ops::AdagradOpMaker);
DECLARE_INFER_SHAPE_FUNCTOR(adagrad, AdagradInferShapeFunctor,
PD_INFER_META(phi::AdagradInferMeta));
REGISTER_OP_WITHOUT_GRADIENT(adagrad, ops::AdagradOp, ops::AdagradOpMaker,
AdagradInferShapeFunctor);
......@@ -14,91 +14,16 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/multiary.h"
namespace paddle {
namespace operators {
class RmspropOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasInput("Param"), true,
platform::errors::NotFound(
"Input(Param) of RmspropOp should not be null."));
PADDLE_ENFORCE_EQ(
ctx->HasInput("MeanSquare"), true,
platform::errors::NotFound(
"Input(MeanSquare) of RmspropOp should not be null."));
PADDLE_ENFORCE_EQ(
ctx->HasInput("LearningRate"), true,
platform::errors::NotFound(
"Input(LearningRate) of RmspropOp should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasInput("Grad"), true,
platform::errors::NotFound(
"Input(Grad) of RmspropOp should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasInput("Moment"), true,
platform::errors::NotFound(
"Input(Moment) of RmspropOp should not be null."));
PADDLE_ENFORCE_EQ(ctx->GetInputsVarType("Param").front(),
framework::proto::VarType::LOD_TENSOR,
platform::errors::InvalidArgument(
"The input var's type in RmspropOp should be "
"LoDTensor, but the received is %s",
ctx->GetInputsVarType("Param").front()));
PADDLE_ENFORCE_EQ(
ctx->HasOutput("ParamOut"), true,
platform::errors::NotFound(
"Output(param_out) of RmspropOp should not be null."));
PADDLE_ENFORCE_EQ(
ctx->HasOutput("MomentOut"), true,
platform::errors::NotFound(
"Output(MomentOut) of RmspropOp should not be null."));
PADDLE_ENFORCE_EQ(
ctx->HasOutput("MeanSquareOut"), true,
platform::errors::NotFound(
"Output(MeanSquareOut) of RmspropOp should not be null."));
if (ctx->Attrs().Get<bool>("centered")) {
PADDLE_ENFORCE_EQ(
ctx->HasOutput("MeanGradOut"), true,
platform::errors::NotFound(
"Output(MeanGradOut) of RmspropOp should not be null."));
}
auto param_dim = ctx->GetInputDim("Param");
PADDLE_ENFORCE_EQ(
param_dim, ctx->GetInputDim("Grad"),
platform::errors::InvalidArgument(
"Param and grad input of RmspropOp should have the same dimension. "
"But received Param's dim [%s] and Grad's dim [%s].",
param_dim, ctx->GetInputDim("Grad")));
PADDLE_ENFORCE_EQ(param_dim, ctx->GetInputDim("Moment"),
platform::errors::InvalidArgument(
"Param and Momentum input of RmspropOp "
"should have the same dimension. But received "
"Param's dim [%s] and Moment [%s]",
param_dim, ctx->GetInputDim("Moment")));
PADDLE_ENFORCE_EQ(param_dim, ctx->GetInputDim("MeanSquare"),
platform::errors::InvalidArgument(
"Param and Momentum input of RmspropOp "
"should have the same dimension. But received "
"Param's dim [%s] and MeanSquare [%s]",
param_dim, ctx->GetInputDim("MeanSquare")));
auto lr_dim = ctx->GetInputDim("LearningRate");
PADDLE_ENFORCE_EQ(phi::product(lr_dim), 1,
platform::errors::InvalidArgument(
"Learning Rate of RmspropOp should be a scalar. But "
"received LearningRate's dim [%s]",
phi::product(lr_dim)));
ctx->SetOutputDim("ParamOut", param_dim);
ctx->SetOutputDim("MomentOut", param_dim);
ctx->SetOutputDim("MeanSquareOut", param_dim);
if (ctx->Attrs().Get<bool>("centered")) {
ctx->SetOutputDim("MeanGradOut", param_dim);
}
}
};
class RmspropOpMaker : public framework::OpProtoAndCheckerMaker {
......@@ -169,4 +94,7 @@ http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf)
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(rmsprop, ops::RmspropOp, ops::RmspropOpMaker);
DECLARE_INFER_SHAPE_FUNCTOR(rmsprop, RmspropInferShapeFunctor,
PD_INFER_META(phi::RmspropInferMeta));
REGISTER_OP_WITHOUT_GRADIENT(rmsprop, ops::RmspropOp, ops::RmspropOpMaker,
RmspropInferShapeFunctor);
......@@ -19,6 +19,10 @@ limitations under the License. */
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/multiary.h"
namespace paddle {
namespace operators {
......@@ -26,46 +30,6 @@ class SGDOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasInput("Param"), true,
platform::errors::NotFound(
"Input(Param) of SGDOp should not be null."));
PADDLE_ENFORCE_EQ(
ctx->HasInput("Grad"), true,
platform::errors::NotFound("Input(Grad) of SGDOp should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasInput("LearningRate"), true,
platform::errors::NotFound(
"Input(LearningRate) of SGDOp should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasOutput("ParamOut"), true,
platform::errors::NotFound(
"Output(ParamOut) of SGDOp should not be null."));
auto lr_dims = ctx->GetInputDim("LearningRate");
PADDLE_ENFORCE_NE(phi::product(lr_dims), 0,
platform::errors::NotFound(
"Maybe the Input variable LearningRate has not "
"been initialized. You may need to confirm "
"if you put exe.run(startup_program) "
"after optimizer.minimize function."));
PADDLE_ENFORCE_EQ(phi::product(lr_dims), 1,
platform::errors::InvalidArgument(
"Learning rate should have 1 element. But received "
"LearningRate dims [%s]",
phi::product(lr_dims)));
auto param_dim = ctx->GetInputDim("Param");
if (ctx->GetInputsVarType("Grad")[0] ==
framework::proto::VarType::LOD_TENSOR) {
PADDLE_ENFORCE_EQ(
param_dim, ctx->GetInputDim("Grad"),
platform::errors::InvalidArgument(
"SGD Operator's input Param and Grad dimensions do not match. "
"The Param %s shape is [%s], but the Grad %s shape is [%s].",
ctx->Inputs("Param")[0], param_dim, ctx->Inputs("Grad")[0],
ctx->GetInputDim("Grad")));
}
ctx->SetOutputDim("ParamOut", param_dim);
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
......@@ -161,8 +125,10 @@ $$param\_out = param - learning\_rate * grad$$
} // namespace paddle
namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(sgd, SGDInferShapeFunctor,
PD_INFER_META(phi::SGDInferMeta));
REGISTER_OPERATOR(
sgd, ops::SGDOp, ops::SGDOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
ops::SGDOpInferVarType);
ops::SGDOpInferVarType, SGDInferShapeFunctor);
......@@ -15,6 +15,10 @@
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/unary.h"
namespace paddle {
namespace operators {
......@@ -24,49 +28,6 @@ class TemporalShiftOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "SpectralNorm");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "SpectralNorm");
auto dim_x = ctx->GetInputDim("X");
PADDLE_ENFORCE_EQ(dim_x.size(), 4,
platform::errors::InvalidArgument(
"Input(X) rank should be 4 in shape of [N*T, C, H, "
"W], but received X rank(%d)",
dim_x.size()));
int seg_num = ctx->Attrs().Get<int>("seg_num");
float shift_ratio = ctx->Attrs().Get<float>("shift_ratio");
PADDLE_ENFORCE_GT(
seg_num, 0,
platform::errors::InvalidArgument(
"Attr(seg_num) should be greater than 0, but received %d",
seg_num));
PADDLE_ENFORCE_GT(
shift_ratio, 0.,
platform::errors::InvalidArgument(
"Attr(shift_ratio) should be greater than 0, but received %d",
shift_ratio));
PADDLE_ENFORCE_LT(
shift_ratio, 0.5,
platform::errors::InvalidArgument(
"Attr(shift_ratio) should be less than 0.5, but received %d",
shift_ratio));
if (ctx->IsRuntime()) {
PADDLE_ENFORCE_EQ(dim_x[0] % seg_num, 0,
platform::errors::InvalidArgument(
"Input(X) dimension[0] should be divided exactly "
"by Attr(seg_num), but received X dimension[0](%d) "
"mod seg_num(%d) != 0",
dim_x[0], seg_num));
}
ctx->SetOutputDim("Out", dim_x);
ctx->ShareLoD("X", "Out");
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
......@@ -186,10 +147,13 @@ class TemporalShiftGradOpMaker : public framework::SingleGradOpMaker<T> {
} // namespace paddle
namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(temporal_shift, TemporalShiftInferShapeFunctor,
PD_INFER_META(phi::TemporalShiftInferMeta));
REGISTER_OPERATOR(temporal_shift, ops::TemporalShiftOp,
ops::TemporalShiftOpMaker,
ops::TemporalShiftGradOpMaker<paddle::framework::OpDesc>,
ops::TemporalShiftGradOpMaker<paddle::imperative::OpBase>);
ops::TemporalShiftGradOpMaker<paddle::imperative::OpBase>,
TemporalShiftInferShapeFunctor);
REGISTER_OPERATOR(temporal_shift_grad, ops::TemporalShiftOpGrad);
REGISTER_OP_CPU_KERNEL(temporal_shift, ops::TemporalShiftKernel<float>,
ops::TemporalShiftKernel<double>);
......
......@@ -75,6 +75,32 @@ void AllValueCompareInferMeta(const MetaTensor& x,
out->set_dtype(DataType::BOOL);
}
void EmbeddingInferMeta(const MetaTensor& input,
const MetaTensor& weight,
int64_t padding_idx,
MetaTensor* out) {
auto table_dims = weight.dims();
auto ids_dims = input.dims();
int ids_rank = ids_dims.size();
VLOG(5) << "ids rank is " << ids_rank << std::endl;
PADDLE_ENFORCE_EQ(
table_dims.size(),
2,
phi::errors::InvalidArgument(
"ShapeError: The dimensions of the 'lookup table' must be 2. "
"But received lookup table's dimensions = %d, "
"lookup table's shape = [%s].",
table_dims.size(),
table_dims));
auto output_dims = phi::vectorize(ids_dims);
output_dims.push_back(table_dims[1]);
out->set_dims(phi::make_ddim(output_dims));
out->set_dtype(weight.dtype());
out->share_lod(input);
}
void KLDivInferMeta(const MetaTensor& x,
const MetaTensor& label,
const std::string& reduction,
......
......@@ -37,6 +37,11 @@ void AllValueCompareInferMeta(const MetaTensor& x,
MetaTensor* out,
MetaConfig config = MetaConfig());
void EmbeddingInferMeta(const MetaTensor& input,
const MetaTensor& weight,
int64_t padding_idx,
MetaTensor* out);
void KLDivInferMeta(const MetaTensor& x,
const MetaTensor& label,
const std::string& reduction,
......
......@@ -66,6 +66,32 @@ void AdadeltaInferMeta(const MetaTensor& param,
avg_squared_update_out->set_dtype(avg_squared_update.dtype());
}
void AdagradInferMeta(const MetaTensor& param,
const MetaTensor& grad,
const MetaTensor& moment,
const MetaTensor& learning_rate,
float epsilon,
MetaTensor* param_out,
MetaTensor* moment_out) {
auto lr_dims = learning_rate.dims();
PADDLE_ENFORCE_EQ(
phi::product(lr_dims),
1,
phi::errors::InvalidArgument("LearningRate should have one element"));
auto param_dims = param.dims();
PADDLE_ENFORCE_EQ(
param_dims,
moment.dims(),
phi::errors::InvalidArgument("Param and Moment input of AdagradOp "
"should have the same dimension."));
param_out->set_dims(param_dims);
param_out->set_dtype(param.dtype());
moment_out->set_dims(param_dims);
moment_out->set_dtype(moment.dtype());
}
void AdamInferMeta(const MetaTensor& param,
const MetaTensor& grad,
const MetaTensor& learning_rate,
......@@ -1390,6 +1416,22 @@ void InterpolateInferMeta(
}
}
void MeshgridInferMeta(const std::vector<MetaTensor*>& inputs,
std::vector<MetaTensor*> outputs) {
const size_t inputs_num = inputs.size();
auto out_shape = std::vector<int>(inputs_num);
for (size_t i = 0; i < inputs.size(); i++) {
out_shape[i] = inputs[i]->dims()[0];
}
auto out_dims = phi::make_ddim(std::vector<int>(out_shape));
for (size_t i = 0; i < outputs.size(); ++i) {
outputs[i]->set_dims(out_dims);
outputs[i]->set_dtype(inputs[0]->dtype());
}
}
void MultiDotInferMeta(const std::vector<MetaTensor*>& x, MetaTensor* out) {
auto inputs_dims = GetMetaTensorsDim(x);
......@@ -1582,6 +1624,65 @@ void PsroiPoolInferMeta(const MetaTensor& x,
out->set_dtype(x.dtype());
}
void RmspropInferMeta(const MetaTensor& param,
const MetaTensor& mean_square,
const MetaTensor& grad,
const MetaTensor& moment,
const MetaTensor& learning_rate,
paddle::optional<const MetaTensor&> mean_grad,
float epsilon,
float decay,
float momentum,
bool centered,
MetaTensor* param_out,
MetaTensor* moment_out,
MetaTensor* mean_square_out,
MetaTensor* mean_grad_out) {
if (centered) {
PADDLE_ENFORCE_NOT_NULL(
mean_grad_out,
phi::errors::InvalidArgument(
"Output(MeanGradOut) of RmspropOp should not be null."));
}
auto param_dim = param.dims();
PADDLE_ENFORCE_EQ(param_dim,
moment.dims(),
phi::errors::InvalidArgument(
"Param and Momentum input of RmspropOp "
"should have the same dimension. But received "
"Param's dim [%s] and Moment [%s]",
param_dim,
moment.dims()));
PADDLE_ENFORCE_EQ(param_dim,
mean_square.dims(),
phi::errors::InvalidArgument(
"Param and Momentum input of RmspropOp "
"should have the same dimension. But received "
"Param's dim [%s] and MeanSquare [%s]",
param_dim,
mean_square.dims()));
auto lr_dim = learning_rate.dims();
PADDLE_ENFORCE_EQ(phi::product(lr_dim),
1,
phi::errors::InvalidArgument(
"Learning Rate of RmspropOp should be a scalar. But "
"received LearningRate's dim [%s]",
phi::product(lr_dim)));
param_out->set_dims(param_dim);
param_out->set_dtype(param.dtype());
moment_out->set_dims(param_dim);
moment_out->set_dtype(moment.dtype());
mean_square_out->set_dims(param_dim);
mean_square_out->set_dtype(mean_square.dtype());
if (centered) {
mean_grad_out->set_dims(param_dim);
mean_grad_out->set_dtype(mean_grad.get_ptr()->dtype());
}
}
void RnnInferMeta(const MetaTensor& x,
const std::vector<MetaTensor*>& pre_state,
const std::vector<MetaTensor*>& weight_list,
......@@ -1667,6 +1768,29 @@ void RnnInferMeta(const MetaTensor& x,
}
}
void SGDInferMeta(const MetaTensor& param,
const MetaTensor& learning_rate,
const MetaTensor& grad,
paddle::optional<const MetaTensor&> master_param,
bool multi_precision,
MetaTensor* param_out,
MetaTensor* master_param_out) {
PADDLE_ENFORCE_NOT_NULL(param_out,
phi::errors::InvalidArgument(
"Output(ParamOut) of SGDOp should not be null."));
auto lr_dims = learning_rate.dims();
PADDLE_ENFORCE_EQ(phi::product(lr_dims),
1,
phi::errors::InvalidArgument(
"Learning rate should have 1 element. But received "
"LearningRate dims [%s]",
phi::product(lr_dims)));
param_out->set_dims(param.dims());
param_out->set_dtype(param.dtype());
}
void StackInferMeta(const std::vector<MetaTensor*>& x,
int axis,
MetaTensor* out) {
......
......@@ -47,6 +47,14 @@ void AdadeltaInferMeta(const MetaTensor& param,
MetaTensor* avg_squared_grad_out,
MetaTensor* avg_squared_update_out);
void AdagradInferMeta(const MetaTensor& param,
const MetaTensor& grad,
const MetaTensor& moment,
const MetaTensor& learning_rate,
float epsilon,
MetaTensor* param_out,
MetaTensor* moment_out);
void AdamaxInferMeta(const MetaTensor& param,
const MetaTensor& grad,
const MetaTensor& learning_rate,
......@@ -215,6 +223,9 @@ void InterpolateInferMeta(
MetaTensor* output,
MetaConfig config = MetaConfig());
void MeshgridInferMeta(const std::vector<MetaTensor*>& inputs,
std::vector<MetaTensor*> outputs);
void MultiDotInferMeta(const std::vector<MetaTensor*>& x, MetaTensor* out);
void MultiplexInferMeta(const std::vector<MetaTensor*>& ins,
......@@ -230,6 +241,21 @@ void PsroiPoolInferMeta(const MetaTensor& x,
float spatial_scale,
MetaTensor* out);
void RmspropInferMeta(const MetaTensor& param,
const MetaTensor& mean_square,
const MetaTensor& grad,
const MetaTensor& moment,
const MetaTensor& learning_rate,
paddle::optional<const MetaTensor&> mean_grad,
float epsilon,
float decay,
float momentum,
bool centered,
MetaTensor* param_out,
MetaTensor* moment_out,
MetaTensor* mean_square_out,
MetaTensor* mean_grad_out);
void RnnInferMeta(const MetaTensor& x,
const std::vector<MetaTensor*>& pre_state,
const std::vector<MetaTensor*>& weight_list,
......@@ -247,6 +273,14 @@ void RnnInferMeta(const MetaTensor& x,
std::vector<MetaTensor*> state,
MetaTensor* reserve);
void SGDInferMeta(const MetaTensor& param,
const MetaTensor& learning_rate,
const MetaTensor& grad,
paddle::optional<const MetaTensor&> master_param,
bool multi_precision,
MetaTensor* param_out,
MetaTensor* master_param_out);
void StackInferMeta(const std::vector<MetaTensor*>& x,
int axis,
MetaTensor* out);
......
......@@ -2102,6 +2102,52 @@ void SumRawInferMeta(const MetaTensor& x,
out->set_layout(x.layout());
}
void TemporalShiftInferMeta(const MetaTensor& x,
int seg_num,
float shift_ratio,
const std::string& data_format,
MetaTensor* out,
MetaConfig config) {
auto dim_x = x.dims();
PADDLE_ENFORCE_EQ(dim_x.size(),
4,
phi::errors::InvalidArgument(
"Input(X) rank should be 4 in shape of [N*T, C, H, "
"W], but received X rank(%d)",
dim_x.size()));
PADDLE_ENFORCE_GT(
seg_num,
0,
phi::errors::InvalidArgument(
"Attr(seg_num) should be greater than 0, but received %d", seg_num));
PADDLE_ENFORCE_GT(
shift_ratio,
0.,
phi::errors::InvalidArgument(
"Attr(shift_ratio) should be greater than 0, but received %d",
shift_ratio));
PADDLE_ENFORCE_LT(
shift_ratio,
0.5,
phi::errors::InvalidArgument(
"Attr(shift_ratio) should be less than 0.5, but received %d",
shift_ratio));
if (config.is_runtime) {
PADDLE_ENFORCE_EQ(dim_x[0] % seg_num,
0,
phi::errors::InvalidArgument(
"Input(X) dimension[0] should be divided exactly "
"by Attr(seg_num), but received X dimension[0](%d) "
"mod seg_num(%d) != 0",
dim_x[0],
seg_num));
}
out->share_meta(x);
}
void TileInferMeta(const MetaTensor& x,
const IntArray& repeat_times,
MetaTensor* out,
......
......@@ -315,6 +315,13 @@ void SumRawInferMeta(const MetaTensor& x,
DataType dtype,
MetaTensor* out);
void TemporalShiftInferMeta(const MetaTensor& x,
int seg_num,
float shift_ratio,
const std::string& data_format,
MetaTensor* out,
MetaConfig config = MetaConfig());
void TileInferMeta(const MetaTensor& x,
const IntArray& repeat_times,
MetaTensor* out,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册