未验证 提交 d8845735 编写于 作者: HappyHeavyRain's avatar HappyHeavyRain 提交者: GitHub

Support bw invoke fw (#50260)

* support bw invoke fw

* fix scale in static_backward.yaml

* fix the bug in tensorrt/convert

* move 'scale','sign' into ops.yaml

* add scale_grad of scale in op_compat.yaml

* change generated_static_op in CMakeLists.txt
上级 9af23f1d
......@@ -77,6 +77,6 @@ cc_library(
op_registry
variable_helper
memcpy
scale_op
generated_op
autograd_meta
hook_utils)
......@@ -7,7 +7,7 @@ if(NOT (NOT WITH_PYTHON AND ON_INFER))
${generated_deps}
eager_scale
scale_node
scale_op
generated_op
matmul_v2_op
dygraph_function
eager_prim_api)
......
......@@ -1051,7 +1051,7 @@ if(WITH_PSCORE)
heter_pipeline_trainer_test
SRCS heter_pipeline_trainer_test.cc
DEPS conditional_block_op
scale_op
generated_op
heter_listen_and_serv_op
executor
heter_server
......@@ -1068,7 +1068,7 @@ if(WITH_PSCORE)
heter_pipeline_trainer_test
SRCS heter_pipeline_trainer_test.cc
DEPS conditional_block_op
scale_op
generated_op
heter_listen_and_serv_op
executor
heter_server
......
......@@ -76,5 +76,5 @@ cc_library(
cc_test(
test_reference_count_pass_last_lived_ops
SRCS test_reference_count_pass_last_lived_ops.cc
DEPS parallel_executor elementwise_mul_op elementwise_add_op scale_op
DEPS parallel_executor elementwise_mul_op elementwise_add_op generated_op
eigen_function)
......@@ -65,7 +65,7 @@ if(WITH_TESTING AND NOT WIN32)
reduce_mean_op
feed_op
fetch_op
scale_op
generated_op
transfer_layout_op
jit_layer)
cc_test(
......
......@@ -11,6 +11,8 @@ set(legacy_bw_op_yaml_file
set(sparse_op_yaml_file ${CMAKE_SOURCE_DIR}/paddle/phi/api/yaml/sparse_ops.yaml)
set(sparse_bw_op_yaml_file
${CMAKE_SOURCE_DIR}/paddle/phi/api/yaml/sparse_backward.yaml)
set(static_bw_op_yaml_file
${CMAKE_SOURCE_DIR}/paddle/phi/api/yaml/static_backward.yaml)
if(NOT PYTHONINTERP_FOUND)
find_package(PythonInterp REQUIRED)
......@@ -66,6 +68,9 @@ execute_process(
COMMAND
${PYTHON_EXECUTABLE} parse_op.py --op_yaml_path ${sparse_bw_op_yaml_file}
--output_path ./parsed_ops/sparse_backward.parsed.yaml --backward
COMMAND
${PYTHON_EXECUTABLE} parse_op.py --op_yaml_path ${static_bw_op_yaml_file}
--output_path ./parsed_ops/static_backward.parsed.yaml --backward
RESULTS_VARIABLE _results)
foreach(_result in ${_results})
if(${_result})
......@@ -82,14 +87,24 @@ execute_process(
COMMAND
${PYTHON_EXECUTABLE} cross_validate.py --forward_yaml_paths
./parsed_ops/ops.parsed.yaml ./parsed_ops/legacy_ops.parsed.yaml
./parsed_ops/static_ops.parsed.yaml --backward_yaml_paths
./parsed_ops/backward_ops.parsed.yaml
--backward_yaml_paths ./parsed_ops/backward_ops.parsed.yaml
./parsed_ops/legacy_backward_ops.parsed.yaml
RESULT_VARIABLE _result)
if(${_result})
message(FATAL_ERROR "ops validation failed, exiting.")
endif()
execute_process(
WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}/paddle/fluid/operators/generator
COMMAND
${PYTHON_EXECUTABLE} cross_validate.py --forward_yaml_paths
./parsed_ops/static_ops.parsed.yaml --backward_yaml_paths
./parsed_ops/static_backward.parsed.yaml
RESULT_VARIABLE _result)
if(${_result})
message(FATAL_ERROR "static ops validation failed, exiting.")
endif()
execute_process(
WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}/paddle/fluid/operators/generator
COMMAND
......@@ -124,8 +139,9 @@ endif()
execute_process(
WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}/paddle/fluid/operators/generator
COMMAND
${PYTHON_EXECUTABLE} generate_static_op.py --ops_yaml_path
./parsed_ops/static_ops.parsed.yaml --op_version_yaml_path
${PYTHON_EXECUTABLE} generate_op.py --ops_yaml_path
./parsed_ops/static_ops.parsed.yaml --backward_yaml_path
./parsed_ops/static_backward.parsed.yaml --op_version_yaml_path
${CMAKE_SOURCE_DIR}/paddle/phi/api/yaml/op_version.yaml
--op_compat_yaml_path ${CMAKE_SOURCE_DIR}/paddle/phi/api/yaml/op_compat.yaml
--output_op_path "${generated_static_op_path}.tmp" --output_arg_map_path
......
......@@ -407,6 +407,7 @@ def process_invoke_op(forward_op_dict, backward_op_dict):
invoke_op = bw_op['invoke']['func']
args_list = bw_op['invoke']['args']
args_index = 0
# backward invoke forward
if invoke_op in forward_op_dict:
reuse_op = forward_op_dict[invoke_op]
bw_op['invoke']['func'] = reuse_op['op_name']
......@@ -460,17 +461,16 @@ def parse_drop_empty_grad(op_fluid_list: list, bw_op_dict: dict):
for bw_name in op_op['backward'].split(',')
]
for bw_name in bw_names:
assert (
bw_name in bw_op_dict
), f"backward {bw_name} is not existed"
for out_grad in op_op['drop_empty_grad']:
assert (
out_grad in bw_op_dict[bw_name]['output_dict']
), f'''
{bw_name} with {out_grad} is not existed in output_dict '''
bw_op_dict[bw_name]['output_dict'][out_grad][
'drop_empty_grad'
] = False
# static_ops.yaml and ops.yaml use the common op_compat.yaml
if bw_name in bw_op_dict:
for out_grad in op_op['drop_empty_grad']:
assert (
out_grad in bw_op_dict[bw_name]['output_dict']
), f'''
{bw_name} with {out_grad} is not existed in output_dict '''
bw_op_dict[bw_name]['output_dict'][out_grad][
'drop_empty_grad'
] = False
def main(
......@@ -493,7 +493,8 @@ def main(
op_versions = yaml.safe_load(f)
# add op version info into op
for op_version in op_versions:
forward_op_dict[op_version['op']]['version'] = op_version['version']
if op_version['op'] in forward_op_dict:
forward_op_dict[op_version['op']]['version'] = op_version['version']
with open(op_compat_yaml_path, "rt") as f:
op_fluid_map_list = yaml.safe_load(f)
......
......@@ -81,6 +81,7 @@ def restruct_io(op):
def main(
ops_yaml_path,
backward_yaml_path,
op_compat_yaml_path,
op_version_yaml_path,
output_op_path,
......@@ -91,6 +92,11 @@ def main(
ops = [restruct_io(op) for op in ops]
forward_op_dict = to_named_dict(ops)
with open(backward_yaml_path, "rt") as f:
backward_ops = yaml.safe_load(f)
backward_ops = [restruct_io(op) for op in backward_ops]
backward_op_dict = to_named_dict(backward_ops)
with open(op_version_yaml_path, "rt") as f:
op_versions = yaml.safe_load(f)
......@@ -139,6 +145,11 @@ if __name__ == "__main__":
parser.add_argument(
'--ops_yaml_path', type=str, help="parsed static ops yaml file."
)
parser.add_argument(
'--backward_yaml_path',
type=str,
help="parsed static backward ops yaml file.",
)
parser.add_argument(
'--op_compat_yaml_path', type=str, help="ops args compat yaml file."
)
......@@ -157,6 +168,7 @@ if __name__ == "__main__":
args = parser.parse_args()
main(
args.ops_yaml_path,
args.backward_yaml_path,
args.op_compat_yaml_path,
args.op_version_yaml_path,
args.output_op_path,
......
......@@ -522,6 +522,31 @@ class {{name | to_pascal_case}}OpMaker : public framework::SingleGradOpMaker<T>
true)}});
{% endfor %}
{% for attr in invoke_op["attrs"] %}
{% set attr_name = attr["fluid_name"] %}
{% set fw_attrs = forward_op["attrs"] %}
{% if attr_name in forward_attr_names %}
{# invoke_op's attrs and fw_attr's attrs must be the same#}
{% set fw_attr = fw_attrs[loop.index0] %}
{% if fw_attr["typename"] == "IntArray" %}
{% if 'tensor_name' in attr or 'manual_flag' not in attr %}
if (this->HasInput("{{fw_attr | to_int_array_tensor_name}}")) {
grad_op->SetInput("{{fw_attr | to_int_array_tensor_name}}", this->Input("{{fw_attr | to_int_array_tensor_name}}"));
}
{% endif %}
{% if 'tensors_name' in fw_attr or 'manual_flag' not in fw_attr %}
if (this->HasInput("{{fw_attr | to_int_array_tensors_name}}")) {
grad_op->SetInput("{{fw_attr | to_int_array_tensors_name}}", this->Input("{{fw_attr | to_int_array_tensors_name}}"));
}
{% endif %}
{% elif fw_attr["typename"] == "Scalar" %}
if (this->HasInput("{{fw_attr | to_scalar_tensor_name}}")) {
grad_op->SetInput("{{fw_attr | to_scalar_tensor_name}}", this->Input("{{fw_attr | to_scalar_tensor_name}}"));
}
{% endif %}
{% endif %}
{% endfor %}
{% for attr in invoke_op["attrs"] %}
grad_op->SetAttr("{{attr["fluid_name"]}}", {{attr["value"]}});
{% endfor %}
......
......@@ -86,7 +86,7 @@ cc_test_old(
executor
scope
proto_desc
scale_op
generated_op
eigen_function)
set_source_files_properties(
......@@ -100,7 +100,7 @@ cc_test_old(
executor
scope
proto_desc
scale_op
generated_op
send_and_recv_op
${RPC_DEPS}
${DISTRIBUTE_DEPS}
......@@ -117,7 +117,7 @@ cc_test_old(
executor
scope
proto_desc
scale_op
generated_op
send_and_recv_op
${RPC_DEPS}
${DISTRIBUTE_DEPS}
......@@ -134,14 +134,14 @@ cc_test_old(
executor
scope
proto_desc
scale_op
generated_op
heter_listen_and_serv_op
${RPC_DEPS}
${DISTRIBUTE_DEPS}
eigen_function)
#set_source_files_properties(heter_cloud_comm_cpu_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
#cc_test(heter_cloud_comm_cpu_test SRCS heter_cloud_comm_cpu_test.cc DEPS executor scope proto_desc scale_op heter_listen_and_serv_op ${RPC_DEPS} ${DISTRIBUTE_DEPS} eigen_function)
#cc_test(heter_cloud_comm_cpu_test SRCS heter_cloud_comm_cpu_test.cc DEPS executor scope proto_desc generated_static_op heter_listen_and_serv_op ${RPC_DEPS} ${DISTRIBUTE_DEPS} eigen_function)
set_source_files_properties(
switch_server_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
......@@ -153,7 +153,7 @@ cc_binary(
executor
scope
proto_desc
scale_op
generated_op
heter_listen_and_serv_op
${RPC_DEPS}
${DISTRIBUTE_DEPS}
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <string>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/unary.h"
namespace paddle {
namespace operators {
class ScaleOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
auto input_data_type =
framework::OperatorWithKernel::IndicateVarDataType(ctx, "X");
return phi::KernelKey(input_data_type, ctx.GetPlace());
}
};
class ScaleOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "(Tensor) Input tensor of scale operator.");
AddInput("ScaleTensor",
"(Tensor) If provided, use this as "
"scale factor, this has a higher priority than "
"attr(scale), the shape of this tensor MUST BE 1.")
.AsDispensable();
AddOutput("Out", "(Tensor) Output tensor of scale operator.");
AddComment(R"DOC(
**Scale operator**
Apply scaling and bias addition to the input tensor.
if bias_after_scale=True:
$$Out = scale*X + bias$$
else:
$$Out = scale*(X + bias)$$
)DOC");
AddAttr<float>("scale", "The scaling factor of the scale operator.")
.SetDefault(1.0);
AddAttr<float>("bias", "The bias of the scale operator.").SetDefault(0.0);
AddAttr<bool>(
"bias_after_scale",
"Apply bias addition after or before scaling. It is useful for "
"numeric stability in some circumstances.")
.SetDefault(true);
}
};
class ScaleOpVarTypeInference : public framework::VarTypeInference {
public:
void operator()(framework::InferVarTypeContext *ctx) const override {
ctx->SyncTypeAndDataType("X", "Out");
}
};
template <typename T>
class ScaleGradMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
void Apply(GradOpPtr<T> grad_op) const override {
grad_op->SetType("scale");
grad_op->SetInput("X", this->OutputGrad("Out"));
if (this->HasInput("ScaleTensor") > 0) {
grad_op->SetInput("ScaleTensor", this->Input("ScaleTensor"));
}
grad_op->SetOutput("Out", this->InputGrad("X"));
VLOG(6) << "Finish SetOutput";
grad_op->SetAttr("scale", this->GetAttr("scale"));
VLOG(6) << "Finish Set Attr scale";
grad_op->SetAttr("bias", 0.0f);
VLOG(6) << "Finish Set Attr bias";
grad_op->SetAttr("bias_after_scale", true);
VLOG(6) << "Finish Set Attr bias_after_scale";
VLOG(6) << "Finish Apply";
}
};
DECLARE_INPLACE_OP_INFERER(ScaleOpInplaceInferer, {"X", "Out"});
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(scale,
ScaleInferShapeFunctor,
PD_INFER_META(phi::UnchangedInferMeta));
REGISTER_OPERATOR(scale,
ops::ScaleOp,
ops::ScaleOpMaker,
ops::ScaleGradMaker<paddle::framework::OpDesc>,
ops::ScaleGradMaker<paddle::imperative::OpBase>,
ScaleInferShapeFunctor,
ops::ScaleOpVarTypeInference,
ops::ScaleOpInplaceInferer);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <memory>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/unary.h"
namespace paddle {
namespace operators {
class SignOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
};
template <typename AttrType>
class SignOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "(Tensor) Input tensor of sign operator.");
AddOutput("Out", "(Tensor) Output tensor of sign operator.");
AddComment(R"DOC(
Sign operator
$$Out = X.sign()$$
)DOC");
}
};
template <typename T>
class SignGradMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
void Apply(GradOpPtr<T> grad_op) const override {
grad_op->SetType("scale");
grad_op->SetInput("X", this->OutputGrad("Out"));
grad_op->SetOutput("Out", this->InputGrad("X"));
grad_op->SetAttr("scale", 0.0f);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(sign,
SignInferShapeFunctor,
PD_INFER_META(phi::UnchangedInferMeta));
REGISTER_OPERATOR(sign,
ops::SignOp,
ops::SignOpMaker<float>,
ops::SignGradMaker<paddle::framework::OpDesc>,
ops::SignGradMaker<paddle::imperative::OpBase>,
SignInferShapeFunctor);
......@@ -29,7 +29,6 @@ cc_test_old(
prim_utils
operator
elementwise_mul_op
scale_op
activation_op
phi_api
phi_dygraph_api
......
......@@ -1103,6 +1103,12 @@
backward : rsqrt_double_grad
inplace : (out_grad -> x_grad)
- backward_op : scale_grad
forward : scale (Tensor x, Scalar scale, float bias, bool bias_after_scale) -> Tensor(out)
args : (Tensor out_grad, Scalar scale=1.0)
output : Tensor(x_grad)
invoke : scale(out_grad, scale, 0.0f, true)
- backward_op : scatter_grad
forward : scatter (Tensor x, Tensor index, Tensor updates, bool overwrite=true) -> Tensor(out)
args : (Tensor index, Tensor updates, Tensor out_grad, bool overwrite)
......@@ -1207,6 +1213,12 @@
optional : grad_grad_out_grad
inplace : (grad_grad_x -> fwd_grad_out_grad)
- backward_op : sign_grad
forward : sign (Tensor x) -> Tensor(out)
args : (Tensor out_grad)
output : Tensor(x_grad)
invoke : scale(out_grad, 0.0f, 0.0f, true)
- backward_op : silu_grad
forward : silu (Tensor x) -> Tensor(out)
args : (Tensor x, Tensor out_grad)
......
......@@ -1144,12 +1144,6 @@
func : rrelu_grad
data_type : x
- backward_op : scale_grad
forward : scale (Tensor x, Scalar scale, float bias, bool bias_after_scale) -> Tensor(out)
args : (Tensor out_grad, Scalar scale=1.0, bool bias_after_scale=true)
output : Tensor(x_grad)
invoke : scale(out_grad, scale, 0.0, bias_after_scale)
- backward_op : segment_pool_grad
forward : segment_pool (Tensor x, Tensor segment_ids, str pooltype) -> Tensor(out), Tensor(summed_ids)
args : (Tensor x, Tensor segment_ids, Tensor out, Tensor summed_ids, Tensor out_grad, str pooltype)
......@@ -1173,12 +1167,6 @@
func : sigmoid_cross_entropy_with_logits_grad
inplace : (out_grad -> x_grad)
- backward_op : sign_grad
forward : sign (Tensor x) -> Tensor(out)
args : (Tensor out_grad)
output : Tensor(x_grad)
invoke : scale(out_grad, 0.0, 0.0, true)
- backward_op : slice_double_grad
forward : slice_grad (Tensor input, Tensor grad_out, int64_t[] axes, IntArray starts, IntArray ends, int64_t[] infer_flags, int64_t[] decrease_axis) -> Tensor(grad_input)
args : (Tensor grad_input_grad, int64_t[] axes, IntArray starts, IntArray ends, int64_t[] infer_flags, int64_t[] decrease_axis)
......
......@@ -1559,18 +1559,6 @@
intermediate : noise
backward : rrelu_grad
- op : scale
args : (Tensor x, Scalar scale, float bias, bool bias_after_scale)
output : Tensor(out)
infer_meta :
func : UnchangedInferMeta
param : [x]
kernel :
func : scale {dense -> dense},
scale_sr {selected_rows -> selected_rows}
inplace : (x -> out)
backward : scale_grad
- op : segment_pool
args : (Tensor x, Tensor segment_ids, str pooltype)
output : Tensor(out), Tensor(summed_ids)
......@@ -1616,15 +1604,6 @@
func : sigmoid_cross_entropy_with_logits
backward : sigmoid_cross_entropy_with_logits_grad
- op : sign
args : (Tensor x)
output : Tensor(out)
infer_meta :
func : UnchangedInferMeta
kernel :
func : sign
backward : sign_grad
- op : slice
args : (Tensor input, int64_t[] axes, IntArray starts, IntArray ends, int64_t[] infer_flags, int64_t[] decrease_axis)
output : Tensor
......
......@@ -1326,10 +1326,15 @@
attrs : [bool use_mkldnn = false, bool use_cudnn = false]
- op : scale
backward : scale_grad
inputs :
x : X
outputs :
out: Out
out : Out
scalar :
scale :
data_type : float
tensor_name : ScaleTensor
extra :
attrs : [bool use_mkldnn = false]
......@@ -1425,6 +1430,13 @@
extra :
attrs : [bool use_mkldnn = false, bool use_cudnn = false]
- op : sign
backward : sign_grad
inputs :
x : X
outputs :
out : Out
- op : silu
backward : silu_grad
inputs :
......
......@@ -1022,6 +1022,19 @@
inplace : (x -> out)
backward : rsqrt_grad
- op : scale
args : (Tensor x, Scalar scale=1.0, float bias=0.0, bool bias_after_scale=true)
output : Tensor(out)
infer_meta :
func : UnchangedInferMeta
param : [x]
kernel :
func : scale {dense -> dense},
scale_sr {selected_rows -> selected_rows}
data_type : x
inplace : (x -> out)
backward : scale_grad
- op : scatter
args : (Tensor x, Tensor index, Tensor updates, bool overwrite=true)
output : Tensor(out)
......@@ -1111,6 +1124,15 @@
func : sigmoid
backward : sigmoid_grad
- op : sign
args : (Tensor x)
output : Tensor(out)
infer_meta :
func : UnchangedInferMeta
kernel :
func : sign
backward : sign_grad
- op : silu
args : (Tensor x)
output : Tensor
......
# This file is to support those static ops different the dynamic.
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
/**
* Note [ Why does the ArgumentMapping function need to be so complicated? ]
*
* In order to meet the requirements of infrt, the function used to match Op
* and Kernel parameters, need to be placed in phi as a compatible component,
* and does not depend on fluid.
*
* Because infrt not only needs to dynamically call this argument mapping
* function at runtime, but also needs to statically declare all possible
* results of the function before running without any information.
*
* The infrt declare like:
*
* def PDKEL_Reshape_to_CPU : Pat<
* (PD_ReshapeOp $x, $shape_tensor, $shape_attr), // OpMaker arguments
* (PDKEL_ReshapeKernelAttr $x, fn($shape_attr)>; // Kernel arguments
* def PDKEL_Reshape_to_CPU : Pat<
* (PD_ReshapeOp $x, $shape_tensor, $shape_attr),
* (PDKEL_ReshapeKernelAttr $x, fn($shape_tensor)>;
*
* Therefore, we need to write out each result of the argument mapping function,
* like `KernelSignature("full", {}, {"ShapeTensor", "value"}, {"Out"})`, it
* cannot contains variable, only can contains const char* string.
*
* Infrt will parse all results before running for the generation of the above
* static declare, which leads to some functions being written in a long way,
* and the complicated ones may have hundreds of lines, which has certain side
* effects on the programming experience.
*/
KernelSignature ScaleOpArgumentMapping(const ArgumentMappingContext& ctx) {
if (ctx.IsDenseTensorInput("X")) {
if (ctx.HasInput("ScaleTensor")) {
return KernelSignature(
"scale", {"X"}, {"ScaleTensor", "bias", "bias_after_scale"}, {"Out"});
} else {
return KernelSignature(
"scale", {"X"}, {"scale", "bias", "bias_after_scale"}, {"Out"});
}
} else if (ctx.IsSelectedRowsInput("X")) {
if (ctx.HasInput("ScaleTensor")) {
return KernelSignature("scale_sr",
{"X"},
{"ScaleTensor", "bias", "bias_after_scale"},
{"Out"});
} else {
return KernelSignature(
"scale_sr", {"X"}, {"scale", "bias", "bias_after_scale"}, {"Out"});
}
} else {
return KernelSignature("unregistered", {}, {}, {});
}
}
} // namespace phi
// op_type, api_name, arg_mapping_fn
PD_REGISTER_ARG_MAPPING_FN(scale, phi::ScaleOpArgumentMapping);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册