未验证 提交 888a30c9 编写于 作者: HappyHeavyRain's avatar HappyHeavyRain 提交者: GitHub

Automatically generate 'assign' operator (#51940)

* support assign op

* support assign infer_var_type

* change code according to review

* change code according to review

* only save 'get_infer_var_type_func'

* rest file mode
上级 97fc2a0f
......@@ -169,7 +169,7 @@ if (WITH_ASCEND)
endif()
if (WITH_ASCEND_CL)
cc_test(assign_op_npu_test SRCS assign_op_npu_test.cc DEPS assign_op)
cc_test(assign_op_npu_test SRCS assign_op_npu_test.cc DEPS generated_static_op)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} npu_op_runner)
endif()
......@@ -185,7 +185,7 @@ set(GLOB_OPERATOR_DEPS ${OPERATOR_DEPS} CACHE INTERNAL "Global Op dependencies")
cc_test(test_common_infer_shape_functions SRCS test_common_infer_shape_functions.cc DEPS common_infer_shape_functions ${COMMON_OP_DEPS} activation_op elementwise_add_op softmax_op softmax)
cc_test(gather_test SRCS gather_test.cc DEPS tensor)
cc_test(assign_op_test SRCS assign_op_test.cc DEPS assign_op)
cc_test(assign_op_test SRCS assign_op_test.cc DEPS generated_static_op)
cc_test(scatter_test SRCS scatter_test.cc DEPS tensor math_function)
cc_test(beam_search_decode_op_test SRCS beam_search_decode_op_test.cc DEPS lod_tensor)
cc_test(save_load_op_test SRCS save_load_op_test.cc DEPS save_op load_op)
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/assign_op.h"
#include <string>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/prim/api/composite_backward/composite_backward_api.h"
#include "paddle/fluid/prim/utils/static/composite_grad_desc_maker.h"
#include "paddle/fluid/prim/utils/static/desc_tensor.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/unary.h"
namespace paddle {
namespace framework {
class OpDesc;
class Variable;
} // namespace framework
namespace imperative {
class OpBase;
} // namespace imperative
} // namespace paddle
namespace paddle {
namespace operators {
class AssignOp : public framework::OperatorWithKernel {
public:
AssignOp(const std::string &type,
const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: OperatorWithKernel(type, inputs, outputs, attrs) {}
protected:
phi::KernelKey GetKernelTypeForVar(
const std::string &var_name,
const phi::DenseTensor &tensor,
const phi::KernelKey &expected_kernel_type) const override {
return phi::KernelKey(phi::Backend::ALL_BACKEND,
tensor.layout(),
expected_kernel_type.dtype());
}
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
const framework::Variable *var = ctx.InputVar("X");
if (var->IsType<framework::LoDTensorArray>()) {
auto t_arr = var->Get<framework::LoDTensorArray>();
// NOTE(liym27): Support an empty tensor array as Input.
// And set the kernel type is float.
if (t_arr.size() == 0) {
return phi::KernelKey(framework::proto::VarType::FP32,
ctx.device_context().GetPlace());
}
}
return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.device_context().GetPlace());
}
};
class AssignInferVarType : public framework::VarTypeInference {
public:
void operator()(framework::InferVarTypeContext *ctx) const override {
ctx->SyncTypeAndDataType("X", "Out");
}
};
class AssignOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput(
"X",
"(phi::DenseTensor, SelectedRows or phi::DenseTensorArray) The input "
"variable "
"could be phi::DenseTensor, SelectedRows or phi::DenseTensorArray.")
.AsDispensable();
AddOutput("Out",
"(phi::DenseTensor, SelectedRows or phi::DenseTensorArray) The "
"type of output "
"is the same as input X.");
AddComment(R"DOC(Assign Operator
Out = X, when type in [phi::DenseTensor/SelectedRows/phi::DenseTensorArray]
raise error if the type is not listed above.
)DOC");
}
};
template <typename T>
class AssignGradMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> op) const override {
op->SetType("assign");
op->SetInput("X", this->OutputGrad("Out"));
op->SetOutput("Out", this->InputGrad("X"));
}
};
class AssignCompositeGradOpMaker : public prim::CompositeGradOpMakerBase {
using prim::CompositeGradOpMakerBase::CompositeGradOpMakerBase;
public:
void Apply() override {
paddle::Tensor out_grad = this->GetSingleOutputGrad("Out");
paddle::Tensor input_grad = this->GetSingleInputGrad("X");
auto dx_ptr = this->GetOutputPtr(&input_grad);
std::string dx_name = this->GetOutputName(input_grad);
VLOG(6) << "Running assign_grad composite func";
prim::assign_grad<prim::DescTensor>(out_grad, dx_ptr);
this->RecoverOutputName(input_grad, dx_name);
}
};
DECLARE_INPLACE_OP_INFERER(AssignOpInplaceInferer, {"X", "Out"});
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
DECLARE_INFER_SHAPE_FUNCTOR(assign,
AssignInferShapeFunctor,
PD_INFER_META(phi::UnchangedInferMeta));
REGISTER_OPERATOR(assign,
ops::AssignOp,
ops::AssignCompositeGradOpMaker,
ops::AssignGradMaker<paddle::framework::OpDesc>,
ops::AssignGradMaker<paddle::imperative::OpBase>,
ops::AssignOpProtoMaker,
ops::AssignOpInplaceInferer,
ops::AssignInferVarType,
AssignInferShapeFunctor);
......@@ -30,6 +30,20 @@ from type_mapping import (
)
def get_infer_var_type_func(op_name):
if op_name == "assign":
return f"""
class {to_pascal_case(op_name)}InferVarType : public framework::VarTypeInference {{
public:
void operator()(framework::InferVarTypeContext *ctx) const override {{
ctx->SyncTypeAndDataType("X", "Out");
}}
}};
"""
else:
return None
def quote(s):
return '"{}"'.format(s)
......
......@@ -23,6 +23,7 @@ from filters import (
cartesian_prod_mapping,
delete_last_underline,
find_optinal_inputs_name,
get_infer_var_type_func,
to_composite_grad_opmaker_name,
to_input_name,
to_int_array_tensor_name,
......@@ -66,6 +67,7 @@ env.filters["to_opmaker_name_cstr"] = to_opmaker_name_cstr
env.filters["cartesian_prod_mapping"] = cartesian_prod_mapping
env.filters["to_composite_grad_opmaker_name"] = to_composite_grad_opmaker_name
env.filters["to_variable_names"] = to_variable_names
env.filters["get_infer_var_type_func"] = get_infer_var_type_func
env.filters["assert_dense_or_sr"] = assert_dense_or_sr
env.filters["find_optinal_inputs_name"] = find_optinal_inputs_name
env.tests["base_op"] = is_base_op
......
......@@ -21,6 +21,7 @@ from filters import (
assert_dense_or_sr,
cartesian_prod_mapping,
find_optinal_inputs_name,
get_infer_var_type_func,
to_composite_grad_opmaker_name,
to_input_name,
to_int_array_tensor_name,
......@@ -67,6 +68,7 @@ env.filters["to_opmaker_name_cstr"] = to_opmaker_name_cstr
env.filters["cartesian_prod_mapping"] = cartesian_prod_mapping
env.filters["to_composite_grad_opmaker_name"] = to_composite_grad_opmaker_name
env.filters["to_variable_names"] = to_variable_names
env.filters["get_infer_var_type_func"] = get_infer_var_type_func
env.tests["base_op"] = is_base_op
env.tests["composite_op"] = is_composite_op
env.tests["vec"] = is_vec
......
......@@ -101,6 +101,24 @@ phi::KernelKey GetReduceGradExpectedKernelType(
return phi::KernelKey(input_data_type, ctx.GetPlace());
}
phi::KernelKey GetAssignExpectedKernelType(
const framework::ExecutionContext& ctx,
const framework::OperatorWithKernel* op_ptr) {
const framework::Variable* var = ctx.InputVar("X");
if (var->IsType<framework::LoDTensorArray>()) {
auto t_arr = var->Get<framework::LoDTensorArray>();
// NOTE(liym27): Support an empty tensor array as Input.
// And set the kernel type is float.
if (t_arr.size() == 0) {
return phi::KernelKey(framework::proto::VarType::FP32,
ctx.device_context().GetPlace());
}
}
return phi::KernelKey(
op_ptr->OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.device_context().GetPlace());
}
phi::KernelKey GetSgdExpectedKernelType(
const framework::ExecutionContext& ctx,
const framework::OperatorWithKernel* op_ptr) {
......
......@@ -28,6 +28,10 @@ phi::KernelKey GetReduceGradExpectedKernelType(
const framework::ExecutionContext& ctx,
const framework::OperatorWithKernel* op_ptr);
phi::KernelKey GetAssignExpectedKernelType(
const framework::ExecutionContext& ctx,
const framework::OperatorWithKernel* op_ptr);
phi::KernelKey GetSgdExpectedKernelType(
const framework::ExecutionContext& ctx,
const framework::OperatorWithKernel* op_ptr);
......
......@@ -438,6 +438,11 @@ class {{op["op_name"] | to_pascal_case}}Op : public framework::OperatorWithKerne
{%- endif %}
};
{% set infer_var_type_func_str = op["op_name"] | get_infer_var_type_func %}
{% if infer_var_type_func_str is not none %}
{{infer_var_type_func_str}}
{% endif %}
DECLARE_INFER_SHAPE_FUNCTOR({{op["op_name"]}}, {{op["op_name"] | to_pascal_case}}InferShapeFunctor,
PD_INFER_META(phi::{{op["infer_meta"]["func"]}}));
{# inplace inferer #}
......@@ -475,6 +480,10 @@ REGISTER_OPERATOR({{name}}, ops::{{name | to_pascal_case}}Op,
{% if op is supports_inplace %}{# inplace#}
ops::{{name | to_pascal_case}}InplaceInferer,
{% endif %}
{% set infer_var_type_func_str = op["op_name"] | get_infer_var_type_func %}
{% if infer_var_type_func_str is not none %}
ops::{{name | to_pascal_case}}InferVarType,
{% endif %}
{% if "backward_composite" in op and op["backward_composite"] is not none %}
ops::{{op["backward_composite"] | to_composite_grad_opmaker_name}},
{% endif %}
......
......@@ -15,7 +15,6 @@ register_unity_group(
argsort_op.cc
array_to_lod_tensor_op.cc
assert_op.cc
assign_op.cc
assign_value_op.cc
attention_lstm_op.cc
average_accumulates_op.cc
......
......@@ -131,6 +131,16 @@
extra :
attrs : [bool use_mkldnn = false, bool use_cudnn = false]
- op : assign
backward : assign_grad
inputs :
x : X
outputs :
out : Out
manual_signature : [assign, assign_grad]
get_expected_kernel_type :
assign : GetAssignExpectedKernelType
- op : atan
inputs :
x : X
......
# This file is to support those static ops different the dynamic.
- backward_op : assign_grad
forward : assign (Tensor x) -> Tensor(out)
args : (Tensor out_grad)
output : Tensor(x_grad)
composite: assign_grad(out_grad, x_grad)
invoke : assign(out_grad)
- backward_op : frobenius_norm_grad
forward: frobenius_norm (Tensor x, IntArray axis={0}, bool keepdim=false, bool reduce_all=false, int in_dtype=-1, int out_dtype=-1) -> Tensor(out)
args : (Tensor x, Tensor out, Tensor out_grad, IntArray axis={0}, bool keepdim=false, bool reduce_all=false, int in_dtype=-1, int out_dtype=-1)
......
......@@ -26,6 +26,17 @@
func : all_reduce
param: [x, reduce_type]
- op : assign
args : (Tensor x)
output : Tensor
infer_meta :
func : UnchangedInferMeta
kernel :
func : assign
optional : x
inplace : (x -> out)
backward : assign_grad
- op : broadcast
args : (Tensor x, int ring_id = 0, int root = 0)
output : Tensor(out)
......
......@@ -122,7 +122,9 @@ PD_REGISTER_GENERAL_KERNEL(assign_array,
CPU,
ALL_LAYOUT,
phi::AssignArrayKernel<phi::CPUContext>,
ALL_DTYPE) {}
ALL_DTYPE) {
kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND);
}
PD_REGISTER_KERNEL(assign_value,
CPU,
ALL_LAYOUT,
......@@ -146,7 +148,9 @@ PD_REGISTER_GENERAL_KERNEL(assign_array,
GPU,
ALL_LAYOUT,
phi::AssignArrayKernel<phi::GPUContext>,
ALL_DTYPE) {}
ALL_DTYPE) {
kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND);
}
PD_REGISTER_KERNEL(assign_value,
GPU,
ALL_LAYOUT,
......@@ -171,7 +175,9 @@ PD_REGISTER_GENERAL_KERNEL(assign_array,
XPU,
ALL_LAYOUT,
phi::AssignArrayKernel<phi::XPUContext>,
ALL_DTYPE) {}
ALL_DTYPE) {
kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND);
}
PD_REGISTER_KERNEL(assign_value,
XPU,
ALL_LAYOUT,
......
......@@ -38,14 +38,18 @@ PD_REGISTER_GENERAL_KERNEL(assign_sr,
CPU,
ALL_LAYOUT,
phi::sr::AssignKernel<phi::CPUContext>,
ALL_DTYPE) {}
ALL_DTYPE) {
kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND);
}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PD_REGISTER_GENERAL_KERNEL(assign_sr,
GPU,
ALL_LAYOUT,
phi::sr::AssignKernel<phi::GPUContext>,
ALL_DTYPE) {}
ALL_DTYPE) {
kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND);
}
#endif
#ifdef PADDLE_WITH_XPU
......@@ -53,5 +57,7 @@ PD_REGISTER_GENERAL_KERNEL(assign_sr,
XPU,
ALL_LAYOUT,
phi::sr::AssignKernel<phi::XPUContext>,
ALL_DTYPE) {}
ALL_DTYPE) {
kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND);
}
#endif
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册