未验证 提交 246ac976 编写于 作者: Y Yuang Liu 提交者: GitHub

[operator migration] Migrate infer shape for merged momentum (#44338)

上级 4baf0dbe
......@@ -45,6 +45,7 @@ yaml_types_mapping = {
'int' : 'int', 'int32_t' : 'int32_t', 'int64_t' : 'int64_t', 'size_t' : 'size_t', \
'float' : 'float', 'double' : 'double', 'bool' : 'bool', \
'str' : 'std::string', \
'str[]' : 'std::vector<std::string>', 'float[]' : 'std::vector<float>', \
'Place' : 'paddle::Place', 'DataLayout' : 'paddle::experimental::DataLayout', 'DataType' : 'paddle::experimental::DataType', \
'int64_t[]' : 'std::vector<int64_t>', 'int[]' : 'std::vector<int>',
'Tensor' : 'Tensor',
......
......@@ -12,7 +12,10 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/multiary.h"
namespace paddle {
namespace operators {
......@@ -22,8 +25,6 @@ class MergedMomentumOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext *ctx) const override {}
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
auto param_dtype =
......@@ -100,6 +101,11 @@ class MergedMomentumOpMaker : public framework::OpProtoAndCheckerMaker {
namespace ops = paddle::operators;
namespace plat = paddle::platform;
DECLARE_INFER_SHAPE_FUNCTOR(merged_momentum,
MergedMomentumInferShapeFunctor,
PD_INFER_META(phi::MergedMomentumInferMeta));
REGISTER_OP_WITHOUT_GRADIENT(merged_momentum,
ops::MergedMomentumOp,
ops::MergedMomentumOpMaker);
ops::MergedMomentumOpMaker,
MergedMomentumInferShapeFunctor);
......@@ -284,5 +284,15 @@ std::unique_ptr<std::vector<phi::DenseTensor>> PrepareData(
return pt_tensors;
}
paddle::optional<std::vector<phi::DenseTensor>> PrepareData(
const paddle::optional<std::vector<Tensor>>& inputs,
const phi::TensorArgDef& target_args_def,
const TransformFlag& transform_flag) {
if (inputs) {
return {*PrepareData(*inputs, target_args_def, transform_flag)};
}
return paddle::none;
}
} // namespace experimental
} // namespace paddle
......@@ -76,5 +76,10 @@ std::unique_ptr<std::vector<phi::DenseTensor>> PrepareData(
const phi::TensorArgDef& target_args_def,
const TransformFlag& transform_flag);
paddle::optional<std::vector<phi::DenseTensor>> PrepareData(
const paddle::optional<std::vector<Tensor>>& inputs,
const phi::TensorArgDef& target_args_def,
const TransformFlag& transform_flag);
} // namespace experimental
} // namespace paddle
......@@ -131,9 +131,11 @@ class BaseAPI(object):
'long': 'long',
'size_t': 'size_t',
'float': 'float',
'float[]': 'const std::vector<float>&',
'double': 'double',
'bool': 'bool',
'str': 'const std::string&',
'str[] ': 'const std::vector<std::string>&',
'Place': 'const Place&',
'DataLayout': 'DataLayout',
'DataType': 'DataType',
......
......@@ -1549,6 +1549,22 @@ void MergedAdamInferMeta(
std::vector<MetaTensor*> beta2_pow_out,
std::vector<MetaTensor*> master_param_out) {}
void MergedMomentumInferMeta(
const std::vector<const MetaTensor*>& param,
const std::vector<const MetaTensor*>& grad,
const std::vector<const MetaTensor*>& velocity,
const std::vector<const MetaTensor*>& learning_rate,
const paddle::optional<std::vector<const MetaTensor*>>& master_param,
float mu,
bool use_nesterov,
const std::vector<std::string>& regularization_method,
const std::vector<float>& regularization_coeff,
bool multi_precision,
float rescale_grad,
std::vector<MetaTensor*> param_out,
std::vector<MetaTensor*> velocity_out,
std::vector<MetaTensor*> master_param_out) {}
void MeshgridInferMeta(const std::vector<const MetaTensor*>& inputs,
std::vector<MetaTensor*> outputs) {
const size_t inputs_num = inputs.size();
......
......@@ -255,6 +255,22 @@ void MergedAdamInferMeta(
std::vector<MetaTensor*> beta2_pow_out,
std::vector<MetaTensor*> master_param_out);
void MergedMomentumInferMeta(
const std::vector<const MetaTensor*>& param,
const std::vector<const MetaTensor*>& grad,
const std::vector<const MetaTensor*>& velocity,
const std::vector<const MetaTensor*>& learning_rate,
const paddle::optional<std::vector<const MetaTensor*>>& master_param,
float mu,
bool use_nesterov,
const std::vector<std::string>& regularization_method,
const std::vector<float>& regularization_coeff,
bool multi_precision,
float rescale_grad,
std::vector<MetaTensor*> param_out,
std::vector<MetaTensor*> velocity_out,
std::vector<MetaTensor*> master_param_out);
void MeshgridInferMeta(const std::vector<const MetaTensor*>& inputs,
std::vector<MetaTensor*> outputs);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册